the-algorithm/trust_and_safety_models/toxicity/optim/callbacks.py

243 lines
8.4 KiB
Python
Raw Normal View History

import os
2023-04-17 06:19:03 +02:00
from collections import defaultdict
import tensorflow as tf
import wandb
2023-04-17 06:19:03 +02:00
from sklearn.metrics import average_precision_score, roc_auc_score
from toxicity_ml_pipeline.settings.default_settings_abs import LABEL_NAMES
from toxicity_ml_pipeline.settings.default_settings_tox import REMOTE_LOGDIR
from toxicity_ml_pipeline.utils.absv_utils import parse_labeled_data
from toxicity_ml_pipeline.utils.helpers import (
compute_precision_fixed_recall,
execute_command,
)
class NothingCallback(tf.keras.callbacks.Callback):
2023-04-17 06:19:03 +02:00
def on_epoch_begin(self, epoch, logs=None):
print("ici, ", epoch)
2023-04-17 06:19:03 +02:00
def on_epoch_end(self, epoch, logs=None):
print("fin ", epoch)
2023-04-17 06:19:03 +02:00
def on_train_batch_end(self, batch, logs=None):
print("fin de batch ", batch)
class ControlledStoppingCheckpointCallback(tf.keras.callbacks.ModelCheckpoint):
2023-04-17 06:19:03 +02:00
def __init__(self, stopping_epoch, *args, **kwargs):
super().__init__(*args, **kwargs)
self.stopping_epoch = stopping_epoch
2023-04-17 06:19:03 +02:00
def on_epoch_end(self, epoch, logs=None):
super().on_epoch_end(epoch, logs)
if epoch == self.stopping_epoch:
self.model.stop_training = True
class SyncingTensorBoard(tf.keras.callbacks.TensorBoard):
2023-04-17 06:19:03 +02:00
def __init__(self, remote_logdir=None, *args, **kwargs):
super().__init__(*args, **kwargs)
self.remote_logdir = (
remote_logdir if remote_logdir is not None else REMOTE_LOGDIR
)
2023-04-17 06:19:03 +02:00
def on_epoch_end(self, epoch, logs=None):
super().on_epoch_end(epoch, logs=logs)
self.synchronize()
2023-04-17 06:19:03 +02:00
def synchronize(self):
base_dir = os.path.dirname(self.log_dir)
cmd = f"gsutil -m rsync -r {base_dir} {self.remote_logdir}"
execute_command(cmd)
class GradientLoggingTensorBoard(SyncingTensorBoard):
2023-04-17 06:19:03 +02:00
def __init__(self, loader, val_data, freq, *args, **kwargs):
super().__init__(*args, **kwargs)
val_dataset = loader.get_balanced_dataset(
training_data=val_data, size_limit=50, return_as_batch=False
)
data_args = list(val_dataset.batch(32).take(1))[0]
self.x_batch, self.y_batch = data_args[0], data_args[1]
self.freq = freq
self.counter = 0
def _log_gradients(self):
writer = self._train_writer
with writer.as_default():
with tf.GradientTape() as tape:
y_pred = self.model(self.x_batch)
loss = self.model.compiled_loss(y_true=self.y_batch, y_pred=y_pred)
gradient_norm = tf.linalg.global_norm(
tape.gradient(loss, self.model.trainable_weights)
)
tf.summary.scalar("gradient_norm", data=gradient_norm, step=self.counter)
writer.flush()
def on_train_batch_end(self, batch, logs=None):
super().on_batch_end(batch, logs=logs)
self.counter += 1
if batch % self.freq == 0:
self._log_gradients()
class AdditionalResultLogger(tf.keras.callbacks.Callback):
2023-04-17 06:19:03 +02:00
def __init__(
self,
data,
set_,
fixed_recall=0.85,
from_logits=False,
dataset_transform_func=None,
batch_size=64,
dual_head=None,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.set_ = set_
if data is None:
return None
self.single_head = True
try:
self.labels = data.int_label.values
except AttributeError:
self.labels = data.to_dataframe()[LABEL_NAMES].values.astype("int")
self.data = data.to_tf_dataset().map(parse_labeled_data).batch(batch_size)
self.label_names = LABEL_NAMES
else:
self.label_names = [""]
if dual_head:
self.label_names = [f"{e}_label" for e in dual_head]
self.labels = {
f"{e}_output": data[f"{e}_label"].values for e in dual_head
}
self.single_head = False
if dataset_transform_func is None:
self.data = data.text.values
else:
self.data = dataset_transform_func(
data, mb_size=batch_size, shuffle=False
)
finally:
if len(self.label_names) == 1:
self.metric_kw = {}
else:
self.metric_kw = {"average": None}
self.counter = 0
self.best_metrics = defaultdict(float)
self.from_logits = from_logits
print(
f"Loaded callback for {set_}, from_logits: {from_logits}, labels {self.label_names}"
)
if 1 < fixed_recall <= 100:
fixed_recall = fixed_recall / 100
elif not (0 < fixed_recall <= 100):
raise ValueError("Threshold should be between 0 and 1, or 0 and 100")
self.fixed_recall = fixed_recall
self.batch_size = batch_size
def compute_precision_fixed_recall(self, labels, preds):
result, _ = compute_precision_fixed_recall(
labels=labels, preds=preds, fixed_recall=self.fixed_recall
)
return result
def on_epoch_end(self, epoch, logs=None):
self.additional_evaluations(step=epoch, eval_time="epoch")
def on_train_batch_end(self, batch, logs=None):
self.counter += 1
if self.counter % 2000 == 0:
self.additional_evaluations(step=self.counter, eval_time="batch")
def _binary_evaluations(self, preds, label_name=None, class_index=None):
mask = None
curr_labels = self.labels
if label_name is not None:
curr_labels = self.labels[label_name]
if class_index is not None:
curr_labels = (curr_labels == class_index).astype(int)
if -1 in curr_labels:
mask = curr_labels != -1
curr_labels = curr_labels[mask]
preds = preds[mask]
return {
f"precision_recall{self.fixed_recall}": self.compute_precision_fixed_recall(
labels=curr_labels, preds=preds
),
"pr_auc": average_precision_score(y_true=curr_labels, y_score=preds),
"roc_auc": roc_auc_score(y_true=curr_labels, y_score=preds),
}
def _multiclass_evaluations(self, preds):
pr_auc_l = average_precision_score(
y_true=self.labels, y_score=preds, **self.metric_kw
)
roc_auc_l = roc_auc_score(y_true=self.labels, y_score=preds, **self.metric_kw)
metrics = {}
for i, label in enumerate(self.label_names):
metrics[f"pr_auc_{label}"] = pr_auc_l[i]
metrics[f"roc_auc_{label}"] = roc_auc_l[i]
return metrics
def additional_evaluations(self, step, eval_time):
print("Evaluating ", self.set_, eval_time, step)
preds = self.model.predict(x=self.data, batch_size=self.batch_size)
if self.from_logits:
preds = tf.keras.activations.sigmoid(preds.logits).numpy()
if self.single_head:
if len(self.label_names) == 1:
metrics = self._binary_evaluations(preds)
else:
metrics = self._multiclass_evaluations(preds)
else:
if preds[0].shape[1] == 1:
binary_preds = preds[0]
multic_preds = preds[1]
else:
binary_preds = preds[1]
multic_preds = preds[0]
binary_metrics = self._binary_evaluations(
binary_preds, label_name="target_output"
)
metrics = {f"{k}_target": v for k, v in binary_metrics.items()}
num_classes = multic_preds.shape[1]
for class_ in range(num_classes):
binary_metrics = self._binary_evaluations(
multic_preds[:, class_],
label_name="content_output",
class_index=class_,
)
metrics.update(
{f"{k}_content_{class_}": v for k, v in binary_metrics.items()}
)
for k, v in metrics.items():
self.best_metrics[f"max_{k}"] = max(v, self.best_metrics[f"max_{k}"])
self.log_metrics(metrics, step=step, eval_time=eval_time)
def log_metrics(self, metrics_d, step, eval_time):
commit = False if self.set_ == "validation" else True
to_report = {self.set_: {**metrics_d, **self.best_metrics}}
if eval_time == "epoch":
to_report["epoch"] = step
wandb.log(to_report, commit=commit)