2023-04-01 00:36:31 +02:00
|
|
|
import os
|
2023-04-17 06:19:03 +02:00
|
|
|
from collections import defaultdict
|
2023-04-01 00:36:31 +02:00
|
|
|
|
|
|
|
import tensorflow as tf
|
|
|
|
import wandb
|
2023-04-17 06:19:03 +02:00
|
|
|
from sklearn.metrics import average_precision_score, roc_auc_score
|
|
|
|
from toxicity_ml_pipeline.settings.default_settings_abs import LABEL_NAMES
|
|
|
|
from toxicity_ml_pipeline.settings.default_settings_tox import REMOTE_LOGDIR
|
|
|
|
from toxicity_ml_pipeline.utils.absv_utils import parse_labeled_data
|
|
|
|
from toxicity_ml_pipeline.utils.helpers import (
|
|
|
|
compute_precision_fixed_recall,
|
|
|
|
execute_command,
|
|
|
|
)
|
2023-04-01 00:36:31 +02:00
|
|
|
|
|
|
|
|
|
|
|
class NothingCallback(tf.keras.callbacks.Callback):
|
2023-04-17 06:19:03 +02:00
|
|
|
def on_epoch_begin(self, epoch, logs=None):
|
|
|
|
print("ici, ", epoch)
|
2023-04-01 00:36:31 +02:00
|
|
|
|
2023-04-17 06:19:03 +02:00
|
|
|
def on_epoch_end(self, epoch, logs=None):
|
|
|
|
print("fin ", epoch)
|
2023-04-01 00:36:31 +02:00
|
|
|
|
2023-04-17 06:19:03 +02:00
|
|
|
def on_train_batch_end(self, batch, logs=None):
|
|
|
|
print("fin de batch ", batch)
|
2023-04-01 00:36:31 +02:00
|
|
|
|
|
|
|
|
|
|
|
class ControlledStoppingCheckpointCallback(tf.keras.callbacks.ModelCheckpoint):
|
2023-04-17 06:19:03 +02:00
|
|
|
def __init__(self, stopping_epoch, *args, **kwargs):
|
|
|
|
super().__init__(*args, **kwargs)
|
|
|
|
self.stopping_epoch = stopping_epoch
|
2023-04-01 00:36:31 +02:00
|
|
|
|
2023-04-17 06:19:03 +02:00
|
|
|
def on_epoch_end(self, epoch, logs=None):
|
|
|
|
super().on_epoch_end(epoch, logs)
|
|
|
|
if epoch == self.stopping_epoch:
|
|
|
|
self.model.stop_training = True
|
2023-04-01 00:36:31 +02:00
|
|
|
|
|
|
|
|
|
|
|
class SyncingTensorBoard(tf.keras.callbacks.TensorBoard):
|
2023-04-17 06:19:03 +02:00
|
|
|
def __init__(self, remote_logdir=None, *args, **kwargs):
|
|
|
|
super().__init__(*args, **kwargs)
|
|
|
|
self.remote_logdir = (
|
|
|
|
remote_logdir if remote_logdir is not None else REMOTE_LOGDIR
|
|
|
|
)
|
2023-04-01 00:36:31 +02:00
|
|
|
|
2023-04-17 06:19:03 +02:00
|
|
|
def on_epoch_end(self, epoch, logs=None):
|
|
|
|
super().on_epoch_end(epoch, logs=logs)
|
|
|
|
self.synchronize()
|
2023-04-01 00:36:31 +02:00
|
|
|
|
2023-04-17 06:19:03 +02:00
|
|
|
def synchronize(self):
|
|
|
|
base_dir = os.path.dirname(self.log_dir)
|
|
|
|
cmd = f"gsutil -m rsync -r {base_dir} {self.remote_logdir}"
|
|
|
|
execute_command(cmd)
|
2023-04-01 00:36:31 +02:00
|
|
|
|
|
|
|
|
|
|
|
class GradientLoggingTensorBoard(SyncingTensorBoard):
|
2023-04-17 06:19:03 +02:00
|
|
|
def __init__(self, loader, val_data, freq, *args, **kwargs):
|
|
|
|
super().__init__(*args, **kwargs)
|
|
|
|
val_dataset = loader.get_balanced_dataset(
|
|
|
|
training_data=val_data, size_limit=50, return_as_batch=False
|
|
|
|
)
|
|
|
|
data_args = list(val_dataset.batch(32).take(1))[0]
|
|
|
|
self.x_batch, self.y_batch = data_args[0], data_args[1]
|
|
|
|
self.freq = freq
|
|
|
|
self.counter = 0
|
|
|
|
|
|
|
|
def _log_gradients(self):
|
|
|
|
writer = self._train_writer
|
|
|
|
|
|
|
|
with writer.as_default():
|
|
|
|
with tf.GradientTape() as tape:
|
|
|
|
y_pred = self.model(self.x_batch)
|
|
|
|
loss = self.model.compiled_loss(y_true=self.y_batch, y_pred=y_pred)
|
|
|
|
gradient_norm = tf.linalg.global_norm(
|
|
|
|
tape.gradient(loss, self.model.trainable_weights)
|
|
|
|
)
|
|
|
|
|
|
|
|
tf.summary.scalar("gradient_norm", data=gradient_norm, step=self.counter)
|
|
|
|
writer.flush()
|
|
|
|
|
|
|
|
def on_train_batch_end(self, batch, logs=None):
|
|
|
|
super().on_batch_end(batch, logs=logs)
|
|
|
|
self.counter += 1
|
|
|
|
if batch % self.freq == 0:
|
|
|
|
self._log_gradients()
|
2023-04-01 00:36:31 +02:00
|
|
|
|
|
|
|
|
|
|
|
class AdditionalResultLogger(tf.keras.callbacks.Callback):
|
2023-04-17 06:19:03 +02:00
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
data,
|
|
|
|
set_,
|
|
|
|
fixed_recall=0.85,
|
|
|
|
from_logits=False,
|
|
|
|
dataset_transform_func=None,
|
|
|
|
batch_size=64,
|
|
|
|
dual_head=None,
|
|
|
|
*args,
|
|
|
|
**kwargs,
|
|
|
|
):
|
|
|
|
super().__init__(*args, **kwargs)
|
|
|
|
self.set_ = set_
|
|
|
|
if data is None:
|
|
|
|
return None
|
|
|
|
|
|
|
|
self.single_head = True
|
|
|
|
try:
|
|
|
|
self.labels = data.int_label.values
|
|
|
|
except AttributeError:
|
|
|
|
self.labels = data.to_dataframe()[LABEL_NAMES].values.astype("int")
|
|
|
|
self.data = data.to_tf_dataset().map(parse_labeled_data).batch(batch_size)
|
|
|
|
self.label_names = LABEL_NAMES
|
|
|
|
else:
|
|
|
|
self.label_names = [""]
|
|
|
|
if dual_head:
|
|
|
|
self.label_names = [f"{e}_label" for e in dual_head]
|
|
|
|
self.labels = {
|
|
|
|
f"{e}_output": data[f"{e}_label"].values for e in dual_head
|
|
|
|
}
|
|
|
|
self.single_head = False
|
|
|
|
if dataset_transform_func is None:
|
|
|
|
self.data = data.text.values
|
|
|
|
else:
|
|
|
|
self.data = dataset_transform_func(
|
|
|
|
data, mb_size=batch_size, shuffle=False
|
|
|
|
)
|
|
|
|
|
|
|
|
finally:
|
|
|
|
if len(self.label_names) == 1:
|
|
|
|
self.metric_kw = {}
|
|
|
|
else:
|
|
|
|
self.metric_kw = {"average": None}
|
|
|
|
|
|
|
|
self.counter = 0
|
|
|
|
self.best_metrics = defaultdict(float)
|
|
|
|
self.from_logits = from_logits
|
|
|
|
print(
|
|
|
|
f"Loaded callback for {set_}, from_logits: {from_logits}, labels {self.label_names}"
|
|
|
|
)
|
|
|
|
|
|
|
|
if 1 < fixed_recall <= 100:
|
|
|
|
fixed_recall = fixed_recall / 100
|
|
|
|
elif not (0 < fixed_recall <= 100):
|
|
|
|
raise ValueError("Threshold should be between 0 and 1, or 0 and 100")
|
|
|
|
self.fixed_recall = fixed_recall
|
|
|
|
self.batch_size = batch_size
|
|
|
|
|
|
|
|
def compute_precision_fixed_recall(self, labels, preds):
|
|
|
|
result, _ = compute_precision_fixed_recall(
|
|
|
|
labels=labels, preds=preds, fixed_recall=self.fixed_recall
|
|
|
|
)
|
|
|
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
def on_epoch_end(self, epoch, logs=None):
|
|
|
|
self.additional_evaluations(step=epoch, eval_time="epoch")
|
|
|
|
|
|
|
|
def on_train_batch_end(self, batch, logs=None):
|
|
|
|
self.counter += 1
|
|
|
|
if self.counter % 2000 == 0:
|
|
|
|
self.additional_evaluations(step=self.counter, eval_time="batch")
|
|
|
|
|
|
|
|
def _binary_evaluations(self, preds, label_name=None, class_index=None):
|
|
|
|
mask = None
|
|
|
|
curr_labels = self.labels
|
|
|
|
if label_name is not None:
|
|
|
|
curr_labels = self.labels[label_name]
|
|
|
|
if class_index is not None:
|
|
|
|
curr_labels = (curr_labels == class_index).astype(int)
|
|
|
|
|
|
|
|
if -1 in curr_labels:
|
|
|
|
mask = curr_labels != -1
|
|
|
|
curr_labels = curr_labels[mask]
|
|
|
|
preds = preds[mask]
|
|
|
|
|
|
|
|
return {
|
|
|
|
f"precision_recall{self.fixed_recall}": self.compute_precision_fixed_recall(
|
|
|
|
labels=curr_labels, preds=preds
|
|
|
|
),
|
|
|
|
"pr_auc": average_precision_score(y_true=curr_labels, y_score=preds),
|
|
|
|
"roc_auc": roc_auc_score(y_true=curr_labels, y_score=preds),
|
|
|
|
}
|
|
|
|
|
|
|
|
def _multiclass_evaluations(self, preds):
|
|
|
|
pr_auc_l = average_precision_score(
|
|
|
|
y_true=self.labels, y_score=preds, **self.metric_kw
|
|
|
|
)
|
|
|
|
roc_auc_l = roc_auc_score(y_true=self.labels, y_score=preds, **self.metric_kw)
|
|
|
|
metrics = {}
|
|
|
|
for i, label in enumerate(self.label_names):
|
|
|
|
metrics[f"pr_auc_{label}"] = pr_auc_l[i]
|
|
|
|
metrics[f"roc_auc_{label}"] = roc_auc_l[i]
|
|
|
|
|
|
|
|
return metrics
|
|
|
|
|
|
|
|
def additional_evaluations(self, step, eval_time):
|
|
|
|
print("Evaluating ", self.set_, eval_time, step)
|
|
|
|
|
|
|
|
preds = self.model.predict(x=self.data, batch_size=self.batch_size)
|
|
|
|
if self.from_logits:
|
|
|
|
preds = tf.keras.activations.sigmoid(preds.logits).numpy()
|
|
|
|
|
|
|
|
if self.single_head:
|
|
|
|
if len(self.label_names) == 1:
|
|
|
|
metrics = self._binary_evaluations(preds)
|
|
|
|
else:
|
|
|
|
metrics = self._multiclass_evaluations(preds)
|
|
|
|
else:
|
|
|
|
if preds[0].shape[1] == 1:
|
|
|
|
binary_preds = preds[0]
|
|
|
|
multic_preds = preds[1]
|
|
|
|
else:
|
|
|
|
binary_preds = preds[1]
|
|
|
|
multic_preds = preds[0]
|
|
|
|
|
|
|
|
binary_metrics = self._binary_evaluations(
|
|
|
|
binary_preds, label_name="target_output"
|
|
|
|
)
|
|
|
|
metrics = {f"{k}_target": v for k, v in binary_metrics.items()}
|
|
|
|
num_classes = multic_preds.shape[1]
|
|
|
|
for class_ in range(num_classes):
|
|
|
|
binary_metrics = self._binary_evaluations(
|
|
|
|
multic_preds[:, class_],
|
|
|
|
label_name="content_output",
|
|
|
|
class_index=class_,
|
|
|
|
)
|
|
|
|
metrics.update(
|
|
|
|
{f"{k}_content_{class_}": v for k, v in binary_metrics.items()}
|
|
|
|
)
|
|
|
|
|
|
|
|
for k, v in metrics.items():
|
|
|
|
self.best_metrics[f"max_{k}"] = max(v, self.best_metrics[f"max_{k}"])
|
|
|
|
|
|
|
|
self.log_metrics(metrics, step=step, eval_time=eval_time)
|
|
|
|
|
|
|
|
def log_metrics(self, metrics_d, step, eval_time):
|
|
|
|
commit = False if self.set_ == "validation" else True
|
|
|
|
to_report = {self.set_: {**metrics_d, **self.best_metrics}}
|
|
|
|
|
|
|
|
if eval_time == "epoch":
|
|
|
|
to_report["epoch"] = step
|
|
|
|
|
|
|
|
wandb.log(to_report, commit=commit)
|