mirror of
https://github.com/twitter/the-algorithm.git
synced 2024-06-03 01:38:45 +02:00
fixed trust_and_safety_models
This commit is contained in:
parent
66644c1771
commit
8f93b2e618
|
@ -3,7 +3,7 @@ import logging
|
||||||
import os
|
import os
|
||||||
import pkgutil
|
import pkgutil
|
||||||
import sys
|
import sys
|
||||||
from typing import List, Union
|
from typing import Dict, List, Optional
|
||||||
from urllib.parse import urlsplit
|
from urllib.parse import urlsplit
|
||||||
|
|
||||||
import apache_beam as beam
|
import apache_beam as beam
|
||||||
|
@ -11,7 +11,7 @@ import faiss
|
||||||
from apache_beam.options.pipeline_options import PipelineOptions
|
from apache_beam.options.pipeline_options import PipelineOptions
|
||||||
|
|
||||||
|
|
||||||
def parse_d6w_config(argv: Union[List[str], None] = None):
|
def parse_d6w_config(argv: Optional[List[str]] = None):
|
||||||
"""Parse d6w config.
|
"""Parse d6w config.
|
||||||
:param argv: d6w config
|
:param argv: d6w config
|
||||||
:return: dictionary containing d6w config
|
:return: dictionary containing d6w config
|
||||||
|
@ -93,8 +93,8 @@ def get_bq_query():
|
||||||
return pkgutil.get_data(__name__, "bq.sql").decode("utf-8")
|
return pkgutil.get_data(__name__, "bq.sql").decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
def parse_metric(config):
|
def parse_metric(config: Dict[str, str]):
|
||||||
metric_str: str = config["metric"].lower()
|
metric_str = config["metric"].lower()
|
||||||
if metric_str == "l2":
|
if metric_str == "l2":
|
||||||
return faiss.METRIC_L2
|
return faiss.METRIC_L2
|
||||||
elif metric_str == "ip":
|
elif metric_str == "ip":
|
||||||
|
@ -142,10 +142,7 @@ def run_pipeline(argv: List[str] = []):
|
||||||
config["metric"],
|
config["metric"],
|
||||||
config["gpu"],
|
config["gpu"],
|
||||||
)
|
)
|
||||||
)
|
) # pylint: disable=unused-variable
|
||||||
|
|
||||||
# Make linter happy
|
|
||||||
index_built
|
|
||||||
|
|
||||||
|
|
||||||
class MergeAndBuildIndex(beam.CombineFn):
|
class MergeAndBuildIndex(beam.CombineFn):
|
||||||
|
@ -159,7 +156,7 @@ class MergeAndBuildIndex(beam.CombineFn):
|
||||||
def create_accumulator(self):
|
def create_accumulator(self):
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def add_input(self, accumulator, element):
|
def add_input(self, accumulator: List, element) -> List:
|
||||||
accumulator.append(element)
|
accumulator.append(element)
|
||||||
return accumulator
|
return accumulator
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
# checkstyle: noqa
|
# checkstyle: noqa
|
||||||
from typing import Dict, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
import tensorflow.compat.v1 as tf
|
import tensorflow.compat.v1 as tf
|
||||||
|
|
||||||
|
|
|
@ -7,13 +7,13 @@ from toxicity_ml_pipeline.settings.hcomp_settings import TOXIC_35
|
||||||
|
|
||||||
TOXIC_35_set = set(TOXIC_35)
|
TOXIC_35_set = set(TOXIC_35)
|
||||||
|
|
||||||
url_group = r"(\bhttps?:\/\/\S+)"
|
URL_GROUP = r"(\bhttps?:\/\/\S+)"
|
||||||
mention_group = r"(\B@\S+)"
|
MENTION_GROUP = r"(\B@\S+)"
|
||||||
urls_mentions_re = re.compile(url_group + r"|" + mention_group, re.IGNORECASE)
|
URLS_MENTIONS_RE = re.compile(URL_GROUP + r"|" + MENTION_GROUP, re.IGNORECASE)
|
||||||
url_re = re.compile(url_group, re.IGNORECASE)
|
URL_RE = re.compile(URL_GROUP, re.IGNORECASE)
|
||||||
mention_re = re.compile(mention_group, re.IGNORECASE)
|
MENTION_RE = re.compile(MENTION_GROUP, re.IGNORECASE)
|
||||||
newline_re = re.compile(r"\n+", re.IGNORECASE)
|
NEWLINE_RE = re.compile(r"\n+", re.IGNORECASE)
|
||||||
and_re = re.compile(r"&\s?amp\s?;", re.IGNORECASE)
|
AND_RE = re.compile(r"&\s?amp\s?;", re.IGNORECASE)
|
||||||
|
|
||||||
|
|
||||||
class DataframeCleaner(ABC):
|
class DataframeCleaner(ABC):
|
||||||
|
@ -98,22 +98,22 @@ class DefaultENNoPreprocessor(DataframeCleaner):
|
||||||
|
|
||||||
|
|
||||||
class DefaultENPreprocessor(DefaultENNoPreprocessor):
|
class DefaultENPreprocessor(DefaultENNoPreprocessor):
|
||||||
def _clean(self, adhoc_df):
|
def _clean(self, adhoc_df: pd.DataFrame) -> pd.DataFrame:
|
||||||
print(
|
print(
|
||||||
"... removing \\n and replacing @mentions and URLs by placeholders. "
|
"... removing \\n and replacing @mentions and URLs by placeholders. "
|
||||||
"Emoji filtering is not done."
|
"Emoji filtering is not done."
|
||||||
)
|
)
|
||||||
adhoc_df["text"] = [
|
adhoc_df["text"] = [
|
||||||
url_re.sub("URL", tweet) for tweet in adhoc_df.raw_text.values
|
URL_RE.sub("URL", tweet) for tweet in adhoc_df.raw_text.values
|
||||||
]
|
]
|
||||||
adhoc_df["text"] = [
|
adhoc_df["text"] = [
|
||||||
mention_re.sub("MENTION", tweet) for tweet in adhoc_df.text.values
|
MENTION_RE.sub("MENTION", tweet) for tweet in adhoc_df.text.values
|
||||||
]
|
]
|
||||||
adhoc_df["text"] = [
|
adhoc_df["text"] = [
|
||||||
newline_re.sub(" ", tweet).lstrip(" ").rstrip(" ")
|
NEWLINE_RE.sub(" ", tweet).lstrip(" ").rstrip(" ")
|
||||||
for tweet in adhoc_df.text.values
|
for tweet in adhoc_df.text.values
|
||||||
]
|
]
|
||||||
adhoc_df["text"] = [and_re.sub("&", tweet) for tweet in adhoc_df.text.values]
|
adhoc_df["text"] = [AND_RE.sub("&", tweet) for tweet in adhoc_df.text.values]
|
||||||
return adhoc_df
|
return adhoc_df
|
||||||
|
|
||||||
|
|
||||||
|
@ -121,10 +121,10 @@ class Defaulti18nPreprocessor(DataframeCleaner):
|
||||||
def _clean(self, adhoc_df):
|
def _clean(self, adhoc_df):
|
||||||
print("... removing @mentions, \\n and URLs. Emoji filtering is not done.")
|
print("... removing @mentions, \\n and URLs. Emoji filtering is not done.")
|
||||||
adhoc_df["text"] = [
|
adhoc_df["text"] = [
|
||||||
urls_mentions_re.sub("", tweet) for tweet in adhoc_df.raw_text.values
|
URLS_MENTIONS_RE.sub("", tweet) for tweet in adhoc_df.raw_text.values
|
||||||
]
|
]
|
||||||
adhoc_df["text"] = [
|
adhoc_df["text"] = [
|
||||||
newline_re.sub(" ", tweet).lstrip(" ").rstrip(" ")
|
NEWLINE_RE.sub(" ", tweet).lstrip(" ").rstrip(" ")
|
||||||
for tweet in adhoc_df.text.values
|
for tweet in adhoc_df.text.values
|
||||||
]
|
]
|
||||||
return adhoc_df
|
return adhoc_df
|
||||||
|
|
|
@ -222,7 +222,7 @@ class ENLoaderWithSampling(ENLoader):
|
||||||
|
|
||||||
def sample(
|
def sample(
|
||||||
self,
|
self,
|
||||||
df,
|
df: pd.DataFrame,
|
||||||
first_set: pd.DataFrame,
|
first_set: pd.DataFrame,
|
||||||
second_set: pd.DataFrame,
|
second_set: pd.DataFrame,
|
||||||
keyword_sampling: bool,
|
keyword_sampling: bool,
|
||||||
|
@ -300,13 +300,13 @@ class I18nLoader(DataframeLoader):
|
||||||
self.accepted_languages = ACCEPTED_LANGUAGES
|
self.accepted_languages = ACCEPTED_LANGUAGES
|
||||||
self.query_settings = dict(QUERY_SETTINGS)
|
self.query_settings = dict(QUERY_SETTINGS)
|
||||||
|
|
||||||
def produce_query(self, language, query, dataset, table, lang):
|
def produce_query(self, language: str, query: str, dataset: str, table: str, lang: str) -> str:
|
||||||
query = query.format(dataset=dataset, table=table)
|
query = query.format(dataset=dataset, table=table)
|
||||||
add_query = f"AND reviewed.{lang}='{language}'"
|
add_query = f"AND reviewed.{lang}='{language}'"
|
||||||
query += add_query
|
query += add_query
|
||||||
return query
|
return query
|
||||||
|
|
||||||
def query_keys(self, language, task=2, size="50"):
|
def query_keys(self, language: str, task: int=2, size: str="50"):
|
||||||
if task == 2:
|
if task == 2:
|
||||||
if language == "ar":
|
if language == "ar":
|
||||||
self.query_settings["adhoc_v2"]["table"] = "..."
|
self.query_settings["adhoc_v2"]["table"] = "..."
|
||||||
|
@ -323,7 +323,7 @@ class I18nLoader(DataframeLoader):
|
||||||
f"There are no other tasks than 2 or 3. {task} does not exist."
|
f"There are no other tasks than 2 or 3. {task} does not exist."
|
||||||
)
|
)
|
||||||
|
|
||||||
def load_data(self, language, test=False, task=2):
|
def load_data(self, language: str, test: bool=False, task: int=2):
|
||||||
if language not in self.accepted_languages:
|
if language not in self.accepted_languages:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Language not in the data {language}. Accepted values are "
|
f"Language not in the data {language}. Accepted values are "
|
||||||
|
|
|
@ -78,11 +78,9 @@ class BalancedMiniBatchLoader(object):
|
||||||
)
|
)
|
||||||
|
|
||||||
self.n_inner_splits = n_inner_splits if n_inner_splits is not None else INNER_CV
|
self.n_inner_splits = n_inner_splits if n_inner_splits is not None else INNER_CV
|
||||||
|
|
||||||
self.seed = seed
|
self.seed = seed
|
||||||
self.mb_size = mb_size
|
self.mb_size = mb_size
|
||||||
self.fold = fold
|
self.fold = fold
|
||||||
|
|
||||||
self.sample_weights = sample_weights
|
self.sample_weights = sample_weights
|
||||||
self.dual_head = dual_head
|
self.dual_head = dual_head
|
||||||
self.huggingface = huggingface
|
self.huggingface = huggingface
|
||||||
|
@ -99,7 +97,7 @@ class BalancedMiniBatchLoader(object):
|
||||||
os.path.join(local_model_dir, "bertweet-base"), normalization=True
|
os.path.join(local_model_dir, "bertweet-base"), normalization=True
|
||||||
)
|
)
|
||||||
|
|
||||||
def tokenize_function(self, el):
|
def tokenize_function(self, el: dict) -> dict:
|
||||||
return self.tokenizer(
|
return self.tokenizer(
|
||||||
el["text"],
|
el["text"],
|
||||||
max_length=MAX_SEQ_LENGTH,
|
max_length=MAX_SEQ_LENGTH,
|
||||||
|
@ -110,12 +108,12 @@ class BalancedMiniBatchLoader(object):
|
||||||
return_attention_mask=False,
|
return_attention_mask=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_stratified_kfold(self, n_splits):
|
def _get_stratified_kfold(self, n_splits: int) -> StratifiedKFold:
|
||||||
return StratifiedKFold(shuffle=True, n_splits=n_splits, random_state=self.seed)
|
return StratifiedKFold(shuffle=True, n_splits=n_splits, random_state=self.seed)
|
||||||
|
|
||||||
def _get_time_fold(self, df):
|
def _get_time_fold(self, df: pd.DataFrame) -> Tuple[pd.DataFrame, pd.DataFrame]:
|
||||||
test_begin_date = pandas.to_datetime(self.test_begin_date).date()
|
test_begin_date = pd.to_datetime(self.test_begin_date).date()
|
||||||
test_end_date = pandas.to_datetime(self.test_end_date).date()
|
test_end_date = pd.to_datetime(self.test_end_date).date()
|
||||||
print(f"Test is going from {test_begin_date} to {test_end_date}.")
|
print(f"Test is going from {test_begin_date} to {test_end_date}.")
|
||||||
test_data = df.query("@test_begin_date <= date <= @test_end_date")
|
test_data = df.query("@test_begin_date <= date <= @test_end_date")
|
||||||
|
|
||||||
|
@ -123,7 +121,7 @@ class BalancedMiniBatchLoader(object):
|
||||||
other_set = df.query(query)
|
other_set = df.query(query)
|
||||||
return other_set, test_data
|
return other_set, test_data
|
||||||
|
|
||||||
def _get_outer_cv_fold(self, df):
|
def _get_outer_cv_fold(self, df: pd.DataFrame) -> Tuple[pd.DataFrame, pd.DataFrame]:
|
||||||
labels = df.int_label
|
labels = df.int_label
|
||||||
stratifier = self._get_stratified_kfold(n_splits=self.n_outer_splits)
|
stratifier = self._get_stratified_kfold(n_splits=self.n_outer_splits)
|
||||||
|
|
||||||
|
@ -132,20 +130,18 @@ class BalancedMiniBatchLoader(object):
|
||||||
if k == self.fold:
|
if k == self.fold:
|
||||||
break
|
break
|
||||||
k += 1
|
k += 1
|
||||||
|
|
||||||
train_data = df.iloc[train_index].copy()
|
train_data = df.iloc[train_index].copy()
|
||||||
test_data = df.iloc[test_index].copy()
|
test_data = df.iloc[test_index].copy()
|
||||||
|
|
||||||
return train_data, test_data
|
return train_data, test_data
|
||||||
|
|
||||||
def get_steps_per_epoch(self, nb_pos_examples):
|
def get_steps_per_epoch(self, nb_pos_examples: int) -> int:
|
||||||
return int(
|
return int(
|
||||||
max(TARGET_POS_PER_EPOCH, nb_pos_examples)
|
max(TARGET_POS_PER_EPOCH, nb_pos_examples)
|
||||||
/ self.mb_size
|
/ self.mb_size
|
||||||
/ self.perc_training_tox
|
/ self.perc_training_tox
|
||||||
)
|
)
|
||||||
|
|
||||||
def make_huggingface_tensorflow_ds(self, group, mb_size=None, shuffle=True):
|
def make_huggingface_tensorflow_ds(self, group, mb_size=None, shuffle: bool = True):
|
||||||
huggingface_ds = Dataset.from_pandas(group).map(
|
huggingface_ds = Dataset.from_pandas(group).map(
|
||||||
self.tokenize_function, batched=True
|
self.tokenize_function, batched=True
|
||||||
)
|
)
|
||||||
|
@ -164,7 +160,9 @@ class BalancedMiniBatchLoader(object):
|
||||||
return tensorflow_ds.repeat()
|
return tensorflow_ds.repeat()
|
||||||
return tensorflow_ds
|
return tensorflow_ds
|
||||||
|
|
||||||
def make_pure_tensorflow_ds(self, df, nb_samples):
|
def make_pure_tensorflow_ds(
|
||||||
|
self, df: pd.DataFrame, nb_samples: int
|
||||||
|
) -> tf.data.Dataset:
|
||||||
buffer_size = nb_samples * 2
|
buffer_size = nb_samples * 2
|
||||||
|
|
||||||
if self.sample_weights is not None:
|
if self.sample_weights is not None:
|
||||||
|
@ -188,8 +186,11 @@ class BalancedMiniBatchLoader(object):
|
||||||
return ds
|
return ds
|
||||||
|
|
||||||
def get_balanced_dataset(
|
def get_balanced_dataset(
|
||||||
self, training_data, size_limit=None, return_as_batch=True
|
self,
|
||||||
):
|
training_data: pd.DataFrame,
|
||||||
|
size_limit: int = None,
|
||||||
|
return_as_batch: bool = True,
|
||||||
|
) -> tf.data.Dataset:
|
||||||
training_data = training_data.sample(frac=1, random_state=self.seed)
|
training_data = training_data.sample(frac=1, random_state=self.seed)
|
||||||
nb_samples = training_data.shape[0] if not size_limit else size_limit
|
nb_samples = training_data.shape[0] if not size_limit else size_limit
|
||||||
|
|
||||||
|
@ -198,8 +199,13 @@ class BalancedMiniBatchLoader(object):
|
||||||
if size_limit:
|
if size_limit:
|
||||||
training_data = training_data[: size_limit * num_classes]
|
training_data = training_data[: size_limit * num_classes]
|
||||||
|
|
||||||
|
percent_tox = (
|
||||||
|
100
|
||||||
|
* training_data[training_data.int_label == toxic_class].shape[0]
|
||||||
|
/ nb_samples
|
||||||
|
)
|
||||||
print(
|
print(
|
||||||
f"... {nb_samples} examples, incl. {(100 * training_data[training_data.int_label == toxic_class].shape[0] / nb_samples):.2f}% tox in train, {num_classes} classes"
|
f"... {nb_samples} examples, incl. {percent_tox:.2f}% tox in train, {num_classes} classes"
|
||||||
)
|
)
|
||||||
label_groups = training_data.groupby("int_label")
|
label_groups = training_data.groupby("int_label")
|
||||||
if self.huggingface:
|
if self.huggingface:
|
||||||
|
@ -273,7 +279,9 @@ class BalancedMiniBatchLoader(object):
|
||||||
|
|
||||||
yield mini_batches, steps_per_epoch, val_data, test_data
|
yield mini_batches, steps_per_epoch, val_data, test_data
|
||||||
|
|
||||||
def simple_cv_load(self, full_df: pd.DataFrame):
|
def simple_cv_load(
|
||||||
|
self, full_df: pd.DataFrame
|
||||||
|
) -> Tuple[tf.data.Dataset, pd.DataFrame, int]:
|
||||||
full_df = self._compute_int_labels(full_df)
|
full_df = self._compute_int_labels(full_df)
|
||||||
|
|
||||||
train_data, test_data = self.get_outer_fold(df=full_df)
|
train_data, test_data = self.get_outer_fold(df=full_df)
|
||||||
|
@ -287,7 +295,9 @@ class BalancedMiniBatchLoader(object):
|
||||||
|
|
||||||
return mini_batches, test_data, steps_per_epoch
|
return mini_batches, test_data, steps_per_epoch
|
||||||
|
|
||||||
def no_cv_load(self, full_df: pd.DataFrame):
|
def no_cv_load(
|
||||||
|
self, full_df: pd.DataFrame
|
||||||
|
) -> Tuple[tf.data.Dataset, pd.DataFrame, int]:
|
||||||
full_df = self._compute_int_labels(full_df)
|
full_df = self._compute_int_labels(full_df)
|
||||||
|
|
||||||
val_test = full_df[full_df.origin == "precision"].copy(deep=True)
|
val_test = full_df[full_df.origin == "precision"].copy(deep=True)
|
||||||
|
|
|
@ -22,7 +22,7 @@ except ModuleNotFoundError:
|
||||||
LOCAL_MODEL_DIR = os.path.join(LOCAL_DIR, "models")
|
LOCAL_MODEL_DIR = os.path.join(LOCAL_DIR, "models")
|
||||||
|
|
||||||
|
|
||||||
def reload_model_weights(weights_dir, language, **kwargs):
|
def reload_model_weights(weights_dir, language: str, **kwargs):
|
||||||
optimizer = tf.keras.optimizers.Adam(0.01)
|
optimizer = tf.keras.optimizers.Adam(0.01)
|
||||||
model_type = (
|
model_type = (
|
||||||
"twitter_bert_base_en_uncased_mlm"
|
"twitter_bert_base_en_uncased_mlm"
|
||||||
|
@ -35,7 +35,7 @@ def reload_model_weights(weights_dir, language, **kwargs):
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def _locally_copy_models(model_type):
|
def _locally_copy_models(model_type: str):
|
||||||
if model_type == "twitter_multilingual_bert_base_cased_mlm":
|
if model_type == "twitter_multilingual_bert_base_cased_mlm":
|
||||||
preprocessor = "bert_multi_cased_preprocess_3"
|
preprocessor = "bert_multi_cased_preprocess_3"
|
||||||
elif model_type == "twitter_bert_base_en_uncased_mlm":
|
elif model_type == "twitter_bert_base_en_uncased_mlm":
|
||||||
|
@ -43,9 +43,7 @@ def _locally_copy_models(model_type):
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
copy_cmd = """mkdir {local_dir}
|
copy_cmd = "mkdir {local_dir}\ngsutil cp -r ...\ngsutil cp -r ..."
|
||||||
gsutil cp -r ...
|
|
||||||
gsutil cp -r ..."""
|
|
||||||
execute_command(
|
execute_command(
|
||||||
copy_cmd.format(
|
copy_cmd.format(
|
||||||
model_type=model_type, preprocessor=preprocessor, local_dir=LOCAL_MODEL_DIR
|
model_type=model_type, preprocessor=preprocessor, local_dir=LOCAL_MODEL_DIR
|
||||||
|
@ -55,7 +53,7 @@ gsutil cp -r ..."""
|
||||||
return preprocessor
|
return preprocessor
|
||||||
|
|
||||||
|
|
||||||
def load_encoder(model_type, trainable):
|
def load_encoder(model_type: str, trainable: bool):
|
||||||
try:
|
try:
|
||||||
model = TextEncoder(
|
model = TextEncoder(
|
||||||
max_seq_lengths=MAX_SEQ_LENGTH,
|
max_seq_lengths=MAX_SEQ_LENGTH,
|
||||||
|
@ -80,7 +78,7 @@ def load_encoder(model_type, trainable):
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def get_loss(loss_name, from_logits, **kwargs):
|
def get_loss(loss_name: str, from_logits, **kwargs):
|
||||||
loss_name = loss_name.lower()
|
loss_name = loss_name.lower()
|
||||||
if loss_name == "bce":
|
if loss_name == "bce":
|
||||||
print("Binary CE loss")
|
print("Binary CE loss")
|
||||||
|
@ -117,7 +115,7 @@ def get_loss(loss_name, from_logits, **kwargs):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _add_additional_embedding_layer(doc_embedding, glorot, seed):
|
def _add_additional_embedding_layer(doc_embedding, glorot, seed: int):
|
||||||
doc_embedding = tf.keras.layers.Dense(
|
doc_embedding = tf.keras.layers.Dense(
|
||||||
768, activation="tanh", kernel_initializer=glorot
|
768, activation="tanh", kernel_initializer=glorot
|
||||||
)(doc_embedding)
|
)(doc_embedding)
|
||||||
|
@ -213,11 +211,11 @@ def load_bertweet(**kwargs):
|
||||||
|
|
||||||
|
|
||||||
def load(
|
def load(
|
||||||
optimizer,
|
optimizer: str,
|
||||||
seed,
|
seed: int,
|
||||||
model_type="twitter_multilingual_bert_base_cased_mlm",
|
model_type: str = "twitter_multilingual_bert_base_cased_mlm",
|
||||||
loss_name="BCE",
|
loss_name: str = "BCE",
|
||||||
trainable=True,
|
trainable: bool = True,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
if model_type == "bertweet-base":
|
if model_type == "bertweet-base":
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
import os
|
import os
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
import wandb
|
import wandb
|
||||||
from sklearn.metrics import average_precision_score, roc_auc_score
|
from sklearn.metrics import average_precision_score, roc_auc_score
|
||||||
|
@ -87,13 +89,13 @@ class GradientLoggingTensorBoard(SyncingTensorBoard):
|
||||||
class AdditionalResultLogger(tf.keras.callbacks.Callback):
|
class AdditionalResultLogger(tf.keras.callbacks.Callback):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
data,
|
data: tf.data.Dataset,
|
||||||
set_,
|
set_: str,
|
||||||
fixed_recall=0.85,
|
fixed_recall: float = 0.85,
|
||||||
from_logits=False,
|
from_logits: bool = False,
|
||||||
dataset_transform_func=None,
|
dataset_transform_func: callable = None,
|
||||||
batch_size=64,
|
batch_size: int = 64,
|
||||||
dual_head=None,
|
dual_head: List[str] = None,
|
||||||
*args,
|
*args,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
@ -144,7 +146,7 @@ class AdditionalResultLogger(tf.keras.callbacks.Callback):
|
||||||
self.fixed_recall = fixed_recall
|
self.fixed_recall = fixed_recall
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
|
|
||||||
def compute_precision_fixed_recall(self, labels, preds):
|
def compute_precision_fixed_recall(self, labels: np.ndarray, preds: np.ndarray):
|
||||||
result, _ = compute_precision_fixed_recall(
|
result, _ = compute_precision_fixed_recall(
|
||||||
labels=labels, preds=preds, fixed_recall=self.fixed_recall
|
labels=labels, preds=preds, fixed_recall=self.fixed_recall
|
||||||
)
|
)
|
||||||
|
@ -159,7 +161,9 @@ class AdditionalResultLogger(tf.keras.callbacks.Callback):
|
||||||
if self.counter % 2000 == 0:
|
if self.counter % 2000 == 0:
|
||||||
self.additional_evaluations(step=self.counter, eval_time="batch")
|
self.additional_evaluations(step=self.counter, eval_time="batch")
|
||||||
|
|
||||||
def _binary_evaluations(self, preds, label_name=None, class_index=None):
|
def _binary_evaluations(
|
||||||
|
self, preds: np.ndarray, label_name=None, class_index: int = None
|
||||||
|
):
|
||||||
mask = None
|
mask = None
|
||||||
curr_labels = self.labels
|
curr_labels = self.labels
|
||||||
if label_name is not None:
|
if label_name is not None:
|
||||||
|
@ -180,7 +184,7 @@ class AdditionalResultLogger(tf.keras.callbacks.Callback):
|
||||||
"roc_auc": roc_auc_score(y_true=curr_labels, y_score=preds),
|
"roc_auc": roc_auc_score(y_true=curr_labels, y_score=preds),
|
||||||
}
|
}
|
||||||
|
|
||||||
def _multiclass_evaluations(self, preds):
|
def _multiclass_evaluations(self, preds: np.ndarray):
|
||||||
pr_auc_l = average_precision_score(
|
pr_auc_l = average_precision_score(
|
||||||
y_true=self.labels, y_score=preds, **self.metric_kw
|
y_true=self.labels, y_score=preds, **self.metric_kw
|
||||||
)
|
)
|
||||||
|
@ -192,7 +196,7 @@ class AdditionalResultLogger(tf.keras.callbacks.Callback):
|
||||||
|
|
||||||
return metrics
|
return metrics
|
||||||
|
|
||||||
def additional_evaluations(self, step, eval_time):
|
def additional_evaluations(self, step: int, eval_time: str):
|
||||||
print("Evaluating ", self.set_, eval_time, step)
|
print("Evaluating ", self.set_, eval_time, step)
|
||||||
|
|
||||||
preds = self.model.predict(x=self.data, batch_size=self.batch_size)
|
preds = self.model.predict(x=self.data, batch_size=self.batch_size)
|
||||||
|
@ -232,7 +236,7 @@ class AdditionalResultLogger(tf.keras.callbacks.Callback):
|
||||||
|
|
||||||
self.log_metrics(metrics, step=step, eval_time=eval_time)
|
self.log_metrics(metrics, step=step, eval_time=eval_time)
|
||||||
|
|
||||||
def log_metrics(self, metrics_d, step, eval_time):
|
def log_metrics(self, metrics_d: dict, step: int, eval_time: str):
|
||||||
commit = False if self.set_ == "validation" else True
|
commit = False if self.set_ == "validation" else True
|
||||||
to_report = {self.set_: {**metrics_d, **self.best_metrics}}
|
to_report = {self.set_: {**metrics_d, **self.best_metrics}}
|
||||||
|
|
||||||
|
|
|
@ -3,7 +3,7 @@ from keras import backend
|
||||||
from keras.utils import losses_utils, tf_utils
|
from keras.utils import losses_utils, tf_utils
|
||||||
|
|
||||||
|
|
||||||
def inv_kl_divergence(y_true, y_pred):
|
def inv_kl_divergence(y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
|
||||||
y_pred = tf.convert_to_tensor(y_pred)
|
y_pred = tf.convert_to_tensor(y_pred)
|
||||||
y_true = tf.cast(y_true, y_pred.dtype)
|
y_true = tf.cast(y_true, y_pred.dtype)
|
||||||
y_true = backend.clip(y_true, backend.epsilon(), 1)
|
y_true = backend.clip(y_true, backend.epsilon(), 1)
|
||||||
|
@ -11,7 +11,7 @@ def inv_kl_divergence(y_true, y_pred):
|
||||||
return tf.reduce_sum(y_pred * tf.math.log(y_pred / y_true), axis=-1)
|
return tf.reduce_sum(y_pred * tf.math.log(y_pred / y_true), axis=-1)
|
||||||
|
|
||||||
|
|
||||||
def masked_bce(y_true, y_pred):
|
def masked_bce(y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
|
||||||
y_true = tf.cast(y_true, dtype=tf.float32)
|
y_true = tf.cast(y_true, dtype=tf.float32)
|
||||||
mask = y_true != -1
|
mask = y_true != -1
|
||||||
|
|
||||||
|
@ -28,7 +28,7 @@ class LossFunctionWrapper(tf.keras.losses.Loss):
|
||||||
self.fn = fn
|
self.fn = fn
|
||||||
self._fn_kwargs = kwargs
|
self._fn_kwargs = kwargs
|
||||||
|
|
||||||
def call(self, y_true, y_pred):
|
def call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
|
||||||
if tf.is_tensor(y_pred) and tf.is_tensor(y_true):
|
if tf.is_tensor(y_pred) and tf.is_tensor(y_true):
|
||||||
y_pred, y_true = losses_utils.squeeze_or_expand_dimensions(y_pred, y_true)
|
y_pred, y_true = losses_utils.squeeze_or_expand_dimensions(y_pred, y_true)
|
||||||
|
|
||||||
|
@ -37,7 +37,7 @@ class LossFunctionWrapper(tf.keras.losses.Loss):
|
||||||
)
|
)
|
||||||
return ag_fn(y_true, y_pred, **self._fn_kwargs)
|
return ag_fn(y_true, y_pred, **self._fn_kwargs)
|
||||||
|
|
||||||
def get_config(self):
|
def get_config(self) -> dict:
|
||||||
config = {}
|
config = {}
|
||||||
for k, v in self._fn_kwargs.items():
|
for k, v in self._fn_kwargs.items():
|
||||||
config[k] = backend.eval(v) if tf_utils.is_tensor_or_variable(v) else v
|
config[k] = backend.eval(v) if tf_utils.is_tensor_or_variable(v) else v
|
||||||
|
@ -47,11 +47,13 @@ class LossFunctionWrapper(tf.keras.losses.Loss):
|
||||||
|
|
||||||
class InvKLD(LossFunctionWrapper):
|
class InvKLD(LossFunctionWrapper):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, reduction=losses_utils.ReductionV2.AUTO, name="inv_kl_divergence"
|
self, reduction=losses_utils.ReductionV2.AUTO, name: str = "inv_kl_divergence"
|
||||||
):
|
):
|
||||||
super().__init__(inv_kl_divergence, name=name, reduction=reduction)
|
super().__init__(inv_kl_divergence, name=name, reduction=reduction)
|
||||||
|
|
||||||
|
|
||||||
class MaskedBCE(LossFunctionWrapper):
|
class MaskedBCE(LossFunctionWrapper):
|
||||||
def __init__(self, reduction=losses_utils.ReductionV2.AUTO, name="masked_bce"):
|
def __init__(
|
||||||
|
self, reduction=losses_utils.ReductionV2.AUTO, name: str = "masked_bce"
|
||||||
|
):
|
||||||
super().__init__(masked_bce, name=name, reduction=reduction)
|
super().__init__(masked_bce, name=name, reduction=reduction)
|
||||||
|
|
|
@ -1,11 +1,18 @@
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from toxicity_ml_pipeline.load_model import reload_model_weights
|
from toxicity_ml_pipeline.load_model import reload_model_weights
|
||||||
from toxicity_ml_pipeline.utils.helpers import load_inference_func, upload_model
|
from toxicity_ml_pipeline.utils.helpers import load_inference_func, upload_model
|
||||||
|
|
||||||
|
|
||||||
def score(
|
def score(
|
||||||
language, df, gcs_model_path, batch_size=64, text_col="text", kw="", **kwargs
|
language: str,
|
||||||
|
df: pd.DataFrame,
|
||||||
|
gcs_model_path: str,
|
||||||
|
batch_size: int = 64,
|
||||||
|
text_col: str = "text",
|
||||||
|
kw: str = "",
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
if language != "en":
|
if language != "en":
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
|
@ -41,7 +48,13 @@ def score(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _get_score(inference_func, df, text_col="text", kw="", batch_size=64):
|
def _get_score(
|
||||||
|
inference_func: tf.function,
|
||||||
|
df: pd.DataFrame,
|
||||||
|
text_col: str = "text",
|
||||||
|
kw: str = "",
|
||||||
|
batch_size: int = 64,
|
||||||
|
) -> pd.DataFrame:
|
||||||
score_col = f"prediction_{kw}"
|
score_col = f"prediction_{kw}"
|
||||||
beginning = 0
|
beginning = 0
|
||||||
end = df.shape[0]
|
end = df.shape[0]
|
||||||
|
|
|
@ -3,6 +3,7 @@ from datetime import datetime
|
||||||
from importlib import import_module
|
from importlib import import_module
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from toxicity_ml_pipeline.data.data_preprocessing import (
|
from toxicity_ml_pipeline.data.data_preprocessing import (
|
||||||
DefaultENNoPreprocessor,
|
DefaultENNoPreprocessor,
|
||||||
|
@ -247,7 +248,9 @@ class Trainer(object):
|
||||||
return warm_up_schedule
|
return warm_up_schedule
|
||||||
return learning_rate_fn
|
return learning_rate_fn
|
||||||
|
|
||||||
def get_optimizer(self, schedule):
|
def get_optimizer(
|
||||||
|
self, schedule: tf.keras.optimizers.schedules.LearningRateSchedule
|
||||||
|
):
|
||||||
optim_args = {
|
optim_args = {
|
||||||
"learning_rate": schedule,
|
"learning_rate": schedule,
|
||||||
"beta_1": 0.9,
|
"beta_1": 0.9,
|
||||||
|
@ -277,7 +280,7 @@ class Trainer(object):
|
||||||
|
|
||||||
return optimizer, callbacks
|
return optimizer, callbacks
|
||||||
|
|
||||||
def load_data(self):
|
def load_data(self) -> pd.DataFrame:
|
||||||
if self.project == 435 or self.project == 211:
|
if self.project == 435 or self.project == 211:
|
||||||
if self.dataset_type is None:
|
if self.dataset_type is None:
|
||||||
data_loader = ENLoader(
|
data_loader = ENLoader(
|
||||||
|
@ -299,7 +302,7 @@ class Trainer(object):
|
||||||
|
|
||||||
return df
|
return df
|
||||||
|
|
||||||
def preprocess(self, df):
|
def preprocess(self, df: pd.DataFrame):
|
||||||
if self.project == 435 or self.project == 211:
|
if self.project == 435 or self.project == 211:
|
||||||
if self.preprocessing is None:
|
if self.preprocessing is None:
|
||||||
data_prepro = DefaultENNoPreprocessor()
|
data_prepro = DefaultENNoPreprocessor()
|
||||||
|
@ -318,7 +321,7 @@ class Trainer(object):
|
||||||
num_classes=self.num_classes,
|
num_classes=self.num_classes,
|
||||||
)
|
)
|
||||||
|
|
||||||
def load_model(self, optimizer):
|
def load_model(self, optimizer: tf.keras.optimizers.Optimizer):
|
||||||
smart_bias_value = (
|
smart_bias_value = (
|
||||||
np.log(self.perc_training_tox / (1 - self.perc_training_tox))
|
np.log(self.perc_training_tox / (1 - self.perc_training_tox))
|
||||||
if self.smart_bias_init
|
if self.smart_bias_init
|
||||||
|
@ -354,7 +357,12 @@ class Trainer(object):
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def _train_single_fold(
|
def _train_single_fold(
|
||||||
self, mb_generator, test_data, steps_per_epoch, fold, val_data=None
|
self,
|
||||||
|
mb_generator: tf.data.Dataset,
|
||||||
|
test_data: pd.DataFrame,
|
||||||
|
steps_per_epoch: int,
|
||||||
|
fold: int,
|
||||||
|
val_data: pd.DataFrame = None,
|
||||||
):
|
):
|
||||||
steps_per_epoch = 100 if self.test else steps_per_epoch
|
steps_per_epoch = 100 if self.test else steps_per_epoch
|
||||||
|
|
||||||
|
@ -376,7 +384,6 @@ class Trainer(object):
|
||||||
}
|
}
|
||||||
|
|
||||||
model.fit(mb_generator, **training_args)
|
model.fit(mb_generator, **training_args)
|
||||||
return
|
|
||||||
|
|
||||||
def train_full_model(self):
|
def train_full_model(self):
|
||||||
print("Setting up random seed.")
|
print("Setting up random seed.")
|
||||||
|
|
|
@ -13,7 +13,7 @@ except ModuleNotFoundError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def upload_model(full_gcs_model_path):
|
def upload_model(full_gcs_model_path: str):
|
||||||
folder_name = full_gcs_model_path
|
folder_name = full_gcs_model_path
|
||||||
if folder_name[:5] != "gs://":
|
if folder_name[:5] != "gs://":
|
||||||
folder_name = "gs://" + folder_name
|
folder_name = "gs://" + folder_name
|
||||||
|
@ -50,7 +50,9 @@ def upload_model(full_gcs_model_path):
|
||||||
return weights_dir
|
return weights_dir
|
||||||
|
|
||||||
|
|
||||||
def compute_precision_fixed_recall(labels, preds, fixed_recall):
|
def compute_precision_fixed_recall(
|
||||||
|
labels: np.ndarray, preds: np.ndarray, fixed_recall: float
|
||||||
|
):
|
||||||
precision_values, recall_values, thresholds = precision_recall_curve(
|
precision_values, recall_values, thresholds = precision_recall_curve(
|
||||||
y_true=labels, probas_pred=preds
|
y_true=labels, probas_pred=preds
|
||||||
)
|
)
|
||||||
|
@ -61,7 +63,7 @@ def compute_precision_fixed_recall(labels, preds, fixed_recall):
|
||||||
return result, thresholds[index_recall - 1]
|
return result, thresholds[index_recall - 1]
|
||||||
|
|
||||||
|
|
||||||
def load_inference_func(model_folder):
|
def load_inference_func(model_folder: str):
|
||||||
model = tf.saved_model.load(model_folder, ["serve"])
|
model = tf.saved_model.load(model_folder, ["serve"])
|
||||||
inference_func = model.signatures["serving_default"]
|
inference_func = model.signatures["serving_default"]
|
||||||
return inference_func
|
return inference_func
|
||||||
|
@ -73,7 +75,7 @@ def execute_query(client, query):
|
||||||
return df
|
return df
|
||||||
|
|
||||||
|
|
||||||
def execute_command(cmd, print_=True):
|
def execute_command(cmd: str, print_: bool = True):
|
||||||
s = subprocess.run(cmd, shell=True, capture_output=print_, check=True)
|
s = subprocess.run(cmd, shell=True, capture_output=print_, check=True)
|
||||||
if print_:
|
if print_:
|
||||||
print(s.stderr.decode("utf-8"))
|
print(s.stderr.decode("utf-8"))
|
||||||
|
@ -95,9 +97,7 @@ def check_gpu():
|
||||||
print(l)
|
print(l)
|
||||||
|
|
||||||
|
|
||||||
def set_seeds(seed):
|
def set_seeds(seed: int):
|
||||||
np.random.seed(seed)
|
np.random.seed(seed)
|
||||||
|
|
||||||
python_random.seed(seed)
|
python_random.seed(seed)
|
||||||
|
|
||||||
tf.random.set_seed(seed)
|
tf.random.set_seed(seed)
|
||||||
|
|
|
@ -401,13 +401,13 @@ class ExperimentTracker(object):
|
||||||
logging.error("Failed to export feature spec. Error: %s", str(err))
|
logging.error("Failed to export feature spec. Error: %s", str(err))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def path(self) -> Union[Dict[str, str], None]:
|
def path(self) -> Optional[Dict[str, str]]:
|
||||||
if self.disabled:
|
if self.disabled:
|
||||||
return None
|
return None
|
||||||
return get_components_from_id(self.tracking_path, ensure_valid_id=False)
|
return get_components_from_id(self.tracking_path, ensure_valid_id=False)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def experiment_id(self) -> Union[str, None]:
|
def experiment_id(self) -> Optional[str]:
|
||||||
"""Return the experiment id."""
|
"""Return the experiment id."""
|
||||||
if self.disabled:
|
if self.disabled:
|
||||||
return None
|
return None
|
||||||
|
|
Loading…
Reference in New Issue
Block a user