106 lines
3.4 KiB
Python

from typing import Dict, Tuple
import math
import bisect
from tml.optimizers.config import (
LearningRate,
OptimizerConfig,
)
import torch
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from tml.ml_logging.torch_logging import logging
def compute_lr(lr_config, step):
"""Compute a learning rate."""
if lr_config.constant is not None:
return lr_config.constant
elif lr_config.piecewise_constant is not None:
return lr_config.piecewise_constant.learning_rate_values[
bisect.bisect_right(lr_config.piecewise_constant.learning_rate_boundaries, step)
]
elif lr_config.linear_ramp_to_constant is not None:
slope = (
lr_config.linear_ramp_to_constant.learning_rate
/ lr_config.linear_ramp_to_constant.num_ramp_steps
)
return min(lr_config.linear_ramp_to_constant.learning_rate, slope * step)
elif lr_config.linear_ramp_to_cosine is not None:
cfg = lr_config.linear_ramp_to_cosine
if step < cfg.num_ramp_steps:
slope = cfg.learning_rate / cfg.num_ramp_steps
return slope * step
elif step <= cfg.final_num_steps:
return cfg.final_learning_rate + (cfg.learning_rate - cfg.final_learning_rate) * 0.5 * (
1.0
+ math.cos(
math.pi * (step - cfg.num_ramp_steps) / (cfg.final_num_steps - cfg.num_ramp_steps)
)
)
else:
return cfg.final_learning_rate
else:
raise ValueError(f"No option selected in lr_config, passed {lr_config}")
class LRShim(_LRScheduler):
"""Shim to get learning rates into a LRScheduler.
This adheres to the torch.optim scheduler API and can be plugged anywhere that
e.g. exponential decay can be used.
"""
def __init__(
self,
optimizer,
lr_dict: Dict[str, LearningRate],
last_epoch=-1,
verbose=False,
):
self.optimizer = optimizer
self.lr_dict = lr_dict
self.group_names = list(self.lr_dict.keys())
num_param_groups = sum(1 for _, _optim in optimizer._optims for _ in _optim.param_groups)
if num_param_groups != len(lr_dict):
raise ValueError(
f"Optimizer had {len(optimizer.param_groups)}, but config had {len(lr_dict)}."
)
super().__init__(optimizer, last_epoch, verbose)
def get_lr(self):
if not self._get_lr_called_within_step:
logging.warn(
"To get the last learning rate computed by the scheduler, " "please use `get_last_lr()`.",
UserWarning,
)
return self._get_closed_form_lr()
def _get_closed_form_lr(self):
return [compute_lr(lr_config, self.last_epoch) for lr_config in self.lr_dict.values()]
def get_optimizer_class(optimizer_config: OptimizerConfig):
if optimizer_config.adam is not None:
return torch.optim.Adam
elif optimizer_config.sgd is not None:
return torch.optim.SGD
elif optimizer_config.adagrad is not None:
return torch.optim.Adagrad
def build_optimizer(
model: torch.nn.Module, optimizer_config: OptimizerConfig
) -> Tuple[Optimizer, _LRScheduler]:
"""Builds an optimizer and LR scheduler from an OptimizerConfig.
Note: use this when you want the same optimizer and learning rate schedule for all your parameters.
"""
optimizer_class = get_optimizer_class(optimizer_config)
optimizer = optimizer_class(model.parameters(), **optimizer_config.sgd.dict())
# We're passing everything in as one group here
scheduler = LRShim(optimizer, lr_dict={"ALL_PARAMS": optimizer_config.learning_rate})
return optimizer, scheduler