2
0
mirror of https://github.com/twitter/the-algorithm.git synced 2025-01-14 21:19:10 +01:00

100 lines
2.5 KiB
Python
Raw Normal View History

import bisect
import os
import random as python_random
import subprocess
from toxicity_ml_pipeline.settings.default_settings_tox import LOCAL_DIR
import numpy as np
from sklearn.metrics import precision_recall_curve
try:
import tensorflow as tf
except ModuleNotFoundError:
pass
def upload_model(full_gcs_model_path):
folder_name = full_gcs_model_path
if folder_name[:5] != "gs://":
folder_name = "gs://" + folder_name
dirname = os.path.dirname(folder_name)
epoch = os.path.basename(folder_name)
model_dir = os.path.join(LOCAL_DIR, "models")
cmd = f"mkdir {model_dir}"
try:
execute_command(cmd)
except subprocess.CalledProcessError:
pass
model_dir = os.path.join(model_dir, os.path.basename(dirname))
cmd = f"mkdir {model_dir}"
try:
execute_command(cmd)
except subprocess.CalledProcessError:
pass
try:
_ = int(epoch)
except ValueError:
cmd = f"gsutil rsync -r '{folder_name}' {model_dir}"
weights_dir = model_dir
else:
cmd = f"gsutil cp '{dirname}/checkpoint' {model_dir}/"
execute_command(cmd)
cmd = f"gsutil cp '{os.path.join(dirname, epoch)}*' {model_dir}/"
weights_dir = f"{model_dir}/{epoch}"
execute_command(cmd)
return weights_dir
def compute_precision_fixed_recall(labels, preds, fixed_recall):
precision_values, recall_values, thresholds = precision_recall_curve(y_true=labels, probas_pred=preds)
index_recall = bisect.bisect_left(-recall_values, -1 * fixed_recall)
result = precision_values[index_recall - 1]
print(f"Precision at {recall_values[index_recall-1]} recall: {result}")
return result, thresholds[index_recall - 1]
def load_inference_func(model_folder):
model = tf.saved_model.load(model_folder, ["serve"])
inference_func = model.signatures["serving_default"]
return inference_func
def execute_query(client, query):
job = client.query(query)
df = job.result().to_dataframe()
return df
def execute_command(cmd, print_=True):
s = subprocess.run(cmd, shell=True, capture_output=print_, check=True)
if print_:
print(s.stderr.decode("utf-8"))
print(s.stdout.decode("utf-8"))
def check_gpu():
try:
execute_command("nvidia-smi")
except subprocess.CalledProcessError:
print("There is no GPU when there should be one.")
raise AttributeError
l = tf.config.list_physical_devices("GPU")
if len(l) == 0:
raise ModuleNotFoundError("Tensorflow has not found the GPU. Check your installation")
print(l)
def set_seeds(seed):
np.random.seed(seed)
python_random.seed(seed)
tf.random.set_seed(seed)