Twitter's Recommendation Algorithm - Heavy Ranker and TwHIN embeddings

This commit is contained in:
twitter-team
2023-03-31 13:05:14 -05:00
commit 78c3235eee
111 changed files with 11876 additions and 0 deletions

0
core/__init__.py Normal file
View File

5
core/config/__init__.py Normal file
View File

@ -0,0 +1,5 @@
from tml.core.config.base_config import BaseConfig
from tml.core.config.config_load import load_config_from_yaml
# Make mypy happy by explicitly rexporting the symbols intended for end user use.
__all__ = ["BaseConfig", "load_config_from_yaml"]

View File

@ -0,0 +1,66 @@
"""Base class for all config (forbids extra fields)."""
import collections
import functools
import yaml
import pydantic
class BaseConfig(pydantic.BaseModel):
"""Base class for all derived config classes.
This class provides some convenient functionality:
- Disallows extra fields when constructing an object. User error
should be reduced by exact arguments.
- "one_of" fields. A subclass can group optional fields and enforce
that only one of the fields be set. For example:
```
class ExampleConfig(BaseConfig):
x: int = Field(None, one_of="group_1")
y: int = Field(None, one_of="group_1")
ExampleConfig(x=1) # ok
ExampleConfig(y=1) # ok
ExampleConfig(x=1, y=1) # throws error
```
"""
class Config:
"""Forbids extras."""
extra = pydantic.Extra.forbid # noqa
@classmethod
@functools.lru_cache()
def _field_data_map(cls, field_data_name):
"""Create a map of fields with provided the field data."""
schema = cls.schema()
one_of = collections.defaultdict(list)
for field, fdata in schema["properties"].items():
if field_data_name in fdata:
one_of[fdata[field_data_name]].append(field)
return one_of
@pydantic.root_validator
def _one_of_check(cls, values):
"""Validate that all 'one of' fields are appear exactly once."""
one_of_map = cls._field_data_map("one_of")
for one_of, field_names in one_of_map.items():
if sum([values.get(n, None) is not None for n in field_names]) != 1:
raise ValueError(f"Exactly one of {','.join(field_names)} required.")
return values
@pydantic.root_validator
def _at_most_one_of_check(cls, values):
"""Validate that all 'at_most_one_of' fields appear at most once."""
at_most_one_of_map = cls._field_data_map("at_most_one_of")
for one_of, field_names in at_most_one_of_map.items():
if sum([values.get(n, None) is not None for n in field_names]) > 1:
raise ValueError(f"At most one of {','.join(field_names)} can be set.")
return values
def pretty_print(self) -> str:
"""Return a human legible (yaml) representation of the config useful for logging."""
return yaml.dump(self.dict())

View File

@ -0,0 +1,38 @@
from unittest import TestCase
from tml.core.config import BaseConfig
import pydantic
class BaseConfigTest(TestCase):
def test_extra_forbidden(self):
class Config(BaseConfig):
x: int
Config(x=1)
with self.assertRaises(pydantic.ValidationError):
Config(x=1, y=2)
def test_one_of(self):
class Config(BaseConfig):
x: int = pydantic.Field(None, one_of="f")
y: int = pydantic.Field(None, one_of="f")
with self.assertRaises(pydantic.ValidationError):
Config()
Config(x=1)
Config(y=1)
with self.assertRaises(pydantic.ValidationError):
Config(x=1, y=3)
def test_at_most_one_of(self):
class Config(BaseConfig):
x: int = pydantic.Field(None, at_most_one_of="f")
y: str = pydantic.Field(None, at_most_one_of="f")
Config()
Config(x=1)
Config(y="a")
with self.assertRaises(pydantic.ValidationError):
Config(x=1, y="a")

View File

@ -0,0 +1,24 @@
import yaml
import string
import getpass
import os
from typing import Type
from tml.core.config.base_config import BaseConfig
def load_config_from_yaml(config_type: Type[BaseConfig], yaml_path: str):
"""Recommend method to load a config file (a yaml file) and parse it.
Because we have a shared filesystem the recommended route to running jobs it put modified config
files with the desired parameters somewhere on the filesytem and run jobs pointing to them.
"""
def _substitute(s):
return string.Template(s).safe_substitute(os.environ, USER=getpass.getuser())
with open(yaml_path, "r") as f:
raw_contents = f.read()
obj = yaml.safe_load(_substitute(raw_contents))
return config_type.parse_obj(obj)

View File

@ -0,0 +1,23 @@
from unittest import TestCase
from tml.core.config import BaseConfig, load_config_from_yaml
import pydantic
import getpass
import pydantic
class _PointlessConfig(BaseConfig):
a: int
user: str
def test_load_config_from_yaml(tmp_path):
yaml_path = tmp_path.joinpath("test.yaml").as_posix()
with open(yaml_path, "w") as yaml_file:
yaml_file.write("""a: 3\nuser: ${USER}\n""")
pointless_config = load_config_from_yaml(_PointlessConfig, yaml_path)
assert pointless_config.a == 3
assert pointless_config.user == getpass.getuser()

40
core/config/training.py Normal file
View File

@ -0,0 +1,40 @@
from typing import Any, Dict, List, Optional
from tml.common.wandb import WandbConfig
from tml.core.config import base_config
from tml.projects.twhin.data.config import TwhinDataConfig
from tml.projects.twhin.models.config import TwhinModelConfig
import pydantic
class RuntimeConfig(base_config.BaseConfig):
wandb: WandbConfig = pydantic.Field(None)
enable_tensorfloat32: bool = pydantic.Field(
False, description="Use tensorfloat32 if on Ampere devices."
)
enable_amp: bool = pydantic.Field(False, description="Enable automatic mixed precision.")
class TrainingConfig(base_config.BaseConfig):
save_dir: str = pydantic.Field("/tmp/model", description="Directory to save checkpoints.")
num_train_steps: pydantic.PositiveInt = 10000
initial_checkpoint_dir: str = pydantic.Field(
None, description="Directory of initial checkpoints", at_most_one_of="initialization"
)
checkpoint_every_n: pydantic.PositiveInt = 1000
checkpoint_max_to_keep: pydantic.PositiveInt = pydantic.Field(
None, description="Maximum number of checkpoints to keep. Defaults to keeping all."
)
train_log_every_n: pydantic.PositiveInt = 1000
num_eval_steps: int = pydantic.Field(
16384, description="Number of evaluation steps. If < 0 the entire dataset will be used."
)
eval_log_every_n: pydantic.PositiveInt = 5000
eval_timeout_in_s: pydantic.PositiveFloat = 60 * 60
gradient_accumulation: int = pydantic.Field(
None, description="Number of replica steps to accumulate gradients."
)
num_epochs: pydantic.PositiveInt = 1

View File

