mirror of
https://github.com/twitter/the-algorithm-ml.git
synced 2025-06-13 12:58:39 +02:00
Twitter's Recommendation Algorithm - Heavy Ranker and TwHIN embeddings
This commit is contained in:
59
common/modules/embedding/config.py
Normal file
59
common/modules/embedding/config.py
Normal file
@ -0,0 +1,59 @@
|
||||
from typing import List
|
||||
from enum import Enum
|
||||
|
||||
import tml.core.config as base_config
|
||||
from tml.optimizers.config import OptimizerConfig
|
||||
|
||||
import pydantic
|
||||
|
||||
|
||||
class DataType(str, Enum):
|
||||
FP32 = "fp32"
|
||||
FP16 = "fp16"
|
||||
|
||||
|
||||
class EmbeddingSnapshot(base_config.BaseConfig):
|
||||
"""Configuration for Embedding snapshot"""
|
||||
|
||||
emb_name: str = pydantic.Field(
|
||||
..., description="Name of the embedding table from the loaded snapshot"
|
||||
)
|
||||
embedding_snapshot_uri: str = pydantic.Field(
|
||||
..., description="Path to torchsnapshot of the embedding"
|
||||
)
|
||||
|
||||
|
||||
class EmbeddingBagConfig(base_config.BaseConfig):
|
||||
"""Configuration for EmbeddingBag."""
|
||||
|
||||
name: str = pydantic.Field(..., description="name of embedding bag")
|
||||
num_embeddings: int = pydantic.Field(..., description="size of embedding dictionary")
|
||||
embedding_dim: int = pydantic.Field(..., description="size of each embedding vector")
|
||||
pretrained: EmbeddingSnapshot = pydantic.Field(None, description="Snapshot properties")
|
||||
vocab: str = pydantic.Field(
|
||||
None, description="Directory to parquet files of mapping from entity ID to table index."
|
||||
)
|
||||
# make sure to use an optimizer that matches:
|
||||
# https://github.com/pytorch/FBGEMM/blob/4c58137529d221390575e47e88d3c05ce65b66fd/fbgemm_gpu/fbgemm_gpu/split_embedding_configs.py#L15
|
||||
optimizer: OptimizerConfig
|
||||
data_type: DataType
|
||||
|
||||
|
||||
class LargeEmbeddingsConfig(base_config.BaseConfig):
|
||||
"""Configuration for EmbeddingBagCollection.
|
||||
|
||||
The tables listed in this config are gathered into a single torchrec EmbeddingBagCollection.
|
||||
"""
|
||||
|
||||
tables: List[EmbeddingBagConfig] = pydantic.Field(..., description="list of embedding tables")
|
||||
tables_to_log: List[str] = pydantic.Field(
|
||||
None, description="list of embedding table names that we want to log during training"
|
||||
)
|
||||
|
||||
|
||||
class Mode(str, Enum):
|
||||
"""Job modes."""
|
||||
|
||||
TRAIN = "train"
|
||||
EVALUATE = "evaluate"
|
||||
INFERENCE = "inference"
|
58
common/modules/embedding/embedding.py
Normal file
58
common/modules/embedding/embedding.py
Normal file
@ -0,0 +1,58 @@
|
||||
from tml.common.modules.embedding.config import LargeEmbeddingsConfig, DataType
|
||||
from tml.ml_logging.torch_logging import logging
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torchrec
|
||||
from torchrec.modules import embedding_configs
|
||||
from torchrec import EmbeddingBagConfig, EmbeddingBagCollection
|
||||
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor
|
||||
import numpy as np
|
||||
|
||||
|
||||
class LargeEmbeddings(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
large_embeddings_config: LargeEmbeddingsConfig,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
tables = []
|
||||
for table in large_embeddings_config.tables:
|
||||
data_type = (
|
||||
embedding_configs.DataType.FP32
|
||||
if (table.data_type == DataType.FP32)
|
||||
else embedding_configs.DataType.FP16
|
||||
)
|
||||
|
||||
tables.append(
|
||||
EmbeddingBagConfig(
|
||||
embedding_dim=table.embedding_dim,
|
||||
feature_names=[table.name], # restricted to 1 feature per table for now
|
||||
name=table.name,
|
||||
num_embeddings=table.num_embeddings,
|
||||
pooling=torchrec.PoolingType.SUM,
|
||||
data_type=data_type,
|
||||
)
|
||||
)
|
||||
|
||||
self.ebc = EmbeddingBagCollection(
|
||||
device="meta",
|
||||
tables=tables,
|
||||
)
|
||||
|
||||
logging.info("********************** EBC named params are **********")
|
||||
logging.info(list(self.ebc.named_parameters()))
|
||||
|
||||
# This hook is used to perform post-processing surgery
|
||||
# on large_embedding models to prep them for serving
|
||||
self.surgery_cut_point = torch.nn.Identity()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sparse_features: KeyedJaggedTensor,
|
||||
) -> KeyedTensor:
|
||||
pooled_embs = self.ebc(sparse_features)
|
||||
|
||||
# a KeyedTensor
|
||||
return self.surgery_cut_point(pooled_embs)
|
Reference in New Issue
Block a user