mirror of
https://github.com/twitter/the-algorithm-ml.git
synced 2025-01-25 21:31:14 +01:00
69 lines
2.1 KiB
Python
69 lines
2.1 KiB
Python
import functools
|
|
|
|
from tml.projects.twhin.models.config import TwhinModelConfig
|
|
from tml.projects.twhin.models.models import TwhinModel
|
|
from tml.optimizers.optimizer import get_optimizer_class, LRShim
|
|
from tml.optimizers.config import get_optimizer_algorithm_config, LearningRate
|
|
from tml.ml_logging.torch_logging import logging
|
|
|
|
from torchrec.optim.optimizers import in_backward_optimizer_filter
|
|
from torchrec.optim import keyed
|
|
|
|
|
|
FUSED_OPT_KEY = "fused_opt"
|
|
TRANSLATION_OPT_KEY = "operator_opt"
|
|
|
|
|
|
def _lr_from_config(optimizer_config):
|
|
if optimizer_config.learning_rate is not None:
|
|
return optimizer_config.learning_rate
|
|
else:
|
|
# treat None as constant lr
|
|
lr_value = get_optimizer_algorithm_config(optimizer_config).lr
|
|
return LearningRate(constant=lr_value)
|
|
|
|
|
|
def build_optimizer(model: TwhinModel, config: TwhinModelConfig):
|
|
"""Builds an optimizer for a Twhin model combining the embeddings optimizer with an optimizer for per-relation translations.
|
|
|
|
Args:
|
|
model: TwhinModel to build optimizer for.
|
|
config: TwhinConfig for model.
|
|
|
|
Returns:
|
|
Optimizer for model.
|
|
"""
|
|
translation_optimizer_fn = functools.partial(
|
|
get_optimizer_class(config.translation_optimizer),
|
|
**get_optimizer_algorithm_config(config.translation_optimizer).dict(),
|
|
)
|
|
|
|
translation_optimizer = keyed.KeyedOptimizerWrapper(
|
|
dict(in_backward_optimizer_filter(model.named_parameters())),
|
|
optim_factory=translation_optimizer_fn,
|
|
)
|
|
|
|
lr_dict = {}
|
|
for table in config.embeddings.tables:
|
|
lr_dict[table.name] = _lr_from_config(table.optimizer)
|
|
lr_dict[TRANSLATION_OPT_KEY] = _lr_from_config(config.translation_optimizer)
|
|
|
|
logging.info(f"***** LR dict: {lr_dict} *****")
|
|
|
|
logging.info(
|
|
f"***** Combining fused optimizer {model.fused_optimizer} with operator optimizer: {translation_optimizer} *****"
|
|
)
|
|
optimizer = keyed.CombinedOptimizer(
|
|
[
|
|
(FUSED_OPT_KEY, model.fused_optimizer),
|
|
(TRANSLATION_OPT_KEY, translation_optimizer),
|
|
]
|
|
)
|
|
|
|
# scheduler = LRShim(optimizer, lr_dict)
|
|
scheduler = None
|
|
|
|
logging.info(f"***** Combined optimizer after init: {optimizer} *****")
|
|
|
|
return optimizer, scheduler
|