This commit is contained in:
Ankit Singh 2023-05-22 17:37:30 -05:00 committed by GitHub
commit 95e3ab2e25
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 3 additions and 13 deletions

View File

@ -10,4 +10,4 @@ if __name__ == "__main__":
lolly_model_scorer = LollyModelScorer(data_example_parser=DBv2DataExampleParser(lolly_model_reader))
score = lolly_model_scorer.score(data_example=sys.argv[2])
print(score)

View File

@ -31,7 +31,7 @@ ptos_prototype = Model(
export_path="...",
features=features,
)
print(ptos_prototype)
cq_loader = BigQueryFeatureLoader(gcp_project=COMPUTE_PROJECT)
labels = [
@ -58,7 +58,6 @@ SELECT
...
"""
print(train_query)
train = cq_loader.load_features(ptos_prototype, "", "", custom_query=train_query)
val = cq_loader.load_features(ptos_prototype, "", "", custom_query=val_query)
print(train.describe(model=ptos_prototype))
@ -105,7 +104,7 @@ def get_positive_weights():
return pos_weight_tensor
pos_weight_tensor = get_positive_weights()
print(pos_weight_tensor)
class TextEncoderPooledOutput(TextEncoder):
def call(self, x):
@ -224,7 +223,6 @@ SELECT
test = cq_loader.load_features(ptos_prototype, "", "", custom_query=test_query)
test = test.to_tf_dataset().map(parse_labeled_data)
print(test)
test_only_media = test.filter(lambda x, y: tf.equal(x["has_media"], True))
test_only_nsfw = test.filter(lambda x, y: tf.greater_equal(x["precision_nsfw"], 0.95))
@ -268,7 +266,6 @@ for name, df in subsets.items():
metrics[name] = eval_model(candidate_model, df)
[(name, m.pr_auc) for name, m in metrics.items()]
for name, x in [(name, m.pr_auc.to_string(index=False).strip().split("\n")) for name, m in metrics.items()]:
print(name)
for y in x:
print(y.strip(), end="\t")
print(".")

View File

@ -406,7 +406,6 @@ precision, recall, thresholds = pr
auc_precision_recall = sklearn.metrics.auc(recall, precision)
print(auc_precision_recall)
plt.figure(figsize=(15, 10))
plt.plot(recall, precision)
@ -415,12 +414,10 @@ plt.xlabel("recall")
plt.ylabel("precision")
ptAt50 = get_point_for_recall(0.5, recall, precision)
print(ptAt50)
plt.plot( [ptAt50[0],ptAt50[0]], [0,ptAt50[1]], 'r')
plt.plot([0, ptAt50[0]], [ptAt50[1], ptAt50[1]], 'r')
ptAt90 = get_point_for_recall(0.9, recall, precision)
print(ptAt90)
plt.plot( [ptAt90[0],ptAt90[0]], [0,ptAt90[1]], 'b')
plt.plot([0, ptAt90[0]], [ptAt90[1], ptAt90[1]], 'b')
@ -437,7 +434,6 @@ plt.savefig('recall_precision_nsfw_Keras_with_twitter_CLIP_MU_test.pdf')
precision, recall, thresholds = pr_sens_prev
auc_precision_recall = sklearn.metrics.auc(recall, precision)
print(auc_precision_recall)
plt.figure(figsize=(15, 10))
plt.plot(recall, precision)
@ -446,12 +442,10 @@ plt.xlabel("recall")
plt.ylabel("precision")
ptAt50 = get_point_for_recall(0.5, recall, precision)
print(ptAt50)
plt.plot( [ptAt50[0],ptAt50[0]], [0,ptAt50[1]], 'r')
plt.plot([0, ptAt50[0]], [ptAt50[1], ptAt50[1]], 'r')
ptAt90 = get_point_for_recall(0.9, recall, precision)
print(ptAt90)
plt.plot( [ptAt90[0],ptAt90[0]], [0,ptAt90[1]], 'b')
plt.plot([0, ptAt90[0]], [ptAt90[1], ptAt90[1]], 'b')

View File

@ -98,7 +98,6 @@ class Trainer(object):
for var, default_value in experiment_settings["default"].items():
override_val = experiment_settings[experiment_id].get(var, default_value)
print("Setting ", var, override_val)
self.__setattr__(var, override_val)
self.content_loss_weight = content_loss_weight if self.dual_head else None