mirror of
https://github.com/twitter/the-algorithm-ml.git
synced 2024-12-24 15:21:47 +01:00
173 lines
5.3 KiB
Python
173 lines
5.3 KiB
Python
|
from typing import Callable
|
||
|
import math
|
||
|
|
||
|
from tml.projects.twhin.data.edges import EdgeBatch
|
||
|
from tml.projects.twhin.models.config import TwhinModelConfig
|
||
|
from tml.projects.twhin.data.config import TwhinDataConfig
|
||
|
from tml.common.modules.embedding.embedding import LargeEmbeddings
|
||
|
from tml.optimizers.optimizer import get_optimizer_class
|
||
|
from tml.optimizers.config import get_optimizer_algorithm_config
|
||
|
|
||
|
import torch
|
||
|
from torch import nn
|
||
|
from torchrec.optim.apply_optimizer_in_backward import apply_optimizer_in_backward
|
||
|
|
||
|
|
||
|
class TwhinModel(nn.Module):
|
||
|
def __init__(self, model_config: TwhinModelConfig, data_config: TwhinDataConfig):
|
||
|
super().__init__()
|
||
|
self.batch_size = data_config.per_replica_batch_size
|
||
|
self.table_names = [table.name for table in model_config.embeddings.tables]
|
||
|
self.large_embeddings = LargeEmbeddings(model_config.embeddings)
|
||
|
self.embedding_dim = model_config.embeddings.tables[0].embedding_dim
|
||
|
self.num_tables = len(model_config.embeddings.tables)
|
||
|
self.in_batch_negatives = data_config.in_batch_negatives
|
||
|
self.global_negatives = data_config.global_negatives
|
||
|
self.num_relations = len(model_config.relations)
|
||
|
|
||
|
# one bias per relation
|
||
|
self.all_trans_embs = torch.nn.parameter.Parameter(
|
||
|
torch.nn.init.uniform_(torch.empty(self.num_relations, self.embedding_dim))
|
||
|
)
|
||
|
|
||
|
def forward(self, batch: EdgeBatch):
|
||
|
|
||
|
# B x D
|
||
|
trans_embs = self.all_trans_embs.data[batch.rels]
|
||
|
|
||
|
# KeyedTensor
|
||
|
outs = self.large_embeddings(batch.nodes)
|
||
|
|
||
|
# 2B x TD
|
||
|
x = outs.values()
|
||
|
|
||
|
# 2B x T x D
|
||
|
x = x.reshape(2 * self.batch_size, -1, self.embedding_dim)
|
||
|
|
||
|
# 2B x D
|
||
|
x = torch.sum(x, 1)
|
||
|
|
||
|
# B x 2 x D
|
||
|
x = x.reshape(self.batch_size, 2, self.embedding_dim)
|
||
|
|
||
|
# translated
|
||
|
translated = x[:, 1, :] + trans_embs
|
||
|
|
||
|
negs = []
|
||
|
if self.in_batch_negatives:
|
||
|
# construct dot products for negatives via matmul
|
||
|
for relation in range(self.num_relations):
|
||
|
rel_mask = batch.rels == relation
|
||
|
rel_count = rel_mask.sum()
|
||
|
|
||
|
if not rel_count:
|
||
|
continue
|
||
|
|
||
|
# R x D
|
||
|
lhs_matrix = x[rel_mask, 0, :]
|
||
|
rhs_matrix = x[rel_mask, 1, :]
|
||
|
|
||
|
lhs_perm = torch.randperm(lhs_matrix.shape[0])
|
||
|
# repeat until we have enough negatives
|
||
|
lhs_perm = lhs_perm.repeat(math.ceil(float(self.in_batch_negatives) / rel_count))
|
||
|
lhs_indices = lhs_perm[: self.in_batch_negatives]
|
||
|
sampled_lhs = lhs_matrix[lhs_indices]
|
||
|
|
||
|
rhs_perm = torch.randperm(rhs_matrix.shape[0])
|
||
|
# repeat until we have enough negatives
|
||
|
rhs_perm = rhs_perm.repeat(math.ceil(float(self.in_batch_negatives) / rel_count))
|
||
|
rhs_indices = rhs_perm[: self.in_batch_negatives]
|
||
|
sampled_rhs = rhs_matrix[rhs_indices]
|
||
|
|
||
|
# RS
|
||
|
negs_rhs = torch.flatten(torch.matmul(lhs_matrix, sampled_rhs.t()))
|
||
|
negs_lhs = torch.flatten(torch.matmul(rhs_matrix, sampled_lhs.t()))
|
||
|
|
||
|
negs.append(negs_lhs)
|
||
|
negs.append(negs_rhs)
|
||
|
|
||
|
# dot product for positives
|
||
|
x = (x[:, 0, :] * translated).sum(-1)
|
||
|
|
||
|
# concat positives and negatives
|
||
|
x = torch.cat([x, *negs])
|
||
|
return {
|
||
|
"logits": x,
|
||
|
"probabilities": torch.sigmoid(x),
|
||
|
}
|
||
|
|
||
|
|
||
|
def apply_optimizers(model: TwhinModel, model_config: TwhinModelConfig):
|
||
|
for table in model_config.embeddings.tables:
|
||
|
optimizer_class = get_optimizer_class(table.optimizer)
|
||
|
optimizer_kwargs = get_optimizer_algorithm_config(table.optimizer).dict()
|
||
|
params = [
|
||
|
param
|
||
|
for name, param in model.large_embeddings.ebc.named_parameters()
|
||
|
if (name.startswith(f"embedding_bags.{table.name}"))
|
||
|
]
|
||
|
apply_optimizer_in_backward(
|
||
|
optimizer_class=optimizer_class,
|
||
|
params=params,
|
||
|
optimizer_kwargs=optimizer_kwargs,
|
||
|
)
|
||
|
|
||
|
return model
|
||
|
|
||
|
|
||
|
class TwhinModelAndLoss(torch.nn.Module):
|
||
|
def __init__(
|
||
|
self,
|
||
|
model,
|
||
|
loss_fn: Callable,
|
||
|
data_config: TwhinDataConfig,
|
||
|
device: torch.device,
|
||
|
) -> None:
|
||
|
"""
|
||
|
Args:
|
||
|
model: torch module to wrap.
|
||
|
loss_fn: Function for calculating loss, should accept logits and labels.
|
||
|
"""
|
||
|
super().__init__()
|
||
|
self.model = model
|
||
|
self.loss_fn = loss_fn
|
||
|
self.batch_size = data_config.per_replica_batch_size
|
||
|
self.in_batch_negatives = data_config.in_batch_negatives
|
||
|
self.device = device
|
||
|
|
||
|
def forward(self, batch: "RecapBatch"): # type: ignore[name-defined]
|
||
|
"""Runs model forward and calculates loss according to given loss_fn.
|
||
|
|
||
|
NOTE: The input signature here needs to be a Pipelineable object for
|
||
|
prefetching purposes during training using torchrec's pipeline. However
|
||
|
the underlying model signature needs to be exportable to onnx, requiring
|
||
|
generic python types. see https://pytorch.org/docs/stable/onnx.html#types.
|
||
|
|
||
|
"""
|
||
|
outputs = self.model(batch)
|
||
|
logits = outputs["logits"]
|
||
|
|
||
|
num_negatives = 2 * self.batch_size * self.in_batch_negatives
|
||
|
num_positives = self.batch_size
|
||
|
|
||
|
neg_weight = float(num_positives) / num_negatives
|
||
|
|
||
|
labels = torch.cat([batch.labels.float(), torch.ones(num_negatives).to(self.device)])
|
||
|
|
||
|
weights = torch.cat(
|
||
|
[batch.weights.float(), (torch.ones(num_negatives) * neg_weight).to(self.device)]
|
||
|
)
|
||
|
|
||
|
losses = self.loss_fn(logits, labels, weights)
|
||
|
|
||
|
outputs.update(
|
||
|
{
|
||
|
"loss": losses,
|
||
|
"labels": labels,
|
||
|
"weights": weights,
|
||
|
}
|
||
|
)
|
||
|
|
||
|
# Allow multiple losses.
|
||
|
return losses, outputs
|