import datetime import os from dataclasses import asdict import numpy as np import pandas as pd import tensorflow as tf import tensorflow_hub as hub import utils import wandb try: wandb_key = ... wandb.login(...) run = wandb.init(project='ptos_with_media', group='new-split-trains', notes='tweet text with only (num_media, precision_nsfw). on full train set, new split.', entity='absv', config=params, name='tweet-text-w-nsfw-1.1', sync_tensorboard=True) except FileNotFoundError: print('Wandb key not found') run = wandb.init(mode='disabled') from notebook_eval_utils import EvalConfig, SparseMultilabelEvaluator from twitter.cuad.representation.models.optimization import create_optimizer from twitter.cuad.representation.models.text_encoder import TextEncoder from twitter.hmli.nimbus.modeling.feature_encoder import FeatureEncoder from twitter.hmli.nimbus.modeling.feature_loader import BigQueryFeatureLoader from twitter.hmli.nimbus.modeling.model_config import ( EncodingType, Feature, FeatureType, Model, ) physical_devices = tf.config.list_physical_devices('GPU') for device in physical_devices: tf.config.experimental.set_memory_growth(device, True) cat_names = [ ... ] category_features = [Feature(name=cat_name, ftype=FeatureType.CONTINUOUS) for cat_name in cat_names] features = [ Feature(name="tweet_text_with_media_annotations", ftype=FeatureType.STRING, encoding=EncodingType.BERT), Feature(name="precision_nsfw", ftype=FeatureType.CONTINUOUS), Feature(name="has_media", ftype=FeatureType.BINARY), Feature(name="num_media", ftype=FeatureType.DISCRETE) ] + category_features ptos_prototype = Model( name='ptos_prototype', export_path="...", features=features, ) print(ptos_prototype) cq_loader = BigQueryFeatureLoader(gcp_project=COMPUTE_PROJECT) labels = [ "has_non_punitive_action", "has_punitive_action", "has_punitive_action_contains_self_harm", "has_punitive_action_encourage_self_harm", "has_punitive_action_episodic", "has_punitive_action_episodic_hateful_conduct", "has_punitive_action_other_abuse_policy", "has_punitive_action_without_self_harm" ] train_query = f""" SELECT {{feature_names}}, {",".join(labels)}, ... """ val_query = f""" SELECT {{feature_names}}, {",".join(labels)}, ... """ print(train_query) train = cq_loader.load_features(ptos_prototype, "", "", custom_query=train_query) val = cq_loader.load_features(ptos_prototype, "", "", custom_query=val_query) print(train.describe(model=ptos_prototype)) params = { 'max_seq_lengths': 128, 'batch_size': 196, 'lr': 1e-5, 'optimizer_type': 'adamw', 'warmup_steps': 0, 'cls_dropout_rate': 0.1, 'epochs': 30, 'steps_per_epoch': 5000, 'model_type': 'twitter_multilingual_bert_base_cased_mlm', 'mixed_precision': True, } def parse_labeled_data(row_dict): label = [row_dict.pop(l) for l in labels] return row_dict, label mirrored_strategy = tf.distribute.MirroredStrategy() BATCH_SIZE = params['batch_size'] * mirrored_strategy.num_replicas_in_sync train_ds = train.to_tf_dataset().map(parse_labeled_data).shuffle(BATCH_SIZE*100).batch(BATCH_SIZE).repeat() val_ds = val.to_tf_dataset().map(parse_labeled_data).batch(BATCH_SIZE) for record in train_ds: tf.print(record) break def get_positive_weights(): """Computes positive weights used for class imbalance from training data.""" label_weights_df = utils.get_label_weights( "tos-data-media-full", project_id="twttr-abusive-interact-prod", dataset_id="tos_policy" ) pos_weight_tensor = tf.cast( label_weights_df.sort_values(by='label').positive_class_weight, dtype=tf.float32 ) return pos_weight_tensor pos_weight_tensor = get_positive_weights() print(pos_weight_tensor) class TextEncoderPooledOutput(TextEncoder): def call(self, x): return super().call([x])["pooled_output"] def get_config(self): return super().get_config() with mirrored_strategy.scope(): text_encoder_pooled_output = TextEncoderPooledOutput( params['max_seq_lengths'], model_type=params['model_type'], trainable=True ) fe = FeatureEncoder(train) inputs, preprocessing_head = fe.build_model_head(model=ptos_prototype, text_encoder=text_encoder_pooled_output) cls_dropout = tf.keras.layers.Dropout(params['cls_dropout_rate'], name="cls_dropout") outputs = cls_dropout(preprocessing_head) outputs = tf.keras.layers.Dense(8, name="output", dtype="float32")(outputs) model = tf.keras.Model( inputs=inputs, outputs=outputs ) pr_auc = tf.keras.metrics.AUC(curve="PR", num_thresholds=1000, multi_label=True, from_logits=True) def custom_loss(y_true, y_pred): return utils.multilabel_weighted_loss(y_true, y_pred, weights=pos_weight_tensor) optimizer = create_optimizer( init_lr=params["lr"], num_train_steps=(params["epochs"] * params["steps_per_epoch"]), num_warmup_steps=params["warmup_steps"], optimizer_type=params["optimizer_type"], ) if params.get("mixed_precision"): optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(optimizer) model.compile( optimizer=optimizer, loss=custom_loss, metrics=[pr_auc] ) model.weights model.summary() pr_auc.name start_train_time = datetime.datetime.now() print(start_train_time.strftime("%m-%d-%Y (%H:%M:%S)")) checkpoint_path = os.path.join("...") print("Saving model checkpoints here: ", checkpoint_path) cp_callback = tf.keras.callbacks.ModelCheckpoint( filepath=os.path.join(checkpoint_path, "model.{epoch:04d}.tf"), verbose=1, monitor=f'val_{pr_auc.name}', mode='max', save_freq='epoch', save_best_only=True ) early_stopping_callback = tf.keras.callbacks.EarlyStopping(patience=7, monitor=f"val_{pr_auc.name}", mode="max") model.fit(train_ds, epochs=params["epochs"], validation_data=val_ds, callbacks=[cp_callback, early_stopping_callback], steps_per_epoch=params["steps_per_epoch"], verbose=2) gs_model_path = ... reloaded_keras_layer = hub.KerasLayer(gs_model_path) inputs = tf.keras.layers.Input(name="tweet__core__tweet__text", shape=(1,), dtype=tf.string) output = reloaded_keras_layer(inputs) v7_model = tf.keras.models.Model(inputs=inputs, outputs=output) pr_auc = tf.keras.metrics.AUC(curve="PR", name="pr_auc") roc_auc = tf.keras.metrics.AUC(curve="ROC", name="roc_auc") v7_model.compile(metrics=[pr_auc, roc_auc]) model.load_weights("...") candidate_model = model with mirrored_strategy.scope(): candidate_eval = candidate_model.evaluate(val_ds) test_query = f""" SELECT {",".join(ptos_prototype.feature_names())}, has_media, precision_nsfw, {",".join(labels)}, ... """ test = cq_loader.load_features(ptos_prototype, "", "", custom_query=test_query) test = test.to_tf_dataset().map(parse_labeled_data) print(test) test_only_media = test.filter(lambda x, y: tf.equal(x["has_media"], True)) test_only_nsfw = test.filter(lambda x, y: tf.greater_equal(x["precision_nsfw"], 0.95)) test_no_media = test.filter(lambda x, y: tf.equal(x["has_media"], False)) test_media_not_nsfw = test.filter(lambda x, y: tf.logical_and(tf.equal(x["has_media"], True), tf.less(x["precision_nsfw"], 0.95))) for d in [test, test_only_media, test_only_nsfw, test_no_media, test_media_not_nsfw]: print(d.reduce(0, lambda x, _: x + 1).numpy()) def display_metrics(probs, targets, labels=labels): eval_config = EvalConfig(prediction_threshold=0.5, precision_k=0.9) for eval_mode, y_mask in [("implicit", np.ones(targets.shape))]: print("Evaluation mode", eval_mode) metrics = SparseMultilabelEvaluator.evaluate( targets, np.array(probs), y_mask, classes=labels, eval_config=eval_config ) metrics_df = pd.DataFrame.from_dict(asdict(metrics)["per_topic_metrics"]).transpose() metrics_df["pos_to_neg"] = metrics_df["num_pos_samples"] / (metrics_df["num_neg_samples"] + 1) display(metrics_df.median()) display(metrics_df) return metrics_df def eval_model(model, df): with mirrored_strategy.scope(): targets = np.stack(list(df.map(lambda x, y: y).as_numpy_iterator()), axis=0) df = df.padded_batch(BATCH_SIZE) preds = model.predict(df) return display_metrics(preds, targets) subsets = {"test": test, "test_only_media": test_only_media, "test_only_nsfw": test_only_nsfw, "test_no_media": test_no_media, "test_media_not_nsfw": test_media_not_nsfw} metrics = {} for name, df in subsets.items(): metrics[name] = eval_model(candidate_model, df) [(name, m.pr_auc) for name, m in metrics.items()] for name, x in [(name, m.pr_auc.to_string(index=False).strip().split("\n")) for name, m in metrics.items()]: print(name) for y in x: print(y.strip(), end="\t") print(".") for d in [test, test_only_media, test_only_nsfw, test_no_media, test_media_not_nsfw]: print(d.reduce(0, lambda x, _: x + 1).numpy())