Merge 858525b9d8aaaf856f2ea91079a2940f3c36a419 into fb54d8b54984f89f7dba90a18e7c3048421464c3

This commit is contained in:
Harshil 2023-05-22 17:38:27 -05:00 committed by GitHub
commit 4a11b533cd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 131 additions and 123 deletions

View File

@ -6,8 +6,8 @@ import sys
from urllib.parse import urlsplit from urllib.parse import urlsplit
import apache_beam as beam import apache_beam as beam
from apache_beam.options.pipeline_options import PipelineOptions
import faiss import faiss
from apache_beam.options.pipeline_options import PipelineOptions
def parse_d6w_config(argv=None): def parse_d6w_config(argv=None):
@ -160,8 +160,8 @@ class MergeAndBuildIndex(beam.CombineFn):
import subprocess import subprocess
import faiss import faiss
from google.cloud import storage
import numpy as np import numpy as np
from google.cloud import storage
client = storage.Client() client = storage.Client()
bucket = client.get_bucket(self.bucket_name) bucket = client.get_bucket(self.bucket_name)

View File

@ -1,5 +1,6 @@
# checkstyle: noqa # checkstyle: noqa
import tensorflow.compat.v1 as tf import tensorflow.compat.v1 as tf
from .constants import INDEX_BY_LABEL, LABEL_NAMES from .constants import INDEX_BY_LABEL, LABEL_NAMES
# TODO: Read these from command line arguments, since they specify the existing example weights in the input data. # TODO: Read these from command line arguments, since they specify the existing example weights in the input data.

View File

@ -1,7 +1,9 @@
# checkstyle: noqa # checkstyle: noqa
import tensorflow.compat.v1 as tf import tensorflow.compat.v1 as tf
from ..constants import EB_SCORE_IDX from ..constants import EB_SCORE_IDX
# The rationale behind this logic is available at TQ-9678. # The rationale behind this logic is available at TQ-9678.
def get_lolly_logits(labels): def get_lolly_logits(labels):
''' '''

View File

@ -4,7 +4,6 @@ from .parsers import DBv2DataExampleParser
from .reader import LollyModelReader from .reader import LollyModelReader
from .scorer import LollyModelScorer from .scorer import LollyModelScorer
if __name__ == "__main__": if __name__ == "__main__":
lolly_model_reader = LollyModelReader(lolly_model_file_path=sys.argv[1]) lolly_model_reader = LollyModelReader(lolly_model_file_path=sys.argv[1])
lolly_model_scorer = LollyModelScorer(data_example_parser=DBv2DataExampleParser(lolly_model_reader)) lolly_model_scorer = LollyModelScorer(data_example_parser=DBv2DataExampleParser(lolly_model_reader))

View File

@ -1,10 +1,13 @@
# checkstyle: noqa # checkstyle: noqa
import tensorflow.compat.v1 as tf
from collections import OrderedDict from collections import OrderedDict
import tensorflow.compat.v1 as tf
import twml
from .constants import EB_SCORE_IDX from .constants import EB_SCORE_IDX
from .lolly.data_helpers import get_lolly_scores from .lolly.data_helpers import get_lolly_scores
import twml
def get_multi_binary_class_metric_fn(metrics, classes=None, class_dim=1): def get_multi_binary_class_metric_fn(metrics, classes=None, class_dim=1):
""" """

View File

@ -1,7 +1,8 @@
from .hashing_utils import make_feature_id import numpy as np
from twml.contrib.layers.hashing_discretizer import HashingDiscretizer from twml.contrib.layers.hashing_discretizer import HashingDiscretizer
import numpy as np
from .hashing_utils import make_feature_id
class TFModelDiscretizerBuilder(object): class TFModelDiscretizerBuilder(object):

View File

@ -1,6 +1,5 @@
from twitter.deepbird.io.util import _get_feature_id
import numpy as np import numpy as np
from twitter.deepbird.io.util import _get_feature_id
def numpy_hashing_uniform(the_id, bin_idx, output_bits): def numpy_hashing_uniform(the_id, bin_idx, output_bits):

View File

@ -1,9 +1,10 @@
from .hashing_utils import make_feature_id, numpy_hashing_uniform
import numpy as np import numpy as np
import tensorflow.compat.v1 as tf import tensorflow.compat.v1 as tf
import twml import twml
from .hashing_utils import make_feature_id, numpy_hashing_uniform
class TFModelWeightsInitializerBuilder(object): class TFModelWeightsInitializerBuilder(object):
def __init__(self, num_bits): def __init__(self, num_bits):

