108 lines
2.8 KiB
Python

from tml.projects.twhin.models.config import TwhinEmbeddingsConfig, TwhinModelConfig
from tml.projects.twhin.data.config import TwhinDataConfig
from tml.common.modules.embedding.config import DataType, EmbeddingBagConfig
from tml.optimizers.config import OptimizerConfig, SgdConfig
from tml.model import maybe_shard_model
from tml.projects.twhin.models.models import apply_optimizers, TwhinModel
from tml.projects.twhin.models.config import Operator, Relation
from tml.common.testing_utils import mock_pg
import torch
import torch.nn.functional as F
from pydantic import ValidationError
import pytest
NUM_EMBS = 10_000
EMB_DIM = 128
def twhin_model_config() -> TwhinModelConfig:
sgd_config_0 = OptimizerConfig(sgd=SgdConfig(lr=0.01))
sgd_config_1 = OptimizerConfig(sgd=SgdConfig(lr=0.02))
table0 = EmbeddingBagConfig(
name="table0",
num_embeddings=NUM_EMBS,
embedding_dim=EMB_DIM,
optimizer=sgd_config_0,
data_type=DataType.FP32,
)
table1 = EmbeddingBagConfig(
name="table1",
num_embeddings=NUM_EMBS,
embedding_dim=EMB_DIM,
optimizer=sgd_config_1,
data_type=DataType.FP32,
)
embeddings_config = TwhinEmbeddingsConfig(
tables=[table0, table1],
)
model_config = TwhinModelConfig(
embeddings=embeddings_config,
translation_optimizer=sgd_config_0,
relations=[
Relation(name="rel0", lhs="table0", rhs="table1", operator=Operator.TRANSLATION),
Relation(name="rel1", lhs="table1", rhs="table0", operator=Operator.TRANSLATION),
],
)
return model_config
def twhin_data_config() -> TwhinDataConfig:
data_config = TwhinDataConfig(
data_root="/",
per_replica_batch_size=10,
global_negatives=10,
in_batch_negatives=10,
limit=1,
offset=1,
)
return data_config
def test_twhin_model():
model_config = twhin_model_config()
loss_fn = F.binary_cross_entropy_with_logits
with mock_pg():
data_config = twhin_data_config()
model = TwhinModel(model_config=model_config, data_config=data_config)
apply_optimizers(model, model_config)
for tensor in model.state_dict().values():
if tensor.size() == (NUM_EMBS, EMB_DIM):
assert str(tensor.device) == "meta"
else:
assert str(tensor.device) == "cpu"
model = maybe_shard_model(model, device=torch.device("cpu"))
def test_unequal_dims():
sgd_config_1 = OptimizerConfig(sgd=SgdConfig(lr=0.02))
sgd_config_2 = OptimizerConfig(sgd=SgdConfig(lr=0.05))
table0 = EmbeddingBagConfig(
name="table0",
num_embeddings=10_000,
embedding_dim=128,
optimizer=sgd_config_1,
data_type=DataType.FP32,
)
table1 = EmbeddingBagConfig(
name="table1",
num_embeddings=10_000,
embedding_dim=64,
optimizer=sgd_config_2,
data_type=DataType.FP32,
)
with pytest.raises(ValidationError):
_ = TwhinEmbeddingsConfig(
tables=[table0, table1],
)