diff --git a/core/train_pipeline.py b/core/train_pipeline.py index cde587e..3209988 100644 --- a/core/train_pipeline.py +++ b/core/train_pipeline.py @@ -39,12 +39,42 @@ Out = TypeVar("Out") class TrainPipeline(abc.ABC, Generic[In, Out]): + """ + Abstract base class for training pipelines. + + Attributes: + In (TypeVar): Input data type. + Out (TypeVar): Output data type. + + Methods: + progress(dataloader_iter: Iterator[In]) -> Out: Abstract method to make progress in the training pipeline. + """ @abc.abstractmethod def progress(self, dataloader_iter: Iterator[In]) -> Out: + """ + Make progress in the training pipeline. + + Args: + dataloader_iter (Iterator[In]): An iterator over input data. + + Returns: + Out: The output data. + """ pass def _to_device(batch: In, device: torch.device, non_blocking: bool) -> In: + """ + Move a batch of data to a specified device. + + Args: + batch (In): The input batch. + device (torch.device): The target device. + non_blocking (bool): If True, move the data asynchronously. + + Returns: + In: The batch of data on the target device. + """ assert isinstance( batch, (torch.Tensor, Pipelineable) ), f"{type(batch)} must implement Pipelineable interface" @@ -52,6 +82,16 @@ def _to_device(batch: In, device: torch.device, non_blocking: bool) -> In: def _wait_for_batch(batch: In, stream: Optional[torch.cuda.streams.Stream]) -> None: + """ + Wait for a batch of data on a specified stream. + + Args: + batch (In): The input batch. + stream (Optional[Stream]): The CUDA stream to wait for. + + Note: + This function is used for managing asynchronous CUDA operations. + """ if stream is None: return torch.cuda.current_stream().wait_stream(stream) @@ -72,11 +112,26 @@ def _wait_for_batch(batch: In, stream: Optional[torch.cuda.streams.Stream]) -> N class TrainPipelineBase(TrainPipeline[In, Out]): """ - This class runs training iterations using a pipeline of two stages, each as a CUDA - stream, namely, the current (default) stream and `self._memcpy_stream`. For each - iteration, `self._memcpy_stream` moves the input from host (CPU) memory to GPU - memory, and the default stream runs forward, backward, and optimization. - """ + This class runs training iterations using a pipeline of two stages, each as a CUDA + stream, namely, the current (default) stream and `self._memcpy_stream`. For each + iteration, `self._memcpy_stream` moves the input from host (CPU) memory to GPU + memory, and the default stream runs forward, backward, and optimization. + + Attributes: + In (TypeVar): Input data type. + Out (TypeVar): Output data type. + + Methods: + __init__(model: torch.nn.Module, optimizer: torch.optim.Optimizer, device: torch.device) -> None: + Initialize the TrainPipelineBase. + + _connect(dataloader_iter: Iterator[In]) -> None: + Establish a connection to the data loader and move the input data to the GPU. + + progress(dataloader_iter: Iterator[In]) -> Out: + Execute a training iteration, including forward and backward passes. + + """ def __init__( self, @@ -84,6 +139,14 @@ class TrainPipelineBase(TrainPipeline[In, Out]): optimizer: torch.optim.Optimizer, device: torch.device, ) -> None: + """ + Initialize the TrainPipelineBase. + + Args: + model (torch.nn.Module): The PyTorch model to be trained. + optimizer (torch.optim.Optimizer): The optimizer used for training. + device (torch.device): The target device for training (CPU or GPU). + """ self._model = model self._optimizer = optimizer self._device = device @@ -94,6 +157,12 @@ class TrainPipelineBase(TrainPipeline[In, Out]): self._connected = False def _connect(self, dataloader_iter: Iterator[In]) -> None: + """ + Establish a connection to the data loader and move the input data to the GPU. + + Args: + dataloader_iter (Iterator[In]): An iterator over input data. + """ cur_batch = next(dataloader_iter) self._cur_batch = cur_batch with torch.cuda.stream(self._memcpy_stream): @@ -101,6 +170,15 @@ class TrainPipelineBase(TrainPipeline[In, Out]): self._connected = True def progress(self, dataloader_iter: Iterator[In]) -> Out: + """ + Execute a training iteration, including forward and backward passes. + + Args: + dataloader_iter (Iterator[In]): An iterator over input data. + + Returns: + Out: The output data. + """ if not self._connected: self._connect(dataloader_iter) @@ -139,6 +217,16 @@ class TrainPipelineBase(TrainPipeline[In, Out]): class Tracer(torch.fx.Tracer): + """ + Custom tracer class for PyTorch models. + + This tracer is used to trace PyTorch models while also considering specific leaf modules and buffer proxying settings. + + Attributes: + proxy_buffer_attributes (bool): Flag to enable/disable proxying buffers during tracing. + _leaf_modules (List[str]): List of qualified names of leaf modules. + """ + # Disable proxying buffers during tracing. Ideally, proxying buffers would # be disabled, but some models are currently mutating buffer values, which # causes errors during tracing. If those models can be rewritten to not do @@ -146,10 +234,26 @@ class Tracer(torch.fx.Tracer): proxy_buffer_attributes = False def __init__(self, leaf_modules: Optional[List[str]] = None) -> None: + """ + Initialize the Tracer. + + Args: + leaf_modules (Optional[List[str]]): List of qualified names of leaf modules to consider as leaf nodes during tracing. + """ super().__init__() self._leaf_modules: List[str] = leaf_modules if leaf_modules is not None else [] def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool: + """ + Check if a module is a leaf module during tracing. + + Args: + m (torch.nn.Module): The PyTorch module. + module_qualified_name (str): The qualified name of the module. + + Returns: + bool: True if the module is considered a leaf module, False otherwise. + """ if isinstance(m, ShardedModule) or module_qualified_name in self._leaf_modules: return True return super().is_leaf_module(m, module_qualified_name) @@ -157,6 +261,15 @@ class Tracer(torch.fx.Tracer): @dataclass class TrainPipelineContext: + """ + Dataclass to store information related to the training pipeline context. + + Attributes: + input_dist_requests (Dict[str, Awaitable[Any]]): A dictionary of input distribution requests. + module_contexts (Dict[str, Multistreamable]): A dictionary of module contexts. + feature_processor_forwards (List[Any]): A list of feature processor forwards. + """ + # pyre-ignore [4] input_dist_requests: Dict[str, Awaitable[Any]] = field(default_factory=dict) module_contexts: Dict[str, Multistreamable] = field(default_factory=dict) @@ -166,6 +279,14 @@ class TrainPipelineContext: @dataclass class ArgInfo: + """ + Dataclass to store information about arguments in the training pipeline. + + Attributes: + input_attrs (List[str]): List of attribute names of the input batch. + is_getitems (List[bool]): List of boolean values indicating whether the argument is accessed using getitem. + name (Optional[str]): Name for the keyword argument in the pipelined forward() call or None for positional arguments. + """ # attributes of input batch, e.g. batch.attr1.attr2 call # will produce ["attr1", "attr2"] input_attrs: List[str] @@ -177,6 +298,16 @@ class ArgInfo: class PipelinedForward: + """ + Represents a pipelined forward pass operation. + + Attributes: + name (str): The name of the forward pass. + args (List[ArgInfo]): List of argument information for the forward pass. + module (ShardedModule): The sharded module associated with the forward pass. + context (TrainPipelineContext): The training pipeline context. + dist_stream (Optional[torch.cuda.streams.Stream]): CUDA stream for distributed processing. + """ def __init__( self, name: str, @@ -185,6 +316,16 @@ class PipelinedForward: context: TrainPipelineContext, dist_stream: Optional[torch.cuda.streams.Stream], ) -> None: + """ + Initialize a PipelinedForward instance. + + Args: + name (str): The name of the forward pass. + args (List[ArgInfo]): List of argument information for the forward pass. + module (ShardedModule): The sharded module associated with the forward pass. + context (TrainPipelineContext): The training pipeline context. + dist_stream (Optional[torch.cuda.streams.Stream]): CUDA stream for distributed processing. + """ self._name = name self._args = args self._module = module @@ -193,6 +334,16 @@ class PipelinedForward: # pyre-ignore [2, 24] def __call__(self, *input, **kwargs) -> Awaitable: + """ + Perform the pipelined forward pass operation. + + Args: + *input: Variable-length positional arguments. + **kwargs: Variable-length keyword arguments. + + Returns: + Awaitable: An awaitable object representing the forward pass result. + """ assert self._name in self._context.input_dist_requests request = self._context.input_dist_requests[self._name] assert isinstance(request, Awaitable) @@ -230,10 +381,22 @@ class PipelinedForward: @property def name(self) -> str: + """ + Get the name of the forward pass. + + Returns: + str: The name of the forward pass. + """ return self._name @property def args(self) -> List[ArgInfo]: + """ + Get the list of argument information for the forward pass. + + Returns: + List[ArgInfo]: List of argument information. + """ return self._args @@ -242,6 +405,17 @@ def _start_data_dist( batch: In, context: TrainPipelineContext, ) -> None: + """ + Start data distribution for a list of pipelined modules. + + Args: + pipelined_modules (List[ShardedModule]): List of ShardedModule instances representing pipelined modules. + batch (In): The input batch. + context (TrainPipelineContext): The training pipeline context. + + Returns: + None: This function doesn't return a value. + """ context.input_dist_requests.clear() context.module_contexts.clear() for module in pipelined_modules: @@ -286,9 +460,17 @@ def _get_node_args_helper( feature_processor_arguments: Optional[List[Node]] = None, ) -> Tuple[List[ArgInfo], int]: """ - Goes through the args/kwargs of a node and arranges them into a list of `ArgInfo`s. - It also counts the number of (args + kwargs) found. - """ + Goes through the args/kwargs of a node and arranges them into a list of `ArgInfo`s. + It also counts the number of (args + kwargs) found. + + Args: + arguments: The arguments to process. + num_found: The current count of arguments found. + feature_processor_arguments: Optional list of feature processor arguments. + + Returns: + Tuple[List[ArgInfo], int]: A tuple containing a list of `ArgInfo` objects and the updated count of arguments found. + """ arg_info_list = [ArgInfo([], [], None) for _ in range(len(arguments))] for arg, arg_info in zip(arguments, arg_info_list): @@ -332,6 +514,16 @@ def _get_node_args_helper( def _get_node_args( node: Node, feature_processor_nodes: Optional[List[Node]] = None ) -> Tuple[List[ArgInfo], int]: + """ + Get argument information for a given node. + + Args: + node (Node): The node to process. + feature_processor_nodes (Optional[List[Node]]): Optional list of feature processor nodes. + + Returns: + Tuple[List[ArgInfo], int]: A tuple containing a list of `ArgInfo` objects and the number of arguments found. + """ num_found = 0 pos_arg_info_list, num_found = _get_node_args_helper( node.args, num_found, feature_processor_nodes @@ -351,6 +543,17 @@ def _get_unsharded_module_names_helper( path: str, unsharded_module_names: Set[str], ) -> bool: + """ + Get the names of unsharded modules in a model. + + Args: + model (torch.nn.Module): The model to analyze. + path (str): The current path in the model hierarchy. + unsharded_module_names (Set[str]): A set to store the names of unsharded modules. + + Returns: + bool: True if any sharded modules were found in the hierarchy, False otherwise. + """ sharded_children = set() for name, child in model.named_children(): curr_path = path + name @@ -375,8 +578,14 @@ def _get_unsharded_module_names_helper( def _get_unsharded_module_names(model: torch.nn.Module) -> List[str]: """ - Returns a list of top level modules do not contain any sharded sub modules. - """ + Returns a list of top-level modules that do not contain any sharded sub-modules. + + Args: + model (torch.nn.Module): The model to analyze. + + Returns: + List[str]: A list of top-level module names without sharded sub-modules. + """ unsharded_module_names: Set[str] = set() _get_unsharded_module_names_helper( @@ -392,6 +601,21 @@ def _rewrite_model( # noqa C901 context: TrainPipelineContext, dist_stream: Optional[torch.cuda.streams.Stream], ) -> List[ShardedModule]: + """ + Rewrites the model to enable pipelined execution for selected sharded modules. + + This function traces the input model using a custom tracer and identifies sharded modules + that can be pipelined. It then creates PipelinedForward objects for these modules, + which enable pipelining during training. + + Args: + model (torch.nn.Module): The input model to be rewritten. + context (TrainPipelineContext): The context containing information needed for pipelining. + dist_stream (Optional[torch.cuda.streams.Stream]): The CUDA stream for data distribution. + + Returns: + List[ShardedModule]: A list of sharded modules that have been rewritten for pipelined execution. + """ # Get underlying nn.Module if isinstance(model, DistributedModelParallel): @@ -442,20 +666,32 @@ def _rewrite_model( # noqa C901 class TrainPipelineSparseDist(TrainPipeline[In, Out]): """ This pipeline overlaps device transfer, and `ShardedModule.input_dist()` with - forward and backward. This helps hide the all2all latency while preserving the - training forward / backward ordering. + forward and backward. This helps hide the all2all latency while preserving the + training forward / backward ordering. - stage 3: forward, backward - uses default CUDA stream - stage 2: ShardedModule.input_dist() - uses data_dist CUDA stream - stage 1: device transfer - uses memcpy CUDA stream + stage 3: forward, backward - uses default CUDA stream + stage 2: ShardedModule.input_dist() - uses data_dist CUDA stream + stage 1: device transfer - uses memcpy CUDA stream - `ShardedModule.input_dist()` is only done for top-level modules in the call graph. - To be considered a top-level module, a module can only depend on 'getattr' calls on - input. + `ShardedModule.input_dist()` is only done for top-level modules in the call graph. + To be considered a top-level module, a module can only depend on 'getattr' calls on + input. - Input model must be symbolically traceable with the exception of `ShardedModule` and - `DistributedDataParallel` modules. - """ + Input model must be symbolically traceable with the exception of `ShardedModule` and + `DistributedDataParallel` modules. + + Args: + model (torch.nn.Module): The input model to be used for training. + optimizer (torch.optim.Optimizer): The optimizer for updating model parameters. + device (torch.device): The device where training will be performed. + enable_amp (bool, optional): Whether to enable automatic mixed precision (AMP). Defaults to False. + enable_grad_scaling (bool, optional): Whether to enable gradient scaling. Defaults to True. + grad_accum (int, optional): Number of gradient accumulation steps. Defaults to None. + + Attributes: + synced_pipeline_id (Dict[int, int]): A dictionary to track synchronized pipelines. + + """ synced_pipeline_id: Dict[int, int] = {} @@ -468,6 +704,17 @@ class TrainPipelineSparseDist(TrainPipeline[In, Out]): enable_grad_scaling: bool = True, grad_accum: Optional[int] = None, ) -> None: + """ + Initializes the training pipeline. + + Args: + model (torch.nn.Module): The input model to be used for training. + optimizer (torch.optim.Optimizer): The optimizer for updating model parameters. + device (torch.device): The device where training will be performed. + enable_amp (bool, optional): Whether to enable automatic mixed precision (AMP). Defaults to False. + enable_grad_scaling (bool, optional): Whether to enable gradient scaling. Defaults to True. + grad_accum (int, optional): Number of gradient accumulation steps. Defaults to None. + """ self._model = model self._optimizer = optimizer self._device = device @@ -504,6 +751,13 @@ class TrainPipelineSparseDist(TrainPipeline[In, Out]): self._grad_accum = grad_accum def _connect(self, dataloader_iter: Iterator[In]) -> None: + """ + Connects the training pipeline to data and prepares for forward and backward passes. + + Args: + dataloader_iter (Iterator[In]): An iterator providing input data batches. + """ + # batch 1 with torch.cuda.stream(self._memcpy_stream): batch_i = next(dataloader_iter) @@ -524,13 +778,20 @@ class TrainPipelineSparseDist(TrainPipeline[In, Out]): def progress(self, dataloader_iter: Iterator[In]) -> Out: """ - NOTE: This method has been updated to perform gradient accumulation. - If `_grad_accum` is set, then loss values are scaled by this amount and - optimizer update/reset is skipped for `_grad_accum` calls of `progress` - (congruent to training steps), and then update/reset on every `_grad_accum`th - step. + Progresses through the training pipeline, performing forward and backward passes. - """ + NOTE: This method has been updated to perform gradient accumulation. + If `_grad_accum` is set, then loss values are scaled by this amount and + optimizer update/reset is skipped for `_grad_accum` calls of `progress` + (congruent to training steps), and then update/reset on every `_grad_accum`th + step. + + Args: + dataloader_iter (Iterator[In]): An iterator providing input data batches. + + Returns: + Out: The output of the forward pass. + """ should_step_optimizer = ( self._grad_accum is not None and self._progress_calls > 0 @@ -617,9 +878,9 @@ class TrainPipelineSparseDist(TrainPipeline[In, Out]): def _sync_pipeline(self) -> None: """ - Syncs `PipelinedForward` for sharded modules with context and dist stream of the - current train pipeline. Used when switching between train pipelines for the same - model. + Syncs `PipelinedForward` for sharded modules with context and dist stream of the + current train pipeline. Used when switching between train pipelines for the same + model. """ for module in self._pipelined_modules: module.forward._context = self._context