mirror of
https://github.com/twitter/the-algorithm.git
synced 2024-06-12 22:28:50 +02:00
fixed trust_and_safety_models
This commit is contained in:
parent
66644c1771
commit
8f93b2e618
|
@ -3,7 +3,7 @@ import logging
|
|||
import os
|
||||
import pkgutil
|
||||
import sys
|
||||
from typing import List, Union
|
||||
from typing import Dict, List, Optional
|
||||
from urllib.parse import urlsplit
|
||||
|
||||
import apache_beam as beam
|
||||
|
@ -11,7 +11,7 @@ import faiss
|
|||
from apache_beam.options.pipeline_options import PipelineOptions
|
||||
|
||||
|
||||
def parse_d6w_config(argv: Union[List[str], None] = None):
|
||||
def parse_d6w_config(argv: Optional[List[str]] = None):
|
||||
"""Parse d6w config.
|
||||
:param argv: d6w config
|
||||
:return: dictionary containing d6w config
|
||||
|
@ -93,8 +93,8 @@ def get_bq_query():
|
|||
return pkgutil.get_data(__name__, "bq.sql").decode("utf-8")
|
||||
|
||||
|
||||
def parse_metric(config):
|
||||
metric_str: str = config["metric"].lower()
|
||||
def parse_metric(config: Dict[str, str]):
|
||||
metric_str = config["metric"].lower()
|
||||
if metric_str == "l2":
|
||||
return faiss.METRIC_L2
|
||||
elif metric_str == "ip":
|
||||
|
@ -142,10 +142,7 @@ def run_pipeline(argv: List[str] = []):
|
|||
config["metric"],
|
||||
config["gpu"],
|
||||
)
|
||||
)
|
||||
|
||||
# Make linter happy
|
||||
index_built
|
||||
) # pylint: disable=unused-variable
|
||||
|
||||
|
||||
class MergeAndBuildIndex(beam.CombineFn):
|
||||
|
@ -159,7 +156,7 @@ class MergeAndBuildIndex(beam.CombineFn):
|
|||
def create_accumulator(self):
|
||||
return []
|
||||
|
||||
def add_input(self, accumulator, element):
|
||||
def add_input(self, accumulator: List, element) -> List:
|
||||
accumulator.append(element)
|
||||
return accumulator
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# checkstyle: noqa
|
||||
from typing import Dict, Optional
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
|
|
|
@ -7,13 +7,13 @@ from toxicity_ml_pipeline.settings.hcomp_settings import TOXIC_35
|
|||
|
||||
TOXIC_35_set = set(TOXIC_35)
|
||||
|
||||
url_group = r"(\bhttps?:\/\/\S+)"
|
||||
mention_group = r"(\B@\S+)"
|
||||
urls_mentions_re = re.compile(url_group + r"|" + mention_group, re.IGNORECASE)
|
||||
url_re = re.compile(url_group, re.IGNORECASE)
|
||||
mention_re = re.compile(mention_group, re.IGNORECASE)
|
||||
newline_re = re.compile(r"\n+", re.IGNORECASE)
|
||||
and_re = re.compile(r"&\s?amp\s?;", re.IGNORECASE)
|
||||
URL_GROUP = r"(\bhttps?:\/\/\S+)"
|
||||
MENTION_GROUP = r"(\B@\S+)"
|
||||
URLS_MENTIONS_RE = re.compile(URL_GROUP + r"|" + MENTION_GROUP, re.IGNORECASE)
|
||||
URL_RE = re.compile(URL_GROUP, re.IGNORECASE)
|
||||
MENTION_RE = re.compile(MENTION_GROUP, re.IGNORECASE)
|
||||
NEWLINE_RE = re.compile(r"\n+", re.IGNORECASE)
|
||||
AND_RE = re.compile(r"&\s?amp\s?;", re.IGNORECASE)
|
||||
|
||||
|
||||
class DataframeCleaner(ABC):
|
||||
|
@ -98,22 +98,22 @@ class DefaultENNoPreprocessor(DataframeCleaner):
|
|||
|
||||
|
||||
class DefaultENPreprocessor(DefaultENNoPreprocessor):
|
||||
def _clean(self, adhoc_df):
|
||||
def _clean(self, adhoc_df: pd.DataFrame) -> pd.DataFrame:
|
||||
print(
|
||||
"... removing \\n and replacing @mentions and URLs by placeholders. "
|
||||
"Emoji filtering is not done."
|
||||
)
|
||||
adhoc_df["text"] = [
|
||||
url_re.sub("URL", tweet) for tweet in adhoc_df.raw_text.values
|
||||
URL_RE.sub("URL", tweet) for tweet in adhoc_df.raw_text.values
|
||||
]
|
||||
adhoc_df["text"] = [
|
||||
mention_re.sub("MENTION", tweet) for tweet in adhoc_df.text.values
|
||||
MENTION_RE.sub("MENTION", tweet) for tweet in adhoc_df.text.values
|
||||
]
|
||||
adhoc_df["text"] = [
|
||||
newline_re.sub(" ", tweet).lstrip(" ").rstrip(" ")
|
||||
NEWLINE_RE.sub(" ", tweet).lstrip(" ").rstrip(" ")
|
||||
for tweet in adhoc_df.text.values
|
||||
]
|
||||
adhoc_df["text"] = [and_re.sub("&", tweet) for tweet in adhoc_df.text.values]
|
||||
adhoc_df["text"] = [AND_RE.sub("&", tweet) for tweet in adhoc_df.text.values]
|
||||
return adhoc_df
|
||||
|
||||
|
||||
|
@ -121,10 +121,10 @@ class Defaulti18nPreprocessor(DataframeCleaner):
|
|||
def _clean(self, adhoc_df):
|
||||
print("... removing @mentions, \\n and URLs. Emoji filtering is not done.")
|
||||
adhoc_df["text"] = [
|
||||
urls_mentions_re.sub("", tweet) for tweet in adhoc_df.raw_text.values
|
||||
URLS_MENTIONS_RE.sub("", tweet) for tweet in adhoc_df.raw_text.values
|
||||
]
|
||||
adhoc_df["text"] = [
|
||||
newline_re.sub(" ", tweet).lstrip(" ").rstrip(" ")
|
||||
NEWLINE_RE.sub(" ", tweet).lstrip(" ").rstrip(" ")
|
||||
for tweet in adhoc_df.text.values
|
||||
]
|
||||
return adhoc_df
|
||||
|
|
|
@ -222,7 +222,7 @@ class ENLoaderWithSampling(ENLoader):
|
|||
|
||||
def sample(
|
||||
self,
|
||||
df,
|
||||
df: pd.DataFrame,
|
||||
first_set: pd.DataFrame,
|
||||
second_set: pd.DataFrame,
|
||||
keyword_sampling: bool,
|
||||
|
@ -300,13 +300,13 @@ class I18nLoader(DataframeLoader):
|
|||
self.accepted_languages = ACCEPTED_LANGUAGES
|
||||
self.query_settings = dict(QUERY_SETTINGS)
|
||||
|
||||
def produce_query(self, language, query, dataset, table, lang):
|
||||
def produce_query(self, language: str, query: str, dataset: str, table: str, lang: str) -> str:
|
||||
query = query.format(dataset=dataset, table=table)
|
||||
add_query = f"AND reviewed.{lang}='{language}'"
|
||||
query += add_query
|
||||
return query
|
||||
|
||||
def query_keys(self, language, task=2, size="50"):
|
||||
def query_keys(self, language: str, task: int=2, size: str="50"):
|
||||
if task == 2:
|
||||
if language == "ar":
|
||||
self.query_settings["adhoc_v2"]["table"] = "..."
|
||||
|
@ -323,7 +323,7 @@ class I18nLoader(DataframeLoader):
|
|||
f"There are no other tasks than 2 or 3. {task} does not exist."
|
||||
)
|
||||
|
||||
def load_data(self, language, test=False, task=2):
|
||||
def load_data(self, language: str, test: bool=False, task: int=2):
|
||||
if language not in self.accepted_languages:
|
||||
raise ValueError(
|
||||
f"Language not in the data {language}. Accepted values are "
|
||||
|
|
|
@ -78,11 +78,9 @@ class BalancedMiniBatchLoader(object):
|
|||
)
|
||||
|
||||
self.n_inner_splits = n_inner_splits if n_inner_splits is not None else INNER_CV
|
||||
|
||||
self.seed = seed
|
||||
self.mb_size = mb_size
|
||||
self.fold = fold
|
||||
|
||||
self.sample_weights = sample_weights
|
||||
self.dual_head = dual_head
|
||||
self.huggingface = huggingface
|
||||
|
@ -99,7 +97,7 @@ class BalancedMiniBatchLoader(object):
|
|||
os.path.join(local_model_dir, "bertweet-base"), normalization=True
|
||||
)
|
||||
|
||||
def tokenize_function(self, el):
|
||||
def tokenize_function(self, el: dict) -> dict:
|
||||
return self.tokenizer(
|
||||
el["text"],
|
||||
max_length=MAX_SEQ_LENGTH,
|
||||
|
@ -110,12 +108,12 @@ class BalancedMiniBatchLoader(object):
|
|||
return_attention_mask=False,
|
||||
)
|
||||
|
||||
def _get_stratified_kfold(self, n_splits):
|
||||
def _get_stratified_kfold(self, n_splits: int) -> StratifiedKFold:
|
||||
return StratifiedKFold(shuffle=True, n_splits=n_splits, random_state=self.seed)
|
||||
|
||||
def _get_time_fold(self, df):
|
||||
test_begin_date = pandas.to_datetime(self.test_begin_date).date()
|
||||
test_end_date = pandas.to_datetime(self.test_end_date).date()
|
||||
def _get_time_fold(self, df: pd.DataFrame) -> Tuple[pd.DataFrame, pd.DataFrame]:
|
||||
test_begin_date = pd.to_datetime(self.test_begin_date).date()
|
||||
test_end_date = pd.to_datetime(self.test_end_date).date()
|
||||
print(f"Test is going from {test_begin_date} to {test_end_date}.")
|
||||
test_data = df.query("@test_begin_date <= date <= @test_end_date")
|
||||
|
||||
|
@ -123,7 +121,7 @@ class BalancedMiniBatchLoader(object):
|
|||
other_set = df.query(query)
|
||||
return other_set, test_data
|
||||
|
||||
def _get_outer_cv_fold(self, df):
|
||||
def _get_outer_cv_fold(self, df: pd.DataFrame) -> Tuple[pd.DataFrame, pd.DataFrame]:
|
||||
labels = df.int_label
|
||||
stratifier = self._get_stratified_kfold(n_splits=self.n_outer_splits)
|
||||
|
||||
|
@ -132,20 +130,18 @@ class BalancedMiniBatchLoader(object):
|
|||
if k == self.fold:
|
||||
break
|
||||
k += 1
|
||||
|
||||
train_data = df.iloc[train_index].copy()
|
||||
test_data = df.iloc[test_index].copy()
|
||||
|
||||
return train_data, test_data
|
||||
|
||||
def get_steps_per_epoch(self, nb_pos_examples):
|
||||
def get_steps_per_epoch(self, nb_pos_examples: int) -> int:
|
||||
return int(
|
||||
max(TARGET_POS_PER_EPOCH, nb_pos_examples)
|
||||
/ self.mb_size
|
||||
/ self.perc_training_tox
|
||||
)
|
||||
|
||||
def make_huggingface_tensorflow_ds(self, group, mb_size=None, shuffle=True):
|
||||
def make_huggingface_tensorflow_ds(self, group, mb_size=None, shuffle: bool = True):
|
||||
huggingface_ds = Dataset.from_pandas(group).map(
|
||||
self.tokenize_function, batched=True
|
||||
)
|
||||
|
@ -164,7 +160,9 @@ class BalancedMiniBatchLoader(object):
|
|||
return tensorflow_ds.repeat()
|
||||
return tensorflow_ds
|
||||
|
||||
def make_pure_tensorflow_ds(self, df, nb_samples):
|
||||
def make_pure_tensorflow_ds(
|
||||
self, df: pd.DataFrame, nb_samples: int
|
||||
) -> tf.data.Dataset:
|
||||
buffer_size = nb_samples * 2
|
||||
|
||||
if self.sample_weights is not None:
|
||||
|
@ -188,8 +186,11 @@ class BalancedMiniBatchLoader(object):
|
|||
return ds
|
||||
|
||||
def get_balanced_dataset(
|
||||
self, training_data, size_limit=None, return_as_batch=True
|
||||
):
|
||||
self,
|
||||
training_data: pd.DataFrame,
|
||||
size_limit: int = None,
|
||||
return_as_batch: bool = True,
|
||||
) -> tf.data.Dataset:
|
||||
training_data = training_data.sample(frac=1, random_state=self.seed)
|
||||
nb_samples = training_data.shape[0] if not size_limit else size_limit
|
||||
|
||||
|
@ -198,8 +199,13 @@ class BalancedMiniBatchLoader(object):
|
|||
if size_limit:
|
||||
training_data = training_data[: size_limit * num_classes]
|
||||
|
||||
percent_tox = (
|
||||
100
|
||||
* training_data[training_data.int_label == toxic_class].shape[0]
|
||||
/ nb_samples
|
||||
)
|
||||
print(
|
||||
f"... {nb_samples} examples, incl. {(100 * training_data[training_data.int_label == toxic_class].shape[0] / nb_samples):.2f}% tox in train, {num_classes} classes"
|
||||
f"... {nb_samples} examples, incl. {percent_tox:.2f}% tox in train, {num_classes} classes"
|
||||
)
|
||||
label_groups = training_data.groupby("int_label")
|
||||
if self.huggingface:
|
||||
|
@ -273,7 +279,9 @@ class BalancedMiniBatchLoader(object):
|
|||
|
||||
yield mini_batches, steps_per_epoch, val_data, test_data
|
||||
|
||||
def simple_cv_load(self, full_df: pd.DataFrame):
|
||||
def simple_cv_load(
|
||||
self, full_df: pd.DataFrame
|
||||
) -> Tuple[tf.data.Dataset, pd.DataFrame, int]:
|
||||
full_df = self._compute_int_labels(full_df)
|
||||
|
||||
train_data, test_data = self.get_outer_fold(df=full_df)
|
||||
|
@ -287,7 +295,9 @@ class BalancedMiniBatchLoader(object):
|
|||
|
||||
return mini_batches, test_data, steps_per_epoch
|
||||
|
||||
def no_cv_load(self, full_df: pd.DataFrame):
|
||||
def no_cv_load(
|
||||
self, full_df: pd.DataFrame
|
||||
) -> Tuple[tf.data.Dataset, pd.DataFrame, int]:
|
||||
full_df = self._compute_int_labels(full_df)
|
||||
|
||||
val_test = full_df[full_df.origin == "precision"].copy(deep=True)
|
||||
|
|
|
@ -22,7 +22,7 @@ except ModuleNotFoundError:
|
|||
LOCAL_MODEL_DIR = os.path.join(LOCAL_DIR, "models")
|
||||
|
||||
|
||||
def reload_model_weights(weights_dir, language, **kwargs):
|
||||
def reload_model_weights(weights_dir, language: str, **kwargs):
|
||||
optimizer = tf.keras.optimizers.Adam(0.01)
|
||||
model_type = (
|
||||
"twitter_bert_base_en_uncased_mlm"
|
||||
|
@ -35,7 +35,7 @@ def reload_model_weights(weights_dir, language, **kwargs):
|
|||
return model
|
||||
|
||||
|
||||
def _locally_copy_models(model_type):
|
||||
def _locally_copy_models(model_type: str):
|
||||
if model_type == "twitter_multilingual_bert_base_cased_mlm":
|
||||
preprocessor = "bert_multi_cased_preprocess_3"
|
||||
elif model_type == "twitter_bert_base_en_uncased_mlm":
|
||||
|
@ -43,9 +43,7 @@ def _locally_copy_models(model_type):
|
|||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
copy_cmd = """mkdir {local_dir}
|
||||
gsutil cp -r ...
|
||||
gsutil cp -r ..."""
|
||||
copy_cmd = "mkdir {local_dir}\ngsutil cp -r ...\ngsutil cp -r ..."
|
||||
execute_command(
|
||||
copy_cmd.format(
|
||||
model_type=model_type, preprocessor=preprocessor, local_dir=LOCAL_MODEL_DIR
|
||||
|
@ -55,7 +53,7 @@ gsutil cp -r ..."""
|
|||
return preprocessor
|
||||
|
||||
|
||||
def load_encoder(model_type, trainable):
|
||||
def load_encoder(model_type: str, trainable: bool):
|
||||
try:
|
||||
model = TextEncoder(
|
||||
max_seq_lengths=MAX_SEQ_LENGTH,
|
||||
|
@ -80,7 +78,7 @@ def load_encoder(model_type, trainable):
|
|||
return model
|
||||
|
||||
|
||||
def get_loss(loss_name, from_logits, **kwargs):
|
||||
def get_loss(loss_name: str, from_logits, **kwargs):
|
||||
loss_name = loss_name.lower()
|
||||
if loss_name == "bce":
|
||||
print("Binary CE loss")
|
||||
|
@ -117,7 +115,7 @@ def get_loss(loss_name, from_logits, **kwargs):
|
|||
)
|
||||
|
||||
|
||||
def _add_additional_embedding_layer(doc_embedding, glorot, seed):
|
||||
def _add_additional_embedding_layer(doc_embedding, glorot, seed: int):
|
||||
doc_embedding = tf.keras.layers.Dense(
|
||||
768, activation="tanh", kernel_initializer=glorot
|
||||
)(doc_embedding)
|
||||
|
@ -213,11 +211,11 @@ def load_bertweet(**kwargs):
|
|||
|
||||
|
||||
def load(
|
||||
optimizer,
|
||||
seed,
|
||||
model_type="twitter_multilingual_bert_base_cased_mlm",
|
||||
loss_name="BCE",
|
||||
trainable=True,
|
||||
optimizer: str,
|
||||
seed: int,
|
||||
model_type: str = "twitter_multilingual_bert_base_cased_mlm",
|
||||
loss_name: str = "BCE",
|
||||
trainable: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
if model_type == "bertweet-base":
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
import os
|
||||
from collections import defaultdict
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
import wandb
|
||||
from sklearn.metrics import average_precision_score, roc_auc_score
|
||||
|
@ -87,13 +89,13 @@ class GradientLoggingTensorBoard(SyncingTensorBoard):
|
|||
class AdditionalResultLogger(tf.keras.callbacks.Callback):
|
||||
def __init__(
|
||||
self,
|
||||
data,
|
||||
set_,
|
||||
fixed_recall=0.85,
|
||||
from_logits=False,
|
||||
dataset_transform_func=None,
|
||||
batch_size=64,
|
||||
dual_head=None,
|
||||
data: tf.data.Dataset,
|
||||
set_: str,
|
||||
fixed_recall: float = 0.85,
|
||||
from_logits: bool = False,
|
||||
dataset_transform_func: callable = None,
|
||||
batch_size: int = 64,
|
||||
dual_head: List[str] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
|
@ -144,7 +146,7 @@ class AdditionalResultLogger(tf.keras.callbacks.Callback):
|
|||
self.fixed_recall = fixed_recall
|
||||
self.batch_size = batch_size
|
||||
|
||||
def compute_precision_fixed_recall(self, labels, preds):
|
||||
def compute_precision_fixed_recall(self, labels: np.ndarray, preds: np.ndarray):
|
||||
result, _ = compute_precision_fixed_recall(
|
||||
labels=labels, preds=preds, fixed_recall=self.fixed_recall
|
||||
)
|
||||
|
@ -159,7 +161,9 @@ class AdditionalResultLogger(tf.keras.callbacks.Callback):
|
|||
if self.counter % 2000 == 0:
|
||||
self.additional_evaluations(step=self.counter, eval_time="batch")
|
||||
|
||||
def _binary_evaluations(self, preds, label_name=None, class_index=None):
|
||||
def _binary_evaluations(
|
||||
self, preds: np.ndarray, label_name=None, class_index: int = None
|
||||
):
|
||||
mask = None
|
||||
curr_labels = self.labels
|
||||
if label_name is not None:
|
||||
|
@ -180,7 +184,7 @@ class AdditionalResultLogger(tf.keras.callbacks.Callback):
|
|||
"roc_auc": roc_auc_score(y_true=curr_labels, y_score=preds),
|
||||
}
|
||||
|
||||
def _multiclass_evaluations(self, preds):
|
||||
def _multiclass_evaluations(self, preds: np.ndarray):
|
||||
pr_auc_l = average_precision_score(
|
||||
y_true=self.labels, y_score=preds, **self.metric_kw
|
||||
)
|
||||
|
@ -192,7 +196,7 @@ class AdditionalResultLogger(tf.keras.callbacks.Callback):
|
|||
|
||||
return metrics
|
||||
|
||||
def additional_evaluations(self, step, eval_time):
|
||||
def additional_evaluations(self, step: int, eval_time: str):
|
||||
print("Evaluating ", self.set_, eval_time, step)
|
||||
|
||||
preds = self.model.predict(x=self.data, batch_size=self.batch_size)
|
||||
|
@ -232,7 +236,7 @@ class AdditionalResultLogger(tf.keras.callbacks.Callback):
|
|||
|
||||
self.log_metrics(metrics, step=step, eval_time=eval_time)
|
||||
|
||||
def log_metrics(self, metrics_d, step, eval_time):
|
||||
def log_metrics(self, metrics_d: dict, step: int, eval_time: str):
|
||||
commit = False if self.set_ == "validation" else True
|
||||
to_report = {self.set_: {**metrics_d, **self.best_metrics}}
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@ from keras import backend
|
|||
from keras.utils import losses_utils, tf_utils
|
||||
|
||||
|
||||
def inv_kl_divergence(y_true, y_pred):
|
||||
def inv_kl_divergence(y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
|
||||
y_pred = tf.convert_to_tensor(y_pred)
|
||||
y_true = tf.cast(y_true, y_pred.dtype)
|
||||
y_true = backend.clip(y_true, backend.epsilon(), 1)
|
||||
|
@ -11,7 +11,7 @@ def inv_kl_divergence(y_true, y_pred):
|
|||
return tf.reduce_sum(y_pred * tf.math.log(y_pred / y_true), axis=-1)
|
||||
|
||||
|
||||
def masked_bce(y_true, y_pred):
|
||||
def masked_bce(y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
|
||||
y_true = tf.cast(y_true, dtype=tf.float32)
|
||||
mask = y_true != -1
|
||||
|
||||
|
@ -28,7 +28,7 @@ class LossFunctionWrapper(tf.keras.losses.Loss):
|
|||
self.fn = fn
|
||||
self._fn_kwargs = kwargs
|
||||
|
||||
def call(self, y_true, y_pred):
|
||||
def call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
|
||||
if tf.is_tensor(y_pred) and tf.is_tensor(y_true):
|
||||
y_pred, y_true = losses_utils.squeeze_or_expand_dimensions(y_pred, y_true)
|
||||
|
||||
|
@ -37,7 +37,7 @@ class LossFunctionWrapper(tf.keras.losses.Loss):
|
|||
)
|
||||
return ag_fn(y_true, y_pred, **self._fn_kwargs)
|
||||
|
||||
def get_config(self):
|
||||
def get_config(self) -> dict:
|
||||
config = {}
|
||||
for k, v in self._fn_kwargs.items():
|
||||
config[k] = backend.eval(v) if tf_utils.is_tensor_or_variable(v) else v
|
||||
|
@ -47,11 +47,13 @@ class LossFunctionWrapper(tf.keras.losses.Loss):
|
|||
|
||||
class InvKLD(LossFunctionWrapper):
|
||||
def __init__(
|
||||
self, reduction=losses_utils.ReductionV2.AUTO, name="inv_kl_divergence"
|
||||
self, reduction=losses_utils.ReductionV2.AUTO, name: str = "inv_kl_divergence"
|
||||
):
|
||||
super().__init__(inv_kl_divergence, name=name, reduction=reduction)
|
||||
|
||||
|
||||
class MaskedBCE(LossFunctionWrapper):
|
||||
def __init__(self, reduction=losses_utils.ReductionV2.AUTO, name="masked_bce"):
|
||||
def __init__(
|
||||
self, reduction=losses_utils.ReductionV2.AUTO, name: str = "masked_bce"
|
||||
):
|
||||
super().__init__(masked_bce, name=name, reduction=reduction)
|
||||
|
|
|
@ -1,11 +1,18 @@
|
|||
import numpy as np
|
||||
import pandas as pd
|
||||
import tensorflow as tf
|
||||
from toxicity_ml_pipeline.load_model import reload_model_weights
|
||||
from toxicity_ml_pipeline.utils.helpers import load_inference_func, upload_model
|
||||
|
||||
|
||||
def score(
|
||||
language, df, gcs_model_path, batch_size=64, text_col="text", kw="", **kwargs
|
||||
language: str,
|
||||
df: pd.DataFrame,
|
||||
gcs_model_path: str,
|
||||
batch_size: int = 64,
|
||||
text_col: str = "text",
|
||||
kw: str = "",
|
||||
**kwargs,
|
||||
):
|
||||
if language != "en":
|
||||
raise NotImplementedError(
|
||||
|
@ -41,7 +48,13 @@ def score(
|
|||
)
|
||||
|
||||
|
||||
def _get_score(inference_func, df, text_col="text", kw="", batch_size=64):
|
||||
def _get_score(
|
||||
inference_func: tf.function,
|
||||
df: pd.DataFrame,
|
||||
text_col: str = "text",
|
||||
kw: str = "",
|
||||
batch_size: int = 64,
|
||||
) -> pd.DataFrame:
|
||||
score_col = f"prediction_{kw}"
|
||||
beginning = 0
|
||||
end = df.shape[0]
|
||||
|
|
|
@ -3,6 +3,7 @@ from datetime import datetime
|
|||
from importlib import import_module
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import tensorflow as tf
|
||||
from toxicity_ml_pipeline.data.data_preprocessing import (
|
||||
DefaultENNoPreprocessor,
|
||||
|
@ -247,7 +248,9 @@ class Trainer(object):
|
|||
return warm_up_schedule
|
||||
return learning_rate_fn
|
||||
|
||||
def get_optimizer(self, schedule):
|
||||
def get_optimizer(
|
||||
self, schedule: tf.keras.optimizers.schedules.LearningRateSchedule
|
||||
):
|
||||
optim_args = {
|
||||
"learning_rate": schedule,
|
||||
"beta_1": 0.9,
|
||||
|
@ -277,7 +280,7 @@ class Trainer(object):
|
|||
|
||||
return optimizer, callbacks
|
||||
|
||||
def load_data(self):
|
||||
def load_data(self) -> pd.DataFrame:
|
||||
if self.project == 435 or self.project == 211:
|
||||
if self.dataset_type is None:
|
||||
data_loader = ENLoader(
|
||||
|
@ -299,7 +302,7 @@ class Trainer(object):
|
|||
|
||||
return df
|
||||
|
||||
def preprocess(self, df):
|
||||
def preprocess(self, df: pd.DataFrame):
|
||||
if self.project == 435 or self.project == 211:
|
||||
if self.preprocessing is None:
|
||||
data_prepro = DefaultENNoPreprocessor()
|
||||
|
@ -318,7 +321,7 @@ class Trainer(object):
|
|||
num_classes=self.num_classes,
|
||||
)
|
||||
|
||||
def load_model(self, optimizer):
|
||||
def load_model(self, optimizer: tf.keras.optimizers.Optimizer):
|
||||
smart_bias_value = (
|
||||
np.log(self.perc_training_tox / (1 - self.perc_training_tox))
|
||||
if self.smart_bias_init
|
||||
|
@ -354,7 +357,12 @@ class Trainer(object):
|
|||
return model
|
||||
|
||||
def _train_single_fold(
|
||||
self, mb_generator, test_data, steps_per_epoch, fold, val_data=None
|
||||
self,
|
||||
mb_generator: tf.data.Dataset,
|
||||
test_data: pd.DataFrame,
|
||||
steps_per_epoch: int,
|
||||
fold: int,
|
||||
val_data: pd.DataFrame = None,
|
||||
):
|
||||
steps_per_epoch = 100 if self.test else steps_per_epoch
|
||||
|
||||
|
@ -376,7 +384,6 @@ class Trainer(object):
|
|||
}
|
||||
|
||||
model.fit(mb_generator, **training_args)
|
||||
return
|
||||
|
||||
def train_full_model(self):
|
||||
print("Setting up random seed.")
|
||||
|
|
|
@ -13,7 +13,7 @@ except ModuleNotFoundError:
|
|||
pass
|
||||
|
||||
|
||||
def upload_model(full_gcs_model_path):
|
||||
def upload_model(full_gcs_model_path: str):
|
||||
folder_name = full_gcs_model_path
|
||||
if folder_name[:5] != "gs://":
|
||||
folder_name = "gs://" + folder_name
|
||||
|
@ -50,7 +50,9 @@ def upload_model(full_gcs_model_path):
|
|||
return weights_dir
|
||||
|
||||
|
||||
def compute_precision_fixed_recall(labels, preds, fixed_recall):
|
||||
def compute_precision_fixed_recall(
|
||||
labels: np.ndarray, preds: np.ndarray, fixed_recall: float
|
||||
):
|
||||
precision_values, recall_values, thresholds = precision_recall_curve(
|
||||
y_true=labels, probas_pred=preds
|
||||
)
|
||||
|
@ -61,7 +63,7 @@ def compute_precision_fixed_recall(labels, preds, fixed_recall):
|
|||
return result, thresholds[index_recall - 1]
|
||||
|
||||
|
||||
def load_inference_func(model_folder):
|
||||
def load_inference_func(model_folder: str):
|
||||
model = tf.saved_model.load(model_folder, ["serve"])
|
||||
inference_func = model.signatures["serving_default"]
|
||||
return inference_func
|
||||
|
@ -73,7 +75,7 @@ def execute_query(client, query):
|
|||
return df
|
||||
|
||||
|
||||
def execute_command(cmd, print_=True):
|
||||
def execute_command(cmd: str, print_: bool = True):
|
||||
s = subprocess.run(cmd, shell=True, capture_output=print_, check=True)
|
||||
if print_:
|
||||
print(s.stderr.decode("utf-8"))
|
||||
|
@ -95,9 +97,7 @@ def check_gpu():
|
|||
print(l)
|
||||
|
||||
|
||||
def set_seeds(seed):
|
||||
def set_seeds(seed: int):
|
||||
np.random.seed(seed)
|
||||
|
||||
python_random.seed(seed)
|
||||
|
||||
tf.random.set_seed(seed)
|
||||
|
|
|
@ -401,13 +401,13 @@ class ExperimentTracker(object):
|
|||
logging.error("Failed to export feature spec. Error: %s", str(err))
|
||||
|
||||
@property
|
||||
def path(self) -> Union[Dict[str, str], None]:
|
||||
def path(self) -> Optional[Dict[str, str]]:
|
||||
if self.disabled:
|
||||
return None
|
||||
return get_components_from_id(self.tracking_path, ensure_valid_id=False)
|
||||
|
||||
@property
|
||||
def experiment_id(self) -> Union[str, None]:
|
||||
def experiment_id(self) -> Optional[str]:
|
||||
"""Return the experiment id."""
|
||||
if self.disabled:
|
||||
return None
|
||||
|
|
Loading…
Reference in New Issue
Block a user