37 lines
1.2 KiB
Python

"""Optimization configurations for models."""
import typing
import tml.core.config as base_config
import tml.optimizers.config as optimizers_config_mod
import pydantic
class RecapAdamConfig(base_config.BaseConfig):
beta_1: float = 0.9 # Momentum term.
beta_2: float = 0.999 # Exponential weighted decay factor.
epsilon: float = 1e-7 # Numerical stability in denominator.
class MultiTaskLearningRates(base_config.BaseConfig):
tower_learning_rates: typing.Dict[str, optimizers_config_mod.LearningRate] = pydantic.Field(
description="Learning rates for different towers of the model."
)
backbone_learning_rate: optimizers_config_mod.LearningRate = pydantic.Field(
None, description="Learning rate for backbone of the model."
)
class RecapOptimizerConfig(base_config.BaseConfig):
multi_task_learning_rates: MultiTaskLearningRates = pydantic.Field(
None, description="Multiple learning rates for different tasks.", one_of="lr"
)
single_task_learning_rate: optimizers_config_mod.LearningRate = pydantic.Field(
None, description="Single task learning rates", one_of="lr"
)
adam: RecapAdamConfig = pydantic.Field(one_of="optimizer")