2023-04-01 00:36:31 +02:00
|
|
|
import tensorflow as tf
|
|
|
|
|
2023-04-17 06:19:03 +02:00
|
|
|
physical_devices = tf.config.list_physical_devices("GPU")
|
2023-04-01 00:36:31 +02:00
|
|
|
for device in physical_devices:
|
|
|
|
tf.config.experimental.set_memory_growth(device, True)
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
import pandas as pd
|
|
|
|
import utils
|
2023-04-17 06:19:03 +02:00
|
|
|
from twitter.cuad.representation.models.optimization import create_optimizer
|
|
|
|
from twitter.cuad.representation.models.text_encoder import TextEncoder
|
|
|
|
from twitter.hmli.nimbus.modeling.feature_encoder import FeatureEncoder
|
|
|
|
from twitter.hmli.nimbus.modeling.feature_loader import BigQueryFeatureLoader
|
|
|
|
from twitter.hmli.nimbus.modeling.model_config import (
|
|
|
|
EncodingType,
|
|
|
|
Feature,
|
|
|
|
FeatureType,
|
|
|
|
LogType,
|
|
|
|
Model,
|
|
|
|
)
|
2023-04-01 00:36:31 +02:00
|
|
|
|
2023-04-17 06:19:03 +02:00
|
|
|
cat_names = [...]
|
2023-04-01 00:36:31 +02:00
|
|
|
|
2023-04-17 06:19:03 +02:00
|
|
|
category_features = [
|
|
|
|
Feature(name=cat_name, ftype=FeatureType.CONTINUOUS) for cat_name in cat_names
|
|
|
|
]
|
2023-04-01 00:36:31 +02:00
|
|
|
features = [
|
2023-04-17 06:19:03 +02:00
|
|
|
Feature(
|
|
|
|
name="tweet_text_with_media_annotations",
|
|
|
|
ftype=FeatureType.STRING,
|
|
|
|
encoding=EncodingType.BERT,
|
|
|
|
),
|
|
|
|
Feature(name="precision_nsfw", ftype=FeatureType.CONTINUOUS),
|
|
|
|
Feature(name="has_media", ftype=FeatureType.BINARY),
|
|
|
|
Feature(name="num_media", ftype=FeatureType.DISCRETE),
|
2023-04-01 00:36:31 +02:00
|
|
|
] + category_features
|
|
|
|
|
|
|
|
ptos_prototype = Model(
|
2023-04-17 06:19:03 +02:00
|
|
|
name="ptos_prototype",
|
|
|
|
export_path="...",
|
|
|
|
features=features,
|
2023-04-01 00:36:31 +02:00
|
|
|
)
|
|
|
|
print(ptos_prototype)
|
|
|
|
|
|
|
|
cq_loader = BigQueryFeatureLoader(gcp_project=COMPUTE_PROJECT)
|
|
|
|
labels = [
|
2023-04-17 06:19:03 +02:00
|
|
|
"has_non_punitive_action",
|
|
|
|
"has_punitive_action",
|
|
|
|
"has_punitive_action_contains_self_harm",
|
|
|
|
"has_punitive_action_encourage_self_harm",
|
|
|
|
"has_punitive_action_episodic",
|
|
|
|
"has_punitive_action_episodic_hateful_conduct",
|
|
|
|
"has_punitive_action_other_abuse_policy",
|
|
|
|
"has_punitive_action_without_self_harm",
|
2023-04-01 00:36:31 +02:00
|
|
|
]
|
|
|
|
|
|
|
|
train_query = f"""
|
|
|
|
SELECT
|
|
|
|
{{feature_names}},
|
|
|
|
{",".join(labels)},
|
|
|
|
...
|
|
|
|
"""
|
|
|
|
val_query = f"""
|
|
|
|
SELECT
|
|
|
|
{{feature_names}},
|
|
|
|
{",".join(labels)},
|
|
|
|
...
|
|
|
|
"""
|
|
|
|
|
|
|
|
print(train_query)
|
|
|
|
train = cq_loader.load_features(ptos_prototype, "", "", custom_query=train_query)
|
|
|
|
val = cq_loader.load_features(ptos_prototype, "", "", custom_query=val_query)
|
|
|
|
print(train.describe(model=ptos_prototype))
|
|
|
|
|
|
|
|
params = {
|
2023-04-17 06:19:03 +02:00
|
|
|
"max_seq_lengths": 128,
|
|
|
|
"batch_size": 196,
|
|
|
|
"lr": 1e-5,
|
|
|
|
"optimizer_type": "adamw",
|
|
|
|
"warmup_steps": 0,
|
|
|
|
"cls_dropout_rate": 0.1,
|
|
|
|
"epochs": 30,
|
|
|
|
"steps_per_epoch": 5000,
|
|
|
|
"model_type": "twitter_multilingual_bert_base_cased_mlm",
|
|
|
|
"mixed_precision": True,
|
2023-04-01 00:36:31 +02:00
|
|
|
}
|
|
|
|
params
|
|
|
|
|
2023-04-17 06:19:03 +02:00
|
|
|
|
2023-04-01 00:36:31 +02:00
|
|
|
def parse_labeled_data(row_dict):
|
2023-04-17 06:19:03 +02:00
|
|
|
label = [row_dict.pop(l) for l in labels]
|
|
|
|
return row_dict, label
|
2023-04-01 00:36:31 +02:00
|
|
|
|
|
|
|
|
2023-04-17 06:19:03 +02:00
|
|
|
mirrored_strategy = tf.distribute.MirroredStrategy()
|
|
|
|
BATCH_SIZE = params["batch_size"] * mirrored_strategy.num_replicas_in_sync
|
|
|
|
|
|
|
|
train_ds = (
|
|
|
|
train.to_tf_dataset()
|
|
|
|
.map(parse_labeled_data)
|
|
|
|
.shuffle(BATCH_SIZE * 100)
|
|
|
|
.batch(BATCH_SIZE)
|
|
|
|
.repeat()
|
|
|
|
)
|
2023-04-01 00:36:31 +02:00
|
|
|
val_ds = val.to_tf_dataset().map(parse_labeled_data).batch(BATCH_SIZE)
|
|
|
|
|
|
|
|
for record in train_ds:
|
2023-04-17 06:19:03 +02:00
|
|
|
tf.print(record)
|
|
|
|
break
|
|
|
|
|
2023-04-01 00:36:31 +02:00
|
|
|
|
|
|
|
def get_positive_weights():
|
2023-04-17 06:19:03 +02:00
|
|
|
"""Computes positive weights used for class imbalance from training data."""
|
|
|
|
label_weights_df = utils.get_label_weights(
|
|
|
|
"tos-data-media-full",
|
|
|
|
project_id="twttr-abusive-interact-prod",
|
|
|
|
dataset_id="tos_policy",
|
|
|
|
)
|
|
|
|
pos_weight_tensor = tf.cast(
|
|
|
|
label_weights_df.sort_values(by="label").positive_class_weight, dtype=tf.float32
|
|
|
|
)
|
|
|
|
return pos_weight_tensor
|
|
|
|
|
2023-04-01 00:36:31 +02:00
|
|
|
|
|
|
|
pos_weight_tensor = get_positive_weights()
|
|
|
|
print(pos_weight_tensor)
|
|
|
|
|
2023-04-17 06:19:03 +02:00
|
|
|
|
2023-04-01 00:36:31 +02:00
|
|
|
class TextEncoderPooledOutput(TextEncoder):
|
2023-04-17 06:19:03 +02:00
|
|
|
def call(self, x):
|
|
|
|
return super().call([x])["pooled_output"]
|
|
|
|
|
|
|
|
def get_config(self):
|
|
|
|
return super().get_config()
|
2023-04-01 00:36:31 +02:00
|
|
|
|
|
|
|
|
|
|
|
with mirrored_strategy.scope():
|
2023-04-17 06:19:03 +02:00
|
|
|
text_encoder_pooled_output = TextEncoderPooledOutput(
|
|
|
|
params["max_seq_lengths"], model_type=params["model_type"], trainable=True
|
|
|
|
)
|
|
|
|
|
|
|
|
fe = FeatureEncoder(train)
|
|
|
|
inputs, preprocessing_head = fe.build_model_head(
|
|
|
|
model=ptos_prototype, text_encoder=text_encoder_pooled_output
|
|
|
|
)
|
|
|
|
|
|
|
|
cls_dropout = tf.keras.layers.Dropout(
|
|
|
|
params["cls_dropout_rate"], name="cls_dropout"
|
|
|
|
)
|
|
|
|
outputs = cls_dropout(preprocessing_head)
|
|
|
|
outputs = tf.keras.layers.Dense(8, name="output", dtype="float32")(outputs)
|
|
|
|
|
|
|
|
model = tf.keras.Model(inputs=inputs, outputs=outputs)
|
|
|
|
pr_auc = tf.keras.metrics.AUC(
|
|
|
|
curve="PR", num_thresholds=1000, multi_label=True, from_logits=True
|
|
|
|
)
|
|
|
|
|
|
|
|
custom_loss = lambda y_true, y_pred: utils.multilabel_weighted_loss(
|
|
|
|
y_true, y_pred, weights=pos_weight_tensor
|
|
|
|
)
|
|
|
|
optimizer = create_optimizer(
|
|
|
|
init_lr=params["lr"],
|
|
|
|
num_train_steps=(params["epochs"] * params["steps_per_epoch"]),
|
|
|
|
num_warmup_steps=params["warmup_steps"],
|
|
|
|
optimizer_type=params["optimizer_type"],
|
|
|
|
)
|
|
|
|
if params.get("mixed_precision"):
|
|
|
|
optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
|
|
|
|
optimizer
|
|
|
|
)
|
|
|
|
|
|
|
|
model.compile(optimizer=optimizer, loss=custom_loss, metrics=[pr_auc])
|
2023-04-01 00:36:31 +02:00
|
|
|
|
|
|
|
model.weights
|
|
|
|
model.summary()
|
|
|
|
pr_auc.name
|
|
|
|
|
|
|
|
import getpass
|
2023-04-17 06:19:03 +02:00
|
|
|
|
2023-04-01 00:36:31 +02:00
|
|
|
import wandb
|
|
|
|
from wandb.keras import WandbCallback
|
2023-04-17 06:19:03 +02:00
|
|
|
|
2023-04-01 00:36:31 +02:00
|
|
|
try:
|
2023-04-17 06:19:03 +02:00
|
|
|
wandb_key = ...
|
|
|
|
wandb.login(...)
|
|
|
|
run = wandb.init(
|
|
|
|
project="ptos_with_media",
|
|
|
|
group="new-split-trains",
|
|
|
|
notes="tweet text with only (num_media, precision_nsfw). on full train set, new split.",
|
|
|
|
entity="absv",
|
|
|
|
config=params,
|
|
|
|
name="tweet-text-w-nsfw-1.1",
|
|
|
|
sync_tensorboard=True,
|
|
|
|
)
|
2023-04-01 00:36:31 +02:00
|
|
|
except FileNotFoundError:
|
2023-04-17 06:19:03 +02:00
|
|
|
print("Wandb key not found")
|
|
|
|
run = wandb.init(mode="disabled")
|
2023-04-01 00:36:31 +02:00
|
|
|
import datetime
|
|
|
|
import os
|
|
|
|
|
|
|
|
start_train_time = datetime.datetime.now()
|
|
|
|
print(start_train_time.strftime("%m-%d-%Y (%H:%M:%S)"))
|
|
|
|
checkpoint_path = os.path.join("...")
|
|
|
|
print("Saving model checkpoints here: ", checkpoint_path)
|
|
|
|
|
|
|
|
cp_callback = tf.keras.callbacks.ModelCheckpoint(
|
2023-04-17 06:19:03 +02:00
|
|
|
filepath=os.path.join(checkpoint_path, "model.{epoch:04d}.tf"),
|
|
|
|
verbose=1,
|
|
|
|
monitor=f"val_{pr_auc.name}",
|
|
|
|
mode="max",
|
|
|
|
save_freq="epoch",
|
|
|
|
save_best_only=True,
|
2023-04-01 00:36:31 +02:00
|
|
|
)
|
|
|
|
|
2023-04-17 06:19:03 +02:00
|
|
|
early_stopping_callback = tf.keras.callbacks.EarlyStopping(
|
|
|
|
patience=7, monitor=f"val_{pr_auc.name}", mode="max"
|
|
|
|
)
|
2023-04-01 00:36:31 +02:00
|
|
|
|
2023-04-17 06:19:03 +02:00
|
|
|
model.fit(
|
|
|
|
train_ds,
|
|
|
|
epochs=params["epochs"],
|
|
|
|
validation_data=val_ds,
|
|
|
|
callbacks=[cp_callback, early_stopping_callback],
|
|
|
|
steps_per_epoch=params["steps_per_epoch"],
|
|
|
|
verbose=2,
|
|
|
|
)
|
2023-04-01 00:36:31 +02:00
|
|
|
|
|
|
|
import tensorflow_hub as hub
|
|
|
|
|
|
|
|
gs_model_path = ...
|
|
|
|
reloaded_keras_layer = hub.KerasLayer(gs_model_path)
|
2023-04-17 06:19:03 +02:00
|
|
|
inputs = tf.keras.layers.Input(
|
|
|
|
name="tweet__core__tweet__text", shape=(1,), dtype=tf.string
|
|
|
|
)
|
2023-04-01 00:36:31 +02:00
|
|
|
output = reloaded_keras_layer(inputs)
|
|
|
|
v7_model = tf.keras.models.Model(inputs=inputs, outputs=output)
|
|
|
|
pr_auc = tf.keras.metrics.AUC(curve="PR", name="pr_auc")
|
|
|
|
roc_auc = tf.keras.metrics.AUC(curve="ROC", name="roc_auc")
|
|
|
|
v7_model.compile(metrics=[pr_auc, roc_auc])
|
|
|
|
|
|
|
|
model.load_weights("...")
|
|
|
|
candidate_model = model
|
|
|
|
|
|
|
|
with mirrored_strategy.scope():
|
2023-04-17 06:19:03 +02:00
|
|
|
candidate_eval = candidate_model.evaluate(val_ds)
|
2023-04-01 00:36:31 +02:00
|
|
|
|
|
|
|
test_query = f"""
|
|
|
|
SELECT
|
|
|
|
{",".join(ptos_prototype.feature_names())},
|
|
|
|
has_media,
|
|
|
|
precision_nsfw,
|
|
|
|
{",".join(labels)},
|
|
|
|
...
|
|
|
|
"""
|
|
|
|
|
|
|
|
test = cq_loader.load_features(ptos_prototype, "", "", custom_query=test_query)
|
|
|
|
test = test.to_tf_dataset().map(parse_labeled_data)
|
|
|
|
|
|
|
|
print(test)
|
|
|
|
|
|
|
|
test_only_media = test.filter(lambda x, y: tf.equal(x["has_media"], True))
|
|
|
|
test_only_nsfw = test.filter(lambda x, y: tf.greater_equal(x["precision_nsfw"], 0.95))
|
|
|
|
test_no_media = test.filter(lambda x, y: tf.equal(x["has_media"], False))
|
2023-04-17 06:19:03 +02:00
|
|
|
test_media_not_nsfw = test.filter(
|
|
|
|
lambda x, y: tf.logical_and(
|
|
|
|
tf.equal(x["has_media"], True), tf.less(x["precision_nsfw"], 0.95)
|
|
|
|
)
|
|
|
|
)
|
2023-04-01 00:36:31 +02:00
|
|
|
for d in [test, test_only_media, test_only_nsfw, test_no_media, test_media_not_nsfw]:
|
2023-04-17 06:19:03 +02:00
|
|
|
print(d.reduce(0, lambda x, _: x + 1).numpy())
|
2023-04-01 00:36:31 +02:00
|
|
|
|
|
|
|
from dataclasses import asdict
|
|
|
|
|
2023-04-17 06:19:03 +02:00
|
|
|
from notebook_eval_utils import EvalConfig, SparseMultilabelEvaluator
|
|
|
|
|
|
|
|
|
2023-04-01 00:36:31 +02:00
|
|
|
def display_metrics(probs, targets, labels=labels):
|
2023-04-17 06:19:03 +02:00
|
|
|
eval_config = EvalConfig(prediction_threshold=0.5, precision_k=0.9)
|
|
|
|
for eval_mode, y_mask in [("implicit", np.ones(targets.shape))]:
|
|
|
|
print("Evaluation mode", eval_mode)
|
|
|
|
metrics = SparseMultilabelEvaluator.evaluate(
|
|
|
|
targets, np.array(probs), y_mask, classes=labels, eval_config=eval_config
|
|
|
|
)
|
|
|
|
metrics_df = pd.DataFrame.from_dict(
|
|
|
|
asdict(metrics)["per_topic_metrics"]
|
|
|
|
).transpose()
|
|
|
|
metrics_df["pos_to_neg"] = metrics_df["num_pos_samples"] / (
|
|
|
|
metrics_df["num_neg_samples"] + 1
|
|
|
|
)
|
|
|
|
display(metrics_df.median())
|
|
|
|
display(metrics_df)
|
|
|
|
return metrics_df
|
2023-04-01 00:36:31 +02:00
|
|
|
|
|
|
|
|
|
|
|
def eval_model(model, df):
|
2023-04-17 06:19:03 +02:00
|
|
|
with mirrored_strategy.scope():
|
|
|
|
targets = np.stack(list(df.map(lambda x, y: y).as_numpy_iterator()), axis=0)
|
|
|
|
df = df.padded_batch(BATCH_SIZE)
|
|
|
|
preds = model.predict(df)
|
|
|
|
return display_metrics(preds, targets)
|
|
|
|
|
|
|
|
|
|
|
|
subsets = {
|
|
|
|
"test": test,
|
|
|
|
"test_only_media": test_only_media,
|
|
|
|
"test_only_nsfw": test_only_nsfw,
|
|
|
|
"test_no_media": test_no_media,
|
|
|
|
"test_media_not_nsfw": test_media_not_nsfw,
|
|
|
|
}
|
2023-04-01 00:36:31 +02:00
|
|
|
|
|
|
|
metrics = {}
|
|
|
|
for name, df in subsets.items():
|
2023-04-17 06:19:03 +02:00
|
|
|
metrics[name] = eval_model(candidate_model, df)
|
2023-04-01 00:36:31 +02:00
|
|
|
[(name, m.pr_auc) for name, m in metrics.items()]
|
2023-04-17 06:19:03 +02:00
|
|
|
for name, x in [
|
|
|
|
(name, m.pr_auc.to_string(index=False).strip().split("\n"))
|
|
|
|
for name, m in metrics.items()
|
|
|
|
]:
|
|
|
|
print(name)
|
|
|
|
for y in x:
|
|
|
|
print(y.strip(), end="\t")
|
|
|
|
print(".")
|
2023-04-01 00:36:31 +02:00
|
|
|
for d in [test, test_only_media, test_only_nsfw, test_no_media, test_media_not_nsfw]:
|
2023-04-17 06:19:03 +02:00
|
|
|
print(d.reduce(0, lambda x, _: x + 1).numpy())
|