the-algorithm-ml/projects/twhin/models/models.py
2023-09-13 11:22:13 +05:30

228 lines
7.8 KiB
Python

from typing import Callable
import math
from tml.projects.twhin.data.edges import EdgeBatch
from tml.projects.twhin.models.config import TwhinModelConfig
from tml.projects.twhin.data.config import TwhinDataConfig
from tml.common.modules.embedding.embedding import LargeEmbeddings
from tml.optimizers.optimizer import get_optimizer_class
from tml.optimizers.config import get_optimizer_algorithm_config
import torch
from torch import nn
from torchrec.optim.apply_optimizer_in_backward import apply_optimizer_in_backward
class TwhinModel(nn.Module):
"""
Twhin model for graph-based entity embeddings and translation.
This class defines the Twhin model, which is used for learning embeddings of entities in a graph
and applying translations to these embeddings based on graph relationships.
Args:
model_config (TwhinModelConfig): Configuration for the Twhin model.
data_config (TwhinDataConfig): Configuration for the data used by the model.
Attributes:
batch_size (int): The batch size used for training.
table_names (List[str]): Names of tables in the model's embeddings.
large_embeddings (LargeEmbeddings): LargeEmbeddings instance for entity embeddings.
embedding_dim (int): Dimensionality of entity embeddings.
num_tables (int): Number of tables in the model's embeddings.
in_batch_negatives (int): Number of in-batch negative samples to use during training.
global_negatives (int): Number of global negative samples to use during training.
num_relations (int): Number of graph relationships in the model.
all_trans_embs (torch.nn.Parameter): Parameter tensor for translation embeddings.
"""
def __init__(self, model_config: TwhinModelConfig, data_config: TwhinDataConfig):
super().__init__()
self.batch_size = data_config.per_replica_batch_size
self.table_names = [table.name for table in model_config.embeddings.tables]
self.large_embeddings = LargeEmbeddings(model_config.embeddings)
self.embedding_dim = model_config.embeddings.tables[0].embedding_dim
self.num_tables = len(model_config.embeddings.tables)
self.in_batch_negatives = data_config.in_batch_negatives
self.global_negatives = data_config.global_negatives
self.num_relations = len(model_config.relations)
# one bias per relation
self.all_trans_embs = torch.nn.parameter.Parameter(
torch.nn.init.uniform_(torch.empty(self.num_relations, self.embedding_dim))
)
def forward(self, batch: EdgeBatch):
"""
Forward pass of the Twhin model.
Args:
batch (EdgeBatch): Input batch containing graph edge information.
Returns:
dict: A dictionary containing model output with "logits" and "probabilities".
- "logits" (torch.Tensor): Logit scores.
- "probabilities" (torch.Tensor): Sigmoid probabilities.
"""
# B x D
trans_embs = self.all_trans_embs.data[batch.rels]
# KeyedTensor
outs = self.large_embeddings(batch.nodes)
# 2B x TD
x = outs.values()
# 2B x T x D
x = x.reshape(2 * self.batch_size, -1, self.embedding_dim)
# 2B x D
x = torch.sum(x, 1)
# B x 2 x D
x = x.reshape(self.batch_size, 2, self.embedding_dim)
# translated
translated = x[:, 1, :] + trans_embs
negs = []
if self.in_batch_negatives:
# construct dot products for negatives via matmul
for relation in range(self.num_relations):
rel_mask = batch.rels == relation
rel_count = rel_mask.sum()
if not rel_count:
continue
# R x D
lhs_matrix = x[rel_mask, 0, :]
rhs_matrix = x[rel_mask, 1, :]
lhs_perm = torch.randperm(lhs_matrix.shape[0])
# repeat until we have enough negatives
lhs_perm = lhs_perm.repeat(math.ceil(float(self.in_batch_negatives) / rel_count))
lhs_indices = lhs_perm[: self.in_batch_negatives]
sampled_lhs = lhs_matrix[lhs_indices]
rhs_perm = torch.randperm(rhs_matrix.shape[0])
# repeat until we have enough negatives
rhs_perm = rhs_perm.repeat(math.ceil(float(self.in_batch_negatives) / rel_count))
rhs_indices = rhs_perm[: self.in_batch_negatives]
sampled_rhs = rhs_matrix[rhs_indices]
# RS
negs_rhs = torch.flatten(torch.matmul(lhs_matrix, sampled_rhs.t()))
negs_lhs = torch.flatten(torch.matmul(rhs_matrix, sampled_lhs.t()))
negs.append(negs_lhs)
negs.append(negs_rhs)
# dot product for positives
x = (x[:, 0, :] * translated).sum(-1)
# concat positives and negatives
x = torch.cat([x, *negs])
return {
"logits": x,
"probabilities": torch.sigmoid(x),
}
def apply_optimizers(model: TwhinModel, model_config: TwhinModelConfig):
"""
Apply optimizers to the Twhin model's embeddings.
This function applies optimizers to the embeddings of the Twhin model based on the provided configuration.
Args:
model (TwhinModel): The Twhin model to apply optimizers to.
model_config (TwhinModelConfig): Configuration for the Twhin model.
Returns:
TwhinModel: The Twhin model with optimizers applied to its embeddings.
"""
for table in model_config.embeddings.tables:
optimizer_class = get_optimizer_class(table.optimizer)
optimizer_kwargs = get_optimizer_algorithm_config(table.optimizer).dict()
params = [
param
for name, param in model.large_embeddings.ebc.named_parameters()
if (name.startswith(f"embedding_bags.{table.name}"))
]
apply_optimizer_in_backward(
optimizer_class=optimizer_class,
params=params,
optimizer_kwargs=optimizer_kwargs,
)
return model
class TwhinModelAndLoss(torch.nn.Module):
def __init__(
self,
model,
loss_fn: Callable,
data_config: TwhinDataConfig,
device: torch.device,
) -> None:
"""
Initialize a TwhinModelAndLoss module.
Args:
model: The torch module to wrap.
loss_fn: A function for calculating loss, should accept logits and labels.
data_config: Configuration for Twhin data.
device: The torch device to use for calculations.
"""
super().__init__()
self.model = model
self.loss_fn = loss_fn
self.batch_size = data_config.per_replica_batch_size
self.in_batch_negatives = data_config.in_batch_negatives
self.device = device
def forward(self, batch: "RecapBatch"): # type: ignore[name-defined]
"""
Run the model forward and calculate the loss according to the given loss_fn.
NOTE: The input signature here needs to be a Pipelineable object for
prefetching purposes during training using torchrec's pipeline. However
the underlying model signature needs to be exportable to onnx, requiring
generic python types. see https://pytorch.org/docs/stable/onnx.html#types
Args:
batch ("RecapBatch"): The input batch for model inference.
Returns:
Tuple[torch.Tensor, Dict[str, torch.Tensor]]: A tuple containing the loss tensor and a dictionary of
additional outputs including logits, labels, and weights.
"""
outputs = self.model(batch)
logits = outputs["logits"]
num_negatives = 2 * self.batch_size * self.in_batch_negatives
num_positives = self.batch_size
neg_weight = float(num_positives) / num_negatives
labels = torch.cat([batch.labels.float(), torch.ones(num_negatives).to(self.device)])
weights = torch.cat(
[batch.weights.float(), (torch.ones(num_negatives) * neg_weight).to(self.device)]
)
losses = self.loss_fn(logits, labels, weights)
outputs.update(
{
"loss": losses,
"labels": labels,
"weights": weights,
}
)
# Allow multiple losses.
return losses, outputs