@ -0,0 +1,316 @@
"""Torch and torchrec specific training and evaluation loops.
Features (go/100_enablements):
- CUDA data-fetch, compute, gradient-push overlap
- Large learnable embeddings through torchrec
- On/off-chief evaluation
- Warmstart/checkpoint management
- go/dataset-service 0-copy integration
"""
import datetime
import os
from typing import Callable, Dict, Iterable, List, Mapping, Optional
from tml.common import log_weights
import tml.common.checkpointing.snapshot as snapshot_lib
from tml.core.losses import get_global_loss_detached
from tml.ml_logging.torch_logging import logging # type: ignore[attr-defined]
from tml.core.train_pipeline import TrainPipelineSparseDist
import tree
import torch
import torch.distributed as dist
from torch.optim.lr_scheduler import _LRScheduler
import torchmetrics as tm
def get_new_iterator(iterable: Iterable):
"""
This obtain a new iterator from the iterable. If the iterable uses tf.data.Dataset internally,
getting a new iterator each N steps will avoid memory leak. To avoid the memory leak
calling iter(iterable) should return a "fresh" iterator using a fresh
(new instance of) tf.data.Iterator.
In particular, iterable can be a torch.utils.data.IterableDataset or a
torch.utils.data.DataLoader.
When using DDS, performing this reset does not change the order in which elements are received
(excluding elements already prefetched) provided that iter(iterable) internally uses
a new instance of tf.data.Dataset created by calling from_dataset_id.
This requirement is satisfied by RecapDataset.
:param iterable:
:return:
"""
return iter(iterable)
def _get_step_fn(pipeline, data_iterator, training: bool):
def step_fn():
# It turns out that model.train() and model.eval() simply switch a single field inside the model
# class,so it's somewhat safer to wrap in here.
if training:
pipeline._model.train()
else:
pipeline._model.eval()
outputs = pipeline.progress(data_iterator)
return tree.map_structure(lambda elem: elem.detach(), outputs)
return step_fn
@torch.no_grad()
def _run_evaluation(
pipeline,
dataset,
eval_steps: int,
metrics: tm.MetricCollection,
eval_batch_size: int,
logger=None,
):
"""Runs the evaluation loop over all evaluation iterators."""
dataset = get_new_iterator(dataset)
step_fn = _get_step_fn(pipeline, dataset, training=False)
last_time = datetime.datetime.now()
logging.info(f"Starting {eval_steps} steps of evaluation.")
for _ in range(eval_steps):
outputs = step_fn()
metrics.update(outputs)
eval_ex_per_s = (
eval_batch_size * eval_steps / (datetime.datetime.now() - last_time).total_seconds()
)
logging.info(f"eval examples_per_s : {eval_ex_per_s}")
metrics_result = metrics.compute()
# Resetting at end to release metrics memory not in use.
# Reset metrics to prevent accumulation between multiple evaluation splits and not report a
# running average.
metrics.reset()
return metrics_result
def train(
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
device: str,
save_dir: str,
logging_interval: int,
train_steps: int,
checkpoint_frequency: int,
dataset: Iterable,
worker_batch_size: int,
num_workers: Optional[int] = 0,
enable_amp: bool = False,
initial_checkpoint_dir: Optional[str] = None,
gradient_accumulation: Optional[int] = None,
logger_initializer: Optional[Callable] = None,
scheduler: _LRScheduler = None,
metrics: Optional[tm.MetricCollection] = None,
parameters_to_log: Optional[Dict[str, Callable]] = None,
tables_to_log: Optional[List[str]] = None,
) -> None:
"""Runs training and eval on the given TrainPipeline
Args:
dataset: data iterator for the training set
evaluation_iterators: data iterators for the different evaluation sets
scheduler: optional learning rate scheduler
output_transform_for_metrics: optional transformation functions to transorm the model
output and labels into a format the metrics can understand
"""
train_pipeline = TrainPipelineSparseDist(
model=model,
optimizer=optimizer,
device=device,
enable_amp=enable_amp,
grad_accum=gradient_accumulation,
) # type: ignore[var-annotated]
# We explicitly initialize optimizer state here so that checkpoint will work properly.
if hasattr(train_pipeline._optimizer, "init_state"):
train_pipeline._optimizer.init_state()
save_state = {
"model": train_pipeline._model,
"optimizer": train_pipeline._optimizer,
"scaler": train_pipeline._grad_scaler,
}
chosen_checkpoint = None
checkpoint_handler = snapshot_lib.Snapshot(
save_dir=save_dir,
state=save_state,
)
if save_dir:
chosen_checkpoint = snapshot_lib.get_checkpoint(save_dir=save_dir, missing_ok=True)
start_step = 0
start_walltime = 0.0
if chosen_checkpoint:
# Skip restoration and exit if we should be finished.
chosen_checkpoint_global_step = snapshot_lib.step_from_checkpoint(chosen_checkpoint)
if not chosen_checkpoint_global_step < dist.get_world_size() * train_steps:
logging.info(
"Not restoring and finishing training as latest checkpoint "
f"{chosen_checkpoint} found "
f"at global_step ({chosen_checkpoint_global_step}) >= "
f"train_steps ({dist.get_world_size() * train_steps})"
)
return
logging.info(f"Restoring latest checkpoint from global_step {chosen_checkpoint_global_step}")
checkpoint_handler.restore(chosen_checkpoint)
start_step = checkpoint_handler.step
start_walltime = checkpoint_handler.walltime
elif initial_checkpoint_dir:
base, ckpt_step = os.path.split(initial_checkpoint_dir)
warmstart_handler = snapshot_lib.Snapshot(
save_dir=base,
state=save_state,
)
ckpt = snapshot_lib.get_checkpoint(save_dir=base, missing_ok=False, global_step=int(ckpt_step))
logging.info(
f"Restoring from initial_checkpoint_dir: {initial_checkpoint_dir}, but keeping starting step as 0."
)
warmstart_handler.restore(ckpt)
train_logger = logger_initializer(mode="train") if logger_initializer else None
train_step_fn = _get_step_fn(train_pipeline, get_new_iterator(dataset), training=True)
# Counting number of parameters in the model directly when creating it.
nb_param = 0
for p in model.parameters():
nb_param += p.numel()
logging.info(f"Model has {nb_param} parameters")
last_time = datetime.datetime.now()
start_time = last_time
last_pending_snapshot = None
for step in range(start_step, train_steps + 1):
checkpoint_handler.step = step
outputs = train_step_fn()
step_done_time = datetime.datetime.now()
checkpoint_handler.walltime = (step_done_time - start_time).total_seconds() + start_walltime
if scheduler:
scheduler.step()
if step % logging_interval == 0:
interval_time = (step_done_time - last_time).total_seconds()
steps_per_s = logging_interval / interval_time
worker_example_per_s = steps_per_s * worker_batch_size
global_example_per_s = worker_example_per_s * (1 + (num_workers or 0))
global_step = step
log_values = {
"global_step": global_step,
"loss": get_global_loss_detached(outputs["loss"]),
"steps_per_s": steps_per_s,
"global_example_per_s": global_example_per_s,
"worker_examples_per_s": worker_example_per_s,
"active_training_walltime": checkpoint_handler.walltime,
}
if parameters_to_log:
log_values.update(
log_weights.weights_to_log(
model=model,
how_to_log=parameters_to_log,
)
)
log_values = tree.map_structure(lambda elem: torch.as_tensor(elem).cpu(), log_values)
if tables_to_log:
log_values.update(
log_weights.log_ebc_norms(
model_state_dict=train_pipeline._model.state_dict(),
ebc_keys=tables_to_log,
)
)
if train_logger:
train_logger.log(log_values, step=global_step)
log_line = ", ".join(f"{name}: {value}" for name, value in log_values.items())
logging.info(f"Step: {step}, training. {log_line}")
last_time = step_done_time
# If we just restored, do not save again.
if checkpoint_frequency and step > start_step and step % checkpoint_frequency == 0:
if last_pending_snapshot and not last_pending_snapshot.done():
logging.warning(
"Begin a new snapshot and the last one hasn't finished. That probably indicates "
"either you're snapshotting really often or something is wrong. Will now block and "
"wait for snapshot to finish before beginning the next one."
)
last_pending_snapshot.wait()
last_pending_snapshot = checkpoint_handler.save(global_step=step * dist.get_world_size())
# Save if we did not just save.
if checkpoint_frequency and step % checkpoint_frequency != 0:
# For the final save, wait for the checkpoint to write to make sure the process doesn't finish
# before its completed.
last_pending_snapshot = checkpoint_handler.save(global_step=step * dist.get_world_size())
logging.info(f"Finished training steps: {step}, global_steps: {step * dist.get_world_size()}")
if last_pending_snapshot:
logging.info(f"Waiting for any checkpoints to finish.")
last_pending_snapshot.wait()
def log_eval_results(
results,
eval_logger,
partition_name: str,
step: int,
):
results = tree.map_structure(lambda elem: torch.as_tensor(elem).cpu(), results)
logging.info(f"Step: {step}, evaluation ({partition_name}).")
for metric_name, metric_value in results.items():
logging.info(f"\t{metric_name}: {metric_value:1.4e}")
if eval_logger:
eval_logger.log(results, step=step, commit=True)
def only_evaluate(
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
device: str,
save_dir: str,
num_train_steps: int,
dataset: Iterable,
eval_batch_size: int,
num_eval_steps: int,
eval_timeout_in_s: int,
eval_logger: Callable,
partition_name: str,
metrics: Optional[tm.MetricCollection] = None,
):
logging.info(f"Evaluating on partition {partition_name}.")
logging.info("Computing metrics:")
logging.info(metrics)
eval_pipeline = TrainPipelineSparseDist(model, optimizer, device) # type: ignore[var-annotated]
save_state = {
"model": eval_pipeline._model,
"optimizer": eval_pipeline._optimizer,
}
checkpoint_handler = snapshot_lib.Snapshot(
save_dir=save_dir,
state=save_state,
)
for checkpoint_path in snapshot_lib.checkpoints_iterator(save_dir, timeout=eval_timeout_in_s):
checkpoint_handler.restore(checkpoint_path)
step = checkpoint_handler.step
dataset = get_new_iterator(dataset)
results = _run_evaluation(
pipeline=eval_pipeline,
dataset=dataset,
eval_steps=num_eval_steps,
eval_batch_size=eval_batch_size,
metrics=metrics,
)
log_eval_results(results, eval_logger, partition_name, step=step)
rank = dist.get_rank() if dist.is_initialized() else 0
if rank == 0:
snapshot_lib.mark_done_eval(checkpoint_path, partition_name)
if step >= num_train_steps:
return

