"""Build optimizers and learning rate schedules.""" import bisect from collections import defaultdict import functools import math import typing from typing import Optional import warnings # from large_embeddings.config import EmbeddingOptimizerConfig from tml.projects.home.recap import model as model_mod from tml.optimizers import config from tml.optimizers import compute_lr from absl import logging # type: ignore[attr-defined] import torch from torchrec.optim import keyed _DEFAULT_LR = 24601.0 # NaN the model if we're not using the learning rate. _BACKBONE = "backbone" _DENSE_EMBEDDINGS = "dense_ebc" class RecapLRShim(torch.optim.lr_scheduler._LRScheduler): """ A shim to get learning rates into a LRScheduler. This class adheres to the torch.optim scheduler API and can be plugged into any scheduler that supports learning rate schedules, such as exponential decay. Args: optimizer: The optimizer to which this scheduler is applied. lr_dict (Dict[str, config.LearningRate]): A dictionary mapping group names to learning rate configurations. emb_learning_rate: The learning rate for embeddings (optional). last_epoch (int): The index of the last epoch (default: -1). verbose (bool): If True, print warnings for deprecated functions (default: False). Example: To create a RecapLRShim scheduler for an optimizer and a dictionary of learning rates, use: ```python scheduler = RecapLRShim(optimizer, lr_dict, emb_learning_rate) ``` Warning: This class is intended for internal use to handle learning rate scheduling within Recap training and should not be directly accessed by external code. """ def __init__( self, optimizer, lr_dict: typing.Dict[str, config.LearningRate], emb_learning_rate, last_epoch=-1, verbose=False, ): self.optimizer = optimizer self.lr_dict = lr_dict self.group_names = list(self.lr_dict.keys()) self.emb_learning_rate = emb_learning_rate # We handle sparse LR scheduling separately, so only validate LR groups against dense param groups num_dense_param_groups = sum( 1 for _, _optim in optimizer._optims for _ in _optim.param_groups if isinstance(_optim, keyed.KeyedOptimizerWrapper) ) if num_dense_param_groups != len(lr_dict): raise ValueError( f"Optimizer had {len(optimizer.param_groups)}, but config had {len(lr_dict)}." ) super().__init__(optimizer, last_epoch, verbose) def get_lr(self): if not self._get_lr_called_within_step: warnings.warn( "To get the last learning rate computed by the scheduler, " "please use `get_last_lr()`.", UserWarning, ) return self._get_closed_form_lr() def _get_closed_form_lr(self): learning_rates = [] for lr_config in self.lr_dict.values(): learning_rates.append(compute_lr(lr_config, self.last_epoch)) # WARNING: The order of appending is important. if self.emb_learning_rate: learning_rates.append(compute_lr(self.emb_learning_rate, self.last_epoch)) return learning_rates def build_optimizer( model: torch.nn.Module, optimizer_config: config.OptimizerConfig, emb_optimizer_config: None = None, # Optional[EmbeddingOptimizerConfig] = None, ): """ Build an optimizer and scheduler for training. Args: model: The torch model, possibly with DDP/DMP. optimizer_config (config.OptimizerConfig): Configuration settings for the optimizer. emb_optimizer_config: Configuration settings for embedding optimization (optional). Returns: torch.optim.Optimizer: The optimizer for training. RecapLRShim: The learning rate scheduler for the optimizer. Example: To build an optimizer and scheduler for training, use: ```python optimizer, scheduler = build_optimizer(model, optimizer_config, emb_optimizer_config) ``` """ optimizer_fn = functools.partial( torch.optim.Adam, lr=_DEFAULT_LR, betas=(optimizer_config.adam.beta_1, optimizer_config.adam.beta_2), eps=optimizer_config.adam.epsilon, maximize=False, ) if optimizer_config.multi_task_learning_rates: logging.info("***** Parameter groups for optimization *****") # Importantly, we preserve insertion order in dictionaries here. parameter_groups: typing.Dict[str, typing.Dict] = defaultdict(dict) added_parameters: typing.Set[str] = set() for task in optimizer_config.multi_task_learning_rates.tower_learning_rates: for name, parameter in model.named_parameters(): if f".{model_mod.sanitize(task)}." in name: parameter_groups[task][name] = parameter logging.info(f"{task}: {name}") if name in added_parameters: raise ValueError(f"Parameter {name} matched multiple tasks.") added_parameters.add(name) for name, parameter in model.named_parameters(): if name not in added_parameters and "embedding_bags" not in name: parameter_groups[_BACKBONE][name] = parameter added_parameters.add(name) logging.info(f"{_BACKBONE}: {name}") for name, parameter in model.named_parameters(): if name not in added_parameters and "embedding_bags" in name: parameter_groups[_DENSE_EMBEDDINGS][name] = parameter logging.info(f"{_DENSE_EMBEDDINGS}: {name}") all_learning_rates = optimizer_config.multi_task_learning_rates.tower_learning_rates.copy() if optimizer_config.multi_task_learning_rates.backbone_learning_rate is not None: all_learning_rates[ _BACKBONE ] = optimizer_config.multi_task_learning_rates.backbone_learning_rate if _DENSE_EMBEDDINGS in parameter_groups and emb_optimizer_config: all_learning_rates[_DENSE_EMBEDDINGS] = emb_optimizer_config.learning_rate.copy() else: parameter_groups = dict(model.named_parameters()) all_learning_rates = {"single_task": optimizer_config.single_task_learning_rate} optimizers = [ keyed.KeyedOptimizerWrapper(param_group, optimizer_fn) for param_name, param_group in parameter_groups.items() if param_name != _DENSE_EMBEDDINGS ] # Making EBC optimizer to be SGD to match fused optimiser if _DENSE_EMBEDDINGS in parameter_groups: optimizers.append( keyed.KeyedOptimizerWrapper( parameter_groups[_DENSE_EMBEDDINGS], functools.partial(torch.optim.SGD, lr=_DEFAULT_LR, maximize=False, momentum=False), ) ) if not parameter_groups.keys() == all_learning_rates.keys(): raise ValueError("Learning rates do not match optimizers") # If the optimiser is dense, model.fused_optimizer will be empty (but not None) emb_learning_rate = None if hasattr(model, "fused_optimizer") and model.fused_optimizer.optimizers: logging.info(f"Model fused optimiser: {model.fused_optimizer}") optimizers.append(model.fused_optimizer) if emb_optimizer_config: emb_learning_rate = emb_optimizer_config.learning_rate.copy() else: raise ValueError("Fused kernel exists, but LR is not set") logging.info(f"***** Combining optimizers: {optimizers} *****") optimizer = keyed.CombinedOptimizer(optimizers) scheduler = RecapLRShim(optimizer, all_learning_rates, emb_learning_rate) logging.info(f"***** Combined optimizer after init: {optimizer} *****") return optimizer, scheduler