Twitter's Recommendation Algorithm - Heavy Ranker and TwHIN embeddings

This commit is contained in:
twitter-team
2023-03-31 13:05:14 -05:00
commit 78c3235eee
111 changed files with 11876 additions and 0 deletions

View File

@ -0,0 +1,54 @@
import typing
import enum
from tml.common.modules.embedding.config import LargeEmbeddingsConfig
from tml.core.config import base_config
from tml.optimizers.config import OptimizerConfig
import pydantic
from pydantic import validator
class TwhinEmbeddingsConfig(LargeEmbeddingsConfig):
@validator("tables")
def embedding_dims_match(cls, tables):
embedding_dim = tables[0].embedding_dim
data_type = tables[0].data_type
for table in tables:
assert table.embedding_dim == embedding_dim, "Embedding dimensions for all nodes must match."
assert table.data_type == data_type, "Data types for all nodes must match."
return tables
class Operator(str, enum.Enum):
TRANSLATION = "translation"
class Relation(pydantic.BaseModel):
"""graph relationship properties and operator"""
name: str = pydantic.Field(..., description="Relationship name.")
lhs: str = pydantic.Field(
...,
description="Name of the entity on the left-hand-side of this relation. Must match a table name.",
)
rhs: str = pydantic.Field(
...,
description="Name of the entity on the right-hand-side of this relation. Must match a table name.",
)
operator: Operator = pydantic.Field(
Operator.TRANSLATION, description="Transformation to apply to lhs embedding before dot product."
)
class TwhinModelConfig(base_config.BaseConfig):
embeddings: TwhinEmbeddingsConfig
relations: typing.List[Relation]
translation_optimizer: OptimizerConfig
@validator("relations", each_item=True)
def valid_node_types(cls, relation, values, **kwargs):
table_names = [table.name for table in values["embeddings"].tables]
assert relation.lhs in table_names, f"Invalid lhs node type: {relation.lhs}"
assert relation.rhs in table_names, f"Invalid rhs node type: {relation.rhs}"
return relation

View File

@ -0,0 +1,172 @@
from typing import Callable
import math
from tml.projects.twhin.data.edges import EdgeBatch
from tml.projects.twhin.models.config import TwhinModelConfig
from tml.projects.twhin.data.config import TwhinDataConfig
from tml.common.modules.embedding.embedding import LargeEmbeddings
from tml.optimizers.optimizer import get_optimizer_class
from tml.optimizers.config import get_optimizer_algorithm_config
import torch
from torch import nn
from torchrec.optim.apply_optimizer_in_backward import apply_optimizer_in_backward
class TwhinModel(nn.Module):
def __init__(self, model_config: TwhinModelConfig, data_config: TwhinDataConfig):
super().__init__()
self.batch_size = data_config.per_replica_batch_size
self.table_names = [table.name for table in model_config.embeddings.tables]
self.large_embeddings = LargeEmbeddings(model_config.embeddings)
self.embedding_dim = model_config.embeddings.tables[0].embedding_dim
self.num_tables = len(model_config.embeddings.tables)
self.in_batch_negatives = data_config.in_batch_negatives
self.global_negatives = data_config.global_negatives
self.num_relations = len(model_config.relations)
# one bias per relation
self.all_trans_embs = torch.nn.parameter.Parameter(
torch.nn.init.uniform_(torch.empty(self.num_relations, self.embedding_dim))
)
def forward(self, batch: EdgeBatch):
# B x D
trans_embs = self.all_trans_embs.data[batch.rels]
# KeyedTensor
outs = self.large_embeddings(batch.nodes)
# 2B x TD
x = outs.values()
# 2B x T x D
x = x.reshape(2 * self.batch_size, -1, self.embedding_dim)
# 2B x D
x = torch.sum(x, 1)
# B x 2 x D
x = x.reshape(self.batch_size, 2, self.embedding_dim)
# translated
translated = x[:, 1, :] + trans_embs
negs = []
if self.in_batch_negatives:
# construct dot products for negatives via matmul
for relation in range(self.num_relations):
rel_mask = batch.rels == relation
rel_count = rel_mask.sum()
if not rel_count:
continue
# R x D
lhs_matrix = x[rel_mask, 0, :]
rhs_matrix = x[rel_mask, 1, :]
lhs_perm = torch.randperm(lhs_matrix.shape[0])
# repeat until we have enough negatives
lhs_perm = lhs_perm.repeat(math.ceil(float(self.in_batch_negatives) / rel_count))
lhs_indices = lhs_perm[: self.in_batch_negatives]
sampled_lhs = lhs_matrix[lhs_indices]
rhs_perm = torch.randperm(rhs_matrix.shape[0])
# repeat until we have enough negatives
rhs_perm = rhs_perm.repeat(math.ceil(float(self.in_batch_negatives) / rel_count))
rhs_indices = rhs_perm[: self.in_batch_negatives]
sampled_rhs = rhs_matrix[rhs_indices]
# RS
negs_rhs = torch.flatten(torch.matmul(lhs_matrix, sampled_rhs.t()))
negs_lhs = torch.flatten(torch.matmul(rhs_matrix, sampled_lhs.t()))
negs.append(negs_lhs)
negs.append(negs_rhs)
# dot product for positives
x = (x[:, 0, :] * translated).sum(-1)
# concat positives and negatives
x = torch.cat([x, *negs])
return {
"logits": x,
"probabilities": torch.sigmoid(x),
}
def apply_optimizers(model: TwhinModel, model_config: TwhinModelConfig):
for table in model_config.embeddings.tables:
optimizer_class = get_optimizer_class(table.optimizer)
optimizer_kwargs = get_optimizer_algorithm_config(table.optimizer).dict()
params = [
param
for name, param in model.large_embeddings.ebc.named_parameters()
if (name.startswith(f"embedding_bags.{table.name}"))
]
apply_optimizer_in_backward(
optimizer_class=optimizer_class,
params=params,
optimizer_kwargs=optimizer_kwargs,
)
return model
class TwhinModelAndLoss(torch.nn.Module):
def __init__(
self,
model,
loss_fn: Callable,
data_config: TwhinDataConfig,
device: torch.device,
) -> None:
"""
Args:
model: torch module to wrap.
loss_fn: Function for calculating loss, should accept logits and labels.
"""
super().__init__()
self.model = model
self.loss_fn = loss_fn
self.batch_size = data_config.per_replica_batch_size
self.in_batch_negatives = data_config.in_batch_negatives
self.device = device
def forward(self, batch: "RecapBatch"): # type: ignore[name-defined]
"""Runs model forward and calculates loss according to given loss_fn.
NOTE: The input signature here needs to be a Pipelineable object for
prefetching purposes during training using torchrec's pipeline. However
the underlying model signature needs to be exportable to onnx, requiring
generic python types. see https://pytorch.org/docs/stable/onnx.html#types.
"""
outputs = self.model(batch)
logits = outputs["logits"]
num_negatives = 2 * self.batch_size * self.in_batch_negatives
num_positives = self.batch_size
neg_weight = float(num_positives) / num_negatives
labels = torch.cat([batch.labels.float(), torch.ones(num_negatives).to(self.device)])
weights = torch.cat(
[batch.weights.float(), (torch.ones(num_negatives) * neg_weight).to(self.device)]
)
losses = self.loss_fn(logits, labels, weights)
outputs.update(
{
"loss": losses,
"labels": labels,
"weights": weights,
}
)
# Allow multiple losses.
return losses, outputs

