the-algorithm-ml/projects/home/recap/optimizer/optimizer.py

195 lines
7.3 KiB
Python
Raw Normal View History

"""Build optimizers and learning rate schedules."""
import bisect
from collections import defaultdict
import functools
import math
import typing
from typing import Optional
import warnings
# from large_embeddings.config import EmbeddingOptimizerConfig
from tml.projects.home.recap import model as model_mod
from tml.optimizers import config
from tml.optimizers import compute_lr
from absl import logging # type: ignore[attr-defined]
import torch
from torchrec.optim import keyed
_DEFAULT_LR = 24601.0 # NaN the model if we're not using the learning rate.
_BACKBONE = "backbone"
_DENSE_EMBEDDINGS = "dense_ebc"
class RecapLRShim(torch.optim.lr_scheduler._LRScheduler):
2023-09-14 08:00:10 +02:00
"""
A shim to get learning rates into a LRScheduler.
2023-09-14 08:00:10 +02:00
This class adheres to the torch.optim scheduler API and can be plugged into any scheduler that supports
learning rate schedules, such as exponential decay.
2023-09-14 08:00:10 +02:00
Args:
optimizer: The optimizer to which this scheduler is applied.
lr_dict (Dict[str, config.LearningRate]): A dictionary mapping group names to learning rate configurations.
emb_learning_rate: The learning rate for embeddings (optional).
last_epoch (int): The index of the last epoch (default: -1).
verbose (bool): If True, print warnings for deprecated functions (default: False).
Example:
To create a RecapLRShim scheduler for an optimizer and a dictionary of learning rates, use:
```python
scheduler = RecapLRShim(optimizer, lr_dict, emb_learning_rate)
```
Warning:
This class is intended for internal use to handle learning rate scheduling within Recap training and should not
be directly accessed by external code.
"""
def __init__(
self,
optimizer,
lr_dict: typing.Dict[str, config.LearningRate],
emb_learning_rate,
last_epoch=-1,
verbose=False,
):
self.optimizer = optimizer
self.lr_dict = lr_dict
self.group_names = list(self.lr_dict.keys())
self.emb_learning_rate = emb_learning_rate
# We handle sparse LR scheduling separately, so only validate LR groups against dense param groups
num_dense_param_groups = sum(
1
for _, _optim in optimizer._optims
for _ in _optim.param_groups
if isinstance(_optim, keyed.KeyedOptimizerWrapper)
)
if num_dense_param_groups != len(lr_dict):
raise ValueError(
f"Optimizer had {len(optimizer.param_groups)}, but config had {len(lr_dict)}."
)
super().__init__(optimizer, last_epoch, verbose)
def get_lr(self):
if not self._get_lr_called_within_step:
warnings.warn(
"To get the last learning rate computed by the scheduler, " "please use `get_last_lr()`.",
UserWarning,
)
return self._get_closed_form_lr()
def _get_closed_form_lr(self):
learning_rates = []
for lr_config in self.lr_dict.values():
learning_rates.append(compute_lr(lr_config, self.last_epoch))
# WARNING: The order of appending is important.
if self.emb_learning_rate:
learning_rates.append(compute_lr(self.emb_learning_rate, self.last_epoch))
return learning_rates
def build_optimizer(
model: torch.nn.Module,
optimizer_config: config.OptimizerConfig,
emb_optimizer_config: None = None, # Optional[EmbeddingOptimizerConfig] = None,
):
2023-09-14 08:00:10 +02:00
"""
Build an optimizer and scheduler for training.
2023-09-14 08:00:10 +02:00
Args:
model: The torch model, possibly with DDP/DMP.
optimizer_config (config.OptimizerConfig): Configuration settings for the optimizer.
emb_optimizer_config: Configuration settings for embedding optimization (optional).
2023-09-14 08:00:10 +02:00
Returns:
torch.optim.Optimizer: The optimizer for training.
RecapLRShim: The learning rate scheduler for the optimizer.
Example:
To build an optimizer and scheduler for training, use:
```python
optimizer, scheduler = build_optimizer(model, optimizer_config, emb_optimizer_config)
```
"""
optimizer_fn = functools.partial(
torch.optim.Adam,
lr=_DEFAULT_LR,
betas=(optimizer_config.adam.beta_1, optimizer_config.adam.beta_2),
eps=optimizer_config.adam.epsilon,
maximize=False,
)
if optimizer_config.multi_task_learning_rates:
logging.info("***** Parameter groups for optimization *****")
# Importantly, we preserve insertion order in dictionaries here.
parameter_groups: typing.Dict[str, typing.Dict] = defaultdict(dict)
added_parameters: typing.Set[str] = set()
for task in optimizer_config.multi_task_learning_rates.tower_learning_rates:
for name, parameter in model.named_parameters():
if f".{model_mod.sanitize(task)}." in name:
parameter_groups[task][name] = parameter
logging.info(f"{task}: {name}")
if name in added_parameters:
raise ValueError(f"Parameter {name} matched multiple tasks.")
added_parameters.add(name)
for name, parameter in model.named_parameters():
if name not in added_parameters and "embedding_bags" not in name:
parameter_groups[_BACKBONE][name] = parameter
added_parameters.add(name)
logging.info(f"{_BACKBONE}: {name}")
for name, parameter in model.named_parameters():
if name not in added_parameters and "embedding_bags" in name:
parameter_groups[_DENSE_EMBEDDINGS][name] = parameter
logging.info(f"{_DENSE_EMBEDDINGS}: {name}")
all_learning_rates = optimizer_config.multi_task_learning_rates.tower_learning_rates.copy()
if optimizer_config.multi_task_learning_rates.backbone_learning_rate is not None:
all_learning_rates[
_BACKBONE
] = optimizer_config.multi_task_learning_rates.backbone_learning_rate
if _DENSE_EMBEDDINGS in parameter_groups and emb_optimizer_config:
all_learning_rates[_DENSE_EMBEDDINGS] = emb_optimizer_config.learning_rate.copy()
else:
parameter_groups = dict(model.named_parameters())
all_learning_rates = {"single_task": optimizer_config.single_task_learning_rate}
optimizers = [
keyed.KeyedOptimizerWrapper(param_group, optimizer_fn)
for param_name, param_group in parameter_groups.items()
if param_name != _DENSE_EMBEDDINGS
]
# Making EBC optimizer to be SGD to match fused optimiser
if _DENSE_EMBEDDINGS in parameter_groups:
optimizers.append(
keyed.KeyedOptimizerWrapper(
parameter_groups[_DENSE_EMBEDDINGS],
functools.partial(torch.optim.SGD, lr=_DEFAULT_LR, maximize=False, momentum=False),
)
)
if not parameter_groups.keys() == all_learning_rates.keys():
raise ValueError("Learning rates do not match optimizers")
# If the optimiser is dense, model.fused_optimizer will be empty (but not None)
emb_learning_rate = None
if hasattr(model, "fused_optimizer") and model.fused_optimizer.optimizers:
logging.info(f"Model fused optimiser: {model.fused_optimizer}")
optimizers.append(model.fused_optimizer)
if emb_optimizer_config:
emb_learning_rate = emb_optimizer_config.learning_rate.copy()
else:
raise ValueError("Fused kernel exists, but LR is not set")
logging.info(f"***** Combining optimizers: {optimizers} *****")
optimizer = keyed.CombinedOptimizer(optimizers)
scheduler = RecapLRShim(optimizer, all_learning_rates, emb_learning_rate)
logging.info(f"***** Combined optimizer after init: {optimizer} *****")
return optimizer, scheduler