View File

@ -1,26 +1,32 @@
# checkstyle: noqa # checkstyle: noqa
from datetime import datetime
import tensorflow.compat.v1 as tf import tensorflow.compat.v1 as tf
from tensorflow.python.estimator.export.export import build_raw_serving_input_receiver_fn import tensorflow_hub as hub
from tensorflow.compat.v1 import logging
from tensorflow.python.estimator.export.export import (
build_raw_serving_input_receiver_fn,
)
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
import tensorflow_hub as hub
from datetime import datetime
from tensorflow.compat.v1 import logging
from twitter.deepbird.projects.timelines.configs import all_configs from twitter.deepbird.projects.timelines.configs import all_configs
import twml
from twml.contrib.calibrators.common_calibrators import (
build_percentile_discretizer_graph,
calibrate_discretizer_and_export,
)
from twml.trainers import DataRecordTrainer from twml.trainers import DataRecordTrainer
from twml.contrib.calibrators.common_calibrators import build_percentile_discretizer_graph
from twml.contrib.calibrators.common_calibrators import calibrate_discretizer_and_export from .constants import PREDICTED_CLASSES, TARGET_LABEL_IDX
from .metrics import get_multi_binary_class_metric_fn
from .constants import TARGET_LABEL_IDX, PREDICTED_CLASSES
from .example_weights import add_weight_arguments, make_weights_tensor from .example_weights import add_weight_arguments, make_weights_tensor
from .lolly.data_helpers import get_lolly_logits from .lolly.data_helpers import get_lolly_logits
from .lolly.tf_model_initializer_builder import TFModelInitializerBuilder
from .lolly.reader import LollyModelReader from .lolly.reader import LollyModelReader
from .lolly.tf_model_initializer_builder import TFModelInitializerBuilder
from .metrics import get_multi_binary_class_metric_fn
from .tf_model.discretizer_builder import TFModelDiscretizerBuilder from .tf_model.discretizer_builder import TFModelDiscretizerBuilder
from .tf_model.weights_initializer_builder import TFModelWeightsInitializerBuilder from .tf_model.weights_initializer_builder import TFModelWeightsInitializerBuilder
import twml
def get_feature_values(features_values, params): def get_feature_values(features_values, params):
if params.lolly_model_tsv: if params.lolly_model_tsv:

View File

