mirror of
https://github.com/twitter/the-algorithm-ml.git
synced 2025-01-10 14:59:21 +01:00
108 lines
2.8 KiB
Python
108 lines
2.8 KiB
Python
from tml.projects.twhin.models.config import TwhinEmbeddingsConfig, TwhinModelConfig
|
|
from tml.projects.twhin.data.config import TwhinDataConfig
|
|
from tml.common.modules.embedding.config import DataType, EmbeddingBagConfig
|
|
from tml.optimizers.config import OptimizerConfig, SgdConfig
|
|
from tml.model import maybe_shard_model
|
|
from tml.projects.twhin.models.models import apply_optimizers, TwhinModel
|
|
from tml.projects.twhin.models.config import Operator, Relation
|
|
from tml.common.testing_utils import mock_pg
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from pydantic import ValidationError
|
|
import pytest
|
|
|
|
|
|
NUM_EMBS = 10_000
|
|
EMB_DIM = 128
|
|
|
|
|
|
def twhin_model_config() -> TwhinModelConfig:
|
|
sgd_config_0 = OptimizerConfig(sgd=SgdConfig(lr=0.01))
|
|
sgd_config_1 = OptimizerConfig(sgd=SgdConfig(lr=0.02))
|
|
|
|
table0 = EmbeddingBagConfig(
|
|
name="table0",
|
|
num_embeddings=NUM_EMBS,
|
|
embedding_dim=EMB_DIM,
|
|
optimizer=sgd_config_0,
|
|
data_type=DataType.FP32,
|
|
)
|
|
table1 = EmbeddingBagConfig(
|
|
name="table1",
|
|
num_embeddings=NUM_EMBS,
|
|
embedding_dim=EMB_DIM,
|
|
optimizer=sgd_config_1,
|
|
data_type=DataType.FP32,
|
|
)
|
|
embeddings_config = TwhinEmbeddingsConfig(
|
|
tables=[table0, table1],
|
|
)
|
|
|
|
model_config = TwhinModelConfig(
|
|
embeddings=embeddings_config,
|
|
translation_optimizer=sgd_config_0,
|
|
relations=[
|
|
Relation(name="rel0", lhs="table0", rhs="table1", operator=Operator.TRANSLATION),
|
|
Relation(name="rel1", lhs="table1", rhs="table0", operator=Operator.TRANSLATION),
|
|
],
|
|
)
|
|
|
|
return model_config
|
|
|
|
|
|
def twhin_data_config() -> TwhinDataConfig:
|
|
data_config = TwhinDataConfig(
|
|
data_root="/",
|
|
per_replica_batch_size=10,
|
|
global_negatives=10,
|
|
in_batch_negatives=10,
|
|
limit=1,
|
|
offset=1,
|
|
)
|
|
|
|
return data_config
|
|
|
|
|
|
def test_twhin_model():
|
|
model_config = twhin_model_config()
|
|
loss_fn = F.binary_cross_entropy_with_logits
|
|
|
|
with mock_pg():
|
|
data_config = twhin_data_config()
|
|
model = TwhinModel(model_config=model_config, data_config=data_config)
|
|
|
|
apply_optimizers(model, model_config)
|
|
|
|
for tensor in model.state_dict().values():
|
|
if tensor.size() == (NUM_EMBS, EMB_DIM):
|
|
assert str(tensor.device) == "meta"
|
|
else:
|
|
assert str(tensor.device) == "cpu"
|
|
|
|
model = maybe_shard_model(model, device=torch.device("cpu"))
|
|
|
|
|
|
def test_unequal_dims():
|
|
sgd_config_1 = OptimizerConfig(sgd=SgdConfig(lr=0.02))
|
|
sgd_config_2 = OptimizerConfig(sgd=SgdConfig(lr=0.05))
|
|
table0 = EmbeddingBagConfig(
|
|
name="table0",
|
|
num_embeddings=10_000,
|
|
embedding_dim=128,
|
|
optimizer=sgd_config_1,
|
|
data_type=DataType.FP32,
|
|
)
|
|
table1 = EmbeddingBagConfig(
|
|
name="table1",
|
|
num_embeddings=10_000,
|
|
embedding_dim=64,
|
|
optimizer=sgd_config_2,
|
|
data_type=DataType.FP32,
|
|
)
|
|
|
|
with pytest.raises(ValidationError):
|
|
_ = TwhinEmbeddingsConfig(
|
|
tables=[table0, table1],
|
|
)
|