mirror of
https://github.com/twitter/the-algorithm.git
synced 2025-01-07 01:48:16 +01:00
Compare commits
2 Commits
53fc9c078c
...
361b07d4fa
Author | SHA1 | Date | |
---|---|---|---|
|
361b07d4fa | ||
|
96b37a8f1b |
@ -1,5 +1,6 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
import kerastuner as kt
|
import kerastuner as kt
|
||||||
import math
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import random
|
import random
|
||||||
@ -14,8 +15,9 @@ from tensorflow.keras.models import Sequential
|
|||||||
from tensorflow.keras.layers import Dense
|
from tensorflow.keras.layers import Dense
|
||||||
from google.cloud import storage
|
from google.cloud import storage
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
physical_devices = tf.config.list_physical_devices('GPU')
|
physical_devices = tf.config.list_physical_devices('GPU')
|
||||||
physical_devices
|
|
||||||
|
|
||||||
tf.config.set_visible_devices([tf.config.PhysicalDevice(name='/physical_device:GPU:1', device_type='GPU')], 'GPU')
|
tf.config.set_visible_devices([tf.config.PhysicalDevice(name='/physical_device:GPU:1', device_type='GPU')], 'GPU')
|
||||||
tf.config.get_visible_devices('GPU')
|
tf.config.get_visible_devices('GPU')
|
||||||
@ -58,11 +60,9 @@ test_batch_size = 256
|
|||||||
validation_batch_size = 256
|
validation_batch_size = 256
|
||||||
|
|
||||||
do_resample = False
|
do_resample = False
|
||||||
def class_func(features, label):
|
|
||||||
return label
|
|
||||||
|
|
||||||
resample_fn = tf.data.experimental.rejection_resample(
|
resample_fn = tf.data.experimental.rejection_resample(
|
||||||
class_func, target_dist = [0.5, 0.5], seed=0
|
lambda _, x: x, target_dist = [0.5, 0.5], seed=0
|
||||||
)
|
)
|
||||||
train_glob = f"{input_root}/train/tfrecord/*.tfrecord"
|
train_glob = f"{input_root}/train/tfrecord/*.tfrecord"
|
||||||
train_files = tf.io.gfile.glob(train_glob)
|
train_files = tf.io.gfile.glob(train_glob)
|
||||||
@ -135,7 +135,7 @@ for example in tqdm(check_ds):
|
|||||||
if label == 1:
|
if label == 1:
|
||||||
pos_cnt += 1
|
pos_cnt += 1
|
||||||
cnt += 1
|
cnt += 1
|
||||||
print(f'{cnt} train entries with {pos_cnt} positive')
|
LOG.info(f'{cnt} train entries with {pos_cnt} positive')
|
||||||
|
|
||||||
metrics = []
|
metrics = []
|
||||||
|
|
||||||
@ -264,7 +264,7 @@ bucket = client.get_bucket(...)
|
|||||||
copy_local_directory_to_gcs(model_path, bucket, model_path)
|
copy_local_directory_to_gcs(model_path, bucket, model_path)
|
||||||
copy_local_directory_to_gcs('tuner_dir', bucket, 'tuner_dir')
|
copy_local_directory_to_gcs('tuner_dir', bucket, 'tuner_dir')
|
||||||
loaded_model = tf.keras.models.load_model(model_path)
|
loaded_model = tf.keras.models.load_model(model_path)
|
||||||
print(history.history.keys())
|
LOG.info(history.history.keys())
|
||||||
|
|
||||||
plt.figure(figsize = (20, 5))
|
plt.figure(figsize = (20, 5))
|
||||||
|
|
||||||
@ -319,7 +319,7 @@ for label in test_labels:
|
|||||||
else:
|
else:
|
||||||
n_test_neg +=1
|
n_test_neg +=1
|
||||||
|
|
||||||
print(f'n_test = {n_test}, n_pos = {n_test_pos}, n_neg = {n_test_neg}')
|
LOG.info(f'n_test = {n_test}, n_pos = {n_test_pos}, n_neg = {n_test_neg}')
|
||||||
|
|
||||||
n_test_sens_prev_pos = 0
|
n_test_sens_prev_pos = 0
|
||||||
n_test_sens_prev_neg = 0
|
n_test_sens_prev_neg = 0
|
||||||
@ -332,7 +332,7 @@ for label in test_sens_prev_labels:
|
|||||||
else:
|
else:
|
||||||
n_test_sens_prev_neg +=1
|
n_test_sens_prev_neg +=1
|
||||||
|
|
||||||
print(f'n_test_sens_prev = {n_test_sens_prev}, n_pos_sens_prev = {n_test_sens_prev_pos}, n_neg = {n_test_sens_prev_neg}')
|
LOG.info(f'n_test_sens_prev = {n_test_sens_prev}, n_pos_sens_prev = {n_test_sens_prev_pos}, n_neg = {n_test_sens_prev_neg}')
|
||||||
|
|
||||||
test_weights = np.ones(np.asarray(test_preds).shape)
|
test_weights = np.ones(np.asarray(test_preds).shape)
|
||||||
|
|
||||||
@ -385,7 +385,7 @@ plt.subplot(1, 3, 2)
|
|||||||
plt.plot(pr[2], pr[1][0:-1])
|
plt.plot(pr[2], pr[1][0:-1])
|
||||||
plt.xlabel("threshold")
|
plt.xlabel("threshold")
|
||||||
plt.ylabel("recall")
|
plt.ylabel("recall")
|
||||||
plt.title("Keras", size=20)
|
plt.title("Keras", fontsize=20)
|
||||||
|
|
||||||
plt.subplot(1, 3, 3)
|
plt.subplot(1, 3, 3)
|
||||||
|
|
||||||
@ -406,7 +406,7 @@ precision, recall, thresholds = pr
|
|||||||
|
|
||||||
auc_precision_recall = sklearn.metrics.auc(recall, precision)
|
auc_precision_recall = sklearn.metrics.auc(recall, precision)
|
||||||
|
|
||||||
print(auc_precision_recall)
|
LOG.info(auc_precision_recall)
|
||||||
|
|
||||||
plt.figure(figsize=(15, 10))
|
plt.figure(figsize=(15, 10))
|
||||||
plt.plot(recall, precision)
|
plt.plot(recall, precision)
|
||||||
@ -415,12 +415,12 @@ plt.xlabel("recall")
|
|||||||
plt.ylabel("precision")
|
plt.ylabel("precision")
|
||||||
|
|
||||||
ptAt50 = get_point_for_recall(0.5, recall, precision)
|
ptAt50 = get_point_for_recall(0.5, recall, precision)
|
||||||
print(ptAt50)
|
LOG.info(ptAt50)
|
||||||
plt.plot( [ptAt50[0],ptAt50[0]], [0,ptAt50[1]], 'r')
|
plt.plot( [ptAt50[0],ptAt50[0]], [0,ptAt50[1]], 'r')
|
||||||
plt.plot([0, ptAt50[0]], [ptAt50[1], ptAt50[1]], 'r')
|
plt.plot([0, ptAt50[0]], [ptAt50[1], ptAt50[1]], 'r')
|
||||||
|
|
||||||
ptAt90 = get_point_for_recall(0.9, recall, precision)
|
ptAt90 = get_point_for_recall(0.9, recall, precision)
|
||||||
print(ptAt90)
|
LOG.info(ptAt50)
|
||||||
plt.plot( [ptAt90[0],ptAt90[0]], [0,ptAt90[1]], 'b')
|
plt.plot( [ptAt90[0],ptAt90[0]], [0,ptAt90[1]], 'b')
|
||||||
plt.plot([0, ptAt90[0]], [ptAt90[1], ptAt90[1]], 'b')
|
plt.plot([0, ptAt90[0]], [ptAt90[1], ptAt90[1]], 'b')
|
||||||
|
|
||||||
@ -428,8 +428,8 @@ ptAt50fmt = "%.4f" % ptAt50[1]
|
|||||||
ptAt90fmt = "%.4f" % ptAt90[1]
|
ptAt90fmt = "%.4f" % ptAt90[1]
|
||||||
aucFmt = "%.4f" % auc_precision_recall
|
aucFmt = "%.4f" % auc_precision_recall
|
||||||
plt.title(
|
plt.title(
|
||||||
f"Keras (nsfw MU test)\nAUC={aucFmt}\np={ptAt50fmt} @ r=0.5\np={ptAt90fmt} @ r=0.9\nN_train={...}} ({...} pos), N_test={n_test} ({n_test_pos} pos)",
|
f"Keras (nsfw MU test)\nAUC={aucFmt}\np={ptAt50fmt} @ r=0.5\np={ptAt90fmt} @ r=0.9\nN_train={...} ({...} pos), N_test={n_test} ({n_test_pos} pos)",
|
||||||
size=20
|
fontsize=20
|
||||||
)
|
)
|
||||||
plt.subplots_adjust(top=0.72)
|
plt.subplots_adjust(top=0.72)
|
||||||
plt.savefig('recall_precision_nsfw_Keras_with_twitter_CLIP_MU_test.pdf')
|
plt.savefig('recall_precision_nsfw_Keras_with_twitter_CLIP_MU_test.pdf')
|
||||||
@ -437,7 +437,7 @@ plt.savefig('recall_precision_nsfw_Keras_with_twitter_CLIP_MU_test.pdf')
|
|||||||
precision, recall, thresholds = pr_sens_prev
|
precision, recall, thresholds = pr_sens_prev
|
||||||
|
|
||||||
auc_precision_recall = sklearn.metrics.auc(recall, precision)
|
auc_precision_recall = sklearn.metrics.auc(recall, precision)
|
||||||
print(auc_precision_recall)
|
LOG.info(auc_precision_recall)
|
||||||
plt.figure(figsize=(15, 10))
|
plt.figure(figsize=(15, 10))
|
||||||
|
|
||||||
plt.plot(recall, precision)
|
plt.plot(recall, precision)
|
||||||
@ -446,12 +446,12 @@ plt.xlabel("recall")
|
|||||||
plt.ylabel("precision")
|
plt.ylabel("precision")
|
||||||
|
|
||||||
ptAt50 = get_point_for_recall(0.5, recall, precision)
|
ptAt50 = get_point_for_recall(0.5, recall, precision)
|
||||||
print(ptAt50)
|
LOG.info(ptAt50)
|
||||||
plt.plot( [ptAt50[0],ptAt50[0]], [0,ptAt50[1]], 'r')
|
plt.plot( [ptAt50[0],ptAt50[0]], [0,ptAt50[1]], 'r')
|
||||||
plt.plot([0, ptAt50[0]], [ptAt50[1], ptAt50[1]], 'r')
|
plt.plot([0, ptAt50[0]], [ptAt50[1], ptAt50[1]], 'r')
|
||||||
|
|
||||||
ptAt90 = get_point_for_recall(0.9, recall, precision)
|
ptAt90 = get_point_for_recall(0.9, recall, precision)
|
||||||
print(ptAt90)
|
LOG.info(ptAt90)
|
||||||
plt.plot( [ptAt90[0],ptAt90[0]], [0,ptAt90[1]], 'b')
|
plt.plot( [ptAt90[0],ptAt90[0]], [0,ptAt90[1]], 'b')
|
||||||
plt.plot([0, ptAt90[0]], [ptAt90[1], ptAt90[1]], 'b')
|
plt.plot([0, ptAt90[0]], [ptAt90[1], ptAt90[1]], 'b')
|
||||||
|
|
||||||
@ -460,7 +460,7 @@ ptAt90fmt = "%.4f" % ptAt90[1]
|
|||||||
aucFmt = "%.4f" % auc_precision_recall
|
aucFmt = "%.4f" % auc_precision_recall
|
||||||
plt.title(
|
plt.title(
|
||||||
f"Keras (nsfw sens prev test)\nAUC={aucFmt}\np={ptAt50fmt} @ r=0.5\np={ptAt90fmt} @ r=0.9\nN_train={...} ({...} pos), N_test={n_test_sens_prev} ({n_test_sens_prev_pos} pos)",
|
f"Keras (nsfw sens prev test)\nAUC={aucFmt}\np={ptAt50fmt} @ r=0.5\np={ptAt90fmt} @ r=0.9\nN_train={...} ({...} pos), N_test={n_test_sens_prev} ({n_test_sens_prev_pos} pos)",
|
||||||
size=20
|
fontsize=20
|
||||||
)
|
)
|
||||||
plt.subplots_adjust(top=0.72)
|
plt.subplots_adjust(top=0.72)
|
||||||
plt.savefig('recall_precision_nsfw_Keras_with_twitter_CLIP_sens_prev_test.pdf')
|
plt.savefig('recall_precision_nsfw_Keras_with_twitter_CLIP_sens_prev_test.pdf')
|
Loading…
Reference in New Issue
Block a user