mirror of
https://github.com/twitter/the-algorithm-ml.git
synced 2025-03-09 06:45:15 +01:00
update
This commit is contained in:
parent
799254345f
commit
9bb0986079
@ -39,12 +39,42 @@ Out = TypeVar("Out")
|
||||
|
||||
|
||||
class TrainPipeline(abc.ABC, Generic[In, Out]):
|
||||
"""
|
||||
Abstract base class for training pipelines.
|
||||
|
||||
Attributes:
|
||||
In (TypeVar): Input data type.
|
||||
Out (TypeVar): Output data type.
|
||||
|
||||
Methods:
|
||||
progress(dataloader_iter: Iterator[In]) -> Out: Abstract method to make progress in the training pipeline.
|
||||
"""
|
||||
@abc.abstractmethod
|
||||
def progress(self, dataloader_iter: Iterator[In]) -> Out:
|
||||
"""
|
||||
Make progress in the training pipeline.
|
||||
|
||||
Args:
|
||||
dataloader_iter (Iterator[In]): An iterator over input data.
|
||||
|
||||
Returns:
|
||||
Out: The output data.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
def _to_device(batch: In, device: torch.device, non_blocking: bool) -> In:
|
||||
"""
|
||||
Move a batch of data to a specified device.
|
||||
|
||||
Args:
|
||||
batch (In): The input batch.
|
||||
device (torch.device): The target device.
|
||||
non_blocking (bool): If True, move the data asynchronously.
|
||||
|
||||
Returns:
|
||||
In: The batch of data on the target device.
|
||||
"""
|
||||
assert isinstance(
|
||||
batch, (torch.Tensor, Pipelineable)
|
||||
), f"{type(batch)} must implement Pipelineable interface"
|
||||
@ -52,6 +82,16 @@ def _to_device(batch: In, device: torch.device, non_blocking: bool) -> In:
|
||||
|
||||
|
||||
def _wait_for_batch(batch: In, stream: Optional[torch.cuda.streams.Stream]) -> None:
|
||||
"""
|
||||
Wait for a batch of data on a specified stream.
|
||||
|
||||
Args:
|
||||
batch (In): The input batch.
|
||||
stream (Optional[Stream]): The CUDA stream to wait for.
|
||||
|
||||
Note:
|
||||
This function is used for managing asynchronous CUDA operations.
|
||||
"""
|
||||
if stream is None:
|
||||
return
|
||||
torch.cuda.current_stream().wait_stream(stream)
|
||||
@ -72,11 +112,26 @@ def _wait_for_batch(batch: In, stream: Optional[torch.cuda.streams.Stream]) -> N
|
||||
|
||||
class TrainPipelineBase(TrainPipeline[In, Out]):
|
||||
"""
|
||||
This class runs training iterations using a pipeline of two stages, each as a CUDA
|
||||
stream, namely, the current (default) stream and `self._memcpy_stream`. For each
|
||||
iteration, `self._memcpy_stream` moves the input from host (CPU) memory to GPU
|
||||
memory, and the default stream runs forward, backward, and optimization.
|
||||
"""
|
||||
This class runs training iterations using a pipeline of two stages, each as a CUDA
|
||||
stream, namely, the current (default) stream and `self._memcpy_stream`. For each
|
||||
iteration, `self._memcpy_stream` moves the input from host (CPU) memory to GPU
|
||||
memory, and the default stream runs forward, backward, and optimization.
|
||||
|
||||
Attributes:
|
||||
In (TypeVar): Input data type.
|
||||
Out (TypeVar): Output data type.
|
||||
|
||||
Methods:
|
||||
__init__(model: torch.nn.Module, optimizer: torch.optim.Optimizer, device: torch.device) -> None:
|
||||
Initialize the TrainPipelineBase.
|
||||
|
||||
_connect(dataloader_iter: Iterator[In]) -> None:
|
||||
Establish a connection to the data loader and move the input data to the GPU.
|
||||
|
||||
progress(dataloader_iter: Iterator[In]) -> Out:
|
||||
Execute a training iteration, including forward and backward passes.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -84,6 +139,14 @@ class TrainPipelineBase(TrainPipeline[In, Out]):
|
||||
optimizer: torch.optim.Optimizer,
|
||||
device: torch.device,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the TrainPipelineBase.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): The PyTorch model to be trained.
|
||||
optimizer (torch.optim.Optimizer): The optimizer used for training.
|
||||
device (torch.device): The target device for training (CPU or GPU).
|
||||
"""
|
||||
self._model = model
|
||||
self._optimizer = optimizer
|
||||
self._device = device
|
||||
@ -94,6 +157,12 @@ class TrainPipelineBase(TrainPipeline[In, Out]):
|
||||
self._connected = False
|
||||
|
||||
def _connect(self, dataloader_iter: Iterator[In]) -> None:
|
||||
"""
|
||||
Establish a connection to the data loader and move the input data to the GPU.
|
||||
|
||||
Args:
|
||||
dataloader_iter (Iterator[In]): An iterator over input data.
|
||||
"""
|
||||
cur_batch = next(dataloader_iter)
|
||||
self._cur_batch = cur_batch
|
||||
with torch.cuda.stream(self._memcpy_stream):
|
||||
@ -101,6 +170,15 @@ class TrainPipelineBase(TrainPipeline[In, Out]):
|
||||
self._connected = True
|
||||
|
||||
def progress(self, dataloader_iter: Iterator[In]) -> Out:
|
||||
"""
|
||||
Execute a training iteration, including forward and backward passes.
|
||||
|
||||
Args:
|
||||
dataloader_iter (Iterator[In]): An iterator over input data.
|
||||
|
||||
Returns:
|
||||
Out: The output data.
|
||||
"""
|
||||
if not self._connected:
|
||||
self._connect(dataloader_iter)
|
||||
|
||||
@ -139,6 +217,16 @@ class TrainPipelineBase(TrainPipeline[In, Out]):
|
||||
|
||||
|
||||
class Tracer(torch.fx.Tracer):
|
||||
"""
|
||||
Custom tracer class for PyTorch models.
|
||||
|
||||
This tracer is used to trace PyTorch models while also considering specific leaf modules and buffer proxying settings.
|
||||
|
||||
Attributes:
|
||||
proxy_buffer_attributes (bool): Flag to enable/disable proxying buffers during tracing.
|
||||
_leaf_modules (List[str]): List of qualified names of leaf modules.
|
||||
"""
|
||||
|
||||
# Disable proxying buffers during tracing. Ideally, proxying buffers would
|
||||
# be disabled, but some models are currently mutating buffer values, which
|
||||
# causes errors during tracing. If those models can be rewritten to not do
|
||||
@ -146,10 +234,26 @@ class Tracer(torch.fx.Tracer):
|
||||
proxy_buffer_attributes = False
|
||||
|
||||
def __init__(self, leaf_modules: Optional[List[str]] = None) -> None:
|
||||
"""
|
||||
Initialize the Tracer.
|
||||
|
||||
Args:
|
||||
leaf_modules (Optional[List[str]]): List of qualified names of leaf modules to consider as leaf nodes during tracing.
|
||||
"""
|
||||
super().__init__()
|
||||
self._leaf_modules: List[str] = leaf_modules if leaf_modules is not None else []
|
||||
|
||||
def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool:
|
||||
"""
|
||||
Check if a module is a leaf module during tracing.
|
||||
|
||||
Args:
|
||||
m (torch.nn.Module): The PyTorch module.
|
||||
module_qualified_name (str): The qualified name of the module.
|
||||
|
||||
Returns:
|
||||
bool: True if the module is considered a leaf module, False otherwise.
|
||||
"""
|
||||
if isinstance(m, ShardedModule) or module_qualified_name in self._leaf_modules:
|
||||
return True
|
||||
return super().is_leaf_module(m, module_qualified_name)
|
||||
@ -157,6 +261,15 @@ class Tracer(torch.fx.Tracer):
|
||||
|
||||
@dataclass
|
||||
class TrainPipelineContext:
|
||||
"""
|
||||
Dataclass to store information related to the training pipeline context.
|
||||
|
||||
Attributes:
|
||||
input_dist_requests (Dict[str, Awaitable[Any]]): A dictionary of input distribution requests.
|
||||
module_contexts (Dict[str, Multistreamable]): A dictionary of module contexts.
|
||||
feature_processor_forwards (List[Any]): A list of feature processor forwards.
|
||||
"""
|
||||
|
||||
# pyre-ignore [4]
|
||||
input_dist_requests: Dict[str, Awaitable[Any]] = field(default_factory=dict)
|
||||
module_contexts: Dict[str, Multistreamable] = field(default_factory=dict)
|
||||
@ -166,6 +279,14 @@ class TrainPipelineContext:
|
||||
|
||||
@dataclass
|
||||
class ArgInfo:
|
||||
"""
|
||||
Dataclass to store information about arguments in the training pipeline.
|
||||
|
||||
Attributes:
|
||||
input_attrs (List[str]): List of attribute names of the input batch.
|
||||
is_getitems (List[bool]): List of boolean values indicating whether the argument is accessed using getitem.
|
||||
name (Optional[str]): Name for the keyword argument in the pipelined forward() call or None for positional arguments.
|
||||
"""
|
||||
# attributes of input batch, e.g. batch.attr1.attr2 call
|
||||
# will produce ["attr1", "attr2"]
|
||||
input_attrs: List[str]
|
||||
@ -177,6 +298,16 @@ class ArgInfo:
|
||||
|
||||
|
||||
class PipelinedForward:
|
||||
"""
|
||||
Represents a pipelined forward pass operation.
|
||||
|
||||
Attributes:
|
||||
name (str): The name of the forward pass.
|
||||
args (List[ArgInfo]): List of argument information for the forward pass.
|
||||
module (ShardedModule): The sharded module associated with the forward pass.
|
||||
context (TrainPipelineContext): The training pipeline context.
|
||||
dist_stream (Optional[torch.cuda.streams.Stream]): CUDA stream for distributed processing.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
@ -185,6 +316,16 @@ class PipelinedForward:
|
||||
context: TrainPipelineContext,
|
||||
dist_stream: Optional[torch.cuda.streams.Stream],
|
||||
) -> None:
|
||||
"""
|
||||
Initialize a PipelinedForward instance.
|
||||
|
||||
Args:
|
||||
name (str): The name of the forward pass.
|
||||
args (List[ArgInfo]): List of argument information for the forward pass.
|
||||
module (ShardedModule): The sharded module associated with the forward pass.
|
||||
context (TrainPipelineContext): The training pipeline context.
|
||||
dist_stream (Optional[torch.cuda.streams.Stream]): CUDA stream for distributed processing.
|
||||
"""
|
||||
self._name = name
|
||||
self._args = args
|
||||
self._module = module
|
||||
@ -193,6 +334,16 @@ class PipelinedForward:
|
||||
|
||||
# pyre-ignore [2, 24]
|
||||
def __call__(self, *input, **kwargs) -> Awaitable:
|
||||
"""
|
||||
Perform the pipelined forward pass operation.
|
||||
|
||||
Args:
|
||||
*input: Variable-length positional arguments.
|
||||
**kwargs: Variable-length keyword arguments.
|
||||
|
||||
Returns:
|
||||
Awaitable: An awaitable object representing the forward pass result.
|
||||
"""
|
||||
assert self._name in self._context.input_dist_requests
|
||||
request = self._context.input_dist_requests[self._name]
|
||||
assert isinstance(request, Awaitable)
|
||||
@ -230,10 +381,22 @@ class PipelinedForward:
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""
|
||||
Get the name of the forward pass.
|
||||
|
||||
Returns:
|
||||
str: The name of the forward pass.
|
||||
"""
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def args(self) -> List[ArgInfo]:
|
||||
"""
|
||||
Get the list of argument information for the forward pass.
|
||||
|
||||
Returns:
|
||||
List[ArgInfo]: List of argument information.
|
||||
"""
|
||||
return self._args
|
||||
|
||||
|
||||
@ -242,6 +405,17 @@ def _start_data_dist(
|
||||
batch: In,
|
||||
context: TrainPipelineContext,
|
||||
) -> None:
|
||||
"""
|
||||
Start data distribution for a list of pipelined modules.
|
||||
|
||||
Args:
|
||||
pipelined_modules (List[ShardedModule]): List of ShardedModule instances representing pipelined modules.
|
||||
batch (In): The input batch.
|
||||
context (TrainPipelineContext): The training pipeline context.
|
||||
|
||||
Returns:
|
||||
None: This function doesn't return a value.
|
||||
"""
|
||||
context.input_dist_requests.clear()
|
||||
context.module_contexts.clear()
|
||||
for module in pipelined_modules:
|
||||
@ -286,9 +460,17 @@ def _get_node_args_helper(
|
||||
feature_processor_arguments: Optional[List[Node]] = None,
|
||||
) -> Tuple[List[ArgInfo], int]:
|
||||
"""
|
||||
Goes through the args/kwargs of a node and arranges them into a list of `ArgInfo`s.
|
||||
It also counts the number of (args + kwargs) found.
|
||||
"""
|
||||
Goes through the args/kwargs of a node and arranges them into a list of `ArgInfo`s.
|
||||
It also counts the number of (args + kwargs) found.
|
||||
|
||||
Args:
|
||||
arguments: The arguments to process.
|
||||
num_found: The current count of arguments found.
|
||||
feature_processor_arguments: Optional list of feature processor arguments.
|
||||
|
||||
Returns:
|
||||
Tuple[List[ArgInfo], int]: A tuple containing a list of `ArgInfo` objects and the updated count of arguments found.
|
||||
"""
|
||||
|
||||
arg_info_list = [ArgInfo([], [], None) for _ in range(len(arguments))]
|
||||
for arg, arg_info in zip(arguments, arg_info_list):
|
||||
@ -332,6 +514,16 @@ def _get_node_args_helper(
|
||||
def _get_node_args(
|
||||
node: Node, feature_processor_nodes: Optional[List[Node]] = None
|
||||
) -> Tuple[List[ArgInfo], int]:
|
||||
"""
|
||||
Get argument information for a given node.
|
||||
|
||||
Args:
|
||||
node (Node): The node to process.
|
||||
feature_processor_nodes (Optional[List[Node]]): Optional list of feature processor nodes.
|
||||
|
||||
Returns:
|
||||
Tuple[List[ArgInfo], int]: A tuple containing a list of `ArgInfo` objects and the number of arguments found.
|
||||
"""
|
||||
num_found = 0
|
||||
pos_arg_info_list, num_found = _get_node_args_helper(
|
||||
node.args, num_found, feature_processor_nodes
|
||||
@ -351,6 +543,17 @@ def _get_unsharded_module_names_helper(
|
||||
path: str,
|
||||
unsharded_module_names: Set[str],
|
||||
) -> bool:
|
||||
"""
|
||||
Get the names of unsharded modules in a model.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): The model to analyze.
|
||||
path (str): The current path in the model hierarchy.
|
||||
unsharded_module_names (Set[str]): A set to store the names of unsharded modules.
|
||||
|
||||
Returns:
|
||||
bool: True if any sharded modules were found in the hierarchy, False otherwise.
|
||||
"""
|
||||
sharded_children = set()
|
||||
for name, child in model.named_children():
|
||||
curr_path = path + name
|
||||
@ -375,8 +578,14 @@ def _get_unsharded_module_names_helper(
|
||||
|
||||
def _get_unsharded_module_names(model: torch.nn.Module) -> List[str]:
|
||||
"""
|
||||
Returns a list of top level modules do not contain any sharded sub modules.
|
||||
"""
|
||||
Returns a list of top-level modules that do not contain any sharded sub-modules.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): The model to analyze.
|
||||
|
||||
Returns:
|
||||
List[str]: A list of top-level module names without sharded sub-modules.
|
||||
"""
|
||||
|
||||
unsharded_module_names: Set[str] = set()
|
||||
_get_unsharded_module_names_helper(
|
||||
@ -392,6 +601,21 @@ def _rewrite_model( # noqa C901
|
||||
context: TrainPipelineContext,
|
||||
dist_stream: Optional[torch.cuda.streams.Stream],
|
||||
) -> List[ShardedModule]:
|
||||
"""
|
||||
Rewrites the model to enable pipelined execution for selected sharded modules.
|
||||
|
||||
This function traces the input model using a custom tracer and identifies sharded modules
|
||||
that can be pipelined. It then creates PipelinedForward objects for these modules,
|
||||
which enable pipelining during training.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): The input model to be rewritten.
|
||||
context (TrainPipelineContext): The context containing information needed for pipelining.
|
||||
dist_stream (Optional[torch.cuda.streams.Stream]): The CUDA stream for data distribution.
|
||||
|
||||
Returns:
|
||||
List[ShardedModule]: A list of sharded modules that have been rewritten for pipelined execution.
|
||||
"""
|
||||
|
||||
# Get underlying nn.Module
|
||||
if isinstance(model, DistributedModelParallel):
|
||||
@ -442,20 +666,32 @@ def _rewrite_model( # noqa C901
|
||||
class TrainPipelineSparseDist(TrainPipeline[In, Out]):
|
||||
"""
|
||||
This pipeline overlaps device transfer, and `ShardedModule.input_dist()` with
|
||||
forward and backward. This helps hide the all2all latency while preserving the
|
||||
training forward / backward ordering.
|
||||
forward and backward. This helps hide the all2all latency while preserving the
|
||||
training forward / backward ordering.
|
||||
|
||||
stage 3: forward, backward - uses default CUDA stream
|
||||
stage 2: ShardedModule.input_dist() - uses data_dist CUDA stream
|
||||
stage 1: device transfer - uses memcpy CUDA stream
|
||||
stage 3: forward, backward - uses default CUDA stream
|
||||
stage 2: ShardedModule.input_dist() - uses data_dist CUDA stream
|
||||
stage 1: device transfer - uses memcpy CUDA stream
|
||||
|
||||
`ShardedModule.input_dist()` is only done for top-level modules in the call graph.
|
||||
To be considered a top-level module, a module can only depend on 'getattr' calls on
|
||||
input.
|
||||
`ShardedModule.input_dist()` is only done for top-level modules in the call graph.
|
||||
To be considered a top-level module, a module can only depend on 'getattr' calls on
|
||||
input.
|
||||
|
||||
Input model must be symbolically traceable with the exception of `ShardedModule` and
|
||||
`DistributedDataParallel` modules.
|
||||
"""
|
||||
Input model must be symbolically traceable with the exception of `ShardedModule` and
|
||||
`DistributedDataParallel` modules.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): The input model to be used for training.
|
||||
optimizer (torch.optim.Optimizer): The optimizer for updating model parameters.
|
||||
device (torch.device): The device where training will be performed.
|
||||
enable_amp (bool, optional): Whether to enable automatic mixed precision (AMP). Defaults to False.
|
||||
enable_grad_scaling (bool, optional): Whether to enable gradient scaling. Defaults to True.
|
||||
grad_accum (int, optional): Number of gradient accumulation steps. Defaults to None.
|
||||
|
||||
Attributes:
|
||||
synced_pipeline_id (Dict[int, int]): A dictionary to track synchronized pipelines.
|
||||
|
||||
"""
|
||||
|
||||
synced_pipeline_id: Dict[int, int] = {}
|
||||
|
||||
@ -468,6 +704,17 @@ class TrainPipelineSparseDist(TrainPipeline[In, Out]):
|
||||
enable_grad_scaling: bool = True,
|
||||
grad_accum: Optional[int] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initializes the training pipeline.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): The input model to be used for training.
|
||||
optimizer (torch.optim.Optimizer): The optimizer for updating model parameters.
|
||||
device (torch.device): The device where training will be performed.
|
||||
enable_amp (bool, optional): Whether to enable automatic mixed precision (AMP). Defaults to False.
|
||||
enable_grad_scaling (bool, optional): Whether to enable gradient scaling. Defaults to True.
|
||||
grad_accum (int, optional): Number of gradient accumulation steps. Defaults to None.
|
||||
"""
|
||||
self._model = model
|
||||
self._optimizer = optimizer
|
||||
self._device = device
|
||||
@ -504,6 +751,13 @@ class TrainPipelineSparseDist(TrainPipeline[In, Out]):
|
||||
self._grad_accum = grad_accum
|
||||
|
||||
def _connect(self, dataloader_iter: Iterator[In]) -> None:
|
||||
"""
|
||||
Connects the training pipeline to data and prepares for forward and backward passes.
|
||||
|
||||
Args:
|
||||
dataloader_iter (Iterator[In]): An iterator providing input data batches.
|
||||
"""
|
||||
|
||||
# batch 1
|
||||
with torch.cuda.stream(self._memcpy_stream):
|
||||
batch_i = next(dataloader_iter)
|
||||
@ -524,13 +778,20 @@ class TrainPipelineSparseDist(TrainPipeline[In, Out]):
|
||||
|
||||
def progress(self, dataloader_iter: Iterator[In]) -> Out:
|
||||
"""
|
||||
NOTE: This method has been updated to perform gradient accumulation.
|
||||
If `_grad_accum` is set, then loss values are scaled by this amount and
|
||||
optimizer update/reset is skipped for `_grad_accum` calls of `progress`
|
||||
(congruent to training steps), and then update/reset on every `_grad_accum`th
|
||||
step.
|
||||
Progresses through the training pipeline, performing forward and backward passes.
|
||||
|
||||
"""
|
||||
NOTE: This method has been updated to perform gradient accumulation.
|
||||
If `_grad_accum` is set, then loss values are scaled by this amount and
|
||||
optimizer update/reset is skipped for `_grad_accum` calls of `progress`
|
||||
(congruent to training steps), and then update/reset on every `_grad_accum`th
|
||||
step.
|
||||
|
||||
Args:
|
||||
dataloader_iter (Iterator[In]): An iterator providing input data batches.
|
||||
|
||||
Returns:
|
||||
Out: The output of the forward pass.
|
||||
"""
|
||||
should_step_optimizer = (
|
||||
self._grad_accum is not None
|
||||
and self._progress_calls > 0
|
||||
@ -617,9 +878,9 @@ class TrainPipelineSparseDist(TrainPipeline[In, Out]):
|
||||
|
||||
def _sync_pipeline(self) -> None:
|
||||
"""
|
||||
Syncs `PipelinedForward` for sharded modules with context and dist stream of the
|
||||
current train pipeline. Used when switching between train pipelines for the same
|
||||
model.
|
||||
Syncs `PipelinedForward` for sharded modules with context and dist stream of the
|
||||
current train pipeline. Used when switching between train pipelines for the same
|
||||
model.
|
||||
"""
|
||||
for module in self._pipelined_modules:
|
||||
module.forward._context = self._context
|
||||
|
Loading…
x
Reference in New Issue
Block a user