mirror of
https://github.com/twitter/the-algorithm-ml.git
synced 2024-11-16 21:29:24 +01:00
799254345f
remaning train_pipline.py
251 lines
6.2 KiB
Python
251 lines
6.2 KiB
Python
from dataclasses import dataclass
|
|
|
|
from tml.core import metrics as core_metrics
|
|
from tml.core.metric_mixin import MetricMixin, prepend_transform
|
|
|
|
import torch
|
|
from torchmetrics import MaxMetric, MetricCollection, SumMetric
|
|
|
|
|
|
@dataclass
|
|
class MockStratifierConfig:
|
|
"""
|
|
Configuration dataclass for mocking a stratifier.
|
|
|
|
Args:
|
|
name (str): The name of the stratifier.
|
|
index (int): The index of the stratifier.
|
|
value (int): The value of the stratifier.
|
|
"""
|
|
name: str
|
|
index: int
|
|
value: int
|
|
|
|
|
|
class Count(MetricMixin, SumMetric):
|
|
"""
|
|
Count metric class that inherits from MetricMixin and SumMetric.
|
|
|
|
This metric counts occurrences.
|
|
|
|
"""
|
|
def transform(self, outputs):
|
|
return {"value": 1}
|
|
|
|
|
|
Max = prepend_transform(MaxMetric, lambda outputs: {"value": outputs["value"]})
|
|
|
|
|
|
def test_count_metric():
|
|
"""
|
|
Test function for the Count metric.
|
|
|
|
It checks if the Count metric correctly counts the number of examples.
|
|
|
|
"""
|
|
num_examples = 123
|
|
examples = [
|
|
{"stuff": 0},
|
|
] * num_examples
|
|
|
|
metric = Count()
|
|
for outputs in examples:
|
|
metric.update(outputs)
|
|
|
|
assert metric.compute().item() == num_examples
|
|
|
|
|
|
def test_collections():
|
|
"""
|
|
Test function for metric collections.
|
|
|
|
It tests if metric collections correctly aggregate metrics.
|
|
|
|
"""
|
|
max_metric = Max()
|
|
count_metric = Count()
|
|
metric = MetricCollection([max_metric, count_metric])
|
|
|
|
examples = [{"value": idx} for idx in range(123)]
|
|
for outputs in examples:
|
|
metric.update(outputs)
|
|
|
|
assert metric.compute() == {
|
|
max_metric.__class__.__name__: len(examples) - 1,
|
|
count_metric.__class__.__name__: len(examples),
|
|
}
|
|
|
|
|
|
def test_task_dependent_ctr():
|
|
"""
|
|
Test function for task-dependent Ctr (Click-Through Rate) metric.
|
|
|
|
It checks if the Ctr metric computes the correct value for different tasks.
|
|
|
|
"""
|
|
num_examples = 144
|
|
batch_size = 1024
|
|
outputs = [
|
|
{
|
|
"stuff": 0,
|
|
"labels": torch.arange(0, 6).repeat(batch_size, 1),
|
|
}
|
|
for idx in range(num_examples)
|
|
]
|
|
|
|
for task_idx in range(5):
|
|
metric = core_metrics.Ctr(task_idx=task_idx)
|
|
for output in outputs:
|
|
metric.update(output)
|
|
assert metric.compute().item() == task_idx
|
|
|
|
|
|
def test_stratified_ctr():
|
|
"""
|
|
Test function for the Stratified Ctr (Click-Through Rate) metric.
|
|
|
|
It checks if the Stratified Ctr metric computes the correct value for different tasks
|
|
and stratified samples.
|
|
|
|
"""
|
|
outputs = [
|
|
{
|
|
"stuff": 0,
|
|
# [bsz, tasks]
|
|
"labels": torch.tensor(
|
|
[
|
|
[0, 1, 2, 3],
|
|
[1, 2, 3, 4],
|
|
[2, 3, 4, 0],
|
|
]
|
|
),
|
|
"stratifiers": {
|
|
# [bsz]
|
|
"level": torch.tensor(
|
|
[9, 0, 9],
|
|
),
|
|
},
|
|
}
|
|
]
|
|
|
|
stratifier = MockStratifierConfig(name="level", index=2, value=9)
|
|
for task_idx in range(5):
|
|
metric = core_metrics.Ctr(task_idx=1, stratifier=stratifier)
|
|
for output in outputs:
|
|
metric.update(output)
|
|
# From the dataset of:
|
|
# [
|
|
# [0, 1, 2, 3],
|
|
# [1, 2, 3, 4],
|
|
# [2, 3, 4, 0],
|
|
# ]
|
|
# we pick out
|
|
# [
|
|
# [0, 1, 2, 3],
|
|
# [2, 3, 4, 0],
|
|
# ]
|
|
# and with Ctr task_idx, we pick out
|
|
# [
|
|
# [1,],
|
|
# [3,],
|
|
# ]
|
|
assert metric.compute().item() == (1 + 3) / 2
|
|
|
|
|
|
def test_auc():
|
|
"""
|
|
Test function for the AUC (Area Under the Curve) metric.
|
|
|
|
It checks if the AUC metric correctly computes the Area Under the ROC Curve.
|
|
|
|
"""
|
|
num_samples = 10000
|
|
metric = core_metrics.Auc(num_samples)
|
|
target = torch.tensor([0, 0, 1, 1, 1])
|
|
preds_correct = torch.tensor([-1.0, -1.0, 1.0, 1.0, 1.0])
|
|
outputs_correct = {"logits": preds_correct, "labels": target}
|
|
preds_bad = torch.tensor([1.0, 1.0, -1.0, -1.0, -1.0])
|
|
outputs_bad = {"logits": preds_bad, "labels": target}
|
|
|
|
metric.update(outputs_correct)
|
|
assert metric.compute().item() == 1.0
|
|
|
|
metric.reset()
|
|
metric.update(outputs_bad)
|
|
assert metric.compute().item() == 0.0
|
|
|
|
|
|
def test_pos_rank():
|
|
"""
|
|
Test function for the PosRanks metric.
|
|
|
|
It checks if the PosRanks metric correctly computes the ranks of positive samples.
|
|
|
|
"""
|
|
metric = core_metrics.PosRanks()
|
|
target = torch.tensor([0, 0, 1, 1, 1])
|
|
preds_correct = torch.tensor([-1.0, -1.0, 0.5, 1.0, 1.5])
|
|
outputs_correct = {"logits": preds_correct, "labels": target}
|
|
preds_bad = torch.tensor([1.0, 1.0, -1.5, -1.0, -0.5])
|
|
outputs_bad = {"logits": preds_bad, "labels": target}
|
|
|
|
metric.update(outputs_correct)
|
|
assert metric.compute().item() == 2.0
|
|
|
|
metric.reset()
|
|
metric.update(outputs_bad)
|
|
assert metric.compute().item() == 4.0
|
|
|
|
|
|
def test_reciprocal_rank():
|
|
"""
|
|
Test function for the Reciprocal Rank metric.
|
|
|
|
It checks if the Reciprocal Rank metric correctly computes the reciprocal of ranks.
|
|
|
|
"""
|
|
metric = core_metrics.ReciprocalRank()
|
|
target = torch.tensor([0, 0, 1, 1, 1])
|
|
preds_correct = torch.tensor([-1.0, -1.0, 0.5, 1.0, 1.5])
|
|
outputs_correct = {"logits": preds_correct, "labels": target}
|
|
preds_bad = torch.tensor([1.0, 1.0, -1.5, -1.0, -0.5])
|
|
outputs_bad = {"logits": preds_bad, "labels": target}
|
|
|
|
metric.update(outputs_correct)
|
|
assert abs(metric.compute().item() - 0.6111) < 0.001
|
|
|
|
metric.reset()
|
|
metric.update(outputs_bad)
|
|
assert abs(metric.compute().item() == 0.2611) < 0.001
|
|
|
|
|
|
def test_hit_k():
|
|
"""
|
|
Test function for the Hit@K metric.
|
|
|
|
It checks if the Hit@K metric correctly computes the fraction of positives that rank in the top K among their negatives.
|
|
|
|
"""
|
|
hit1_metric = core_metrics.HitAtK(1)
|
|
target = torch.tensor([0, 0, 1, 1, 1])
|
|
preds_correct = torch.tensor([-1.0, 1.0, 0.5, -0.1, 1.5])
|
|
outputs_correct = {"logits": preds_correct, "labels": target}
|
|
preds_bad = torch.tensor([1.0, 1.0, -1.5, -1.0, -0.5])
|
|
outputs_bad = {"logits": preds_bad, "labels": target}
|
|
|
|
hit1_metric.update(outputs_correct)
|
|
assert abs(hit1_metric.compute().item() - 0.3333) < 0.0001
|
|
|
|
hit1_metric.reset()
|
|
hit1_metric.update(outputs_bad)
|
|
|
|
assert hit1_metric.compute().item() == 0
|
|
|
|
hit3_metric = core_metrics.HitAtK(3)
|
|
hit3_metric.update(outputs_correct)
|
|
assert (hit3_metric.compute().item() - 0.66666) < 0.0001
|
|
|
|
hit3_metric.reset()
|
|
hit3_metric.update(outputs_bad)
|
|
assert abs(hit3_metric.compute().item() - 0.3333) < 0.0001
|