mirror of
https://github.com/twitter/the-algorithm-ml.git
synced 2025-06-13 12:58:39 +02:00
Twitter's Recommendation Algorithm - Heavy Ranker and TwHIN embeddings
This commit is contained in:
54
projects/twhin/models/config.py
Normal file
54
projects/twhin/models/config.py
Normal file
@ -0,0 +1,54 @@
|
||||
import typing
|
||||
import enum
|
||||
|
||||
from tml.common.modules.embedding.config import LargeEmbeddingsConfig
|
||||
from tml.core.config import base_config
|
||||
from tml.optimizers.config import OptimizerConfig
|
||||
|
||||
import pydantic
|
||||
from pydantic import validator
|
||||
|
||||
|
||||
class TwhinEmbeddingsConfig(LargeEmbeddingsConfig):
|
||||
@validator("tables")
|
||||
def embedding_dims_match(cls, tables):
|
||||
embedding_dim = tables[0].embedding_dim
|
||||
data_type = tables[0].data_type
|
||||
for table in tables:
|
||||
assert table.embedding_dim == embedding_dim, "Embedding dimensions for all nodes must match."
|
||||
assert table.data_type == data_type, "Data types for all nodes must match."
|
||||
return tables
|
||||
|
||||
|
||||
class Operator(str, enum.Enum):
|
||||
TRANSLATION = "translation"
|
||||
|
||||
|
||||
class Relation(pydantic.BaseModel):
|
||||
"""graph relationship properties and operator"""
|
||||
|
||||
name: str = pydantic.Field(..., description="Relationship name.")
|
||||
lhs: str = pydantic.Field(
|
||||
...,
|
||||
description="Name of the entity on the left-hand-side of this relation. Must match a table name.",
|
||||
)
|
||||
rhs: str = pydantic.Field(
|
||||
...,
|
||||
description="Name of the entity on the right-hand-side of this relation. Must match a table name.",
|
||||
)
|
||||
operator: Operator = pydantic.Field(
|
||||
Operator.TRANSLATION, description="Transformation to apply to lhs embedding before dot product."
|
||||
)
|
||||
|
||||
|
||||
class TwhinModelConfig(base_config.BaseConfig):
|
||||
embeddings: TwhinEmbeddingsConfig
|
||||
relations: typing.List[Relation]
|
||||
translation_optimizer: OptimizerConfig
|
||||
|
||||
@validator("relations", each_item=True)
|
||||
def valid_node_types(cls, relation, values, **kwargs):
|
||||
table_names = [table.name for table in values["embeddings"].tables]
|
||||
assert relation.lhs in table_names, f"Invalid lhs node type: {relation.lhs}"
|
||||
assert relation.rhs in table_names, f"Invalid rhs node type: {relation.rhs}"
|
||||
return relation
|
172
projects/twhin/models/models.py
Normal file
172
projects/twhin/models/models.py
Normal file
@ -0,0 +1,172 @@
|
||||
from typing import Callable
|
||||
import math
|
||||
|
||||
from tml.projects.twhin.data.edges import EdgeBatch
|
||||
from tml.projects.twhin.models.config import TwhinModelConfig
|
||||
from tml.projects.twhin.data.config import TwhinDataConfig
|
||||
from tml.common.modules.embedding.embedding import LargeEmbeddings
|
||||
from tml.optimizers.optimizer import get_optimizer_class
|
||||
from tml.optimizers.config import get_optimizer_algorithm_config
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torchrec.optim.apply_optimizer_in_backward import apply_optimizer_in_backward
|
||||
|
||||
|
||||
class TwhinModel(nn.Module):
|
||||
def __init__(self, model_config: TwhinModelConfig, data_config: TwhinDataConfig):
|
||||
super().__init__()
|
||||
self.batch_size = data_config.per_replica_batch_size
|
||||
self.table_names = [table.name for table in model_config.embeddings.tables]
|
||||
self.large_embeddings = LargeEmbeddings(model_config.embeddings)
|
||||
self.embedding_dim = model_config.embeddings.tables[0].embedding_dim
|
||||
self.num_tables = len(model_config.embeddings.tables)
|
||||
self.in_batch_negatives = data_config.in_batch_negatives
|
||||
self.global_negatives = data_config.global_negatives
|
||||
self.num_relations = len(model_config.relations)
|
||||
|
||||
# one bias per relation
|
||||
self.all_trans_embs = torch.nn.parameter.Parameter(
|
||||
torch.nn.init.uniform_(torch.empty(self.num_relations, self.embedding_dim))
|
||||
)
|
||||
|
||||
def forward(self, batch: EdgeBatch):
|
||||
|
||||
# B x D
|
||||
trans_embs = self.all_trans_embs.data[batch.rels]
|
||||
|
||||
# KeyedTensor
|
||||
outs = self.large_embeddings(batch.nodes)
|
||||
|
||||
# 2B x TD
|
||||
x = outs.values()
|
||||
|
||||
# 2B x T x D
|
||||
x = x.reshape(2 * self.batch_size, -1, self.embedding_dim)
|
||||
|
||||
# 2B x D
|
||||
x = torch.sum(x, 1)
|
||||
|
||||
# B x 2 x D
|
||||
x = x.reshape(self.batch_size, 2, self.embedding_dim)
|
||||
|
||||
# translated
|
||||
translated = x[:, 1, :] + trans_embs
|
||||
|
||||
negs = []
|
||||
if self.in_batch_negatives:
|
||||
# construct dot products for negatives via matmul
|
||||
for relation in range(self.num_relations):
|
||||
rel_mask = batch.rels == relation
|
||||
rel_count = rel_mask.sum()
|
||||
|
||||
if not rel_count:
|
||||
continue
|
||||
|
||||
# R x D
|
||||
lhs_matrix = x[rel_mask, 0, :]
|
||||
rhs_matrix = x[rel_mask, 1, :]
|
||||
|
||||
lhs_perm = torch.randperm(lhs_matrix.shape[0])
|
||||
# repeat until we have enough negatives
|
||||
lhs_perm = lhs_perm.repeat(math.ceil(float(self.in_batch_negatives) / rel_count))
|
||||
lhs_indices = lhs_perm[: self.in_batch_negatives]
|
||||
sampled_lhs = lhs_matrix[lhs_indices]
|
||||
|
||||
rhs_perm = torch.randperm(rhs_matrix.shape[0])
|
||||
# repeat until we have enough negatives
|
||||
rhs_perm = rhs_perm.repeat(math.ceil(float(self.in_batch_negatives) / rel_count))
|
||||
rhs_indices = rhs_perm[: self.in_batch_negatives]
|
||||
sampled_rhs = rhs_matrix[rhs_indices]
|
||||
|
||||
# RS
|
||||
negs_rhs = torch.flatten(torch.matmul(lhs_matrix, sampled_rhs.t()))
|
||||
negs_lhs = torch.flatten(torch.matmul(rhs_matrix, sampled_lhs.t()))
|
||||
|
||||
negs.append(negs_lhs)
|
||||
negs.append(negs_rhs)
|
||||
|
||||
# dot product for positives
|
||||
x = (x[:, 0, :] * translated).sum(-1)
|
||||
|
||||
# concat positives and negatives
|
||||
x = torch.cat([x, *negs])
|
||||
return {
|
||||
"logits": x,
|
||||
"probabilities": torch.sigmoid(x),
|
||||
}
|
||||
|
||||
|
||||
def apply_optimizers(model: TwhinModel, model_config: TwhinModelConfig):
|
||||
for table in model_config.embeddings.tables:
|
||||
optimizer_class = get_optimizer_class(table.optimizer)
|
||||
optimizer_kwargs = get_optimizer_algorithm_config(table.optimizer).dict()
|
||||
params = [
|
||||
param
|
||||
for name, param in model.large_embeddings.ebc.named_parameters()
|
||||
if (name.startswith(f"embedding_bags.{table.name}"))
|
||||
]
|
||||
apply_optimizer_in_backward(
|
||||
optimizer_class=optimizer_class,
|
||||
params=params,
|
||||
optimizer_kwargs=optimizer_kwargs,
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
class TwhinModelAndLoss(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
loss_fn: Callable,
|
||||
data_config: TwhinDataConfig,
|
||||
device: torch.device,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
model: torch module to wrap.
|
||||
loss_fn: Function for calculating loss, should accept logits and labels.
|
||||
"""
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.loss_fn = loss_fn
|
||||
self.batch_size = data_config.per_replica_batch_size
|
||||
self.in_batch_negatives = data_config.in_batch_negatives
|
||||
self.device = device
|
||||
|
||||
def forward(self, batch: "RecapBatch"): # type: ignore[name-defined]
|
||||
"""Runs model forward and calculates loss according to given loss_fn.
|
||||
|
||||
NOTE: The input signature here needs to be a Pipelineable object for
|
||||
prefetching purposes during training using torchrec's pipeline. However
|
||||
the underlying model signature needs to be exportable to onnx, requiring
|
||||
generic python types. see https://pytorch.org/docs/stable/onnx.html#types.
|
||||
|
||||
"""
|
||||
outputs = self.model(batch)
|
||||
logits = outputs["logits"]
|
||||
|
||||
num_negatives = 2 * self.batch_size * self.in_batch_negatives
|
||||
num_positives = self.batch_size
|
||||
|
||||
neg_weight = float(num_positives) / num_negatives
|
||||
|
||||
labels = torch.cat([batch.labels.float(), torch.ones(num_negatives).to(self.device)])
|
||||
|
||||
weights = torch.cat(
|
||||
[batch.weights.float(), (torch.ones(num_negatives) * neg_weight).to(self.device)]
|
||||
)
|
||||
|
||||
losses = self.loss_fn(logits, labels, weights)
|
||||
|
||||
outputs.update(
|
||||
{
|
||||
"loss": losses,
|
||||
"labels": labels,
|
||||
"weights": weights,
|
||||
}
|
||||
)
|
||||
|
||||
# Allow multiple losses.
|
||||
return losses, outputs
|
107
projects/twhin/models/test_models.py
Normal file
107
projects/twhin/models/test_models.py
Normal file
@ -0,0 +1,107 @@
|
||||
from tml.projects.twhin.models.config import TwhinEmbeddingsConfig, TwhinModelConfig
|
||||
from tml.projects.twhin.data.config import TwhinDataConfig
|
||||
from tml.common.modules.embedding.config import DataType, EmbeddingBagConfig
|
||||
from tml.optimizers.config import OptimizerConfig, SgdConfig
|
||||
from tml.model import maybe_shard_model
|
||||
from tml.projects.twhin.models.models import apply_optimizers, TwhinModel
|
||||
from tml.projects.twhin.models.config import Operator, Relation
|
||||
from tml.common.testing_utils import mock_pg
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from pydantic import ValidationError
|
||||
import pytest
|
||||
|
||||
|
||||
NUM_EMBS = 10_000
|
||||
EMB_DIM = 128
|
||||
|
||||
|
||||
def twhin_model_config() -> TwhinModelConfig:
|
||||
sgd_config_0 = OptimizerConfig(sgd=SgdConfig(lr=0.01))
|
||||
sgd_config_1 = OptimizerConfig(sgd=SgdConfig(lr=0.02))
|
||||
|
||||
table0 = EmbeddingBagConfig(
|
||||
name="table0",
|
||||
num_embeddings=NUM_EMBS,
|
||||
embedding_dim=EMB_DIM,
|
||||
optimizer=sgd_config_0,
|
||||
data_type=DataType.FP32,
|
||||
)
|
||||
table1 = EmbeddingBagConfig(
|
||||
name="table1",
|
||||
num_embeddings=NUM_EMBS,
|
||||
embedding_dim=EMB_DIM,
|
||||
optimizer=sgd_config_1,
|
||||
data_type=DataType.FP32,
|
||||
)
|
||||
embeddings_config = TwhinEmbeddingsConfig(
|
||||
tables=[table0, table1],
|
||||
)
|
||||
|
||||
model_config = TwhinModelConfig(
|
||||
embeddings=embeddings_config,
|
||||
translation_optimizer=sgd_config_0,
|
||||
relations=[
|
||||
Relation(name="rel0", lhs="table0", rhs="table1", operator=Operator.TRANSLATION),
|
||||
Relation(name="rel1", lhs="table1", rhs="table0", operator=Operator.TRANSLATION),
|
||||
],
|
||||
)
|
||||
|
||||
return model_config
|
||||
|
||||
|
||||
def twhin_data_config() -> TwhinDataConfig:
|
||||
data_config = TwhinDataConfig(
|
||||
data_root="/",
|
||||
per_replica_batch_size=10,
|
||||
global_negatives=10,
|
||||
in_batch_negatives=10,
|
||||
limit=1,
|
||||
offset=1,
|
||||
)
|
||||
|
||||
return data_config
|
||||
|
||||
|
||||
def test_twhin_model():
|
||||
model_config = twhin_model_config()
|
||||
loss_fn = F.binary_cross_entropy_with_logits
|
||||
|
||||
with mock_pg():
|
||||
data_config = twhin_data_config()
|
||||
model = TwhinModel(model_config=model_config, data_config=data_config)
|
||||
|
||||
apply_optimizers(model, model_config)
|
||||
|
||||
for tensor in model.state_dict().values():
|
||||
if tensor.size() == (NUM_EMBS, EMB_DIM):
|
||||
assert str(tensor.device) == "meta"
|
||||
else:
|
||||
assert str(tensor.device) == "cpu"
|
||||
|
||||
model = maybe_shard_model(model, device=torch.device("cpu"))
|
||||
|
||||
|
||||
def test_unequal_dims():
|
||||
sgd_config_1 = OptimizerConfig(sgd=SgdConfig(lr=0.02))
|
||||
sgd_config_2 = OptimizerConfig(sgd=SgdConfig(lr=0.05))
|
||||
table0 = EmbeddingBagConfig(
|
||||
name="table0",
|
||||
num_embeddings=10_000,
|
||||
embedding_dim=128,
|
||||
optimizer=sgd_config_1,
|
||||
data_type=DataType.FP32,
|
||||
)
|
||||
table1 = EmbeddingBagConfig(
|
||||
name="table1",
|
||||
num_embeddings=10_000,
|
||||
embedding_dim=64,
|
||||
optimizer=sgd_config_2,
|
||||
data_type=DataType.FP32,
|
||||
)
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
_ = TwhinEmbeddingsConfig(
|
||||
tables=[table0, table1],
|
||||
)
|
Reference in New Issue
Block a user