mirror of
https://github.com/twitter/the-algorithm-ml.git
synced 2024-11-19 14:39:22 +01:00
95 lines
3.7 KiB
Python
95 lines
3.7 KiB
Python
from typing import List, Optional
|
|
import tml.core.config as base_config
|
|
from tml.optimizers import config as optimizer_config
|
|
|
|
import pydantic
|
|
|
|
|
|
class EmbeddingSnapshot(base_config.BaseConfig):
|
|
"""Configuration for Embedding snapshot"""
|
|
|
|
emb_name: str = pydantic.Field(
|
|
..., description="Name of the embedding table from the loaded snapshot"
|
|
)
|
|
embedding_snapshot_uri: str = pydantic.Field(
|
|
..., description="Path to torchsnapshot of the embedding"
|
|
)
|
|
|
|
|
|
# https://pytorch.org/torchrec/torchrec.modules.html#torchrec.modules.embedding_configs.EmbeddingBagConfig
|
|
class EmbeddingBagConfig(base_config.BaseConfig):
|
|
"""Configuration for EmbeddingBag."""
|
|
|
|
name: str = pydantic.Field(..., description="name of embedding bag")
|
|
num_embeddings: int = pydantic.Field(..., description="size of embedding dictionary")
|
|
embedding_dim: int = pydantic.Field(..., description="size of each embedding vector")
|
|
pretrained: EmbeddingSnapshot = pydantic.Field(None, description="Snapshot properties")
|
|
vocab: str = pydantic.Field(
|
|
None, description="Directory to parquet files of mapping from entity ID to table index."
|
|
)
|
|
|
|
|
|
class EmbeddingOptimizerConfig(base_config.BaseConfig):
|
|
learning_rate: optimizer_config.LearningRate = pydantic.Field(
|
|
None, description="learning rate scheduler for the EBC"
|
|
)
|
|
init_learning_rate: float = pydantic.Field(description="initial learning rate for the EBC")
|
|
# NB: Only sgd is supported right now and implicitly.
|
|
# FBGemm only supports simple exact_sgd which only takes LR as an argument.
|
|
|
|
|
|
class LargeEmbeddingsConfig(base_config.BaseConfig):
|
|
"""Configuration for EmbeddingBagCollection.
|
|
|
|
The tables listed in this config are gathered into a single torchrec EmbeddingBagCollection.
|
|
"""
|
|
|
|
tables: List[EmbeddingBagConfig] = pydantic.Field(..., description="list of embedding tables")
|
|
optimizer: EmbeddingOptimizerConfig
|
|
tables_to_log: List[str] = pydantic.Field(
|
|
None, description="list of embedding table names that we want to log during training"
|
|
)
|
|
|
|
|
|
class StratifierConfig(base_config.BaseConfig):
|
|
name: str
|
|
index: int
|
|
value: int
|
|
|
|
|
|
class SmallEmbeddingBagConfig(base_config.BaseConfig):
|
|
"""Configuration for SmallEmbeddingBag."""
|
|
|
|
name: str = pydantic.Field(..., description="name of embedding bag")
|
|
num_embeddings: int = pydantic.Field(..., description="size of embedding dictionary")
|
|
embedding_dim: int = pydantic.Field(..., description="size of each embedding vector")
|
|
index: int = pydantic.Field(..., description="index in the discrete tensor to look for")
|
|
|
|
|
|
class SmallEmbeddingBagConfig(base_config.BaseConfig):
|
|
"""Configuration for SmallEmbeddingBag."""
|
|
|
|
name: str = pydantic.Field(..., description="name of embedding bag")
|
|
num_embeddings: int = pydantic.Field(..., description="size of embedding dictionary")
|
|
embedding_dim: int = pydantic.Field(..., description="size of each embedding vector")
|
|
index: int = pydantic.Field(..., description="index in the discrete tensor to look for")
|
|
|
|
|
|
class SmallEmbeddingsConfig(base_config.BaseConfig):
|
|
"""Configuration for SmallEmbeddingConfig.
|
|
|
|
Here we can use discrete features that already are present in our TFRecords generated using
|
|
segdense conversion as "home_recap_2022_discrete__segdense_vals" which are available in
|
|
the model as "discrete_features", and embed a user-defined set of them with configurable
|
|
dimensions and vocabulary sizes.
|
|
|
|
Compared with LargeEmbedding, this config is for small embedding tables that can fit inside
|
|
the model, whereas LargeEmbedding usually is meant to be hydrated outside the model at
|
|
serving time due to size (>>1 GB).
|
|
|
|
This small embeddings table uses the same optimizer as the rest of the model."""
|
|
|
|
tables: List[SmallEmbeddingBagConfig] = pydantic.Field(
|
|
..., description="list of embedding tables"
|
|
)
|