173 lines
5.3 KiB
Python

from typing import Callable
import math
from tml.projects.twhin.data.edges import EdgeBatch
from tml.projects.twhin.models.config import TwhinModelConfig
from tml.projects.twhin.data.config import TwhinDataConfig
from tml.common.modules.embedding.embedding import LargeEmbeddings
from tml.optimizers.optimizer import get_optimizer_class
from tml.optimizers.config import get_optimizer_algorithm_config
import torch
from torch import nn
from torchrec.optim.apply_optimizer_in_backward import apply_optimizer_in_backward
class TwhinModel(nn.Module):
def __init__(self, model_config: TwhinModelConfig, data_config: TwhinDataConfig):
super().__init__()
self.batch_size = data_config.per_replica_batch_size
self.table_names = [table.name for table in model_config.embeddings.tables]
self.large_embeddings = LargeEmbeddings(model_config.embeddings)
self.embedding_dim = model_config.embeddings.tables[0].embedding_dim
self.num_tables = len(model_config.embeddings.tables)
self.in_batch_negatives = data_config.in_batch_negatives
self.global_negatives = data_config.global_negatives
self.num_relations = len(model_config.relations)
# one bias per relation
self.all_trans_embs = torch.nn.parameter.Parameter(
torch.nn.init.uniform_(torch.empty(self.num_relations, self.embedding_dim))
)
def forward(self, batch: EdgeBatch):
# B x D
trans_embs = self.all_trans_embs.data[batch.rels]
# KeyedTensor
outs = self.large_embeddings(batch.nodes)
# 2B x TD
x = outs.values()
# 2B x T x D
x = x.reshape(2 * self.batch_size, -1, self.embedding_dim)
# 2B x D
x = torch.sum(x, 1)
# B x 2 x D
x = x.reshape(self.batch_size, 2, self.embedding_dim)
# translated
translated = x[:, 1, :] + trans_embs
negs = []
if self.in_batch_negatives:
# construct dot products for negatives via matmul
for relation in range(self.num_relations):
rel_mask = batch.rels == relation
rel_count = rel_mask.sum()
if not rel_count:
continue
# R x D
lhs_matrix = x[rel_mask, 0, :]
rhs_matrix = x[rel_mask, 1, :]
lhs_perm = torch.randperm(lhs_matrix.shape[0])
# repeat until we have enough negatives
lhs_perm = lhs_perm.repeat(math.ceil(float(self.in_batch_negatives) / rel_count))
lhs_indices = lhs_perm[: self.in_batch_negatives]
sampled_lhs = lhs_matrix[lhs_indices]
rhs_perm = torch.randperm(rhs_matrix.shape[0])
# repeat until we have enough negatives
rhs_perm = rhs_perm.repeat(math.ceil(float(self.in_batch_negatives) / rel_count))
rhs_indices = rhs_perm[: self.in_batch_negatives]
sampled_rhs = rhs_matrix[rhs_indices]
# RS
negs_rhs = torch.flatten(torch.matmul(lhs_matrix, sampled_rhs.t()))
negs_lhs = torch.flatten(torch.matmul(rhs_matrix, sampled_lhs.t()))
negs.append(negs_lhs)
negs.append(negs_rhs)
# dot product for positives
x = (x[:, 0, :] * translated).sum(-1)
# concat positives and negatives
x = torch.cat([x, *negs])
return {
"logits": x,
"probabilities": torch.sigmoid(x),
}
def apply_optimizers(model: TwhinModel, model_config: TwhinModelConfig):
for table in model_config.embeddings.tables:
optimizer_class = get_optimizer_class(table.optimizer)
optimizer_kwargs = get_optimizer_algorithm_config(table.optimizer).dict()
params = [
param
for name, param in model.large_embeddings.ebc.named_parameters()
if (name.startswith(f"embedding_bags.{table.name}"))
]
apply_optimizer_in_backward(
optimizer_class=optimizer_class,
params=params,
optimizer_kwargs=optimizer_kwargs,
)
return model
class TwhinModelAndLoss(torch.nn.Module):
def __init__(
self,
model,
loss_fn: Callable,
data_config: TwhinDataConfig,
device: torch.device,
) -> None:
"""
Args:
model: torch module to wrap.
loss_fn: Function for calculating loss, should accept logits and labels.
"""
super().__init__()
self.model = model
self.loss_fn = loss_fn
self.batch_size = data_config.per_replica_batch_size
self.in_batch_negatives = data_config.in_batch_negatives
self.device = device
def forward(self, batch: "RecapBatch"): # type: ignore[name-defined]
"""Runs model forward and calculates loss according to given loss_fn.
NOTE: The input signature here needs to be a Pipelineable object for
prefetching purposes during training using torchrec's pipeline. However
the underlying model signature needs to be exportable to onnx, requiring
generic python types. see https://pytorch.org/docs/stable/onnx.html#types.
"""
outputs = self.model(batch)
logits = outputs["logits"]
num_negatives = 2 * self.batch_size * self.in_batch_negatives
num_positives = self.batch_size
neg_weight = float(num_positives) / num_negatives
labels = torch.cat([batch.labels.float(), torch.ones(num_negatives).to(self.device)])
weights = torch.cat(
[batch.weights.float(), (torch.ones(num_negatives) * neg_weight).to(self.device)]
)
losses = self.loss_fn(logits, labels, weights)
outputs.update(
{
"loss": losses,
"labels": labels,
"weights": weights,
}
)
# Allow multiple losses.
return losses, outputs