mirror of
https://github.com/twitter/the-algorithm.git
synced 2024-11-18 01:19:21 +01:00
ef4c5eb65e
Please note we have force-pushed a new initial commit in order to remove some publicly-available Twitter user information. Note that this process may be required in the future.
100 lines
2.5 KiB
Python
100 lines
2.5 KiB
Python
import bisect
|
|
import os
|
|
import random as python_random
|
|
import subprocess
|
|
|
|
from toxicity_ml_pipeline.settings.default_settings_tox import LOCAL_DIR
|
|
|
|
import numpy as np
|
|
from sklearn.metrics import precision_recall_curve
|
|
|
|
|
|
try:
|
|
import tensorflow as tf
|
|
except ModuleNotFoundError:
|
|
pass
|
|
|
|
|
|
def upload_model(full_gcs_model_path):
|
|
folder_name = full_gcs_model_path
|
|
if folder_name[:5] != "gs://":
|
|
folder_name = "gs://" + folder_name
|
|
|
|
dirname = os.path.dirname(folder_name)
|
|
epoch = os.path.basename(folder_name)
|
|
|
|
model_dir = os.path.join(LOCAL_DIR, "models")
|
|
cmd = f"mkdir {model_dir}"
|
|
try:
|
|
execute_command(cmd)
|
|
except subprocess.CalledProcessError:
|
|
pass
|
|
model_dir = os.path.join(model_dir, os.path.basename(dirname))
|
|
cmd = f"mkdir {model_dir}"
|
|
try:
|
|
execute_command(cmd)
|
|
except subprocess.CalledProcessError:
|
|
pass
|
|
|
|
try:
|
|
_ = int(epoch)
|
|
except ValueError:
|
|
cmd = f"gsutil rsync -r '{folder_name}' {model_dir}"
|
|
weights_dir = model_dir
|
|
|
|
else:
|
|
cmd = f"gsutil cp '{dirname}/checkpoint' {model_dir}/"
|
|
execute_command(cmd)
|
|
cmd = f"gsutil cp '{os.path.join(dirname, epoch)}*' {model_dir}/"
|
|
weights_dir = f"{model_dir}/{epoch}"
|
|
|
|
execute_command(cmd)
|
|
return weights_dir
|
|
|
|
def compute_precision_fixed_recall(labels, preds, fixed_recall):
|
|
precision_values, recall_values, thresholds = precision_recall_curve(y_true=labels, probas_pred=preds)
|
|
index_recall = bisect.bisect_left(-recall_values, -1 * fixed_recall)
|
|
result = precision_values[index_recall - 1]
|
|
print(f"Precision at {recall_values[index_recall-1]} recall: {result}")
|
|
|
|
return result, thresholds[index_recall - 1]
|
|
|
|
def load_inference_func(model_folder):
|
|
model = tf.saved_model.load(model_folder, ["serve"])
|
|
inference_func = model.signatures["serving_default"]
|
|
return inference_func
|
|
|
|
|
|
def execute_query(client, query):
|
|
job = client.query(query)
|
|
df = job.result().to_dataframe()
|
|
return df
|
|
|
|
|
|
def execute_command(cmd, print_=True):
|
|
s = subprocess.run(cmd, shell=True, capture_output=print_, check=True)
|
|
if print_:
|
|
print(s.stderr.decode("utf-8"))
|
|
print(s.stdout.decode("utf-8"))
|
|
|
|
|
|
def check_gpu():
|
|
try:
|
|
execute_command("nvidia-smi")
|
|
except subprocess.CalledProcessError:
|
|
print("There is no GPU when there should be one.")
|
|
raise AttributeError
|
|
|
|
l = tf.config.list_physical_devices("GPU")
|
|
if len(l) == 0:
|
|
raise ModuleNotFoundError("Tensorflow has not found the GPU. Check your installation")
|
|
print(l)
|
|
|
|
|
|
def set_seeds(seed):
|
|
np.random.seed(seed)
|
|
|
|
python_random.seed(seed)
|
|
|
|
tf.random.set_seed(seed)
|