View File

@ -0,0 +1,107 @@
from tml.projects.twhin.models.config import TwhinEmbeddingsConfig, TwhinModelConfig
from tml.projects.twhin.data.config import TwhinDataConfig
from tml.common.modules.embedding.config import DataType, EmbeddingBagConfig
from tml.optimizers.config import OptimizerConfig, SgdConfig
from tml.model import maybe_shard_model
from tml.projects.twhin.models.models import apply_optimizers, TwhinModel
from tml.projects.twhin.models.config import Operator, Relation
from tml.common.testing_utils import mock_pg
import torch
import torch.nn.functional as F
from pydantic import ValidationError
import pytest
NUM_EMBS = 10_000
EMB_DIM = 128
def twhin_model_config() -> TwhinModelConfig:
sgd_config_0 = OptimizerConfig(sgd=SgdConfig(lr=0.01))
sgd_config_1 = OptimizerConfig(sgd=SgdConfig(lr=0.02))
table0 = EmbeddingBagConfig(
name="table0",
num_embeddings=NUM_EMBS,
embedding_dim=EMB_DIM,
optimizer=sgd_config_0,
data_type=DataType.FP32,
)
table1 = EmbeddingBagConfig(
name="table1",
num_embeddings=NUM_EMBS,
embedding_dim=EMB_DIM,
optimizer=sgd_config_1,
data_type=DataType.FP32,
)
embeddings_config = TwhinEmbeddingsConfig(
tables=[table0, table1],
)
model_config = TwhinModelConfig(
embeddings=embeddings_config,
translation_optimizer=sgd_config_0,
relations=[
Relation(name="rel0", lhs="table0", rhs="table1", operator=Operator.TRANSLATION),
Relation(name="rel1", lhs="table1", rhs="table0", operator=Operator.TRANSLATION),
],
)
return model_config
def twhin_data_config() -> TwhinDataConfig:
data_config = TwhinDataConfig(
data_root="/",
per_replica_batch_size=10,
global_negatives=10,
in_batch_negatives=10,
limit=1,
offset=1,
)
return data_config
def test_twhin_model():
model_config = twhin_model_config()
loss_fn = F.binary_cross_entropy_with_logits
with mock_pg():
data_config = twhin_data_config()
model = TwhinModel(model_config=model_config, data_config=data_config)
apply_optimizers(model, model_config)
for tensor in model.state_dict().values():
if tensor.size() == (NUM_EMBS, EMB_DIM):
assert str(tensor.device) == "meta"
else:
assert str(tensor.device) == "cpu"
model = maybe_shard_model(model, device=torch.device("cpu"))
def test_unequal_dims():
sgd_config_1 = OptimizerConfig(sgd=SgdConfig(lr=0.02))
sgd_config_2 = OptimizerConfig(sgd=SgdConfig(lr=0.05))
table0 = EmbeddingBagConfig(
name="table0",
num_embeddings=10_000,
embedding_dim=128,
optimizer=sgd_config_1,
data_type=DataType.FP32,
)
table1 = EmbeddingBagConfig(
name="table1",
num_embeddings=10_000,
embedding_dim=64,
optimizer=sgd_config_2,
data_type=DataType.FP32,
)
with pytest.raises(ValidationError):
_ = TwhinEmbeddingsConfig(
tables=[table0, table1],
)