2023-03-31 20:05:14 +02:00
|
|
|
from tml.projects.twhin.models.config import TwhinEmbeddingsConfig, TwhinModelConfig
|
|
|
|
from tml.projects.twhin.data.config import TwhinDataConfig
|
|
|
|
from tml.common.modules.embedding.config import DataType, EmbeddingBagConfig
|
|
|
|
from tml.optimizers.config import OptimizerConfig, SgdConfig
|
|
|
|
from tml.model import maybe_shard_model
|
|
|
|
from tml.projects.twhin.models.models import apply_optimizers, TwhinModel
|
|
|
|
from tml.projects.twhin.models.config import Operator, Relation
|
|
|
|
from tml.common.testing_utils import mock_pg
|
|
|
|
|
|
|
|
import torch
|
|
|
|
import torch.nn.functional as F
|
|
|
|
from pydantic import ValidationError
|
|
|
|
import pytest
|
|
|
|
|
|
|
|
|
|
|
|
NUM_EMBS = 10_000
|
|
|
|
EMB_DIM = 128
|
|
|
|
|
|
|
|
|
|
|
|
def twhin_model_config() -> TwhinModelConfig:
|
2023-09-13 07:52:13 +02:00
|
|
|
"""
|
|
|
|
Create a configuration for the Twhin model.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
TwhinModelConfig: The Twhin model configuration.
|
|
|
|
"""
|
2023-03-31 20:05:14 +02:00
|
|
|
sgd_config_0 = OptimizerConfig(sgd=SgdConfig(lr=0.01))
|
|
|
|
sgd_config_1 = OptimizerConfig(sgd=SgdConfig(lr=0.02))
|
|
|
|
|
|
|
|
table0 = EmbeddingBagConfig(
|
|
|
|
name="table0",
|
|
|
|
num_embeddings=NUM_EMBS,
|
|
|
|
embedding_dim=EMB_DIM,
|
|
|
|
optimizer=sgd_config_0,
|
|
|
|
data_type=DataType.FP32,
|
|
|
|
)
|
|
|
|
table1 = EmbeddingBagConfig(
|
|
|
|
name="table1",
|
|
|
|
num_embeddings=NUM_EMBS,
|
|
|
|
embedding_dim=EMB_DIM,
|
|
|
|
optimizer=sgd_config_1,
|
|
|
|
data_type=DataType.FP32,
|
|
|
|
)
|
|
|
|
embeddings_config = TwhinEmbeddingsConfig(
|
|
|
|
tables=[table0, table1],
|
|
|
|
)
|
|
|
|
|
|
|
|
model_config = TwhinModelConfig(
|
|
|
|
embeddings=embeddings_config,
|
|
|
|
translation_optimizer=sgd_config_0,
|
|
|
|
relations=[
|
|
|
|
Relation(name="rel0", lhs="table0", rhs="table1", operator=Operator.TRANSLATION),
|
|
|
|
Relation(name="rel1", lhs="table1", rhs="table0", operator=Operator.TRANSLATION),
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
|
|
|
return model_config
|
|
|
|
|
|
|
|
|
|
|
|
def twhin_data_config() -> TwhinDataConfig:
|
2023-09-13 07:52:13 +02:00
|
|
|
"""
|
|
|
|
Create a configuration for the Twhin data.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
TwhinDataConfig: The Twhin data configuration.
|
|
|
|
"""
|
2023-03-31 20:05:14 +02:00
|
|
|
data_config = TwhinDataConfig(
|
|
|
|
data_root="/",
|
|
|
|
per_replica_batch_size=10,
|
|
|
|
global_negatives=10,
|
|
|
|
in_batch_negatives=10,
|
|
|
|
limit=1,
|
|
|
|
offset=1,
|
|
|
|
)
|
|
|
|
|
|
|
|
return data_config
|
|
|
|
|
|
|
|
|
|
|
|
def test_twhin_model():
|
2023-09-13 07:52:13 +02:00
|
|
|
"""
|
|
|
|
Test the Twhin model creation and optimization.
|
|
|
|
|
|
|
|
This function creates a Twhin model using the specified configuration and tests its optimization. It also checks
|
|
|
|
the device placement of model parameters.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
None
|
|
|
|
"""
|
2023-03-31 20:05:14 +02:00
|
|
|
model_config = twhin_model_config()
|
|
|
|
loss_fn = F.binary_cross_entropy_with_logits
|
|
|
|
|
|
|
|
with mock_pg():
|
|
|
|
data_config = twhin_data_config()
|
|
|
|
model = TwhinModel(model_config=model_config, data_config=data_config)
|
|
|
|
|
|
|
|
apply_optimizers(model, model_config)
|
|
|
|
|
|
|
|
for tensor in model.state_dict().values():
|
|
|
|
if tensor.size() == (NUM_EMBS, EMB_DIM):
|
|
|
|
assert str(tensor.device) == "meta"
|
|
|
|
else:
|
|
|
|
assert str(tensor.device) == "cpu"
|
|
|
|
|
|
|
|
model = maybe_shard_model(model, device=torch.device("cpu"))
|
|
|
|
|
|
|
|
|
|
|
|
def test_unequal_dims():
|
2023-09-14 08:00:10 +02:00
|
|
|
"""
|
|
|
|
Test function for validating unequal embedding dimensions in TwhinEmbeddingsConfig.
|
|
|
|
|
|
|
|
This function tests whether the validation logic correctly raises a `ValidationError` when
|
|
|
|
embedding dimensions in the `TwhinEmbeddingsConfig` are not equal for all tables.
|
|
|
|
|
|
|
|
The test includes the following steps:
|
|
|
|
1. Create two embedding configurations with different embedding dimensions.
|
|
|
|
2. Attempt to create a `TwhinEmbeddingsConfig` instance with the unequal embedding dimensions.
|
|
|
|
3. Assert that a `ValidationError` is raised, indicating that embedding dimensions must match.
|
|
|
|
|
|
|
|
This function serves as a test case to ensure that the validation logic enforces equal embedding dimensions
|
|
|
|
in the `TwhinEmbeddingsConfig` for all tables.
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
AssertionError: If the expected `ValidationError` is not raised.
|
|
|
|
"""
|
2023-03-31 20:05:14 +02:00
|
|
|
sgd_config_1 = OptimizerConfig(sgd=SgdConfig(lr=0.02))
|
|
|
|
sgd_config_2 = OptimizerConfig(sgd=SgdConfig(lr=0.05))
|
|
|
|
table0 = EmbeddingBagConfig(
|
|
|
|
name="table0",
|
|
|
|
num_embeddings=10_000,
|
|
|
|
embedding_dim=128,
|
|
|
|
optimizer=sgd_config_1,
|
|
|
|
data_type=DataType.FP32,
|
|
|
|
)
|
|
|
|
table1 = EmbeddingBagConfig(
|
|
|
|
name="table1",
|
|
|
|
num_embeddings=10_000,
|
|
|
|
embedding_dim=64,
|
|
|
|
optimizer=sgd_config_2,
|
|
|
|
data_type=DataType.FP32,
|
|
|
|
)
|
|
|
|
|
|
|
|
with pytest.raises(ValidationError):
|
|
|
|
_ = TwhinEmbeddingsConfig(
|
|
|
|
tables=[table0, table1],
|
|
|
|
)
|