"""Torch and torchrec specific training and evaluation loops. Features (go/100_enablements): - CUDA data-fetch, compute, gradient-push overlap - Large learnable embeddings through torchrec - On/off-chief evaluation - Warmstart/checkpoint management - go/dataset-service 0-copy integration """ import datetime import os from typing import Callable, Dict, Iterable, List, Mapping, Optional from tml.common import log_weights import tml.common.checkpointing.snapshot as snapshot_lib from tml.core.losses import get_global_loss_detached from tml.ml_logging.torch_logging import logging # type: ignore[attr-defined] from tml.core.train_pipeline import TrainPipelineSparseDist import tree import torch import torch.distributed as dist from torch.optim.lr_scheduler import _LRScheduler import torchmetrics as tm def get_new_iterator(iterable: Iterable): """ This obtain a new iterator from the iterable. If the iterable uses tf.data.Dataset internally, getting a new iterator each N steps will avoid memory leak. To avoid the memory leak calling iter(iterable) should return a "fresh" iterator using a fresh (new instance of) tf.data.Iterator. In particular, iterable can be a torch.utils.data.IterableDataset or a torch.utils.data.DataLoader. When using DDS, performing this reset does not change the order in which elements are received (excluding elements already prefetched) provided that iter(iterable) internally uses a new instance of tf.data.Dataset created by calling from_dataset_id. This requirement is satisfied by RecapDataset. :param iterable: :return: """ return iter(iterable) def _get_step_fn(pipeline, data_iterator, training: bool): """ Returns a function to perform a single evaluation step. Args: pipeline (Pipeline): The pipeline object containing the model. data_iterator (Iterator): The data iterator for evaluation. training (bool): Flag indicating if the model should be in training mode. Returns: function: A function that performs a single evaluation step. """ def step_fn(): """ Perform a single evaluation step. Returns: Any: The evaluation results after a single step. """ # It turns out that model.train() and model.eval() simply switch a single field inside the model # class,so it's somewhat safer to wrap in here. if training: pipeline._model.train() else: pipeline._model.eval() outputs = pipeline.progress(data_iterator) return tree.map_structure(lambda elem: elem.detach(), outputs) return step_fn @torch.no_grad() def _run_evaluation( pipeline, dataset, eval_steps: int, metrics: tm.MetricCollection, eval_batch_size: int, logger=None, ): """ Run the evaluation loop over all evaluation iterators. Args: pipeline (Pipeline): The pipeline object containing the model. dataset (Dataset): The dataset to evaluate. eval_steps (int): The number of evaluation steps to perform. metrics (tm.MetricCollection): A collection of evaluation metrics. eval_batch_size (int): Batch size for evaluation. logger (Optional[Logger]): A logger for recording evaluation progress (default: None). Returns: dict: A dictionary containing the computed evaluation metrics. """ dataset = get_new_iterator(dataset) step_fn = _get_step_fn(pipeline, dataset, training=False) last_time = datetime.datetime.now() logging.info(f"Starting {eval_steps} steps of evaluation.") for _ in range(eval_steps): outputs = step_fn() metrics.update(outputs) eval_ex_per_s = ( eval_batch_size * eval_steps / (datetime.datetime.now() - last_time).total_seconds() ) logging.info(f"eval examples_per_s : {eval_ex_per_s}") metrics_result = metrics.compute() # Resetting at end to release metrics memory not in use. # Reset metrics to prevent accumulation between multiple evaluation splits and not report a # running average. metrics.reset() return metrics_result def train( model: torch.nn.Module, optimizer: torch.optim.Optimizer, device: str, save_dir: str, logging_interval: int, train_steps: int, checkpoint_frequency: int, dataset: Iterable, worker_batch_size: int, num_workers: Optional[int] = 0, enable_amp: bool = False, initial_checkpoint_dir: Optional[str] = None, gradient_accumulation: Optional[int] = None, logger_initializer: Optional[Callable] = None, scheduler: _LRScheduler = None, metrics: Optional[tm.MetricCollection] = None, parameters_to_log: Optional[Dict[str, Callable]] = None, tables_to_log: Optional[List[str]] = None, ) -> None: """ Runs training and evaluation on the given TrainPipeline. Args: model (torch.nn.Module): The neural network model to train. optimizer (torch.optim.Optimizer): The optimizer for model optimization. device (str): The target device for model training (e.g., 'cuda' or 'cpu'). save_dir (str): The directory to save model checkpoints and logs. logging_interval (int): Interval for logging training progress. train_steps (int): The number of training steps to perform. checkpoint_frequency (int): Frequency of saving model checkpoints. dataset (Iterable): Data iterator for the training set. worker_batch_size (int): Batch size for data loading workers. num_workers (Optional[int]): Number of data loading workers (default: 0). enable_amp (bool): Flag to enable Automatic Mixed Precision (AMP) training (default: False). initial_checkpoint_dir (Optional[str]): Directory to initialize training from (default: None). gradient_accumulation (Optional[int]): Number of gradient accumulation steps (default: None). logger_initializer (Optional[Callable]): A logger initializer function (default: None). scheduler (_LRScheduler): Optional learning rate scheduler (default: None). metrics (Optional[tm.MetricCollection]): A collection of evaluation metrics (default: None). parameters_to_log (Optional[Dict[str, Callable]]): Dictionary of parameters to log (default: None). tables_to_log (Optional[List[str]]): List of tables to log (default: None). """ train_pipeline = TrainPipelineSparseDist( model=model, optimizer=optimizer, device=device, enable_amp=enable_amp, grad_accum=gradient_accumulation, ) # type: ignore[var-annotated] # We explicitly initialize optimizer state here so that checkpoint will work properly. if hasattr(train_pipeline._optimizer, "init_state"): train_pipeline._optimizer.init_state() save_state = { "model": train_pipeline._model, "optimizer": train_pipeline._optimizer, "scaler": train_pipeline._grad_scaler, } chosen_checkpoint = None checkpoint_handler = snapshot_lib.Snapshot( save_dir=save_dir, state=save_state, ) if save_dir: chosen_checkpoint = snapshot_lib.get_checkpoint(save_dir=save_dir, missing_ok=True) start_step = 0 start_walltime = 0.0 if chosen_checkpoint: # Skip restoration and exit if we should be finished. chosen_checkpoint_global_step = snapshot_lib.step_from_checkpoint(chosen_checkpoint) if not chosen_checkpoint_global_step < dist.get_world_size() * train_steps: logging.info( "Not restoring and finishing training as latest checkpoint " f"{chosen_checkpoint} found " f"at global_step ({chosen_checkpoint_global_step}) >= " f"train_steps ({dist.get_world_size() * train_steps})" ) return logging.info(f"Restoring latest checkpoint from global_step {chosen_checkpoint_global_step}") checkpoint_handler.restore(chosen_checkpoint) start_step = checkpoint_handler.step start_walltime = checkpoint_handler.walltime elif initial_checkpoint_dir: base, ckpt_step = os.path.split(initial_checkpoint_dir) warmstart_handler = snapshot_lib.Snapshot( save_dir=base, state=save_state, ) ckpt = snapshot_lib.get_checkpoint(save_dir=base, missing_ok=False, global_step=int(ckpt_step)) logging.info( f"Restoring from initial_checkpoint_dir: {initial_checkpoint_dir}, but keeping starting step as 0." ) warmstart_handler.restore(ckpt) train_logger = logger_initializer(mode="train") if logger_initializer else None train_step_fn = _get_step_fn(train_pipeline, get_new_iterator(dataset), training=True) # Counting number of parameters in the model directly when creating it. nb_param = 0 for p in model.parameters(): nb_param += p.numel() logging.info(f"Model has {nb_param} parameters") last_time = datetime.datetime.now() start_time = last_time last_pending_snapshot = None for step in range(start_step, train_steps + 1): checkpoint_handler.step = step outputs = train_step_fn() step_done_time = datetime.datetime.now() checkpoint_handler.walltime = (step_done_time - start_time).total_seconds() + start_walltime if scheduler: scheduler.step() if step % logging_interval == 0: interval_time = (step_done_time - last_time).total_seconds() steps_per_s = logging_interval / interval_time worker_example_per_s = steps_per_s * worker_batch_size global_example_per_s = worker_example_per_s * (1 + (num_workers or 0)) global_step = step log_values = { "global_step": global_step, "loss": get_global_loss_detached(outputs["loss"]), "steps_per_s": steps_per_s, "global_example_per_s": global_example_per_s, "worker_examples_per_s": worker_example_per_s, "active_training_walltime": checkpoint_handler.walltime, } if parameters_to_log: log_values.update( log_weights.weights_to_log( model=model, how_to_log=parameters_to_log, ) ) log_values = tree.map_structure(lambda elem: torch.as_tensor(elem).cpu(), log_values) if tables_to_log: log_values.update( log_weights.log_ebc_norms( model_state_dict=train_pipeline._model.state_dict(), ebc_keys=tables_to_log, ) ) if train_logger: train_logger.log(log_values, step=global_step) log_line = ", ".join(f"{name}: {value}" for name, value in log_values.items()) logging.info(f"Step: {step}, training. {log_line}") last_time = step_done_time # If we just restored, do not save again. if checkpoint_frequency and step > start_step and step % checkpoint_frequency == 0: if last_pending_snapshot and not last_pending_snapshot.done(): logging.warning( "Begin a new snapshot and the last one hasn't finished. That probably indicates " "either you're snapshotting really often or something is wrong. Will now block and " "wait for snapshot to finish before beginning the next one." ) last_pending_snapshot.wait() last_pending_snapshot = checkpoint_handler.save(global_step=step * dist.get_world_size()) # Save if we did not just save. if checkpoint_frequency and step % checkpoint_frequency != 0: # For the final save, wait for the checkpoint to write to make sure the process doesn't finish # before its completed. last_pending_snapshot = checkpoint_handler.save(global_step=step * dist.get_world_size()) logging.info(f"Finished training steps: {step}, global_steps: {step * dist.get_world_size()}") if last_pending_snapshot: logging.info(f"Waiting for any checkpoints to finish.") last_pending_snapshot.wait() def log_eval_results( results, eval_logger, partition_name: str, step: int, ): """ Logs evaluation results and optionally records them using a provided logger. Args: results (Any): The evaluation results to log. eval_logger (Callable): A logger for recording evaluation results. partition_name (str): The name of the evaluation partition. step (int): The current step in the evaluation. """ results = tree.map_structure(lambda elem: torch.as_tensor(elem).cpu(), results) logging.info(f"Step: {step}, evaluation ({partition_name}).") for metric_name, metric_value in results.items(): logging.info(f"\t{metric_name}: {metric_value:1.4e}") if eval_logger: eval_logger.log(results, step=step, commit=True) def only_evaluate( model: torch.nn.Module, optimizer: torch.optim.Optimizer, device: str, save_dir: str, num_train_steps: int, dataset: Iterable, eval_batch_size: int, num_eval_steps: int, eval_timeout_in_s: int, eval_logger: Callable, partition_name: str, metrics: Optional[tm.MetricCollection] = None, ): """ Performs evaluation on a given dataset partition. Args: model (torch.nn.Module): The neural network model for evaluation. optimizer (torch.optim.Optimizer): The optimizer used during evaluation. device (str): The target device for evaluation (e.g., 'cuda' or 'cpu'). save_dir (str): The directory containing model checkpoints. num_train_steps (int): The total number of training steps. dataset (Iterable): Data iterator for evaluation. eval_batch_size (int): Batch size for evaluation. num_eval_steps (int): The number of evaluation steps to perform. eval_timeout_in_s (int): Timeout for evaluating checkpoints in seconds. eval_logger (Callable): A logger for recording evaluation results. partition_name (str): The name of the evaluation partition. metrics (Optional[tm.MetricCollection]): A collection of evaluation metrics (default: None). """ logging.info(f"Evaluating on partition {partition_name}.") logging.info("Computing metrics:") logging.info(metrics) eval_pipeline = TrainPipelineSparseDist(model, optimizer, device) # type: ignore[var-annotated] save_state = { "model": eval_pipeline._model, "optimizer": eval_pipeline._optimizer, } checkpoint_handler = snapshot_lib.Snapshot( save_dir=save_dir, state=save_state, ) for checkpoint_path in snapshot_lib.checkpoints_iterator(save_dir, timeout=eval_timeout_in_s): checkpoint_handler.restore(checkpoint_path) step = checkpoint_handler.step dataset = get_new_iterator(dataset) results = _run_evaluation( pipeline=eval_pipeline, dataset=dataset, eval_steps=num_eval_steps, eval_batch_size=eval_batch_size, metrics=metrics, ) log_eval_results(results, eval_logger, partition_name, step=step) rank = dist.get_rank() if dist.is_initialized() else 0 if rank == 0: snapshot_lib.mark_done_eval(checkpoint_path, partition_name) if step >= num_train_steps: return