mirror of
https://github.com/twitter/the-algorithm-ml.git
synced 2025-06-19 00:59:26 +02:00
Twitter's Recommendation Algorithm - Heavy Ranker and TwHIN embeddings
This commit is contained in:
0
core/__init__.py
Normal file
0
core/__init__.py
Normal file
5
core/config/__init__.py
Normal file
5
core/config/__init__.py
Normal file
@ -0,0 +1,5 @@
|
||||
from tml.core.config.base_config import BaseConfig
|
||||
from tml.core.config.config_load import load_config_from_yaml
|
||||
|
||||
# Make mypy happy by explicitly rexporting the symbols intended for end user use.
|
||||
__all__ = ["BaseConfig", "load_config_from_yaml"]
|
66
core/config/base_config.py
Normal file
66
core/config/base_config.py
Normal file
@ -0,0 +1,66 @@
|
||||
"""Base class for all config (forbids extra fields)."""
|
||||
|
||||
import collections
|
||||
import functools
|
||||
import yaml
|
||||
|
||||
import pydantic
|
||||
|
||||
|
||||
class BaseConfig(pydantic.BaseModel):
|
||||
"""Base class for all derived config classes.
|
||||
|
||||
This class provides some convenient functionality:
|
||||
- Disallows extra fields when constructing an object. User error
|
||||
should be reduced by exact arguments.
|
||||
- "one_of" fields. A subclass can group optional fields and enforce
|
||||
that only one of the fields be set. For example:
|
||||
|
||||
```
|
||||
class ExampleConfig(BaseConfig):
|
||||
x: int = Field(None, one_of="group_1")
|
||||
y: int = Field(None, one_of="group_1")
|
||||
|
||||
ExampleConfig(x=1) # ok
|
||||
ExampleConfig(y=1) # ok
|
||||
ExampleConfig(x=1, y=1) # throws error
|
||||
```
|
||||
"""
|
||||
|
||||
class Config:
|
||||
"""Forbids extras."""
|
||||
|
||||
extra = pydantic.Extra.forbid # noqa
|
||||
|
||||
@classmethod
|
||||
@functools.lru_cache()
|
||||
def _field_data_map(cls, field_data_name):
|
||||
"""Create a map of fields with provided the field data."""
|
||||
schema = cls.schema()
|
||||
one_of = collections.defaultdict(list)
|
||||
for field, fdata in schema["properties"].items():
|
||||
if field_data_name in fdata:
|
||||
one_of[fdata[field_data_name]].append(field)
|
||||
return one_of
|
||||
|
||||
@pydantic.root_validator
|
||||
def _one_of_check(cls, values):
|
||||
"""Validate that all 'one of' fields are appear exactly once."""
|
||||
one_of_map = cls._field_data_map("one_of")
|
||||
for one_of, field_names in one_of_map.items():
|
||||
if sum([values.get(n, None) is not None for n in field_names]) != 1:
|
||||
raise ValueError(f"Exactly one of {','.join(field_names)} required.")
|
||||
return values
|
||||
|
||||
@pydantic.root_validator
|
||||
def _at_most_one_of_check(cls, values):
|
||||
"""Validate that all 'at_most_one_of' fields appear at most once."""
|
||||
at_most_one_of_map = cls._field_data_map("at_most_one_of")
|
||||
for one_of, field_names in at_most_one_of_map.items():
|
||||
if sum([values.get(n, None) is not None for n in field_names]) > 1:
|
||||
raise ValueError(f"At most one of {','.join(field_names)} can be set.")
|
||||
return values
|
||||
|
||||
def pretty_print(self) -> str:
|
||||
"""Return a human legible (yaml) representation of the config useful for logging."""
|
||||
return yaml.dump(self.dict())
|
38
core/config/base_config_test.py
Normal file
38
core/config/base_config_test.py
Normal file
@ -0,0 +1,38 @@
|
||||
from unittest import TestCase
|
||||
|
||||
from tml.core.config import BaseConfig
|
||||
|
||||
import pydantic
|
||||
|
||||
|
||||
class BaseConfigTest(TestCase):
|
||||
def test_extra_forbidden(self):
|
||||
class Config(BaseConfig):
|
||||
x: int
|
||||
|
||||
Config(x=1)
|
||||
with self.assertRaises(pydantic.ValidationError):
|
||||
Config(x=1, y=2)
|
||||
|
||||
def test_one_of(self):
|
||||
class Config(BaseConfig):
|
||||
x: int = pydantic.Field(None, one_of="f")
|
||||
y: int = pydantic.Field(None, one_of="f")
|
||||
|
||||
with self.assertRaises(pydantic.ValidationError):
|
||||
Config()
|
||||
Config(x=1)
|
||||
Config(y=1)
|
||||
with self.assertRaises(pydantic.ValidationError):
|
||||
Config(x=1, y=3)
|
||||
|
||||
def test_at_most_one_of(self):
|
||||
class Config(BaseConfig):
|
||||
x: int = pydantic.Field(None, at_most_one_of="f")
|
||||
y: str = pydantic.Field(None, at_most_one_of="f")
|
||||
|
||||
Config()
|
||||
Config(x=1)
|
||||
Config(y="a")
|
||||
with self.assertRaises(pydantic.ValidationError):
|
||||
Config(x=1, y="a")
|
24
core/config/config_load.py
Normal file
24
core/config/config_load.py
Normal file
@ -0,0 +1,24 @@
|
||||
import yaml
|
||||
import string
|
||||
import getpass
|
||||
import os
|
||||
from typing import Type
|
||||
|
||||
from tml.core.config.base_config import BaseConfig
|
||||
|
||||
|
||||
def load_config_from_yaml(config_type: Type[BaseConfig], yaml_path: str):
|
||||
"""Recommend method to load a config file (a yaml file) and parse it.
|
||||
|
||||
Because we have a shared filesystem the recommended route to running jobs it put modified config
|
||||
files with the desired parameters somewhere on the filesytem and run jobs pointing to them.
|
||||
"""
|
||||
|
||||
def _substitute(s):
|
||||
return string.Template(s).safe_substitute(os.environ, USER=getpass.getuser())
|
||||
|
||||
with open(yaml_path, "r") as f:
|
||||
raw_contents = f.read()
|
||||
obj = yaml.safe_load(_substitute(raw_contents))
|
||||
|
||||
return config_type.parse_obj(obj)
|
23
core/config/test_config_load.py
Normal file
23
core/config/test_config_load.py
Normal file
@ -0,0 +1,23 @@
|
||||
from unittest import TestCase
|
||||
|
||||
from tml.core.config import BaseConfig, load_config_from_yaml
|
||||
|
||||
import pydantic
|
||||
import getpass
|
||||
import pydantic
|
||||
|
||||
|
||||
class _PointlessConfig(BaseConfig):
|
||||
a: int
|
||||
user: str
|
||||
|
||||
|
||||
def test_load_config_from_yaml(tmp_path):
|
||||
yaml_path = tmp_path.joinpath("test.yaml").as_posix()
|
||||
with open(yaml_path, "w") as yaml_file:
|
||||
yaml_file.write("""a: 3\nuser: ${USER}\n""")
|
||||
|
||||
pointless_config = load_config_from_yaml(_PointlessConfig, yaml_path)
|
||||
|
||||
assert pointless_config.a == 3
|
||||
assert pointless_config.user == getpass.getuser()
|
40
core/config/training.py
Normal file
40
core/config/training.py
Normal file
@ -0,0 +1,40 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from tml.common.wandb import WandbConfig
|
||||
from tml.core.config import base_config
|
||||
from tml.projects.twhin.data.config import TwhinDataConfig
|
||||
from tml.projects.twhin.models.config import TwhinModelConfig
|
||||
|
||||
import pydantic
|
||||
|
||||
|
||||
class RuntimeConfig(base_config.BaseConfig):
|
||||
wandb: WandbConfig = pydantic.Field(None)
|
||||
enable_tensorfloat32: bool = pydantic.Field(
|
||||
False, description="Use tensorfloat32 if on Ampere devices."
|
||||
)
|
||||
enable_amp: bool = pydantic.Field(False, description="Enable automatic mixed precision.")
|
||||
|
||||
|
||||
class TrainingConfig(base_config.BaseConfig):
|
||||
save_dir: str = pydantic.Field("/tmp/model", description="Directory to save checkpoints.")
|
||||
num_train_steps: pydantic.PositiveInt = 10000
|
||||
initial_checkpoint_dir: str = pydantic.Field(
|
||||
None, description="Directory of initial checkpoints", at_most_one_of="initialization"
|
||||
)
|
||||
checkpoint_every_n: pydantic.PositiveInt = 1000
|
||||
checkpoint_max_to_keep: pydantic.PositiveInt = pydantic.Field(
|
||||
None, description="Maximum number of checkpoints to keep. Defaults to keeping all."
|
||||
)
|
||||
train_log_every_n: pydantic.PositiveInt = 1000
|
||||
num_eval_steps: int = pydantic.Field(
|
||||
16384, description="Number of evaluation steps. If < 0 the entire dataset will be used."
|
||||
)
|
||||
eval_log_every_n: pydantic.PositiveInt = 5000
|
||||
|
||||
eval_timeout_in_s: pydantic.PositiveFloat = 60 * 60
|
||||
|
||||
gradient_accumulation: int = pydantic.Field(
|
||||
None, description="Number of replica steps to accumulate gradients."
|
||||
)
|
||||
num_epochs: pydantic.PositiveInt = 1
|
316
core/custom_training_loop.py
Normal file
316
core/custom_training_loop.py
Normal file
@ -0,0 +1,316 @@
|
||||
"""Torch and torchrec specific training and evaluation loops.
|
||||
|
||||
Features (go/100_enablements):
|
||||
- CUDA data-fetch, compute, gradient-push overlap
|
||||
- Large learnable embeddings through torchrec
|
||||
- On/off-chief evaluation
|
||||
- Warmstart/checkpoint management
|
||||
- go/dataset-service 0-copy integration
|
||||
|
||||
"""
|
||||
import datetime
|
||||
import os
|
||||
from typing import Callable, Dict, Iterable, List, Mapping, Optional
|
||||
|
||||
|
||||
from tml.common import log_weights
|
||||
import tml.common.checkpointing.snapshot as snapshot_lib
|
||||
from tml.core.losses import get_global_loss_detached
|
||||
from tml.ml_logging.torch_logging import logging # type: ignore[attr-defined]
|
||||
from tml.core.train_pipeline import TrainPipelineSparseDist
|
||||
|
||||
import tree
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
import torchmetrics as tm
|
||||
|
||||
|
||||
def get_new_iterator(iterable: Iterable):
|
||||
"""
|
||||
This obtain a new iterator from the iterable. If the iterable uses tf.data.Dataset internally,
|
||||
getting a new iterator each N steps will avoid memory leak. To avoid the memory leak
|
||||
calling iter(iterable) should return a "fresh" iterator using a fresh
|
||||
(new instance of) tf.data.Iterator.
|
||||
In particular, iterable can be a torch.utils.data.IterableDataset or a
|
||||
torch.utils.data.DataLoader.
|
||||
|
||||
When using DDS, performing this reset does not change the order in which elements are received
|
||||
(excluding elements already prefetched) provided that iter(iterable) internally uses
|
||||
a new instance of tf.data.Dataset created by calling from_dataset_id.
|
||||
This requirement is satisfied by RecapDataset.
|
||||
:param iterable:
|
||||
:return:
|
||||
"""
|
||||
return iter(iterable)
|
||||
|
||||
|
||||
def _get_step_fn(pipeline, data_iterator, training: bool):
|
||||
def step_fn():
|
||||
# It turns out that model.train() and model.eval() simply switch a single field inside the model
|
||||
# class,so it's somewhat safer to wrap in here.
|
||||
if training:
|
||||
pipeline._model.train()
|
||||
else:
|
||||
pipeline._model.eval()
|
||||
|
||||
outputs = pipeline.progress(data_iterator)
|
||||
return tree.map_structure(lambda elem: elem.detach(), outputs)
|
||||
|
||||
return step_fn
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def _run_evaluation(
|
||||
pipeline,
|
||||
dataset,
|
||||
eval_steps: int,
|
||||
metrics: tm.MetricCollection,
|
||||
eval_batch_size: int,
|
||||
logger=None,
|
||||
):
|
||||
"""Runs the evaluation loop over all evaluation iterators."""
|
||||
dataset = get_new_iterator(dataset)
|
||||
step_fn = _get_step_fn(pipeline, dataset, training=False)
|
||||
last_time = datetime.datetime.now()
|
||||
logging.info(f"Starting {eval_steps} steps of evaluation.")
|
||||
for _ in range(eval_steps):
|
||||
outputs = step_fn()
|
||||
metrics.update(outputs)
|
||||
eval_ex_per_s = (
|
||||
eval_batch_size * eval_steps / (datetime.datetime.now() - last_time).total_seconds()
|
||||
)
|
||||
logging.info(f"eval examples_per_s : {eval_ex_per_s}")
|
||||
metrics_result = metrics.compute()
|
||||
# Resetting at end to release metrics memory not in use.
|
||||
# Reset metrics to prevent accumulation between multiple evaluation splits and not report a
|
||||
# running average.
|
||||
metrics.reset()
|
||||
return metrics_result
|
||||
|
||||
|
||||
def train(
|
||||
model: torch.nn.Module,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
device: str,
|
||||
save_dir: str,
|
||||
logging_interval: int,
|
||||
train_steps: int,
|
||||
checkpoint_frequency: int,
|
||||
dataset: Iterable,
|
||||
worker_batch_size: int,
|
||||
num_workers: Optional[int] = 0,
|
||||
enable_amp: bool = False,
|
||||
initial_checkpoint_dir: Optional[str] = None,
|
||||
gradient_accumulation: Optional[int] = None,
|
||||
logger_initializer: Optional[Callable] = None,
|
||||
scheduler: _LRScheduler = None,
|
||||
metrics: Optional[tm.MetricCollection] = None,
|
||||
parameters_to_log: Optional[Dict[str, Callable]] = None,
|
||||
tables_to_log: Optional[List[str]] = None,
|
||||
) -> None:
|
||||
"""Runs training and eval on the given TrainPipeline
|
||||
|
||||
Args:
|
||||
dataset: data iterator for the training set
|
||||
evaluation_iterators: data iterators for the different evaluation sets
|
||||
scheduler: optional learning rate scheduler
|
||||
output_transform_for_metrics: optional transformation functions to transorm the model
|
||||
output and labels into a format the metrics can understand
|
||||
"""
|
||||
|
||||
train_pipeline = TrainPipelineSparseDist(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
device=device,
|
||||
enable_amp=enable_amp,
|
||||
grad_accum=gradient_accumulation,
|
||||
) # type: ignore[var-annotated]
|
||||
|
||||
# We explicitly initialize optimizer state here so that checkpoint will work properly.
|
||||
if hasattr(train_pipeline._optimizer, "init_state"):
|
||||
train_pipeline._optimizer.init_state()
|
||||
|
||||
save_state = {
|
||||
"model": train_pipeline._model,
|
||||
"optimizer": train_pipeline._optimizer,
|
||||
"scaler": train_pipeline._grad_scaler,
|
||||
}
|
||||
|
||||
chosen_checkpoint = None
|
||||
checkpoint_handler = snapshot_lib.Snapshot(
|
||||
save_dir=save_dir,
|
||||
state=save_state,
|
||||
)
|
||||
|
||||
if save_dir:
|
||||
chosen_checkpoint = snapshot_lib.get_checkpoint(save_dir=save_dir, missing_ok=True)
|
||||
|
||||
start_step = 0
|
||||
start_walltime = 0.0
|
||||
if chosen_checkpoint:
|
||||
# Skip restoration and exit if we should be finished.
|
||||
chosen_checkpoint_global_step = snapshot_lib.step_from_checkpoint(chosen_checkpoint)
|
||||
if not chosen_checkpoint_global_step < dist.get_world_size() * train_steps:
|
||||
logging.info(
|
||||
"Not restoring and finishing training as latest checkpoint "
|
||||
f"{chosen_checkpoint} found "
|
||||
f"at global_step ({chosen_checkpoint_global_step}) >= "
|
||||
f"train_steps ({dist.get_world_size() * train_steps})"
|
||||
)
|
||||
return
|
||||
logging.info(f"Restoring latest checkpoint from global_step {chosen_checkpoint_global_step}")
|
||||
checkpoint_handler.restore(chosen_checkpoint)
|
||||
start_step = checkpoint_handler.step
|
||||
start_walltime = checkpoint_handler.walltime
|
||||
elif initial_checkpoint_dir:
|
||||
base, ckpt_step = os.path.split(initial_checkpoint_dir)
|
||||
warmstart_handler = snapshot_lib.Snapshot(
|
||||
save_dir=base,
|
||||
state=save_state,
|
||||
)
|
||||
ckpt = snapshot_lib.get_checkpoint(save_dir=base, missing_ok=False, global_step=int(ckpt_step))
|
||||
logging.info(
|
||||
f"Restoring from initial_checkpoint_dir: {initial_checkpoint_dir}, but keeping starting step as 0."
|
||||
)
|
||||
warmstart_handler.restore(ckpt)
|
||||
|
||||
train_logger = logger_initializer(mode="train") if logger_initializer else None
|
||||
train_step_fn = _get_step_fn(train_pipeline, get_new_iterator(dataset), training=True)
|
||||
|
||||
# Counting number of parameters in the model directly when creating it.
|
||||
nb_param = 0
|
||||
for p in model.parameters():
|
||||
nb_param += p.numel()
|
||||
logging.info(f"Model has {nb_param} parameters")
|
||||
|
||||
last_time = datetime.datetime.now()
|
||||
start_time = last_time
|
||||
last_pending_snapshot = None
|
||||
for step in range(start_step, train_steps + 1):
|
||||
checkpoint_handler.step = step
|
||||
outputs = train_step_fn()
|
||||
step_done_time = datetime.datetime.now()
|
||||
checkpoint_handler.walltime = (step_done_time - start_time).total_seconds() + start_walltime
|
||||
|
||||
if scheduler:
|
||||
scheduler.step()
|
||||
|
||||
if step % logging_interval == 0:
|
||||
interval_time = (step_done_time - last_time).total_seconds()
|
||||
steps_per_s = logging_interval / interval_time
|
||||
worker_example_per_s = steps_per_s * worker_batch_size
|
||||
global_example_per_s = worker_example_per_s * (1 + (num_workers or 0))
|
||||
global_step = step
|
||||
|
||||
log_values = {
|
||||
"global_step": global_step,
|
||||
"loss": get_global_loss_detached(outputs["loss"]),
|
||||
"steps_per_s": steps_per_s,
|
||||
"global_example_per_s": global_example_per_s,
|
||||
"worker_examples_per_s": worker_example_per_s,
|
||||
"active_training_walltime": checkpoint_handler.walltime,
|
||||
}
|
||||
if parameters_to_log:
|
||||
log_values.update(
|
||||
log_weights.weights_to_log(
|
||||
model=model,
|
||||
how_to_log=parameters_to_log,
|
||||
)
|
||||
)
|
||||
log_values = tree.map_structure(lambda elem: torch.as_tensor(elem).cpu(), log_values)
|
||||
|
||||
if tables_to_log:
|
||||
log_values.update(
|
||||
log_weights.log_ebc_norms(
|
||||
model_state_dict=train_pipeline._model.state_dict(),
|
||||
ebc_keys=tables_to_log,
|
||||
)
|
||||
)
|
||||
if train_logger:
|
||||
train_logger.log(log_values, step=global_step)
|
||||
log_line = ", ".join(f"{name}: {value}" for name, value in log_values.items())
|
||||
logging.info(f"Step: {step}, training. {log_line}")
|
||||
last_time = step_done_time
|
||||
|
||||
# If we just restored, do not save again.
|
||||
if checkpoint_frequency and step > start_step and step % checkpoint_frequency == 0:
|
||||
if last_pending_snapshot and not last_pending_snapshot.done():
|
||||
logging.warning(
|
||||
"Begin a new snapshot and the last one hasn't finished. That probably indicates "
|
||||
"either you're snapshotting really often or something is wrong. Will now block and "
|
||||
"wait for snapshot to finish before beginning the next one."
|
||||
)
|
||||
last_pending_snapshot.wait()
|
||||
last_pending_snapshot = checkpoint_handler.save(global_step=step * dist.get_world_size())
|
||||
|
||||
# Save if we did not just save.
|
||||
if checkpoint_frequency and step % checkpoint_frequency != 0:
|
||||
# For the final save, wait for the checkpoint to write to make sure the process doesn't finish
|
||||
# before its completed.
|
||||
last_pending_snapshot = checkpoint_handler.save(global_step=step * dist.get_world_size())
|
||||
logging.info(f"Finished training steps: {step}, global_steps: {step * dist.get_world_size()}")
|
||||
|
||||
if last_pending_snapshot:
|
||||
logging.info(f"Waiting for any checkpoints to finish.")
|
||||
last_pending_snapshot.wait()
|
||||
|
||||
|
||||
def log_eval_results(
|
||||
results,
|
||||
eval_logger,
|
||||
partition_name: str,
|
||||
step: int,
|
||||
):
|
||||
results = tree.map_structure(lambda elem: torch.as_tensor(elem).cpu(), results)
|
||||
logging.info(f"Step: {step}, evaluation ({partition_name}).")
|
||||
for metric_name, metric_value in results.items():
|
||||
logging.info(f"\t{metric_name}: {metric_value:1.4e}")
|
||||
|
||||
if eval_logger:
|
||||
eval_logger.log(results, step=step, commit=True)
|
||||
|
||||
|
||||
def only_evaluate(
|
||||
model: torch.nn.Module,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
device: str,
|
||||
save_dir: str,
|
||||
num_train_steps: int,
|
||||
dataset: Iterable,
|
||||
eval_batch_size: int,
|
||||
num_eval_steps: int,
|
||||
eval_timeout_in_s: int,
|
||||
eval_logger: Callable,
|
||||
partition_name: str,
|
||||
metrics: Optional[tm.MetricCollection] = None,
|
||||
):
|
||||
logging.info(f"Evaluating on partition {partition_name}.")
|
||||
logging.info("Computing metrics:")
|
||||
logging.info(metrics)
|
||||
eval_pipeline = TrainPipelineSparseDist(model, optimizer, device) # type: ignore[var-annotated]
|
||||
save_state = {
|
||||
"model": eval_pipeline._model,
|
||||
"optimizer": eval_pipeline._optimizer,
|
||||
}
|
||||
checkpoint_handler = snapshot_lib.Snapshot(
|
||||
save_dir=save_dir,
|
||||
state=save_state,
|
||||
)
|
||||
for checkpoint_path in snapshot_lib.checkpoints_iterator(save_dir, timeout=eval_timeout_in_s):
|
||||
checkpoint_handler.restore(checkpoint_path)
|
||||
step = checkpoint_handler.step
|
||||
dataset = get_new_iterator(dataset)
|
||||
results = _run_evaluation(
|
||||
pipeline=eval_pipeline,
|
||||
dataset=dataset,
|
||||
eval_steps=num_eval_steps,
|
||||
eval_batch_size=eval_batch_size,
|
||||
metrics=metrics,
|
||||
)
|
||||
log_eval_results(results, eval_logger, partition_name, step=step)
|
||||
rank = dist.get_rank() if dist.is_initialized() else 0
|
||||
if rank == 0:
|
||||
snapshot_lib.mark_done_eval(checkpoint_path, partition_name)
|
||||
if step >= num_train_steps:
|
||||
return
|
45
core/debug_training_loop.py
Normal file
45
core/debug_training_loop.py
Normal file
@ -0,0 +1,45 @@
|
||||
"""This is a very limited feature training loop useful for interactive debugging.
|
||||
|
||||
It is not intended for actual model tranining (it is not fast, doesn't compile the model).
|
||||
It does not support checkpointing.
|
||||
|
||||
suggested use:
|
||||
|
||||
from tml.core import debug_training_loop
|
||||
debug_training_loop.train(...)
|
||||
"""
|
||||
|
||||
from typing import Iterable, Optional, Dict, Callable, List
|
||||
import torch
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
import torchmetrics as tm
|
||||
|
||||
from tml.ml_logging.torch_logging import logging
|
||||
|
||||
|
||||
def train(
|
||||
model: torch.nn.Module,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
train_steps: int,
|
||||
dataset: Iterable,
|
||||
scheduler: _LRScheduler = None,
|
||||
# Accept any arguments (to be compatible with the real training loop)
|
||||
# but just ignore them.
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
|
||||
logging.warning("Running debug training loop, don't use for model training.")
|
||||
|
||||
data_iter = iter(dataset)
|
||||
for step in range(0, train_steps + 1):
|
||||
x = next(data_iter)
|
||||
optimizer.zero_grad()
|
||||
loss, outputs = model.forward(x)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
if scheduler:
|
||||
scheduler.step()
|
||||
|
||||
logging.info(f"Step {step} completed. Loss = {loss}")
|
7
core/loss_type.py
Normal file
7
core/loss_type.py
Normal file
@ -0,0 +1,7 @@
|
||||
"""Loss type enums."""
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class LossType(str, Enum):
|
||||
CROSS_ENTROPY = "cross_entropy"
|
||||
BCE_WITH_LOGITS = "bce_with_logits"
|
109
core/losses.py
Normal file
109
core/losses.py
Normal file
@ -0,0 +1,109 @@
|
||||
"""Loss functions -- including multi task ones."""
|
||||
|
||||
import typing
|
||||
|
||||
from tml.core.loss_type import LossType
|
||||
from tml.ml_logging.torch_logging import logging
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def _maybe_warn(reduction: str):
|
||||
"""
|
||||
Warning for reduction different than mean.
|
||||
"""
|
||||
if reduction != "mean":
|
||||
logging.warn(
|
||||
f"For the same global_batch_size, the gradient in DDP is guaranteed to be equal,"
|
||||
f"to the gradient without DDP only for mean reduction. If you need this property for"
|
||||
f"the provided reduction {reduction}, it needs to be implemented."
|
||||
)
|
||||
|
||||
|
||||
def build_loss(
|
||||
loss_type: LossType,
|
||||
reduction="mean",
|
||||
):
|
||||
_maybe_warn(reduction)
|
||||
f = _LOSS_TYPE_TO_FUNCTION[loss_type]
|
||||
|
||||
def loss_fn(logits, labels):
|
||||
return f(logits, labels.type_as(logits), reduction=reduction)
|
||||
|
||||
return loss_fn
|
||||
|
||||
|
||||
def get_global_loss_detached(local_loss, reduction="mean"):
|
||||
"""
|
||||
Perform all_reduce to obtain the global loss function using the provided reduction.
|
||||
:param local_loss: The local loss of the current rank.
|
||||
:param reduction: The reduction to use for all_reduce. Should match the reduction used by DDP.
|
||||
:return: The reduced & detached global loss.
|
||||
"""
|
||||
if reduction != "mean":
|
||||
logging.warn(
|
||||
f"The reduction used in this function should be the same as the one used by "
|
||||
f"the DDP model. By default DDP uses mean, So ensure that DDP is appropriately"
|
||||
f"modified for reduction {reduction}."
|
||||
)
|
||||
|
||||
if reduction not in ["mean", "sum"]:
|
||||
raise ValueError(f"Reduction {reduction} is currently unsupported.")
|
||||
|
||||
global_loss = local_loss.detach()
|
||||
|
||||
if reduction == "mean":
|
||||
global_loss.div_(torch.distributed.get_world_size())
|
||||
|
||||
torch.distributed.all_reduce(global_loss)
|
||||
return global_loss
|
||||
|
||||
|
||||
def build_multi_task_loss(
|
||||
loss_type: LossType,
|
||||
tasks: typing.List[str],
|
||||
task_loss_reduction="mean",
|
||||
global_reduction="mean",
|
||||
pos_weights=None,
|
||||
):
|
||||
_maybe_warn(global_reduction)
|
||||
_maybe_warn(task_loss_reduction)
|
||||
f = _LOSS_TYPE_TO_FUNCTION[loss_type]
|
||||
|
||||
loss_reduction_fns = {
|
||||
"mean": torch.mean,
|
||||
"sum": torch.sum,
|
||||
"min": torch.min,
|
||||
"max": torch.max,
|
||||
"median": torch.median,
|
||||
}
|
||||
|
||||
def loss_fn(logits: torch.Tensor, labels: torch.Tensor, weights: torch.Tensor):
|
||||
if pos_weights is None:
|
||||
torch_weights = torch.ones([len(tasks)])
|
||||
else:
|
||||
torch_weights = torch.tensor(pos_weights)
|
||||
|
||||
losses = {}
|
||||
for task_idx, task in enumerate(tasks):
|
||||
task_logits = logits[:, task_idx]
|
||||
label = labels[:, task_idx].type_as(task_logits)
|
||||
|
||||
loss = f(
|
||||
task_logits,
|
||||
label,
|
||||
reduction=task_loss_reduction,
|
||||
pos_weight=torch_weights[task_idx],
|
||||
weight=weights[:, task_idx],
|
||||
)
|
||||
losses[f"loss/{task}"] = loss
|
||||
|
||||
losses["loss"] = loss_reduction_fns[global_reduction](torch.stack(list(losses.values())))
|
||||
return losses
|
||||
|
||||
return loss_fn
|
||||
|
||||
|
||||
_LOSS_TYPE_TO_FUNCTION = {
|
||||
LossType.BCE_WITH_LOGITS: torch.nn.functional.binary_cross_entropy_with_logits
|
||||
}
|
104
core/metric_mixin.py
Normal file
104
core/metric_mixin.py
Normal file
@ -0,0 +1,104 @@
|
||||
"""
|
||||
Mixin that requires a transform to munge output dictionary of tensors a
|
||||
model produces to a form that the torchmetrics.Metric.update expects.
|
||||
|
||||
By unifying on our signature for `update`, we can also now use
|
||||
torchmetrics.MetricCollection which requires all metrics have
|
||||
the same call signature.
|
||||
|
||||
To use, override this with a transform that munges `outputs`
|
||||
into a kwargs dict that the inherited metric.update accepts.
|
||||
|
||||
Here are two examples of how to extend torchmetrics.SumMetric so that it accepts
|
||||
an output dictionary of tensors and munges it to what SumMetric expects (single `value`)
|
||||
for its update method.
|
||||
|
||||
1. Using as a mixin to inherit from or define a new metric class.
|
||||
|
||||
class Count(MetricMixin, SumMetric):
|
||||
def transform(self, outputs):
|
||||
return {'value': 1}
|
||||
|
||||
2. Redefine an existing metric class.
|
||||
|
||||
SumMetric = prepend_transform(SumMetric, lambda outputs: {'value': 1})
|
||||
|
||||
"""
|
||||
from abc import abstractmethod
|
||||
from typing import Callable, Dict, List
|
||||
|
||||
from tml.ml_logging.torch_logging import logging # type: ignore[attr-defined]
|
||||
|
||||
import torch
|
||||
import torchmetrics
|
||||
|
||||
|
||||
class MetricMixin:
|
||||
@abstractmethod
|
||||
def transform(self, outputs: Dict[str, torch.Tensor]) -> Dict:
|
||||
...
|
||||
|
||||
def update(self, outputs: Dict[str, torch.Tensor]):
|
||||
results = self.transform(outputs)
|
||||
# Do not try to update if any tensor is empty as a result of stratification.
|
||||
for value in results.values():
|
||||
if torch.is_tensor(value) and not value.nelement():
|
||||
return
|
||||
super().update(**results)
|
||||
|
||||
|
||||
class TaskMixin:
|
||||
def __init__(self, task_idx: int = -1, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._task_idx = task_idx
|
||||
|
||||
|
||||
class StratifyMixin:
|
||||
def __init__(
|
||||
self,
|
||||
stratifier=None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self._stratifier = stratifier
|
||||
|
||||
def maybe_apply_stratification(
|
||||
self, outputs: Dict[str, torch.Tensor], value_names: List[str]
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""Pick out examples with values for which the stratifier feature is equal to a specific stratifier indicator value."""
|
||||
outputs = outputs.copy()
|
||||
if not self._stratifier:
|
||||
return outputs
|
||||
stratifiers = outputs.get("stratifiers")
|
||||
if not stratifiers:
|
||||
return outputs
|
||||
if stratifiers.get(self._stratifier.name) is None:
|
||||
return outputs
|
||||
|
||||
mask = torch.flatten(outputs["stratifiers"][self._stratifier.name] == self._stratifier.value)
|
||||
target_slice = torch.squeeze(mask.nonzero(), -1)
|
||||
for value_name in value_names:
|
||||
target = outputs[value_name]
|
||||
outputs[value_name] = torch.index_select(target, 0, target_slice)
|
||||
return outputs
|
||||
|
||||
|
||||
def prepend_transform(base_metric: torchmetrics.Metric, transform: Callable):
|
||||
"""Returns new class using MetricMixin and given base_metric.
|
||||
|
||||
Functionally the same using inheritance, just saves some lines of code
|
||||
if no need for class attributes.
|
||||
|
||||
"""
|
||||
|
||||
def transform_method(_self, *args, **kwargs):
|
||||
return transform(*args, **kwargs)
|
||||
|
||||
return type(
|
||||
base_metric.__name__,
|
||||
(
|
||||
MetricMixin,
|
||||
base_metric,
|
||||
),
|
||||
{"transform": transform_method},
|
||||
)
|
148
core/metrics.py
Normal file
148
core/metrics.py
Normal file
@ -0,0 +1,148 @@
|
||||
"""Common metrics that also support multi task.
|
||||
|
||||
We assume multi task models will output [task_idx, ...] predictions
|
||||
|
||||
"""
|
||||
from typing import Any, Dict
|
||||
|
||||
from tml.core.metric_mixin import MetricMixin, StratifyMixin, TaskMixin
|
||||
|
||||
import torch
|
||||
import torchmetrics as tm
|
||||
|
||||
|
||||
def probs_and_labels(
|
||||
outputs: Dict[str, torch.Tensor],
|
||||
task_idx: int,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
preds = outputs["probabilities"]
|
||||
target = outputs["labels"]
|
||||
if task_idx >= 0:
|
||||
preds = preds[:, task_idx]
|
||||
target = target[:, task_idx]
|
||||
return {
|
||||
"preds": preds,
|
||||
"target": target.int(),
|
||||
}
|
||||
|
||||
|
||||
class Count(StratifyMixin, TaskMixin, MetricMixin, tm.SumMetric):
|
||||
def transform(self, outputs):
|
||||
outputs = self.maybe_apply_stratification(outputs, ["labels"])
|
||||
value = outputs["labels"]
|
||||
if self._task_idx >= 0:
|
||||
value = value[:, self._task_idx]
|
||||
return {"value": value}
|
||||
|
||||
|
||||
class Ctr(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
|
||||
def transform(self, outputs):
|
||||
outputs = self.maybe_apply_stratification(outputs, ["labels"])
|
||||
value = outputs["labels"]
|
||||
if self._task_idx >= 0:
|
||||
value = value[:, self._task_idx]
|
||||
return {"value": value}
|
||||
|
||||
|
||||
class Pctr(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
|
||||
def transform(self, outputs):
|
||||
outputs = self.maybe_apply_stratification(outputs, ["probabilities"])
|
||||
value = outputs["probabilities"]
|
||||
if self._task_idx >= 0:
|
||||
value = value[:, self._task_idx]
|
||||
return {"value": value}
|
||||
|
||||
|
||||
class Precision(StratifyMixin, TaskMixin, MetricMixin, tm.Precision):
|
||||
def transform(self, outputs):
|
||||
outputs = self.maybe_apply_stratification(outputs, ["probabilities", "labels"])
|
||||
return probs_and_labels(outputs, self._task_idx)
|
||||
|
||||
|
||||
class Recall(StratifyMixin, TaskMixin, MetricMixin, tm.Recall):
|
||||
def transform(self, outputs):
|
||||
outputs = self.maybe_apply_stratification(outputs, ["probabilities", "labels"])
|
||||
return probs_and_labels(outputs, self._task_idx)
|
||||
|
||||
|
||||
class TorchMetricsRocauc(StratifyMixin, TaskMixin, MetricMixin, tm.AUROC):
|
||||
def transform(self, outputs):
|
||||
outputs = self.maybe_apply_stratification(outputs, ["probabilities", "labels"])
|
||||
return probs_and_labels(outputs, self._task_idx)
|
||||
|
||||
|
||||
class Auc(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
|
||||
"""
|
||||
Based on:
|
||||
https://github.com/facebookresearch/PyTorch-BigGraph/blob/a11ff0eb644b7e4cb569067c280112b47f40ef62/torchbiggraph/util.py#L420
|
||||
"""
|
||||
|
||||
def __init__(self, num_samples, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.num_samples = num_samples
|
||||
|
||||
def transform(self, outputs: Dict[str, torch.Tensor]) -> Dict[str, Any]:
|
||||
scores, labels = outputs["logits"], outputs["labels"]
|
||||
pos_scores = scores[labels == 1]
|
||||
neg_scores = scores[labels == 0]
|
||||
result = {
|
||||
"value": pos_scores[torch.randint(len(pos_scores), (self.num_samples,))]
|
||||
> neg_scores[torch.randint(len(neg_scores), (self.num_samples,))]
|
||||
}
|
||||
return result
|
||||
|
||||
|
||||
class PosRanks(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
|
||||
"""
|
||||
The ranks of all positives
|
||||
Based on:
|
||||
https://github.com/facebookresearch/PyTorch-BigGraph/blob/a11ff0eb644b7e4cb569067c280112b47f40ef62/torchbiggraph/eval.py#L73
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def transform(self, outputs: Dict[str, torch.Tensor]) -> Dict[str, Any]:
|
||||
scores, labels = outputs["logits"], outputs["labels"]
|
||||
_, sorted_indices = scores.sort(descending=True)
|
||||
pos_ranks = labels[sorted_indices].nonzero(as_tuple=True)[0] + 1 # all ranks start from 1
|
||||
result = {"value": pos_ranks}
|
||||
return result
|
||||
|
||||
|
||||
class ReciprocalRank(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
|
||||
"""
|
||||
The reciprocal of the ranks of all
|
||||
Based on:
|
||||
https://github.com/facebookresearch/PyTorch-BigGraph/blob/a11ff0eb644b7e4cb569067c280112b47f40ef62/torchbiggraph/eval.py#L74
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def transform(self, outputs: Dict[str, torch.Tensor]) -> Dict[str, Any]:
|
||||
scores, labels = outputs["logits"], outputs["labels"]
|
||||
_, sorted_indices = scores.sort(descending=True)
|
||||
pos_ranks = labels[sorted_indices].nonzero(as_tuple=True)[0] + 1 # all ranks start from 1
|
||||
result = {"value": torch.div(torch.ones_like(pos_ranks), pos_ranks)}
|
||||
return result
|
||||
|
||||
|
||||
class HitAtK(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
|
||||
"""
|
||||
The fraction of positives that rank in the top K among their negatives
|
||||
Note that this is basically precision@k
|
||||
Based on:
|
||||
https://github.com/facebookresearch/PyTorch-BigGraph/blob/a11ff0eb644b7e4cb569067c280112b47f40ef62/torchbiggraph/eval.py#L75
|
||||
"""
|
||||
|
||||
def __init__(self, k: int, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.k = k
|
||||
|
||||
def transform(self, outputs: Dict[str, torch.Tensor]) -> Dict[str, Any]:
|
||||
scores, labels = outputs["logits"], outputs["labels"]
|
||||
_, sorted_indices = scores.sort(descending=True)
|
||||
pos_ranks = labels[sorted_indices].nonzero(as_tuple=True)[0] + 1 # all ranks start from 1
|
||||
result = {"value": (pos_ranks <= self.k).float()}
|
||||
return result
|
187
core/test_metrics.py
Normal file
187
core/test_metrics.py
Normal file
@ -0,0 +1,187 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from tml.core import metrics as core_metrics
|
||||
from tml.core.metric_mixin import MetricMixin, prepend_transform
|
||||
|
||||
import torch
|
||||
from torchmetrics import MaxMetric, MetricCollection, SumMetric
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockStratifierConfig:
|
||||
name: str
|
||||
index: int
|
||||
value: int
|
||||
|
||||
|
||||
class Count(MetricMixin, SumMetric):
|
||||
def transform(self, outputs):
|
||||
return {"value": 1}
|
||||
|
||||
|
||||
Max = prepend_transform(MaxMetric, lambda outputs: {"value": outputs["value"]})
|
||||
|
||||
|
||||
def test_count_metric():
|
||||
num_examples = 123
|
||||
examples = [
|
||||
{"stuff": 0},
|
||||
] * num_examples
|
||||
|
||||
metric = Count()
|
||||
for outputs in examples:
|
||||
metric.update(outputs)
|
||||
|
||||
assert metric.compute().item() == num_examples
|
||||
|
||||
|
||||
def test_collections():
|
||||
max_metric = Max()
|
||||
count_metric = Count()
|
||||
metric = MetricCollection([max_metric, count_metric])
|
||||
|
||||
examples = [{"value": idx} for idx in range(123)]
|
||||
for outputs in examples:
|
||||
metric.update(outputs)
|
||||
|
||||
assert metric.compute() == {
|
||||
max_metric.__class__.__name__: len(examples) - 1,
|
||||
count_metric.__class__.__name__: len(examples),
|
||||
}
|
||||
|
||||
|
||||
def test_task_dependent_ctr():
|
||||
num_examples = 144
|
||||
batch_size = 1024
|
||||
outputs = [
|
||||
{
|
||||
"stuff": 0,
|
||||
"labels": torch.arange(0, 6).repeat(batch_size, 1),
|
||||
}
|
||||
for idx in range(num_examples)
|
||||
]
|
||||
|
||||
for task_idx in range(5):
|
||||
metric = core_metrics.Ctr(task_idx=task_idx)
|
||||
for output in outputs:
|
||||
metric.update(output)
|
||||
assert metric.compute().item() == task_idx
|
||||
|
||||
|
||||
def test_stratified_ctr():
|
||||
outputs = [
|
||||
{
|
||||
"stuff": 0,
|
||||
# [bsz, tasks]
|
||||
"labels": torch.tensor(
|
||||
[
|
||||
[0, 1, 2, 3],
|
||||
[1, 2, 3, 4],
|
||||
[2, 3, 4, 0],
|
||||
]
|
||||
),
|
||||
"stratifiers": {
|
||||
# [bsz]
|
||||
"level": torch.tensor(
|
||||
[9, 0, 9],
|
||||
),
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
stratifier = MockStratifierConfig(name="level", index=2, value=9)
|
||||
for task_idx in range(5):
|
||||
metric = core_metrics.Ctr(task_idx=1, stratifier=stratifier)
|
||||
for output in outputs:
|
||||
metric.update(output)
|
||||
# From the dataset of:
|
||||
# [
|
||||
# [0, 1, 2, 3],
|
||||
# [1, 2, 3, 4],
|
||||
# [2, 3, 4, 0],
|
||||
# ]
|
||||
# we pick out
|
||||
# [
|
||||
# [0, 1, 2, 3],
|
||||
# [2, 3, 4, 0],
|
||||
# ]
|
||||
# and with Ctr task_idx, we pick out
|
||||
# [
|
||||
# [1,],
|
||||
# [3,],
|
||||
# ]
|
||||
assert metric.compute().item() == (1 + 3) / 2
|
||||
|
||||
|
||||
def test_auc():
|
||||
num_samples = 10000
|
||||
metric = core_metrics.Auc(num_samples)
|
||||
target = torch.tensor([0, 0, 1, 1, 1])
|
||||
preds_correct = torch.tensor([-1.0, -1.0, 1.0, 1.0, 1.0])
|
||||
outputs_correct = {"logits": preds_correct, "labels": target}
|
||||
preds_bad = torch.tensor([1.0, 1.0, -1.0, -1.0, -1.0])
|
||||
outputs_bad = {"logits": preds_bad, "labels": target}
|
||||
|
||||
metric.update(outputs_correct)
|
||||
assert metric.compute().item() == 1.0
|
||||
|
||||
metric.reset()
|
||||
metric.update(outputs_bad)
|
||||
assert metric.compute().item() == 0.0
|
||||
|
||||
|
||||
def test_pos_rank():
|
||||
metric = core_metrics.PosRanks()
|
||||
target = torch.tensor([0, 0, 1, 1, 1])
|
||||
preds_correct = torch.tensor([-1.0, -1.0, 0.5, 1.0, 1.5])
|
||||
outputs_correct = {"logits": preds_correct, "labels": target}
|
||||
preds_bad = torch.tensor([1.0, 1.0, -1.5, -1.0, -0.5])
|
||||
outputs_bad = {"logits": preds_bad, "labels": target}
|
||||
|
||||
metric.update(outputs_correct)
|
||||
assert metric.compute().item() == 2.0
|
||||
|
||||
metric.reset()
|
||||
metric.update(outputs_bad)
|
||||
assert metric.compute().item() == 4.0
|
||||
|
||||
|
||||
def test_reciprocal_rank():
|
||||
metric = core_metrics.ReciprocalRank()
|
||||
target = torch.tensor([0, 0, 1, 1, 1])
|
||||
preds_correct = torch.tensor([-1.0, -1.0, 0.5, 1.0, 1.5])
|
||||
outputs_correct = {"logits": preds_correct, "labels": target}
|
||||
preds_bad = torch.tensor([1.0, 1.0, -1.5, -1.0, -0.5])
|
||||
outputs_bad = {"logits": preds_bad, "labels": target}
|
||||
|
||||
metric.update(outputs_correct)
|
||||
assert abs(metric.compute().item() - 0.6111) < 0.001
|
||||
|
||||
metric.reset()
|
||||
metric.update(outputs_bad)
|
||||
assert abs(metric.compute().item() == 0.2611) < 0.001
|
||||
|
||||
|
||||
def test_hit_k():
|
||||
hit1_metric = core_metrics.HitAtK(1)
|
||||
target = torch.tensor([0, 0, 1, 1, 1])
|
||||
preds_correct = torch.tensor([-1.0, 1.0, 0.5, -0.1, 1.5])
|
||||
outputs_correct = {"logits": preds_correct, "labels": target}
|
||||
preds_bad = torch.tensor([1.0, 1.0, -1.5, -1.0, -0.5])
|
||||
outputs_bad = {"logits": preds_bad, "labels": target}
|
||||
|
||||
hit1_metric.update(outputs_correct)
|
||||
assert abs(hit1_metric.compute().item() - 0.3333) < 0.0001
|
||||
|
||||
hit1_metric.reset()
|
||||
hit1_metric.update(outputs_bad)
|
||||
|
||||
assert hit1_metric.compute().item() == 0
|
||||
|
||||
hit3_metric = core_metrics.HitAtK(3)
|
||||
hit3_metric.update(outputs_correct)
|
||||
assert (hit3_metric.compute().item() - 0.66666) < 0.0001
|
||||
|
||||
hit3_metric.reset()
|
||||
hit3_metric.update(outputs_bad)
|
||||
assert abs(hit3_metric.compute().item() - 0.3333) < 0.0001
|
91
core/test_train_pipeline.py
Normal file
91
core/test_train_pipeline.py
Normal file
@ -0,0 +1,91 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple
|
||||
|
||||
from tml.common.batch import DataclassBatch
|
||||
from tml.common.testing_utils import mock_pg
|
||||
from tml.core import train_pipeline
|
||||
|
||||
import torch
|
||||
from torchrec.distributed import DistributedModelParallel
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockDataclassBatch(DataclassBatch):
|
||||
continuous_features: torch.Tensor
|
||||
labels: torch.Tensor
|
||||
|
||||
|
||||
class MockModule(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.model = torch.nn.Linear(10, 1)
|
||||
self.loss_fn = torch.nn.BCEWithLogitsLoss()
|
||||
|
||||
def forward(self, batch: MockDataclassBatch) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
pred = self.model(batch.continuous_features)
|
||||
loss = self.loss_fn(pred, batch.labels)
|
||||
return (loss, pred)
|
||||
|
||||
|
||||
def create_batch(bsz: int):
|
||||
return MockDataclassBatch(
|
||||
continuous_features=torch.rand(bsz, 10).float(),
|
||||
labels=torch.bernoulli(torch.empty(bsz, 1).uniform_(0, 1)).float(),
|
||||
)
|
||||
|
||||
|
||||
def test_sparse_pipeline():
|
||||
device = torch.device("cpu")
|
||||
model = MockModule().to(device)
|
||||
|
||||
steps = 8
|
||||
example = create_batch(1)
|
||||
dataloader = iter(example for _ in range(steps + 2))
|
||||
|
||||
results = []
|
||||
with mock_pg():
|
||||
d_model = DistributedModelParallel(model)
|
||||
pipeline = train_pipeline.TrainPipelineSparseDist(
|
||||
model=d_model,
|
||||
optimizer=torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9),
|
||||
device=device,
|
||||
grad_accum=2,
|
||||
)
|
||||
for _ in range(steps):
|
||||
results.append(pipeline.progress(dataloader))
|
||||
|
||||
results = [elem.detach().numpy() for elem in results]
|
||||
# Check gradients are accumulated, i.e. results do not change for every 0th and 1th.
|
||||
for first, second in zip(results[::2], results[1::2]):
|
||||
assert first == second, results
|
||||
|
||||
# Check we do update gradients, i.e. results do change for every 1th and 2nd.
|
||||
for first, second in zip(results[1::2], results[2::2]):
|
||||
assert first != second, results
|
||||
|
||||
|
||||
def test_amp():
|
||||
device = torch.device("cpu")
|
||||
model = MockModule().to(device)
|
||||
|
||||
steps = 8
|
||||
example = create_batch(1)
|
||||
dataloader = iter(example for _ in range(steps + 2))
|
||||
|
||||
results = []
|
||||
with mock_pg():
|
||||
d_model = DistributedModelParallel(model)
|
||||
pipeline = train_pipeline.TrainPipelineSparseDist(
|
||||
model=d_model,
|
||||
optimizer=torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9),
|
||||
device=device,
|
||||
enable_amp=True,
|
||||
# Not supported on CPU.
|
||||
enable_grad_scaling=False,
|
||||
)
|
||||
for _ in range(steps):
|
||||
results.append(pipeline.progress(dataloader))
|
||||
|
||||
results = [elem.detach() for elem in results]
|
||||
for value in results:
|
||||
assert value.dtype == torch.bfloat16
|
626
core/train_pipeline.py
Normal file
626
core/train_pipeline.py
Normal file
@ -0,0 +1,626 @@
|
||||
"""
|
||||
Taken from https://raw.githubusercontent.com/pytorch/torchrec/v0.3.2/torchrec/distributed/train_pipeline.py
|
||||
with TrainPipelineSparseDist.progress modified to support gradient accumulation.
|
||||
|
||||
"""
|
||||
import abc
|
||||
from dataclasses import dataclass, field
|
||||
import logging
|
||||
from typing import (
|
||||
Any,
|
||||
cast,
|
||||
Dict,
|
||||
Generic,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
)
|
||||
|
||||
import torch
|
||||
from torch.autograd.profiler import record_function
|
||||
from torch.fx.node import Node
|
||||
from torchrec.distributed.model_parallel import (
|
||||
DistributedModelParallel,
|
||||
ShardedModule,
|
||||
)
|
||||
from torchrec.distributed.types import Awaitable
|
||||
from torchrec.modules.feature_processor import BaseGroupedFeatureProcessor
|
||||
from torchrec.streamable import Multistreamable, Pipelineable
|
||||
|
||||
|
||||
logger: logging.Logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
In = TypeVar("In", bound=Pipelineable)
|
||||
Out = TypeVar("Out")
|
||||
|
||||
|
||||
class TrainPipeline(abc.ABC, Generic[In, Out]):
|
||||
@abc.abstractmethod
|
||||
def progress(self, dataloader_iter: Iterator[In]) -> Out:
|
||||
pass
|
||||
|
||||
|
||||
def _to_device(batch: In, device: torch.device, non_blocking: bool) -> In:
|
||||
assert isinstance(
|
||||
batch, (torch.Tensor, Pipelineable)
|
||||
), f"{type(batch)} must implement Pipelineable interface"
|
||||
return cast(In, batch.to(device=device, non_blocking=non_blocking))
|
||||
|
||||
|
||||
def _wait_for_batch(batch: In, stream: Optional[torch.cuda.streams.Stream]) -> None:
|
||||
if stream is None:
|
||||
return
|
||||
torch.cuda.current_stream().wait_stream(stream)
|
||||
# As mentioned in https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html,
|
||||
# PyTorch uses the "caching allocator" for memory allocation for tensors. When a tensor is
|
||||
# freed, its memory is likely to be reused by newly constructed tenosrs. By default,
|
||||
# this allocator traces whether a tensor is still in use by only the CUDA stream where it
|
||||
# was created. When a tensor is used by additional CUDA streams, we need to call record_stream
|
||||
# to tell the allocator about all these streams. Otherwise, the allocator might free the
|
||||
# underlying memory of the tensor once it is no longer used by the creator stream. This is
|
||||
# a notable programming trick when we write programs using multi CUDA streams.
|
||||
cur_stream = torch.cuda.current_stream()
|
||||
assert isinstance(
|
||||
batch, (torch.Tensor, Multistreamable)
|
||||
), f"{type(batch)} must implement Multistreamable interface"
|
||||
batch.record_stream(cur_stream)
|
||||
|
||||
|
||||
class TrainPipelineBase(TrainPipeline[In, Out]):
|
||||
"""
|
||||
This class runs training iterations using a pipeline of two stages, each as a CUDA
|
||||
stream, namely, the current (default) stream and `self._memcpy_stream`. For each
|
||||
iteration, `self._memcpy_stream` moves the input from host (CPU) memory to GPU
|
||||
memory, and the default stream runs forward, backward, and optimization.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: torch.nn.Module,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
device: torch.device,
|
||||
) -> None:
|
||||
self._model = model
|
||||
self._optimizer = optimizer
|
||||
self._device = device
|
||||
self._memcpy_stream: Optional[torch.cuda.streams.Stream] = (
|
||||
torch.cuda.Stream() if device.type == "cuda" else None
|
||||
)
|
||||
self._cur_batch: Optional[In] = None
|
||||
self._connected = False
|
||||
|
||||
def _connect(self, dataloader_iter: Iterator[In]) -> None:
|
||||
cur_batch = next(dataloader_iter)
|
||||
self._cur_batch = cur_batch
|
||||
with torch.cuda.stream(self._memcpy_stream):
|
||||
self._cur_batch = _to_device(cur_batch, self._device, non_blocking=True)
|
||||
self._connected = True
|
||||
|
||||
def progress(self, dataloader_iter: Iterator[In]) -> Out:
|
||||
if not self._connected:
|
||||
self._connect(dataloader_iter)
|
||||
|
||||
# Fetch next batch
|
||||
with record_function("## next_batch ##"):
|
||||
next_batch = next(dataloader_iter)
|
||||
cur_batch = self._cur_batch
|
||||
assert cur_batch is not None
|
||||
|
||||
if self._model.training:
|
||||
with record_function("## zero_grad ##"):
|
||||
self._optimizer.zero_grad()
|
||||
|
||||
with record_function("## wait_for_batch ##"):
|
||||
_wait_for_batch(cur_batch, self._memcpy_stream)
|
||||
|
||||
with record_function("## forward ##"):
|
||||
(losses, output) = self._model(cur_batch)
|
||||
|
||||
if self._model.training:
|
||||
with record_function("## backward ##"):
|
||||
torch.sum(losses, dim=0).backward()
|
||||
|
||||
# Copy the next batch to GPU
|
||||
self._cur_batch = cur_batch = next_batch
|
||||
with record_function("## copy_batch_to_gpu ##"):
|
||||
with torch.cuda.stream(self._memcpy_stream):
|
||||
self._cur_batch = _to_device(cur_batch, self._device, non_blocking=True)
|
||||
|
||||
# Update
|
||||
if self._model.training:
|
||||
with record_function("## optimizer ##"):
|
||||
self._optimizer.step()
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class Tracer(torch.fx.Tracer):
|
||||
# Disable proxying buffers during tracing. Ideally, proxying buffers would
|
||||
# be disabled, but some models are currently mutating buffer values, which
|
||||
# causes errors during tracing. If those models can be rewritten to not do
|
||||
# that, we can likely remove this line
|
||||
proxy_buffer_attributes = False
|
||||
|
||||
def __init__(self, leaf_modules: Optional[List[str]] = None) -> None:
|
||||
super().__init__()
|
||||
self._leaf_modules: List[str] = leaf_modules if leaf_modules is not None else []
|
||||
|
||||
def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool:
|
||||
if isinstance(m, ShardedModule) or module_qualified_name in self._leaf_modules:
|
||||
return True
|
||||
return super().is_leaf_module(m, module_qualified_name)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainPipelineContext:
|
||||
# pyre-ignore [4]
|
||||
input_dist_requests: Dict[str, Awaitable[Any]] = field(default_factory=dict)
|
||||
module_contexts: Dict[str, Multistreamable] = field(default_factory=dict)
|
||||
# pyre-ignore [4]
|
||||
feature_processor_forwards: List[Any] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ArgInfo:
|
||||
# attributes of input batch, e.g. batch.attr1.attr2 call
|
||||
# will produce ["attr1", "attr2"]
|
||||
input_attrs: List[str]
|
||||
# batch[attr1].attr2 will produce [True, False]
|
||||
is_getitems: List[bool]
|
||||
# name for kwarg of pipelined forward() call or None
|
||||
# for a positional arg
|
||||
name: Optional[str]
|
||||
|
||||
|
||||
class PipelinedForward:
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
args: List[ArgInfo],
|
||||
module: ShardedModule,
|
||||
context: TrainPipelineContext,
|
||||
dist_stream: Optional[torch.cuda.streams.Stream],
|
||||
) -> None:
|
||||
self._name = name
|
||||
self._args = args
|
||||
self._module = module
|
||||
self._context = context
|
||||
self._dist_stream = dist_stream
|
||||
|
||||
# pyre-ignore [2, 24]
|
||||
def __call__(self, *input, **kwargs) -> Awaitable:
|
||||
assert self._name in self._context.input_dist_requests
|
||||
request = self._context.input_dist_requests[self._name]
|
||||
assert isinstance(request, Awaitable)
|
||||
with record_function("## wait_sparse_data_dist ##"):
|
||||
# Finish waiting on the dist_stream,
|
||||
# in case some delayed stream scheduling happens during the wait() call.
|
||||
with torch.cuda.stream(self._dist_stream):
|
||||
data = request.wait()
|
||||
|
||||
# Make sure that both result of input_dist and context
|
||||
# are properly transferred to the current stream.
|
||||
if self._dist_stream is not None:
|
||||
torch.cuda.current_stream().wait_stream(self._dist_stream)
|
||||
cur_stream = torch.cuda.current_stream()
|
||||
|
||||
assert isinstance(
|
||||
data, (torch.Tensor, Multistreamable)
|
||||
), f"{type(data)} must implement Multistreamable interface"
|
||||
# pyre-fixme[6]: For 1st param expected `Stream` but got `Stream`.
|
||||
data.record_stream(cur_stream)
|
||||
|
||||
ctx = self._context.module_contexts[self._name]
|
||||
ctx.record_stream(cur_stream)
|
||||
|
||||
if len(self._context.feature_processor_forwards) > 0:
|
||||
with record_function("## feature_processor ##"):
|
||||
for sparse_feature in data:
|
||||
if sparse_feature.id_score_list_features is not None:
|
||||
for fp_forward in self._context.feature_processor_forwards:
|
||||
sparse_feature.id_score_list_features = fp_forward(
|
||||
sparse_feature.id_score_list_features
|
||||
)
|
||||
|
||||
return self._module.compute_and_output_dist(self._context.module_contexts[self._name], data)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def args(self) -> List[ArgInfo]:
|
||||
return self._args
|
||||
|
||||
|
||||
def _start_data_dist(
|
||||
pipelined_modules: List[ShardedModule],
|
||||
batch: In,
|
||||
context: TrainPipelineContext,
|
||||
) -> None:
|
||||
context.input_dist_requests.clear()
|
||||
context.module_contexts.clear()
|
||||
for module in pipelined_modules:
|
||||
forward = module.forward
|
||||
assert isinstance(forward, PipelinedForward)
|
||||
|
||||
# Retrieve argument for the input_dist of EBC
|
||||
# is_getitem True means this argument could be retrieved by a list
|
||||
# False means this argument is getting while getattr
|
||||
# and this info was done in the _rewrite_model by tracing the
|
||||
# entire model to get the arg_info_list
|
||||
args = []
|
||||
kwargs = {}
|
||||
for arg_info in forward.args:
|
||||
if arg_info.input_attrs:
|
||||
arg = batch
|
||||
for attr, is_getitem in zip(arg_info.input_attrs, arg_info.is_getitems):
|
||||
if is_getitem:
|
||||
arg = arg[attr]
|
||||
else:
|
||||
arg = getattr(arg, attr)
|
||||
if arg_info.name:
|
||||
kwargs[arg_info.name] = arg
|
||||
else:
|
||||
args.append(arg)
|
||||
else:
|
||||
args.append(None)
|
||||
# Start input distribution.
|
||||
module_ctx = module.create_context()
|
||||
context.module_contexts[forward.name] = module_ctx
|
||||
context.input_dist_requests[forward.name] = module.input_dist(module_ctx, *args, **kwargs)
|
||||
|
||||
# Call wait on the first awaitable in the input dist for the tensor splits
|
||||
for key, awaitable in context.input_dist_requests.items():
|
||||
context.input_dist_requests[key] = awaitable.wait()
|
||||
|
||||
|
||||
def _get_node_args_helper(
|
||||
# pyre-ignore
|
||||
arguments,
|
||||
num_found: int,
|
||||
feature_processor_arguments: Optional[List[Node]] = None,
|
||||
) -> Tuple[List[ArgInfo], int]:
|
||||
"""
|
||||
Goes through the args/kwargs of a node and arranges them into a list of `ArgInfo`s.
|
||||
It also counts the number of (args + kwargs) found.
|
||||
"""
|
||||
|
||||
arg_info_list = [ArgInfo([], [], None) for _ in range(len(arguments))]
|
||||
for arg, arg_info in zip(arguments, arg_info_list):
|
||||
if arg is None:
|
||||
num_found += 1
|
||||
continue
|
||||
while True:
|
||||
if not isinstance(arg, torch.fx.Node):
|
||||
break
|
||||
child_node = arg
|
||||
|
||||
if child_node.op == "placeholder":
|
||||
num_found += 1
|
||||
break
|
||||
# skip this fp node
|
||||
elif feature_processor_arguments is not None and child_node in feature_processor_arguments:
|
||||
arg = child_node.args[0]
|
||||
elif (
|
||||
child_node.op == "call_function"
|
||||
and child_node.target.__module__ == "builtins"
|
||||
# pyre-ignore[16]
|
||||
and child_node.target.__name__ == "getattr"
|
||||
):
|
||||
arg_info.input_attrs.insert(0, child_node.args[1])
|
||||
arg_info.is_getitems.insert(0, False)
|
||||
arg = child_node.args[0]
|
||||
elif (
|
||||
child_node.op == "call_function"
|
||||
and child_node.target.__module__ == "_operator"
|
||||
# pyre-ignore[16]
|
||||
and child_node.target.__name__ == "getitem"
|
||||
):
|
||||
arg_info.input_attrs.insert(0, child_node.args[1])
|
||||
arg_info.is_getitems.insert(0, True)
|
||||
arg = child_node.args[0]
|
||||
else:
|
||||
break
|
||||
return arg_info_list, num_found
|
||||
|
||||
|
||||
def _get_node_args(
|
||||
node: Node, feature_processor_nodes: Optional[List[Node]] = None
|
||||
) -> Tuple[List[ArgInfo], int]:
|
||||
num_found = 0
|
||||
pos_arg_info_list, num_found = _get_node_args_helper(
|
||||
node.args, num_found, feature_processor_nodes
|
||||
)
|
||||
kwargs_arg_info_list, num_found = _get_node_args_helper(node.kwargs.values(), num_found)
|
||||
|
||||
# Replace with proper names for kwargs
|
||||
for name, arg_info_list in zip(node.kwargs, kwargs_arg_info_list):
|
||||
arg_info_list.name = name
|
||||
|
||||
arg_info_list = pos_arg_info_list + kwargs_arg_info_list
|
||||
return arg_info_list, num_found
|
||||
|
||||
|
||||
def _get_unsharded_module_names_helper(
|
||||
model: torch.nn.Module,
|
||||
path: str,
|
||||
unsharded_module_names: Set[str],
|
||||
) -> bool:
|
||||
sharded_children = set()
|
||||
for name, child in model.named_children():
|
||||
curr_path = path + name
|
||||
if isinstance(child, ShardedModule):
|
||||
sharded_children.add(name)
|
||||
else:
|
||||
child_sharded = _get_unsharded_module_names_helper(
|
||||
child,
|
||||
curr_path + ".",
|
||||
unsharded_module_names,
|
||||
)
|
||||
if child_sharded:
|
||||
sharded_children.add(name)
|
||||
|
||||
if len(sharded_children) > 0:
|
||||
for name, _ in model.named_children():
|
||||
if name not in sharded_children:
|
||||
unsharded_module_names.add(path + name)
|
||||
|
||||
return len(sharded_children) > 0
|
||||
|
||||
|
||||
def _get_unsharded_module_names(model: torch.nn.Module) -> List[str]:
|
||||
"""
|
||||
Returns a list of top level modules do not contain any sharded sub modules.
|
||||
"""
|
||||
|
||||
unsharded_module_names: Set[str] = set()
|
||||
_get_unsharded_module_names_helper(
|
||||
model,
|
||||
"",
|
||||
unsharded_module_names,
|
||||
)
|
||||
return list(unsharded_module_names)
|
||||
|
||||
|
||||
def _rewrite_model( # noqa C901
|
||||
model: torch.nn.Module,
|
||||
context: TrainPipelineContext,
|
||||
dist_stream: Optional[torch.cuda.streams.Stream],
|
||||
) -> List[ShardedModule]:
|
||||
|
||||
# Get underlying nn.Module
|
||||
if isinstance(model, DistributedModelParallel):
|
||||
model = model.module
|
||||
|
||||
# Collect a list of sharded modules.
|
||||
sharded_modules = {}
|
||||
fp_modules = {}
|
||||
for name, m in model.named_modules():
|
||||
if isinstance(m, ShardedModule):
|
||||
sharded_modules[name] = m
|
||||
if isinstance(m, BaseGroupedFeatureProcessor):
|
||||
fp_modules[name] = m
|
||||
|
||||
# Trace a model.
|
||||
tracer = Tracer(leaf_modules=_get_unsharded_module_names(model))
|
||||
graph = tracer.trace(model)
|
||||
|
||||
feature_processor_nodes = []
|
||||
# find the fp node
|
||||
for node in graph.nodes:
|
||||
if node.op == "call_module" and node.target in fp_modules:
|
||||
feature_processor_nodes.append(node)
|
||||
# Select sharded modules, which are top-level in the forward call graph,
|
||||
# i.e. which don't have input transformations, i.e.
|
||||
# rely only on 'builtins.getattr'.
|
||||
ret = []
|
||||
for node in graph.nodes:
|
||||
if node.op == "call_module" and node.target in sharded_modules:
|
||||
total_num_args = len(node.args) + len(node.kwargs)
|
||||
if total_num_args == 0:
|
||||
continue
|
||||
arg_info_list, num_found = _get_node_args(node, feature_processor_nodes)
|
||||
if num_found == total_num_args:
|
||||
logger.info(f"Module '{node.target}'' will be pipelined")
|
||||
child = sharded_modules[node.target]
|
||||
child.forward = PipelinedForward(
|
||||
node.target,
|
||||
arg_info_list,
|
||||
child,
|
||||
context,
|
||||
dist_stream,
|
||||
)
|
||||
ret.append(child)
|
||||
return ret
|
||||
|
||||
|
||||
class TrainPipelineSparseDist(TrainPipeline[In, Out]):
|
||||
"""
|
||||
This pipeline overlaps device transfer, and `ShardedModule.input_dist()` with
|
||||
forward and backward. This helps hide the all2all latency while preserving the
|
||||
training forward / backward ordering.
|
||||
|
||||
stage 3: forward, backward - uses default CUDA stream
|
||||
stage 2: ShardedModule.input_dist() - uses data_dist CUDA stream
|
||||
stage 1: device transfer - uses memcpy CUDA stream
|
||||
|
||||
`ShardedModule.input_dist()` is only done for top-level modules in the call graph.
|
||||
To be considered a top-level module, a module can only depend on 'getattr' calls on
|
||||
input.
|
||||
|
||||
Input model must be symbolically traceable with the exception of `ShardedModule` and
|
||||
`DistributedDataParallel` modules.
|
||||
"""
|
||||
|
||||
synced_pipeline_id: Dict[int, int] = {}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: torch.nn.Module,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
device: torch.device,
|
||||
enable_amp: bool = False,
|
||||
enable_grad_scaling: bool = True,
|
||||
grad_accum: Optional[int] = None,
|
||||
) -> None:
|
||||
self._model = model
|
||||
self._optimizer = optimizer
|
||||
self._device = device
|
||||
self._enable_amp = enable_amp
|
||||
# NOTE: Pending upstream feedback, but two flags because we can run AMP without CUDA but cannot scale gradients without CUDA.
|
||||
# Background on gradient/loss scaling
|
||||
# https://docs.nvidia.com/deeplearning/performance/mixed-precision-training/index.html#lossscaling
|
||||
# https://pytorch.org/docs/stable/amp.html#gradient-scaling
|
||||
self._enable_grad_scaling = enable_grad_scaling
|
||||
self._grad_scaler = torch.cuda.amp.GradScaler(
|
||||
enabled=self._enable_amp and self._enable_grad_scaling
|
||||
)
|
||||
logging.info(f"Amp is enabled: {self._enable_amp}")
|
||||
|
||||
# use two data streams to support two concurrent batches
|
||||
if device.type == "cuda":
|
||||
self._memcpy_stream: Optional[torch.cuda.streams.Stream] = torch.cuda.Stream()
|
||||
self._data_dist_stream: Optional[torch.cuda.streams.Stream] = torch.cuda.Stream()
|
||||
else:
|
||||
if self._enable_amp:
|
||||
logging.warning("Amp is enabled, but no CUDA available")
|
||||
self._memcpy_stream: Optional[torch.cuda.streams.Stream] = None
|
||||
self._data_dist_stream: Optional[torch.cuda.streams.Stream] = None
|
||||
self._batch_i: Optional[In] = None
|
||||
self._batch_ip1: Optional[In] = None
|
||||
self._batch_ip2: Optional[In] = None
|
||||
self._connected = False
|
||||
self._context = TrainPipelineContext()
|
||||
self._pipelined_modules: List[ShardedModule] = []
|
||||
|
||||
self._progress_calls = 0
|
||||
if grad_accum is not None:
|
||||
assert isinstance(grad_accum, int) and grad_accum > 0
|
||||
self._grad_accum = grad_accum
|
||||
|
||||
def _connect(self, dataloader_iter: Iterator[In]) -> None:
|
||||
# batch 1
|
||||
with torch.cuda.stream(self._memcpy_stream):
|
||||
batch_i = next(dataloader_iter)
|
||||
self._batch_i = batch_i = _to_device(batch_i, self._device, non_blocking=True)
|
||||
# Try to pipeline input data dist.
|
||||
self._pipelined_modules = _rewrite_model(self._model, self._context, self._data_dist_stream)
|
||||
|
||||
with torch.cuda.stream(self._data_dist_stream):
|
||||
_wait_for_batch(batch_i, self._memcpy_stream)
|
||||
_start_data_dist(self._pipelined_modules, batch_i, self._context)
|
||||
|
||||
# batch 2
|
||||
with torch.cuda.stream(self._memcpy_stream):
|
||||
batch_ip1 = next(dataloader_iter)
|
||||
self._batch_ip1 = batch_ip1 = _to_device(batch_ip1, self._device, non_blocking=True)
|
||||
self._connected = True
|
||||
self.__class__.synced_pipeline_id[id(self._model)] = id(self)
|
||||
|
||||
def progress(self, dataloader_iter: Iterator[In]) -> Out:
|
||||
"""
|
||||
NOTE: This method has been updated to perform gradient accumulation.
|
||||
If `_grad_accum` is set, then loss values are scaled by this amount and
|
||||
optimizer update/reset is skipped for `_grad_accum` calls of `progress`
|
||||
(congruent to training steps), and then update/reset on every `_grad_accum`th
|
||||
step.
|
||||
|
||||
"""
|
||||
should_step_optimizer = (
|
||||
self._grad_accum is not None
|
||||
and self._progress_calls > 0
|
||||
and (self._progress_calls + 1) % self._grad_accum == 0
|
||||
) or self._grad_accum is None
|
||||
should_reset_optimizer = (
|
||||
self._grad_accum is not None
|
||||
and self._progress_calls > 0
|
||||
and (self._progress_calls + 2) % self._grad_accum == 0
|
||||
) or self._grad_accum is None
|
||||
|
||||
if not self._connected:
|
||||
self._connect(dataloader_iter)
|
||||
elif self.__class__.synced_pipeline_id.get(id(self._model), None) != id(self):
|
||||
self._sync_pipeline()
|
||||
self.__class__.synced_pipeline_id[id(self._model)] = id(self)
|
||||
|
||||
if self._model.training and should_reset_optimizer:
|
||||
with record_function("## zero_grad ##"):
|
||||
self._optimizer.zero_grad()
|
||||
|
||||
with record_function("## copy_batch_to_gpu ##"):
|
||||
with torch.cuda.stream(self._memcpy_stream):
|
||||
batch_ip2 = next(dataloader_iter)
|
||||
self._batch_ip2 = batch_ip2 = _to_device(batch_ip2, self._device, non_blocking=True)
|
||||
batch_i = cast(In, self._batch_i)
|
||||
batch_ip1 = cast(In, self._batch_ip1)
|
||||
|
||||
with record_function("## wait_for_batch ##"):
|
||||
_wait_for_batch(batch_i, self._data_dist_stream)
|
||||
|
||||
# Forward
|
||||
with record_function("## forward ##"):
|
||||
# if using multiple streams (ie. CUDA), create an event in default stream
|
||||
# before starting forward pass
|
||||
if self._data_dist_stream:
|
||||
event = torch.cuda.current_stream().record_event()
|
||||
if self._enable_amp:
|
||||
# conditionally apply the model to the batch in the autocast context
|
||||
# it appears that `enabled=self._enable_amp` should handle this,
|
||||
# but it does not.
|
||||
with torch.autocast(
|
||||
device_type=self._device.type,
|
||||
dtype=torch.bfloat16,
|
||||
enabled=self._enable_amp,
|
||||
):
|
||||
(losses, output) = cast(Tuple[torch.Tensor, Out], self._model(batch_i))
|
||||
else:
|
||||
(losses, output) = cast(Tuple[torch.Tensor, Out], self._model(batch_i))
|
||||
|
||||
# Data Distribution
|
||||
with record_function("## sparse_data_dist ##"):
|
||||
with torch.cuda.stream(self._data_dist_stream):
|
||||
_wait_for_batch(batch_ip1, self._memcpy_stream)
|
||||
# Ensure event in default stream has been called before
|
||||
# starting data dist
|
||||
if self._data_dist_stream:
|
||||
# pyre-ignore [61]: Local variable `event` is undefined, or not always defined
|
||||
self._data_dist_stream.wait_event(event)
|
||||
_start_data_dist(self._pipelined_modules, batch_ip1, self._context)
|
||||
|
||||
if self._model.training:
|
||||
# Backward
|
||||
with record_function("## backward ##"):
|
||||
# Loss is normalize by number of accumulation steps.
|
||||
# The reported loss in `output['loss']` remains the unnormalized value.
|
||||
if self._grad_accum is not None:
|
||||
losses = losses / self._grad_accum
|
||||
self._grad_scaler.scale(torch.sum(losses, dim=0)).backward()
|
||||
|
||||
if should_step_optimizer:
|
||||
# Update
|
||||
with record_function("## optimizer ##"):
|
||||
self._grad_scaler.step(self._optimizer)
|
||||
self._grad_scaler.update()
|
||||
|
||||
self._batch_i = batch_ip1
|
||||
self._batch_ip1 = batch_ip2
|
||||
|
||||
if self._model.training:
|
||||
self._progress_calls += 1
|
||||
|
||||
return output
|
||||
|
||||
def _sync_pipeline(self) -> None:
|
||||
"""
|
||||
Syncs `PipelinedForward` for sharded modules with context and dist stream of the
|
||||
current train pipeline. Used when switching between train pipelines for the same
|
||||
model.
|
||||
"""
|
||||
for module in self._pipelined_modules:
|
||||
module.forward._context = self._context
|
||||
module.forward._dist_stream = self._data_dist_stream
|
Reference in New Issue
Block a user