mirror of
https://github.com/twitter/the-algorithm-ml.git
synced 2025-01-11 15:19:06 +01:00
67 lines
1.7 KiB
Python
67 lines
1.7 KiB
Python
|
"""Tests edges dataset functionality."""
|
||
|
|
||
|
from unittest.mock import patch
|
||
|
import os
|
||
|
import tempfile
|
||
|
|
||
|
from tml.projects.twhin.data.edges import EdgesDataset
|
||
|
from tml.projects.twhin.models.config import Relation
|
||
|
|
||
|
from fsspec.implementations.local import LocalFileSystem
|
||
|
import numpy as np
|
||
|
import pyarrow as pa
|
||
|
import pyarrow.compute as pc
|
||
|
import pyarrow.parquet as pq
|
||
|
import torch
|
||
|
|
||
|
|
||
|
TABLE_SIZES = {"user": 16, "author": 32}
|
||
|
RELATIONS = [
|
||
|
Relation(name="fav", lhs="user", rhs="author"),
|
||
|
Relation(name="engaged_with_reply", lhs="author", rhs="user"),
|
||
|
]
|
||
|
|
||
|
|
||
|
def test_gen():
|
||
|
import os
|
||
|
import tempfile
|
||
|
|
||
|
from fsspec.implementations.local import LocalFileSystem
|
||
|
import pyarrow as pa
|
||
|
import pyarrow.parquet as pq
|
||
|
|
||
|
lhs = pa.array(np.arange(4))
|
||
|
rhs = pa.array(np.flip(np.arange(4)))
|
||
|
rel = pa.array([0, 1, 0, 0])
|
||
|
names = ["lhs", "rhs", "rel"]
|
||
|
|
||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||
|
table = pa.Table.from_arrays([lhs, rhs, rel], names=names)
|
||
|
writer = pq.ParquetWriter(
|
||
|
os.path.join(tmpdir, "example.parquet"),
|
||
|
table.schema,
|
||
|
)
|
||
|
writer.write_table(table)
|
||
|
writer.close()
|
||
|
|
||
|
ds = EdgesDataset(
|
||
|
file_pattern=os.path.join(tmpdir, "*"),
|
||
|
table_sizes=TABLE_SIZES,
|
||
|
relations=RELATIONS,
|
||
|
batch_size=4,
|
||
|
)
|
||
|
ds.FS = LocalFileSystem()
|
||
|
|
||
|
dl = ds.dataloader()
|
||
|
batch = next(iter(dl))
|
||
|
|
||
|
# labels should be positive
|
||
|
labels = batch.labels
|
||
|
assert (labels[:4] == 1).sum() == 4
|
||
|
|
||
|
# make sure positive examples are what we expect
|
||
|
kjt_values = batch.nodes.values()
|
||
|
users, authors = torch.split(kjt_values, 4, dim=0)
|
||
|
assert torch.equal(users[:4], torch.tensor([0, 2, 2, 3]))
|
||
|
assert torch.equal(authors[:4], torch.tensor([3, 1, 1, 0]))
|