mirror of
https://github.com/twitter/the-algorithm.git
synced 2025-01-08 18:30:43 +01:00
Merge 74f37ca086
into fb54d8b549
This commit is contained in:
commit
95e3ab2e25
@ -10,4 +10,4 @@ if __name__ == "__main__":
|
||||
lolly_model_scorer = LollyModelScorer(data_example_parser=DBv2DataExampleParser(lolly_model_reader))
|
||||
|
||||
score = lolly_model_scorer.score(data_example=sys.argv[2])
|
||||
print(score)
|
||||
|
||||
|
@ -31,7 +31,7 @@ ptos_prototype = Model(
|
||||
export_path="...",
|
||||
features=features,
|
||||
)
|
||||
print(ptos_prototype)
|
||||
|
||||
|
||||
cq_loader = BigQueryFeatureLoader(gcp_project=COMPUTE_PROJECT)
|
||||
labels = [
|
||||
@ -58,7 +58,6 @@ SELECT
|
||||
...
|
||||
"""
|
||||
|
||||
print(train_query)
|
||||
train = cq_loader.load_features(ptos_prototype, "", "", custom_query=train_query)
|
||||
val = cq_loader.load_features(ptos_prototype, "", "", custom_query=val_query)
|
||||
print(train.describe(model=ptos_prototype))
|
||||
@ -105,7 +104,7 @@ def get_positive_weights():
|
||||
return pos_weight_tensor
|
||||
|
||||
pos_weight_tensor = get_positive_weights()
|
||||
print(pos_weight_tensor)
|
||||
|
||||
|
||||
class TextEncoderPooledOutput(TextEncoder):
|
||||
def call(self, x):
|
||||
@ -224,7 +223,6 @@ SELECT
|
||||
test = cq_loader.load_features(ptos_prototype, "", "", custom_query=test_query)
|
||||
test = test.to_tf_dataset().map(parse_labeled_data)
|
||||
|
||||
print(test)
|
||||
|
||||
test_only_media = test.filter(lambda x, y: tf.equal(x["has_media"], True))
|
||||
test_only_nsfw = test.filter(lambda x, y: tf.greater_equal(x["precision_nsfw"], 0.95))
|
||||
@ -268,7 +266,6 @@ for name, df in subsets.items():
|
||||
metrics[name] = eval_model(candidate_model, df)
|
||||
[(name, m.pr_auc) for name, m in metrics.items()]
|
||||
for name, x in [(name, m.pr_auc.to_string(index=False).strip().split("\n")) for name, m in metrics.items()]:
|
||||
print(name)
|
||||
for y in x:
|
||||
print(y.strip(), end="\t")
|
||||
print(".")
|
||||
|
@ -406,7 +406,6 @@ precision, recall, thresholds = pr
|
||||
|
||||
auc_precision_recall = sklearn.metrics.auc(recall, precision)
|
||||
|
||||
print(auc_precision_recall)
|
||||
|
||||
plt.figure(figsize=(15, 10))
|
||||
plt.plot(recall, precision)
|
||||
@ -415,12 +414,10 @@ plt.xlabel("recall")
|
||||
plt.ylabel("precision")
|
||||
|
||||
ptAt50 = get_point_for_recall(0.5, recall, precision)
|
||||
print(ptAt50)
|
||||
plt.plot( [ptAt50[0],ptAt50[0]], [0,ptAt50[1]], 'r')
|
||||
plt.plot([0, ptAt50[0]], [ptAt50[1], ptAt50[1]], 'r')
|
||||
|
||||
ptAt90 = get_point_for_recall(0.9, recall, precision)
|
||||
print(ptAt90)
|
||||
plt.plot( [ptAt90[0],ptAt90[0]], [0,ptAt90[1]], 'b')
|
||||
plt.plot([0, ptAt90[0]], [ptAt90[1], ptAt90[1]], 'b')
|
||||
|
||||
@ -437,7 +434,6 @@ plt.savefig('recall_precision_nsfw_Keras_with_twitter_CLIP_MU_test.pdf')
|
||||
precision, recall, thresholds = pr_sens_prev
|
||||
|
||||
auc_precision_recall = sklearn.metrics.auc(recall, precision)
|
||||
print(auc_precision_recall)
|
||||
plt.figure(figsize=(15, 10))
|
||||
|
||||
plt.plot(recall, precision)
|
||||
@ -446,12 +442,10 @@ plt.xlabel("recall")
|
||||
plt.ylabel("precision")
|
||||
|
||||
ptAt50 = get_point_for_recall(0.5, recall, precision)
|
||||
print(ptAt50)
|
||||
plt.plot( [ptAt50[0],ptAt50[0]], [0,ptAt50[1]], 'r')
|
||||
plt.plot([0, ptAt50[0]], [ptAt50[1], ptAt50[1]], 'r')
|
||||
|
||||
ptAt90 = get_point_for_recall(0.9, recall, precision)
|
||||
print(ptAt90)
|
||||
plt.plot( [ptAt90[0],ptAt90[0]], [0,ptAt90[1]], 'b')
|
||||
plt.plot([0, ptAt90[0]], [ptAt90[1], ptAt90[1]], 'b')
|
||||
|
||||
|
@ -98,7 +98,6 @@ class Trainer(object):
|
||||
|
||||
for var, default_value in experiment_settings["default"].items():
|
||||
override_val = experiment_settings[experiment_id].get(var, default_value)
|
||||
print("Setting ", var, override_val)
|
||||
self.__setattr__(var, override_val)
|
||||
|
||||
self.content_loss_weight = content_loss_weight if self.dual_head else None
|
||||
|
Loading…
Reference in New Issue
Block a user