the-algorithm/trust_and_safety_models/toxicity/load_model.py
2023-04-17 09:49:03 +05:30

256 lines
7.8 KiB
Python

import os
from toxicity_ml_pipeline.settings.default_settings_tox import LOCAL_DIR, MAX_SEQ_LENGTH
try:
from toxicity_ml_pipeline.optim.losses import MaskedBCE
except ImportError:
print("No MaskedBCE loss")
import tensorflow as tf
from toxicity_ml_pipeline.utils.helpers import execute_command
try:
from twitter.cuad.representation.models.text_encoder import TextEncoder
except ModuleNotFoundError:
print("No TextEncoder package")
try:
from transformers import TFAutoModelForSequenceClassification
except ModuleNotFoundError:
print("No HuggingFace package")
LOCAL_MODEL_DIR = os.path.join(LOCAL_DIR, "models")
def reload_model_weights(weights_dir, language, **kwargs):
optimizer = tf.keras.optimizers.Adam(0.01)
model_type = (
"twitter_bert_base_en_uncased_mlm"
if language == "en"
else "twitter_multilingual_bert_base_cased_mlm"
)
model = load(optimizer=optimizer, seed=42, model_type=model_type, **kwargs)
model.load_weights(weights_dir)
return model
def _locally_copy_models(model_type):
if model_type == "twitter_multilingual_bert_base_cased_mlm":
preprocessor = "bert_multi_cased_preprocess_3"
elif model_type == "twitter_bert_base_en_uncased_mlm":
preprocessor = "bert_en_uncased_preprocess_3"
else:
raise NotImplementedError
copy_cmd = """mkdir {local_dir}
gsutil cp -r ...
gsutil cp -r ..."""
execute_command(
copy_cmd.format(
model_type=model_type, preprocessor=preprocessor, local_dir=LOCAL_MODEL_DIR
)
)
return preprocessor
def load_encoder(model_type, trainable):
try:
model = TextEncoder(
max_seq_lengths=MAX_SEQ_LENGTH,
model_type=model_type,
cluster="gcp",
trainable=trainable,
enable_dynamic_shapes=True,
)
except (OSError, tf.errors.AbortedError) as e:
print(e)
preprocessor = _locally_copy_models(model_type)
model = TextEncoder(
max_seq_lengths=MAX_SEQ_LENGTH,
local_model_path=f"models/{model_type}",
local_preprocessor_path=f"models/{preprocessor}",
cluster="gcp",
trainable=trainable,
enable_dynamic_shapes=True,
)
return model
def get_loss(loss_name, from_logits, **kwargs):
loss_name = loss_name.lower()
if loss_name == "bce":
print("Binary CE loss")
return tf.keras.losses.BinaryCrossentropy(from_logits=from_logits)
if loss_name == "cce":
print("Categorical cross-entropy loss")
return tf.keras.losses.CategoricalCrossentropy(from_logits=from_logits)
if loss_name == "scce":
print("Sparse categorical cross-entropy loss")
return tf.keras.losses.SparseCategoricalCrossentropy(from_logits=from_logits)
if loss_name == "focal_bce":
gamma = kwargs.get("gamma", 2)
print("Focal binary CE loss", gamma)
return tf.keras.losses.BinaryFocalCrossentropy(
gamma=gamma, from_logits=from_logits
)
if loss_name == "masked_bce":
multitask = kwargs.get("multitask", False)
if from_logits or multitask:
raise NotImplementedError
print(f"Masked Binary Cross Entropy")
return MaskedBCE()
if loss_name == "inv_kl_loss":
raise NotImplementedError
raise ValueError(
f"This loss name is not valid: {loss_name}. Accepted loss names: BCE, masked BCE, CCE, sCCE, "
f"Focal_BCE, inv_KL_loss"
)
def _add_additional_embedding_layer(doc_embedding, glorot, seed):
doc_embedding = tf.keras.layers.Dense(
768, activation="tanh", kernel_initializer=glorot
)(doc_embedding)
doc_embedding = tf.keras.layers.Dropout(rate=0.1, seed=seed)(doc_embedding)
return doc_embedding
def _get_bias(**kwargs):
smart_bias_value = kwargs.get("smart_bias_value", 0)
print("Smart bias init to ", smart_bias_value)
output_bias = tf.keras.initializers.Constant(smart_bias_value)
return output_bias
def load_inhouse_bert(model_type, trainable, seed, **kwargs):
inputs = tf.keras.layers.Input(shape=(), dtype=tf.string)
encoder = load_encoder(model_type=model_type, trainable=trainable)
doc_embedding = encoder([inputs])["pooled_output"]
doc_embedding = tf.keras.layers.Dropout(rate=0.1, seed=seed)(doc_embedding)
glorot = tf.keras.initializers.glorot_uniform(seed=seed)
if kwargs.get("additional_layer", False):
doc_embedding = _add_additional_embedding_layer(doc_embedding, glorot, seed)
if kwargs.get("content_num_classes", None):
probs = get_last_layer(
glorot=glorot, last_layer_name="target_output", **kwargs
)(doc_embedding)
second_probs = get_last_layer(
num_classes=kwargs["content_num_classes"],
last_layer_name="content_output",
glorot=glorot,
)(doc_embedding)
probs = [probs, second_probs]
else:
probs = get_last_layer(glorot=glorot, **kwargs)(doc_embedding)
model = tf.keras.models.Model(inputs=inputs, outputs=probs)
return model, False
def get_last_layer(**kwargs):
output_bias = _get_bias(**kwargs)
if "glorot" in kwargs:
glorot = kwargs["glorot"]
else:
glorot = tf.keras.initializers.glorot_uniform(seed=kwargs["seed"])
layer_name = kwargs.get("last_layer_name", "dense_1")
if kwargs.get("num_classes", 1) > 1:
last_layer = tf.keras.layers.Dense(
kwargs["num_classes"],
activation="softmax",
kernel_initializer=glorot,
bias_initializer=output_bias,
name=layer_name,
)
elif kwargs.get("num_raters", 1) > 1:
if kwargs.get("multitask", False):
raise NotImplementedError
last_layer = tf.keras.layers.Dense(
kwargs["num_raters"],
activation="sigmoid",
kernel_initializer=glorot,
bias_initializer=output_bias,
name="probs",
)
else:
last_layer = tf.keras.layers.Dense(
1,
activation="sigmoid",
kernel_initializer=glorot,
bias_initializer=output_bias,
name=layer_name,
)
return last_layer
def load_bertweet(**kwargs):
bert = TFAutoModelForSequenceClassification.from_pretrained(
os.path.join(LOCAL_MODEL_DIR, "bertweet-base"),
num_labels=1,
classifier_dropout=0.1,
hidden_size=768,
)
if "num_classes" in kwargs and kwargs["num_classes"] > 2:
raise NotImplementedError
return bert, True
def load(
optimizer,
seed,
model_type="twitter_multilingual_bert_base_cased_mlm",
loss_name="BCE",
trainable=True,
**kwargs,
):
if model_type == "bertweet-base":
model, from_logits = load_bertweet()
else:
model, from_logits = load_inhouse_bert(model_type, trainable, seed, **kwargs)
pr_auc = tf.keras.metrics.AUC(curve="PR", name="pr_auc", from_logits=from_logits)
roc_auc = tf.keras.metrics.AUC(curve="ROC", name="roc_auc", from_logits=from_logits)
loss = get_loss(loss_name, from_logits, **kwargs)
if kwargs.get("content_num_classes", None):
second_loss = get_loss(
loss_name=kwargs["content_loss_name"], from_logits=from_logits
)
loss_weights = {
"content_output": kwargs["content_loss_weight"],
"target_output": 1,
}
model.compile(
optimizer=optimizer,
loss={"content_output": second_loss, "target_output": loss},
loss_weights=loss_weights,
metrics=[pr_auc, roc_auc],
)
else:
model.compile(
optimizer=optimizer,
loss=loss,
metrics=[pr_auc, roc_auc],
)
print(model.summary(), "logits: ", from_logits)
return model