diff --git a/ann/src/main/python/dataflow/faiss_index_bq_dataset.py b/ann/src/main/python/dataflow/faiss_index_bq_dataset.py index 1863cabef..fa6101c5a 100644 --- a/ann/src/main/python/dataflow/faiss_index_bq_dataset.py +++ b/ann/src/main/python/dataflow/faiss_index_bq_dataset.py @@ -6,8 +6,8 @@ import sys from urllib.parse import urlsplit import apache_beam as beam -from apache_beam.options.pipeline_options import PipelineOptions import faiss +from apache_beam.options.pipeline_options import PipelineOptions def parse_d6w_config(argv=None): @@ -160,8 +160,8 @@ class MergeAndBuildIndex(beam.CombineFn): import subprocess import faiss - from google.cloud import storage import numpy as np + from google.cloud import storage client = storage.Client() bucket = client.get_bucket(self.bucket_name) diff --git a/src/python/twitter/deepbird/projects/timelines/scripts/models/earlybird/example_weights.py b/src/python/twitter/deepbird/projects/timelines/scripts/models/earlybird/example_weights.py index cf0c38ecc..5bda19274 100644 --- a/src/python/twitter/deepbird/projects/timelines/scripts/models/earlybird/example_weights.py +++ b/src/python/twitter/deepbird/projects/timelines/scripts/models/earlybird/example_weights.py @@ -1,5 +1,6 @@ # checkstyle: noqa import tensorflow.compat.v1 as tf + from .constants import INDEX_BY_LABEL, LABEL_NAMES # TODO: Read these from command line arguments, since they specify the existing example weights in the input data. diff --git a/src/python/twitter/deepbird/projects/timelines/scripts/models/earlybird/lolly/data_helpers.py b/src/python/twitter/deepbird/projects/timelines/scripts/models/earlybird/lolly/data_helpers.py index 723dd626c..af1faefaf 100644 --- a/src/python/twitter/deepbird/projects/timelines/scripts/models/earlybird/lolly/data_helpers.py +++ b/src/python/twitter/deepbird/projects/timelines/scripts/models/earlybird/lolly/data_helpers.py @@ -1,7 +1,9 @@ # checkstyle: noqa import tensorflow.compat.v1 as tf + from ..constants import EB_SCORE_IDX + # The rationale behind this logic is available at TQ-9678. def get_lolly_logits(labels): ''' diff --git a/src/python/twitter/deepbird/projects/timelines/scripts/models/earlybird/lolly/score.py b/src/python/twitter/deepbird/projects/timelines/scripts/models/earlybird/lolly/score.py index 5692616c2..88ee6f391 100644 --- a/src/python/twitter/deepbird/projects/timelines/scripts/models/earlybird/lolly/score.py +++ b/src/python/twitter/deepbird/projects/timelines/scripts/models/earlybird/lolly/score.py @@ -4,7 +4,6 @@ from .parsers import DBv2DataExampleParser from .reader import LollyModelReader from .scorer import LollyModelScorer - if __name__ == "__main__": lolly_model_reader = LollyModelReader(lolly_model_file_path=sys.argv[1]) lolly_model_scorer = LollyModelScorer(data_example_parser=DBv2DataExampleParser(lolly_model_reader)) diff --git a/src/python/twitter/deepbird/projects/timelines/scripts/models/earlybird/metrics.py b/src/python/twitter/deepbird/projects/timelines/scripts/models/earlybird/metrics.py index 6919914f8..d926a23d0 100644 --- a/src/python/twitter/deepbird/projects/timelines/scripts/models/earlybird/metrics.py +++ b/src/python/twitter/deepbird/projects/timelines/scripts/models/earlybird/metrics.py @@ -1,10 +1,13 @@ # checkstyle: noqa -import tensorflow.compat.v1 as tf from collections import OrderedDict + +import tensorflow.compat.v1 as tf + +import twml + from .constants import EB_SCORE_IDX from .lolly.data_helpers import get_lolly_scores -import twml def get_multi_binary_class_metric_fn(metrics, classes=None, class_dim=1): """ diff --git a/src/python/twitter/deepbird/projects/timelines/scripts/models/earlybird/tf_model/discretizer_builder.py b/src/python/twitter/deepbird/projects/timelines/scripts/models/earlybird/tf_model/discretizer_builder.py index 82c31bde0..6b2e9559c 100644 --- a/src/python/twitter/deepbird/projects/timelines/scripts/models/earlybird/tf_model/discretizer_builder.py +++ b/src/python/twitter/deepbird/projects/timelines/scripts/models/earlybird/tf_model/discretizer_builder.py @@ -1,7 +1,8 @@ -from .hashing_utils import make_feature_id +import numpy as np from twml.contrib.layers.hashing_discretizer import HashingDiscretizer -import numpy as np + +from .hashing_utils import make_feature_id class TFModelDiscretizerBuilder(object): diff --git a/src/python/twitter/deepbird/projects/timelines/scripts/models/earlybird/tf_model/hashing_utils.py b/src/python/twitter/deepbird/projects/timelines/scripts/models/earlybird/tf_model/hashing_utils.py index 2c57f8d63..d5140da25 100644 --- a/src/python/twitter/deepbird/projects/timelines/scripts/models/earlybird/tf_model/hashing_utils.py +++ b/src/python/twitter/deepbird/projects/timelines/scripts/models/earlybird/tf_model/hashing_utils.py @@ -1,6 +1,5 @@ -from twitter.deepbird.io.util import _get_feature_id - import numpy as np +from twitter.deepbird.io.util import _get_feature_id def numpy_hashing_uniform(the_id, bin_idx, output_bits): diff --git a/src/python/twitter/deepbird/projects/timelines/scripts/models/earlybird/tf_model/weights_initializer_builder.py b/src/python/twitter/deepbird/projects/timelines/scripts/models/earlybird/tf_model/weights_initializer_builder.py index 63491ea38..4a1bd5b67 100644 --- a/src/python/twitter/deepbird/projects/timelines/scripts/models/earlybird/tf_model/weights_initializer_builder.py +++ b/src/python/twitter/deepbird/projects/timelines/scripts/models/earlybird/tf_model/weights_initializer_builder.py @@ -1,9 +1,10 @@ -from .hashing_utils import make_feature_id, numpy_hashing_uniform - import numpy as np import tensorflow.compat.v1 as tf + import twml +from .hashing_utils import make_feature_id, numpy_hashing_uniform + class TFModelWeightsInitializerBuilder(object): def __init__(self, num_bits): diff --git a/src/python/twitter/deepbird/projects/timelines/scripts/models/earlybird/train.py b/src/python/twitter/deepbird/projects/timelines/scripts/models/earlybird/train.py index 6ef181f5f..c0cff5bc0 100644 --- a/src/python/twitter/deepbird/projects/timelines/scripts/models/earlybird/train.py +++ b/src/python/twitter/deepbird/projects/timelines/scripts/models/earlybird/train.py @@ -1,26 +1,32 @@ # checkstyle: noqa +from datetime import datetime + import tensorflow.compat.v1 as tf -from tensorflow.python.estimator.export.export import build_raw_serving_input_receiver_fn +import tensorflow_hub as hub +from tensorflow.compat.v1 import logging +from tensorflow.python.estimator.export.export import ( + build_raw_serving_input_receiver_fn, +) from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops -import tensorflow_hub as hub - -from datetime import datetime -from tensorflow.compat.v1 import logging from twitter.deepbird.projects.timelines.configs import all_configs + +import twml +from twml.contrib.calibrators.common_calibrators import ( + build_percentile_discretizer_graph, + calibrate_discretizer_and_export, +) from twml.trainers import DataRecordTrainer -from twml.contrib.calibrators.common_calibrators import build_percentile_discretizer_graph -from twml.contrib.calibrators.common_calibrators import calibrate_discretizer_and_export -from .metrics import get_multi_binary_class_metric_fn -from .constants import TARGET_LABEL_IDX, PREDICTED_CLASSES + +from .constants import PREDICTED_CLASSES, TARGET_LABEL_IDX from .example_weights import add_weight_arguments, make_weights_tensor from .lolly.data_helpers import get_lolly_logits -from .lolly.tf_model_initializer_builder import TFModelInitializerBuilder from .lolly.reader import LollyModelReader +from .lolly.tf_model_initializer_builder import TFModelInitializerBuilder +from .metrics import get_multi_binary_class_metric_fn from .tf_model.discretizer_builder import TFModelDiscretizerBuilder from .tf_model.weights_initializer_builder import TFModelWeightsInitializerBuilder -import twml def get_feature_values(features_values, params): if params.lolly_model_tsv: diff --git a/trust_and_safety_models/abusive/abusive_model.py b/trust_and_safety_models/abusive/abusive_model.py index 06fff4ed2..89f0319e4 100644 --- a/trust_and_safety_models/abusive/abusive_model.py +++ b/trust_and_safety_models/abusive/abusive_model.py @@ -1,19 +1,45 @@ +import datetime +import os +from dataclasses import asdict + +import numpy as np +import pandas as pd import tensorflow as tf +import tensorflow_hub as hub +import utils +import wandb + +try: + wandb_key = ... + wandb.login(...) + run = wandb.init(project='ptos_with_media', + group='new-split-trains', + notes='tweet text with only (num_media, precision_nsfw). on full train set, new split.', + entity='absv', + config=params, + name='tweet-text-w-nsfw-1.1', + sync_tensorboard=True) +except FileNotFoundError: + print('Wandb key not found') + run = wandb.init(mode='disabled') + + +from notebook_eval_utils import EvalConfig, SparseMultilabelEvaluator +from twitter.cuad.representation.models.optimization import create_optimizer +from twitter.cuad.representation.models.text_encoder import TextEncoder +from twitter.hmli.nimbus.modeling.feature_encoder import FeatureEncoder +from twitter.hmli.nimbus.modeling.feature_loader import BigQueryFeatureLoader +from twitter.hmli.nimbus.modeling.model_config import ( + EncodingType, + Feature, + FeatureType, + Model, +) physical_devices = tf.config.list_physical_devices('GPU') for device in physical_devices: tf.config.experimental.set_memory_growth(device, True) -from twitter.hmli.nimbus.modeling.model_config import FeatureType, EncodingType, Feature, Model, LogType -from twitter.hmli.nimbus.modeling.feature_loader import BigQueryFeatureLoader -from twitter.cuad.representation.models.text_encoder import TextEncoder -from twitter.cuad.representation.models.optimization import create_optimizer -from twitter.hmli.nimbus.modeling.feature_encoder import FeatureEncoder - -import numpy as np -import pandas as pd -import utils - cat_names = [ ... ] @@ -75,7 +101,6 @@ params = { 'model_type': 'twitter_multilingual_bert_base_cased_mlm', 'mixed_precision': True, } -params def parse_labeled_data(row_dict): label = [row_dict.pop(l) for l in labels] @@ -134,7 +159,9 @@ with mirrored_strategy.scope(): ) pr_auc = tf.keras.metrics.AUC(curve="PR", num_thresholds=1000, multi_label=True, from_logits=True) - custom_loss = lambda y_true, y_pred: utils.multilabel_weighted_loss(y_true, y_pred, weights=pos_weight_tensor) + def custom_loss(y_true, y_pred): + return utils.multilabel_weighted_loss(y_true, y_pred, weights=pos_weight_tensor) + optimizer = create_optimizer( init_lr=params["lr"], num_train_steps=(params["epochs"] * params["steps_per_epoch"]), @@ -154,25 +181,6 @@ model.weights model.summary() pr_auc.name -import getpass -import wandb -from wandb.keras import WandbCallback -try: - wandb_key = ... - wandb.login(...) - run = wandb.init(project='ptos_with_media', - group='new-split-trains', - notes='tweet text with only (num_media, precision_nsfw). on full train set, new split.', - entity='absv', - config=params, - name='tweet-text-w-nsfw-1.1', - sync_tensorboard=True) -except FileNotFoundError: - print('Wandb key not found') - run = wandb.init(mode='disabled') -import datetime -import os - start_train_time = datetime.datetime.now() print(start_train_time.strftime("%m-%d-%Y (%H:%M:%S)")) checkpoint_path = os.path.join("...") @@ -195,8 +203,6 @@ model.fit(train_ds, epochs=params["epochs"], validation_data=val_ds, callbacks=[ steps_per_epoch=params["steps_per_epoch"], verbose=2) -import tensorflow_hub as hub - gs_model_path = ... reloaded_keras_layer = hub.KerasLayer(gs_model_path) inputs = tf.keras.layers.Input(name="tweet__core__tweet__text", shape=(1,), dtype=tf.string) @@ -233,9 +239,6 @@ test_media_not_nsfw = test.filter(lambda x, y: tf.logical_and(tf.equal(x["has_me for d in [test, test_only_media, test_only_nsfw, test_no_media, test_media_not_nsfw]: print(d.reduce(0, lambda x, _: x + 1).numpy()) -from notebook_eval_utils import SparseMultilabelEvaluator, EvalConfig -from dataclasses import asdict - def display_metrics(probs, targets, labels=labels): eval_config = EvalConfig(prediction_threshold=0.5, precision_k=0.9) for eval_mode, y_mask in [("implicit", np.ones(targets.shape))]: @@ -273,4 +276,4 @@ for name, x in [(name, m.pr_auc.to_string(index=False).strip().split("\n")) for print(y.strip(), end="\t") print(".") for d in [test, test_only_media, test_only_nsfw, test_no_media, test_media_not_nsfw]: - print(d.reduce(0, lambda x, _: x + 1).numpy()) \ No newline at end of file + print(d.reduce(0, lambda x, _: x + 1).numpy()) diff --git a/trust_and_safety_models/nsfw/nsfw_media.py b/trust_and_safety_models/nsfw/nsfw_media.py index b5dfebb65..2c9202fc7 100644 --- a/trust_and_safety_models/nsfw/nsfw_media.py +++ b/trust_and_safety_models/nsfw/nsfw_media.py @@ -1,21 +1,19 @@ +import glob +import os +import random + import kerastuner as kt -import math import numpy as np import pandas as pd -import random import sklearn.metrics import tensorflow as tf -import os -import glob - -from tqdm import tqdm -from matplotlib import pyplot as plt -from tensorflow.keras.models import Sequential -from tensorflow.keras.layers import Dense from google.cloud import storage +from matplotlib import pyplot as plt +from tensorflow.keras.layers import Dense +from tensorflow.keras.models import Sequential +from tqdm import tqdm physical_devices = tf.config.list_physical_devices('GPU') -physical_devices tf.config.set_visible_devices([tf.config.PhysicalDevice(name='/physical_device:GPU:1', device_type='GPU')], 'GPU') tf.config.get_visible_devices('GPU') @@ -89,7 +87,7 @@ test_ds = test_ds.map(lambda x: preprocess_embedding_example(x, positive_label=p if use_sens_prev_data: test_sens_prev_glob = f"{sens_prev_input_root}/test/tfrecord/*.tfrecord" - test_sens_prev_files = tf.io.gfile.glob(test_sens_prev_glob) + test_sens_prev_files = tf.io.gfile.glob(test_sens_prev_glob) if not len(test_sens_prev_files): raise ValueError(f"Did not find any eval files matching {test_sens_prev_glob}") @@ -109,12 +107,12 @@ train_ds = train_ds.repeat() if has_validation_data: eval_glob = f"{input_root}/validation/tfrecord/*.tfrecord" - eval_files = tf.io.gfile.glob(eval_glob) + eval_files = tf.io.gfile.glob(eval_glob) if use_sens_prev_data: eval_sens_prev_glob = f"{sens_prev_input_root}/validation/tfrecord/*.tfrecord" eval_sens_prev_files = tf.io.gfile.glob(eval_sens_prev_glob) - eval_files = eval_files + eval_sens_prev_files + eval_files = eval_files + eval_sens_prev_files if not len(eval_files): @@ -428,7 +426,7 @@ ptAt50fmt = "%.4f" % ptAt50[1] ptAt90fmt = "%.4f" % ptAt90[1] aucFmt = "%.4f" % auc_precision_recall plt.title( - f"Keras (nsfw MU test)\nAUC={aucFmt}\np={ptAt50fmt} @ r=0.5\np={ptAt90fmt} @ r=0.9\nN_train={...}} ({...} pos), N_test={n_test} ({n_test_pos} pos)", + f"Keras (nsfw MU test)\nAUC={aucFmt}\np={ptAt50fmt} @ r=0.5\np={ptAt90fmt} @ r=0.9\nN_train={...} ({...} pos), N_test={n_test} ({n_test_pos} pos)", size=20 ) plt.subplots_adjust(top=0.72) diff --git a/trust_and_safety_models/nsfw/nsfw_text.py b/trust_and_safety_models/nsfw/nsfw_text.py index 980fc8fd4..879c12e52 100644 --- a/trust_and_safety_models/nsfw/nsfw_text.py +++ b/trust_and_safety_models/nsfw/nsfw_text.py @@ -1,14 +1,16 @@ -from datetime import datetime -from functools import reduce import os -import pandas as pd -import re -from sklearn.metrics import average_precision_score, classification_report, precision_recall_curve, PrecisionRecallDisplay -from sklearn.model_selection import train_test_split -import tensorflow as tf -import matplotlib.pyplot as plt import re +from datetime import datetime +import matplotlib.pyplot as plt +import pandas as pd +import tensorflow as tf +from sklearn.metrics import ( + average_precision_score, + classification_report, + precision_recall_curve, +) +from sklearn.model_selection import train_test_split from twitter.cuad.representation.models.optimization import create_optimizer from twitter.cuad.representation.models.text_encoder import TextEncoder diff --git a/trust_and_safety_models/toxicity/data/data_preprocessing.py b/trust_and_safety_models/toxicity/data/data_preprocessing.py index f7da608f6..279e51d82 100644 --- a/trust_and_safety_models/toxicity/data/data_preprocessing.py +++ b/trust_and_safety_models/toxicity/data/data_preprocessing.py @@ -1,10 +1,8 @@ -from abc import ABC import re - -from toxicity_ml_pipeline.settings.hcomp_settings import TOXIC_35 +from abc import ABC import numpy as np - +from toxicity_ml_pipeline.settings.hcomp_settings import TOXIC_35 TOXIC_35_set = set(TOXIC_35) @@ -84,7 +82,7 @@ class DefaultENNoPreprocessor(DataframeCleaner): else: raise NotImplementedError - if "filter_low_agreements" in kwargs and kwargs["filter_low_agreements"] == True: + if "filter_low_agreements" in kwargs and kwargs["filter_low_agreements"] is True: df.drop(df[(df.agreement_rate <= 0.6)].index, axis=0, inplace=True) raise NotImplementedError diff --git a/trust_and_safety_models/toxicity/data/dataframe_loader.py b/trust_and_safety_models/toxicity/data/dataframe_loader.py index f3855d6b5..94d028ba8 100644 --- a/trust_and_safety_models/toxicity/data/dataframe_loader.py +++ b/trust_and_safety_models/toxicity/data/dataframe_loader.py @@ -287,7 +287,7 @@ class ENLoaderWithSampling(ENLoader): class I18nLoader(DataframeLoader): def __init__(self): super().__init__(project=...) - from archive.settings.... import ACCEPTED_LANGUAGES, QUERY_SETTINGS + from archive.settings import ACCEPTED_LANGUAGES, QUERY_SETTINGS self.accepted_languages = ACCEPTED_LANGUAGES self.query_settings = dict(QUERY_SETTINGS) diff --git a/trust_and_safety_models/toxicity/data/mb_generator.py b/trust_and_safety_models/toxicity/data/mb_generator.py index 58a89f8c5..b2a6bdd20 100644 --- a/trust_and_safety_models/toxicity/data/mb_generator.py +++ b/trust_and_safety_models/toxicity/data/mb_generator.py @@ -1,6 +1,10 @@ -from importlib import import_module import os +from importlib import import_module +import numpy as np +import pandas +import tensorflow as tf +from sklearn.model_selection import StratifiedKFold from toxicity_ml_pipeline.settings.default_settings_tox import ( INNER_CV, LOCAL_DIR, @@ -12,12 +16,6 @@ from toxicity_ml_pipeline.settings.default_settings_tox import ( ) from toxicity_ml_pipeline.utils.helpers import execute_command -import numpy as np -import pandas -from sklearn.model_selection import StratifiedKFold -import tensorflow as tf - - try: from transformers import AutoTokenizer, DataCollatorWithPadding except ModuleNotFoundError: diff --git a/trust_and_safety_models/toxicity/load_model.py b/trust_and_safety_models/toxicity/load_model.py index 7b271066f..cec761e67 100644 --- a/trust_and_safety_models/toxicity/load_model.py +++ b/trust_and_safety_models/toxicity/load_model.py @@ -1,14 +1,13 @@ import os from toxicity_ml_pipeline.settings.default_settings_tox import LOCAL_DIR, MAX_SEQ_LENGTH + try: from toxicity_ml_pipeline.optim.losses import MaskedBCE except ImportError: print('No MaskedBCE loss') -from toxicity_ml_pipeline.utils.helpers import execute_command - import tensorflow as tf - +from toxicity_ml_pipeline.utils.helpers import execute_command try: from twitter.cuad.representation.models.text_encoder import TextEncoder @@ -102,7 +101,7 @@ def get_loss(loss_name, from_logits, **kwargs): multitask = kwargs.get("multitask", False) if from_logits or multitask: raise NotImplementedError - print(f'Masked Binary Cross Entropy') + print('Masked Binary Cross Entropy') return MaskedBCE() if loss_name == "inv_kl_loss": diff --git a/trust_and_safety_models/toxicity/optim/callbacks.py b/trust_and_safety_models/toxicity/optim/callbacks.py index bbf8d7c97..e6ff2f374 100644 --- a/trust_and_safety_models/toxicity/optim/callbacks.py +++ b/trust_and_safety_models/toxicity/optim/callbacks.py @@ -1,14 +1,16 @@ -from collections import defaultdict import os +from collections import defaultdict -from toxicity_ml_pipeline.settings.default_settings_tox import REMOTE_LOGDIR -from toxicity_ml_pipeline.settings.default_settings_abs import LABEL_NAMES -from toxicity_ml_pipeline.utils.absv_utils import parse_labeled_data -from toxicity_ml_pipeline.utils.helpers import compute_precision_fixed_recall, execute_command - -from sklearn.metrics import average_precision_score, roc_auc_score import tensorflow as tf import wandb +from sklearn.metrics import average_precision_score, roc_auc_score +from toxicity_ml_pipeline.settings.default_settings_abs import LABEL_NAMES +from toxicity_ml_pipeline.settings.default_settings_tox import REMOTE_LOGDIR +from toxicity_ml_pipeline.utils.absv_utils import parse_labeled_data +from toxicity_ml_pipeline.utils.helpers import ( + compute_precision_fixed_recall, + execute_command, +) class NothingCallback(tf.keras.callbacks.Callback): diff --git a/trust_and_safety_models/toxicity/optim/losses.py b/trust_and_safety_models/toxicity/optim/losses.py index 273c6676e..0e9e90936 100644 --- a/trust_and_safety_models/toxicity/optim/losses.py +++ b/trust_and_safety_models/toxicity/optim/losses.py @@ -1,7 +1,7 @@ import tensorflow as tf -from keras.utils import tf_utils -from keras.utils import losses_utils from keras import backend +from keras.utils import losses_utils, tf_utils + def inv_kl_divergence(y_true, y_pred): y_pred = tf.convert_to_tensor(y_pred) diff --git a/trust_and_safety_models/toxicity/rescoring.py b/trust_and_safety_models/toxicity/rescoring.py index 71d95ed76..fb3ccad1b 100644 --- a/trust_and_safety_models/toxicity/rescoring.py +++ b/trust_and_safety_models/toxicity/rescoring.py @@ -1,8 +1,7 @@ -from toxicity_ml_pipeline.load_model import reload_model_weights -from toxicity_ml_pipeline.utils.helpers import load_inference_func, upload_model - import numpy as np import tensorflow as tf +from toxicity_ml_pipeline.load_model import reload_model_weights +from toxicity_ml_pipeline.utils.helpers import load_inference_func, upload_model def score(language, df, gcs_model_path, batch_size=64, text_col="text", kw="", **kwargs): diff --git a/trust_and_safety_models/toxicity/settings/default_settings_tox.py b/trust_and_safety_models/toxicity/settings/default_settings_tox.py index 0968b9adc..9b121f39a 100644 --- a/trust_and_safety_models/toxicity/settings/default_settings_tox.py +++ b/trust_and_safety_models/toxicity/settings/default_settings_tox.py @@ -1,6 +1,5 @@ import os - TEAM_PROJECT = "twttr-toxicity-prod" try: from google.cloud import bigquery @@ -16,7 +15,7 @@ else: CLIENT = None print("Issue at logging time", e) -TRAINING_DATA_LOCATION = f"..." +TRAINING_DATA_LOCATION = "..." GCS_ADDRESS = "..." LOCAL_DIR = os.getcwd() REMOTE_LOGDIR = "{GCS_ADDRESS}/logs" diff --git a/trust_and_safety_models/toxicity/train.py b/trust_and_safety_models/toxicity/train.py index de450ee7b..01e30bdb9 100644 --- a/trust_and_safety_models/toxicity/train.py +++ b/trust_and_safety_models/toxicity/train.py @@ -1,14 +1,16 @@ +import os from datetime import datetime from importlib import import_module -import os +import numpy as np +import tensorflow as tf from toxicity_ml_pipeline.data.data_preprocessing import ( DefaultENNoPreprocessor, DefaultENPreprocessor, ) from toxicity_ml_pipeline.data.dataframe_loader import ENLoader, ENLoaderWithSampling from toxicity_ml_pipeline.data.mb_generator import BalancedMiniBatchLoader -from toxicity_ml_pipeline.load_model import load, get_last_layer +from toxicity_ml_pipeline.load_model import get_last_layer, load from toxicity_ml_pipeline.optim.callbacks import ( AdditionalResultLogger, ControlledStoppingCheckpointCallback, @@ -19,6 +21,8 @@ from toxicity_ml_pipeline.optim.schedulers import WarmUp from toxicity_ml_pipeline.settings.default_settings_abs import GCS_ADDRESS as ABS_GCS from toxicity_ml_pipeline.settings.default_settings_tox import ( GCS_ADDRESS as TOX_GCS, +) +from toxicity_ml_pipeline.settings.default_settings_tox import ( MODEL_DIR, RANDOM_SEED, REMOTE_LOGDIR, @@ -26,10 +30,6 @@ from toxicity_ml_pipeline.settings.default_settings_tox import ( ) from toxicity_ml_pipeline.utils.helpers import check_gpu, set_seeds, upload_model -import numpy as np -import tensorflow as tf - - try: from tensorflow_addons.optimizers import AdamW except ModuleNotFoundError: @@ -139,9 +139,9 @@ class Trainer(object): ) print("------- Experiment name: ", experiment_name) self.logdir = ( - f"..." + "..." if self.test - else f"..." + else "..." ) self.checkpoint_path = f"{self.model_dir}/{experiment_name}" diff --git a/trust_and_safety_models/toxicity/utils/helpers.py b/trust_and_safety_models/toxicity/utils/helpers.py index c21d7eb1c..30b374768 100644 --- a/trust_and_safety_models/toxicity/utils/helpers.py +++ b/trust_and_safety_models/toxicity/utils/helpers.py @@ -3,11 +3,9 @@ import os import random as python_random import subprocess -from toxicity_ml_pipeline.settings.default_settings_tox import LOCAL_DIR - import numpy as np from sklearn.metrics import precision_recall_curve - +from toxicity_ml_pipeline.settings.default_settings_tox import LOCAL_DIR try: import tensorflow as tf