2023-03-31 13:05:14 -05:00
|
|
|
from collections import defaultdict
|
|
|
|
from dataclasses import dataclass
|
|
|
|
from typing import Dict, List, Tuple
|
|
|
|
|
|
|
|
from tml.common.batch import DataclassBatch
|
|
|
|
from tml.reader.dataset import Dataset
|
|
|
|
from tml.projects.twhin.models.config import Relation
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
import pyarrow as pa
|
|
|
|
import pyarrow.compute as pc
|
|
|
|
import torch
|
|
|
|
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class EdgeBatch(DataclassBatch):
|
2023-09-13 11:22:13 +05:30
|
|
|
"""
|
|
|
|
Batch data structure for edge-based models.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
nodes (KeyedJaggedTensor): A KeyedJaggedTensor containing node embeddings.
|
|
|
|
labels (torch.Tensor): Tensor containing labels.
|
|
|
|
rels (torch.Tensor): Tensor containing relation information.
|
|
|
|
weights (torch.Tensor): Tensor containing weights.
|
|
|
|
"""
|
2023-03-31 13:05:14 -05:00
|
|
|
nodes: KeyedJaggedTensor
|
|
|
|
labels: torch.Tensor
|
|
|
|
rels: torch.Tensor
|
|
|
|
weights: torch.Tensor
|
|
|
|
|
|
|
|
|
|
|
|
class EdgesDataset(Dataset):
|
2023-09-13 11:22:13 +05:30
|
|
|
"""
|
|
|
|
Dataset for edge-based models.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
file_pattern (str): The file pattern for the dataset.
|
|
|
|
table_sizes (Dict[str, int]): A dictionary of table names and their sizes.
|
|
|
|
relations (List[Relation]): A list of relations between tables.
|
|
|
|
lhs_column_name (str): The name of the left-hand-side column.
|
|
|
|
rhs_column_name (str): The name of the right-hand-side column.
|
|
|
|
rel_column_name (str): The name of the relation column.
|
|
|
|
**dataset_kwargs: Additional keyword arguments for the parent Dataset class.
|
|
|
|
"""
|
2023-03-31 13:05:14 -05:00
|
|
|
rng = np.random.default_rng()
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
file_pattern: str,
|
|
|
|
table_sizes: Dict[str, int],
|
|
|
|
relations: List[Relation],
|
|
|
|
lhs_column_name: str = "lhs",
|
|
|
|
rhs_column_name: str = "rhs",
|
|
|
|
rel_column_name: str = "rel",
|
|
|
|
**dataset_kwargs
|
|
|
|
):
|
|
|
|
self.batch_size = dataset_kwargs["batch_size"]
|
|
|
|
|
|
|
|
self.table_sizes = table_sizes
|
|
|
|
self.num_tables = len(table_sizes)
|
|
|
|
self.table_names = list(table_sizes.keys())
|
|
|
|
|
|
|
|
self.relations = relations
|
|
|
|
self.relations_t = torch.tensor(
|
|
|
|
[
|
|
|
|
[self.table_names.index(relation.lhs), self.table_names.index(relation.rhs)]
|
|
|
|
for relation in relations
|
|
|
|
]
|
|
|
|
)
|
|
|
|
|
|
|
|
self.lhs_column_name = lhs_column_name
|
|
|
|
self.rhs_column_name = rhs_column_name
|
|
|
|
self.rel_column_name = rel_column_name
|
|
|
|
self.label_column_name = "label"
|
|
|
|
|
|
|
|
super().__init__(file_pattern=file_pattern, **dataset_kwargs)
|
|
|
|
|
|
|
|
def pa_to_batch(self, batch: pa.RecordBatch):
|
2023-09-13 11:22:13 +05:30
|
|
|
"""
|
|
|
|
Converts a pyarrow RecordBatch to an EdgeBatch.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
batch (pa.RecordBatch): A pyarrow RecordBatch containing data.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
EdgeBatch: An EdgeBatch containing node embeddings, labels, relations, and weights.
|
|
|
|
"""
|
2023-03-31 13:05:14 -05:00
|
|
|
lhs = torch.from_numpy(batch.column(self.lhs_column_name).to_numpy())
|
|
|
|
rhs = torch.from_numpy(batch.column(self.rhs_column_name).to_numpy())
|
|
|
|
rel = torch.from_numpy(batch.column(self.rel_column_name).to_numpy())
|
|
|
|
label = torch.from_numpy(batch.column(self.label_column_name).to_numpy())
|
|
|
|
|
|
|
|
nodes = self._to_kjt(lhs, rhs, rel)
|
|
|
|
return EdgeBatch(
|
|
|
|
nodes=nodes,
|
|
|
|
rels=rel,
|
|
|
|
labels=label,
|
|
|
|
weights=torch.ones(batch.num_rows),
|
|
|
|
)
|
|
|
|
|
|
|
|
def _to_kjt(
|
|
|
|
self, lhs: torch.Tensor, rhs: torch.Tensor, rel: torch.Tensor
|
|
|
|
) -> Tuple[KeyedJaggedTensor, List[Tuple[int, int]]]:
|
|
|
|
|
|
|
|
"""Process edges that contain lhs index, rhs index, relation index.
|
2023-09-13 11:22:13 +05:30
|
|
|
|
|
|
|
Args:
|
|
|
|
lhs (torch.Tensor): Tensor containing left-hand-side indices.
|
|
|
|
rhs (torch.Tensor): Tensor containing right-hand-side indices.
|
|
|
|
rel (torch.Tensor): Tensor containing relation indices.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Tuple[KeyedJaggedTensor, List[Tuple[int, int]]]: A KeyedJaggedTensor and relation index pairs.
|
2023-03-31 13:05:14 -05:00
|
|
|
Example:
|
|
|
|
|
|
|
|
```
|
|
|
|
tables = ["f0", "f1", "f2", "f3"]
|
|
|
|
relations = [["f0", "f1"], ["f1", "f2"], ["f1", "f0"], ["f2", "f1"], ["f0", "f2"]]
|
|
|
|
self.relations_t = torch.Tensor([[0, 1], [1, 2], [1, 0], [2, 1], [0, 2]])
|
|
|
|
lhs = [1, 6, 3, 1, 8]
|
|
|
|
rhs = [6, 3, 4, 4, 9]
|
|
|
|
rel = [0, 2, 1, 3, 4]
|
|
|
|
|
|
|
|
This corresponds to the following "edges":
|
|
|
|
edges = [
|
|
|
|
{"lhs": 1, "rhs": 6, "relation": ["f0", "f1"]},
|
|
|
|
{"lhs": 6, "rhs": 3, "relation": ["f1", "f0"]},
|
|
|
|
{"lhs": 3, "rhs": 4, "relation": ["f1", "f2"]},
|
|
|
|
{"lhs": 1, "rhs": 4, "relation": ["f2", "f1"]},
|
|
|
|
{"lhs": 8, "rhs": 9, "relation": ["f0", "f2"]},
|
|
|
|
]
|
|
|
|
```
|
|
|
|
|
|
|
|
Returns a KeyedJaggedTensor used to look up all embeddings.
|
|
|
|
|
|
|
|
Note: We treat the lhs and rhs as though they're separate lookups: `len(lenghts) == 2 * bsz * len(tables)`.
|
|
|
|
This differs from the DLRM pattern where we have `len(lengths) = bsz * len(tables)`.
|
|
|
|
|
|
|
|
For the example above:
|
|
|
|
```
|
|
|
|
lookups = tensor([
|
|
|
|
[0., 1.],
|
|
|
|
[1., 6.],
|
|
|
|
[1., 6.],
|
|
|
|
[0., 3.],
|
|
|
|
[1., 3.],
|
|
|
|
[2., 4.],
|
|
|
|
[2., 1.],
|
|
|
|
[1., 4.],
|
|
|
|
[0., 8.],
|
|
|
|
[2., 9.]
|
|
|
|
])
|
|
|
|
|
|
|
|
kjt = KeyedJaggedTensor(
|
|
|
|
features=["f0", "f1", "f2"]
|
|
|
|
values=[
|
|
|
|
1, 3, 8, # f0
|
|
|
|
6, 6, 3, 4, # f1
|
|
|
|
4, 1, 9 # f2
|
|
|
|
]
|
|
|
|
lengths=[
|
|
|
|
1, 0, 0, 1, 0, 0, 0, 0, 1, 0, # f0
|
|
|
|
0, 1, 1, 0, 1, 0, 0, 1, 0, 0, # f1
|
|
|
|
0, 0, 0, 0, 0, 1, 1, 0, 0, 1, # f2
|
|
|
|
)
|
|
|
|
```
|
|
|
|
|
|
|
|
Note:
|
|
|
|
- values = [values for f0] + [values for f1] + [values for f2]
|
|
|
|
- lengths are always 0 or 1, and sum(lengths) = len(values) = 2 * bsz
|
|
|
|
"""
|
|
|
|
lookups = torch.concat((lhs[:, None], self.relations_t[rel], rhs[:, None]), dim=1)
|
|
|
|
index = torch.LongTensor([1, 0, 2, 3])
|
|
|
|
lookups = lookups[:, index].reshape(2 * self.batch_size, 2)
|
|
|
|
|
|
|
|
# values is just the row indices into each table, ordered by the table indices
|
|
|
|
_, indices = torch.sort(lookups[:, 0], dim=0, stable=True)
|
|
|
|
values = lookups[indices][:, 1].int()
|
|
|
|
|
|
|
|
# lengths[table_idx * batch_size + i] == whether the ith lookup is for the table with index table_idx
|
|
|
|
lengths = torch.arange(self.num_tables)[:, None].eq(lookups[:, 0])
|
|
|
|
lengths = lengths.reshape(-1).int()
|
|
|
|
|
|
|
|
return KeyedJaggedTensor(keys=self.table_names, values=values, lengths=lengths)
|
|
|
|
|
|
|
|
def to_batches(self):
|
2023-09-13 11:22:13 +05:30
|
|
|
"""
|
|
|
|
Converts data to batches.
|
|
|
|
|
|
|
|
Yields:
|
|
|
|
pa.RecordBatch: A pyarrow RecordBatch containing data.
|
|
|
|
"""
|
2023-03-31 13:05:14 -05:00
|
|
|
ds = super().to_batches()
|
|
|
|
batch_size = self._dataset_kwargs["batch_size"]
|
|
|
|
|
|
|
|
names = [
|
|
|
|
self.lhs_column_name,
|
|
|
|
self.rhs_column_name,
|
|
|
|
self.rel_column_name,
|
|
|
|
self.label_column_name,
|
|
|
|
]
|
|
|
|
for _, batch in enumerate(ds):
|
|
|
|
# Pass along positive edges
|
|
|
|
lhs = batch.column(self.lhs_column_name)
|
|
|
|
rhs = batch.column(self.rhs_column_name)
|
|
|
|
rel = batch.column(self.rel_column_name)
|
|
|
|
label = pa.array(np.ones(batch_size, dtype=np.int64))
|
|
|
|
|
|
|
|
yield pa.RecordBatch.from_arrays(
|
|
|
|
arrays=[lhs, rhs, rel, label],
|
|
|
|
names=names,
|
|
|
|
)
|