170 lines
5.0 KiB
Python
Raw Normal View History

from collections import defaultdict
from dataclasses import dataclass
from typing import Dict, List, Tuple
from tml.common.batch import DataclassBatch
from tml.reader.dataset import Dataset
from tml.projects.twhin.models.config import Relation
import numpy as np
import pyarrow as pa
import pyarrow.compute as pc
import torch
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
@dataclass
class EdgeBatch(DataclassBatch):
nodes: KeyedJaggedTensor
labels: torch.Tensor
rels: torch.Tensor
weights: torch.Tensor
class EdgesDataset(Dataset):
rng = np.random.default_rng()
def __init__(
self,
file_pattern: str,
table_sizes: Dict[str, int],
relations: List[Relation],
lhs_column_name: str = "lhs",
rhs_column_name: str = "rhs",
rel_column_name: str = "rel",
**dataset_kwargs
):
self.batch_size = dataset_kwargs["batch_size"]
self.table_sizes = table_sizes
self.num_tables = len(table_sizes)
self.table_names = list(table_sizes.keys())
self.relations = relations
self.relations_t = torch.tensor(
[
[self.table_names.index(relation.lhs), self.table_names.index(relation.rhs)]
for relation in relations
]
)
self.lhs_column_name = lhs_column_name
self.rhs_column_name = rhs_column_name
self.rel_column_name = rel_column_name
self.label_column_name = "label"
super().__init__(file_pattern=file_pattern, **dataset_kwargs)
def pa_to_batch(self, batch: pa.RecordBatch):
lhs = torch.from_numpy(batch.column(self.lhs_column_name).to_numpy())
rhs = torch.from_numpy(batch.column(self.rhs_column_name).to_numpy())
rel = torch.from_numpy(batch.column(self.rel_column_name).to_numpy())
label = torch.from_numpy(batch.column(self.label_column_name).to_numpy())
nodes = self._to_kjt(lhs, rhs, rel)
return EdgeBatch(
nodes=nodes,
rels=rel,
labels=label,
weights=torch.ones(batch.num_rows),
)
def _to_kjt(
self, lhs: torch.Tensor, rhs: torch.Tensor, rel: torch.Tensor
) -> Tuple[KeyedJaggedTensor, List[Tuple[int, int]]]:
"""Process edges that contain lhs index, rhs index, relation index.
Example:
```
tables = ["f0", "f1", "f2", "f3"]
relations = [["f0", "f1"], ["f1", "f2"], ["f1", "f0"], ["f2", "f1"], ["f0", "f2"]]
self.relations_t = torch.Tensor([[0, 1], [1, 2], [1, 0], [2, 1], [0, 2]])
lhs = [1, 6, 3, 1, 8]
rhs = [6, 3, 4, 4, 9]
rel = [0, 2, 1, 3, 4]
This corresponds to the following "edges":
edges = [
{"lhs": 1, "rhs": 6, "relation": ["f0", "f1"]},
{"lhs": 6, "rhs": 3, "relation": ["f1", "f0"]},
{"lhs": 3, "rhs": 4, "relation": ["f1", "f2"]},
{"lhs": 1, "rhs": 4, "relation": ["f2", "f1"]},
{"lhs": 8, "rhs": 9, "relation": ["f0", "f2"]},
]
```
Returns a KeyedJaggedTensor used to look up all embeddings.
Note: We treat the lhs and rhs as though they're separate lookups: `len(lenghts) == 2 * bsz * len(tables)`.
This differs from the DLRM pattern where we have `len(lengths) = bsz * len(tables)`.
For the example above:
```
lookups = tensor([
[0., 1.],
[1., 6.],
[1., 6.],
[0., 3.],
[1., 3.],
[2., 4.],
[2., 1.],
[1., 4.],
[0., 8.],
[2., 9.]
])
kjt = KeyedJaggedTensor(
features=["f0", "f1", "f2"]
values=[
1, 3, 8, # f0
6, 6, 3, 4, # f1
4, 1, 9 # f2
]
lengths=[
1, 0, 0, 1, 0, 0, 0, 0, 1, 0, # f0
0, 1, 1, 0, 1, 0, 0, 1, 0, 0, # f1
0, 0, 0, 0, 0, 1, 1, 0, 0, 1, # f2
)
```
Note:
- values = [values for f0] + [values for f1] + [values for f2]
- lengths are always 0 or 1, and sum(lengths) = len(values) = 2 * bsz
"""
lookups = torch.concat((lhs[:, None], self.relations_t[rel], rhs[:, None]), dim=1)
index = torch.LongTensor([1, 0, 2, 3])
lookups = lookups[:, index].reshape(2 * self.batch_size, 2)
# values is just the row indices into each table, ordered by the table indices
_, indices = torch.sort(lookups[:, 0], dim=0, stable=True)
values = lookups[indices][:, 1].int()
# lengths[table_idx * batch_size + i] == whether the ith lookup is for the table with index table_idx
lengths = torch.arange(self.num_tables)[:, None].eq(lookups[:, 0])
lengths = lengths.reshape(-1).int()
return KeyedJaggedTensor(keys=self.table_names, values=values, lengths=lengths)
def to_batches(self):
ds = super().to_batches()
batch_size = self._dataset_kwargs["batch_size"]
names = [
self.lhs_column_name,
self.rhs_column_name,
self.rel_column_name,
self.label_column_name,
]
for _, batch in enumerate(ds):
# Pass along positive edges
lhs = batch.column(self.lhs_column_name)
rhs = batch.column(self.rhs_column_name)
rel = batch.column(self.rel_column_name)
label = pa.array(np.ones(batch_size, dtype=np.int64))
yield pa.RecordBatch.from_arrays(
arrays=[lhs, rhs, rel, label],
names=names,
)