This commit is contained in:
Maria Novik 2023-05-22 17:38:50 -05:00 committed by GitHub
commit 361b07d4fa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,5 +1,6 @@
import logging
import kerastuner as kt import kerastuner as kt
import math
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import random import random
@ -14,8 +15,9 @@ from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense from tensorflow.keras.layers import Dense
from google.cloud import storage from google.cloud import storage
LOG = logging.getLogger(__name__)
physical_devices = tf.config.list_physical_devices('GPU') physical_devices = tf.config.list_physical_devices('GPU')
physical_devices
tf.config.set_visible_devices([tf.config.PhysicalDevice(name='/physical_device:GPU:1', device_type='GPU')], 'GPU') tf.config.set_visible_devices([tf.config.PhysicalDevice(name='/physical_device:GPU:1', device_type='GPU')], 'GPU')
tf.config.get_visible_devices('GPU') tf.config.get_visible_devices('GPU')
@ -58,11 +60,9 @@ test_batch_size = 256
validation_batch_size = 256 validation_batch_size = 256
do_resample = False do_resample = False
def class_func(features, label):
return label
resample_fn = tf.data.experimental.rejection_resample( resample_fn = tf.data.experimental.rejection_resample(
class_func, target_dist = [0.5, 0.5], seed=0 lambda _, x: x, target_dist = [0.5, 0.5], seed=0
) )
train_glob = f"{input_root}/train/tfrecord/*.tfrecord" train_glob = f"{input_root}/train/tfrecord/*.tfrecord"
train_files = tf.io.gfile.glob(train_glob) train_files = tf.io.gfile.glob(train_glob)
@ -135,7 +135,7 @@ for example in tqdm(check_ds):
if label == 1: if label == 1:
pos_cnt += 1 pos_cnt += 1
cnt += 1 cnt += 1
print(f'{cnt} train entries with {pos_cnt} positive') LOG.info(f'{cnt} train entries with {pos_cnt} positive')
metrics = [] metrics = []
@ -264,7 +264,7 @@ bucket = client.get_bucket(...)
copy_local_directory_to_gcs(model_path, bucket, model_path) copy_local_directory_to_gcs(model_path, bucket, model_path)
copy_local_directory_to_gcs('tuner_dir', bucket, 'tuner_dir') copy_local_directory_to_gcs('tuner_dir', bucket, 'tuner_dir')
loaded_model = tf.keras.models.load_model(model_path) loaded_model = tf.keras.models.load_model(model_path)
print(history.history.keys()) LOG.info(history.history.keys())
plt.figure(figsize = (20, 5)) plt.figure(figsize = (20, 5))
@ -319,7 +319,7 @@ for label in test_labels:
else: else:
n_test_neg +=1 n_test_neg +=1
print(f'n_test = {n_test}, n_pos = {n_test_pos}, n_neg = {n_test_neg}') LOG.info(f'n_test = {n_test}, n_pos = {n_test_pos}, n_neg = {n_test_neg}')
n_test_sens_prev_pos = 0 n_test_sens_prev_pos = 0
n_test_sens_prev_neg = 0 n_test_sens_prev_neg = 0
@ -332,7 +332,7 @@ for label in test_sens_prev_labels:
else: else:
n_test_sens_prev_neg +=1 n_test_sens_prev_neg +=1
print(f'n_test_sens_prev = {n_test_sens_prev}, n_pos_sens_prev = {n_test_sens_prev_pos}, n_neg = {n_test_sens_prev_neg}') LOG.info(f'n_test_sens_prev = {n_test_sens_prev}, n_pos_sens_prev = {n_test_sens_prev_pos}, n_neg = {n_test_sens_prev_neg}')
test_weights = np.ones(np.asarray(test_preds).shape) test_weights = np.ones(np.asarray(test_preds).shape)
@ -385,7 +385,7 @@ plt.subplot(1, 3, 2)
plt.plot(pr[2], pr[1][0:-1]) plt.plot(pr[2], pr[1][0:-1])
plt.xlabel("threshold") plt.xlabel("threshold")
plt.ylabel("recall") plt.ylabel("recall")
plt.title("Keras", size=20) plt.title("Keras", fontsize=20)
plt.subplot(1, 3, 3) plt.subplot(1, 3, 3)
@ -406,7 +406,7 @@ precision, recall, thresholds = pr
auc_precision_recall = sklearn.metrics.auc(recall, precision) auc_precision_recall = sklearn.metrics.auc(recall, precision)
print(auc_precision_recall) LOG.info(auc_precision_recall)
plt.figure(figsize=(15, 10)) plt.figure(figsize=(15, 10))
plt.plot(recall, precision) plt.plot(recall, precision)
@ -415,12 +415,12 @@ plt.xlabel("recall")
plt.ylabel("precision") plt.ylabel("precision")
ptAt50 = get_point_for_recall(0.5, recall, precision) ptAt50 = get_point_for_recall(0.5, recall, precision)
print(ptAt50) LOG.info(ptAt50)
plt.plot( [ptAt50[0],ptAt50[0]], [0,ptAt50[1]], 'r') plt.plot( [ptAt50[0],ptAt50[0]], [0,ptAt50[1]], 'r')
plt.plot([0, ptAt50[0]], [ptAt50[1], ptAt50[1]], 'r') plt.plot([0, ptAt50[0]], [ptAt50[1], ptAt50[1]], 'r')
ptAt90 = get_point_for_recall(0.9, recall, precision) ptAt90 = get_point_for_recall(0.9, recall, precision)
print(ptAt90) LOG.info(ptAt50)
plt.plot( [ptAt90[0],ptAt90[0]], [0,ptAt90[1]], 'b') plt.plot( [ptAt90[0],ptAt90[0]], [0,ptAt90[1]], 'b')
plt.plot([0, ptAt90[0]], [ptAt90[1], ptAt90[1]], 'b') plt.plot([0, ptAt90[0]], [ptAt90[1], ptAt90[1]], 'b')
@ -428,8 +428,8 @@ ptAt50fmt = "%.4f" % ptAt50[1]
ptAt90fmt = "%.4f" % ptAt90[1] ptAt90fmt = "%.4f" % ptAt90[1]
aucFmt = "%.4f" % auc_precision_recall aucFmt = "%.4f" % auc_precision_recall
plt.title( plt.title(
f"Keras (nsfw MU test)\nAUC={aucFmt}\np={ptAt50fmt} @ r=0.5\np={ptAt90fmt} @ r=0.9\nN_train={...}} ({...} pos), N_test={n_test} ({n_test_pos} pos)", f"Keras (nsfw MU test)\nAUC={aucFmt}\np={ptAt50fmt} @ r=0.5\np={ptAt90fmt} @ r=0.9\nN_train={...} ({...} pos), N_test={n_test} ({n_test_pos} pos)",
size=20 fontsize=20
) )
plt.subplots_adjust(top=0.72) plt.subplots_adjust(top=0.72)
plt.savefig('recall_precision_nsfw_Keras_with_twitter_CLIP_MU_test.pdf') plt.savefig('recall_precision_nsfw_Keras_with_twitter_CLIP_MU_test.pdf')
@ -437,7 +437,7 @@ plt.savefig('recall_precision_nsfw_Keras_with_twitter_CLIP_MU_test.pdf')
precision, recall, thresholds = pr_sens_prev precision, recall, thresholds = pr_sens_prev
auc_precision_recall = sklearn.metrics.auc(recall, precision) auc_precision_recall = sklearn.metrics.auc(recall, precision)
print(auc_precision_recall) LOG.info(auc_precision_recall)
plt.figure(figsize=(15, 10)) plt.figure(figsize=(15, 10))
plt.plot(recall, precision) plt.plot(recall, precision)
@ -446,12 +446,12 @@ plt.xlabel("recall")
plt.ylabel("precision") plt.ylabel("precision")
ptAt50 = get_point_for_recall(0.5, recall, precision) ptAt50 = get_point_for_recall(0.5, recall, precision)
print(ptAt50) LOG.info(ptAt50)
plt.plot( [ptAt50[0],ptAt50[0]], [0,ptAt50[1]], 'r') plt.plot( [ptAt50[0],ptAt50[0]], [0,ptAt50[1]], 'r')
plt.plot([0, ptAt50[0]], [ptAt50[1], ptAt50[1]], 'r') plt.plot([0, ptAt50[0]], [ptAt50[1], ptAt50[1]], 'r')
ptAt90 = get_point_for_recall(0.9, recall, precision) ptAt90 = get_point_for_recall(0.9, recall, precision)
print(ptAt90) LOG.info(ptAt90)
plt.plot( [ptAt90[0],ptAt90[0]], [0,ptAt90[1]], 'b') plt.plot( [ptAt90[0],ptAt90[0]], [0,ptAt90[1]], 'b')
plt.plot([0, ptAt90[0]], [ptAt90[1], ptAt90[1]], 'b') plt.plot([0, ptAt90[0]], [ptAt90[1], ptAt90[1]], 'b')
@ -460,7 +460,7 @@ ptAt90fmt = "%.4f" % ptAt90[1]
aucFmt = "%.4f" % auc_precision_recall aucFmt = "%.4f" % auc_precision_recall
plt.title( plt.title(
f"Keras (nsfw sens prev test)\nAUC={aucFmt}\np={ptAt50fmt} @ r=0.5\np={ptAt90fmt} @ r=0.9\nN_train={...} ({...} pos), N_test={n_test_sens_prev} ({n_test_sens_prev_pos} pos)", f"Keras (nsfw sens prev test)\nAUC={aucFmt}\np={ptAt50fmt} @ r=0.5\np={ptAt90fmt} @ r=0.9\nN_train={...} ({...} pos), N_test={n_test_sens_prev} ({n_test_sens_prev_pos} pos)",
size=20 fontsize=20
) )
plt.subplots_adjust(top=0.72) plt.subplots_adjust(top=0.72)
plt.savefig('recall_precision_nsfw_Keras_with_twitter_CLIP_sens_prev_test.pdf') plt.savefig('recall_precision_nsfw_Keras_with_twitter_CLIP_sens_prev_test.pdf')