core update

remaning train_pipline.py
This commit is contained in:
rajveer43 2023-09-11 16:26:29 +05:30
parent b85210863f
commit 799254345f
7 changed files with 366 additions and 29 deletions

View File

@ -46,7 +46,24 @@ def get_new_iterator(iterable: Iterable):
def _get_step_fn(pipeline, data_iterator, training: bool): def _get_step_fn(pipeline, data_iterator, training: bool):
"""
Returns a function to perform a single evaluation step.
Args:
pipeline (Pipeline): The pipeline object containing the model.
data_iterator (Iterator): The data iterator for evaluation.
training (bool): Flag indicating if the model should be in training mode.
Returns:
function: A function that performs a single evaluation step.
"""
def step_fn(): def step_fn():
"""
Perform a single evaluation step.
Returns:
Any: The evaluation results after a single step.
"""
# It turns out that model.train() and model.eval() simply switch a single field inside the model # It turns out that model.train() and model.eval() simply switch a single field inside the model
# class,so it's somewhat safer to wrap in here. # class,so it's somewhat safer to wrap in here.
if training: if training:
@ -69,7 +86,21 @@ def _run_evaluation(
eval_batch_size: int, eval_batch_size: int,
logger=None, logger=None,
): ):
"""Runs the evaluation loop over all evaluation iterators.""" """
Run the evaluation loop over all evaluation iterators.
Args:
pipeline (Pipeline): The pipeline object containing the model.
dataset (Dataset): The dataset to evaluate.
eval_steps (int): The number of evaluation steps to perform.
metrics (tm.MetricCollection): A collection of evaluation metrics.
eval_batch_size (int): Batch size for evaluation.
logger (Optional[Logger]): A logger for recording evaluation progress (default: None).
Returns:
dict: A dictionary containing the computed evaluation metrics.
"""
dataset = get_new_iterator(dataset) dataset = get_new_iterator(dataset)
step_fn = _get_step_fn(pipeline, dataset, training=False) step_fn = _get_step_fn(pipeline, dataset, training=False)
last_time = datetime.datetime.now() last_time = datetime.datetime.now()
@ -109,15 +140,29 @@ def train(
parameters_to_log: Optional[Dict[str, Callable]] = None, parameters_to_log: Optional[Dict[str, Callable]] = None,
tables_to_log: Optional[List[str]] = None, tables_to_log: Optional[List[str]] = None,
) -> None: ) -> None:
"""Runs training and eval on the given TrainPipeline
Args:
dataset: data iterator for the training set
evaluation_iterators: data iterators for the different evaluation sets
scheduler: optional learning rate scheduler
output_transform_for_metrics: optional transformation functions to transorm the model
output and labels into a format the metrics can understand
""" """
Runs training and evaluation on the given TrainPipeline.
Args:
model (torch.nn.Module): The neural network model to train.
optimizer (torch.optim.Optimizer): The optimizer for model optimization.
device (str): The target device for model training (e.g., 'cuda' or 'cpu').
save_dir (str): The directory to save model checkpoints and logs.
logging_interval (int): Interval for logging training progress.
train_steps (int): The number of training steps to perform.
checkpoint_frequency (int): Frequency of saving model checkpoints.
dataset (Iterable): Data iterator for the training set.
worker_batch_size (int): Batch size for data loading workers.
num_workers (Optional[int]): Number of data loading workers (default: 0).
enable_amp (bool): Flag to enable Automatic Mixed Precision (AMP) training (default: False).
initial_checkpoint_dir (Optional[str]): Directory to initialize training from (default: None).
gradient_accumulation (Optional[int]): Number of gradient accumulation steps (default: None).
logger_initializer (Optional[Callable]): A logger initializer function (default: None).
scheduler (_LRScheduler): Optional learning rate scheduler (default: None).
metrics (Optional[tm.MetricCollection]): A collection of evaluation metrics (default: None).
parameters_to_log (Optional[Dict[str, Callable]]): Dictionary of parameters to log (default: None).
tables_to_log (Optional[List[str]]): List of tables to log (default: None).
"""
train_pipeline = TrainPipelineSparseDist( train_pipeline = TrainPipelineSparseDist(
model=model, model=model,
@ -262,6 +307,15 @@ def log_eval_results(
partition_name: str, partition_name: str,
step: int, step: int,
): ):
"""
Logs evaluation results and optionally records them using a provided logger.
Args:
results (Any): The evaluation results to log.
eval_logger (Callable): A logger for recording evaluation results.
partition_name (str): The name of the evaluation partition.
step (int): The current step in the evaluation.
"""
results = tree.map_structure(lambda elem: torch.as_tensor(elem).cpu(), results) results = tree.map_structure(lambda elem: torch.as_tensor(elem).cpu(), results)
logging.info(f"Step: {step}, evaluation ({partition_name}).") logging.info(f"Step: {step}, evaluation ({partition_name}).")
for metric_name, metric_value in results.items(): for metric_name, metric_value in results.items():
@ -285,6 +339,23 @@ def only_evaluate(
partition_name: str, partition_name: str,
metrics: Optional[tm.MetricCollection] = None, metrics: Optional[tm.MetricCollection] = None,
): ):
"""
Performs evaluation on a given dataset partition.
Args:
model (torch.nn.Module): The neural network model for evaluation.
optimizer (torch.optim.Optimizer): The optimizer used during evaluation.
device (str): The target device for evaluation (e.g., 'cuda' or 'cpu').
save_dir (str): The directory containing model checkpoints.
num_train_steps (int): The total number of training steps.
dataset (Iterable): Data iterator for evaluation.
eval_batch_size (int): Batch size for evaluation.
num_eval_steps (int): The number of evaluation steps to perform.
eval_timeout_in_s (int): Timeout for evaluating checkpoints in seconds.
eval_logger (Callable): A logger for recording evaluation results.
partition_name (str): The name of the evaluation partition.
metrics (Optional[tm.MetricCollection]): A collection of evaluation metrics (default: None).
"""
logging.info(f"Evaluating on partition {partition_name}.") logging.info(f"Evaluating on partition {partition_name}.")
logging.info("Computing metrics:") logging.info("Computing metrics:")
logging.info(metrics) logging.info(metrics)

