This commit is contained in:
rajveer43 2023-09-11 20:27:52 +05:30
parent 799254345f
commit 9bb0986079

View File

@ -39,12 +39,42 @@ Out = TypeVar("Out")
class TrainPipeline(abc.ABC, Generic[In, Out]): class TrainPipeline(abc.ABC, Generic[In, Out]):
"""
Abstract base class for training pipelines.
Attributes:
In (TypeVar): Input data type.
Out (TypeVar): Output data type.
Methods:
progress(dataloader_iter: Iterator[In]) -> Out: Abstract method to make progress in the training pipeline.
"""
@abc.abstractmethod @abc.abstractmethod
def progress(self, dataloader_iter: Iterator[In]) -> Out: def progress(self, dataloader_iter: Iterator[In]) -> Out:
"""
Make progress in the training pipeline.
Args:
dataloader_iter (Iterator[In]): An iterator over input data.
Returns:
Out: The output data.
"""
pass pass
def _to_device(batch: In, device: torch.device, non_blocking: bool) -> In: def _to_device(batch: In, device: torch.device, non_blocking: bool) -> In:
"""
Move a batch of data to a specified device.
Args:
batch (In): The input batch.
device (torch.device): The target device.
non_blocking (bool): If True, move the data asynchronously.
Returns:
In: The batch of data on the target device.
"""
assert isinstance( assert isinstance(
batch, (torch.Tensor, Pipelineable) batch, (torch.Tensor, Pipelineable)
), f"{type(batch)} must implement Pipelineable interface" ), f"{type(batch)} must implement Pipelineable interface"
@ -52,6 +82,16 @@ def _to_device(batch: In, device: torch.device, non_blocking: bool) -> In:
def _wait_for_batch(batch: In, stream: Optional[torch.cuda.streams.Stream]) -> None: def _wait_for_batch(batch: In, stream: Optional[torch.cuda.streams.Stream]) -> None:
"""
Wait for a batch of data on a specified stream.
Args:
batch (In): The input batch.
stream (Optional[Stream]): The CUDA stream to wait for.
Note:
This function is used for managing asynchronous CUDA operations.
"""
if stream is None: if stream is None:
return return
torch.cuda.current_stream().wait_stream(stream) torch.cuda.current_stream().wait_stream(stream)
@ -72,11 +112,26 @@ def _wait_for_batch(batch: In, stream: Optional[torch.cuda.streams.Stream]) -> N
class TrainPipelineBase(TrainPipeline[In, Out]): class TrainPipelineBase(TrainPipeline[In, Out]):
""" """
This class runs training iterations using a pipeline of two stages, each as a CUDA This class runs training iterations using a pipeline of two stages, each as a CUDA
stream, namely, the current (default) stream and `self._memcpy_stream`. For each stream, namely, the current (default) stream and `self._memcpy_stream`. For each
iteration, `self._memcpy_stream` moves the input from host (CPU) memory to GPU iteration, `self._memcpy_stream` moves the input from host (CPU) memory to GPU
memory, and the default stream runs forward, backward, and optimization. memory, and the default stream runs forward, backward, and optimization.
"""
Attributes:
In (TypeVar): Input data type.
Out (TypeVar): Output data type.
Methods:
__init__(model: torch.nn.Module, optimizer: torch.optim.Optimizer, device: torch.device) -> None:
Initialize the TrainPipelineBase.
_connect(dataloader_iter: Iterator[In]) -> None:
Establish a connection to the data loader and move the input data to the GPU.
progress(dataloader_iter: Iterator[In]) -> Out:
Execute a training iteration, including forward and backward passes.
"""
def __init__( def __init__(
self, self,
@ -84,6 +139,14 @@ class TrainPipelineBase(TrainPipeline[In, Out]):
optimizer: torch.optim.Optimizer, optimizer: torch.optim.Optimizer,
device: torch.device, device: torch.device,
) -> None: ) -> None:
"""
Initialize the TrainPipelineBase.
Args:
model (torch.nn.Module): The PyTorch model to be trained.
optimizer (torch.optim.Optimizer): The optimizer used for training.
device (torch.device): The target device for training (CPU or GPU).
"""
self._model = model self._model = model
self._optimizer = optimizer self._optimizer = optimizer
self._device = device self._device = device
@ -94,6 +157,12 @@ class TrainPipelineBase(TrainPipeline[In, Out]):
self._connected = False self._connected = False
def _connect(self, dataloader_iter: Iterator[In]) -> None: def _connect(self, dataloader_iter: Iterator[In]) -> None:
"""
Establish a connection to the data loader and move the input data to the GPU.
Args:
dataloader_iter (Iterator[In]): An iterator over input data.
"""
cur_batch = next(dataloader_iter) cur_batch = next(dataloader_iter)
self._cur_batch = cur_batch self._cur_batch = cur_batch
with torch.cuda.stream(self._memcpy_stream): with torch.cuda.stream(self._memcpy_stream):
@ -101,6 +170,15 @@ class TrainPipelineBase(TrainPipeline[In, Out]):
self._connected = True self._connected = True
def progress(self, dataloader_iter: Iterator[In]) -> Out: def progress(self, dataloader_iter: Iterator[In]) -> Out:
"""
Execute a training iteration, including forward and backward passes.
Args:
dataloader_iter (Iterator[In]): An iterator over input data.
Returns:
Out: The output data.
"""
if not self._connected: if not self._connected:
self._connect(dataloader_iter) self._connect(dataloader_iter)
@ -139,6 +217,16 @@ class TrainPipelineBase(TrainPipeline[In, Out]):
class Tracer(torch.fx.Tracer): class Tracer(torch.fx.Tracer):
"""
Custom tracer class for PyTorch models.
This tracer is used to trace PyTorch models while also considering specific leaf modules and buffer proxying settings.
Attributes:
proxy_buffer_attributes (bool): Flag to enable/disable proxying buffers during tracing.
_leaf_modules (List[str]): List of qualified names of leaf modules.
"""
# Disable proxying buffers during tracing. Ideally, proxying buffers would # Disable proxying buffers during tracing. Ideally, proxying buffers would
# be disabled, but some models are currently mutating buffer values, which # be disabled, but some models are currently mutating buffer values, which
# causes errors during tracing. If those models can be rewritten to not do # causes errors during tracing. If those models can be rewritten to not do
@ -146,10 +234,26 @@ class Tracer(torch.fx.Tracer):
proxy_buffer_attributes = False proxy_buffer_attributes = False
def __init__(self, leaf_modules: Optional[List[str]] = None) -> None: def __init__(self, leaf_modules: Optional[List[str]] = None) -> None:
"""
Initialize the Tracer.
Args:
leaf_modules (Optional[List[str]]): List of qualified names of leaf modules to consider as leaf nodes during tracing.
"""
super().__init__() super().__init__()
self._leaf_modules: List[str] = leaf_modules if leaf_modules is not None else [] self._leaf_modules: List[str] = leaf_modules if leaf_modules is not None else []
def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool: def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool:
"""
Check if a module is a leaf module during tracing.
Args:
m (torch.nn.Module): The PyTorch module.
module_qualified_name (str): The qualified name of the module.
Returns:
bool: True if the module is considered a leaf module, False otherwise.
"""
if isinstance(m, ShardedModule) or module_qualified_name in self._leaf_modules: if isinstance(m, ShardedModule) or module_qualified_name in self._leaf_modules:
return True return True
return super().is_leaf_module(m, module_qualified_name) return super().is_leaf_module(m, module_qualified_name)
@ -157,6 +261,15 @@ class Tracer(torch.fx.Tracer):
@dataclass @dataclass
class TrainPipelineContext: class TrainPipelineContext:
"""
Dataclass to store information related to the training pipeline context.
Attributes:
input_dist_requests (Dict[str, Awaitable[Any]]): A dictionary of input distribution requests.
module_contexts (Dict[str, Multistreamable]): A dictionary of module contexts.
feature_processor_forwards (List[Any]): A list of feature processor forwards.
"""
# pyre-ignore [4] # pyre-ignore [4]
input_dist_requests: Dict[str, Awaitable[Any]] = field(default_factory=dict) input_dist_requests: Dict[str, Awaitable[Any]] = field(default_factory=dict)
module_contexts: Dict[str, Multistreamable] = field(default_factory=dict) module_contexts: Dict[str, Multistreamable] = field(default_factory=dict)
@ -166,6 +279,14 @@ class TrainPipelineContext:
@dataclass @dataclass
class ArgInfo: class ArgInfo:
"""
Dataclass to store information about arguments in the training pipeline.
Attributes:
input_attrs (List[str]): List of attribute names of the input batch.
is_getitems (List[bool]): List of boolean values indicating whether the argument is accessed using getitem.
name (Optional[str]): Name for the keyword argument in the pipelined forward() call or None for positional arguments.
"""
# attributes of input batch, e.g. batch.attr1.attr2 call # attributes of input batch, e.g. batch.attr1.attr2 call
# will produce ["attr1", "attr2"] # will produce ["attr1", "attr2"]
input_attrs: List[str] input_attrs: List[str]
@ -177,6 +298,16 @@ class ArgInfo:
class PipelinedForward: class PipelinedForward:
"""
Represents a pipelined forward pass operation.
Attributes:
name (str): The name of the forward pass.
args (List[ArgInfo]): List of argument information for the forward pass.
module (ShardedModule): The sharded module associated with the forward pass.
context (TrainPipelineContext): The training pipeline context.
dist_stream (Optional[torch.cuda.streams.Stream]): CUDA stream for distributed processing.
"""
def __init__( def __init__(
self, self,
name: str, name: str,
@ -185,6 +316,16 @@ class PipelinedForward:
context: TrainPipelineContext, context: TrainPipelineContext,
dist_stream: Optional[torch.cuda.streams.Stream], dist_stream: Optional[torch.cuda.streams.Stream],
) -> None: ) -> None:
"""
Initialize a PipelinedForward instance.
Args:
name (str): The name of the forward pass.
args (List[ArgInfo]): List of argument information for the forward pass.
module (ShardedModule): The sharded module associated with the forward pass.
context (TrainPipelineContext): The training pipeline context.
dist_stream (Optional[torch.cuda.streams.Stream]): CUDA stream for distributed processing.
"""
self._name = name self._name = name
self._args = args self._args = args
self._module = module self._module = module
@ -193,6 +334,16 @@ class PipelinedForward:
# pyre-ignore [2, 24] # pyre-ignore [2, 24]
def __call__(self, *input, **kwargs) -> Awaitable: def __call__(self, *input, **kwargs) -> Awaitable:
"""
Perform the pipelined forward pass operation.
Args:
*input: Variable-length positional arguments.
**kwargs: Variable-length keyword arguments.
Returns:
Awaitable: An awaitable object representing the forward pass result.
"""
assert self._name in self._context.input_dist_requests assert self._name in self._context.input_dist_requests
request = self._context.input_dist_requests[self._name] request = self._context.input_dist_requests[self._name]
assert isinstance(request, Awaitable) assert isinstance(request, Awaitable)
@ -230,10 +381,22 @@ class PipelinedForward:
@property @property
def name(self) -> str: def name(self) -> str:
"""
Get the name of the forward pass.
Returns:
str: The name of the forward pass.
"""
return self._name return self._name
@property @property
def args(self) -> List[ArgInfo]: def args(self) -> List[ArgInfo]:
"""
Get the list of argument information for the forward pass.
Returns:
List[ArgInfo]: List of argument information.
"""
return self._args return self._args
@ -242,6 +405,17 @@ def _start_data_dist(
batch: In, batch: In,
context: TrainPipelineContext, context: TrainPipelineContext,
) -> None: ) -> None:
"""
Start data distribution for a list of pipelined modules.
Args:
pipelined_modules (List[ShardedModule]): List of ShardedModule instances representing pipelined modules.
batch (In): The input batch.
context (TrainPipelineContext): The training pipeline context.
Returns:
None: This function doesn't return a value.
"""
context.input_dist_requests.clear() context.input_dist_requests.clear()
context.module_contexts.clear() context.module_contexts.clear()
for module in pipelined_modules: for module in pipelined_modules:
@ -286,9 +460,17 @@ def _get_node_args_helper(
feature_processor_arguments: Optional[List[Node]] = None, feature_processor_arguments: Optional[List[Node]] = None,
) -> Tuple[List[ArgInfo], int]: ) -> Tuple[List[ArgInfo], int]:
""" """
Goes through the args/kwargs of a node and arranges them into a list of `ArgInfo`s. Goes through the args/kwargs of a node and arranges them into a list of `ArgInfo`s.
It also counts the number of (args + kwargs) found. It also counts the number of (args + kwargs) found.
"""
Args:
arguments: The arguments to process.
num_found: The current count of arguments found.
feature_processor_arguments: Optional list of feature processor arguments.
Returns:
Tuple[List[ArgInfo], int]: A tuple containing a list of `ArgInfo` objects and the updated count of arguments found.
"""
arg_info_list = [ArgInfo([], [], None) for _ in range(len(arguments))] arg_info_list = [ArgInfo([], [], None) for _ in range(len(arguments))]
for arg, arg_info in zip(arguments, arg_info_list): for arg, arg_info in zip(arguments, arg_info_list):
@ -332,6 +514,16 @@ def _get_node_args_helper(
def _get_node_args( def _get_node_args(
node: Node, feature_processor_nodes: Optional[List[Node]] = None node: Node, feature_processor_nodes: Optional[List[Node]] = None
) -> Tuple[List[ArgInfo], int]: ) -> Tuple[List[ArgInfo], int]:
"""
Get argument information for a given node.
Args:
node (Node): The node to process.
feature_processor_nodes (Optional[List[Node]]): Optional list of feature processor nodes.
Returns:
Tuple[List[ArgInfo], int]: A tuple containing a list of `ArgInfo` objects and the number of arguments found.
"""
num_found = 0 num_found = 0
pos_arg_info_list, num_found = _get_node_args_helper( pos_arg_info_list, num_found = _get_node_args_helper(
node.args, num_found, feature_processor_nodes node.args, num_found, feature_processor_nodes
@ -351,6 +543,17 @@ def _get_unsharded_module_names_helper(
path: str, path: str,
unsharded_module_names: Set[str], unsharded_module_names: Set[str],
) -> bool: ) -> bool:
"""
Get the names of unsharded modules in a model.
Args:
model (torch.nn.Module): The model to analyze.
path (str): The current path in the model hierarchy.
unsharded_module_names (Set[str]): A set to store the names of unsharded modules.
Returns:
bool: True if any sharded modules were found in the hierarchy, False otherwise.
"""
sharded_children = set() sharded_children = set()
for name, child in model.named_children(): for name, child in model.named_children():
curr_path = path + name curr_path = path + name
@ -375,8 +578,14 @@ def _get_unsharded_module_names_helper(
def _get_unsharded_module_names(model: torch.nn.Module) -> List[str]: def _get_unsharded_module_names(model: torch.nn.Module) -> List[str]:
""" """
Returns a list of top level modules do not contain any sharded sub modules. Returns a list of top-level modules that do not contain any sharded sub-modules.
"""
Args:
model (torch.nn.Module): The model to analyze.
Returns:
List[str]: A list of top-level module names without sharded sub-modules.
"""
unsharded_module_names: Set[str] = set() unsharded_module_names: Set[str] = set()
_get_unsharded_module_names_helper( _get_unsharded_module_names_helper(
@ -392,6 +601,21 @@ def _rewrite_model( # noqa C901
context: TrainPipelineContext, context: TrainPipelineContext,
dist_stream: Optional[torch.cuda.streams.Stream], dist_stream: Optional[torch.cuda.streams.Stream],
) -> List[ShardedModule]: ) -> List[ShardedModule]:
"""
Rewrites the model to enable pipelined execution for selected sharded modules.
This function traces the input model using a custom tracer and identifies sharded modules
that can be pipelined. It then creates PipelinedForward objects for these modules,
which enable pipelining during training.
Args:
model (torch.nn.Module): The input model to be rewritten.
context (TrainPipelineContext): The context containing information needed for pipelining.
dist_stream (Optional[torch.cuda.streams.Stream]): The CUDA stream for data distribution.
Returns:
List[ShardedModule]: A list of sharded modules that have been rewritten for pipelined execution.
"""
# Get underlying nn.Module # Get underlying nn.Module
if isinstance(model, DistributedModelParallel): if isinstance(model, DistributedModelParallel):
@ -442,20 +666,32 @@ def _rewrite_model( # noqa C901
class TrainPipelineSparseDist(TrainPipeline[In, Out]): class TrainPipelineSparseDist(TrainPipeline[In, Out]):
""" """
This pipeline overlaps device transfer, and `ShardedModule.input_dist()` with This pipeline overlaps device transfer, and `ShardedModule.input_dist()` with
forward and backward. This helps hide the all2all latency while preserving the forward and backward. This helps hide the all2all latency while preserving the
training forward / backward ordering. training forward / backward ordering.
stage 3: forward, backward - uses default CUDA stream stage 3: forward, backward - uses default CUDA stream
stage 2: ShardedModule.input_dist() - uses data_dist CUDA stream stage 2: ShardedModule.input_dist() - uses data_dist CUDA stream
stage 1: device transfer - uses memcpy CUDA stream stage 1: device transfer - uses memcpy CUDA stream
`ShardedModule.input_dist()` is only done for top-level modules in the call graph. `ShardedModule.input_dist()` is only done for top-level modules in the call graph.
To be considered a top-level module, a module can only depend on 'getattr' calls on To be considered a top-level module, a module can only depend on 'getattr' calls on
input. input.
Input model must be symbolically traceable with the exception of `ShardedModule` and Input model must be symbolically traceable with the exception of `ShardedModule` and
`DistributedDataParallel` modules. `DistributedDataParallel` modules.
"""
Args:
model (torch.nn.Module): The input model to be used for training.
optimizer (torch.optim.Optimizer): The optimizer for updating model parameters.
device (torch.device): The device where training will be performed.
enable_amp (bool, optional): Whether to enable automatic mixed precision (AMP). Defaults to False.
enable_grad_scaling (bool, optional): Whether to enable gradient scaling. Defaults to True.
grad_accum (int, optional): Number of gradient accumulation steps. Defaults to None.
Attributes:
synced_pipeline_id (Dict[int, int]): A dictionary to track synchronized pipelines.
"""
synced_pipeline_id: Dict[int, int] = {} synced_pipeline_id: Dict[int, int] = {}
@ -468,6 +704,17 @@ class TrainPipelineSparseDist(TrainPipeline[In, Out]):
enable_grad_scaling: bool = True, enable_grad_scaling: bool = True,
grad_accum: Optional[int] = None, grad_accum: Optional[int] = None,
) -> None: ) -> None:
"""
Initializes the training pipeline.
Args:
model (torch.nn.Module): The input model to be used for training.
optimizer (torch.optim.Optimizer): The optimizer for updating model parameters.
device (torch.device): The device where training will be performed.
enable_amp (bool, optional): Whether to enable automatic mixed precision (AMP). Defaults to False.
enable_grad_scaling (bool, optional): Whether to enable gradient scaling. Defaults to True.
grad_accum (int, optional): Number of gradient accumulation steps. Defaults to None.
"""
self._model = model self._model = model
self._optimizer = optimizer self._optimizer = optimizer
self._device = device self._device = device
@ -504,6 +751,13 @@ class TrainPipelineSparseDist(TrainPipeline[In, Out]):
self._grad_accum = grad_accum self._grad_accum = grad_accum
def _connect(self, dataloader_iter: Iterator[In]) -> None: def _connect(self, dataloader_iter: Iterator[In]) -> None:
"""
Connects the training pipeline to data and prepares for forward and backward passes.
Args:
dataloader_iter (Iterator[In]): An iterator providing input data batches.
"""
# batch 1 # batch 1
with torch.cuda.stream(self._memcpy_stream): with torch.cuda.stream(self._memcpy_stream):
batch_i = next(dataloader_iter) batch_i = next(dataloader_iter)
@ -524,13 +778,20 @@ class TrainPipelineSparseDist(TrainPipeline[In, Out]):
def progress(self, dataloader_iter: Iterator[In]) -> Out: def progress(self, dataloader_iter: Iterator[In]) -> Out:
""" """
NOTE: This method has been updated to perform gradient accumulation. Progresses through the training pipeline, performing forward and backward passes.
If `_grad_accum` is set, then loss values are scaled by this amount and
optimizer update/reset is skipped for `_grad_accum` calls of `progress`
(congruent to training steps), and then update/reset on every `_grad_accum`th
step.
""" NOTE: This method has been updated to perform gradient accumulation.
If `_grad_accum` is set, then loss values are scaled by this amount and
optimizer update/reset is skipped for `_grad_accum` calls of `progress`
(congruent to training steps), and then update/reset on every `_grad_accum`th
step.
Args:
dataloader_iter (Iterator[In]): An iterator providing input data batches.
Returns:
Out: The output of the forward pass.
"""
should_step_optimizer = ( should_step_optimizer = (
self._grad_accum is not None self._grad_accum is not None
and self._progress_calls > 0 and self._progress_calls > 0
@ -617,9 +878,9 @@ class TrainPipelineSparseDist(TrainPipeline[In, Out]):
def _sync_pipeline(self) -> None: def _sync_pipeline(self) -> None:
""" """
Syncs `PipelinedForward` for sharded modules with context and dist stream of the Syncs `PipelinedForward` for sharded modules with context and dist stream of the
current train pipeline. Used when switching between train pipelines for the same current train pipeline. Used when switching between train pipelines for the same
model. model.
""" """
for module in self._pipelined_modules: for module in self._pipelined_modules:
module.forward._context = self._context module.forward._context = self._context