mirror of
https://github.com/twitter/the-algorithm.git
synced 2025-01-07 01:48:16 +01:00
Compare commits
2 Commits
d03bf89cbf
...
4a11b533cd
Author | SHA1 | Date | |
---|---|---|---|
|
4a11b533cd | ||
|
858525b9d8 |
@ -6,8 +6,8 @@ import sys
|
||||
from urllib.parse import urlsplit
|
||||
|
||||
import apache_beam as beam
|
||||
from apache_beam.options.pipeline_options import PipelineOptions
|
||||
import faiss
|
||||
from apache_beam.options.pipeline_options import PipelineOptions
|
||||
|
||||
|
||||
def parse_d6w_config(argv=None):
|
||||
@ -160,8 +160,8 @@ class MergeAndBuildIndex(beam.CombineFn):
|
||||
import subprocess
|
||||
|
||||
import faiss
|
||||
from google.cloud import storage
|
||||
import numpy as np
|
||||
from google.cloud import storage
|
||||
|
||||
client = storage.Client()
|
||||
bucket = client.get_bucket(self.bucket_name)
|
||||
|
@ -1,5 +1,6 @@
|
||||
# checkstyle: noqa
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
from .constants import INDEX_BY_LABEL, LABEL_NAMES
|
||||
|
||||
# TODO: Read these from command line arguments, since they specify the existing example weights in the input data.
|
||||
|
@ -1,7 +1,9 @@
|
||||
# checkstyle: noqa
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
from ..constants import EB_SCORE_IDX
|
||||
|
||||
|
||||
# The rationale behind this logic is available at TQ-9678.
|
||||
def get_lolly_logits(labels):
|
||||
'''
|
||||
|
@ -4,7 +4,6 @@ from .parsers import DBv2DataExampleParser
|
||||
from .reader import LollyModelReader
|
||||
from .scorer import LollyModelScorer
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
lolly_model_reader = LollyModelReader(lolly_model_file_path=sys.argv[1])
|
||||
lolly_model_scorer = LollyModelScorer(data_example_parser=DBv2DataExampleParser(lolly_model_reader))
|
||||
|
@ -1,10 +1,13 @@
|
||||
# checkstyle: noqa
|
||||
import tensorflow.compat.v1 as tf
|
||||
from collections import OrderedDict
|
||||
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
import twml
|
||||
|
||||
from .constants import EB_SCORE_IDX
|
||||
from .lolly.data_helpers import get_lolly_scores
|
||||
|
||||
import twml
|
||||
|
||||
def get_multi_binary_class_metric_fn(metrics, classes=None, class_dim=1):
|
||||
"""
|
||||
|
@ -1,7 +1,8 @@
|
||||
from .hashing_utils import make_feature_id
|
||||
import numpy as np
|
||||
|
||||
from twml.contrib.layers.hashing_discretizer import HashingDiscretizer
|
||||
import numpy as np
|
||||
|
||||
from .hashing_utils import make_feature_id
|
||||
|
||||
|
||||
class TFModelDiscretizerBuilder(object):
|
||||
|
@ -1,6 +1,5 @@
|
||||
from twitter.deepbird.io.util import _get_feature_id
|
||||
|
||||
import numpy as np
|
||||
from twitter.deepbird.io.util import _get_feature_id
|
||||
|
||||
|
||||
def numpy_hashing_uniform(the_id, bin_idx, output_bits):
|
||||
|
@ -1,9 +1,10 @@
|
||||
from .hashing_utils import make_feature_id, numpy_hashing_uniform
|
||||
|
||||
import numpy as np
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
import twml
|
||||
|
||||
from .hashing_utils import make_feature_id, numpy_hashing_uniform
|
||||
|
||||
|
||||
class TFModelWeightsInitializerBuilder(object):
|
||||
def __init__(self, num_bits):
|
||||
|
@ -1,26 +1,32 @@
|
||||
# checkstyle: noqa
|
||||
from datetime import datetime
|
||||
|
||||
import tensorflow.compat.v1 as tf
|
||||
from tensorflow.python.estimator.export.export import build_raw_serving_input_receiver_fn
|
||||
import tensorflow_hub as hub
|
||||
from tensorflow.compat.v1 import logging
|
||||
from tensorflow.python.estimator.export.export import (
|
||||
build_raw_serving_input_receiver_fn,
|
||||
)
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import array_ops
|
||||
import tensorflow_hub as hub
|
||||
|
||||
from datetime import datetime
|
||||
from tensorflow.compat.v1 import logging
|
||||
from twitter.deepbird.projects.timelines.configs import all_configs
|
||||
|
||||
import twml
|
||||
from twml.contrib.calibrators.common_calibrators import (
|
||||
build_percentile_discretizer_graph,
|
||||
calibrate_discretizer_and_export,
|
||||
)
|
||||
from twml.trainers import DataRecordTrainer
|
||||
from twml.contrib.calibrators.common_calibrators import build_percentile_discretizer_graph
|
||||
from twml.contrib.calibrators.common_calibrators import calibrate_discretizer_and_export
|
||||
from .metrics import get_multi_binary_class_metric_fn
|
||||
from .constants import TARGET_LABEL_IDX, PREDICTED_CLASSES
|
||||
|
||||
from .constants import PREDICTED_CLASSES, TARGET_LABEL_IDX
|
||||
from .example_weights import add_weight_arguments, make_weights_tensor
|
||||
from .lolly.data_helpers import get_lolly_logits
|
||||
from .lolly.tf_model_initializer_builder import TFModelInitializerBuilder
|
||||
from .lolly.reader import LollyModelReader
|
||||
from .lolly.tf_model_initializer_builder import TFModelInitializerBuilder
|
||||
from .metrics import get_multi_binary_class_metric_fn
|
||||
from .tf_model.discretizer_builder import TFModelDiscretizerBuilder
|
||||
from .tf_model.weights_initializer_builder import TFModelWeightsInitializerBuilder
|
||||
|
||||
import twml
|
||||
|
||||
def get_feature_values(features_values, params):
|
||||
if params.lolly_model_tsv:
|
||||
|
@ -1,19 +1,45 @@
|
||||
import datetime
|
||||
import os
|
||||
from dataclasses import asdict
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import tensorflow as tf
|
||||
import tensorflow_hub as hub
|
||||
import utils
|
||||
import wandb
|
||||
|
||||
try:
|
||||
wandb_key = ...
|
||||
wandb.login(...)
|
||||
run = wandb.init(project='ptos_with_media',
|
||||
group='new-split-trains',
|
||||
notes='tweet text with only (num_media, precision_nsfw). on full train set, new split.',
|
||||
entity='absv',
|
||||
config=params,
|
||||
name='tweet-text-w-nsfw-1.1',
|
||||
sync_tensorboard=True)
|
||||
except FileNotFoundError:
|
||||
print('Wandb key not found')
|
||||
run = wandb.init(mode='disabled')
|
||||
|
||||
|
||||
from notebook_eval_utils import EvalConfig, SparseMultilabelEvaluator
|
||||
from twitter.cuad.representation.models.optimization import create_optimizer
|
||||
from twitter.cuad.representation.models.text_encoder import TextEncoder
|
||||
from twitter.hmli.nimbus.modeling.feature_encoder import FeatureEncoder
|
||||
from twitter.hmli.nimbus.modeling.feature_loader import BigQueryFeatureLoader
|
||||
from twitter.hmli.nimbus.modeling.model_config import (
|
||||
EncodingType,
|
||||
Feature,
|
||||
FeatureType,
|
||||
Model,
|
||||
)
|
||||
|
||||
physical_devices = tf.config.list_physical_devices('GPU')
|
||||
for device in physical_devices:
|
||||
tf.config.experimental.set_memory_growth(device, True)
|
||||
|
||||
from twitter.hmli.nimbus.modeling.model_config import FeatureType, EncodingType, Feature, Model, LogType
|
||||
from twitter.hmli.nimbus.modeling.feature_loader import BigQueryFeatureLoader
|
||||
from twitter.cuad.representation.models.text_encoder import TextEncoder
|
||||
from twitter.cuad.representation.models.optimization import create_optimizer
|
||||
from twitter.hmli.nimbus.modeling.feature_encoder import FeatureEncoder
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import utils
|
||||
|
||||
cat_names = [
|
||||
...
|
||||
]
|
||||
@ -75,7 +101,6 @@ params = {
|
||||
'model_type': 'twitter_multilingual_bert_base_cased_mlm',
|
||||
'mixed_precision': True,
|
||||
}
|
||||
params
|
||||
|
||||
def parse_labeled_data(row_dict):
|
||||
label = [row_dict.pop(l) for l in labels]
|
||||
@ -134,7 +159,9 @@ with mirrored_strategy.scope():
|
||||
)
|
||||
pr_auc = tf.keras.metrics.AUC(curve="PR", num_thresholds=1000, multi_label=True, from_logits=True)
|
||||
|
||||
custom_loss = lambda y_true, y_pred: utils.multilabel_weighted_loss(y_true, y_pred, weights=pos_weight_tensor)
|
||||
def custom_loss(y_true, y_pred):
|
||||
return utils.multilabel_weighted_loss(y_true, y_pred, weights=pos_weight_tensor)
|
||||
|
||||
optimizer = create_optimizer(
|
||||
init_lr=params["lr"],
|
||||
num_train_steps=(params["epochs"] * params["steps_per_epoch"]),
|
||||
@ -154,25 +181,6 @@ model.weights
|
||||
model.summary()
|
||||
pr_auc.name
|
||||
|
||||
import getpass
|
||||
import wandb
|
||||
from wandb.keras import WandbCallback
|
||||
try:
|
||||
wandb_key = ...
|
||||
wandb.login(...)
|
||||
run = wandb.init(project='ptos_with_media',
|
||||
group='new-split-trains',
|
||||
notes='tweet text with only (num_media, precision_nsfw). on full train set, new split.',
|
||||
entity='absv',
|
||||
config=params,
|
||||
name='tweet-text-w-nsfw-1.1',
|
||||
sync_tensorboard=True)
|
||||
except FileNotFoundError:
|
||||
print('Wandb key not found')
|
||||
run = wandb.init(mode='disabled')
|
||||
import datetime
|
||||
import os
|
||||
|
||||
start_train_time = datetime.datetime.now()
|
||||
print(start_train_time.strftime("%m-%d-%Y (%H:%M:%S)"))
|
||||
checkpoint_path = os.path.join("...")
|
||||
@ -195,8 +203,6 @@ model.fit(train_ds, epochs=params["epochs"], validation_data=val_ds, callbacks=[
|
||||
steps_per_epoch=params["steps_per_epoch"],
|
||||
verbose=2)
|
||||
|
||||
import tensorflow_hub as hub
|
||||
|
||||
gs_model_path = ...
|
||||
reloaded_keras_layer = hub.KerasLayer(gs_model_path)
|
||||
inputs = tf.keras.layers.Input(name="tweet__core__tweet__text", shape=(1,), dtype=tf.string)
|
||||
@ -233,9 +239,6 @@ test_media_not_nsfw = test.filter(lambda x, y: tf.logical_and(tf.equal(x["has_me
|
||||
for d in [test, test_only_media, test_only_nsfw, test_no_media, test_media_not_nsfw]:
|
||||
print(d.reduce(0, lambda x, _: x + 1).numpy())
|
||||
|
||||
from notebook_eval_utils import SparseMultilabelEvaluator, EvalConfig
|
||||
from dataclasses import asdict
|
||||
|
||||
def display_metrics(probs, targets, labels=labels):
|
||||
eval_config = EvalConfig(prediction_threshold=0.5, precision_k=0.9)
|
||||
for eval_mode, y_mask in [("implicit", np.ones(targets.shape))]:
|
||||
|
@ -1,21 +1,19 @@
|
||||
import glob
|
||||
import os
|
||||
import random
|
||||
|
||||
import kerastuner as kt
|
||||
import math
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import random
|
||||
import sklearn.metrics
|
||||
import tensorflow as tf
|
||||
import os
|
||||
import glob
|
||||
|
||||
from tqdm import tqdm
|
||||
from matplotlib import pyplot as plt
|
||||
from tensorflow.keras.models import Sequential
|
||||
from tensorflow.keras.layers import Dense
|
||||
from google.cloud import storage
|
||||
from matplotlib import pyplot as plt
|
||||
from tensorflow.keras.layers import Dense
|
||||
from tensorflow.keras.models import Sequential
|
||||
from tqdm import tqdm
|
||||
|
||||
physical_devices = tf.config.list_physical_devices('GPU')
|
||||
physical_devices
|
||||
|
||||
tf.config.set_visible_devices([tf.config.PhysicalDevice(name='/physical_device:GPU:1', device_type='GPU')], 'GPU')
|
||||
tf.config.get_visible_devices('GPU')
|
||||
@ -89,7 +87,7 @@ test_ds = test_ds.map(lambda x: preprocess_embedding_example(x, positive_label=p
|
||||
|
||||
if use_sens_prev_data:
|
||||
test_sens_prev_glob = f"{sens_prev_input_root}/test/tfrecord/*.tfrecord"
|
||||
test_sens_prev_files = tf.io.gfile.glob(test_sens_prev_glob)
|
||||
test_sens_prev_files = tf.io.gfile.glob(test_sens_prev_glob)
|
||||
|
||||
if not len(test_sens_prev_files):
|
||||
raise ValueError(f"Did not find any eval files matching {test_sens_prev_glob}")
|
||||
@ -109,12 +107,12 @@ train_ds = train_ds.repeat()
|
||||
|
||||
if has_validation_data:
|
||||
eval_glob = f"{input_root}/validation/tfrecord/*.tfrecord"
|
||||
eval_files = tf.io.gfile.glob(eval_glob)
|
||||
eval_files = tf.io.gfile.glob(eval_glob)
|
||||
|
||||
if use_sens_prev_data:
|
||||
eval_sens_prev_glob = f"{sens_prev_input_root}/validation/tfrecord/*.tfrecord"
|
||||
eval_sens_prev_files = tf.io.gfile.glob(eval_sens_prev_glob)
|
||||
eval_files = eval_files + eval_sens_prev_files
|
||||
eval_files = eval_files + eval_sens_prev_files
|
||||
|
||||
|
||||
if not len(eval_files):
|
||||
@ -428,7 +426,7 @@ ptAt50fmt = "%.4f" % ptAt50[1]
|
||||
ptAt90fmt = "%.4f" % ptAt90[1]
|
||||
aucFmt = "%.4f" % auc_precision_recall
|
||||
plt.title(
|
||||
f"Keras (nsfw MU test)\nAUC={aucFmt}\np={ptAt50fmt} @ r=0.5\np={ptAt90fmt} @ r=0.9\nN_train={...}} ({...} pos), N_test={n_test} ({n_test_pos} pos)",
|
||||
f"Keras (nsfw MU test)\nAUC={aucFmt}\np={ptAt50fmt} @ r=0.5\np={ptAt90fmt} @ r=0.9\nN_train={...} ({...} pos), N_test={n_test} ({n_test_pos} pos)",
|
||||
size=20
|
||||
)
|
||||
plt.subplots_adjust(top=0.72)
|
||||
|
@ -1,14 +1,16 @@
|
||||
from datetime import datetime
|
||||
from functools import reduce
|
||||
import os
|
||||
import pandas as pd
|
||||
import re
|
||||
from sklearn.metrics import average_precision_score, classification_report, precision_recall_curve, PrecisionRecallDisplay
|
||||
from sklearn.model_selection import train_test_split
|
||||
import tensorflow as tf
|
||||
import matplotlib.pyplot as plt
|
||||
import re
|
||||
from datetime import datetime
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import pandas as pd
|
||||
import tensorflow as tf
|
||||
from sklearn.metrics import (
|
||||
average_precision_score,
|
||||
classification_report,
|
||||
precision_recall_curve,
|
||||
)
|
||||
from sklearn.model_selection import train_test_split
|
||||
from twitter.cuad.representation.models.optimization import create_optimizer
|
||||
from twitter.cuad.representation.models.text_encoder import TextEncoder
|
||||
|
||||
|
@ -1,10 +1,8 @@
|
||||
from abc import ABC
|
||||
import re
|
||||
|
||||
from toxicity_ml_pipeline.settings.hcomp_settings import TOXIC_35
|
||||
from abc import ABC
|
||||
|
||||
import numpy as np
|
||||
|
||||
from toxicity_ml_pipeline.settings.hcomp_settings import TOXIC_35
|
||||
|
||||
TOXIC_35_set = set(TOXIC_35)
|
||||
|
||||
@ -84,7 +82,7 @@ class DefaultENNoPreprocessor(DataframeCleaner):
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
if "filter_low_agreements" in kwargs and kwargs["filter_low_agreements"] == True:
|
||||
if "filter_low_agreements" in kwargs and kwargs["filter_low_agreements"] is True:
|
||||
df.drop(df[(df.agreement_rate <= 0.6)].index, axis=0, inplace=True)
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -287,7 +287,7 @@ class ENLoaderWithSampling(ENLoader):
|
||||
class I18nLoader(DataframeLoader):
|
||||
def __init__(self):
|
||||
super().__init__(project=...)
|
||||
from archive.settings.... import ACCEPTED_LANGUAGES, QUERY_SETTINGS
|
||||
from archive.settings import ACCEPTED_LANGUAGES, QUERY_SETTINGS
|
||||
|
||||
self.accepted_languages = ACCEPTED_LANGUAGES
|
||||
self.query_settings = dict(QUERY_SETTINGS)
|
||||
|
@ -1,6 +1,10 @@
|
||||
from importlib import import_module
|
||||
import os
|
||||
from importlib import import_module
|
||||
|
||||
import numpy as np
|
||||
import pandas
|
||||
import tensorflow as tf
|
||||
from sklearn.model_selection import StratifiedKFold
|
||||
from toxicity_ml_pipeline.settings.default_settings_tox import (
|
||||
INNER_CV,
|
||||
LOCAL_DIR,
|
||||
@ -12,12 +16,6 @@ from toxicity_ml_pipeline.settings.default_settings_tox import (
|
||||
)
|
||||
from toxicity_ml_pipeline.utils.helpers import execute_command
|
||||
|
||||
import numpy as np
|
||||
import pandas
|
||||
from sklearn.model_selection import StratifiedKFold
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
try:
|
||||
from transformers import AutoTokenizer, DataCollatorWithPadding
|
||||
except ModuleNotFoundError:
|
||||
|
@ -1,14 +1,13 @@
|
||||
import os
|
||||
|
||||
from toxicity_ml_pipeline.settings.default_settings_tox import LOCAL_DIR, MAX_SEQ_LENGTH
|
||||
|
||||
try:
|
||||
from toxicity_ml_pipeline.optim.losses import MaskedBCE
|
||||
except ImportError:
|
||||
print('No MaskedBCE loss')
|
||||
from toxicity_ml_pipeline.utils.helpers import execute_command
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from toxicity_ml_pipeline.utils.helpers import execute_command
|
||||
|
||||
try:
|
||||
from twitter.cuad.representation.models.text_encoder import TextEncoder
|
||||
@ -102,7 +101,7 @@ def get_loss(loss_name, from_logits, **kwargs):
|
||||
multitask = kwargs.get("multitask", False)
|
||||
if from_logits or multitask:
|
||||
raise NotImplementedError
|
||||
print(f'Masked Binary Cross Entropy')
|
||||
print('Masked Binary Cross Entropy')
|
||||
return MaskedBCE()
|
||||
|
||||
if loss_name == "inv_kl_loss":
|
||||
|
@ -1,14 +1,16 @@
|
||||
from collections import defaultdict
|
||||
import os
|
||||
from collections import defaultdict
|
||||
|
||||
from toxicity_ml_pipeline.settings.default_settings_tox import REMOTE_LOGDIR
|
||||
from toxicity_ml_pipeline.settings.default_settings_abs import LABEL_NAMES
|
||||
from toxicity_ml_pipeline.utils.absv_utils import parse_labeled_data
|
||||
from toxicity_ml_pipeline.utils.helpers import compute_precision_fixed_recall, execute_command
|
||||
|
||||
from sklearn.metrics import average_precision_score, roc_auc_score
|
||||
import tensorflow as tf
|
||||
import wandb
|
||||
from sklearn.metrics import average_precision_score, roc_auc_score
|
||||
from toxicity_ml_pipeline.settings.default_settings_abs import LABEL_NAMES
|
||||
from toxicity_ml_pipeline.settings.default_settings_tox import REMOTE_LOGDIR
|
||||
from toxicity_ml_pipeline.utils.absv_utils import parse_labeled_data
|
||||
from toxicity_ml_pipeline.utils.helpers import (
|
||||
compute_precision_fixed_recall,
|
||||
execute_command,
|
||||
)
|
||||
|
||||
|
||||
class NothingCallback(tf.keras.callbacks.Callback):
|
||||
|
@ -1,7 +1,7 @@
|
||||
import tensorflow as tf
|
||||
from keras.utils import tf_utils
|
||||
from keras.utils import losses_utils
|
||||
from keras import backend
|
||||
from keras.utils import losses_utils, tf_utils
|
||||
|
||||
|
||||
def inv_kl_divergence(y_true, y_pred):
|
||||
y_pred = tf.convert_to_tensor(y_pred)
|
||||
|
@ -1,8 +1,7 @@
|
||||
from toxicity_ml_pipeline.load_model import reload_model_weights
|
||||
from toxicity_ml_pipeline.utils.helpers import load_inference_func, upload_model
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from toxicity_ml_pipeline.load_model import reload_model_weights
|
||||
from toxicity_ml_pipeline.utils.helpers import load_inference_func, upload_model
|
||||
|
||||
|
||||
def score(language, df, gcs_model_path, batch_size=64, text_col="text", kw="", **kwargs):
|
||||
|
@ -1,6 +1,5 @@
|
||||
import os
|
||||
|
||||
|
||||
TEAM_PROJECT = "twttr-toxicity-prod"
|
||||
try:
|
||||
from google.cloud import bigquery
|
||||
@ -16,7 +15,7 @@ else:
|
||||
CLIENT = None
|
||||
print("Issue at logging time", e)
|
||||
|
||||
TRAINING_DATA_LOCATION = f"..."
|
||||
TRAINING_DATA_LOCATION = "..."
|
||||
GCS_ADDRESS = "..."
|
||||
LOCAL_DIR = os.getcwd()
|
||||
REMOTE_LOGDIR = "{GCS_ADDRESS}/logs"
|
||||
|
@ -1,14 +1,16 @@
|
||||
import os
|
||||
from datetime import datetime
|
||||
from importlib import import_module
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from toxicity_ml_pipeline.data.data_preprocessing import (
|
||||
DefaultENNoPreprocessor,
|
||||
DefaultENPreprocessor,
|
||||
)
|
||||
from toxicity_ml_pipeline.data.dataframe_loader import ENLoader, ENLoaderWithSampling
|
||||
from toxicity_ml_pipeline.data.mb_generator import BalancedMiniBatchLoader
|
||||
from toxicity_ml_pipeline.load_model import load, get_last_layer
|
||||
from toxicity_ml_pipeline.load_model import get_last_layer, load
|
||||
from toxicity_ml_pipeline.optim.callbacks import (
|
||||
AdditionalResultLogger,
|
||||
ControlledStoppingCheckpointCallback,
|
||||
@ -19,6 +21,8 @@ from toxicity_ml_pipeline.optim.schedulers import WarmUp
|
||||
from toxicity_ml_pipeline.settings.default_settings_abs import GCS_ADDRESS as ABS_GCS
|
||||
from toxicity_ml_pipeline.settings.default_settings_tox import (
|
||||
GCS_ADDRESS as TOX_GCS,
|
||||
)
|
||||
from toxicity_ml_pipeline.settings.default_settings_tox import (
|
||||
MODEL_DIR,
|
||||
RANDOM_SEED,
|
||||
REMOTE_LOGDIR,
|
||||
@ -26,10 +30,6 @@ from toxicity_ml_pipeline.settings.default_settings_tox import (
|
||||
)
|
||||
from toxicity_ml_pipeline.utils.helpers import check_gpu, set_seeds, upload_model
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
try:
|
||||
from tensorflow_addons.optimizers import AdamW
|
||||
except ModuleNotFoundError:
|
||||
@ -139,9 +139,9 @@ class Trainer(object):
|
||||
)
|
||||
print("------- Experiment name: ", experiment_name)
|
||||
self.logdir = (
|
||||
f"..."
|
||||
"..."
|
||||
if self.test
|
||||
else f"..."
|
||||
else "..."
|
||||
)
|
||||
self.checkpoint_path = f"{self.model_dir}/{experiment_name}"
|
||||
|
||||
|
@ -3,11 +3,9 @@ import os
|
||||
import random as python_random
|
||||
import subprocess
|
||||
|
||||
from toxicity_ml_pipeline.settings.default_settings_tox import LOCAL_DIR
|
||||
|
||||
import numpy as np
|
||||
from sklearn.metrics import precision_recall_curve
|
||||
|
||||
from toxicity_ml_pipeline.settings.default_settings_tox import LOCAL_DIR
|
||||
|
||||
try:
|
||||
import tensorflow as tf
|
||||
|
Loading…
Reference in New Issue
Block a user