View File

@ -28,6 +28,18 @@ def train(
*args, *args,
**kwargs, **kwargs,
) -> None: ) -> None:
"""
Debugging training loop. Do not use for actual model training.
Args:
model (torch.nn.Module): The neural network model.
optimizer (torch.optim.Optimizer): The optimizer for model optimization.
train_steps (int): The number of training steps to perform.
dataset (Iterable): Data iterator for training data.
scheduler (_LRScheduler, optional): Learning rate scheduler (default: None).
*args: Additional arguments (ignored).
**kwargs: Additional keyword arguments (ignored).
"""
logging.warning("Running debug training loop, don't use for model training.") logging.warning("Running debug training loop, don't use for model training.")

View File

@ -10,8 +10,11 @@ import torch
def _maybe_warn(reduction: str): def _maybe_warn(reduction: str):
""" """
Warning for reduction different than mean. Emit a warning if the reduction method is different from 'mean'.
"""
Args:
reduction (str): The reduction method being used.
"""
if reduction != "mean": if reduction != "mean":
logging.warn( logging.warn(
f"For the same global_batch_size, the gradient in DDP is guaranteed to be equal," f"For the same global_batch_size, the gradient in DDP is guaranteed to be equal,"
@ -24,6 +27,16 @@ def build_loss(
loss_type: LossType, loss_type: LossType,
reduction="mean", reduction="mean",
): ):
"""
Build a loss function based on the specified loss type and reduction method.
Args:
loss_type (LossType): The type of loss to build.
reduction (str): The reduction method for the loss (default: 'mean').
Returns:
Callable: A loss function that takes logits and labels as input.
"""
_maybe_warn(reduction) _maybe_warn(reduction)
f = _LOSS_TYPE_TO_FUNCTION[loss_type] f = _LOSS_TYPE_TO_FUNCTION[loss_type]
@ -35,11 +48,15 @@ def build_loss(
def get_global_loss_detached(local_loss, reduction="mean"): def get_global_loss_detached(local_loss, reduction="mean"):
""" """
Perform all_reduce to obtain the global loss function using the provided reduction. Perform all_reduce to obtain the global loss function using the provided reduction.
:param local_loss: The local loss of the current rank.
:param reduction: The reduction to use for all_reduce. Should match the reduction used by DDP. Args:
:return: The reduced & detached global loss. local_loss (torch.Tensor): The local loss of the current rank.
""" reduction (str): The reduction to use for all_reduce. Should match the reduction used by DDP.
Returns:
torch.Tensor: The reduced and detached global loss.
"""
if reduction != "mean": if reduction != "mean":
logging.warn( logging.warn(
f"The reduction used in this function should be the same as the one used by " f"The reduction used in this function should be the same as the one used by "
@ -66,6 +83,19 @@ def build_multi_task_loss(
global_reduction="mean", global_reduction="mean",
pos_weights=None, pos_weights=None,
): ):
"""
Build a multi-task loss function based on the specified loss type and configurations.
Args:
loss_type (LossType): The type of loss to build.
tasks (typing.List[str]): List of task names.
task_loss_reduction (str): Reduction method for task-specific losses (default: 'mean').
global_reduction (str): Reduction method for the global loss (default: 'mean').
pos_weights (Optional): Positive class weights for tasks (default: None).
Returns:
Callable: A multi-task loss function that takes logits, labels, and weights as input.
"""
_maybe_warn(global_reduction) _maybe_warn(global_reduction)
_maybe_warn(task_loss_reduction) _maybe_warn(task_loss_reduction)
f = _LOSS_TYPE_TO_FUNCTION[loss_type] f = _LOSS_TYPE_TO_FUNCTION[loss_type]

View File

@ -36,9 +36,24 @@ import torchmetrics
class MetricMixin: class MetricMixin:
@abstractmethod @abstractmethod
def transform(self, outputs: Dict[str, torch.Tensor]) -> Dict: def transform(self, outputs: Dict[str, torch.Tensor]) -> Dict:
"""
Abstract method to transform model outputs into a dictionary of metrics.
Args:
outputs (Dict[str, torch.Tensor]): Model outputs.
Returns:
Dict: A dictionary of computed metrics.
"""
... ...
def update(self, outputs: Dict[str, torch.Tensor]): def update(self, outputs: Dict[str, torch.Tensor]):
"""
Update the metrics based on model outputs.
Args:
outputs (Dict[str, torch.Tensor]): Model outputs.
"""
results = self.transform(outputs) results = self.transform(outputs)
# Do not try to update if any tensor is empty as a result of stratification. # Do not try to update if any tensor is empty as a result of stratification.
for value in results.values(): for value in results.values():
@ -49,6 +64,13 @@ class MetricMixin:
class TaskMixin: class TaskMixin:
def __init__(self, task_idx: int = -1, **kwargs): def __init__(self, task_idx: int = -1, **kwargs):
"""
Initialize a TaskMixin instance.
Args:
task_idx (int): Index of the task associated with this mixin (default: -1).
**kwargs: Additional keyword arguments.
"""
super().__init__(**kwargs) super().__init__(**kwargs)
self._task_idx = task_idx self._task_idx = task_idx
@ -59,13 +81,31 @@ class StratifyMixin:
stratifier=None, stratifier=None,
**kwargs, **kwargs,
): ):
"""
Initialize a StratifyMixin instance.
Args:
stratifier: A stratifier for filtering outputs (default: None).
**kwargs: Additional keyword arguments.
"""
super().__init__(**kwargs) super().__init__(**kwargs)
self._stratifier = stratifier self._stratifier = stratifier
def maybe_apply_stratification( def maybe_apply_stratification(
self, outputs: Dict[str, torch.Tensor], value_names: List[str] self, outputs: Dict[str, torch.Tensor], value_names: List[str]
) -> Dict[str, torch.Tensor]: ) -> Dict[str, torch.Tensor]:
"""Pick out examples with values for which the stratifier feature is equal to a specific stratifier indicator value.""" """
Apply stratification to filter examples in the outputs.
Pick out examples with values for which the stratifier feature is equal to a specific stratifier indicator value.
Args:
outputs (Dict[str, torch.Tensor]): Model outputs.
value_names (List[str]): Names of values to filter.
Returns:
Dict[str, torch.Tensor]: Filtered outputs.
"""
outputs = outputs.copy() outputs = outputs.copy()
if not self._stratifier: if not self._stratifier:
return outputs return outputs
@ -84,12 +124,20 @@ class StratifyMixin:
def prepend_transform(base_metric: torchmetrics.Metric, transform: Callable): def prepend_transform(base_metric: torchmetrics.Metric, transform: Callable):
"""Returns new class using MetricMixin and given base_metric.
Functionally the same using inheritance, just saves some lines of code
if no need for class attributes.
""" """
Returns a new class using MetricMixin and the given base_metric.
Functionally the same as using inheritance, but it saves some lines of code
if there's no need for class attributes.
Args:
base_metric (torchmetrics.Metric): The base metric class to prepend the transform to.
transform (Callable): The transformation function to prepend to the metric.
Returns:
Type: A new class that includes MetricMixin and the provided base_metric
with the specified transformation method.
"""
def transform_method(_self, *args, **kwargs): def transform_method(_self, *args, **kwargs):
return transform(*args, **kwargs) return transform(*args, **kwargs)

View File

@ -15,6 +15,16 @@ def probs_and_labels(
outputs: Dict[str, torch.Tensor], outputs: Dict[str, torch.Tensor],
task_idx: int, task_idx: int,
) -> Dict[str, torch.Tensor]: ) -> Dict[str, torch.Tensor]:
"""
Extract probabilities and labels from model outputs.
Args:
outputs (Dict[str, torch.Tensor]): Model outputs.
task_idx (int): Index of the task.
Returns:
Dict[str, torch.Tensor]: Dictionary containing 'preds' and 'target' tensors.
"""
preds = outputs["probabilities"] preds = outputs["probabilities"]
target = outputs["labels"] target = outputs["labels"]
if task_idx >= 0: if task_idx >= 0:
@ -28,6 +38,11 @@ def probs_and_labels(
class Count(StratifyMixin, TaskMixin, MetricMixin, tm.SumMetric): class Count(StratifyMixin, TaskMixin, MetricMixin, tm.SumMetric):
def transform(self, outputs): def transform(self, outputs):
"""
Count metric class that inherits from StratifyMixin, TaskMixin, MetricMixin, and SumMetric.
This metric counts values after potential stratification and task selection.
"""
outputs = self.maybe_apply_stratification(outputs, ["labels"]) outputs = self.maybe_apply_stratification(outputs, ["labels"])
value = outputs["labels"] value = outputs["labels"]
if self._task_idx >= 0: if self._task_idx >= 0:
@ -36,6 +51,12 @@ class Count(StratifyMixin, TaskMixin, MetricMixin, tm.SumMetric):
class Ctr(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric): class Ctr(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
"""
Ctr (Click-Through Rate) metric class that inherits from StratifyMixin, TaskMixin, MetricMixin, and MeanMetric.
This metric calculates the mean metric value after potential stratification and task selection.
"""
def transform(self, outputs): def transform(self, outputs):
outputs = self.maybe_apply_stratification(outputs, ["labels"]) outputs = self.maybe_apply_stratification(outputs, ["labels"])
value = outputs["labels"] value = outputs["labels"]
@ -45,6 +66,11 @@ class Ctr(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
class Pctr(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric): class Pctr(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
"""
Pctr (Predicted Click-Through Rate) metric class that inherits from StratifyMixin, TaskMixin, MetricMixin, and MeanMetric.
This metric calculates the mean metric value using probabilities after potential stratification and task selection.
"""
def transform(self, outputs): def transform(self, outputs):
outputs = self.maybe_apply_stratification(outputs, ["probabilities"]) outputs = self.maybe_apply_stratification(outputs, ["probabilities"])
value = outputs["probabilities"] value = outputs["probabilities"]
@ -54,12 +80,22 @@ class Pctr(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
class Precision(StratifyMixin, TaskMixin, MetricMixin, tm.Precision): class Precision(StratifyMixin, TaskMixin, MetricMixin, tm.Precision):
"""
Precision metric class that inherits from StratifyMixin, TaskMixin, MetricMixin, and Precision.
This metric computes precision after potential stratification and task selection.
"""
def transform(self, outputs): def transform(self, outputs):
outputs = self.maybe_apply_stratification(outputs, ["probabilities", "labels"]) outputs = self.maybe_apply_stratification(outputs, ["probabilities", "labels"])
return probs_and_labels(outputs, self._task_idx) return probs_and_labels(outputs, self._task_idx)
class Recall(StratifyMixin, TaskMixin, MetricMixin, tm.Recall): class Recall(StratifyMixin, TaskMixin, MetricMixin, tm.Recall):
"""
Recall metric class that inherits from StratifyMixin, TaskMixin, MetricMixin, and Recall.
This metric computes recall after potential stratification and task selection.
"""
def transform(self, outputs): def transform(self, outputs):
outputs = self.maybe_apply_stratification(outputs, ["probabilities", "labels"]) outputs = self.maybe_apply_stratification(outputs, ["probabilities", "labels"])
return probs_and_labels(outputs, self._task_idx) return probs_and_labels(outputs, self._task_idx)
@ -73,6 +109,14 @@ class TorchMetricsRocauc(StratifyMixin, TaskMixin, MetricMixin, tm.AUROC):
class Auc(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric): class Auc(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
""" """
AUC (Area Under the ROC Curve) metric class.
This metric computes the AUC metric based on the logits and labels in the model outputs.
Args:
num_samples (int): The number of samples used to compute AUC.
**kwargs: Additional keyword arguments.
Based on: Based on:
https://github.com/facebookresearch/PyTorch-BigGraph/blob/a11ff0eb644b7e4cb569067c280112b47f40ef62/torchbiggraph/util.py#L420 https://github.com/facebookresearch/PyTorch-BigGraph/blob/a11ff0eb644b7e4cb569067c280112b47f40ef62/torchbiggraph/util.py#L420
""" """
@ -94,8 +138,14 @@ class Auc(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
class PosRanks(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric): class PosRanks(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
""" """
The ranks of all positives PosRanks metric class.
Based on:
This metric computes the ranks of all positive examples based on the logits and labels
in the model outputs.
Args:
**kwargs: Additional keyword arguments.
https://github.com/facebookresearch/PyTorch-BigGraph/blob/a11ff0eb644b7e4cb569067c280112b47f40ef62/torchbiggraph/eval.py#L73 https://github.com/facebookresearch/PyTorch-BigGraph/blob/a11ff0eb644b7e4cb569067c280112b47f40ef62/torchbiggraph/eval.py#L73
""" """
@ -112,8 +162,13 @@ class PosRanks(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
class ReciprocalRank(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric): class ReciprocalRank(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
""" """
The reciprocal of the ranks of all ReciprocalRank metric class.
Based on:
This metric computes the reciprocal of the ranks of all positive examples based on the logits and labels
in the model outputs.
Args:
**kwargs: Additional keyword arguments.
https://github.com/facebookresearch/PyTorch-BigGraph/blob/a11ff0eb644b7e4cb569067c280112b47f40ef62/torchbiggraph/eval.py#L74 https://github.com/facebookresearch/PyTorch-BigGraph/blob/a11ff0eb644b7e4cb569067c280112b47f40ef62/torchbiggraph/eval.py#L74
""" """
@ -130,9 +185,14 @@ class ReciprocalRank(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
class HitAtK(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric): class HitAtK(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
""" """
The fraction of positives that rank in the top K among their negatives HitAtK metric class.
Note that this is basically precision@k
Based on: This metric computes the fraction of positive examples that rank in the top K among their negatives,
which is equivalent to precision@K.
Args:
k (int): The value of K.
**kwargs: Additional keyword arguments.
https://github.com/facebookresearch/PyTorch-BigGraph/blob/a11ff0eb644b7e4cb569067c280112b47f40ef62/torchbiggraph/eval.py#L75 https://github.com/facebookresearch/PyTorch-BigGraph/blob/a11ff0eb644b7e4cb569067c280112b47f40ef62/torchbiggraph/eval.py#L75
""" """

View File

@ -9,12 +9,26 @@ from torchmetrics import MaxMetric, MetricCollection, SumMetric
@dataclass @dataclass
class MockStratifierConfig: class MockStratifierConfig:
"""
Configuration dataclass for mocking a stratifier.
Args:
name (str): The name of the stratifier.
index (int): The index of the stratifier.
value (int): The value of the stratifier.
"""
name: str name: str
index: int index: int
value: int value: int
class Count(MetricMixin, SumMetric): class Count(MetricMixin, SumMetric):
"""
Count metric class that inherits from MetricMixin and SumMetric.
This metric counts occurrences.
"""
def transform(self, outputs): def transform(self, outputs):
return {"value": 1} return {"value": 1}
@ -23,6 +37,12 @@ Max = prepend_transform(MaxMetric, lambda outputs: {"value": outputs["value"]})
def test_count_metric(): def test_count_metric():
"""
Test function for the Count metric.
It checks if the Count metric correctly counts the number of examples.
"""
num_examples = 123 num_examples = 123
examples = [ examples = [
{"stuff": 0}, {"stuff": 0},
@ -36,6 +56,12 @@ def test_count_metric():
def test_collections(): def test_collections():
"""
Test function for metric collections.
It tests if metric collections correctly aggregate metrics.
"""
max_metric = Max() max_metric = Max()
count_metric = Count() count_metric = Count()
metric = MetricCollection([max_metric, count_metric]) metric = MetricCollection([max_metric, count_metric])
@ -51,6 +77,12 @@ def test_collections():
def test_task_dependent_ctr(): def test_task_dependent_ctr():
"""
Test function for task-dependent Ctr (Click-Through Rate) metric.
It checks if the Ctr metric computes the correct value for different tasks.
"""
num_examples = 144 num_examples = 144
batch_size = 1024 batch_size = 1024
outputs = [ outputs = [
@ -69,6 +101,13 @@ def test_task_dependent_ctr():
def test_stratified_ctr(): def test_stratified_ctr():
"""
Test function for the Stratified Ctr (Click-Through Rate) metric.
It checks if the Stratified Ctr metric computes the correct value for different tasks
and stratified samples.
"""
outputs = [ outputs = [
{ {
"stuff": 0, "stuff": 0,
@ -114,6 +153,12 @@ def test_stratified_ctr():
def test_auc(): def test_auc():
"""
Test function for the AUC (Area Under the Curve) metric.
It checks if the AUC metric correctly computes the Area Under the ROC Curve.
"""
num_samples = 10000 num_samples = 10000
metric = core_metrics.Auc(num_samples) metric = core_metrics.Auc(num_samples)
target = torch.tensor([0, 0, 1, 1, 1]) target = torch.tensor([0, 0, 1, 1, 1])
@ -131,6 +176,12 @@ def test_auc():
def test_pos_rank(): def test_pos_rank():
"""
Test function for the PosRanks metric.
It checks if the PosRanks metric correctly computes the ranks of positive samples.
"""
metric = core_metrics.PosRanks() metric = core_metrics.PosRanks()
target = torch.tensor([0, 0, 1, 1, 1]) target = torch.tensor([0, 0, 1, 1, 1])
preds_correct = torch.tensor([-1.0, -1.0, 0.5, 1.0, 1.5]) preds_correct = torch.tensor([-1.0, -1.0, 0.5, 1.0, 1.5])
@ -147,6 +198,12 @@ def test_pos_rank():
def test_reciprocal_rank(): def test_reciprocal_rank():
"""
Test function for the Reciprocal Rank metric.
It checks if the Reciprocal Rank metric correctly computes the reciprocal of ranks.
"""
metric = core_metrics.ReciprocalRank() metric = core_metrics.ReciprocalRank()
target = torch.tensor([0, 0, 1, 1, 1]) target = torch.tensor([0, 0, 1, 1, 1])
preds_correct = torch.tensor([-1.0, -1.0, 0.5, 1.0, 1.5]) preds_correct = torch.tensor([-1.0, -1.0, 0.5, 1.0, 1.5])
@ -163,6 +220,12 @@ def test_reciprocal_rank():
def test_hit_k(): def test_hit_k():
"""
Test function for the Hit@K metric.
It checks if the Hit@K metric correctly computes the fraction of positives that rank in the top K among their negatives.
"""
hit1_metric = core_metrics.HitAtK(1) hit1_metric = core_metrics.HitAtK(1)
target = torch.tensor([0, 0, 1, 1, 1]) target = torch.tensor([0, 0, 1, 1, 1])
preds_correct = torch.tensor([-1.0, 1.0, 0.5, -0.1, 1.5]) preds_correct = torch.tensor([-1.0, 1.0, 0.5, -0.1, 1.5])

View File

@ -11,23 +11,60 @@ from torchrec.distributed import DistributedModelParallel
@dataclass @dataclass
class MockDataclassBatch(DataclassBatch): class MockDataclassBatch(DataclassBatch):
"""
Mock data class batch for testing purposes.
This class represents a batch of data with continuous features and labels.
Attributes:
continuous_features (torch.Tensor): Tensor containing continuous feature data.
labels (torch.Tensor): Tensor containing label data.
"""
continuous_features: torch.Tensor continuous_features: torch.Tensor
labels: torch.Tensor labels: torch.Tensor
class MockModule(torch.nn.Module): class MockModule(torch.nn.Module):
"""
Mock PyTorch module for testing purposes.
This module defines a simple neural network model with a linear layer
followed by a BCEWithLogitsLoss loss function.
Attributes:
model (torch.nn.Linear): The linear model layer.
loss_fn (torch.nn.BCEWithLogitsLoss): Binary cross-entropy loss function.
"""
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
self.model = torch.nn.Linear(10, 1) self.model = torch.nn.Linear(10, 1)
self.loss_fn = torch.nn.BCEWithLogitsLoss() self.loss_fn = torch.nn.BCEWithLogitsLoss()
def forward(self, batch: MockDataclassBatch) -> Tuple[torch.Tensor, torch.Tensor]: def forward(self, batch: MockDataclassBatch) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Forward pass of the mock module.
Args:
batch (MockDataclassBatch): Input data batch with continuous features and labels.
Returns:
Tuple[torch.Tensor, torch.Tensor]: A tuple containing the loss and predictions.
"""
pred = self.model(batch.continuous_features) pred = self.model(batch.continuous_features)
loss = self.loss_fn(pred, batch.labels) loss = self.loss_fn(pred, batch.labels)
return (loss, pred) return (loss, pred)
def create_batch(bsz: int): def create_batch(bsz: int):
"""
Create a mock data batch with random continuous features and labels.
Args:
bsz (int): Batch size.
Returns:
MockDataclassBatch: A batch of data with continuous features and labels.
"""
return MockDataclassBatch( return MockDataclassBatch(
continuous_features=torch.rand(bsz, 10).float(), continuous_features=torch.rand(bsz, 10).float(),
labels=torch.bernoulli(torch.empty(bsz, 1).uniform_(0, 1)).float(), labels=torch.bernoulli(torch.empty(bsz, 1).uniform_(0, 1)).float(),
@ -35,6 +72,13 @@ def create_batch(bsz: int):
def test_sparse_pipeline(): def test_sparse_pipeline():
"""
Test function for the sparse pipeline with distributed model parallelism.
This function tests the behavior of the sparse training pipeline using
a mock module and data.
"""
device = torch.device("cpu") device = torch.device("cpu")
model = MockModule().to(device) model = MockModule().to(device)
@ -65,6 +109,15 @@ def test_sparse_pipeline():
def test_amp(): def test_amp():
"""
Test automatic mixed-precision (AMP) training with the sparse pipeline.
This function tests the behavior of the sparse training pipeline with
automatic mixed-precision (AMP) enabled, using a mock module and data.
AMP allows for faster training by using lower-precision data types, such as
torch.bfloat16, while maintaining model accuracy.
"""
device = torch.device("cpu") device = torch.device("cpu")
model = MockModule().to(device) model = MockModule().to(device)