mirror of
https://github.com/twitter/the-algorithm-ml.git
synced 2024-12-23 14:51:49 +01:00
59 lines
1.6 KiB
Python
59 lines
1.6 KiB
Python
from tml.common.modules.embedding.config import LargeEmbeddingsConfig, DataType
|
|
from tml.ml_logging.torch_logging import logging
|
|
|
|
import torch
|
|
from torch import nn
|
|
import torchrec
|
|
from torchrec.modules import embedding_configs
|
|
from torchrec import EmbeddingBagConfig, EmbeddingBagCollection
|
|
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor
|
|
import numpy as np
|
|
|
|
|
|
class LargeEmbeddings(nn.Module):
|
|
def __init__(
|
|
self,
|
|
large_embeddings_config: LargeEmbeddingsConfig,
|
|
):
|
|
super().__init__()
|
|
|
|
tables = []
|
|
for table in large_embeddings_config.tables:
|
|
data_type = (
|
|
embedding_configs.DataType.FP32
|
|
if (table.data_type == DataType.FP32)
|
|
else embedding_configs.DataType.FP16
|
|
)
|
|
|
|
tables.append(
|
|
EmbeddingBagConfig(
|
|
embedding_dim=table.embedding_dim,
|
|
feature_names=[table.name], # restricted to 1 feature per table for now
|
|
name=table.name,
|
|
num_embeddings=table.num_embeddings,
|
|
pooling=torchrec.PoolingType.SUM,
|
|
data_type=data_type,
|
|
)
|
|
)
|
|
|
|
self.ebc = EmbeddingBagCollection(
|
|
device="meta",
|
|
tables=tables,
|
|
)
|
|
|
|
logging.info("********************** EBC named params are **********")
|
|
logging.info(list(self.ebc.named_parameters()))
|
|
|
|
# This hook is used to perform post-processing surgery
|
|
# on large_embedding models to prep them for serving
|
|
self.surgery_cut_point = torch.nn.Identity()
|
|
|
|
def forward(
|
|
self,
|
|
sparse_features: KeyedJaggedTensor,
|
|
) -> KeyedTensor:
|
|
pooled_embs = self.ebc(sparse_features)
|
|
|
|
# a KeyedTensor
|
|
return self.surgery_cut_point(pooled_embs)
|