mirror of
https://github.com/twitter/the-algorithm-ml.git
synced 2025-03-09 06:45:15 +01:00
update
This commit is contained in:
parent
799254345f
commit
9bb0986079
@ -39,12 +39,42 @@ Out = TypeVar("Out")
|
|||||||
|
|
||||||
|
|
||||||
class TrainPipeline(abc.ABC, Generic[In, Out]):
|
class TrainPipeline(abc.ABC, Generic[In, Out]):
|
||||||
|
"""
|
||||||
|
Abstract base class for training pipelines.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
In (TypeVar): Input data type.
|
||||||
|
Out (TypeVar): Output data type.
|
||||||
|
|
||||||
|
Methods:
|
||||||
|
progress(dataloader_iter: Iterator[In]) -> Out: Abstract method to make progress in the training pipeline.
|
||||||
|
"""
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def progress(self, dataloader_iter: Iterator[In]) -> Out:
|
def progress(self, dataloader_iter: Iterator[In]) -> Out:
|
||||||
|
"""
|
||||||
|
Make progress in the training pipeline.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataloader_iter (Iterator[In]): An iterator over input data.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Out: The output data.
|
||||||
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def _to_device(batch: In, device: torch.device, non_blocking: bool) -> In:
|
def _to_device(batch: In, device: torch.device, non_blocking: bool) -> In:
|
||||||
|
"""
|
||||||
|
Move a batch of data to a specified device.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch (In): The input batch.
|
||||||
|
device (torch.device): The target device.
|
||||||
|
non_blocking (bool): If True, move the data asynchronously.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
In: The batch of data on the target device.
|
||||||
|
"""
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
batch, (torch.Tensor, Pipelineable)
|
batch, (torch.Tensor, Pipelineable)
|
||||||
), f"{type(batch)} must implement Pipelineable interface"
|
), f"{type(batch)} must implement Pipelineable interface"
|
||||||
@ -52,6 +82,16 @@ def _to_device(batch: In, device: torch.device, non_blocking: bool) -> In:
|
|||||||
|
|
||||||
|
|
||||||
def _wait_for_batch(batch: In, stream: Optional[torch.cuda.streams.Stream]) -> None:
|
def _wait_for_batch(batch: In, stream: Optional[torch.cuda.streams.Stream]) -> None:
|
||||||
|
"""
|
||||||
|
Wait for a batch of data on a specified stream.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch (In): The input batch.
|
||||||
|
stream (Optional[Stream]): The CUDA stream to wait for.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
This function is used for managing asynchronous CUDA operations.
|
||||||
|
"""
|
||||||
if stream is None:
|
if stream is None:
|
||||||
return
|
return
|
||||||
torch.cuda.current_stream().wait_stream(stream)
|
torch.cuda.current_stream().wait_stream(stream)
|
||||||
@ -72,11 +112,26 @@ def _wait_for_batch(batch: In, stream: Optional[torch.cuda.streams.Stream]) -> N
|
|||||||
|
|
||||||
class TrainPipelineBase(TrainPipeline[In, Out]):
|
class TrainPipelineBase(TrainPipeline[In, Out]):
|
||||||
"""
|
"""
|
||||||
This class runs training iterations using a pipeline of two stages, each as a CUDA
|
This class runs training iterations using a pipeline of two stages, each as a CUDA
|
||||||
stream, namely, the current (default) stream and `self._memcpy_stream`. For each
|
stream, namely, the current (default) stream and `self._memcpy_stream`. For each
|
||||||
iteration, `self._memcpy_stream` moves the input from host (CPU) memory to GPU
|
iteration, `self._memcpy_stream` moves the input from host (CPU) memory to GPU
|
||||||
memory, and the default stream runs forward, backward, and optimization.
|
memory, and the default stream runs forward, backward, and optimization.
|
||||||
"""
|
|
||||||
|
Attributes:
|
||||||
|
In (TypeVar): Input data type.
|
||||||
|
Out (TypeVar): Output data type.
|
||||||
|
|
||||||
|
Methods:
|
||||||
|
__init__(model: torch.nn.Module, optimizer: torch.optim.Optimizer, device: torch.device) -> None:
|
||||||
|
Initialize the TrainPipelineBase.
|
||||||
|
|
||||||
|
_connect(dataloader_iter: Iterator[In]) -> None:
|
||||||
|
Establish a connection to the data loader and move the input data to the GPU.
|
||||||
|
|
||||||
|
progress(dataloader_iter: Iterator[In]) -> Out:
|
||||||
|
Execute a training iteration, including forward and backward passes.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -84,6 +139,14 @@ class TrainPipelineBase(TrainPipeline[In, Out]):
|
|||||||
optimizer: torch.optim.Optimizer,
|
optimizer: torch.optim.Optimizer,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize the TrainPipelineBase.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (torch.nn.Module): The PyTorch model to be trained.
|
||||||
|
optimizer (torch.optim.Optimizer): The optimizer used for training.
|
||||||
|
device (torch.device): The target device for training (CPU or GPU).
|
||||||
|
"""
|
||||||
self._model = model
|
self._model = model
|
||||||
self._optimizer = optimizer
|
self._optimizer = optimizer
|
||||||
self._device = device
|
self._device = device
|
||||||
@ -94,6 +157,12 @@ class TrainPipelineBase(TrainPipeline[In, Out]):
|
|||||||
self._connected = False
|
self._connected = False
|
||||||
|
|
||||||
def _connect(self, dataloader_iter: Iterator[In]) -> None:
|
def _connect(self, dataloader_iter: Iterator[In]) -> None:
|
||||||
|
"""
|
||||||
|
Establish a connection to the data loader and move the input data to the GPU.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataloader_iter (Iterator[In]): An iterator over input data.
|
||||||
|
"""
|
||||||
cur_batch = next(dataloader_iter)
|
cur_batch = next(dataloader_iter)
|
||||||
self._cur_batch = cur_batch
|
self._cur_batch = cur_batch
|
||||||
with torch.cuda.stream(self._memcpy_stream):
|
with torch.cuda.stream(self._memcpy_stream):
|
||||||
@ -101,6 +170,15 @@ class TrainPipelineBase(TrainPipeline[In, Out]):
|
|||||||
self._connected = True
|
self._connected = True
|
||||||
|
|
||||||
def progress(self, dataloader_iter: Iterator[In]) -> Out:
|
def progress(self, dataloader_iter: Iterator[In]) -> Out:
|
||||||
|
"""
|
||||||
|
Execute a training iteration, including forward and backward passes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataloader_iter (Iterator[In]): An iterator over input data.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Out: The output data.
|
||||||
|
"""
|
||||||
if not self._connected:
|
if not self._connected:
|
||||||
self._connect(dataloader_iter)
|
self._connect(dataloader_iter)
|
||||||
|
|
||||||
@ -139,6 +217,16 @@ class TrainPipelineBase(TrainPipeline[In, Out]):
|
|||||||
|
|
||||||
|
|
||||||
class Tracer(torch.fx.Tracer):
|
class Tracer(torch.fx.Tracer):
|
||||||
|
"""
|
||||||
|
Custom tracer class for PyTorch models.
|
||||||
|
|
||||||
|
This tracer is used to trace PyTorch models while also considering specific leaf modules and buffer proxying settings.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
proxy_buffer_attributes (bool): Flag to enable/disable proxying buffers during tracing.
|
||||||
|
_leaf_modules (List[str]): List of qualified names of leaf modules.
|
||||||
|
"""
|
||||||
|
|
||||||
# Disable proxying buffers during tracing. Ideally, proxying buffers would
|
# Disable proxying buffers during tracing. Ideally, proxying buffers would
|
||||||
# be disabled, but some models are currently mutating buffer values, which
|
# be disabled, but some models are currently mutating buffer values, which
|
||||||
# causes errors during tracing. If those models can be rewritten to not do
|
# causes errors during tracing. If those models can be rewritten to not do
|
||||||
@ -146,10 +234,26 @@ class Tracer(torch.fx.Tracer):
|
|||||||
proxy_buffer_attributes = False
|
proxy_buffer_attributes = False
|
||||||
|
|
||||||
def __init__(self, leaf_modules: Optional[List[str]] = None) -> None:
|
def __init__(self, leaf_modules: Optional[List[str]] = None) -> None:
|
||||||
|
"""
|
||||||
|
Initialize the Tracer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
leaf_modules (Optional[List[str]]): List of qualified names of leaf modules to consider as leaf nodes during tracing.
|
||||||
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._leaf_modules: List[str] = leaf_modules if leaf_modules is not None else []
|
self._leaf_modules: List[str] = leaf_modules if leaf_modules is not None else []
|
||||||
|
|
||||||
def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool:
|
def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a module is a leaf module during tracing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
m (torch.nn.Module): The PyTorch module.
|
||||||
|
module_qualified_name (str): The qualified name of the module.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the module is considered a leaf module, False otherwise.
|
||||||
|
"""
|
||||||
if isinstance(m, ShardedModule) or module_qualified_name in self._leaf_modules:
|
if isinstance(m, ShardedModule) or module_qualified_name in self._leaf_modules:
|
||||||
return True
|
return True
|
||||||
return super().is_leaf_module(m, module_qualified_name)
|
return super().is_leaf_module(m, module_qualified_name)
|
||||||
@ -157,6 +261,15 @@ class Tracer(torch.fx.Tracer):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TrainPipelineContext:
|
class TrainPipelineContext:
|
||||||
|
"""
|
||||||
|
Dataclass to store information related to the training pipeline context.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
input_dist_requests (Dict[str, Awaitable[Any]]): A dictionary of input distribution requests.
|
||||||
|
module_contexts (Dict[str, Multistreamable]): A dictionary of module contexts.
|
||||||
|
feature_processor_forwards (List[Any]): A list of feature processor forwards.
|
||||||
|
"""
|
||||||
|
|
||||||
# pyre-ignore [4]
|
# pyre-ignore [4]
|
||||||
input_dist_requests: Dict[str, Awaitable[Any]] = field(default_factory=dict)
|
input_dist_requests: Dict[str, Awaitable[Any]] = field(default_factory=dict)
|
||||||
module_contexts: Dict[str, Multistreamable] = field(default_factory=dict)
|
module_contexts: Dict[str, Multistreamable] = field(default_factory=dict)
|
||||||
@ -166,6 +279,14 @@ class TrainPipelineContext:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ArgInfo:
|
class ArgInfo:
|
||||||
|
"""
|
||||||
|
Dataclass to store information about arguments in the training pipeline.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
input_attrs (List[str]): List of attribute names of the input batch.
|
||||||
|
is_getitems (List[bool]): List of boolean values indicating whether the argument is accessed using getitem.
|
||||||
|
name (Optional[str]): Name for the keyword argument in the pipelined forward() call or None for positional arguments.
|
||||||
|
"""
|
||||||
# attributes of input batch, e.g. batch.attr1.attr2 call
|
# attributes of input batch, e.g. batch.attr1.attr2 call
|
||||||
# will produce ["attr1", "attr2"]
|
# will produce ["attr1", "attr2"]
|
||||||
input_attrs: List[str]
|
input_attrs: List[str]
|
||||||
@ -177,6 +298,16 @@ class ArgInfo:
|
|||||||
|
|
||||||
|
|
||||||
class PipelinedForward:
|
class PipelinedForward:
|
||||||
|
"""
|
||||||
|
Represents a pipelined forward pass operation.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
name (str): The name of the forward pass.
|
||||||
|
args (List[ArgInfo]): List of argument information for the forward pass.
|
||||||
|
module (ShardedModule): The sharded module associated with the forward pass.
|
||||||
|
context (TrainPipelineContext): The training pipeline context.
|
||||||
|
dist_stream (Optional[torch.cuda.streams.Stream]): CUDA stream for distributed processing.
|
||||||
|
"""
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
name: str,
|
name: str,
|
||||||
@ -185,6 +316,16 @@ class PipelinedForward:
|
|||||||
context: TrainPipelineContext,
|
context: TrainPipelineContext,
|
||||||
dist_stream: Optional[torch.cuda.streams.Stream],
|
dist_stream: Optional[torch.cuda.streams.Stream],
|
||||||
) -> None:
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize a PipelinedForward instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str): The name of the forward pass.
|
||||||
|
args (List[ArgInfo]): List of argument information for the forward pass.
|
||||||
|
module (ShardedModule): The sharded module associated with the forward pass.
|
||||||
|
context (TrainPipelineContext): The training pipeline context.
|
||||||
|
dist_stream (Optional[torch.cuda.streams.Stream]): CUDA stream for distributed processing.
|
||||||
|
"""
|
||||||
self._name = name
|
self._name = name
|
||||||
self._args = args
|
self._args = args
|
||||||
self._module = module
|
self._module = module
|
||||||
@ -193,6 +334,16 @@ class PipelinedForward:
|
|||||||
|
|
||||||
# pyre-ignore [2, 24]
|
# pyre-ignore [2, 24]
|
||||||
def __call__(self, *input, **kwargs) -> Awaitable:
|
def __call__(self, *input, **kwargs) -> Awaitable:
|
||||||
|
"""
|
||||||
|
Perform the pipelined forward pass operation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
*input: Variable-length positional arguments.
|
||||||
|
**kwargs: Variable-length keyword arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Awaitable: An awaitable object representing the forward pass result.
|
||||||
|
"""
|
||||||
assert self._name in self._context.input_dist_requests
|
assert self._name in self._context.input_dist_requests
|
||||||
request = self._context.input_dist_requests[self._name]
|
request = self._context.input_dist_requests[self._name]
|
||||||
assert isinstance(request, Awaitable)
|
assert isinstance(request, Awaitable)
|
||||||
@ -230,10 +381,22 @@ class PipelinedForward:
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
|
"""
|
||||||
|
Get the name of the forward pass.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The name of the forward pass.
|
||||||
|
"""
|
||||||
return self._name
|
return self._name
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def args(self) -> List[ArgInfo]:
|
def args(self) -> List[ArgInfo]:
|
||||||
|
"""
|
||||||
|
Get the list of argument information for the forward pass.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[ArgInfo]: List of argument information.
|
||||||
|
"""
|
||||||
return self._args
|
return self._args
|
||||||
|
|
||||||
|
|
||||||
@ -242,6 +405,17 @@ def _start_data_dist(
|
|||||||
batch: In,
|
batch: In,
|
||||||
context: TrainPipelineContext,
|
context: TrainPipelineContext,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
"""
|
||||||
|
Start data distribution for a list of pipelined modules.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pipelined_modules (List[ShardedModule]): List of ShardedModule instances representing pipelined modules.
|
||||||
|
batch (In): The input batch.
|
||||||
|
context (TrainPipelineContext): The training pipeline context.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None: This function doesn't return a value.
|
||||||
|
"""
|
||||||
context.input_dist_requests.clear()
|
context.input_dist_requests.clear()
|
||||||
context.module_contexts.clear()
|
context.module_contexts.clear()
|
||||||
for module in pipelined_modules:
|
for module in pipelined_modules:
|
||||||
@ -286,9 +460,17 @@ def _get_node_args_helper(
|
|||||||
feature_processor_arguments: Optional[List[Node]] = None,
|
feature_processor_arguments: Optional[List[Node]] = None,
|
||||||
) -> Tuple[List[ArgInfo], int]:
|
) -> Tuple[List[ArgInfo], int]:
|
||||||
"""
|
"""
|
||||||
Goes through the args/kwargs of a node and arranges them into a list of `ArgInfo`s.
|
Goes through the args/kwargs of a node and arranges them into a list of `ArgInfo`s.
|
||||||
It also counts the number of (args + kwargs) found.
|
It also counts the number of (args + kwargs) found.
|
||||||
"""
|
|
||||||
|
Args:
|
||||||
|
arguments: The arguments to process.
|
||||||
|
num_found: The current count of arguments found.
|
||||||
|
feature_processor_arguments: Optional list of feature processor arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[List[ArgInfo], int]: A tuple containing a list of `ArgInfo` objects and the updated count of arguments found.
|
||||||
|
"""
|
||||||
|
|
||||||
arg_info_list = [ArgInfo([], [], None) for _ in range(len(arguments))]
|
arg_info_list = [ArgInfo([], [], None) for _ in range(len(arguments))]
|
||||||
for arg, arg_info in zip(arguments, arg_info_list):
|
for arg, arg_info in zip(arguments, arg_info_list):
|
||||||
@ -332,6 +514,16 @@ def _get_node_args_helper(
|
|||||||
def _get_node_args(
|
def _get_node_args(
|
||||||
node: Node, feature_processor_nodes: Optional[List[Node]] = None
|
node: Node, feature_processor_nodes: Optional[List[Node]] = None
|
||||||
) -> Tuple[List[ArgInfo], int]:
|
) -> Tuple[List[ArgInfo], int]:
|
||||||
|
"""
|
||||||
|
Get argument information for a given node.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node (Node): The node to process.
|
||||||
|
feature_processor_nodes (Optional[List[Node]]): Optional list of feature processor nodes.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[List[ArgInfo], int]: A tuple containing a list of `ArgInfo` objects and the number of arguments found.
|
||||||
|
"""
|
||||||
num_found = 0
|
num_found = 0
|
||||||
pos_arg_info_list, num_found = _get_node_args_helper(
|
pos_arg_info_list, num_found = _get_node_args_helper(
|
||||||
node.args, num_found, feature_processor_nodes
|
node.args, num_found, feature_processor_nodes
|
||||||
@ -351,6 +543,17 @@ def _get_unsharded_module_names_helper(
|
|||||||
path: str,
|
path: str,
|
||||||
unsharded_module_names: Set[str],
|
unsharded_module_names: Set[str],
|
||||||
) -> bool:
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Get the names of unsharded modules in a model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (torch.nn.Module): The model to analyze.
|
||||||
|
path (str): The current path in the model hierarchy.
|
||||||
|
unsharded_module_names (Set[str]): A set to store the names of unsharded modules.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if any sharded modules were found in the hierarchy, False otherwise.
|
||||||
|
"""
|
||||||
sharded_children = set()
|
sharded_children = set()
|
||||||
for name, child in model.named_children():
|
for name, child in model.named_children():
|
||||||
curr_path = path + name
|
curr_path = path + name
|
||||||
@ -375,8 +578,14 @@ def _get_unsharded_module_names_helper(
|
|||||||
|
|
||||||
def _get_unsharded_module_names(model: torch.nn.Module) -> List[str]:
|
def _get_unsharded_module_names(model: torch.nn.Module) -> List[str]:
|
||||||
"""
|
"""
|
||||||
Returns a list of top level modules do not contain any sharded sub modules.
|
Returns a list of top-level modules that do not contain any sharded sub-modules.
|
||||||
"""
|
|
||||||
|
Args:
|
||||||
|
model (torch.nn.Module): The model to analyze.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[str]: A list of top-level module names without sharded sub-modules.
|
||||||
|
"""
|
||||||
|
|
||||||
unsharded_module_names: Set[str] = set()
|
unsharded_module_names: Set[str] = set()
|
||||||
_get_unsharded_module_names_helper(
|
_get_unsharded_module_names_helper(
|
||||||
@ -392,6 +601,21 @@ def _rewrite_model( # noqa C901
|
|||||||
context: TrainPipelineContext,
|
context: TrainPipelineContext,
|
||||||
dist_stream: Optional[torch.cuda.streams.Stream],
|
dist_stream: Optional[torch.cuda.streams.Stream],
|
||||||
) -> List[ShardedModule]:
|
) -> List[ShardedModule]:
|
||||||
|
"""
|
||||||
|
Rewrites the model to enable pipelined execution for selected sharded modules.
|
||||||
|
|
||||||
|
This function traces the input model using a custom tracer and identifies sharded modules
|
||||||
|
that can be pipelined. It then creates PipelinedForward objects for these modules,
|
||||||
|
which enable pipelining during training.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (torch.nn.Module): The input model to be rewritten.
|
||||||
|
context (TrainPipelineContext): The context containing information needed for pipelining.
|
||||||
|
dist_stream (Optional[torch.cuda.streams.Stream]): The CUDA stream for data distribution.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[ShardedModule]: A list of sharded modules that have been rewritten for pipelined execution.
|
||||||
|
"""
|
||||||
|
|
||||||
# Get underlying nn.Module
|
# Get underlying nn.Module
|
||||||
if isinstance(model, DistributedModelParallel):
|
if isinstance(model, DistributedModelParallel):
|
||||||
@ -442,20 +666,32 @@ def _rewrite_model( # noqa C901
|
|||||||
class TrainPipelineSparseDist(TrainPipeline[In, Out]):
|
class TrainPipelineSparseDist(TrainPipeline[In, Out]):
|
||||||
"""
|
"""
|
||||||
This pipeline overlaps device transfer, and `ShardedModule.input_dist()` with
|
This pipeline overlaps device transfer, and `ShardedModule.input_dist()` with
|
||||||
forward and backward. This helps hide the all2all latency while preserving the
|
forward and backward. This helps hide the all2all latency while preserving the
|
||||||
training forward / backward ordering.
|
training forward / backward ordering.
|
||||||
|
|
||||||
stage 3: forward, backward - uses default CUDA stream
|
stage 3: forward, backward - uses default CUDA stream
|
||||||
stage 2: ShardedModule.input_dist() - uses data_dist CUDA stream
|
stage 2: ShardedModule.input_dist() - uses data_dist CUDA stream
|
||||||
stage 1: device transfer - uses memcpy CUDA stream
|
stage 1: device transfer - uses memcpy CUDA stream
|
||||||
|
|
||||||
`ShardedModule.input_dist()` is only done for top-level modules in the call graph.
|
`ShardedModule.input_dist()` is only done for top-level modules in the call graph.
|
||||||
To be considered a top-level module, a module can only depend on 'getattr' calls on
|
To be considered a top-level module, a module can only depend on 'getattr' calls on
|
||||||
input.
|
input.
|
||||||
|
|
||||||
Input model must be symbolically traceable with the exception of `ShardedModule` and
|
Input model must be symbolically traceable with the exception of `ShardedModule` and
|
||||||
`DistributedDataParallel` modules.
|
`DistributedDataParallel` modules.
|
||||||
"""
|
|
||||||
|
Args:
|
||||||
|
model (torch.nn.Module): The input model to be used for training.
|
||||||
|
optimizer (torch.optim.Optimizer): The optimizer for updating model parameters.
|
||||||
|
device (torch.device): The device where training will be performed.
|
||||||
|
enable_amp (bool, optional): Whether to enable automatic mixed precision (AMP). Defaults to False.
|
||||||
|
enable_grad_scaling (bool, optional): Whether to enable gradient scaling. Defaults to True.
|
||||||
|
grad_accum (int, optional): Number of gradient accumulation steps. Defaults to None.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
synced_pipeline_id (Dict[int, int]): A dictionary to track synchronized pipelines.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
synced_pipeline_id: Dict[int, int] = {}
|
synced_pipeline_id: Dict[int, int] = {}
|
||||||
|
|
||||||
@ -468,6 +704,17 @@ class TrainPipelineSparseDist(TrainPipeline[In, Out]):
|
|||||||
enable_grad_scaling: bool = True,
|
enable_grad_scaling: bool = True,
|
||||||
grad_accum: Optional[int] = None,
|
grad_accum: Optional[int] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initializes the training pipeline.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (torch.nn.Module): The input model to be used for training.
|
||||||
|
optimizer (torch.optim.Optimizer): The optimizer for updating model parameters.
|
||||||
|
device (torch.device): The device where training will be performed.
|
||||||
|
enable_amp (bool, optional): Whether to enable automatic mixed precision (AMP). Defaults to False.
|
||||||
|
enable_grad_scaling (bool, optional): Whether to enable gradient scaling. Defaults to True.
|
||||||
|
grad_accum (int, optional): Number of gradient accumulation steps. Defaults to None.
|
||||||
|
"""
|
||||||
self._model = model
|
self._model = model
|
||||||
self._optimizer = optimizer
|
self._optimizer = optimizer
|
||||||
self._device = device
|
self._device = device
|
||||||
@ -504,6 +751,13 @@ class TrainPipelineSparseDist(TrainPipeline[In, Out]):
|
|||||||
self._grad_accum = grad_accum
|
self._grad_accum = grad_accum
|
||||||
|
|
||||||
def _connect(self, dataloader_iter: Iterator[In]) -> None:
|
def _connect(self, dataloader_iter: Iterator[In]) -> None:
|
||||||
|
"""
|
||||||
|
Connects the training pipeline to data and prepares for forward and backward passes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataloader_iter (Iterator[In]): An iterator providing input data batches.
|
||||||
|
"""
|
||||||
|
|
||||||
# batch 1
|
# batch 1
|
||||||
with torch.cuda.stream(self._memcpy_stream):
|
with torch.cuda.stream(self._memcpy_stream):
|
||||||
batch_i = next(dataloader_iter)
|
batch_i = next(dataloader_iter)
|
||||||
@ -524,13 +778,20 @@ class TrainPipelineSparseDist(TrainPipeline[In, Out]):
|
|||||||
|
|
||||||
def progress(self, dataloader_iter: Iterator[In]) -> Out:
|
def progress(self, dataloader_iter: Iterator[In]) -> Out:
|
||||||
"""
|
"""
|
||||||
NOTE: This method has been updated to perform gradient accumulation.
|
Progresses through the training pipeline, performing forward and backward passes.
|
||||||
If `_grad_accum` is set, then loss values are scaled by this amount and
|
|
||||||
optimizer update/reset is skipped for `_grad_accum` calls of `progress`
|
|
||||||
(congruent to training steps), and then update/reset on every `_grad_accum`th
|
|
||||||
step.
|
|
||||||
|
|
||||||
"""
|
NOTE: This method has been updated to perform gradient accumulation.
|
||||||
|
If `_grad_accum` is set, then loss values are scaled by this amount and
|
||||||
|
optimizer update/reset is skipped for `_grad_accum` calls of `progress`
|
||||||
|
(congruent to training steps), and then update/reset on every `_grad_accum`th
|
||||||
|
step.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataloader_iter (Iterator[In]): An iterator providing input data batches.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Out: The output of the forward pass.
|
||||||
|
"""
|
||||||
should_step_optimizer = (
|
should_step_optimizer = (
|
||||||
self._grad_accum is not None
|
self._grad_accum is not None
|
||||||
and self._progress_calls > 0
|
and self._progress_calls > 0
|
||||||
@ -617,9 +878,9 @@ class TrainPipelineSparseDist(TrainPipeline[In, Out]):
|
|||||||
|
|
||||||
def _sync_pipeline(self) -> None:
|
def _sync_pipeline(self) -> None:
|
||||||
"""
|
"""
|
||||||
Syncs `PipelinedForward` for sharded modules with context and dist stream of the
|
Syncs `PipelinedForward` for sharded modules with context and dist stream of the
|
||||||
current train pipeline. Used when switching between train pipelines for the same
|
current train pipeline. Used when switching between train pipelines for the same
|
||||||
model.
|
model.
|
||||||
"""
|
"""
|
||||||
for module in self._pipelined_modules:
|
for module in self._pipelined_modules:
|
||||||
module.forward._context = self._context
|
module.forward._context = self._context
|
||||||
|
Loading…
x
Reference in New Issue
Block a user