mirror of
https://github.com/twitter/the-algorithm-ml.git
synced 2025-06-16 18:58:30 +02:00
Twitter's Recommendation Algorithm - Heavy Ranker and TwHIN embeddings
This commit is contained in:
1
optimizers/__init__.py
Normal file
1
optimizers/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from tml.optimizers.optimizer import compute_lr
|
82
optimizers/config.py
Normal file
82
optimizers/config.py
Normal file
@ -0,0 +1,82 @@
|
||||
"""Optimization configurations for models."""
|
||||
|
||||
import typing
|
||||
|
||||
import tml.core.config as base_config
|
||||
|
||||
import pydantic
|
||||
|
||||
|
||||
class PiecewiseConstant(base_config.BaseConfig):
|
||||
learning_rate_boundaries: typing.List[int] = pydantic.Field(None)
|
||||
learning_rate_values: typing.List[float] = pydantic.Field(None)
|
||||
|
||||
|
||||
class LinearRampToConstant(base_config.BaseConfig):
|
||||
learning_rate: float
|
||||
num_ramp_steps: pydantic.PositiveInt = pydantic.Field(
|
||||
description="Number of steps to ramp this up from zero."
|
||||
)
|
||||
|
||||
|
||||
class LinearRampToCosine(base_config.BaseConfig):
|
||||
learning_rate: float
|
||||
final_learning_rate: float
|
||||
num_ramp_steps: pydantic.PositiveInt = pydantic.Field(
|
||||
description="Number of steps to ramp this up from zero."
|
||||
)
|
||||
final_num_steps: pydantic.PositiveInt = pydantic.Field(
|
||||
description="Final number of steps where decay stops."
|
||||
)
|
||||
|
||||
|
||||
class LearningRate(base_config.BaseConfig):
|
||||
constant: float = pydantic.Field(None, one_of="lr")
|
||||
linear_ramp_to_cosine: LinearRampToCosine = pydantic.Field(None, one_of="lr")
|
||||
linear_ramp_to_constant: LinearRampToConstant = pydantic.Field(None, one_of="lr")
|
||||
piecewise_constant: PiecewiseConstant = pydantic.Field(None, one_of="lr")
|
||||
|
||||
|
||||
class OptimizerAlgorithmConfig(base_config.BaseConfig):
|
||||
"""Base class for optimizer configurations."""
|
||||
|
||||
lr: float
|
||||
...
|
||||
|
||||
|
||||
class AdamConfig(OptimizerAlgorithmConfig):
|
||||
# see https://pytorch.org/docs/stable/generated/torch.optim.Adam.html#torch.optim.Adam
|
||||
lr: float
|
||||
betas: typing.Tuple[float, float] = [0.9, 0.999]
|
||||
eps: float = 1e-7 # Numerical stability in denominator.
|
||||
|
||||
|
||||
class SgdConfig(OptimizerAlgorithmConfig):
|
||||
lr: float
|
||||
momentum: float = 0.0
|
||||
|
||||
|
||||
class AdagradConfig(OptimizerAlgorithmConfig):
|
||||
lr: float
|
||||
eps: float = 0
|
||||
|
||||
|
||||
class OptimizerConfig(base_config.BaseConfig):
|
||||
learning_rate: LearningRate = pydantic.Field(
|
||||
None,
|
||||
description="Constant learning rates",
|
||||
)
|
||||
adam: AdamConfig = pydantic.Field(None, one_of="optimizer")
|
||||
sgd: SgdConfig = pydantic.Field(None, one_of="optimizer")
|
||||
adagrad: AdagradConfig = pydantic.Field(None, one_of="optimizer")
|
||||
|
||||
|
||||
def get_optimizer_algorithm_config(optimizer_config: OptimizerConfig):
|
||||
if optimizer_config.adam is not None:
|
||||
return optimizer_config.adam
|
||||
elif optimizer_config.sgd is not None:
|
||||
return optimizer_config.sgd
|
||||
elif optimizer_config.adagrad is not None:
|
||||
return optimizer_config.adagrad
|
||||
else:
|
||||
raise ValueError(f"No optimizer selected in optimizer_config, passed {optimizer_config}")
|
105
optimizers/optimizer.py
Normal file
105
optimizers/optimizer.py
Normal file
@ -0,0 +1,105 @@
|
||||
from typing import Dict, Tuple
|
||||
import math
|
||||
import bisect
|
||||
|
||||
from tml.optimizers.config import (
|
||||
LearningRate,
|
||||
OptimizerConfig,
|
||||
)
|
||||
|
||||
import torch
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
from tml.ml_logging.torch_logging import logging
|
||||
|
||||
|
||||
def compute_lr(lr_config, step):
|
||||
"""Compute a learning rate."""
|
||||
if lr_config.constant is not None:
|
||||
return lr_config.constant
|
||||
elif lr_config.piecewise_constant is not None:
|
||||
return lr_config.piecewise_constant.learning_rate_values[
|
||||
bisect.bisect_right(lr_config.piecewise_constant.learning_rate_boundaries, step)
|
||||
]
|
||||
elif lr_config.linear_ramp_to_constant is not None:
|
||||
slope = (
|
||||
lr_config.linear_ramp_to_constant.learning_rate
|
||||
/ lr_config.linear_ramp_to_constant.num_ramp_steps
|
||||
)
|
||||
return min(lr_config.linear_ramp_to_constant.learning_rate, slope * step)
|
||||
elif lr_config.linear_ramp_to_cosine is not None:
|
||||
cfg = lr_config.linear_ramp_to_cosine
|
||||
if step < cfg.num_ramp_steps:
|
||||
slope = cfg.learning_rate / cfg.num_ramp_steps
|
||||
return slope * step
|
||||
elif step <= cfg.final_num_steps:
|
||||
return cfg.final_learning_rate + (cfg.learning_rate - cfg.final_learning_rate) * 0.5 * (
|
||||
1.0
|
||||
+ math.cos(
|
||||
math.pi * (step - cfg.num_ramp_steps) / (cfg.final_num_steps - cfg.num_ramp_steps)
|
||||
)
|
||||
)
|
||||
else:
|
||||
return cfg.final_learning_rate
|
||||
else:
|
||||
raise ValueError(f"No option selected in lr_config, passed {lr_config}")
|
||||
|
||||
|
||||
class LRShim(_LRScheduler):
|
||||
"""Shim to get learning rates into a LRScheduler.
|
||||
|
||||
This adheres to the torch.optim scheduler API and can be plugged anywhere that
|
||||
e.g. exponential decay can be used.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
optimizer,
|
||||
lr_dict: Dict[str, LearningRate],
|
||||
last_epoch=-1,
|
||||
verbose=False,
|
||||
):
|
||||
self.optimizer = optimizer
|
||||
self.lr_dict = lr_dict
|
||||
self.group_names = list(self.lr_dict.keys())
|
||||
|
||||
num_param_groups = sum(1 for _, _optim in optimizer._optims for _ in _optim.param_groups)
|
||||
if num_param_groups != len(lr_dict):
|
||||
raise ValueError(
|
||||
f"Optimizer had {len(optimizer.param_groups)}, but config had {len(lr_dict)}."
|
||||
)
|
||||
|
||||
super().__init__(optimizer, last_epoch, verbose)
|
||||
|
||||
def get_lr(self):
|
||||
if not self._get_lr_called_within_step:
|
||||
logging.warn(
|
||||
"To get the last learning rate computed by the scheduler, " "please use `get_last_lr()`.",
|
||||
UserWarning,
|
||||
)
|
||||
return self._get_closed_form_lr()
|
||||
|
||||
def _get_closed_form_lr(self):
|
||||
return [compute_lr(lr_config, self.last_epoch) for lr_config in self.lr_dict.values()]
|
||||
|
||||
|
||||
def get_optimizer_class(optimizer_config: OptimizerConfig):
|
||||
if optimizer_config.adam is not None:
|
||||
return torch.optim.Adam
|
||||
elif optimizer_config.sgd is not None:
|
||||
return torch.optim.SGD
|
||||
elif optimizer_config.adagrad is not None:
|
||||
return torch.optim.Adagrad
|
||||
|
||||
|
||||
def build_optimizer(
|
||||
model: torch.nn.Module, optimizer_config: OptimizerConfig
|
||||
) -> Tuple[Optimizer, _LRScheduler]:
|
||||
"""Builds an optimizer and LR scheduler from an OptimizerConfig.
|
||||
Note: use this when you want the same optimizer and learning rate schedule for all your parameters.
|
||||
"""
|
||||
optimizer_class = get_optimizer_class(optimizer_config)
|
||||
optimizer = optimizer_class(model.parameters(), **optimizer_config.sgd.dict())
|
||||
# We're passing everything in as one group here
|
||||
scheduler = LRShim(optimizer, lr_dict={"ALL_PARAMS": optimizer_config.learning_rate})
|
||||
return optimizer, scheduler
|
Reference in New Issue
Block a user