mirror of
https://github.com/twitter/the-algorithm-ml.git
synced 2025-01-26 05:35:27 +01:00
210 lines
7.4 KiB
Python
210 lines
7.4 KiB
Python
from typing import Dict, Tuple
|
|
import math
|
|
import bisect
|
|
|
|
from tml.optimizers.config import (
|
|
LearningRate,
|
|
OptimizerConfig,
|
|
)
|
|
|
|
import torch
|
|
from torch.optim import Optimizer
|
|
from torch.optim.lr_scheduler import _LRScheduler
|
|
from tml.ml_logging.torch_logging import logging
|
|
|
|
|
|
def compute_lr(lr_config, step):
|
|
"""
|
|
Compute the learning rate based on the specified learning rate configuration.
|
|
|
|
This function calculates the learning rate according to the given configuration, which can include
|
|
constant learning rates, piecewise constant schedules, linear ramps, and cosine annealing.
|
|
|
|
Args:
|
|
lr_config (LearningRate): The learning rate configuration specifying the learning rate schedule.
|
|
step (int): The current training step or iteration.
|
|
|
|
Returns:
|
|
float: The computed learning rate for the current step.
|
|
|
|
Raises:
|
|
ValueError: If the `lr_config` is invalid or contains conflicting options.
|
|
|
|
Example:
|
|
```python
|
|
lr_schedule = LearningRate(
|
|
constant=0.001,
|
|
piecewise_constant=PiecewiseConstant(
|
|
learning_rate_boundaries=[1000, 2000, 3000],
|
|
learning_rate_values=[0.1, 0.05, 0.01, 0.001]
|
|
)
|
|
)
|
|
current_step = 2500
|
|
learning_rate = compute_lr(lr_schedule, current_step)
|
|
```
|
|
"""
|
|
if lr_config.constant is not None:
|
|
return lr_config.constant
|
|
elif lr_config.piecewise_constant is not None:
|
|
return lr_config.piecewise_constant.learning_rate_values[
|
|
bisect.bisect_right(lr_config.piecewise_constant.learning_rate_boundaries, step)
|
|
]
|
|
elif lr_config.linear_ramp_to_constant is not None:
|
|
slope = (
|
|
lr_config.linear_ramp_to_constant.learning_rate
|
|
/ lr_config.linear_ramp_to_constant.num_ramp_steps
|
|
)
|
|
return min(lr_config.linear_ramp_to_constant.learning_rate, slope * step)
|
|
elif lr_config.linear_ramp_to_cosine is not None:
|
|
cfg = lr_config.linear_ramp_to_cosine
|
|
if step < cfg.num_ramp_steps:
|
|
slope = cfg.learning_rate / cfg.num_ramp_steps
|
|
return slope * step
|
|
elif step <= cfg.final_num_steps:
|
|
return cfg.final_learning_rate + (cfg.learning_rate - cfg.final_learning_rate) * 0.5 * (
|
|
1.0
|
|
+ math.cos(
|
|
math.pi * (step - cfg.num_ramp_steps) / (cfg.final_num_steps - cfg.num_ramp_steps)
|
|
)
|
|
)
|
|
else:
|
|
return cfg.final_learning_rate
|
|
else:
|
|
raise ValueError(f"No option selected in lr_config, passed {lr_config}")
|
|
|
|
|
|
class LRShim(_LRScheduler):
|
|
"""
|
|
Learning Rate Scheduler Shim to adjust learning rates during training.
|
|
|
|
This class acts as a shim to apply different learning rates to individual parameter groups
|
|
within an optimizer. It adheres to the torch.optim scheduler API and can be used with various
|
|
optimizers, allowing fine-grained control over learning rates based on configuration.
|
|
|
|
Args:
|
|
optimizer (torch.optim.Optimizer): The optimizer for which learning rates will be adjusted.
|
|
lr_dict (Dict[str, LearningRate]): A dictionary mapping parameter group names to their
|
|
corresponding learning rate configurations.
|
|
last_epoch (int, optional): The index of the last epoch. Default is -1.
|
|
verbose (bool, optional): If True, prints a warning message when accessing learning rates
|
|
using the deprecated `get_lr()` method. Default is False.
|
|
|
|
Raises:
|
|
ValueError: If the number of parameter groups in the optimizer does not match the number
|
|
of learning rate configurations provided.
|
|
|
|
Note:
|
|
To obtain the last computed learning rates, please use `get_last_lr()`.
|
|
|
|
Example:
|
|
```python
|
|
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
|
|
lr_schedule = {
|
|
'main': LearningRate(constant=0.01),
|
|
'auxiliary': LearningRate(piecewise_constant=PiecewiseConstant(
|
|
learning_rate_boundaries=[1000, 2000],
|
|
learning_rate_values=[0.01, 0.001]
|
|
))
|
|
}
|
|
lr_shim = LRShim(optimizer, lr_schedule)
|
|
|
|
for epoch in range(num_epochs):
|
|
# Train the model
|
|
train(...)
|
|
# Update learning rates at the end of each epoch
|
|
lr_shim.step(epoch)
|
|
|
|
final_lr_main = lr_shim.get_last_lr()['main']
|
|
final_lr_auxiliary = lr_shim.get_last_lr()['auxiliary']
|
|
```
|
|
|
|
See Also:
|
|
- `LearningRate`: Configuration for specifying learning rates.
|
|
- `PiecewiseConstant`: Configuration for piecewise constant learning rate schedules.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
optimizer,
|
|
lr_dict: Dict[str, LearningRate],
|
|
last_epoch=-1,
|
|
verbose=False,
|
|
):
|
|
self.optimizer = optimizer
|
|
self.lr_dict = lr_dict
|
|
self.group_names = list(self.lr_dict.keys())
|
|
|
|
num_param_groups = sum(1 for _, _optim in optimizer._optims for _ in _optim.param_groups)
|
|
if num_param_groups != len(lr_dict):
|
|
raise ValueError(
|
|
f"Optimizer had {len(optimizer.param_groups)}, but config had {len(lr_dict)}."
|
|
)
|
|
|
|
super().__init__(optimizer, last_epoch, verbose)
|
|
|
|
def get_lr(self):
|
|
if not self._get_lr_called_within_step:
|
|
logging.warn(
|
|
"To get the last learning rate computed by the scheduler, " "please use `get_last_lr()`.",
|
|
UserWarning,
|
|
)
|
|
return self._get_closed_form_lr()
|
|
|
|
def _get_closed_form_lr(self):
|
|
return [compute_lr(lr_config, self.last_epoch) for lr_config in self.lr_dict.values()]
|
|
|
|
|
|
def get_optimizer_class(optimizer_config: OptimizerConfig):
|
|
if optimizer_config.adam is not None:
|
|
return torch.optim.Adam
|
|
elif optimizer_config.sgd is not None:
|
|
return torch.optim.SGD
|
|
elif optimizer_config.adagrad is not None:
|
|
return torch.optim.Adagrad
|
|
|
|
|
|
def build_optimizer(
|
|
model: torch.nn.Module, optimizer_config: OptimizerConfig
|
|
) -> Tuple[Optimizer, _LRScheduler]:
|
|
"""
|
|
Build an optimizer and learning rate scheduler based on the provided optimizer configuration.
|
|
|
|
Args:
|
|
model (torch.nn.Module): The PyTorch model for which the optimizer will be created.
|
|
optimizer_config (OptimizerConfig): The optimizer configuration specifying the optimizer
|
|
algorithm and learning rate settings.
|
|
|
|
Returns:
|
|
Tuple[Optimizer, _LRScheduler]: A tuple containing the optimizer and learning rate scheduler
|
|
objects.
|
|
|
|
Note:
|
|
This function is intended for cases where you want the same optimizer and learning rate
|
|
schedule for all model parameters.
|
|
|
|
Example:
|
|
```python
|
|
model = MyModel()
|
|
optimizer_config = OptimizerConfig(
|
|
learning_rate=LearningRate(constant=0.01),
|
|
sgd=SgdConfig(lr=0.01, momentum=0.9)
|
|
)
|
|
optimizer, scheduler = build_optimizer(model, optimizer_config)
|
|
|
|
for epoch in range(num_epochs):
|
|
# Train the model with the optimizer
|
|
train(model, optimizer, ...)
|
|
# Update learning rates at the end of each epoch
|
|
scheduler.step(epoch)
|
|
```
|
|
|
|
See Also:
|
|
- `OptimizerConfig`: Configuration for specifying optimizer settings.
|
|
- `LRShim`: Learning rate scheduler shim for fine-grained learning rate control.
|
|
"""
|
|
optimizer_class = get_optimizer_class(optimizer_config)
|
|
optimizer = optimizer_class(model.parameters(), **optimizer_config.sgd.dict())
|
|
# We're passing everything in as one group here
|
|
scheduler = LRShim(optimizer, lr_dict={"ALL_PARAMS": optimizer_config.learning_rate})
|
|
return optimizer, scheduler
|