View File

@ -0,0 +1,45 @@
"""This is a very limited feature training loop useful for interactive debugging.
It is not intended for actual model tranining (it is not fast, doesn't compile the model).
It does not support checkpointing.
suggested use:
from tml.core import debug_training_loop
debug_training_loop.train(...)
"""
from typing import Iterable, Optional, Dict, Callable, List
import torch
from torch.optim.lr_scheduler import _LRScheduler
import torchmetrics as tm
from tml.ml_logging.torch_logging import logging
def train(
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
train_steps: int,
dataset: Iterable,
scheduler: _LRScheduler = None,
# Accept any arguments (to be compatible with the real training loop)
# but just ignore them.
*args,
**kwargs,
) -> None:
logging.warning("Running debug training loop, don't use for model training.")
data_iter = iter(dataset)
for step in range(0, train_steps + 1):
x = next(data_iter)
optimizer.zero_grad()
loss, outputs = model.forward(x)
loss.backward()
optimizer.step()
if scheduler:
scheduler.step()
logging.info(f"Step {step} completed. Loss = {loss}")

7
core/loss_type.py Normal file
View File

@ -0,0 +1,7 @@
"""Loss type enums."""
from enum import Enum
class LossType(str, Enum):
CROSS_ENTROPY = "cross_entropy"
BCE_WITH_LOGITS = "bce_with_logits"

109
core/losses.py Normal file
View File

@ -0,0 +1,109 @@
"""Loss functions -- including multi task ones."""
import typing
from tml.core.loss_type import LossType
from tml.ml_logging.torch_logging import logging
import torch
def _maybe_warn(reduction: str):
"""
Warning for reduction different than mean.
"""
if reduction != "mean":
logging.warn(
f"For the same global_batch_size, the gradient in DDP is guaranteed to be equal,"
f"to the gradient without DDP only for mean reduction. If you need this property for"
f"the provided reduction {reduction}, it needs to be implemented."
)
def build_loss(
loss_type: LossType,
reduction="mean",
):
_maybe_warn(reduction)
f = _LOSS_TYPE_TO_FUNCTION[loss_type]
def loss_fn(logits, labels):
return f(logits, labels.type_as(logits), reduction=reduction)
return loss_fn
def get_global_loss_detached(local_loss, reduction="mean"):
"""
Perform all_reduce to obtain the global loss function using the provided reduction.
:param local_loss: The local loss of the current rank.
:param reduction: The reduction to use for all_reduce. Should match the reduction used by DDP.
:return: The reduced & detached global loss.
"""
if reduction != "mean":
logging.warn(
f"The reduction used in this function should be the same as the one used by "
f"the DDP model. By default DDP uses mean, So ensure that DDP is appropriately"
f"modified for reduction {reduction}."
)
if reduction not in ["mean", "sum"]:
raise ValueError(f"Reduction {reduction} is currently unsupported.")
global_loss = local_loss.detach()
if reduction == "mean":
global_loss.div_(torch.distributed.get_world_size())
torch.distributed.all_reduce(global_loss)
return global_loss
def build_multi_task_loss(
loss_type: LossType,
tasks: typing.List[str],
task_loss_reduction="mean",
global_reduction="mean",
pos_weights=None,
):
_maybe_warn(global_reduction)
_maybe_warn(task_loss_reduction)
f = _LOSS_TYPE_TO_FUNCTION[loss_type]
loss_reduction_fns = {
"mean": torch.mean,
"sum": torch.sum,
"min": torch.min,
"max": torch.max,
"median": torch.median,
}
def loss_fn(logits: torch.Tensor, labels: torch.Tensor, weights: torch.Tensor):
if pos_weights is None:
torch_weights = torch.ones([len(tasks)])
else:
torch_weights = torch.tensor(pos_weights)
losses = {}
for task_idx, task in enumerate(tasks):
task_logits = logits[:, task_idx]
label = labels[:, task_idx].type_as(task_logits)
loss = f(
task_logits,
label,
reduction=task_loss_reduction,
pos_weight=torch_weights[task_idx],
weight=weights[:, task_idx],
)
losses[f"loss/{task}"] = loss
losses["loss"] = loss_reduction_fns[global_reduction](torch.stack(list(losses.values())))
return losses
return loss_fn
_LOSS_TYPE_TO_FUNCTION = {
LossType.BCE_WITH_LOGITS: torch.nn.functional.binary_cross_entropy_with_logits
}

