167 lines
6.0 KiB
Python

"""Build optimizers and learning rate schedules."""
import bisect
from collections import defaultdict
import functools
import math
import typing
from typing import Optional
import warnings
# from large_embeddings.config import EmbeddingOptimizerConfig
from tml.projects.home.recap import model as model_mod
from tml.optimizers import config
from tml.optimizers import compute_lr
from absl import logging # type: ignore[attr-defined]
import torch
from torchrec.optim import keyed
_DEFAULT_LR = 24601.0 # NaN the model if we're not using the learning rate.
_BACKBONE = "backbone"
_DENSE_EMBEDDINGS = "dense_ebc"
class RecapLRShim(torch.optim.lr_scheduler._LRScheduler):
"""Shim to get learning rates into a LRScheduler.
This adheres to the torch.optim scheduler API and can be plugged anywhere that
e.g. exponential decay can be used.
"""
def __init__(
self,
optimizer,
lr_dict: typing.Dict[str, config.LearningRate],
emb_learning_rate,
last_epoch=-1,
verbose=False,
):
self.optimizer = optimizer
self.lr_dict = lr_dict
self.group_names = list(self.lr_dict.keys())
self.emb_learning_rate = emb_learning_rate
# We handle sparse LR scheduling separately, so only validate LR groups against dense param groups
num_dense_param_groups = sum(
1
for _, _optim in optimizer._optims
for _ in _optim.param_groups
if isinstance(_optim, keyed.KeyedOptimizerWrapper)
)
if num_dense_param_groups != len(lr_dict):
raise ValueError(
f"Optimizer had {len(optimizer.param_groups)}, but config had {len(lr_dict)}."
)
super().__init__(optimizer, last_epoch, verbose)
def get_lr(self):
if not self._get_lr_called_within_step:
warnings.warn(
"To get the last learning rate computed by the scheduler, " "please use `get_last_lr()`.",
UserWarning,
)
return self._get_closed_form_lr()
def _get_closed_form_lr(self):
learning_rates = []
for lr_config in self.lr_dict.values():
learning_rates.append(compute_lr(lr_config, self.last_epoch))
# WARNING: The order of appending is important.
if self.emb_learning_rate:
learning_rates.append(compute_lr(self.emb_learning_rate, self.last_epoch))
return learning_rates
def build_optimizer(
model: torch.nn.Module,
optimizer_config: config.OptimizerConfig,
emb_optimizer_config: None = None, # Optional[EmbeddingOptimizerConfig] = None,
):
"""Builds an optimizer and scheduler.
Args:
model: A torch model, probably with DDP/DMP.
optimizer_config: An OptimizerConfig object that specifies learning rates per tower.
Returns:
A torch.optim instance, and a scheduler instance.
"""
optimizer_fn = functools.partial(
torch.optim.Adam,
lr=_DEFAULT_LR,
betas=(optimizer_config.adam.beta_1, optimizer_config.adam.beta_2),
eps=optimizer_config.adam.epsilon,
maximize=False,
)
if optimizer_config.multi_task_learning_rates:
logging.info("***** Parameter groups for optimization *****")
# Importantly, we preserve insertion order in dictionaries here.
parameter_groups: typing.Dict[str, typing.Dict] = defaultdict(dict)
added_parameters: typing.Set[str] = set()
for task in optimizer_config.multi_task_learning_rates.tower_learning_rates:
for name, parameter in model.named_parameters():
if f".{model_mod.sanitize(task)}." in name:
parameter_groups[task][name] = parameter
logging.info(f"{task}: {name}")
if name in added_parameters:
raise ValueError(f"Parameter {name} matched multiple tasks.")
added_parameters.add(name)
for name, parameter in model.named_parameters():
if name not in added_parameters and "embedding_bags" not in name:
parameter_groups[_BACKBONE][name] = parameter
added_parameters.add(name)
logging.info(f"{_BACKBONE}: {name}")
for name, parameter in model.named_parameters():
if name not in added_parameters and "embedding_bags" in name:
parameter_groups[_DENSE_EMBEDDINGS][name] = parameter
logging.info(f"{_DENSE_EMBEDDINGS}: {name}")
all_learning_rates = optimizer_config.multi_task_learning_rates.tower_learning_rates.copy()
if optimizer_config.multi_task_learning_rates.backbone_learning_rate is not None:
all_learning_rates[
_BACKBONE
] = optimizer_config.multi_task_learning_rates.backbone_learning_rate
if _DENSE_EMBEDDINGS in parameter_groups and emb_optimizer_config:
all_learning_rates[_DENSE_EMBEDDINGS] = emb_optimizer_config.learning_rate.copy()
else:
parameter_groups = dict(model.named_parameters())
all_learning_rates = {"single_task": optimizer_config.single_task_learning_rate}
optimizers = [
keyed.KeyedOptimizerWrapper(param_group, optimizer_fn)
for param_name, param_group in parameter_groups.items()
if param_name != _DENSE_EMBEDDINGS
]
# Making EBC optimizer to be SGD to match fused optimiser
if _DENSE_EMBEDDINGS in parameter_groups:
optimizers.append(
keyed.KeyedOptimizerWrapper(
parameter_groups[_DENSE_EMBEDDINGS],
functools.partial(torch.optim.SGD, lr=_DEFAULT_LR, maximize=False, momentum=False),
)
)
if not parameter_groups.keys() == all_learning_rates.keys():
raise ValueError("Learning rates do not match optimizers")
# If the optimiser is dense, model.fused_optimizer will be empty (but not None)
emb_learning_rate = None
if hasattr(model, "fused_optimizer") and model.fused_optimizer.optimizers:
logging.info(f"Model fused optimiser: {model.fused_optimizer}")
optimizers.append(model.fused_optimizer)
if emb_optimizer_config:
emb_learning_rate = emb_optimizer_config.learning_rate.copy()
else:
raise ValueError("Fused kernel exists, but LR is not set")
logging.info(f"***** Combining optimizers: {optimizers} *****")
optimizer = keyed.CombinedOptimizer(optimizers)
scheduler = RecapLRShim(optimizer, all_learning_rates, emb_learning_rate)
logging.info(f"***** Combined optimizer after init: {optimizer} *****")
return optimizer, scheduler