@ -1,19 +1,45 @@
import datetime
import os
from dataclasses import asdict
import numpy as np
import pandas as pd
import tensorflow as tf import tensorflow as tf
import tensorflow_hub as hub
import utils
import wandb
try:
wandb_key = ...
wandb.login(...)
run = wandb.init(project='ptos_with_media',
group='new-split-trains',
notes='tweet text with only (num_media, precision_nsfw). on full train set, new split.',
entity='absv',
config=params,
name='tweet-text-w-nsfw-1.1',
sync_tensorboard=True)
except FileNotFoundError:
print('Wandb key not found')
run = wandb.init(mode='disabled')
from notebook_eval_utils import EvalConfig, SparseMultilabelEvaluator
from twitter.cuad.representation.models.optimization import create_optimizer
from twitter.cuad.representation.models.text_encoder import TextEncoder
from twitter.hmli.nimbus.modeling.feature_encoder import FeatureEncoder
from twitter.hmli.nimbus.modeling.feature_loader import BigQueryFeatureLoader
from twitter.hmli.nimbus.modeling.model_config import (
EncodingType,
Feature,
FeatureType,
Model,
)
physical_devices = tf.config.list_physical_devices('GPU') physical_devices = tf.config.list_physical_devices('GPU')
for device in physical_devices: for device in physical_devices:
tf.config.experimental.set_memory_growth(device, True) tf.config.experimental.set_memory_growth(device, True)
from twitter.hmli.nimbus.modeling.model_config import FeatureType, EncodingType, Feature, Model, LogType
from twitter.hmli.nimbus.modeling.feature_loader import BigQueryFeatureLoader
from twitter.cuad.representation.models.text_encoder import TextEncoder
from twitter.cuad.representation.models.optimization import create_optimizer
from twitter.hmli.nimbus.modeling.feature_encoder import FeatureEncoder
import numpy as np
import pandas as pd
import utils
cat_names = [ cat_names = [
... ...
] ]
@ -75,7 +101,6 @@ params = {
'model_type': 'twitter_multilingual_bert_base_cased_mlm', 'model_type': 'twitter_multilingual_bert_base_cased_mlm',
'mixed_precision': True, 'mixed_precision': True,
} }
params
def parse_labeled_data(row_dict): def parse_labeled_data(row_dict):
label = [row_dict.pop(l) for l in labels] label = [row_dict.pop(l) for l in labels]
@ -134,7 +159,9 @@ with mirrored_strategy.scope():
) )
pr_auc = tf.keras.metrics.AUC(curve="PR", num_thresholds=1000, multi_label=True, from_logits=True) pr_auc = tf.keras.metrics.AUC(curve="PR", num_thresholds=1000, multi_label=True, from_logits=True)
custom_loss = lambda y_true, y_pred: utils.multilabel_weighted_loss(y_true, y_pred, weights=pos_weight_tensor) def custom_loss(y_true, y_pred):
return utils.multilabel_weighted_loss(y_true, y_pred, weights=pos_weight_tensor)
optimizer = create_optimizer( optimizer = create_optimizer(
init_lr=params["lr"], init_lr=params["lr"],
num_train_steps=(params["epochs"] * params["steps_per_epoch"]), num_train_steps=(params["epochs"] * params["steps_per_epoch"]),
@ -154,25 +181,6 @@ model.weights
model.summary() model.summary()
pr_auc.name pr_auc.name
import getpass
import wandb
from wandb.keras import WandbCallback
try:
wandb_key = ...
wandb.login(...)
run = wandb.init(project='ptos_with_media',
group='new-split-trains',
notes='tweet text with only (num_media, precision_nsfw). on full train set, new split.',
entity='absv',
config=params,
name='tweet-text-w-nsfw-1.1',
sync_tensorboard=True)
except FileNotFoundError:
print('Wandb key not found')
run = wandb.init(mode='disabled')
import datetime
import os
start_train_time = datetime.datetime.now() start_train_time = datetime.datetime.now()
print(start_train_time.strftime("%m-%d-%Y (%H:%M:%S)")) print(start_train_time.strftime("%m-%d-%Y (%H:%M:%S)"))
checkpoint_path = os.path.join("...") checkpoint_path = os.path.join("...")
@ -195,8 +203,6 @@ model.fit(train_ds, epochs=params["epochs"], validation_data=val_ds, callbacks=[
steps_per_epoch=params["steps_per_epoch"], steps_per_epoch=params["steps_per_epoch"],
verbose=2) verbose=2)
import tensorflow_hub as hub
gs_model_path = ... gs_model_path = ...
reloaded_keras_layer = hub.KerasLayer(gs_model_path) reloaded_keras_layer = hub.KerasLayer(gs_model_path)
inputs = tf.keras.layers.Input(name="tweet__core__tweet__text", shape=(1,), dtype=tf.string) inputs = tf.keras.layers.Input(name="tweet__core__tweet__text", shape=(1,), dtype=tf.string)
@ -233,9 +239,6 @@ test_media_not_nsfw = test.filter(lambda x, y: tf.logical_and(tf.equal(x["has_me
for d in [test, test_only_media, test_only_nsfw, test_no_media, test_media_not_nsfw]: for d in [test, test_only_media, test_only_nsfw, test_no_media, test_media_not_nsfw]:
print(d.reduce(0, lambda x, _: x + 1).numpy()) print(d.reduce(0, lambda x, _: x + 1).numpy())
from notebook_eval_utils import SparseMultilabelEvaluator, EvalConfig
from dataclasses import asdict
def display_metrics(probs, targets, labels=labels): def display_metrics(probs, targets, labels=labels):
eval_config = EvalConfig(prediction_threshold=0.5, precision_k=0.9) eval_config = EvalConfig(prediction_threshold=0.5, precision_k=0.9)
for eval_mode, y_mask in [("implicit", np.ones(targets.shape))]: for eval_mode, y_mask in [("implicit", np.ones(targets.shape))]:

View File

@ -1,21 +1,19 @@
import glob
import os
import random
import kerastuner as kt import kerastuner as kt
import math
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import random
import sklearn.metrics import sklearn.metrics
import tensorflow as tf import tensorflow as tf
import os
import glob
from tqdm import tqdm
from matplotlib import pyplot as plt
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from google.cloud import storage from google.cloud import storage
from matplotlib import pyplot as plt
from tensorflow.keras.layers import Dense
from tensorflow.keras.models import Sequential
from tqdm import tqdm
physical_devices = tf.config.list_physical_devices('GPU') physical_devices = tf.config.list_physical_devices('GPU')
physical_devices
tf.config.set_visible_devices([tf.config.PhysicalDevice(name='/physical_device:GPU:1', device_type='GPU')], 'GPU') tf.config.set_visible_devices([tf.config.PhysicalDevice(name='/physical_device:GPU:1', device_type='GPU')], 'GPU')
tf.config.get_visible_devices('GPU') tf.config.get_visible_devices('GPU')
@ -428,7 +426,7 @@ ptAt50fmt = "%.4f" % ptAt50[1]
ptAt90fmt = "%.4f" % ptAt90[1] ptAt90fmt = "%.4f" % ptAt90[1]
aucFmt = "%.4f" % auc_precision_recall aucFmt = "%.4f" % auc_precision_recall
plt.title( plt.title(
f"Keras (nsfw MU test)\nAUC={aucFmt}\np={ptAt50fmt} @ r=0.5\np={ptAt90fmt} @ r=0.9\nN_train={...}} ({...} pos), N_test={n_test} ({n_test_pos} pos)", f"Keras (nsfw MU test)\nAUC={aucFmt}\np={ptAt50fmt} @ r=0.5\np={ptAt90fmt} @ r=0.9\nN_train={...} ({...} pos), N_test={n_test} ({n_test_pos} pos)",
size=20 size=20
) )
plt.subplots_adjust(top=0.72) plt.subplots_adjust(top=0.72)

View File

@ -1,14 +1,16 @@
from datetime import datetime
from functools import reduce
import os import os
import pandas as pd
import re
from sklearn.metrics import average_precision_score, classification_report, precision_recall_curve, PrecisionRecallDisplay
from sklearn.model_selection import train_test_split
import tensorflow as tf
import matplotlib.pyplot as plt
import re import re
from datetime import datetime
import matplotlib.pyplot as plt
import pandas as pd
import tensorflow as tf
from sklearn.metrics import (
average_precision_score,
classification_report,
precision_recall_curve,
)
from sklearn.model_selection import train_test_split
from twitter.cuad.representation.models.optimization import create_optimizer from twitter.cuad.representation.models.optimization import create_optimizer
from twitter.cuad.representation.models.text_encoder import TextEncoder from twitter.cuad.representation.models.text_encoder import TextEncoder

View File

@ -1,10 +1,8 @@
from abc import ABC
import re import re
from abc import ABC
from toxicity_ml_pipeline.settings.hcomp_settings import TOXIC_35
import numpy as np import numpy as np
from toxicity_ml_pipeline.settings.hcomp_settings import TOXIC_35
TOXIC_35_set = set(TOXIC_35) TOXIC_35_set = set(TOXIC_35)
@ -84,7 +82,7 @@ class DefaultENNoPreprocessor(DataframeCleaner):
else: else:
raise NotImplementedError raise NotImplementedError
if "filter_low_agreements" in kwargs and kwargs["filter_low_agreements"] == True: if "filter_low_agreements" in kwargs and kwargs["filter_low_agreements"] is True:
df.drop(df[(df.agreement_rate <= 0.6)].index, axis=0, inplace=True) df.drop(df[(df.agreement_rate <= 0.6)].index, axis=0, inplace=True)
raise NotImplementedError raise NotImplementedError

View File

@ -287,7 +287,7 @@ class ENLoaderWithSampling(ENLoader):
class I18nLoader(DataframeLoader): class I18nLoader(DataframeLoader):
def __init__(self): def __init__(self):
super().__init__(project=...) super().__init__(project=...)
from archive.settings.... import ACCEPTED_LANGUAGES, QUERY_SETTINGS from archive.settings import ACCEPTED_LANGUAGES, QUERY_SETTINGS
self.accepted_languages = ACCEPTED_LANGUAGES self.accepted_languages = ACCEPTED_LANGUAGES
self.query_settings = dict(QUERY_SETTINGS) self.query_settings = dict(QUERY_SETTINGS)

View File

@ -1,6 +1,10 @@
from importlib import import_module
import os import os
from importlib import import_module
import numpy as np
import pandas
import tensorflow as tf
from sklearn.model_selection import StratifiedKFold
from toxicity_ml_pipeline.settings.default_settings_tox import ( from toxicity_ml_pipeline.settings.default_settings_tox import (
INNER_CV, INNER_CV,
LOCAL_DIR, LOCAL_DIR,
@ -12,12 +16,6 @@ from toxicity_ml_pipeline.settings.default_settings_tox import (
) )
from toxicity_ml_pipeline.utils.helpers import execute_command from toxicity_ml_pipeline.utils.helpers import execute_command
import numpy as np
import pandas
from sklearn.model_selection import StratifiedKFold
import tensorflow as tf
try: try:
from transformers import AutoTokenizer, DataCollatorWithPadding from transformers import AutoTokenizer, DataCollatorWithPadding
except ModuleNotFoundError: except ModuleNotFoundError:

View File

@ -1,14 +1,13 @@
import os import os
from toxicity_ml_pipeline.settings.default_settings_tox import LOCAL_DIR, MAX_SEQ_LENGTH from toxicity_ml_pipeline.settings.default_settings_tox import LOCAL_DIR, MAX_SEQ_LENGTH
try: try:
from toxicity_ml_pipeline.optim.losses import MaskedBCE from toxicity_ml_pipeline.optim.losses import MaskedBCE
except ImportError: except ImportError:
print('No MaskedBCE loss') print('No MaskedBCE loss')
from toxicity_ml_pipeline.utils.helpers import execute_command
import tensorflow as tf import tensorflow as tf
from toxicity_ml_pipeline.utils.helpers import execute_command
try: try:
from twitter.cuad.representation.models.text_encoder import TextEncoder from twitter.cuad.representation.models.text_encoder import TextEncoder
@ -102,7 +101,7 @@ def get_loss(loss_name, from_logits, **kwargs):
multitask = kwargs.get("multitask", False) multitask = kwargs.get("multitask", False)
if from_logits or multitask: if from_logits or multitask:
raise NotImplementedError raise NotImplementedError
print(f'Masked Binary Cross Entropy') print('Masked Binary Cross Entropy')
return MaskedBCE() return MaskedBCE()
if loss_name == "inv_kl_loss": if loss_name == "inv_kl_loss":

View File

@ -1,14 +1,16 @@
from collections import defaultdict
import os import os
from collections import defaultdict
from toxicity_ml_pipeline.settings.default_settings_tox import REMOTE_LOGDIR
from toxicity_ml_pipeline.settings.default_settings_abs import LABEL_NAMES
from toxicity_ml_pipeline.utils.absv_utils import parse_labeled_data
from toxicity_ml_pipeline.utils.helpers import compute_precision_fixed_recall, execute_command
from sklearn.metrics import average_precision_score, roc_auc_score
import tensorflow as tf import tensorflow as tf
import wandb import wandb
from sklearn.metrics import average_precision_score, roc_auc_score
from toxicity_ml_pipeline.settings.default_settings_abs import LABEL_NAMES
from toxicity_ml_pipeline.settings.default_settings_tox import REMOTE_LOGDIR
from toxicity_ml_pipeline.utils.absv_utils import parse_labeled_data
from toxicity_ml_pipeline.utils.helpers import (
compute_precision_fixed_recall,
execute_command,
)
class NothingCallback(tf.keras.callbacks.Callback): class NothingCallback(tf.keras.callbacks.Callback):

View File

@ -1,7 +1,7 @@
import tensorflow as tf import tensorflow as tf
from keras.utils import tf_utils
from keras.utils import losses_utils
from keras import backend from keras import backend
from keras.utils import losses_utils, tf_utils
def inv_kl_divergence(y_true, y_pred): def inv_kl_divergence(y_true, y_pred):
y_pred = tf.convert_to_tensor(y_pred) y_pred = tf.convert_to_tensor(y_pred)

View File

@ -1,8 +1,7 @@
from toxicity_ml_pipeline.load_model import reload_model_weights
from toxicity_ml_pipeline.utils.helpers import load_inference_func, upload_model
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from toxicity_ml_pipeline.load_model import reload_model_weights
from toxicity_ml_pipeline.utils.helpers import load_inference_func, upload_model
def score(language, df, gcs_model_path, batch_size=64, text_col="text", kw="", **kwargs): def score(language, df, gcs_model_path, batch_size=64, text_col="text", kw="", **kwargs):

View File

@ -1,6 +1,5 @@
import os import os
TEAM_PROJECT = "twttr-toxicity-prod" TEAM_PROJECT = "twttr-toxicity-prod"
try: try:
from google.cloud import bigquery from google.cloud import bigquery
@ -16,7 +15,7 @@ else:
CLIENT = None CLIENT = None
print("Issue at logging time", e) print("Issue at logging time", e)
TRAINING_DATA_LOCATION = f"..." TRAINING_DATA_LOCATION = "..."
GCS_ADDRESS = "..." GCS_ADDRESS = "..."
LOCAL_DIR = os.getcwd() LOCAL_DIR = os.getcwd()
REMOTE_LOGDIR = "{GCS_ADDRESS}/logs" REMOTE_LOGDIR = "{GCS_ADDRESS}/logs"

View File

@ -1,14 +1,16 @@
import os
from datetime import datetime from datetime import datetime
from importlib import import_module from importlib import import_module
import os
import numpy as np
import tensorflow as tf
from toxicity_ml_pipeline.data.data_preprocessing import ( from toxicity_ml_pipeline.data.data_preprocessing import (
DefaultENNoPreprocessor, DefaultENNoPreprocessor,
DefaultENPreprocessor, DefaultENPreprocessor,
) )
from toxicity_ml_pipeline.data.dataframe_loader import ENLoader, ENLoaderWithSampling from toxicity_ml_pipeline.data.dataframe_loader import ENLoader, ENLoaderWithSampling
from toxicity_ml_pipeline.data.mb_generator import BalancedMiniBatchLoader from toxicity_ml_pipeline.data.mb_generator import BalancedMiniBatchLoader
from toxicity_ml_pipeline.load_model import load, get_last_layer from toxicity_ml_pipeline.load_model import get_last_layer, load
from toxicity_ml_pipeline.optim.callbacks import ( from toxicity_ml_pipeline.optim.callbacks import (
AdditionalResultLogger, AdditionalResultLogger,
ControlledStoppingCheckpointCallback, ControlledStoppingCheckpointCallback,
@ -19,6 +21,8 @@ from toxicity_ml_pipeline.optim.schedulers import WarmUp
from toxicity_ml_pipeline.settings.default_settings_abs import GCS_ADDRESS as ABS_GCS from toxicity_ml_pipeline.settings.default_settings_abs import GCS_ADDRESS as ABS_GCS
from toxicity_ml_pipeline.settings.default_settings_tox import ( from toxicity_ml_pipeline.settings.default_settings_tox import (
GCS_ADDRESS as TOX_GCS, GCS_ADDRESS as TOX_GCS,
)
from toxicity_ml_pipeline.settings.default_settings_tox import (
MODEL_DIR, MODEL_DIR,
RANDOM_SEED, RANDOM_SEED,
REMOTE_LOGDIR, REMOTE_LOGDIR,
@ -26,10 +30,6 @@ from toxicity_ml_pipeline.settings.default_settings_tox import (
) )
from toxicity_ml_pipeline.utils.helpers import check_gpu, set_seeds, upload_model from toxicity_ml_pipeline.utils.helpers import check_gpu, set_seeds, upload_model
import numpy as np
import tensorflow as tf
try: try:
from tensorflow_addons.optimizers import AdamW from tensorflow_addons.optimizers import AdamW
except ModuleNotFoundError: except ModuleNotFoundError:
@ -139,9 +139,9 @@ class Trainer(object):
) )
print("------- Experiment name: ", experiment_name) print("------- Experiment name: ", experiment_name)
self.logdir = ( self.logdir = (
f"..." "..."
if self.test if self.test
else f"..." else "..."
) )
self.checkpoint_path = f"{self.model_dir}/{experiment_name}" self.checkpoint_path = f"{self.model_dir}/{experiment_name}"

View File

@ -3,11 +3,9 @@ import os
import random as python_random import random as python_random
import subprocess import subprocess
from toxicity_ml_pipeline.settings.default_settings_tox import LOCAL_DIR
import numpy as np import numpy as np
from sklearn.metrics import precision_recall_curve from sklearn.metrics import precision_recall_curve
from toxicity_ml_pipeline.settings.default_settings_tox import LOCAL_DIR
try: try:
import tensorflow as tf import tensorflow as tf