the-algorithm-ml/projects/twhin/metrics.py

18 lines
287 B
Python

import torch
import torchmetrics as tm
import tml.core.metrics as core_metrics
def create_metrics(
device: torch.device,
):
metrics = dict()
metrics.update(
{
"AUC": core_metrics.Auc(128),
}
)
metrics = tm.MetricCollection(metrics).to(device)
return metrics