mirror of
https://github.com/twitter/the-algorithm-ml.git
synced 2024-11-16 21:29:24 +01:00
60 lines
1.8 KiB
Python
60 lines
1.8 KiB
Python
|
from typing import List
|
||
|
from enum import Enum
|
||
|
|
||
|
import tml.core.config as base_config
|
||
|
from tml.optimizers.config import OptimizerConfig
|
||
|
|
||
|
import pydantic
|
||
|
|
||
|
|
||
|
class DataType(str, Enum):
|
||
|
FP32 = "fp32"
|
||
|
FP16 = "fp16"
|
||
|
|
||
|
|
||
|
class EmbeddingSnapshot(base_config.BaseConfig):
|
||
|
"""Configuration for Embedding snapshot"""
|
||
|
|
||
|
emb_name: str = pydantic.Field(
|
||
|
..., description="Name of the embedding table from the loaded snapshot"
|
||
|
)
|
||
|
embedding_snapshot_uri: str = pydantic.Field(
|
||
|
..., description="Path to torchsnapshot of the embedding"
|
||
|
)
|
||
|
|
||
|
|
||
|
class EmbeddingBagConfig(base_config.BaseConfig):
|
||
|
"""Configuration for EmbeddingBag."""
|
||
|
|
||
|
name: str = pydantic.Field(..., description="name of embedding bag")
|
||
|
num_embeddings: int = pydantic.Field(..., description="size of embedding dictionary")
|
||
|
embedding_dim: int = pydantic.Field(..., description="size of each embedding vector")
|
||
|
pretrained: EmbeddingSnapshot = pydantic.Field(None, description="Snapshot properties")
|
||
|
vocab: str = pydantic.Field(
|
||
|
None, description="Directory to parquet files of mapping from entity ID to table index."
|
||
|
)
|
||
|
# make sure to use an optimizer that matches:
|
||
|
# https://github.com/pytorch/FBGEMM/blob/4c58137529d221390575e47e88d3c05ce65b66fd/fbgemm_gpu/fbgemm_gpu/split_embedding_configs.py#L15
|
||
|
optimizer: OptimizerConfig
|
||
|
data_type: DataType
|
||
|
|
||
|
|
||
|
class LargeEmbeddingsConfig(base_config.BaseConfig):
|
||
|
"""Configuration for EmbeddingBagCollection.
|
||
|
|
||
|
The tables listed in this config are gathered into a single torchrec EmbeddingBagCollection.
|
||
|
"""
|
||
|
|
||
|
tables: List[EmbeddingBagConfig] = pydantic.Field(..., description="list of embedding tables")
|
||
|
tables_to_log: List[str] = pydantic.Field(
|
||
|
None, description="list of embedding table names that we want to log during training"
|
||
|
)
|
||
|
|
||
|
|
||
|
class Mode(str, Enum):
|
||
|
"""Job modes."""
|
||
|
|
||
|
TRAIN = "train"
|
||
|
EVALUATE = "evaluate"
|
||
|
INFERENCE = "inference"
|