This commit is contained in:
rajveer43 2023-09-11 20:27:52 +05:30
parent 799254345f
commit 9bb0986079

View File

@ -39,12 +39,42 @@ Out = TypeVar("Out")
class TrainPipeline(abc.ABC, Generic[In, Out]):
"""
Abstract base class for training pipelines.
Attributes:
In (TypeVar): Input data type.
Out (TypeVar): Output data type.
Methods:
progress(dataloader_iter: Iterator[In]) -> Out: Abstract method to make progress in the training pipeline.
"""
@abc.abstractmethod
def progress(self, dataloader_iter: Iterator[In]) -> Out:
"""
Make progress in the training pipeline.
Args:
dataloader_iter (Iterator[In]): An iterator over input data.
Returns:
Out: The output data.
"""
pass
def _to_device(batch: In, device: torch.device, non_blocking: bool) -> In:
"""
Move a batch of data to a specified device.
Args:
batch (In): The input batch.
device (torch.device): The target device.
non_blocking (bool): If True, move the data asynchronously.
Returns:
In: The batch of data on the target device.
"""
assert isinstance(
batch, (torch.Tensor, Pipelineable)
), f"{type(batch)} must implement Pipelineable interface"
@ -52,6 +82,16 @@ def _to_device(batch: In, device: torch.device, non_blocking: bool) -> In:
def _wait_for_batch(batch: In, stream: Optional[torch.cuda.streams.Stream]) -> None:
"""
Wait for a batch of data on a specified stream.
Args:
batch (In): The input batch.
stream (Optional[Stream]): The CUDA stream to wait for.
Note:
This function is used for managing asynchronous CUDA operations.
"""
if stream is None:
return
torch.cuda.current_stream().wait_stream(stream)
@ -72,11 +112,26 @@ def _wait_for_batch(batch: In, stream: Optional[torch.cuda.streams.Stream]) -> N
class TrainPipelineBase(TrainPipeline[In, Out]):
"""
This class runs training iterations using a pipeline of two stages, each as a CUDA
stream, namely, the current (default) stream and `self._memcpy_stream`. For each
iteration, `self._memcpy_stream` moves the input from host (CPU) memory to GPU
memory, and the default stream runs forward, backward, and optimization.
"""
This class runs training iterations using a pipeline of two stages, each as a CUDA
stream, namely, the current (default) stream and `self._memcpy_stream`. For each
iteration, `self._memcpy_stream` moves the input from host (CPU) memory to GPU
memory, and the default stream runs forward, backward, and optimization.
Attributes:
In (TypeVar): Input data type.
Out (TypeVar): Output data type.
Methods:
__init__(model: torch.nn.Module, optimizer: torch.optim.Optimizer, device: torch.device) -> None:
Initialize the TrainPipelineBase.
_connect(dataloader_iter: Iterator[In]) -> None:
Establish a connection to the data loader and move the input data to the GPU.
progress(dataloader_iter: Iterator[In]) -> Out:
Execute a training iteration, including forward and backward passes.
"""
def __init__(
self,
@ -84,6 +139,14 @@ class TrainPipelineBase(TrainPipeline[In, Out]):
optimizer: torch.optim.Optimizer,
device: torch.device,
) -> None:
"""
Initialize the TrainPipelineBase.
Args:
model (torch.nn.Module): The PyTorch model to be trained.
optimizer (torch.optim.Optimizer): The optimizer used for training.
device (torch.device): The target device for training (CPU or GPU).
"""
self._model = model
self._optimizer = optimizer
self._device = device
@ -94,6 +157,12 @@ class TrainPipelineBase(TrainPipeline[In, Out]):
self._connected = False
def _connect(self, dataloader_iter: Iterator[In]) -> None:
"""
Establish a connection to the data loader and move the input data to the GPU.
Args:
dataloader_iter (Iterator[In]): An iterator over input data.
"""
cur_batch = next(dataloader_iter)
self._cur_batch = cur_batch
with torch.cuda.stream(self._memcpy_stream):
@ -101,6 +170,15 @@ class TrainPipelineBase(TrainPipeline[In, Out]):
self._connected = True
def progress(self, dataloader_iter: Iterator[In]) -> Out:
"""
Execute a training iteration, including forward and backward passes.
Args:
dataloader_iter (Iterator[In]): An iterator over input data.
Returns:
Out: The output data.
"""
if not self._connected:
self._connect(dataloader_iter)
@ -139,6 +217,16 @@ class TrainPipelineBase(TrainPipeline[In, Out]):
class Tracer(torch.fx.Tracer):
"""
Custom tracer class for PyTorch models.
This tracer is used to trace PyTorch models while also considering specific leaf modules and buffer proxying settings.
Attributes:
proxy_buffer_attributes (bool): Flag to enable/disable proxying buffers during tracing.
_leaf_modules (List[str]): List of qualified names of leaf modules.
"""
# Disable proxying buffers during tracing. Ideally, proxying buffers would
# be disabled, but some models are currently mutating buffer values, which
# causes errors during tracing. If those models can be rewritten to not do
@ -146,10 +234,26 @@ class Tracer(torch.fx.Tracer):
proxy_buffer_attributes = False
def __init__(self, leaf_modules: Optional[List[str]] = None) -> None:
"""
Initialize the Tracer.
Args:
leaf_modules (Optional[List[str]]): List of qualified names of leaf modules to consider as leaf nodes during tracing.
"""
super().__init__()
self._leaf_modules: List[str] = leaf_modules if leaf_modules is not None else []
def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool:
"""
Check if a module is a leaf module during tracing.
Args:
m (torch.nn.Module): The PyTorch module.
module_qualified_name (str): The qualified name of the module.
Returns:
bool: True if the module is considered a leaf module, False otherwise.
"""
if isinstance(m, ShardedModule) or module_qualified_name in self._leaf_modules:
return True
return super().is_leaf_module(m, module_qualified_name)
@ -157,6 +261,15 @@ class Tracer(torch.fx.Tracer):
@dataclass
class TrainPipelineContext:
"""
Dataclass to store information related to the training pipeline context.
Attributes:
input_dist_requests (Dict[str, Awaitable[Any]]): A dictionary of input distribution requests.
module_contexts (Dict[str, Multistreamable]): A dictionary of module contexts.
feature_processor_forwards (List[Any]): A list of feature processor forwards.
"""
# pyre-ignore [4]
input_dist_requests: Dict[str, Awaitable[Any]] = field(default_factory=dict)
module_contexts: Dict[str, Multistreamable] = field(default_factory=dict)
@ -166,6 +279,14 @@ class TrainPipelineContext:
@dataclass
class ArgInfo:
"""
Dataclass to store information about arguments in the training pipeline.
Attributes:
input_attrs (List[str]): List of attribute names of the input batch.
is_getitems (List[bool]): List of boolean values indicating whether the argument is accessed using getitem.
name (Optional[str]): Name for the keyword argument in the pipelined forward() call or None for positional arguments.
"""
# attributes of input batch, e.g. batch.attr1.attr2 call
# will produce ["attr1", "attr2"]
input_attrs: List[str]
@ -177,6 +298,16 @@ class ArgInfo:
class PipelinedForward:
"""
Represents a pipelined forward pass operation.
Attributes:
name (str): The name of the forward pass.
args (List[ArgInfo]): List of argument information for the forward pass.
module (ShardedModule): The sharded module associated with the forward pass.
context (TrainPipelineContext): The training pipeline context.
dist_stream (Optional[torch.cuda.streams.Stream]): CUDA stream for distributed processing.
"""
def __init__(
self,
name: str,
@ -185,6 +316,16 @@ class PipelinedForward:
context: TrainPipelineContext,
dist_stream: Optional[torch.cuda.streams.Stream],
) -> None:
"""
Initialize a PipelinedForward instance.
Args:
name (str): The name of the forward pass.
args (List[ArgInfo]): List of argument information for the forward pass.
module (ShardedModule): The sharded module associated with the forward pass.
context (TrainPipelineContext): The training pipeline context.
dist_stream (Optional[torch.cuda.streams.Stream]): CUDA stream for distributed processing.
"""
self._name = name
self._args = args
self._module = module
@ -193,6 +334,16 @@ class PipelinedForward:
# pyre-ignore [2, 24]
def __call__(self, *input, **kwargs) -> Awaitable:
"""
Perform the pipelined forward pass operation.
Args:
*input: Variable-length positional arguments.
**kwargs: Variable-length keyword arguments.
Returns:
Awaitable: An awaitable object representing the forward pass result.
"""
assert self._name in self._context.input_dist_requests
request = self._context.input_dist_requests[self._name]
assert isinstance(request, Awaitable)
@ -230,10 +381,22 @@ class PipelinedForward:
@property
def name(self) -> str:
"""
Get the name of the forward pass.
Returns:
str: The name of the forward pass.
"""
return self._name
@property
def args(self) -> List[ArgInfo]:
"""
Get the list of argument information for the forward pass.
Returns:
List[ArgInfo]: List of argument information.
"""
return self._args
@ -242,6 +405,17 @@ def _start_data_dist(
batch: In,
context: TrainPipelineContext,
) -> None:
"""
Start data distribution for a list of pipelined modules.
Args:
pipelined_modules (List[ShardedModule]): List of ShardedModule instances representing pipelined modules.
batch (In): The input batch.
context (TrainPipelineContext): The training pipeline context.
Returns:
None: This function doesn't return a value.
"""
context.input_dist_requests.clear()
context.module_contexts.clear()
for module in pipelined_modules:
@ -286,9 +460,17 @@ def _get_node_args_helper(
feature_processor_arguments: Optional[List[Node]] = None,
) -> Tuple[List[ArgInfo], int]:
"""
Goes through the args/kwargs of a node and arranges them into a list of `ArgInfo`s.
It also counts the number of (args + kwargs) found.
"""
Goes through the args/kwargs of a node and arranges them into a list of `ArgInfo`s.
It also counts the number of (args + kwargs) found.
Args:
arguments: The arguments to process.
num_found: The current count of arguments found.
feature_processor_arguments: Optional list of feature processor arguments.
Returns:
Tuple[List[ArgInfo], int]: A tuple containing a list of `ArgInfo` objects and the updated count of arguments found.
"""
arg_info_list = [ArgInfo([], [], None) for _ in range(len(arguments))]
for arg, arg_info in zip(arguments, arg_info_list):
@ -332,6 +514,16 @@ def _get_node_args_helper(
def _get_node_args(
node: Node, feature_processor_nodes: Optional[List[Node]] = None
) -> Tuple[List[ArgInfo], int]:
"""
Get argument information for a given node.
Args:
node (Node): The node to process.
feature_processor_nodes (Optional[List[Node]]): Optional list of feature processor nodes.
Returns:
Tuple[List[ArgInfo], int]: A tuple containing a list of `ArgInfo` objects and the number of arguments found.
"""
num_found = 0
pos_arg_info_list, num_found = _get_node_args_helper(
node.args, num_found, feature_processor_nodes
@ -351,6 +543,17 @@ def _get_unsharded_module_names_helper(
path: str,
unsharded_module_names: Set[str],
) -> bool:
"""
Get the names of unsharded modules in a model.
Args:
model (torch.nn.Module): The model to analyze.
path (str): The current path in the model hierarchy.
unsharded_module_names (Set[str]): A set to store the names of unsharded modules.
Returns:
bool: True if any sharded modules were found in the hierarchy, False otherwise.
"""
sharded_children = set()
for name, child in model.named_children():
curr_path = path + name
@ -375,8 +578,14 @@ def _get_unsharded_module_names_helper(
def _get_unsharded_module_names(model: torch.nn.Module) -> List[str]:
"""
Returns a list of top level modules do not contain any sharded sub modules.
"""
Returns a list of top-level modules that do not contain any sharded sub-modules.
Args:
model (torch.nn.Module): The model to analyze.
Returns:
List[str]: A list of top-level module names without sharded sub-modules.
"""
unsharded_module_names: Set[str] = set()
_get_unsharded_module_names_helper(
@ -392,6 +601,21 @@ def _rewrite_model( # noqa C901
context: TrainPipelineContext,
dist_stream: Optional[torch.cuda.streams.Stream],
) -> List[ShardedModule]:
"""
Rewrites the model to enable pipelined execution for selected sharded modules.
This function traces the input model using a custom tracer and identifies sharded modules
that can be pipelined. It then creates PipelinedForward objects for these modules,
which enable pipelining during training.
Args:
model (torch.nn.Module): The input model to be rewritten.
context (TrainPipelineContext): The context containing information needed for pipelining.
dist_stream (Optional[torch.cuda.streams.Stream]): The CUDA stream for data distribution.
Returns:
List[ShardedModule]: A list of sharded modules that have been rewritten for pipelined execution.
"""
# Get underlying nn.Module
if isinstance(model, DistributedModelParallel):
@ -442,20 +666,32 @@ def _rewrite_model( # noqa C901
class TrainPipelineSparseDist(TrainPipeline[In, Out]):
"""
This pipeline overlaps device transfer, and `ShardedModule.input_dist()` with
forward and backward. This helps hide the all2all latency while preserving the
training forward / backward ordering.
forward and backward. This helps hide the all2all latency while preserving the
training forward / backward ordering.
stage 3: forward, backward - uses default CUDA stream
stage 2: ShardedModule.input_dist() - uses data_dist CUDA stream
stage 1: device transfer - uses memcpy CUDA stream
stage 3: forward, backward - uses default CUDA stream
stage 2: ShardedModule.input_dist() - uses data_dist CUDA stream
stage 1: device transfer - uses memcpy CUDA stream
`ShardedModule.input_dist()` is only done for top-level modules in the call graph.
To be considered a top-level module, a module can only depend on 'getattr' calls on
input.
`ShardedModule.input_dist()` is only done for top-level modules in the call graph.
To be considered a top-level module, a module can only depend on 'getattr' calls on
input.
Input model must be symbolically traceable with the exception of `ShardedModule` and
`DistributedDataParallel` modules.
"""
Input model must be symbolically traceable with the exception of `ShardedModule` and
`DistributedDataParallel` modules.
Args:
model (torch.nn.Module): The input model to be used for training.
optimizer (torch.optim.Optimizer): The optimizer for updating model parameters.
device (torch.device): The device where training will be performed.
enable_amp (bool, optional): Whether to enable automatic mixed precision (AMP). Defaults to False.
enable_grad_scaling (bool, optional): Whether to enable gradient scaling. Defaults to True.
grad_accum (int, optional): Number of gradient accumulation steps. Defaults to None.
Attributes:
synced_pipeline_id (Dict[int, int]): A dictionary to track synchronized pipelines.
"""
synced_pipeline_id: Dict[int, int] = {}
@ -468,6 +704,17 @@ class TrainPipelineSparseDist(TrainPipeline[In, Out]):
enable_grad_scaling: bool = True,
grad_accum: Optional[int] = None,
) -> None:
"""
Initializes the training pipeline.
Args:
model (torch.nn.Module): The input model to be used for training.
optimizer (torch.optim.Optimizer): The optimizer for updating model parameters.
device (torch.device): The device where training will be performed.
enable_amp (bool, optional): Whether to enable automatic mixed precision (AMP). Defaults to False.
enable_grad_scaling (bool, optional): Whether to enable gradient scaling. Defaults to True.
grad_accum (int, optional): Number of gradient accumulation steps. Defaults to None.
"""
self._model = model
self._optimizer = optimizer
self._device = device
@ -504,6 +751,13 @@ class TrainPipelineSparseDist(TrainPipeline[In, Out]):
self._grad_accum = grad_accum
def _connect(self, dataloader_iter: Iterator[In]) -> None:
"""
Connects the training pipeline to data and prepares for forward and backward passes.
Args:
dataloader_iter (Iterator[In]): An iterator providing input data batches.
"""
# batch 1
with torch.cuda.stream(self._memcpy_stream):
batch_i = next(dataloader_iter)
@ -524,13 +778,20 @@ class TrainPipelineSparseDist(TrainPipeline[In, Out]):
def progress(self, dataloader_iter: Iterator[In]) -> Out:
"""
NOTE: This method has been updated to perform gradient accumulation.
If `_grad_accum` is set, then loss values are scaled by this amount and
optimizer update/reset is skipped for `_grad_accum` calls of `progress`
(congruent to training steps), and then update/reset on every `_grad_accum`th
step.
Progresses through the training pipeline, performing forward and backward passes.
"""
NOTE: This method has been updated to perform gradient accumulation.
If `_grad_accum` is set, then loss values are scaled by this amount and
optimizer update/reset is skipped for `_grad_accum` calls of `progress`
(congruent to training steps), and then update/reset on every `_grad_accum`th
step.
Args:
dataloader_iter (Iterator[In]): An iterator providing input data batches.
Returns:
Out: The output of the forward pass.
"""
should_step_optimizer = (
self._grad_accum is not None
and self._progress_calls > 0
@ -617,9 +878,9 @@ class TrainPipelineSparseDist(TrainPipeline[In, Out]):
def _sync_pipeline(self) -> None:
"""
Syncs `PipelinedForward` for sharded modules with context and dist stream of the
current train pipeline. Used when switching between train pipelines for the same
model.
Syncs `PipelinedForward` for sharded modules with context and dist stream of the
current train pipeline. Used when switching between train pipelines for the same
model.
"""
for module in self._pipelined_modules:
module.forward._context = self._context