2023-03-31 20:05:14 +02:00
|
|
|
from dataclasses import dataclass
|
|
|
|
from typing import Tuple
|
|
|
|
|
|
|
|
from tml.common.batch import DataclassBatch
|
|
|
|
from tml.common.testing_utils import mock_pg
|
|
|
|
from tml.core import train_pipeline
|
|
|
|
|
|
|
|
import torch
|
|
|
|
from torchrec.distributed import DistributedModelParallel
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class MockDataclassBatch(DataclassBatch):
|
2023-09-11 12:56:29 +02:00
|
|
|
"""
|
|
|
|
Mock data class batch for testing purposes.
|
|
|
|
|
|
|
|
This class represents a batch of data with continuous features and labels.
|
|
|
|
|
|
|
|
Attributes:
|
|
|
|
continuous_features (torch.Tensor): Tensor containing continuous feature data.
|
|
|
|
labels (torch.Tensor): Tensor containing label data.
|
|
|
|
"""
|
2023-03-31 20:05:14 +02:00
|
|
|
continuous_features: torch.Tensor
|
|
|
|
labels: torch.Tensor
|
|
|
|
|
|
|
|
|
|
|
|
class MockModule(torch.nn.Module):
|
2023-09-11 12:56:29 +02:00
|
|
|
"""
|
|
|
|
Mock PyTorch module for testing purposes.
|
|
|
|
|
|
|
|
This module defines a simple neural network model with a linear layer
|
|
|
|
followed by a BCEWithLogitsLoss loss function.
|
|
|
|
|
|
|
|
Attributes:
|
|
|
|
model (torch.nn.Linear): The linear model layer.
|
|
|
|
loss_fn (torch.nn.BCEWithLogitsLoss): Binary cross-entropy loss function.
|
|
|
|
"""
|
2023-03-31 20:05:14 +02:00
|
|
|
def __init__(self) -> None:
|
|
|
|
super().__init__()
|
|
|
|
self.model = torch.nn.Linear(10, 1)
|
|
|
|
self.loss_fn = torch.nn.BCEWithLogitsLoss()
|
|
|
|
|
|
|
|
def forward(self, batch: MockDataclassBatch) -> Tuple[torch.Tensor, torch.Tensor]:
|
2023-09-11 12:56:29 +02:00
|
|
|
"""
|
|
|
|
Forward pass of the mock module.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
batch (MockDataclassBatch): Input data batch with continuous features and labels.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Tuple[torch.Tensor, torch.Tensor]: A tuple containing the loss and predictions.
|
|
|
|
"""
|
2023-03-31 20:05:14 +02:00
|
|
|
pred = self.model(batch.continuous_features)
|
|
|
|
loss = self.loss_fn(pred, batch.labels)
|
|
|
|
return (loss, pred)
|
|
|
|
|
|
|
|
|
|
|
|
def create_batch(bsz: int):
|
2023-09-11 12:56:29 +02:00
|
|
|
"""
|
|
|
|
Create a mock data batch with random continuous features and labels.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
bsz (int): Batch size.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
MockDataclassBatch: A batch of data with continuous features and labels.
|
|
|
|
"""
|
2023-03-31 20:05:14 +02:00
|
|
|
return MockDataclassBatch(
|
|
|
|
continuous_features=torch.rand(bsz, 10).float(),
|
|
|
|
labels=torch.bernoulli(torch.empty(bsz, 1).uniform_(0, 1)).float(),
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def test_sparse_pipeline():
|
2023-09-11 12:56:29 +02:00
|
|
|
"""
|
|
|
|
Test function for the sparse pipeline with distributed model parallelism.
|
|
|
|
|
|
|
|
This function tests the behavior of the sparse training pipeline using
|
|
|
|
a mock module and data.
|
|
|
|
"""
|
|
|
|
|
2023-03-31 20:05:14 +02:00
|
|
|
device = torch.device("cpu")
|
|
|
|
model = MockModule().to(device)
|
|
|
|
|
|
|
|
steps = 8
|
|
|
|
example = create_batch(1)
|
|
|
|
dataloader = iter(example for _ in range(steps + 2))
|
|
|
|
|
|
|
|
results = []
|
|
|
|
with mock_pg():
|
|
|
|
d_model = DistributedModelParallel(model)
|
|
|
|
pipeline = train_pipeline.TrainPipelineSparseDist(
|
|
|
|
model=d_model,
|
|
|
|
optimizer=torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9),
|
|
|
|
device=device,
|
|
|
|
grad_accum=2,
|
|
|
|
)
|
|
|
|
for _ in range(steps):
|
|
|
|
results.append(pipeline.progress(dataloader))
|
|
|
|
|
|
|
|
results = [elem.detach().numpy() for elem in results]
|
|
|
|
# Check gradients are accumulated, i.e. results do not change for every 0th and 1th.
|
|
|
|
for first, second in zip(results[::2], results[1::2]):
|
|
|
|
assert first == second, results
|
|
|
|
|
|
|
|
# Check we do update gradients, i.e. results do change for every 1th and 2nd.
|
|
|
|
for first, second in zip(results[1::2], results[2::2]):
|
|
|
|
assert first != second, results
|
|
|
|
|
|
|
|
|
|
|
|
def test_amp():
|
2023-09-11 12:56:29 +02:00
|
|
|
"""
|
|
|
|
Test automatic mixed-precision (AMP) training with the sparse pipeline.
|
|
|
|
|
|
|
|
This function tests the behavior of the sparse training pipeline with
|
|
|
|
automatic mixed-precision (AMP) enabled, using a mock module and data.
|
|
|
|
|
|
|
|
AMP allows for faster training by using lower-precision data types, such as
|
|
|
|
torch.bfloat16, while maintaining model accuracy.
|
|
|
|
"""
|
2023-03-31 20:05:14 +02:00
|
|
|
device = torch.device("cpu")
|
|
|
|
model = MockModule().to(device)
|
|
|
|
|
|
|
|
steps = 8
|
|
|
|
example = create_batch(1)
|
|
|
|
dataloader = iter(example for _ in range(steps + 2))
|
|
|
|
|
|
|
|
results = []
|
|
|
|
with mock_pg():
|
|
|
|
d_model = DistributedModelParallel(model)
|
|
|
|
pipeline = train_pipeline.TrainPipelineSparseDist(
|
|
|
|
model=d_model,
|
|
|
|
optimizer=torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9),
|
|
|
|
device=device,
|
|
|
|
enable_amp=True,
|
|
|
|
# Not supported on CPU.
|
|
|
|
enable_grad_scaling=False,
|
|
|
|
)
|
|
|
|
for _ in range(steps):
|
|
|
|
results.append(pipeline.progress(dataloader))
|
|
|
|
|
|
|
|
results = [elem.detach() for elem in results]
|
|
|
|
for value in results:
|
|
|
|
assert value.dtype == torch.bfloat16
|