From 799254345fa542ff6bd645d5eccfcaf767901644 Mon Sep 17 00:00:00 2001 From: rajveer43 Date: Mon, 11 Sep 2023 16:26:29 +0530 Subject: [PATCH] `core` update remaning train_pipline.py --- core/custom_training_loop.py | 89 ++++++++++++++++++++++++++++++++---- core/debug_training_loop.py | 12 +++++ core/losses.py | 44 +++++++++++++++--- core/metric_mixin.py | 60 +++++++++++++++++++++--- core/metrics.py | 74 +++++++++++++++++++++++++++--- core/test_metrics.py | 63 +++++++++++++++++++++++++ core/test_train_pipeline.py | 53 +++++++++++++++++++++ 7 files changed, 366 insertions(+), 29 deletions(-) diff --git a/core/custom_training_loop.py b/core/custom_training_loop.py index 0241145..73b2cf1 100644 --- a/core/custom_training_loop.py +++ b/core/custom_training_loop.py @@ -46,7 +46,24 @@ def get_new_iterator(iterable: Iterable): def _get_step_fn(pipeline, data_iterator, training: bool): + """ + Returns a function to perform a single evaluation step. + + Args: + pipeline (Pipeline): The pipeline object containing the model. + data_iterator (Iterator): The data iterator for evaluation. + training (bool): Flag indicating if the model should be in training mode. + + Returns: + function: A function that performs a single evaluation step. + """ def step_fn(): + """ + Perform a single evaluation step. + + Returns: + Any: The evaluation results after a single step. + """ # It turns out that model.train() and model.eval() simply switch a single field inside the model # class,so it's somewhat safer to wrap in here. if training: @@ -69,7 +86,21 @@ def _run_evaluation( eval_batch_size: int, logger=None, ): - """Runs the evaluation loop over all evaluation iterators.""" + """ + Run the evaluation loop over all evaluation iterators. + + Args: + pipeline (Pipeline): The pipeline object containing the model. + dataset (Dataset): The dataset to evaluate. + eval_steps (int): The number of evaluation steps to perform. + metrics (tm.MetricCollection): A collection of evaluation metrics. + eval_batch_size (int): Batch size for evaluation. + logger (Optional[Logger]): A logger for recording evaluation progress (default: None). + + Returns: + dict: A dictionary containing the computed evaluation metrics. + """ + dataset = get_new_iterator(dataset) step_fn = _get_step_fn(pipeline, dataset, training=False) last_time = datetime.datetime.now() @@ -109,15 +140,29 @@ def train( parameters_to_log: Optional[Dict[str, Callable]] = None, tables_to_log: Optional[List[str]] = None, ) -> None: - """Runs training and eval on the given TrainPipeline - - Args: - dataset: data iterator for the training set - evaluation_iterators: data iterators for the different evaluation sets - scheduler: optional learning rate scheduler - output_transform_for_metrics: optional transformation functions to transorm the model - output and labels into a format the metrics can understand """ + Runs training and evaluation on the given TrainPipeline. + + Args: + model (torch.nn.Module): The neural network model to train. + optimizer (torch.optim.Optimizer): The optimizer for model optimization. + device (str): The target device for model training (e.g., 'cuda' or 'cpu'). + save_dir (str): The directory to save model checkpoints and logs. + logging_interval (int): Interval for logging training progress. + train_steps (int): The number of training steps to perform. + checkpoint_frequency (int): Frequency of saving model checkpoints. + dataset (Iterable): Data iterator for the training set. + worker_batch_size (int): Batch size for data loading workers. + num_workers (Optional[int]): Number of data loading workers (default: 0). + enable_amp (bool): Flag to enable Automatic Mixed Precision (AMP) training (default: False). + initial_checkpoint_dir (Optional[str]): Directory to initialize training from (default: None). + gradient_accumulation (Optional[int]): Number of gradient accumulation steps (default: None). + logger_initializer (Optional[Callable]): A logger initializer function (default: None). + scheduler (_LRScheduler): Optional learning rate scheduler (default: None). + metrics (Optional[tm.MetricCollection]): A collection of evaluation metrics (default: None). + parameters_to_log (Optional[Dict[str, Callable]]): Dictionary of parameters to log (default: None). + tables_to_log (Optional[List[str]]): List of tables to log (default: None). + """ train_pipeline = TrainPipelineSparseDist( model=model, @@ -262,6 +307,15 @@ def log_eval_results( partition_name: str, step: int, ): + """ + Logs evaluation results and optionally records them using a provided logger. + + Args: + results (Any): The evaluation results to log. + eval_logger (Callable): A logger for recording evaluation results. + partition_name (str): The name of the evaluation partition. + step (int): The current step in the evaluation. + """ results = tree.map_structure(lambda elem: torch.as_tensor(elem).cpu(), results) logging.info(f"Step: {step}, evaluation ({partition_name}).") for metric_name, metric_value in results.items(): @@ -285,6 +339,23 @@ def only_evaluate( partition_name: str, metrics: Optional[tm.MetricCollection] = None, ): + """ + Performs evaluation on a given dataset partition. + + Args: + model (torch.nn.Module): The neural network model for evaluation. + optimizer (torch.optim.Optimizer): The optimizer used during evaluation. + device (str): The target device for evaluation (e.g., 'cuda' or 'cpu'). + save_dir (str): The directory containing model checkpoints. + num_train_steps (int): The total number of training steps. + dataset (Iterable): Data iterator for evaluation. + eval_batch_size (int): Batch size for evaluation. + num_eval_steps (int): The number of evaluation steps to perform. + eval_timeout_in_s (int): Timeout for evaluating checkpoints in seconds. + eval_logger (Callable): A logger for recording evaluation results. + partition_name (str): The name of the evaluation partition. + metrics (Optional[tm.MetricCollection]): A collection of evaluation metrics (default: None). + """ logging.info(f"Evaluating on partition {partition_name}.") logging.info("Computing metrics:") logging.info(metrics) diff --git a/core/debug_training_loop.py b/core/debug_training_loop.py index 610eea9..bced83f 100644 --- a/core/debug_training_loop.py +++ b/core/debug_training_loop.py @@ -28,6 +28,18 @@ def train( *args, **kwargs, ) -> None: + """ + Debugging training loop. Do not use for actual model training. + + Args: + model (torch.nn.Module): The neural network model. + optimizer (torch.optim.Optimizer): The optimizer for model optimization. + train_steps (int): The number of training steps to perform. + dataset (Iterable): Data iterator for training data. + scheduler (_LRScheduler, optional): Learning rate scheduler (default: None). + *args: Additional arguments (ignored). + **kwargs: Additional keyword arguments (ignored). + """ logging.warning("Running debug training loop, don't use for model training.") diff --git a/core/losses.py b/core/losses.py index 6ef9a4a..7cf0cf1 100644 --- a/core/losses.py +++ b/core/losses.py @@ -10,8 +10,11 @@ import torch def _maybe_warn(reduction: str): """ - Warning for reduction different than mean. - """ + Emit a warning if the reduction method is different from 'mean'. + + Args: + reduction (str): The reduction method being used. + """ if reduction != "mean": logging.warn( f"For the same global_batch_size, the gradient in DDP is guaranteed to be equal," @@ -24,6 +27,16 @@ def build_loss( loss_type: LossType, reduction="mean", ): + """ + Build a loss function based on the specified loss type and reduction method. + + Args: + loss_type (LossType): The type of loss to build. + reduction (str): The reduction method for the loss (default: 'mean'). + + Returns: + Callable: A loss function that takes logits and labels as input. + """ _maybe_warn(reduction) f = _LOSS_TYPE_TO_FUNCTION[loss_type] @@ -35,11 +48,15 @@ def build_loss( def get_global_loss_detached(local_loss, reduction="mean"): """ - Perform all_reduce to obtain the global loss function using the provided reduction. - :param local_loss: The local loss of the current rank. - :param reduction: The reduction to use for all_reduce. Should match the reduction used by DDP. - :return: The reduced & detached global loss. - """ + Perform all_reduce to obtain the global loss function using the provided reduction. + + Args: + local_loss (torch.Tensor): The local loss of the current rank. + reduction (str): The reduction to use for all_reduce. Should match the reduction used by DDP. + + Returns: + torch.Tensor: The reduced and detached global loss. + """ if reduction != "mean": logging.warn( f"The reduction used in this function should be the same as the one used by " @@ -66,6 +83,19 @@ def build_multi_task_loss( global_reduction="mean", pos_weights=None, ): + """ + Build a multi-task loss function based on the specified loss type and configurations. + + Args: + loss_type (LossType): The type of loss to build. + tasks (typing.List[str]): List of task names. + task_loss_reduction (str): Reduction method for task-specific losses (default: 'mean'). + global_reduction (str): Reduction method for the global loss (default: 'mean'). + pos_weights (Optional): Positive class weights for tasks (default: None). + + Returns: + Callable: A multi-task loss function that takes logits, labels, and weights as input. + """ _maybe_warn(global_reduction) _maybe_warn(task_loss_reduction) f = _LOSS_TYPE_TO_FUNCTION[loss_type] diff --git a/core/metric_mixin.py b/core/metric_mixin.py index a716ca7..def38cb 100644 --- a/core/metric_mixin.py +++ b/core/metric_mixin.py @@ -36,9 +36,24 @@ import torchmetrics class MetricMixin: @abstractmethod def transform(self, outputs: Dict[str, torch.Tensor]) -> Dict: + """ + Abstract method to transform model outputs into a dictionary of metrics. + + Args: + outputs (Dict[str, torch.Tensor]): Model outputs. + + Returns: + Dict: A dictionary of computed metrics. + """ ... def update(self, outputs: Dict[str, torch.Tensor]): + """ + Update the metrics based on model outputs. + + Args: + outputs (Dict[str, torch.Tensor]): Model outputs. + """ results = self.transform(outputs) # Do not try to update if any tensor is empty as a result of stratification. for value in results.values(): @@ -49,6 +64,13 @@ class MetricMixin: class TaskMixin: def __init__(self, task_idx: int = -1, **kwargs): + """ + Initialize a TaskMixin instance. + + Args: + task_idx (int): Index of the task associated with this mixin (default: -1). + **kwargs: Additional keyword arguments. + """ super().__init__(**kwargs) self._task_idx = task_idx @@ -59,13 +81,31 @@ class StratifyMixin: stratifier=None, **kwargs, ): + """ + Initialize a StratifyMixin instance. + + Args: + stratifier: A stratifier for filtering outputs (default: None). + **kwargs: Additional keyword arguments. + """ super().__init__(**kwargs) self._stratifier = stratifier def maybe_apply_stratification( self, outputs: Dict[str, torch.Tensor], value_names: List[str] ) -> Dict[str, torch.Tensor]: - """Pick out examples with values for which the stratifier feature is equal to a specific stratifier indicator value.""" + """ + Apply stratification to filter examples in the outputs. + + Pick out examples with values for which the stratifier feature is equal to a specific stratifier indicator value. + + Args: + outputs (Dict[str, torch.Tensor]): Model outputs. + value_names (List[str]): Names of values to filter. + + Returns: + Dict[str, torch.Tensor]: Filtered outputs. + """ outputs = outputs.copy() if not self._stratifier: return outputs @@ -84,12 +124,20 @@ class StratifyMixin: def prepend_transform(base_metric: torchmetrics.Metric, transform: Callable): - """Returns new class using MetricMixin and given base_metric. - - Functionally the same using inheritance, just saves some lines of code - if no need for class attributes. - """ + Returns a new class using MetricMixin and the given base_metric. + + Functionally the same as using inheritance, but it saves some lines of code + if there's no need for class attributes. + + Args: + base_metric (torchmetrics.Metric): The base metric class to prepend the transform to. + transform (Callable): The transformation function to prepend to the metric. + + Returns: + Type: A new class that includes MetricMixin and the provided base_metric + with the specified transformation method. + """ def transform_method(_self, *args, **kwargs): return transform(*args, **kwargs) diff --git a/core/metrics.py b/core/metrics.py index 2384e4d..0a9c38b 100644 --- a/core/metrics.py +++ b/core/metrics.py @@ -15,6 +15,16 @@ def probs_and_labels( outputs: Dict[str, torch.Tensor], task_idx: int, ) -> Dict[str, torch.Tensor]: + """ + Extract probabilities and labels from model outputs. + + Args: + outputs (Dict[str, torch.Tensor]): Model outputs. + task_idx (int): Index of the task. + + Returns: + Dict[str, torch.Tensor]: Dictionary containing 'preds' and 'target' tensors. + """ preds = outputs["probabilities"] target = outputs["labels"] if task_idx >= 0: @@ -28,6 +38,11 @@ def probs_and_labels( class Count(StratifyMixin, TaskMixin, MetricMixin, tm.SumMetric): def transform(self, outputs): + """ + Count metric class that inherits from StratifyMixin, TaskMixin, MetricMixin, and SumMetric. + + This metric counts values after potential stratification and task selection. + """ outputs = self.maybe_apply_stratification(outputs, ["labels"]) value = outputs["labels"] if self._task_idx >= 0: @@ -36,6 +51,12 @@ class Count(StratifyMixin, TaskMixin, MetricMixin, tm.SumMetric): class Ctr(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric): + """ + Ctr (Click-Through Rate) metric class that inherits from StratifyMixin, TaskMixin, MetricMixin, and MeanMetric. + + This metric calculates the mean metric value after potential stratification and task selection. + """ + def transform(self, outputs): outputs = self.maybe_apply_stratification(outputs, ["labels"]) value = outputs["labels"] @@ -45,6 +66,11 @@ class Ctr(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric): class Pctr(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric): + """ + Pctr (Predicted Click-Through Rate) metric class that inherits from StratifyMixin, TaskMixin, MetricMixin, and MeanMetric. + + This metric calculates the mean metric value using probabilities after potential stratification and task selection. + """ def transform(self, outputs): outputs = self.maybe_apply_stratification(outputs, ["probabilities"]) value = outputs["probabilities"] @@ -54,12 +80,22 @@ class Pctr(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric): class Precision(StratifyMixin, TaskMixin, MetricMixin, tm.Precision): + """ + Precision metric class that inherits from StratifyMixin, TaskMixin, MetricMixin, and Precision. + + This metric computes precision after potential stratification and task selection. + """ def transform(self, outputs): outputs = self.maybe_apply_stratification(outputs, ["probabilities", "labels"]) return probs_and_labels(outputs, self._task_idx) class Recall(StratifyMixin, TaskMixin, MetricMixin, tm.Recall): + """ + Recall metric class that inherits from StratifyMixin, TaskMixin, MetricMixin, and Recall. + + This metric computes recall after potential stratification and task selection. + """ def transform(self, outputs): outputs = self.maybe_apply_stratification(outputs, ["probabilities", "labels"]) return probs_and_labels(outputs, self._task_idx) @@ -73,6 +109,14 @@ class TorchMetricsRocauc(StratifyMixin, TaskMixin, MetricMixin, tm.AUROC): class Auc(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric): """ + AUC (Area Under the ROC Curve) metric class. + + This metric computes the AUC metric based on the logits and labels in the model outputs. + + Args: + num_samples (int): The number of samples used to compute AUC. + **kwargs: Additional keyword arguments. + Based on: https://github.com/facebookresearch/PyTorch-BigGraph/blob/a11ff0eb644b7e4cb569067c280112b47f40ef62/torchbiggraph/util.py#L420 """ @@ -94,8 +138,14 @@ class Auc(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric): class PosRanks(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric): """ - The ranks of all positives - Based on: + PosRanks metric class. + + This metric computes the ranks of all positive examples based on the logits and labels + in the model outputs. + + Args: + **kwargs: Additional keyword arguments. + https://github.com/facebookresearch/PyTorch-BigGraph/blob/a11ff0eb644b7e4cb569067c280112b47f40ef62/torchbiggraph/eval.py#L73 """ @@ -112,8 +162,13 @@ class PosRanks(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric): class ReciprocalRank(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric): """ - The reciprocal of the ranks of all - Based on: + ReciprocalRank metric class. + + This metric computes the reciprocal of the ranks of all positive examples based on the logits and labels + in the model outputs. + + Args: + **kwargs: Additional keyword arguments. https://github.com/facebookresearch/PyTorch-BigGraph/blob/a11ff0eb644b7e4cb569067c280112b47f40ef62/torchbiggraph/eval.py#L74 """ @@ -130,9 +185,14 @@ class ReciprocalRank(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric): class HitAtK(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric): """ - The fraction of positives that rank in the top K among their negatives - Note that this is basically precision@k - Based on: + HitAtK metric class. + + This metric computes the fraction of positive examples that rank in the top K among their negatives, + which is equivalent to precision@K. + + Args: + k (int): The value of K. + **kwargs: Additional keyword arguments. https://github.com/facebookresearch/PyTorch-BigGraph/blob/a11ff0eb644b7e4cb569067c280112b47f40ef62/torchbiggraph/eval.py#L75 """ diff --git a/core/test_metrics.py b/core/test_metrics.py index ac29819..c9a34e8 100644 --- a/core/test_metrics.py +++ b/core/test_metrics.py @@ -9,12 +9,26 @@ from torchmetrics import MaxMetric, MetricCollection, SumMetric @dataclass class MockStratifierConfig: + """ + Configuration dataclass for mocking a stratifier. + + Args: + name (str): The name of the stratifier. + index (int): The index of the stratifier. + value (int): The value of the stratifier. + """ name: str index: int value: int class Count(MetricMixin, SumMetric): + """ + Count metric class that inherits from MetricMixin and SumMetric. + + This metric counts occurrences. + + """ def transform(self, outputs): return {"value": 1} @@ -23,6 +37,12 @@ Max = prepend_transform(MaxMetric, lambda outputs: {"value": outputs["value"]}) def test_count_metric(): + """ + Test function for the Count metric. + + It checks if the Count metric correctly counts the number of examples. + + """ num_examples = 123 examples = [ {"stuff": 0}, @@ -36,6 +56,12 @@ def test_count_metric(): def test_collections(): + """ + Test function for metric collections. + + It tests if metric collections correctly aggregate metrics. + + """ max_metric = Max() count_metric = Count() metric = MetricCollection([max_metric, count_metric]) @@ -51,6 +77,12 @@ def test_collections(): def test_task_dependent_ctr(): + """ + Test function for task-dependent Ctr (Click-Through Rate) metric. + + It checks if the Ctr metric computes the correct value for different tasks. + + """ num_examples = 144 batch_size = 1024 outputs = [ @@ -69,6 +101,13 @@ def test_task_dependent_ctr(): def test_stratified_ctr(): + """ + Test function for the Stratified Ctr (Click-Through Rate) metric. + + It checks if the Stratified Ctr metric computes the correct value for different tasks + and stratified samples. + + """ outputs = [ { "stuff": 0, @@ -114,6 +153,12 @@ def test_stratified_ctr(): def test_auc(): + """ + Test function for the AUC (Area Under the Curve) metric. + + It checks if the AUC metric correctly computes the Area Under the ROC Curve. + + """ num_samples = 10000 metric = core_metrics.Auc(num_samples) target = torch.tensor([0, 0, 1, 1, 1]) @@ -131,6 +176,12 @@ def test_auc(): def test_pos_rank(): + """ + Test function for the PosRanks metric. + + It checks if the PosRanks metric correctly computes the ranks of positive samples. + + """ metric = core_metrics.PosRanks() target = torch.tensor([0, 0, 1, 1, 1]) preds_correct = torch.tensor([-1.0, -1.0, 0.5, 1.0, 1.5]) @@ -147,6 +198,12 @@ def test_pos_rank(): def test_reciprocal_rank(): + """ + Test function for the Reciprocal Rank metric. + + It checks if the Reciprocal Rank metric correctly computes the reciprocal of ranks. + + """ metric = core_metrics.ReciprocalRank() target = torch.tensor([0, 0, 1, 1, 1]) preds_correct = torch.tensor([-1.0, -1.0, 0.5, 1.0, 1.5]) @@ -163,6 +220,12 @@ def test_reciprocal_rank(): def test_hit_k(): + """ + Test function for the Hit@K metric. + + It checks if the Hit@K metric correctly computes the fraction of positives that rank in the top K among their negatives. + + """ hit1_metric = core_metrics.HitAtK(1) target = torch.tensor([0, 0, 1, 1, 1]) preds_correct = torch.tensor([-1.0, 1.0, 0.5, -0.1, 1.5]) diff --git a/core/test_train_pipeline.py b/core/test_train_pipeline.py index 8e2f6f5..ee9f401 100644 --- a/core/test_train_pipeline.py +++ b/core/test_train_pipeline.py @@ -11,23 +11,60 @@ from torchrec.distributed import DistributedModelParallel @dataclass class MockDataclassBatch(DataclassBatch): + """ + Mock data class batch for testing purposes. + + This class represents a batch of data with continuous features and labels. + + Attributes: + continuous_features (torch.Tensor): Tensor containing continuous feature data. + labels (torch.Tensor): Tensor containing label data. + """ continuous_features: torch.Tensor labels: torch.Tensor class MockModule(torch.nn.Module): + """ + Mock PyTorch module for testing purposes. + + This module defines a simple neural network model with a linear layer + followed by a BCEWithLogitsLoss loss function. + + Attributes: + model (torch.nn.Linear): The linear model layer. + loss_fn (torch.nn.BCEWithLogitsLoss): Binary cross-entropy loss function. + """ def __init__(self) -> None: super().__init__() self.model = torch.nn.Linear(10, 1) self.loss_fn = torch.nn.BCEWithLogitsLoss() def forward(self, batch: MockDataclassBatch) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass of the mock module. + + Args: + batch (MockDataclassBatch): Input data batch with continuous features and labels. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing the loss and predictions. + """ pred = self.model(batch.continuous_features) loss = self.loss_fn(pred, batch.labels) return (loss, pred) def create_batch(bsz: int): + """ + Create a mock data batch with random continuous features and labels. + + Args: + bsz (int): Batch size. + + Returns: + MockDataclassBatch: A batch of data with continuous features and labels. + """ return MockDataclassBatch( continuous_features=torch.rand(bsz, 10).float(), labels=torch.bernoulli(torch.empty(bsz, 1).uniform_(0, 1)).float(), @@ -35,6 +72,13 @@ def create_batch(bsz: int): def test_sparse_pipeline(): + """ + Test function for the sparse pipeline with distributed model parallelism. + + This function tests the behavior of the sparse training pipeline using + a mock module and data. + """ + device = torch.device("cpu") model = MockModule().to(device) @@ -65,6 +109,15 @@ def test_sparse_pipeline(): def test_amp(): + """ + Test automatic mixed-precision (AMP) training with the sparse pipeline. + + This function tests the behavior of the sparse training pipeline with + automatic mixed-precision (AMP) enabled, using a mock module and data. + + AMP allows for faster training by using lower-precision data types, such as + torch.bfloat16, while maintaining model accuracy. + """ device = torch.device("cpu") model = MockModule().to(device)