55 lines
2.0 KiB
Python

from tml.core import config as config_mod
import tml.projects.home.recap.data.config as data_config
import tml.projects.home.recap.model.config as model_config
import tml.projects.home.recap.optimizer.config as optimizer_config
from enum import Enum
from typing import Dict, Optional
import pydantic
class TrainingConfig(config_mod.BaseConfig):
save_dir: str = "/tmp/model"
num_train_steps: pydantic.PositiveInt = 1000000
initial_checkpoint_dir: str = pydantic.Field(
None, description="Directory of initial checkpoints", at_most_one_of="initialization"
)
checkpoint_every_n: pydantic.PositiveInt = 1000
checkpoint_max_to_keep: pydantic.PositiveInt = pydantic.Field(
None, description="Maximum number of checkpoints to keep. Defaults to keeping all."
)
train_log_every_n: pydantic.PositiveInt = 1000
num_eval_steps: int = pydantic.Field(
16384, description="Number of evaluation steps. If < 0 the entire dataset " "will be used."
)
eval_log_every_n: pydantic.PositiveInt = 5000
eval_timeout_in_s: pydantic.PositiveFloat = 60 * 60
gradient_accumulation: int = pydantic.Field(
None, description="Number of replica steps to accumulate gradients."
)
class RecapConfig(config_mod.BaseConfig):
training: TrainingConfig = pydantic.Field(TrainingConfig())
model: model_config.ModelConfig
train_data: data_config.RecapDataConfig
validation_data: Dict[str, data_config.RecapDataConfig]
optimizer: optimizer_config.RecapOptimizerConfig
which_metrics: Optional[str] = pydantic.Field(None, description="which metrics to pick.")
# DANGER DANGER! You might expect validators here to ensure that multi task learning setups are
# the same as the data. Unfortunately, this throws opaque errors when the model configuration is
# invalid. In our judgement, that is a more frequency and worse occurrence than tasks not matching
# the data.
class JobMode(str, Enum):
"""Job modes."""
TRAIN = "train"
EVALUATE = "evaluate"
INFERENCE = "inference"