mirror of
https://github.com/twitter/the-algorithm.git
synced 2024-12-22 18:21:51 +01:00
Compare commits
2 Commits
659b5688d3
...
95e3ab2e25
Author | SHA1 | Date | |
---|---|---|---|
|
95e3ab2e25 | ||
|
74f37ca086 |
@ -10,4 +10,4 @@ if __name__ == "__main__":
|
|||||||
lolly_model_scorer = LollyModelScorer(data_example_parser=DBv2DataExampleParser(lolly_model_reader))
|
lolly_model_scorer = LollyModelScorer(data_example_parser=DBv2DataExampleParser(lolly_model_reader))
|
||||||
|
|
||||||
score = lolly_model_scorer.score(data_example=sys.argv[2])
|
score = lolly_model_scorer.score(data_example=sys.argv[2])
|
||||||
print(score)
|
|
||||||
|
@ -31,7 +31,7 @@ ptos_prototype = Model(
|
|||||||
export_path="...",
|
export_path="...",
|
||||||
features=features,
|
features=features,
|
||||||
)
|
)
|
||||||
print(ptos_prototype)
|
|
||||||
|
|
||||||
cq_loader = BigQueryFeatureLoader(gcp_project=COMPUTE_PROJECT)
|
cq_loader = BigQueryFeatureLoader(gcp_project=COMPUTE_PROJECT)
|
||||||
labels = [
|
labels = [
|
||||||
@ -58,7 +58,6 @@ SELECT
|
|||||||
...
|
...
|
||||||
"""
|
"""
|
||||||
|
|
||||||
print(train_query)
|
|
||||||
train = cq_loader.load_features(ptos_prototype, "", "", custom_query=train_query)
|
train = cq_loader.load_features(ptos_prototype, "", "", custom_query=train_query)
|
||||||
val = cq_loader.load_features(ptos_prototype, "", "", custom_query=val_query)
|
val = cq_loader.load_features(ptos_prototype, "", "", custom_query=val_query)
|
||||||
print(train.describe(model=ptos_prototype))
|
print(train.describe(model=ptos_prototype))
|
||||||
@ -105,7 +104,7 @@ def get_positive_weights():
|
|||||||
return pos_weight_tensor
|
return pos_weight_tensor
|
||||||
|
|
||||||
pos_weight_tensor = get_positive_weights()
|
pos_weight_tensor = get_positive_weights()
|
||||||
print(pos_weight_tensor)
|
|
||||||
|
|
||||||
class TextEncoderPooledOutput(TextEncoder):
|
class TextEncoderPooledOutput(TextEncoder):
|
||||||
def call(self, x):
|
def call(self, x):
|
||||||
@ -224,7 +223,6 @@ SELECT
|
|||||||
test = cq_loader.load_features(ptos_prototype, "", "", custom_query=test_query)
|
test = cq_loader.load_features(ptos_prototype, "", "", custom_query=test_query)
|
||||||
test = test.to_tf_dataset().map(parse_labeled_data)
|
test = test.to_tf_dataset().map(parse_labeled_data)
|
||||||
|
|
||||||
print(test)
|
|
||||||
|
|
||||||
test_only_media = test.filter(lambda x, y: tf.equal(x["has_media"], True))
|
test_only_media = test.filter(lambda x, y: tf.equal(x["has_media"], True))
|
||||||
test_only_nsfw = test.filter(lambda x, y: tf.greater_equal(x["precision_nsfw"], 0.95))
|
test_only_nsfw = test.filter(lambda x, y: tf.greater_equal(x["precision_nsfw"], 0.95))
|
||||||
@ -268,7 +266,6 @@ for name, df in subsets.items():
|
|||||||
metrics[name] = eval_model(candidate_model, df)
|
metrics[name] = eval_model(candidate_model, df)
|
||||||
[(name, m.pr_auc) for name, m in metrics.items()]
|
[(name, m.pr_auc) for name, m in metrics.items()]
|
||||||
for name, x in [(name, m.pr_auc.to_string(index=False).strip().split("\n")) for name, m in metrics.items()]:
|
for name, x in [(name, m.pr_auc.to_string(index=False).strip().split("\n")) for name, m in metrics.items()]:
|
||||||
print(name)
|
|
||||||
for y in x:
|
for y in x:
|
||||||
print(y.strip(), end="\t")
|
print(y.strip(), end="\t")
|
||||||
print(".")
|
print(".")
|
||||||
|
@ -406,7 +406,6 @@ precision, recall, thresholds = pr
|
|||||||
|
|
||||||
auc_precision_recall = sklearn.metrics.auc(recall, precision)
|
auc_precision_recall = sklearn.metrics.auc(recall, precision)
|
||||||
|
|
||||||
print(auc_precision_recall)
|
|
||||||
|
|
||||||
plt.figure(figsize=(15, 10))
|
plt.figure(figsize=(15, 10))
|
||||||
plt.plot(recall, precision)
|
plt.plot(recall, precision)
|
||||||
@ -415,12 +414,10 @@ plt.xlabel("recall")
|
|||||||
plt.ylabel("precision")
|
plt.ylabel("precision")
|
||||||
|
|
||||||
ptAt50 = get_point_for_recall(0.5, recall, precision)
|
ptAt50 = get_point_for_recall(0.5, recall, precision)
|
||||||
print(ptAt50)
|
|
||||||
plt.plot( [ptAt50[0],ptAt50[0]], [0,ptAt50[1]], 'r')
|
plt.plot( [ptAt50[0],ptAt50[0]], [0,ptAt50[1]], 'r')
|
||||||
plt.plot([0, ptAt50[0]], [ptAt50[1], ptAt50[1]], 'r')
|
plt.plot([0, ptAt50[0]], [ptAt50[1], ptAt50[1]], 'r')
|
||||||
|
|
||||||
ptAt90 = get_point_for_recall(0.9, recall, precision)
|
ptAt90 = get_point_for_recall(0.9, recall, precision)
|
||||||
print(ptAt90)
|
|
||||||
plt.plot( [ptAt90[0],ptAt90[0]], [0,ptAt90[1]], 'b')
|
plt.plot( [ptAt90[0],ptAt90[0]], [0,ptAt90[1]], 'b')
|
||||||
plt.plot([0, ptAt90[0]], [ptAt90[1], ptAt90[1]], 'b')
|
plt.plot([0, ptAt90[0]], [ptAt90[1], ptAt90[1]], 'b')
|
||||||
|
|
||||||
@ -437,7 +434,6 @@ plt.savefig('recall_precision_nsfw_Keras_with_twitter_CLIP_MU_test.pdf')
|
|||||||
precision, recall, thresholds = pr_sens_prev
|
precision, recall, thresholds = pr_sens_prev
|
||||||
|
|
||||||
auc_precision_recall = sklearn.metrics.auc(recall, precision)
|
auc_precision_recall = sklearn.metrics.auc(recall, precision)
|
||||||
print(auc_precision_recall)
|
|
||||||
plt.figure(figsize=(15, 10))
|
plt.figure(figsize=(15, 10))
|
||||||
|
|
||||||
plt.plot(recall, precision)
|
plt.plot(recall, precision)
|
||||||
@ -446,12 +442,10 @@ plt.xlabel("recall")
|
|||||||
plt.ylabel("precision")
|
plt.ylabel("precision")
|
||||||
|
|
||||||
ptAt50 = get_point_for_recall(0.5, recall, precision)
|
ptAt50 = get_point_for_recall(0.5, recall, precision)
|
||||||
print(ptAt50)
|
|
||||||
plt.plot( [ptAt50[0],ptAt50[0]], [0,ptAt50[1]], 'r')
|
plt.plot( [ptAt50[0],ptAt50[0]], [0,ptAt50[1]], 'r')
|
||||||
plt.plot([0, ptAt50[0]], [ptAt50[1], ptAt50[1]], 'r')
|
plt.plot([0, ptAt50[0]], [ptAt50[1], ptAt50[1]], 'r')
|
||||||
|
|
||||||
ptAt90 = get_point_for_recall(0.9, recall, precision)
|
ptAt90 = get_point_for_recall(0.9, recall, precision)
|
||||||
print(ptAt90)
|
|
||||||
plt.plot( [ptAt90[0],ptAt90[0]], [0,ptAt90[1]], 'b')
|
plt.plot( [ptAt90[0],ptAt90[0]], [0,ptAt90[1]], 'b')
|
||||||
plt.plot([0, ptAt90[0]], [ptAt90[1], ptAt90[1]], 'b')
|
plt.plot([0, ptAt90[0]], [ptAt90[1], ptAt90[1]], 'b')
|
||||||
|
|
||||||
|
@ -98,7 +98,6 @@ class Trainer(object):
|
|||||||
|
|
||||||
for var, default_value in experiment_settings["default"].items():
|
for var, default_value in experiment_settings["default"].items():
|
||||||
override_val = experiment_settings[experiment_id].get(var, default_value)
|
override_val = experiment_settings[experiment_id].get(var, default_value)
|
||||||
print("Setting ", var, override_val)
|
|
||||||
self.__setattr__(var, override_val)
|
self.__setattr__(var, override_val)
|
||||||
|
|
||||||
self.content_loss_weight = content_loss_weight if self.dual_head else None
|
self.content_loss_weight = content_loss_weight if self.dual_head else None
|
||||||
|
Loading…
Reference in New Issue
Block a user