104
core/metric_mixin.py Normal file
View File

@ -0,0 +1,104 @@
"""
Mixin that requires a transform to munge output dictionary of tensors a
model produces to a form that the torchmetrics.Metric.update expects.
By unifying on our signature for `update`, we can also now use
torchmetrics.MetricCollection which requires all metrics have
the same call signature.
To use, override this with a transform that munges `outputs`
into a kwargs dict that the inherited metric.update accepts.
Here are two examples of how to extend torchmetrics.SumMetric so that it accepts
an output dictionary of tensors and munges it to what SumMetric expects (single `value`)
for its update method.
1. Using as a mixin to inherit from or define a new metric class.
class Count(MetricMixin, SumMetric):
def transform(self, outputs):
return {'value': 1}
2. Redefine an existing metric class.
SumMetric = prepend_transform(SumMetric, lambda outputs: {'value': 1})
"""
from abc import abstractmethod
from typing import Callable, Dict, List
from tml.ml_logging.torch_logging import logging # type: ignore[attr-defined]
import torch
import torchmetrics
class MetricMixin:
@abstractmethod
def transform(self, outputs: Dict[str, torch.Tensor]) -> Dict:
...
def update(self, outputs: Dict[str, torch.Tensor]):
results = self.transform(outputs)
# Do not try to update if any tensor is empty as a result of stratification.
for value in results.values():
if torch.is_tensor(value) and not value.nelement():
return
super().update(**results)
class TaskMixin:
def __init__(self, task_idx: int = -1, **kwargs):
super().__init__(**kwargs)
self._task_idx = task_idx
class StratifyMixin:
def __init__(
self,
stratifier=None,
**kwargs,
):
super().__init__(**kwargs)
self._stratifier = stratifier
def maybe_apply_stratification(
self, outputs: Dict[str, torch.Tensor], value_names: List[str]
) -> Dict[str, torch.Tensor]:
"""Pick out examples with values for which the stratifier feature is equal to a specific stratifier indicator value."""
outputs = outputs.copy()
if not self._stratifier:
return outputs
stratifiers = outputs.get("stratifiers")
if not stratifiers:
return outputs
if stratifiers.get(self._stratifier.name) is None:
return outputs
mask = torch.flatten(outputs["stratifiers"][self._stratifier.name] == self._stratifier.value)
target_slice = torch.squeeze(mask.nonzero(), -1)
for value_name in value_names:
target = outputs[value_name]
outputs[value_name] = torch.index_select(target, 0, target_slice)
return outputs
def prepend_transform(base_metric: torchmetrics.Metric, transform: Callable):
"""Returns new class using MetricMixin and given base_metric.
Functionally the same using inheritance, just saves some lines of code
if no need for class attributes.
"""
def transform_method(_self, *args, **kwargs):
return transform(*args, **kwargs)
return type(
base_metric.__name__,
(
MetricMixin,
base_metric,
),
{"transform": transform_method},
)

148
core/metrics.py Normal file
View File

