mirror of
https://github.com/twitter/the-algorithm-ml.git
synced 2024-11-05 16:25:08 +01:00
core
update
remaning train_pipline.py
This commit is contained in:
parent
b85210863f
commit
799254345f
@ -46,7 +46,24 @@ def get_new_iterator(iterable: Iterable):
|
|||||||
|
|
||||||
|
|
||||||
def _get_step_fn(pipeline, data_iterator, training: bool):
|
def _get_step_fn(pipeline, data_iterator, training: bool):
|
||||||
|
"""
|
||||||
|
Returns a function to perform a single evaluation step.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pipeline (Pipeline): The pipeline object containing the model.
|
||||||
|
data_iterator (Iterator): The data iterator for evaluation.
|
||||||
|
training (bool): Flag indicating if the model should be in training mode.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
function: A function that performs a single evaluation step.
|
||||||
|
"""
|
||||||
def step_fn():
|
def step_fn():
|
||||||
|
"""
|
||||||
|
Perform a single evaluation step.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Any: The evaluation results after a single step.
|
||||||
|
"""
|
||||||
# It turns out that model.train() and model.eval() simply switch a single field inside the model
|
# It turns out that model.train() and model.eval() simply switch a single field inside the model
|
||||||
# class,so it's somewhat safer to wrap in here.
|
# class,so it's somewhat safer to wrap in here.
|
||||||
if training:
|
if training:
|
||||||
@ -69,7 +86,21 @@ def _run_evaluation(
|
|||||||
eval_batch_size: int,
|
eval_batch_size: int,
|
||||||
logger=None,
|
logger=None,
|
||||||
):
|
):
|
||||||
"""Runs the evaluation loop over all evaluation iterators."""
|
"""
|
||||||
|
Run the evaluation loop over all evaluation iterators.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pipeline (Pipeline): The pipeline object containing the model.
|
||||||
|
dataset (Dataset): The dataset to evaluate.
|
||||||
|
eval_steps (int): The number of evaluation steps to perform.
|
||||||
|
metrics (tm.MetricCollection): A collection of evaluation metrics.
|
||||||
|
eval_batch_size (int): Batch size for evaluation.
|
||||||
|
logger (Optional[Logger]): A logger for recording evaluation progress (default: None).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: A dictionary containing the computed evaluation metrics.
|
||||||
|
"""
|
||||||
|
|
||||||
dataset = get_new_iterator(dataset)
|
dataset = get_new_iterator(dataset)
|
||||||
step_fn = _get_step_fn(pipeline, dataset, training=False)
|
step_fn = _get_step_fn(pipeline, dataset, training=False)
|
||||||
last_time = datetime.datetime.now()
|
last_time = datetime.datetime.now()
|
||||||
@ -109,15 +140,29 @@ def train(
|
|||||||
parameters_to_log: Optional[Dict[str, Callable]] = None,
|
parameters_to_log: Optional[Dict[str, Callable]] = None,
|
||||||
tables_to_log: Optional[List[str]] = None,
|
tables_to_log: Optional[List[str]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Runs training and eval on the given TrainPipeline
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dataset: data iterator for the training set
|
|
||||||
evaluation_iterators: data iterators for the different evaluation sets
|
|
||||||
scheduler: optional learning rate scheduler
|
|
||||||
output_transform_for_metrics: optional transformation functions to transorm the model
|
|
||||||
output and labels into a format the metrics can understand
|
|
||||||
"""
|
"""
|
||||||
|
Runs training and evaluation on the given TrainPipeline.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (torch.nn.Module): The neural network model to train.
|
||||||
|
optimizer (torch.optim.Optimizer): The optimizer for model optimization.
|
||||||
|
device (str): The target device for model training (e.g., 'cuda' or 'cpu').
|
||||||
|
save_dir (str): The directory to save model checkpoints and logs.
|
||||||
|
logging_interval (int): Interval for logging training progress.
|
||||||
|
train_steps (int): The number of training steps to perform.
|
||||||
|
checkpoint_frequency (int): Frequency of saving model checkpoints.
|
||||||
|
dataset (Iterable): Data iterator for the training set.
|
||||||
|
worker_batch_size (int): Batch size for data loading workers.
|
||||||
|
num_workers (Optional[int]): Number of data loading workers (default: 0).
|
||||||
|
enable_amp (bool): Flag to enable Automatic Mixed Precision (AMP) training (default: False).
|
||||||
|
initial_checkpoint_dir (Optional[str]): Directory to initialize training from (default: None).
|
||||||
|
gradient_accumulation (Optional[int]): Number of gradient accumulation steps (default: None).
|
||||||
|
logger_initializer (Optional[Callable]): A logger initializer function (default: None).
|
||||||
|
scheduler (_LRScheduler): Optional learning rate scheduler (default: None).
|
||||||
|
metrics (Optional[tm.MetricCollection]): A collection of evaluation metrics (default: None).
|
||||||
|
parameters_to_log (Optional[Dict[str, Callable]]): Dictionary of parameters to log (default: None).
|
||||||
|
tables_to_log (Optional[List[str]]): List of tables to log (default: None).
|
||||||
|
"""
|
||||||
|
|
||||||
train_pipeline = TrainPipelineSparseDist(
|
train_pipeline = TrainPipelineSparseDist(
|
||||||
model=model,
|
model=model,
|
||||||
@ -262,6 +307,15 @@ def log_eval_results(
|
|||||||
partition_name: str,
|
partition_name: str,
|
||||||
step: int,
|
step: int,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Logs evaluation results and optionally records them using a provided logger.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results (Any): The evaluation results to log.
|
||||||
|
eval_logger (Callable): A logger for recording evaluation results.
|
||||||
|
partition_name (str): The name of the evaluation partition.
|
||||||
|
step (int): The current step in the evaluation.
|
||||||
|
"""
|
||||||
results = tree.map_structure(lambda elem: torch.as_tensor(elem).cpu(), results)
|
results = tree.map_structure(lambda elem: torch.as_tensor(elem).cpu(), results)
|
||||||
logging.info(f"Step: {step}, evaluation ({partition_name}).")
|
logging.info(f"Step: {step}, evaluation ({partition_name}).")
|
||||||
for metric_name, metric_value in results.items():
|
for metric_name, metric_value in results.items():
|
||||||
@ -285,6 +339,23 @@ def only_evaluate(
|
|||||||
partition_name: str,
|
partition_name: str,
|
||||||
metrics: Optional[tm.MetricCollection] = None,
|
metrics: Optional[tm.MetricCollection] = None,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Performs evaluation on a given dataset partition.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (torch.nn.Module): The neural network model for evaluation.
|
||||||
|
optimizer (torch.optim.Optimizer): The optimizer used during evaluation.
|
||||||
|
device (str): The target device for evaluation (e.g., 'cuda' or 'cpu').
|
||||||
|
save_dir (str): The directory containing model checkpoints.
|
||||||
|
num_train_steps (int): The total number of training steps.
|
||||||
|
dataset (Iterable): Data iterator for evaluation.
|
||||||
|
eval_batch_size (int): Batch size for evaluation.
|
||||||
|
num_eval_steps (int): The number of evaluation steps to perform.
|
||||||
|
eval_timeout_in_s (int): Timeout for evaluating checkpoints in seconds.
|
||||||
|
eval_logger (Callable): A logger for recording evaluation results.
|
||||||
|
partition_name (str): The name of the evaluation partition.
|
||||||
|
metrics (Optional[tm.MetricCollection]): A collection of evaluation metrics (default: None).
|
||||||
|
"""
|
||||||
logging.info(f"Evaluating on partition {partition_name}.")
|
logging.info(f"Evaluating on partition {partition_name}.")
|
||||||
logging.info("Computing metrics:")
|
logging.info("Computing metrics:")
|
||||||
logging.info(metrics)
|
logging.info(metrics)
|
||||||
|
@ -28,6 +28,18 @@ def train(
|
|||||||
*args,
|
*args,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
"""
|
||||||
|
Debugging training loop. Do not use for actual model training.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (torch.nn.Module): The neural network model.
|
||||||
|
optimizer (torch.optim.Optimizer): The optimizer for model optimization.
|
||||||
|
train_steps (int): The number of training steps to perform.
|
||||||
|
dataset (Iterable): Data iterator for training data.
|
||||||
|
scheduler (_LRScheduler, optional): Learning rate scheduler (default: None).
|
||||||
|
*args: Additional arguments (ignored).
|
||||||
|
**kwargs: Additional keyword arguments (ignored).
|
||||||
|
"""
|
||||||
|
|
||||||
logging.warning("Running debug training loop, don't use for model training.")
|
logging.warning("Running debug training loop, don't use for model training.")
|
||||||
|
|
||||||
|
@ -10,8 +10,11 @@ import torch
|
|||||||
|
|
||||||
def _maybe_warn(reduction: str):
|
def _maybe_warn(reduction: str):
|
||||||
"""
|
"""
|
||||||
Warning for reduction different than mean.
|
Emit a warning if the reduction method is different from 'mean'.
|
||||||
"""
|
|
||||||
|
Args:
|
||||||
|
reduction (str): The reduction method being used.
|
||||||
|
"""
|
||||||
if reduction != "mean":
|
if reduction != "mean":
|
||||||
logging.warn(
|
logging.warn(
|
||||||
f"For the same global_batch_size, the gradient in DDP is guaranteed to be equal,"
|
f"For the same global_batch_size, the gradient in DDP is guaranteed to be equal,"
|
||||||
@ -24,6 +27,16 @@ def build_loss(
|
|||||||
loss_type: LossType,
|
loss_type: LossType,
|
||||||
reduction="mean",
|
reduction="mean",
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Build a loss function based on the specified loss type and reduction method.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
loss_type (LossType): The type of loss to build.
|
||||||
|
reduction (str): The reduction method for the loss (default: 'mean').
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Callable: A loss function that takes logits and labels as input.
|
||||||
|
"""
|
||||||
_maybe_warn(reduction)
|
_maybe_warn(reduction)
|
||||||
f = _LOSS_TYPE_TO_FUNCTION[loss_type]
|
f = _LOSS_TYPE_TO_FUNCTION[loss_type]
|
||||||
|
|
||||||
@ -35,11 +48,15 @@ def build_loss(
|
|||||||
|
|
||||||
def get_global_loss_detached(local_loss, reduction="mean"):
|
def get_global_loss_detached(local_loss, reduction="mean"):
|
||||||
"""
|
"""
|
||||||
Perform all_reduce to obtain the global loss function using the provided reduction.
|
Perform all_reduce to obtain the global loss function using the provided reduction.
|
||||||
:param local_loss: The local loss of the current rank.
|
|
||||||
:param reduction: The reduction to use for all_reduce. Should match the reduction used by DDP.
|
Args:
|
||||||
:return: The reduced & detached global loss.
|
local_loss (torch.Tensor): The local loss of the current rank.
|
||||||
"""
|
reduction (str): The reduction to use for all_reduce. Should match the reduction used by DDP.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The reduced and detached global loss.
|
||||||
|
"""
|
||||||
if reduction != "mean":
|
if reduction != "mean":
|
||||||
logging.warn(
|
logging.warn(
|
||||||
f"The reduction used in this function should be the same as the one used by "
|
f"The reduction used in this function should be the same as the one used by "
|
||||||
@ -66,6 +83,19 @@ def build_multi_task_loss(
|
|||||||
global_reduction="mean",
|
global_reduction="mean",
|
||||||
pos_weights=None,
|
pos_weights=None,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Build a multi-task loss function based on the specified loss type and configurations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
loss_type (LossType): The type of loss to build.
|
||||||
|
tasks (typing.List[str]): List of task names.
|
||||||
|
task_loss_reduction (str): Reduction method for task-specific losses (default: 'mean').
|
||||||
|
global_reduction (str): Reduction method for the global loss (default: 'mean').
|
||||||
|
pos_weights (Optional): Positive class weights for tasks (default: None).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Callable: A multi-task loss function that takes logits, labels, and weights as input.
|
||||||
|
"""
|
||||||
_maybe_warn(global_reduction)
|
_maybe_warn(global_reduction)
|
||||||
_maybe_warn(task_loss_reduction)
|
_maybe_warn(task_loss_reduction)
|
||||||
f = _LOSS_TYPE_TO_FUNCTION[loss_type]
|
f = _LOSS_TYPE_TO_FUNCTION[loss_type]
|
||||||
|
@ -36,9 +36,24 @@ import torchmetrics
|
|||||||
class MetricMixin:
|
class MetricMixin:
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def transform(self, outputs: Dict[str, torch.Tensor]) -> Dict:
|
def transform(self, outputs: Dict[str, torch.Tensor]) -> Dict:
|
||||||
|
"""
|
||||||
|
Abstract method to transform model outputs into a dictionary of metrics.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
outputs (Dict[str, torch.Tensor]): Model outputs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict: A dictionary of computed metrics.
|
||||||
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
def update(self, outputs: Dict[str, torch.Tensor]):
|
def update(self, outputs: Dict[str, torch.Tensor]):
|
||||||
|
"""
|
||||||
|
Update the metrics based on model outputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
outputs (Dict[str, torch.Tensor]): Model outputs.
|
||||||
|
"""
|
||||||
results = self.transform(outputs)
|
results = self.transform(outputs)
|
||||||
# Do not try to update if any tensor is empty as a result of stratification.
|
# Do not try to update if any tensor is empty as a result of stratification.
|
||||||
for value in results.values():
|
for value in results.values():
|
||||||
@ -49,6 +64,13 @@ class MetricMixin:
|
|||||||
|
|
||||||
class TaskMixin:
|
class TaskMixin:
|
||||||
def __init__(self, task_idx: int = -1, **kwargs):
|
def __init__(self, task_idx: int = -1, **kwargs):
|
||||||
|
"""
|
||||||
|
Initialize a TaskMixin instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_idx (int): Index of the task associated with this mixin (default: -1).
|
||||||
|
**kwargs: Additional keyword arguments.
|
||||||
|
"""
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self._task_idx = task_idx
|
self._task_idx = task_idx
|
||||||
|
|
||||||
@ -59,13 +81,31 @@ class StratifyMixin:
|
|||||||
stratifier=None,
|
stratifier=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Initialize a StratifyMixin instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stratifier: A stratifier for filtering outputs (default: None).
|
||||||
|
**kwargs: Additional keyword arguments.
|
||||||
|
"""
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self._stratifier = stratifier
|
self._stratifier = stratifier
|
||||||
|
|
||||||
def maybe_apply_stratification(
|
def maybe_apply_stratification(
|
||||||
self, outputs: Dict[str, torch.Tensor], value_names: List[str]
|
self, outputs: Dict[str, torch.Tensor], value_names: List[str]
|
||||||
) -> Dict[str, torch.Tensor]:
|
) -> Dict[str, torch.Tensor]:
|
||||||
"""Pick out examples with values for which the stratifier feature is equal to a specific stratifier indicator value."""
|
"""
|
||||||
|
Apply stratification to filter examples in the outputs.
|
||||||
|
|
||||||
|
Pick out examples with values for which the stratifier feature is equal to a specific stratifier indicator value.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
outputs (Dict[str, torch.Tensor]): Model outputs.
|
||||||
|
value_names (List[str]): Names of values to filter.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, torch.Tensor]: Filtered outputs.
|
||||||
|
"""
|
||||||
outputs = outputs.copy()
|
outputs = outputs.copy()
|
||||||
if not self._stratifier:
|
if not self._stratifier:
|
||||||
return outputs
|
return outputs
|
||||||
@ -84,12 +124,20 @@ class StratifyMixin:
|
|||||||
|
|
||||||
|
|
||||||
def prepend_transform(base_metric: torchmetrics.Metric, transform: Callable):
|
def prepend_transform(base_metric: torchmetrics.Metric, transform: Callable):
|
||||||
"""Returns new class using MetricMixin and given base_metric.
|
|
||||||
|
|
||||||
Functionally the same using inheritance, just saves some lines of code
|
|
||||||
if no need for class attributes.
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
Returns a new class using MetricMixin and the given base_metric.
|
||||||
|
|
||||||
|
Functionally the same as using inheritance, but it saves some lines of code
|
||||||
|
if there's no need for class attributes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
base_metric (torchmetrics.Metric): The base metric class to prepend the transform to.
|
||||||
|
transform (Callable): The transformation function to prepend to the metric.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Type: A new class that includes MetricMixin and the provided base_metric
|
||||||
|
with the specified transformation method.
|
||||||
|
"""
|
||||||
|
|
||||||
def transform_method(_self, *args, **kwargs):
|
def transform_method(_self, *args, **kwargs):
|
||||||
return transform(*args, **kwargs)
|
return transform(*args, **kwargs)
|
||||||
|
@ -15,6 +15,16 @@ def probs_and_labels(
|
|||||||
outputs: Dict[str, torch.Tensor],
|
outputs: Dict[str, torch.Tensor],
|
||||||
task_idx: int,
|
task_idx: int,
|
||||||
) -> Dict[str, torch.Tensor]:
|
) -> Dict[str, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Extract probabilities and labels from model outputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
outputs (Dict[str, torch.Tensor]): Model outputs.
|
||||||
|
task_idx (int): Index of the task.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, torch.Tensor]: Dictionary containing 'preds' and 'target' tensors.
|
||||||
|
"""
|
||||||
preds = outputs["probabilities"]
|
preds = outputs["probabilities"]
|
||||||
target = outputs["labels"]
|
target = outputs["labels"]
|
||||||
if task_idx >= 0:
|
if task_idx >= 0:
|
||||||
@ -28,6 +38,11 @@ def probs_and_labels(
|
|||||||
|
|
||||||
class Count(StratifyMixin, TaskMixin, MetricMixin, tm.SumMetric):
|
class Count(StratifyMixin, TaskMixin, MetricMixin, tm.SumMetric):
|
||||||
def transform(self, outputs):
|
def transform(self, outputs):
|
||||||
|
"""
|
||||||
|
Count metric class that inherits from StratifyMixin, TaskMixin, MetricMixin, and SumMetric.
|
||||||
|
|
||||||
|
This metric counts values after potential stratification and task selection.
|
||||||
|
"""
|
||||||
outputs = self.maybe_apply_stratification(outputs, ["labels"])
|
outputs = self.maybe_apply_stratification(outputs, ["labels"])
|
||||||
value = outputs["labels"]
|
value = outputs["labels"]
|
||||||
if self._task_idx >= 0:
|
if self._task_idx >= 0:
|
||||||
@ -36,6 +51,12 @@ class Count(StratifyMixin, TaskMixin, MetricMixin, tm.SumMetric):
|
|||||||
|
|
||||||
|
|
||||||
class Ctr(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
|
class Ctr(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
|
||||||
|
"""
|
||||||
|
Ctr (Click-Through Rate) metric class that inherits from StratifyMixin, TaskMixin, MetricMixin, and MeanMetric.
|
||||||
|
|
||||||
|
This metric calculates the mean metric value after potential stratification and task selection.
|
||||||
|
"""
|
||||||
|
|
||||||
def transform(self, outputs):
|
def transform(self, outputs):
|
||||||
outputs = self.maybe_apply_stratification(outputs, ["labels"])
|
outputs = self.maybe_apply_stratification(outputs, ["labels"])
|
||||||
value = outputs["labels"]
|
value = outputs["labels"]
|
||||||
@ -45,6 +66,11 @@ class Ctr(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
|
|||||||
|
|
||||||
|
|
||||||
class Pctr(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
|
class Pctr(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
|
||||||
|
"""
|
||||||
|
Pctr (Predicted Click-Through Rate) metric class that inherits from StratifyMixin, TaskMixin, MetricMixin, and MeanMetric.
|
||||||
|
|
||||||
|
This metric calculates the mean metric value using probabilities after potential stratification and task selection.
|
||||||
|
"""
|
||||||
def transform(self, outputs):
|
def transform(self, outputs):
|
||||||
outputs = self.maybe_apply_stratification(outputs, ["probabilities"])
|
outputs = self.maybe_apply_stratification(outputs, ["probabilities"])
|
||||||
value = outputs["probabilities"]
|
value = outputs["probabilities"]
|
||||||
@ -54,12 +80,22 @@ class Pctr(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
|
|||||||
|
|
||||||
|
|
||||||
class Precision(StratifyMixin, TaskMixin, MetricMixin, tm.Precision):
|
class Precision(StratifyMixin, TaskMixin, MetricMixin, tm.Precision):
|
||||||
|
"""
|
||||||
|
Precision metric class that inherits from StratifyMixin, TaskMixin, MetricMixin, and Precision.
|
||||||
|
|
||||||
|
This metric computes precision after potential stratification and task selection.
|
||||||
|
"""
|
||||||
def transform(self, outputs):
|
def transform(self, outputs):
|
||||||
outputs = self.maybe_apply_stratification(outputs, ["probabilities", "labels"])
|
outputs = self.maybe_apply_stratification(outputs, ["probabilities", "labels"])
|
||||||
return probs_and_labels(outputs, self._task_idx)
|
return probs_and_labels(outputs, self._task_idx)
|
||||||
|
|
||||||
|
|
||||||
class Recall(StratifyMixin, TaskMixin, MetricMixin, tm.Recall):
|
class Recall(StratifyMixin, TaskMixin, MetricMixin, tm.Recall):
|
||||||
|
"""
|
||||||
|
Recall metric class that inherits from StratifyMixin, TaskMixin, MetricMixin, and Recall.
|
||||||
|
|
||||||
|
This metric computes recall after potential stratification and task selection.
|
||||||
|
"""
|
||||||
def transform(self, outputs):
|
def transform(self, outputs):
|
||||||
outputs = self.maybe_apply_stratification(outputs, ["probabilities", "labels"])
|
outputs = self.maybe_apply_stratification(outputs, ["probabilities", "labels"])
|
||||||
return probs_and_labels(outputs, self._task_idx)
|
return probs_and_labels(outputs, self._task_idx)
|
||||||
@ -73,6 +109,14 @@ class TorchMetricsRocauc(StratifyMixin, TaskMixin, MetricMixin, tm.AUROC):
|
|||||||
|
|
||||||
class Auc(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
|
class Auc(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
|
||||||
"""
|
"""
|
||||||
|
AUC (Area Under the ROC Curve) metric class.
|
||||||
|
|
||||||
|
This metric computes the AUC metric based on the logits and labels in the model outputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_samples (int): The number of samples used to compute AUC.
|
||||||
|
**kwargs: Additional keyword arguments.
|
||||||
|
|
||||||
Based on:
|
Based on:
|
||||||
https://github.com/facebookresearch/PyTorch-BigGraph/blob/a11ff0eb644b7e4cb569067c280112b47f40ef62/torchbiggraph/util.py#L420
|
https://github.com/facebookresearch/PyTorch-BigGraph/blob/a11ff0eb644b7e4cb569067c280112b47f40ef62/torchbiggraph/util.py#L420
|
||||||
"""
|
"""
|
||||||
@ -94,8 +138,14 @@ class Auc(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
|
|||||||
|
|
||||||
class PosRanks(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
|
class PosRanks(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
|
||||||
"""
|
"""
|
||||||
The ranks of all positives
|
PosRanks metric class.
|
||||||
Based on:
|
|
||||||
|
This metric computes the ranks of all positive examples based on the logits and labels
|
||||||
|
in the model outputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
**kwargs: Additional keyword arguments.
|
||||||
|
|
||||||
https://github.com/facebookresearch/PyTorch-BigGraph/blob/a11ff0eb644b7e4cb569067c280112b47f40ef62/torchbiggraph/eval.py#L73
|
https://github.com/facebookresearch/PyTorch-BigGraph/blob/a11ff0eb644b7e4cb569067c280112b47f40ef62/torchbiggraph/eval.py#L73
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -112,8 +162,13 @@ class PosRanks(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
|
|||||||
|
|
||||||
class ReciprocalRank(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
|
class ReciprocalRank(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
|
||||||
"""
|
"""
|
||||||
The reciprocal of the ranks of all
|
ReciprocalRank metric class.
|
||||||
Based on:
|
|
||||||
|
This metric computes the reciprocal of the ranks of all positive examples based on the logits and labels
|
||||||
|
in the model outputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
**kwargs: Additional keyword arguments.
|
||||||
https://github.com/facebookresearch/PyTorch-BigGraph/blob/a11ff0eb644b7e4cb569067c280112b47f40ef62/torchbiggraph/eval.py#L74
|
https://github.com/facebookresearch/PyTorch-BigGraph/blob/a11ff0eb644b7e4cb569067c280112b47f40ef62/torchbiggraph/eval.py#L74
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -130,9 +185,14 @@ class ReciprocalRank(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
|
|||||||
|
|
||||||
class HitAtK(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
|
class HitAtK(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
|
||||||
"""
|
"""
|
||||||
The fraction of positives that rank in the top K among their negatives
|
HitAtK metric class.
|
||||||
Note that this is basically precision@k
|
|
||||||
Based on:
|
This metric computes the fraction of positive examples that rank in the top K among their negatives,
|
||||||
|
which is equivalent to precision@K.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
k (int): The value of K.
|
||||||
|
**kwargs: Additional keyword arguments.
|
||||||
https://github.com/facebookresearch/PyTorch-BigGraph/blob/a11ff0eb644b7e4cb569067c280112b47f40ef62/torchbiggraph/eval.py#L75
|
https://github.com/facebookresearch/PyTorch-BigGraph/blob/a11ff0eb644b7e4cb569067c280112b47f40ef62/torchbiggraph/eval.py#L75
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -9,12 +9,26 @@ from torchmetrics import MaxMetric, MetricCollection, SumMetric
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MockStratifierConfig:
|
class MockStratifierConfig:
|
||||||
|
"""
|
||||||
|
Configuration dataclass for mocking a stratifier.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str): The name of the stratifier.
|
||||||
|
index (int): The index of the stratifier.
|
||||||
|
value (int): The value of the stratifier.
|
||||||
|
"""
|
||||||
name: str
|
name: str
|
||||||
index: int
|
index: int
|
||||||
value: int
|
value: int
|
||||||
|
|
||||||
|
|
||||||
class Count(MetricMixin, SumMetric):
|
class Count(MetricMixin, SumMetric):
|
||||||
|
"""
|
||||||
|
Count metric class that inherits from MetricMixin and SumMetric.
|
||||||
|
|
||||||
|
This metric counts occurrences.
|
||||||
|
|
||||||
|
"""
|
||||||
def transform(self, outputs):
|
def transform(self, outputs):
|
||||||
return {"value": 1}
|
return {"value": 1}
|
||||||
|
|
||||||
@ -23,6 +37,12 @@ Max = prepend_transform(MaxMetric, lambda outputs: {"value": outputs["value"]})
|
|||||||
|
|
||||||
|
|
||||||
def test_count_metric():
|
def test_count_metric():
|
||||||
|
"""
|
||||||
|
Test function for the Count metric.
|
||||||
|
|
||||||
|
It checks if the Count metric correctly counts the number of examples.
|
||||||
|
|
||||||
|
"""
|
||||||
num_examples = 123
|
num_examples = 123
|
||||||
examples = [
|
examples = [
|
||||||
{"stuff": 0},
|
{"stuff": 0},
|
||||||
@ -36,6 +56,12 @@ def test_count_metric():
|
|||||||
|
|
||||||
|
|
||||||
def test_collections():
|
def test_collections():
|
||||||
|
"""
|
||||||
|
Test function for metric collections.
|
||||||
|
|
||||||
|
It tests if metric collections correctly aggregate metrics.
|
||||||
|
|
||||||
|
"""
|
||||||
max_metric = Max()
|
max_metric = Max()
|
||||||
count_metric = Count()
|
count_metric = Count()
|
||||||
metric = MetricCollection([max_metric, count_metric])
|
metric = MetricCollection([max_metric, count_metric])
|
||||||
@ -51,6 +77,12 @@ def test_collections():
|
|||||||
|
|
||||||
|
|
||||||
def test_task_dependent_ctr():
|
def test_task_dependent_ctr():
|
||||||
|
"""
|
||||||
|
Test function for task-dependent Ctr (Click-Through Rate) metric.
|
||||||
|
|
||||||
|
It checks if the Ctr metric computes the correct value for different tasks.
|
||||||
|
|
||||||
|
"""
|
||||||
num_examples = 144
|
num_examples = 144
|
||||||
batch_size = 1024
|
batch_size = 1024
|
||||||
outputs = [
|
outputs = [
|
||||||
@ -69,6 +101,13 @@ def test_task_dependent_ctr():
|
|||||||
|
|
||||||
|
|
||||||
def test_stratified_ctr():
|
def test_stratified_ctr():
|
||||||
|
"""
|
||||||
|
Test function for the Stratified Ctr (Click-Through Rate) metric.
|
||||||
|
|
||||||
|
It checks if the Stratified Ctr metric computes the correct value for different tasks
|
||||||
|
and stratified samples.
|
||||||
|
|
||||||
|
"""
|
||||||
outputs = [
|
outputs = [
|
||||||
{
|
{
|
||||||
"stuff": 0,
|
"stuff": 0,
|
||||||
@ -114,6 +153,12 @@ def test_stratified_ctr():
|
|||||||
|
|
||||||
|
|
||||||
def test_auc():
|
def test_auc():
|
||||||
|
"""
|
||||||
|
Test function for the AUC (Area Under the Curve) metric.
|
||||||
|
|
||||||
|
It checks if the AUC metric correctly computes the Area Under the ROC Curve.
|
||||||
|
|
||||||
|
"""
|
||||||
num_samples = 10000
|
num_samples = 10000
|
||||||
metric = core_metrics.Auc(num_samples)
|
metric = core_metrics.Auc(num_samples)
|
||||||
target = torch.tensor([0, 0, 1, 1, 1])
|
target = torch.tensor([0, 0, 1, 1, 1])
|
||||||
@ -131,6 +176,12 @@ def test_auc():
|
|||||||
|
|
||||||
|
|
||||||
def test_pos_rank():
|
def test_pos_rank():
|
||||||
|
"""
|
||||||
|
Test function for the PosRanks metric.
|
||||||
|
|
||||||
|
It checks if the PosRanks metric correctly computes the ranks of positive samples.
|
||||||
|
|
||||||
|
"""
|
||||||
metric = core_metrics.PosRanks()
|
metric = core_metrics.PosRanks()
|
||||||
target = torch.tensor([0, 0, 1, 1, 1])
|
target = torch.tensor([0, 0, 1, 1, 1])
|
||||||
preds_correct = torch.tensor([-1.0, -1.0, 0.5, 1.0, 1.5])
|
preds_correct = torch.tensor([-1.0, -1.0, 0.5, 1.0, 1.5])
|
||||||
@ -147,6 +198,12 @@ def test_pos_rank():
|
|||||||
|
|
||||||
|
|
||||||
def test_reciprocal_rank():
|
def test_reciprocal_rank():
|
||||||
|
"""
|
||||||
|
Test function for the Reciprocal Rank metric.
|
||||||
|
|
||||||
|
It checks if the Reciprocal Rank metric correctly computes the reciprocal of ranks.
|
||||||
|
|
||||||
|
"""
|
||||||
metric = core_metrics.ReciprocalRank()
|
metric = core_metrics.ReciprocalRank()
|
||||||
target = torch.tensor([0, 0, 1, 1, 1])
|
target = torch.tensor([0, 0, 1, 1, 1])
|
||||||
preds_correct = torch.tensor([-1.0, -1.0, 0.5, 1.0, 1.5])
|
preds_correct = torch.tensor([-1.0, -1.0, 0.5, 1.0, 1.5])
|
||||||
@ -163,6 +220,12 @@ def test_reciprocal_rank():
|
|||||||
|
|
||||||
|
|
||||||
def test_hit_k():
|
def test_hit_k():
|
||||||
|
"""
|
||||||
|
Test function for the Hit@K metric.
|
||||||
|
|
||||||
|
It checks if the Hit@K metric correctly computes the fraction of positives that rank in the top K among their negatives.
|
||||||
|
|
||||||
|
"""
|
||||||
hit1_metric = core_metrics.HitAtK(1)
|
hit1_metric = core_metrics.HitAtK(1)
|
||||||
target = torch.tensor([0, 0, 1, 1, 1])
|
target = torch.tensor([0, 0, 1, 1, 1])
|
||||||
preds_correct = torch.tensor([-1.0, 1.0, 0.5, -0.1, 1.5])
|
preds_correct = torch.tensor([-1.0, 1.0, 0.5, -0.1, 1.5])
|
||||||
|
@ -11,23 +11,60 @@ from torchrec.distributed import DistributedModelParallel
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MockDataclassBatch(DataclassBatch):
|
class MockDataclassBatch(DataclassBatch):
|
||||||
|
"""
|
||||||
|
Mock data class batch for testing purposes.
|
||||||
|
|
||||||
|
This class represents a batch of data with continuous features and labels.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
continuous_features (torch.Tensor): Tensor containing continuous feature data.
|
||||||
|
labels (torch.Tensor): Tensor containing label data.
|
||||||
|
"""
|
||||||
continuous_features: torch.Tensor
|
continuous_features: torch.Tensor
|
||||||
labels: torch.Tensor
|
labels: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
class MockModule(torch.nn.Module):
|
class MockModule(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Mock PyTorch module for testing purposes.
|
||||||
|
|
||||||
|
This module defines a simple neural network model with a linear layer
|
||||||
|
followed by a BCEWithLogitsLoss loss function.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
model (torch.nn.Linear): The linear model layer.
|
||||||
|
loss_fn (torch.nn.BCEWithLogitsLoss): Binary cross-entropy loss function.
|
||||||
|
"""
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.model = torch.nn.Linear(10, 1)
|
self.model = torch.nn.Linear(10, 1)
|
||||||
self.loss_fn = torch.nn.BCEWithLogitsLoss()
|
self.loss_fn = torch.nn.BCEWithLogitsLoss()
|
||||||
|
|
||||||
def forward(self, batch: MockDataclassBatch) -> Tuple[torch.Tensor, torch.Tensor]:
|
def forward(self, batch: MockDataclassBatch) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Forward pass of the mock module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch (MockDataclassBatch): Input data batch with continuous features and labels.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[torch.Tensor, torch.Tensor]: A tuple containing the loss and predictions.
|
||||||
|
"""
|
||||||
pred = self.model(batch.continuous_features)
|
pred = self.model(batch.continuous_features)
|
||||||
loss = self.loss_fn(pred, batch.labels)
|
loss = self.loss_fn(pred, batch.labels)
|
||||||
return (loss, pred)
|
return (loss, pred)
|
||||||
|
|
||||||
|
|
||||||
def create_batch(bsz: int):
|
def create_batch(bsz: int):
|
||||||
|
"""
|
||||||
|
Create a mock data batch with random continuous features and labels.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
bsz (int): Batch size.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
MockDataclassBatch: A batch of data with continuous features and labels.
|
||||||
|
"""
|
||||||
return MockDataclassBatch(
|
return MockDataclassBatch(
|
||||||
continuous_features=torch.rand(bsz, 10).float(),
|
continuous_features=torch.rand(bsz, 10).float(),
|
||||||
labels=torch.bernoulli(torch.empty(bsz, 1).uniform_(0, 1)).float(),
|
labels=torch.bernoulli(torch.empty(bsz, 1).uniform_(0, 1)).float(),
|
||||||
@ -35,6 +72,13 @@ def create_batch(bsz: int):
|
|||||||
|
|
||||||
|
|
||||||
def test_sparse_pipeline():
|
def test_sparse_pipeline():
|
||||||
|
"""
|
||||||
|
Test function for the sparse pipeline with distributed model parallelism.
|
||||||
|
|
||||||
|
This function tests the behavior of the sparse training pipeline using
|
||||||
|
a mock module and data.
|
||||||
|
"""
|
||||||
|
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
model = MockModule().to(device)
|
model = MockModule().to(device)
|
||||||
|
|
||||||
@ -65,6 +109,15 @@ def test_sparse_pipeline():
|
|||||||
|
|
||||||
|
|
||||||
def test_amp():
|
def test_amp():
|
||||||
|
"""
|
||||||
|
Test automatic mixed-precision (AMP) training with the sparse pipeline.
|
||||||
|
|
||||||
|
This function tests the behavior of the sparse training pipeline with
|
||||||
|
automatic mixed-precision (AMP) enabled, using a mock module and data.
|
||||||
|
|
||||||
|
AMP allows for faster training by using lower-precision data types, such as
|
||||||
|
torch.bfloat16, while maintaining model accuracy.
|
||||||
|
"""
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
model = MockModule().to(device)
|
model = MockModule().to(device)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user