import os from toxicity_ml_pipeline.settings.default_settings_tox import LOCAL_DIR, MAX_SEQ_LENGTH try: from toxicity_ml_pipeline.optim.losses import MaskedBCE except ImportError: print("No MaskedBCE loss") import tensorflow as tf from toxicity_ml_pipeline.utils.helpers import execute_command try: from twitter.cuad.representation.models.text_encoder import TextEncoder except ModuleNotFoundError: print("No TextEncoder package") try: from transformers import TFAutoModelForSequenceClassification except ModuleNotFoundError: print("No HuggingFace package") LOCAL_MODEL_DIR = os.path.join(LOCAL_DIR, "models") def reload_model_weights(weights_dir, language, **kwargs): optimizer = tf.keras.optimizers.Adam(0.01) model_type = ( "twitter_bert_base_en_uncased_mlm" if language == "en" else "twitter_multilingual_bert_base_cased_mlm" ) model = load(optimizer=optimizer, seed=42, model_type=model_type, **kwargs) model.load_weights(weights_dir) return model def _locally_copy_models(model_type): if model_type == "twitter_multilingual_bert_base_cased_mlm": preprocessor = "bert_multi_cased_preprocess_3" elif model_type == "twitter_bert_base_en_uncased_mlm": preprocessor = "bert_en_uncased_preprocess_3" else: raise NotImplementedError copy_cmd = """mkdir {local_dir} gsutil cp -r ... gsutil cp -r ...""" execute_command( copy_cmd.format( model_type=model_type, preprocessor=preprocessor, local_dir=LOCAL_MODEL_DIR ) ) return preprocessor def load_encoder(model_type, trainable): try: model = TextEncoder( max_seq_lengths=MAX_SEQ_LENGTH, model_type=model_type, cluster="gcp", trainable=trainable, enable_dynamic_shapes=True, ) except (OSError, tf.errors.AbortedError) as e: print(e) preprocessor = _locally_copy_models(model_type) model = TextEncoder( max_seq_lengths=MAX_SEQ_LENGTH, local_model_path=f"models/{model_type}", local_preprocessor_path=f"models/{preprocessor}", cluster="gcp", trainable=trainable, enable_dynamic_shapes=True, ) return model def get_loss(loss_name, from_logits, **kwargs): loss_name = loss_name.lower() if loss_name == "bce": print("Binary CE loss") return tf.keras.losses.BinaryCrossentropy(from_logits=from_logits) if loss_name == "cce": print("Categorical cross-entropy loss") return tf.keras.losses.CategoricalCrossentropy(from_logits=from_logits) if loss_name == "scce": print("Sparse categorical cross-entropy loss") return tf.keras.losses.SparseCategoricalCrossentropy(from_logits=from_logits) if loss_name == "focal_bce": gamma = kwargs.get("gamma", 2) print("Focal binary CE loss", gamma) return tf.keras.losses.BinaryFocalCrossentropy( gamma=gamma, from_logits=from_logits ) if loss_name == "masked_bce": multitask = kwargs.get("multitask", False) if from_logits or multitask: raise NotImplementedError print(f"Masked Binary Cross Entropy") return MaskedBCE() if loss_name == "inv_kl_loss": raise NotImplementedError raise ValueError( f"This loss name is not valid: {loss_name}. Accepted loss names: BCE, masked BCE, CCE, sCCE, " f"Focal_BCE, inv_KL_loss" ) def _add_additional_embedding_layer(doc_embedding, glorot, seed): doc_embedding = tf.keras.layers.Dense( 768, activation="tanh", kernel_initializer=glorot )(doc_embedding) doc_embedding = tf.keras.layers.Dropout(rate=0.1, seed=seed)(doc_embedding) return doc_embedding def _get_bias(**kwargs): smart_bias_value = kwargs.get("smart_bias_value", 0) print("Smart bias init to ", smart_bias_value) output_bias = tf.keras.initializers.Constant(smart_bias_value) return output_bias def load_inhouse_bert(model_type, trainable, seed, **kwargs): inputs = tf.keras.layers.Input(shape=(), dtype=tf.string) encoder = load_encoder(model_type=model_type, trainable=trainable) doc_embedding = encoder([inputs])["pooled_output"] doc_embedding = tf.keras.layers.Dropout(rate=0.1, seed=seed)(doc_embedding) glorot = tf.keras.initializers.glorot_uniform(seed=seed) if kwargs.get("additional_layer", False): doc_embedding = _add_additional_embedding_layer(doc_embedding, glorot, seed) if kwargs.get("content_num_classes", None): probs = get_last_layer( glorot=glorot, last_layer_name="target_output", **kwargs )(doc_embedding) second_probs = get_last_layer( num_classes=kwargs["content_num_classes"], last_layer_name="content_output", glorot=glorot, )(doc_embedding) probs = [probs, second_probs] else: probs = get_last_layer(glorot=glorot, **kwargs)(doc_embedding) model = tf.keras.models.Model(inputs=inputs, outputs=probs) return model, False def get_last_layer(**kwargs): output_bias = _get_bias(**kwargs) if "glorot" in kwargs: glorot = kwargs["glorot"] else: glorot = tf.keras.initializers.glorot_uniform(seed=kwargs["seed"]) layer_name = kwargs.get("last_layer_name", "dense_1") if kwargs.get("num_classes", 1) > 1: last_layer = tf.keras.layers.Dense( kwargs["num_classes"], activation="softmax", kernel_initializer=glorot, bias_initializer=output_bias, name=layer_name, ) elif kwargs.get("num_raters", 1) > 1: if kwargs.get("multitask", False): raise NotImplementedError last_layer = tf.keras.layers.Dense( kwargs["num_raters"], activation="sigmoid", kernel_initializer=glorot, bias_initializer=output_bias, name="probs", ) else: last_layer = tf.keras.layers.Dense( 1, activation="sigmoid", kernel_initializer=glorot, bias_initializer=output_bias, name=layer_name, ) return last_layer def load_bertweet(**kwargs): bert = TFAutoModelForSequenceClassification.from_pretrained( os.path.join(LOCAL_MODEL_DIR, "bertweet-base"), num_labels=1, classifier_dropout=0.1, hidden_size=768, ) if "num_classes" in kwargs and kwargs["num_classes"] > 2: raise NotImplementedError return bert, True def load( optimizer, seed, model_type="twitter_multilingual_bert_base_cased_mlm", loss_name="BCE", trainable=True, **kwargs, ): if model_type == "bertweet-base": model, from_logits = load_bertweet() else: model, from_logits = load_inhouse_bert(model_type, trainable, seed, **kwargs) pr_auc = tf.keras.metrics.AUC(curve="PR", name="pr_auc", from_logits=from_logits) roc_auc = tf.keras.metrics.AUC(curve="ROC", name="roc_auc", from_logits=from_logits) loss = get_loss(loss_name, from_logits, **kwargs) if kwargs.get("content_num_classes", None): second_loss = get_loss( loss_name=kwargs["content_loss_name"], from_logits=from_logits ) loss_weights = { "content_output": kwargs["content_loss_weight"], "target_output": 1, } model.compile( optimizer=optimizer, loss={"content_output": second_loss, "target_output": loss}, loss_weights=loss_weights, metrics=[pr_auc, roc_auc], ) else: model.compile( optimizer=optimizer, loss=loss, metrics=[pr_auc, roc_auc], ) print(model.summary(), "logits: ", from_logits) return model