the-algorithm-ml/core/test_train_pipeline.py

145 lines
4.2 KiB
Python
Raw Normal View History

from dataclasses import dataclass
from typing import Tuple
from tml.common.batch import DataclassBatch
from tml.common.testing_utils import mock_pg
from tml.core import train_pipeline
import torch
from torchrec.distributed import DistributedModelParallel
@dataclass
class MockDataclassBatch(DataclassBatch):
"""
Mock data class batch for testing purposes.
This class represents a batch of data with continuous features and labels.
Attributes:
continuous_features (torch.Tensor): Tensor containing continuous feature data.
labels (torch.Tensor): Tensor containing label data.
"""
continuous_features: torch.Tensor
labels: torch.Tensor
class MockModule(torch.nn.Module):
"""
Mock PyTorch module for testing purposes.
This module defines a simple neural network model with a linear layer
followed by a BCEWithLogitsLoss loss function.
Attributes:
model (torch.nn.Linear): The linear model layer.
loss_fn (torch.nn.BCEWithLogitsLoss): Binary cross-entropy loss function.
"""
def __init__(self) -> None:
super().__init__()
self.model = torch.nn.Linear(10, 1)
self.loss_fn = torch.nn.BCEWithLogitsLoss()
def forward(self, batch: MockDataclassBatch) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Forward pass of the mock module.
Args:
batch (MockDataclassBatch): Input data batch with continuous features and labels.
Returns:
Tuple[torch.Tensor, torch.Tensor]: A tuple containing the loss and predictions.
"""
pred = self.model(batch.continuous_features)
loss = self.loss_fn(pred, batch.labels)
return (loss, pred)
def create_batch(bsz: int):
"""
Create a mock data batch with random continuous features and labels.
Args:
bsz (int): Batch size.
Returns:
MockDataclassBatch: A batch of data with continuous features and labels.
"""
return MockDataclassBatch(
continuous_features=torch.rand(bsz, 10).float(),
labels=torch.bernoulli(torch.empty(bsz, 1).uniform_(0, 1)).float(),
)
def test_sparse_pipeline():
"""
Test function for the sparse pipeline with distributed model parallelism.
This function tests the behavior of the sparse training pipeline using
a mock module and data.
"""
device = torch.device("cpu")
model = MockModule().to(device)
steps = 8
example = create_batch(1)
dataloader = iter(example for _ in range(steps + 2))
results = []
with mock_pg():
d_model = DistributedModelParallel(model)
pipeline = train_pipeline.TrainPipelineSparseDist(
model=d_model,
optimizer=torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9),
device=device,
grad_accum=2,
)
for _ in range(steps):
results.append(pipeline.progress(dataloader))
results = [elem.detach().numpy() for elem in results]
# Check gradients are accumulated, i.e. results do not change for every 0th and 1th.
for first, second in zip(results[::2], results[1::2]):
assert first == second, results
# Check we do update gradients, i.e. results do change for every 1th and 2nd.
for first, second in zip(results[1::2], results[2::2]):
assert first != second, results
def test_amp():
"""
Test automatic mixed-precision (AMP) training with the sparse pipeline.
This function tests the behavior of the sparse training pipeline with
automatic mixed-precision (AMP) enabled, using a mock module and data.
AMP allows for faster training by using lower-precision data types, such as
torch.bfloat16, while maintaining model accuracy.
"""
device = torch.device("cpu")
model = MockModule().to(device)
steps = 8
example = create_batch(1)
dataloader = iter(example for _ in range(steps + 2))
results = []
with mock_pg():
d_model = DistributedModelParallel(model)
pipeline = train_pipeline.TrainPipelineSparseDist(
model=d_model,
optimizer=torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9),
device=device,
enable_amp=True,
# Not supported on CPU.
enable_grad_scaling=False,
)
for _ in range(steps):
results.append(pipeline.progress(dataloader))
results = [elem.detach() for elem in results]
for value in results:
assert value.dtype == torch.bfloat16