from typing import Dict, Tuple import math import bisect from tml.optimizers.config import ( LearningRate, OptimizerConfig, ) import torch from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler from tml.ml_logging.torch_logging import logging def compute_lr(lr_config, step): """Compute a learning rate.""" if lr_config.constant is not None: return lr_config.constant elif lr_config.piecewise_constant is not None: return lr_config.piecewise_constant.learning_rate_values[ bisect.bisect_right(lr_config.piecewise_constant.learning_rate_boundaries, step) ] elif lr_config.linear_ramp_to_constant is not None: slope = ( lr_config.linear_ramp_to_constant.learning_rate / lr_config.linear_ramp_to_constant.num_ramp_steps ) return min(lr_config.linear_ramp_to_constant.learning_rate, slope * step) elif lr_config.linear_ramp_to_cosine is not None: cfg = lr_config.linear_ramp_to_cosine if step < cfg.num_ramp_steps: slope = cfg.learning_rate / cfg.num_ramp_steps return slope * step elif step <= cfg.final_num_steps: return cfg.final_learning_rate + (cfg.learning_rate - cfg.final_learning_rate) * 0.5 * ( 1.0 + math.cos( math.pi * (step - cfg.num_ramp_steps) / (cfg.final_num_steps - cfg.num_ramp_steps) ) ) else: return cfg.final_learning_rate else: raise ValueError(f"No option selected in lr_config, passed {lr_config}") class LRShim(_LRScheduler): """Shim to get learning rates into a LRScheduler. This adheres to the torch.optim scheduler API and can be plugged anywhere that e.g. exponential decay can be used. """ def __init__( self, optimizer, lr_dict: Dict[str, LearningRate], last_epoch=-1, verbose=False, ): self.optimizer = optimizer self.lr_dict = lr_dict self.group_names = list(self.lr_dict.keys()) num_param_groups = sum(1 for _, _optim in optimizer._optims for _ in _optim.param_groups) if num_param_groups != len(lr_dict): raise ValueError( f"Optimizer had {len(optimizer.param_groups)}, but config had {len(lr_dict)}." ) super().__init__(optimizer, last_epoch, verbose) def get_lr(self): if not self._get_lr_called_within_step: logging.warn( "To get the last learning rate computed by the scheduler, " "please use `get_last_lr()`.", UserWarning, ) return self._get_closed_form_lr() def _get_closed_form_lr(self): return [compute_lr(lr_config, self.last_epoch) for lr_config in self.lr_dict.values()] def get_optimizer_class(optimizer_config: OptimizerConfig): if optimizer_config.adam is not None: return torch.optim.Adam elif optimizer_config.sgd is not None: return torch.optim.SGD elif optimizer_config.adagrad is not None: return torch.optim.Adagrad def build_optimizer( model: torch.nn.Module, optimizer_config: OptimizerConfig ) -> Tuple[Optimizer, _LRScheduler]: """Builds an optimizer and LR scheduler from an OptimizerConfig. Note: use this when you want the same optimizer and learning rate schedule for all your parameters. """ optimizer_class = get_optimizer_class(optimizer_config) optimizer = optimizer_class(model.parameters(), **optimizer_config.sgd.dict()) # We're passing everything in as one group here scheduler = LRShim(optimizer, lr_dict={"ALL_PARAMS": optimizer_config.learning_rate}) return optimizer, scheduler