@ -0,0 +1,148 @@
"""Common metrics that also support multi task.
We assume multi task models will output [task_idx, ...] predictions
"""
from typing import Any, Dict
from tml.core.metric_mixin import MetricMixin, StratifyMixin, TaskMixin
import torch
import torchmetrics as tm
def probs_and_labels(
outputs: Dict[str, torch.Tensor],
task_idx: int,
) -> Dict[str, torch.Tensor]:
preds = outputs["probabilities"]
target = outputs["labels"]
if task_idx >= 0:
preds = preds[:, task_idx]
target = target[:, task_idx]
return {
"preds": preds,
"target": target.int(),
}
class Count(StratifyMixin, TaskMixin, MetricMixin, tm.SumMetric):
def transform(self, outputs):
outputs = self.maybe_apply_stratification(outputs, ["labels"])
value = outputs["labels"]
if self._task_idx >= 0:
value = value[:, self._task_idx]
return {"value": value}
class Ctr(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
def transform(self, outputs):
outputs = self.maybe_apply_stratification(outputs, ["labels"])
value = outputs["labels"]
if self._task_idx >= 0:
value = value[:, self._task_idx]
return {"value": value}
class Pctr(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
def transform(self, outputs):
outputs = self.maybe_apply_stratification(outputs, ["probabilities"])
value = outputs["probabilities"]
if self._task_idx >= 0:
value = value[:, self._task_idx]
return {"value": value}
class Precision(StratifyMixin, TaskMixin, MetricMixin, tm.Precision):
def transform(self, outputs):
outputs = self.maybe_apply_stratification(outputs, ["probabilities", "labels"])
return probs_and_labels(outputs, self._task_idx)
class Recall(StratifyMixin, TaskMixin, MetricMixin, tm.Recall):
def transform(self, outputs):
outputs = self.maybe_apply_stratification(outputs, ["probabilities", "labels"])
return probs_and_labels(outputs, self._task_idx)
class TorchMetricsRocauc(StratifyMixin, TaskMixin, MetricMixin, tm.AUROC):
def transform(self, outputs):
outputs = self.maybe_apply_stratification(outputs, ["probabilities", "labels"])
return probs_and_labels(outputs, self._task_idx)
class Auc(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
"""
Based on:
https://github.com/facebookresearch/PyTorch-BigGraph/blob/a11ff0eb644b7e4cb569067c280112b47f40ef62/torchbiggraph/util.py#L420
"""
def __init__(self, num_samples, **kwargs):
super().__init__(**kwargs)
self.num_samples = num_samples
def transform(self, outputs: Dict[str, torch.Tensor]) -> Dict[str, Any]:
scores, labels = outputs["logits"], outputs["labels"]
pos_scores = scores[labels == 1]
neg_scores = scores[labels == 0]
result = {
"value": pos_scores[torch.randint(len(pos_scores), (self.num_samples,))]
> neg_scores[torch.randint(len(neg_scores), (self.num_samples,))]
}
return result
class PosRanks(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
"""
The ranks of all positives
Based on:
https://github.com/facebookresearch/PyTorch-BigGraph/blob/a11ff0eb644b7e4cb569067c280112b47f40ef62/torchbiggraph/eval.py#L73
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
def transform(self, outputs: Dict[str, torch.Tensor]) -> Dict[str, Any]:
scores, labels = outputs["logits"], outputs["labels"]
_, sorted_indices = scores.sort(descending=True)
pos_ranks = labels[sorted_indices].nonzero(as_tuple=True)[0] + 1 # all ranks start from 1
result = {"value": pos_ranks}
return result
class ReciprocalRank(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
"""
The reciprocal of the ranks of all
Based on:
https://github.com/facebookresearch/PyTorch-BigGraph/blob/a11ff0eb644b7e4cb569067c280112b47f40ef62/torchbiggraph/eval.py#L74
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
def transform(self, outputs: Dict[str, torch.Tensor]) -> Dict[str, Any]:
scores, labels = outputs["logits"], outputs["labels"]
_, sorted_indices = scores.sort(descending=True)
pos_ranks = labels[sorted_indices].nonzero(as_tuple=True)[0] + 1 # all ranks start from 1
result = {"value": torch.div(torch.ones_like(pos_ranks), pos_ranks)}
return result
class HitAtK(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
"""
The fraction of positives that rank in the top K among their negatives
Note that this is basically precision@k
Based on:
https://github.com/facebookresearch/PyTorch-BigGraph/blob/a11ff0eb644b7e4cb569067c280112b47f40ef62/torchbiggraph/eval.py#L75
"""
def __init__(self, k: int, **kwargs):
super().__init__(**kwargs)
self.k = k
def transform(self, outputs: Dict[str, torch.Tensor]) -> Dict[str, Any]:
scores, labels = outputs["logits"], outputs["labels"]
_, sorted_indices = scores.sort(descending=True)
pos_ranks = labels[sorted_indices].nonzero(as_tuple=True)[0] + 1 # all ranks start from 1
result = {"value": (pos_ranks <= self.k).float()}
return result

187
core/test_metrics.py Normal file
View File

@ -0,0 +1,187 @@
from dataclasses import dataclass
from tml.core import metrics as core_metrics
from tml.core.metric_mixin import MetricMixin, prepend_transform
import torch
from torchmetrics import MaxMetric, MetricCollection, SumMetric
@dataclass
class MockStratifierConfig:
name: str
index: int
value: int
class Count(MetricMixin, SumMetric):
def transform(self, outputs):
return {"value": 1}
Max = prepend_transform(MaxMetric, lambda outputs: {"value": outputs["value"]})
def test_count_metric():
num_examples = 123
examples = [
{"stuff": 0},
] * num_examples
metric = Count()
for outputs in examples:
metric.update(outputs)
assert metric.compute().item() == num_examples
def test_collections():
max_metric = Max()
count_metric = Count()
metric = MetricCollection([max_metric, count_metric])
examples = [{"value": idx} for idx in range(123)]
for outputs in examples:
metric.update(outputs)
assert metric.compute() == {
max_metric.__class__.__name__: len(examples) - 1,
count_metric.__class__.__name__: len(examples),
}
def test_task_dependent_ctr():
num_examples = 144
batch_size = 1024
outputs = [
{
"stuff": 0,
"labels": torch.arange(0, 6).repeat(batch_size, 1),
}
for idx in range(num_examples)
]
for task_idx in range(5):
metric = core_metrics.Ctr(task_idx=task_idx)
for output in outputs:
metric.update(output)
assert metric.compute().item() == task_idx
def test_stratified_ctr():
outputs = [
{
"stuff": 0,
# [bsz, tasks]
"labels": torch.tensor(
[
[0, 1, 2, 3],
[1, 2, 3, 4],
[2, 3, 4, 0],
]
),
"stratifiers": {
# [bsz]
"level": torch.tensor(
[9, 0, 9],
),
},
}
]
stratifier = MockStratifierConfig(name="level", index=2, value=9)
for task_idx in range(5):
metric = core_metrics.Ctr(task_idx=1, stratifier=stratifier)
for output in outputs:
metric.update(output)
# From the dataset of:
# [
# [0, 1, 2, 3],
# [1, 2, 3, 4],
# [2, 3, 4, 0],
# ]
# we pick out
# [
# [0, 1, 2, 3],
# [2, 3, 4, 0],
# ]
# and with Ctr task_idx, we pick out
# [
# [1,],
# [3,],
# ]
assert metric.compute().item() == (1 + 3) / 2
def test_auc():
num_samples = 10000
metric = core_metrics.Auc(num_samples)
target = torch.tensor([0, 0, 1, 1, 1])
preds_correct = torch.tensor([-1.0, -1.0, 1.0, 1.0, 1.0])
outputs_correct = {"logits": preds_correct, "labels": target}
preds_bad = torch.tensor([1.0, 1.0, -1.0, -1.0, -1.0])
outputs_bad = {"logits": preds_bad, "labels": target}
metric.update(outputs_correct)
assert metric.compute().item() == 1.0
metric.reset()
metric.update(outputs_bad)
assert metric.compute().item() == 0.0
def test_pos_rank():
metric = core_metrics.PosRanks()
target = torch.tensor([0, 0, 1, 1, 1])
preds_correct = torch.tensor([-1.0, -1.0, 0.5, 1.0, 1.5])
outputs_correct = {"logits": preds_correct, "labels": target}
preds_bad = torch.tensor([1.0, 1.0, -1.5, -1.0, -0.5])
outputs_bad = {"logits": preds_bad, "labels": target}
metric.update(outputs_correct)
assert metric.compute().item() == 2.0
metric.reset()
metric.update(outputs_bad)
assert metric.compute().item() == 4.0
def test_reciprocal_rank():
metric = core_metrics.ReciprocalRank()
target = torch.tensor([0, 0, 1, 1, 1])
preds_correct = torch.tensor([-1.0, -1.0, 0.5, 1.0, 1.5])
outputs_correct = {"logits": preds_correct, "labels": target}
preds_bad = torch.tensor([1.0, 1.0, -1.5, -1.0, -0.5])
outputs_bad = {"logits": preds_bad, "labels": target}
metric.update(outputs_correct)
assert abs(metric.compute().item() - 0.6111) < 0.001
metric.reset()
metric.update(outputs_bad)
assert abs(metric.compute().item() == 0.2611) < 0.001
def test_hit_k():
hit1_metric = core_metrics.HitAtK(1)
target = torch.tensor([0, 0, 1, 1, 1])
preds_correct = torch.tensor([-1.0, 1.0, 0.5, -0.1, 1.5])
outputs_correct = {"logits": preds_correct, "labels": target}
preds_bad = torch.tensor([1.0, 1.0, -1.5, -1.0, -0.5])
outputs_bad = {"logits": preds_bad, "labels": target}
hit1_metric.update(outputs_correct)
assert abs(hit1_metric.compute().item() - 0.3333) < 0.0001
hit1_metric.reset()
hit1_metric.update(outputs_bad)
assert hit1_metric.compute().item() == 0
hit3_metric = core_metrics.HitAtK(3)
hit3_metric.update(outputs_correct)
assert (hit3_metric.compute().item() - 0.66666) < 0.0001
hit3_metric.reset()
hit3_metric.update(outputs_bad)
assert abs(hit3_metric.compute().item() - 0.3333) < 0.0001

View File

@ -0,0 +1,91 @@
from dataclasses import dataclass
from typing import Tuple
from tml.common.batch import DataclassBatch
from tml.common.testing_utils import mock_pg
from tml.core import train_pipeline
import torch
from torchrec.distributed import DistributedModelParallel
@dataclass
class MockDataclassBatch(DataclassBatch):
continuous_features: torch.Tensor
labels: torch.Tensor
class MockModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.model = torch.nn.Linear(10, 1)
self.loss_fn = torch.nn.BCEWithLogitsLoss()
def forward(self, batch: MockDataclassBatch) -> Tuple[torch.Tensor, torch.Tensor]:
pred = self.model(batch.continuous_features)
loss = self.loss_fn(pred, batch.labels)
return (loss, pred)
def create_batch(bsz: int):
return MockDataclassBatch(
continuous_features=torch.rand(bsz, 10).float(),
labels=torch.bernoulli(torch.empty(bsz, 1).uniform_(0, 1)).float(),
)
def test_sparse_pipeline():
device = torch.device("cpu")
model = MockModule().to(device)
steps = 8
example = create_batch(1)
dataloader = iter(example for _ in range(steps + 2))
results = []
with mock_pg():
d_model = DistributedModelParallel(model)
pipeline = train_pipeline.TrainPipelineSparseDist(
model=d_model,
optimizer=torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9),
device=device,
grad_accum=2,
)
for _ in range(steps):
results.append(pipeline.progress(dataloader))
results = [elem.detach().numpy() for elem in results]
# Check gradients are accumulated, i.e. results do not change for every 0th and 1th.
for first, second in zip(results[::2], results[1::2]):
assert first == second, results
# Check we do update gradients, i.e. results do change for every 1th and 2nd.
for first, second in zip(results[1::2], results[2::2]):
assert first != second, results
def test_amp():
device = torch.device("cpu")
model = MockModule().to(device)
steps = 8
example = create_batch(1)
dataloader = iter(example for _ in range(steps + 2))
results = []
with mock_pg():
d_model = DistributedModelParallel(model)
pipeline = train_pipeline.TrainPipelineSparseDist(
model=d_model,
optimizer=torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9),
device=device,
enable_amp=True,
# Not supported on CPU.
enable_grad_scaling=False,
)
for _ in range(steps):
results.append(pipeline.progress(dataloader))
results = [elem.detach() for elem in results]
for value in results:
assert value.dtype == torch.bfloat16

626
core/train_pipeline.py Normal file
View File

@ -0,0 +1,626 @@
"""
Taken from https://raw.githubusercontent.com/pytorch/torchrec/v0.3.2/torchrec/distributed/train_pipeline.py
with TrainPipelineSparseDist.progress modified to support gradient accumulation.
"""
import abc
from dataclasses import dataclass, field
import logging
from typing import (
Any,
cast,
Dict,
Generic,
Iterator,
List,
Optional,
Set,
Tuple,
TypeVar,
)
import torch
from torch.autograd.profiler import record_function
from torch.fx.node import Node
from torchrec.distributed.model_parallel import (
DistributedModelParallel,
ShardedModule,
)
from torchrec.distributed.types import Awaitable
from torchrec.modules.feature_processor import BaseGroupedFeatureProcessor
from torchrec.streamable import Multistreamable, Pipelineable
logger: logging.Logger = logging.getLogger(__name__)
In = TypeVar("In", bound=Pipelineable)
Out = TypeVar("Out")
class TrainPipeline(abc.ABC, Generic[In, Out]):
@abc.abstractmethod
def progress(self, dataloader_iter: Iterator[In]) -> Out:
pass
def _to_device(batch: In, device: torch.device, non_blocking: bool) -> In:
assert isinstance(
batch, (torch.Tensor, Pipelineable)
), f"{type(batch)} must implement Pipelineable interface"
return cast(In, batch.to(device=device, non_blocking=non_blocking))
def _wait_for_batch(batch: In, stream: Optional[torch.cuda.streams.Stream]) -> None:
if stream is None:
return
torch.cuda.current_stream().wait_stream(stream)
# As mentioned in https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html,
# PyTorch uses the "caching allocator" for memory allocation for tensors. When a tensor is
# freed, its memory is likely to be reused by newly constructed tenosrs. By default,
# this allocator traces whether a tensor is still in use by only the CUDA stream where it
# was created. When a tensor is used by additional CUDA streams, we need to call record_stream
# to tell the allocator about all these streams. Otherwise, the allocator might free the
# underlying memory of the tensor once it is no longer used by the creator stream. This is
# a notable programming trick when we write programs using multi CUDA streams.
cur_stream = torch.cuda.current_stream()
assert isinstance(
batch, (torch.Tensor, Multistreamable)
), f"{type(batch)} must implement Multistreamable interface"
batch.record_stream(cur_stream)
class TrainPipelineBase(TrainPipeline[In, Out]):
"""
This class runs training iterations using a pipeline of two stages, each as a CUDA
stream, namely, the current (default) stream and `self._memcpy_stream`. For each
iteration, `self._memcpy_stream` moves the input from host (CPU) memory to GPU
memory, and the default stream runs forward, backward, and optimization.
"""
def __init__(
self,
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
device: torch.device,
) -> None:
self._model = model
self._optimizer = optimizer
self._device = device
self._memcpy_stream: Optional[torch.cuda.streams.Stream] = (
torch.cuda.Stream() if device.type == "cuda" else None
)
self._cur_batch: Optional[In] = None
self._connected = False
def _connect(self, dataloader_iter: Iterator[In]) -> None:
cur_batch = next(dataloader_iter)
self._cur_batch = cur_batch
with torch.cuda.stream(self._memcpy_stream):
self._cur_batch = _to_device(cur_batch, self._device, non_blocking=True)
self._connected = True
def progress(self, dataloader_iter: Iterator[In]) -> Out:
if not self._connected:
self._connect(dataloader_iter)
# Fetch next batch
with record_function("## next_batch ##"):
next_batch = next(dataloader_iter)
cur_batch = self._cur_batch
assert cur_batch is not None
if self._model.training:
with record_function("## zero_grad ##"):
self._optimizer.zero_grad()
with record_function("## wait_for_batch ##"):
_wait_for_batch(cur_batch, self._memcpy_stream)
with record_function("## forward ##"):
(losses, output) = self._model(cur_batch)
if self._model.training:
with record_function("## backward ##"):
torch.sum(losses, dim=0).backward()
# Copy the next batch to GPU
self._cur_batch = cur_batch = next_batch
with record_function("## copy_batch_to_gpu ##"):
with torch.cuda.stream(self._memcpy_stream):
self._cur_batch = _to_device(cur_batch, self._device, non_blocking=True)
# Update
if self._model.training:
with record_function("## optimizer ##"):
self._optimizer.step()
return output
class Tracer(torch.fx.Tracer):
# Disable proxying buffers during tracing. Ideally, proxying buffers would
# be disabled, but some models are currently mutating buffer values, which
# causes errors during tracing. If those models can be rewritten to not do
# that, we can likely remove this line
proxy_buffer_attributes = False
def __init__(self, leaf_modules: Optional[List[str]] = None) -> None:
super().__init__()
self._leaf_modules: List[str] = leaf_modules if leaf_modules is not None else []
def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool:
if isinstance(m, ShardedModule) or module_qualified_name in self._leaf_modules:
return True
return super().is_leaf_module(m, module_qualified_name)
@dataclass
class TrainPipelineContext:
# pyre-ignore [4]
input_dist_requests: Dict[str, Awaitable[Any]] = field(default_factory=dict)
module_contexts: Dict[str, Multistreamable] = field(default_factory=dict)
# pyre-ignore [4]
feature_processor_forwards: List[Any] = field(default_factory=list)
@dataclass
class ArgInfo:
# attributes of input batch, e.g. batch.attr1.attr2 call
# will produce ["attr1", "attr2"]
input_attrs: List[str]
# batch[attr1].attr2 will produce [True, False]
is_getitems: List[bool]
# name for kwarg of pipelined forward() call or None
# for a positional arg
name: Optional[str]
class PipelinedForward:
def __init__(
self,
name: str,
args: List[ArgInfo],
module: ShardedModule,
context: TrainPipelineContext,
dist_stream: Optional[torch.cuda.streams.Stream],
) -> None:
self._name = name
self._args = args
self._module = module
self._context = context
self._dist_stream = dist_stream
# pyre-ignore [2, 24]
def __call__(self, *input, **kwargs) -> Awaitable:
assert self._name in self._context.input_dist_requests
request = self._context.input_dist_requests[self._name]
assert isinstance(request, Awaitable)
with record_function("## wait_sparse_data_dist ##"):
# Finish waiting on the dist_stream,
# in case some delayed stream scheduling happens during the wait() call.
with torch.cuda.stream(self._dist_stream):
data = request.wait()
# Make sure that both result of input_dist and context
# are properly transferred to the current stream.
if self._dist_stream is not None:
torch.cuda.current_stream().wait_stream(self._dist_stream)
cur_stream = torch.cuda.current_stream()
assert isinstance(
data, (torch.Tensor, Multistreamable)
), f"{type(data)} must implement Multistreamable interface"
# pyre-fixme[6]: For 1st param expected `Stream` but got `Stream`.
data.record_stream(cur_stream)
ctx = self._context.module_contexts[self._name]
ctx.record_stream(cur_stream)
if len(self._context.feature_processor_forwards) > 0:
with record_function("## feature_processor ##"):
for sparse_feature in data:
if sparse_feature.id_score_list_features is not None:
for fp_forward in self._context.feature_processor_forwards:
sparse_feature.id_score_list_features = fp_forward(
sparse_feature.id_score_list_features
)
return self._module.compute_and_output_dist(self._context.module_contexts[self._name], data)
@property
def name(self) -> str:
return self._name
@property
def args(self) -> List[ArgInfo]:
return self._args
def _start_data_dist(
pipelined_modules: List[ShardedModule],
batch: In,
context: TrainPipelineContext,
) -> None:
context.input_dist_requests.clear()
context.module_contexts.clear()
for module in pipelined_modules:
forward = module.forward
assert isinstance(forward, PipelinedForward)
# Retrieve argument for the input_dist of EBC
# is_getitem True means this argument could be retrieved by a list
# False means this argument is getting while getattr
# and this info was done in the _rewrite_model by tracing the
# entire model to get the arg_info_list
args = []
kwargs = {}
for arg_info in forward.args:
if arg_info.input_attrs:
arg = batch
for attr, is_getitem in zip(arg_info.input_attrs, arg_info.is_getitems):
if is_getitem:
arg = arg[attr]
else:
arg = getattr(arg, attr)
if arg_info.name:
kwargs[arg_info.name] = arg
else:
args.append(arg)
else:
args.append(None)
# Start input distribution.
module_ctx = module.create_context()
context.module_contexts[forward.name] = module_ctx
context.input_dist_requests[forward.name] = module.input_dist(module_ctx, *args, **kwargs)
# Call wait on the first awaitable in the input dist for the tensor splits
for key, awaitable in context.input_dist_requests.items():
context.input_dist_requests[key] = awaitable.wait()
def _get_node_args_helper(
# pyre-ignore
arguments,
num_found: int,
feature_processor_arguments: Optional[List[Node]] = None,
) -> Tuple[List[ArgInfo], int]:
"""
Goes through the args/kwargs of a node and arranges them into a list of `ArgInfo`s.
It also counts the number of (args + kwargs) found.
"""
arg_info_list = [ArgInfo([], [], None) for _ in range(len(arguments))]
for arg, arg_info in zip(arguments, arg_info_list):
if arg is None:
num_found += 1
continue
while True:
if not isinstance(arg, torch.fx.Node):
break
child_node = arg
if child_node.op == "placeholder":
num_found += 1
break
# skip this fp node
elif feature_processor_arguments is not None and child_node in feature_processor_arguments:
arg = child_node.args[0]
elif (
child_node.op == "call_function"
and child_node.target.__module__ == "builtins"
# pyre-ignore[16]
and child_node.target.__name__ == "getattr"
):
arg_info.input_attrs.insert(0, child_node.args[1])
arg_info.is_getitems.insert(0, False)
arg = child_node.args[0]
elif (
child_node.op == "call_function"
and child_node.target.__module__ == "_operator"
# pyre-ignore[16]
and child_node.target.__name__ == "getitem"
):
arg_info.input_attrs.insert(0, child_node.args[1])
arg_info.is_getitems.insert(0, True)
arg = child_node.args[0]
else:
break
return arg_info_list, num_found
def _get_node_args(
node: Node, feature_processor_nodes: Optional[List[Node]] = None
) -> Tuple[List[ArgInfo], int]:
num_found = 0
pos_arg_info_list, num_found = _get_node_args_helper(
node.args, num_found, feature_processor_nodes
)
kwargs_arg_info_list, num_found = _get_node_args_helper(node.kwargs.values(), num_found)
# Replace with proper names for kwargs
for name, arg_info_list in zip(node.kwargs, kwargs_arg_info_list):
arg_info_list.name = name
arg_info_list = pos_arg_info_list + kwargs_arg_info_list
return arg_info_list, num_found
def _get_unsharded_module_names_helper(
model: torch.nn.Module,
path: str,
unsharded_module_names: Set[str],
) -> bool:
sharded_children = set()
for name, child in model.named_children():
curr_path = path + name
if isinstance(child, ShardedModule):
sharded_children.add(name)
else:
child_sharded = _get_unsharded_module_names_helper(
child,
curr_path + ".",
unsharded_module_names,
)
if child_sharded:
sharded_children.add(name)
if len(sharded_children) > 0:
for name, _ in model.named_children():
if name not in sharded_children:
unsharded_module_names.add(path + name)
return len(sharded_children) > 0
def _get_unsharded_module_names(model: torch.nn.Module) -> List[str]:
"""
Returns a list of top level modules do not contain any sharded sub modules.
"""
unsharded_module_names: Set[str] = set()
_get_unsharded_module_names_helper(
model,
"",
unsharded_module_names,
)
return list(unsharded_module_names)
def _rewrite_model( # noqa C901
model: torch.nn.Module,
context: TrainPipelineContext,
dist_stream: Optional[torch.cuda.streams.Stream],
) -> List[ShardedModule]:
# Get underlying nn.Module
if isinstance(model, DistributedModelParallel):
model = model.module
# Collect a list of sharded modules.
sharded_modules = {}
fp_modules = {}
for name, m in model.named_modules():
if isinstance(m, ShardedModule):
sharded_modules[name] = m
if isinstance(m, BaseGroupedFeatureProcessor):
fp_modules[name] = m
# Trace a model.
tracer = Tracer(leaf_modules=_get_unsharded_module_names(model))
graph = tracer.trace(model)
feature_processor_nodes = []
# find the fp node
for node in graph.nodes:
if node.op == "call_module" and node.target in fp_modules:
feature_processor_nodes.append(node)
# Select sharded modules, which are top-level in the forward call graph,
# i.e. which don't have input transformations, i.e.
# rely only on 'builtins.getattr'.
ret = []
for node in graph.nodes:
if node.op == "call_module" and node.target in sharded_modules:
total_num_args = len(node.args) + len(node.kwargs)
if total_num_args == 0:
continue
arg_info_list, num_found = _get_node_args(node, feature_processor_nodes)
if num_found == total_num_args:
logger.info(f"Module '{node.target}'' will be pipelined")
child = sharded_modules[node.target]
child.forward = PipelinedForward(
node.target,
arg_info_list,
child,
context,
dist_stream,
)
ret.append(child)
return ret
class TrainPipelineSparseDist(TrainPipeline[In, Out]):
"""
This pipeline overlaps device transfer, and `ShardedModule.input_dist()` with
forward and backward. This helps hide the all2all latency while preserving the
training forward / backward ordering.
stage 3: forward, backward - uses default CUDA stream
stage 2: ShardedModule.input_dist() - uses data_dist CUDA stream
stage 1: device transfer - uses memcpy CUDA stream
`ShardedModule.input_dist()` is only done for top-level modules in the call graph.
To be considered a top-level module, a module can only depend on 'getattr' calls on
input.
Input model must be symbolically traceable with the exception of `ShardedModule` and
`DistributedDataParallel` modules.
"""
synced_pipeline_id: Dict[int, int] = {}
def __init__(
self,
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
device: torch.device,
enable_amp: bool = False,
enable_grad_scaling: bool = True,
grad_accum: Optional[int] = None,
) -> None:
self._model = model
self._optimizer = optimizer
self._device = device
self._enable_amp = enable_amp
# NOTE: Pending upstream feedback, but two flags because we can run AMP without CUDA but cannot scale gradients without CUDA.
# Background on gradient/loss scaling
# https://docs.nvidia.com/deeplearning/performance/mixed-precision-training/index.html#lossscaling
# https://pytorch.org/docs/stable/amp.html#gradient-scaling
self._enable_grad_scaling = enable_grad_scaling
self._grad_scaler = torch.cuda.amp.GradScaler(
enabled=self._enable_amp and self._enable_grad_scaling
)
logging.info(f"Amp is enabled: {self._enable_amp}")
# use two data streams to support two concurrent batches
if device.type == "cuda":
self._memcpy_stream: Optional[torch.cuda.streams.Stream] = torch.cuda.Stream()
self._data_dist_stream: Optional[torch.cuda.streams.Stream] = torch.cuda.Stream()
else:
if self._enable_amp:
logging.warning("Amp is enabled, but no CUDA available")
self._memcpy_stream: Optional[torch.cuda.streams.Stream] = None
self._data_dist_stream: Optional[torch.cuda.streams.Stream] = None
self._batch_i: Optional[In] = None
self._batch_ip1: Optional[In] = None
self._batch_ip2: Optional[In] = None
self._connected = False
self._context = TrainPipelineContext()
self._pipelined_modules: List[ShardedModule] = []
self._progress_calls = 0
if grad_accum is not None:
assert isinstance(grad_accum, int) and grad_accum > 0
self._grad_accum = grad_accum
def _connect(self, dataloader_iter: Iterator[In]) -> None:
# batch 1
with torch.cuda.stream(self._memcpy_stream):
batch_i = next(dataloader_iter)
self._batch_i = batch_i = _to_device(batch_i, self._device, non_blocking=True)
# Try to pipeline input data dist.
self._pipelined_modules = _rewrite_model(self._model, self._context, self._data_dist_stream)
with torch.cuda.stream(self._data_dist_stream):
_wait_for_batch(batch_i, self._memcpy_stream)
_start_data_dist(self._pipelined_modules, batch_i, self._context)
# batch 2
with torch.cuda.stream(self._memcpy_stream):
batch_ip1 = next(dataloader_iter)
self._batch_ip1 = batch_ip1 = _to_device(batch_ip1, self._device, non_blocking=True)
self._connected = True
self.__class__.synced_pipeline_id[id(self._model)] = id(self)
def progress(self, dataloader_iter: Iterator[In]) -> Out:
"""
NOTE: This method has been updated to perform gradient accumulation.
If `_grad_accum` is set, then loss values are scaled by this amount and
optimizer update/reset is skipped for `_grad_accum` calls of `progress`
(congruent to training steps), and then update/reset on every `_grad_accum`th
step.
"""
should_step_optimizer = (
self._grad_accum is not None
and self._progress_calls > 0
and (self._progress_calls + 1) % self._grad_accum == 0
) or self._grad_accum is None
should_reset_optimizer = (
self._grad_accum is not None
and self._progress_calls > 0
and (self._progress_calls + 2) % self._grad_accum == 0
) or self._grad_accum is None
if not self._connected:
self._connect(dataloader_iter)
elif self.__class__.synced_pipeline_id.get(id(self._model), None) != id(self):
self._sync_pipeline()
self.__class__.synced_pipeline_id[id(self._model)] = id(self)
if self._model.training and should_reset_optimizer:
with record_function("## zero_grad ##"):
self._optimizer.zero_grad()
with record_function("## copy_batch_to_gpu ##"):
with torch.cuda.stream(self._memcpy_stream):
batch_ip2 = next(dataloader_iter)
self._batch_ip2 = batch_ip2 = _to_device(batch_ip2, self._device, non_blocking=True)
batch_i = cast(In, self._batch_i)
batch_ip1 = cast(In, self._batch_ip1)
with record_function("## wait_for_batch ##"):
_wait_for_batch(batch_i, self._data_dist_stream)
# Forward
with record_function("## forward ##"):
# if using multiple streams (ie. CUDA), create an event in default stream
# before starting forward pass
if self._data_dist_stream:
event = torch.cuda.current_stream().record_event()
if self._enable_amp:
# conditionally apply the model to the batch in the autocast context
# it appears that `enabled=self._enable_amp` should handle this,
# but it does not.
with torch.autocast(
device_type=self._device.type,
dtype=torch.bfloat16,
enabled=self._enable_amp,
):
(losses, output) = cast(Tuple[torch.Tensor, Out], self._model(batch_i))
else:
(losses, output) = cast(Tuple[torch.Tensor, Out], self._model(batch_i))
# Data Distribution
with record_function("## sparse_data_dist ##"):
with torch.cuda.stream(self._data_dist_stream):
_wait_for_batch(batch_ip1, self._memcpy_stream)
# Ensure event in default stream has been called before
# starting data dist
if self._data_dist_stream:
# pyre-ignore [61]: Local variable `event` is undefined, or not always defined
self._data_dist_stream.wait_event(event)
_start_data_dist(self._pipelined_modules, batch_ip1, self._context)
if self._model.training:
# Backward
with record_function("## backward ##"):
# Loss is normalize by number of accumulation steps.
# The reported loss in `output['loss']` remains the unnormalized value.
if self._grad_accum is not None:
losses = losses / self._grad_accum
self._grad_scaler.scale(torch.sum(losses, dim=0)).backward()
if should_step_optimizer:
# Update
with record_function("## optimizer ##"):
self._grad_scaler.step(self._optimizer)
self._grad_scaler.update()
self._batch_i = batch_ip1
self._batch_ip1 = batch_ip2
if self._model.training:
self._progress_calls += 1
return output
def _sync_pipeline(self) -> None:
"""
Syncs `PipelinedForward` for sharded modules with context and dist stream of the
current train pipeline. Used when switching between train pipelines for the same
model.
"""
for module in self._pipelined_modules:
module.forward._context = self._context
module.forward._dist_stream = self._data_dist_stream