Twitter's Recommendation Algorithm - Heavy Ranker and TwHIN embeddings

This commit is contained in:
twitter-team
2023-03-31 13:05:14 -05:00
commit 78c3235eee
111 changed files with 11876 additions and 0 deletions

View File

@ -0,0 +1,59 @@
from typing import List
from enum import Enum
import tml.core.config as base_config
from tml.optimizers.config import OptimizerConfig
import pydantic
class DataType(str, Enum):
FP32 = "fp32"
FP16 = "fp16"
class EmbeddingSnapshot(base_config.BaseConfig):
"""Configuration for Embedding snapshot"""
emb_name: str = pydantic.Field(
..., description="Name of the embedding table from the loaded snapshot"
)
embedding_snapshot_uri: str = pydantic.Field(
..., description="Path to torchsnapshot of the embedding"
)
class EmbeddingBagConfig(base_config.BaseConfig):
"""Configuration for EmbeddingBag."""
name: str = pydantic.Field(..., description="name of embedding bag")
num_embeddings: int = pydantic.Field(..., description="size of embedding dictionary")
embedding_dim: int = pydantic.Field(..., description="size of each embedding vector")
pretrained: EmbeddingSnapshot = pydantic.Field(None, description="Snapshot properties")
vocab: str = pydantic.Field(
None, description="Directory to parquet files of mapping from entity ID to table index."
)
# make sure to use an optimizer that matches:
# https://github.com/pytorch/FBGEMM/blob/4c58137529d221390575e47e88d3c05ce65b66fd/fbgemm_gpu/fbgemm_gpu/split_embedding_configs.py#L15
optimizer: OptimizerConfig
data_type: DataType
class LargeEmbeddingsConfig(base_config.BaseConfig):
"""Configuration for EmbeddingBagCollection.
The tables listed in this config are gathered into a single torchrec EmbeddingBagCollection.
"""
tables: List[EmbeddingBagConfig] = pydantic.Field(..., description="list of embedding tables")
tables_to_log: List[str] = pydantic.Field(
None, description="list of embedding table names that we want to log during training"
)
class Mode(str, Enum):
"""Job modes."""
TRAIN = "train"
EVALUATE = "evaluate"
INFERENCE = "inference"

View File

@ -0,0 +1,58 @@
from tml.common.modules.embedding.config import LargeEmbeddingsConfig, DataType
from tml.ml_logging.torch_logging import logging
import torch
from torch import nn
import torchrec
from torchrec.modules import embedding_configs
from torchrec import EmbeddingBagConfig, EmbeddingBagCollection
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor
import numpy as np
class LargeEmbeddings(nn.Module):
def __init__(
self,
large_embeddings_config: LargeEmbeddingsConfig,
):
super().__init__()
tables = []
for table in large_embeddings_config.tables:
data_type = (
embedding_configs.DataType.FP32
if (table.data_type == DataType.FP32)
else embedding_configs.DataType.FP16
)
tables.append(
EmbeddingBagConfig(
embedding_dim=table.embedding_dim,
feature_names=[table.name], # restricted to 1 feature per table for now
name=table.name,
num_embeddings=table.num_embeddings,
pooling=torchrec.PoolingType.SUM,
data_type=data_type,
)
)
self.ebc = EmbeddingBagCollection(
device="meta",
tables=tables,
)
logging.info("********************** EBC named params are **********")
logging.info(list(self.ebc.named_parameters()))
# This hook is used to perform post-processing surgery
# on large_embedding models to prep them for serving
self.surgery_cut_point = torch.nn.Identity()
def forward(
self,
sparse_features: KeyedJaggedTensor,
) -> KeyedTensor:
pooled_embs = self.ebc(sparse_features)
# a KeyedTensor
return self.surgery_cut_point(pooled_embs)