mirror of
https://github.com/twitter/the-algorithm-ml.git
synced 2025-01-25 05:11:10 +01:00
update
This commit is contained in:
parent
590e8b76fe
commit
db4ff958f6
@ -14,7 +14,24 @@ from absl import logging as logging
|
|||||||
|
|
||||||
|
|
||||||
def setup_absl_logging():
|
def setup_absl_logging():
|
||||||
"""Make sure that absl logging pushes to stdout rather than stderr."""
|
"""
|
||||||
|
Configure absl-py logging to direct log messages to stdout and apply a custom log message format.
|
||||||
|
|
||||||
|
This function ensures that log messages generated by the absl-py library are written to stdout
|
||||||
|
rather than stderr. It also applies a custom log message format that includes module, function,
|
||||||
|
line number, log level, and the log message content.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
This function should be called once at the beginning of your script or application to
|
||||||
|
configure absl-py logging.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
To use this function, simply call it at the start of your script:
|
||||||
|
```
|
||||||
|
setup_absl_logging()
|
||||||
|
```
|
||||||
|
|
||||||
|
"""
|
||||||
logging.get_absl_handler().python_handler.stream = sys.stdout
|
logging.get_absl_handler().python_handler.stream = sys.stdout
|
||||||
formatter = py_logging.Formatter(
|
formatter = py_logging.Formatter(
|
||||||
fmt="[%(module)s.%(funcName)s:%(lineno)s - %(levelname)s] %(message)s"
|
fmt="[%(module)s.%(funcName)s:%(lineno)s - %(levelname)s] %(message)s"
|
||||||
|
@ -5,6 +5,21 @@ from tml.ml_logging.torch_logging import logging
|
|||||||
|
|
||||||
class Testtlogging(unittest.TestCase):
|
class Testtlogging(unittest.TestCase):
|
||||||
def test_warn_once(self):
|
def test_warn_once(self):
|
||||||
|
"""
|
||||||
|
Test that warning messages are logged only once when using the assertLogs context manager.
|
||||||
|
|
||||||
|
This unit test checks the behavior of the logging system when warning messages are issued
|
||||||
|
multiple times within the same context. It uses the assertLogs context manager to capture
|
||||||
|
log messages at the INFO level and verifies that warning messages are logged only once.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
To use this test case, call it using a test runner like unittest:
|
||||||
|
```
|
||||||
|
python -m unittest your_test_module.TestLogging.test_warn_once
|
||||||
|
```
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
with self.assertLogs(level="INFO") as captured_logs:
|
with self.assertLogs(level="INFO") as captured_logs:
|
||||||
logging.info("first info")
|
logging.info("first info")
|
||||||
logging.warning("first warning")
|
logging.warning("first warning")
|
||||||
|
@ -18,7 +18,35 @@ import torch.distributed as dist
|
|||||||
|
|
||||||
|
|
||||||
def rank_specific(logger):
|
def rank_specific(logger):
|
||||||
"""Ensures that we only override a given logger once."""
|
"""
|
||||||
|
Customize logger behavior based on the distributed environment and rank.
|
||||||
|
|
||||||
|
This function allows for customizing the behavior of a logger based on the distributed environment and the rank
|
||||||
|
of the current process. It overrides standard logging methods (e.g., error, warning) to conditionally log messages
|
||||||
|
depending on the rank or limit the number of redundant logs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logger: The logger object to customize.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The customized logger.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
To use this function with the `logging` module:
|
||||||
|
```python
|
||||||
|
import logging
|
||||||
|
from rank_specific_logging import rank_specific
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
rank_specific(logger)
|
||||||
|
```
|
||||||
|
|
||||||
|
Customization:
|
||||||
|
- Messages are only logged if the distributed environment is not initialized or if the rank matches.
|
||||||
|
- The 'warning' method is limited to logging a single redundant warning.
|
||||||
|
- Logging from rank -1 is redirected to include the rank information.
|
||||||
|
|
||||||
|
"""
|
||||||
if hasattr(logger, "_ALREADY_OVERWRITTEN_TO_BE_RANK_SPECIFIC"):
|
if hasattr(logger, "_ALREADY_OVERWRITTEN_TO_BE_RANK_SPECIFIC"):
|
||||||
return logger
|
return logger
|
||||||
|
|
||||||
|
@ -8,11 +8,60 @@ import pydantic
|
|||||||
|
|
||||||
|
|
||||||
class PiecewiseConstant(base_config.BaseConfig):
|
class PiecewiseConstant(base_config.BaseConfig):
|
||||||
|
"""
|
||||||
|
Configuration for a piecewise constant learning rate schedule.
|
||||||
|
|
||||||
|
This configuration class allows you to specify a piecewise constant learning rate schedule
|
||||||
|
by defining boundaries and corresponding learning rate values.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
learning_rate_boundaries (List[int], optional): List of step boundaries at which
|
||||||
|
the learning rate will change. If None, no boundaries are defined.
|
||||||
|
learning_rate_values (List[float], optional): List of learning rate values
|
||||||
|
corresponding to the boundaries. If None, no values are defined.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
To configure a piecewise constant learning rate schedule, create an instance of this class
|
||||||
|
and set the attributes accordingly. For example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
piecewise_lr = PiecewiseConstant(
|
||||||
|
learning_rate_boundaries=[1000, 2000, 3000],
|
||||||
|
learning_rate_values=[0.1, 0.05, 0.01, 0.001]
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
Note:
|
||||||
|
The number of learning rate values should be one more than the number of boundaries.
|
||||||
|
|
||||||
|
"""
|
||||||
learning_rate_boundaries: typing.List[int] = pydantic.Field(None)
|
learning_rate_boundaries: typing.List[int] = pydantic.Field(None)
|
||||||
learning_rate_values: typing.List[float] = pydantic.Field(None)
|
learning_rate_values: typing.List[float] = pydantic.Field(None)
|
||||||
|
|
||||||
|
|
||||||
class LinearRampToConstant(base_config.BaseConfig):
|
class LinearRampToConstant(base_config.BaseConfig):
|
||||||
|
"""
|
||||||
|
Configuration for a linear ramp-up to constant learning rate schedule.
|
||||||
|
|
||||||
|
This configuration class allows you to specify a learning rate schedule that ramps up linearly
|
||||||
|
from zero to a constant value over a specified number of steps.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
learning_rate (float): The final constant learning rate.
|
||||||
|
num_ramp_steps (PositiveInt): Number of steps to ramp up the learning rate from zero.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
To configure a linear ramp-up to a constant learning rate, create an instance of this class
|
||||||
|
and set the attributes accordingly. For example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
linear_ramp_lr = LinearRampToConstant(
|
||||||
|
learning_rate=0.1,
|
||||||
|
num_ramp_steps=1000
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
"""
|
||||||
learning_rate: float
|
learning_rate: float
|
||||||
num_ramp_steps: pydantic.PositiveInt = pydantic.Field(
|
num_ramp_steps: pydantic.PositiveInt = pydantic.Field(
|
||||||
description="Number of steps to ramp this up from zero."
|
description="Number of steps to ramp this up from zero."
|
||||||
@ -20,6 +69,32 @@ class LinearRampToConstant(base_config.BaseConfig):
|
|||||||
|
|
||||||
|
|
||||||
class LinearRampToCosine(base_config.BaseConfig):
|
class LinearRampToCosine(base_config.BaseConfig):
|
||||||
|
"""
|
||||||
|
Configuration for a linear ramp-up to cosine decay learning rate schedule.
|
||||||
|
|
||||||
|
This configuration class allows you to specify a learning rate schedule that ramps up linearly
|
||||||
|
from zero, then decays following a cosine schedule to a final constant learning rate.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
learning_rate (float): The initial learning rate at the start of ramp-up.
|
||||||
|
final_learning_rate (float): The final constant learning rate after decay.
|
||||||
|
num_ramp_steps (PositiveInt): Number of steps to ramp up the learning rate from zero.
|
||||||
|
final_num_steps (PositiveInt): Final number of steps where decay stops.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
To configure a linear ramp-up to cosine decay learning rate, create an instance of this
|
||||||
|
class and set the attributes accordingly. For example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
ramp_to_cosine_lr = LinearRampToCosine(
|
||||||
|
learning_rate=0.01,
|
||||||
|
final_learning_rate=0.001,
|
||||||
|
num_ramp_steps=1000,
|
||||||
|
final_num_steps=5000
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
"""
|
||||||
learning_rate: float
|
learning_rate: float
|
||||||
final_learning_rate: float
|
final_learning_rate: float
|
||||||
num_ramp_steps: pydantic.PositiveInt = pydantic.Field(
|
num_ramp_steps: pydantic.PositiveInt = pydantic.Field(
|
||||||
@ -31,6 +106,41 @@ class LinearRampToCosine(base_config.BaseConfig):
|
|||||||
|
|
||||||
|
|
||||||
class LearningRate(base_config.BaseConfig):
|
class LearningRate(base_config.BaseConfig):
|
||||||
|
"""
|
||||||
|
Learning rate configuration for training.
|
||||||
|
|
||||||
|
This configuration class allows you to specify different learning rate schedules
|
||||||
|
for your training process.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
constant (float, optional): Constant learning rate to be used throughout training.
|
||||||
|
linear_ramp_to_cosine (LinearRampToCosine, optional): Learning rate that ramps up linearly
|
||||||
|
and then decays following a cosine schedule.
|
||||||
|
linear_ramp_to_constant (LinearRampToConstant, optional): Learning rate that ramps up
|
||||||
|
linearly and then remains constant.
|
||||||
|
piecewise_constant (PiecewiseConstant, optional): Learning rate that changes at specified
|
||||||
|
boundaries with corresponding values.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
To configure a learning rate schedule, create an instance of this class and set the
|
||||||
|
attributes accordingly. For example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
learning_rate = LearningRate(
|
||||||
|
constant=0.01,
|
||||||
|
linear_ramp_to_cosine=LinearRampToCosine(
|
||||||
|
learning_rate=0.1,
|
||||||
|
final_learning_rate=0.001,
|
||||||
|
num_ramp_steps=1000,
|
||||||
|
final_num_steps=5000
|
||||||
|
)
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Each learning rate schedule attribute can be set to `None` if not needed.
|
||||||
|
|
||||||
|
"""
|
||||||
constant: float = pydantic.Field(None, one_of="lr")
|
constant: float = pydantic.Field(None, one_of="lr")
|
||||||
linear_ramp_to_cosine: LinearRampToCosine = pydantic.Field(None, one_of="lr")
|
linear_ramp_to_cosine: LinearRampToCosine = pydantic.Field(None, one_of="lr")
|
||||||
linear_ramp_to_constant: LinearRampToConstant = pydantic.Field(None, one_of="lr")
|
linear_ramp_to_constant: LinearRampToConstant = pydantic.Field(None, one_of="lr")
|
||||||
@ -38,30 +148,166 @@ class LearningRate(base_config.BaseConfig):
|
|||||||
|
|
||||||
|
|
||||||
class OptimizerAlgorithmConfig(base_config.BaseConfig):
|
class OptimizerAlgorithmConfig(base_config.BaseConfig):
|
||||||
"""Base class for optimizer configurations."""
|
"""
|
||||||
|
Base class for optimizer configurations.
|
||||||
|
|
||||||
|
This base configuration class provides a structure for specifying various optimizer-related
|
||||||
|
settings, including the learning rate and different learning rate schedules.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
lr (float): The base learning rate used by the optimizer.
|
||||||
|
|
||||||
|
Subclasses should inherit from this base class and define additional attributes specific to
|
||||||
|
the optimizer algorithm they represent.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
To create a custom optimizer configuration, create a subclass of this base class and
|
||||||
|
define the necessary attributes. For example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
class MyOptimizerConfig(OptimizerAlgorithmConfig):
|
||||||
|
momentum: float = pydantic.Field(0.9, description="Momentum value for SGD.")
|
||||||
|
```
|
||||||
|
|
||||||
|
Note:
|
||||||
|
This base class does not include specific optimizer settings. Subclasses should define
|
||||||
|
the optimizer-specific attributes as needed.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
lr: float
|
lr: float
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
class AdamConfig(OptimizerAlgorithmConfig):
|
class AdamConfig(OptimizerAlgorithmConfig):
|
||||||
# see https://pytorch.org/docs/stable/generated/torch.optim.Adam.html#torch.optim.Adam
|
"""
|
||||||
|
Configuration for the Adam optimizer.
|
||||||
|
|
||||||
|
This configuration class allows you to specify the hyperparameters for the Adam optimizer.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
lr (float): The learning rate for optimization.
|
||||||
|
betas (Tuple[float, float], optional): Coefficients used for computing running averages
|
||||||
|
of gradient and squared gradient. Defaults to (0.9, 0.999).
|
||||||
|
eps (float, optional): A small constant added to the denominator for numerical stability.
|
||||||
|
Defaults to 1e-7.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
To configure the Adam optimizer, create an instance of this class and set the attributes
|
||||||
|
accordingly. For example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
adam_optimizer = AdamConfig(
|
||||||
|
lr=0.001,
|
||||||
|
betas=(0.9, 0.999),
|
||||||
|
eps=1e-8
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
See Also:
|
||||||
|
[PyTorch Adam Documentation](https://pytorch.org/docs/stable/generated/torch.optim.Adam.html#torch.optim.Adam)
|
||||||
|
|
||||||
|
"""
|
||||||
lr: float
|
lr: float
|
||||||
betas: typing.Tuple[float, float] = [0.9, 0.999]
|
betas: typing.Tuple[float, float] = [0.9, 0.999]
|
||||||
eps: float = 1e-7 # Numerical stability in denominator.
|
eps: float = 1e-7 # Numerical stability in denominator.
|
||||||
|
|
||||||
|
|
||||||
class SgdConfig(OptimizerAlgorithmConfig):
|
class SgdConfig(OptimizerAlgorithmConfig):
|
||||||
|
"""
|
||||||
|
Configuration for the Stochastic Gradient Descent (SGD) optimizer.
|
||||||
|
|
||||||
|
This configuration class allows you to specify the hyperparameters for the SGD optimizer.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
lr (float): The learning rate for optimization.
|
||||||
|
momentum (float, optional): The momentum factor for SGD. Defaults to 0.0.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
To configure the SGD optimizer, create an instance of this class and set the attributes
|
||||||
|
accordingly. For example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
sgd_optimizer = SgdConfig(
|
||||||
|
lr=0.01,
|
||||||
|
momentum=0.9
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
"""
|
||||||
lr: float
|
lr: float
|
||||||
momentum: float = 0.0
|
momentum: float = 0.0
|
||||||
|
|
||||||
|
|
||||||
class AdagradConfig(OptimizerAlgorithmConfig):
|
class AdagradConfig(OptimizerAlgorithmConfig):
|
||||||
|
"""
|
||||||
|
Configuration for the optimizer used during training.
|
||||||
|
|
||||||
|
This configuration class allows you to specify the optimizer for training, including
|
||||||
|
options for various optimizer algorithms.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
learning_rate (LearningRate, optional): Learning rate configuration. Defaults to None.
|
||||||
|
adam (AdamConfig, optional): Configuration for the Adam optimizer. Defaults to None.
|
||||||
|
sgd (SgdConfig, optional): Configuration for the Stochastic Gradient Descent (SGD) optimizer.
|
||||||
|
Defaults to None.
|
||||||
|
adagrad (AdagradConfig, optional): Configuration for the Adagrad optimizer. Defaults to None.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
To configure the optimizer for training, create an instance of this class and set the
|
||||||
|
attributes accordingly. For example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
optimizer_config = OptimizerConfig(
|
||||||
|
learning_rate=LearningRate(constant=0.001),
|
||||||
|
adam=AdamConfig(lr=0.001, betas=(0.9, 0.999), eps=1e-8)
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
"""
|
||||||
lr: float
|
lr: float
|
||||||
eps: float = 0
|
eps: float = 0
|
||||||
|
|
||||||
|
|
||||||
class OptimizerConfig(base_config.BaseConfig):
|
class OptimizerConfig(base_config.BaseConfig):
|
||||||
|
"""
|
||||||
|
Configuration for defining different optimizer algorithms and their parameters.
|
||||||
|
|
||||||
|
This class allows you to configure various optimizer algorithms such as Adam, SGD, and Adagrad,
|
||||||
|
along with their respective hyperparameters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
learning_rate (LearningRate): The learning rate configuration, which can include
|
||||||
|
constant learning rates or other learning rate schedules.
|
||||||
|
adam (AdamConfig): Configuration for the Adam optimizer.
|
||||||
|
sgd (SgdConfig): Configuration for the Stochastic Gradient Descent (SGD) optimizer.
|
||||||
|
adagrad (AdagradConfig): Configuration for the Adagrad optimizer.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
optimizer_config = OptimizerConfig(
|
||||||
|
learning_rate=LearningRate(constant=0.001),
|
||||||
|
adam=AdamConfig(lr=0.001, betas=(0.9, 0.999), eps=1e-8),
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
learning_rate (LearningRate): The learning rate configuration.
|
||||||
|
adam (AdamConfig): Configuration for the Adam optimizer.
|
||||||
|
sgd (SgdConfig): Configuration for the Stochastic Gradient Descent (SGD) optimizer.
|
||||||
|
adagrad (AdagradConfig): Configuration for the Adagrad optimizer.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
You can specify only one of the optimizer configurations (adam, sgd, or adagrad) in an
|
||||||
|
`OptimizerConfig` instance.
|
||||||
|
|
||||||
|
See Also:
|
||||||
|
- `LearningRate`: Configuration for specifying learning rates.
|
||||||
|
- `AdamConfig`: Configuration for the Adam optimizer.
|
||||||
|
- `SgdConfig`: Configuration for the Stochastic Gradient Descent (SGD) optimizer.
|
||||||
|
- `AdagradConfig`: Configuration for the Adagrad optimizer.
|
||||||
|
|
||||||
|
"""
|
||||||
learning_rate: LearningRate = pydantic.Field(
|
learning_rate: LearningRate = pydantic.Field(
|
||||||
None,
|
None,
|
||||||
description="Constant learning rates",
|
description="Constant learning rates",
|
||||||
@ -72,6 +318,33 @@ class OptimizerConfig(base_config.BaseConfig):
|
|||||||
|
|
||||||
|
|
||||||
def get_optimizer_algorithm_config(optimizer_config: OptimizerConfig):
|
def get_optimizer_algorithm_config(optimizer_config: OptimizerConfig):
|
||||||
|
"""
|
||||||
|
Get the optimizer algorithm configuration from the given `OptimizerConfig`.
|
||||||
|
|
||||||
|
This function extracts and returns the specific optimizer algorithm configuration
|
||||||
|
(e.g., Adam, SGD, or Adagrad) from the provided `OptimizerConfig`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
optimizer_config (OptimizerConfig): The optimizer configuration object containing
|
||||||
|
one of the optimizer algorithm configurations.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Union[AdamConfig, SgdConfig, AdagradConfig]: The specific optimizer algorithm
|
||||||
|
configuration extracted from `optimizer_config`.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If no optimizer algorithm is selected in `optimizer_config`.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
optimizer_config = OptimizerConfig(
|
||||||
|
adam=AdamConfig(lr=0.001, betas=(0.9, 0.999), eps=1e-8)
|
||||||
|
)
|
||||||
|
algorithm_config = get_optimizer_algorithm_config(optimizer_config)
|
||||||
|
# `algorithm_config` will be an instance of `AdamConfig`.
|
||||||
|
```
|
||||||
|
|
||||||
|
"""
|
||||||
if optimizer_config.adam is not None:
|
if optimizer_config.adam is not None:
|
||||||
return optimizer_config.adam
|
return optimizer_config.adam
|
||||||
elif optimizer_config.sgd is not None:
|
elif optimizer_config.sgd is not None:
|
||||||
|
@ -14,7 +14,35 @@ from tml.ml_logging.torch_logging import logging
|
|||||||
|
|
||||||
|
|
||||||
def compute_lr(lr_config, step):
|
def compute_lr(lr_config, step):
|
||||||
"""Compute a learning rate."""
|
"""
|
||||||
|
Compute the learning rate based on the specified learning rate configuration.
|
||||||
|
|
||||||
|
This function calculates the learning rate according to the given configuration, which can include
|
||||||
|
constant learning rates, piecewise constant schedules, linear ramps, and cosine annealing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lr_config (LearningRate): The learning rate configuration specifying the learning rate schedule.
|
||||||
|
step (int): The current training step or iteration.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
float: The computed learning rate for the current step.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the `lr_config` is invalid or contains conflicting options.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
lr_schedule = LearningRate(
|
||||||
|
constant=0.001,
|
||||||
|
piecewise_constant=PiecewiseConstant(
|
||||||
|
learning_rate_boundaries=[1000, 2000, 3000],
|
||||||
|
learning_rate_values=[0.1, 0.05, 0.01, 0.001]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
current_step = 2500
|
||||||
|
learning_rate = compute_lr(lr_schedule, current_step)
|
||||||
|
```
|
||||||
|
"""
|
||||||
if lr_config.constant is not None:
|
if lr_config.constant is not None:
|
||||||
return lr_config.constant
|
return lr_config.constant
|
||||||
elif lr_config.piecewise_constant is not None:
|
elif lr_config.piecewise_constant is not None:
|
||||||
@ -46,11 +74,54 @@ def compute_lr(lr_config, step):
|
|||||||
|
|
||||||
|
|
||||||
class LRShim(_LRScheduler):
|
class LRShim(_LRScheduler):
|
||||||
"""Shim to get learning rates into a LRScheduler.
|
|
||||||
|
|
||||||
This adheres to the torch.optim scheduler API and can be plugged anywhere that
|
|
||||||
e.g. exponential decay can be used.
|
|
||||||
"""
|
"""
|
||||||
|
Learning Rate Scheduler Shim to adjust learning rates during training.
|
||||||
|
|
||||||
|
This class acts as a shim to apply different learning rates to individual parameter groups
|
||||||
|
within an optimizer. It adheres to the torch.optim scheduler API and can be used with various
|
||||||
|
optimizers, allowing fine-grained control over learning rates based on configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
optimizer (torch.optim.Optimizer): The optimizer for which learning rates will be adjusted.
|
||||||
|
lr_dict (Dict[str, LearningRate]): A dictionary mapping parameter group names to their
|
||||||
|
corresponding learning rate configurations.
|
||||||
|
last_epoch (int, optional): The index of the last epoch. Default is -1.
|
||||||
|
verbose (bool, optional): If True, prints a warning message when accessing learning rates
|
||||||
|
using the deprecated `get_lr()` method. Default is False.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the number of parameter groups in the optimizer does not match the number
|
||||||
|
of learning rate configurations provided.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
To obtain the last computed learning rates, please use `get_last_lr()`.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
|
||||||
|
lr_schedule = {
|
||||||
|
'main': LearningRate(constant=0.01),
|
||||||
|
'auxiliary': LearningRate(piecewise_constant=PiecewiseConstant(
|
||||||
|
learning_rate_boundaries=[1000, 2000],
|
||||||
|
learning_rate_values=[0.01, 0.001]
|
||||||
|
))
|
||||||
|
}
|
||||||
|
lr_shim = LRShim(optimizer, lr_schedule)
|
||||||
|
|
||||||
|
for epoch in range(num_epochs):
|
||||||
|
# Train the model
|
||||||
|
train(...)
|
||||||
|
# Update learning rates at the end of each epoch
|
||||||
|
lr_shim.step(epoch)
|
||||||
|
|
||||||
|
final_lr_main = lr_shim.get_last_lr()['main']
|
||||||
|
final_lr_auxiliary = lr_shim.get_last_lr()['auxiliary']
|
||||||
|
```
|
||||||
|
|
||||||
|
See Also:
|
||||||
|
- `LearningRate`: Configuration for specifying learning rates.
|
||||||
|
- `PiecewiseConstant`: Configuration for piecewise constant learning rate schedules.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -95,9 +166,42 @@ def get_optimizer_class(optimizer_config: OptimizerConfig):
|
|||||||
def build_optimizer(
|
def build_optimizer(
|
||||||
model: torch.nn.Module, optimizer_config: OptimizerConfig
|
model: torch.nn.Module, optimizer_config: OptimizerConfig
|
||||||
) -> Tuple[Optimizer, _LRScheduler]:
|
) -> Tuple[Optimizer, _LRScheduler]:
|
||||||
"""Builds an optimizer and LR scheduler from an OptimizerConfig.
|
|
||||||
Note: use this when you want the same optimizer and learning rate schedule for all your parameters.
|
|
||||||
"""
|
"""
|
||||||
|
Build an optimizer and learning rate scheduler based on the provided optimizer configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (torch.nn.Module): The PyTorch model for which the optimizer will be created.
|
||||||
|
optimizer_config (OptimizerConfig): The optimizer configuration specifying the optimizer
|
||||||
|
algorithm and learning rate settings.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[Optimizer, _LRScheduler]: A tuple containing the optimizer and learning rate scheduler
|
||||||
|
objects.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
This function is intended for cases where you want the same optimizer and learning rate
|
||||||
|
schedule for all model parameters.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
model = MyModel()
|
||||||
|
optimizer_config = OptimizerConfig(
|
||||||
|
learning_rate=LearningRate(constant=0.01),
|
||||||
|
sgd=SgdConfig(lr=0.01, momentum=0.9)
|
||||||
|
)
|
||||||
|
optimizer, scheduler = build_optimizer(model, optimizer_config)
|
||||||
|
|
||||||
|
for epoch in range(num_epochs):
|
||||||
|
# Train the model with the optimizer
|
||||||
|
train(model, optimizer, ...)
|
||||||
|
# Update learning rates at the end of each epoch
|
||||||
|
scheduler.step(epoch)
|
||||||
|
```
|
||||||
|
|
||||||
|
See Also:
|
||||||
|
- `OptimizerConfig`: Configuration for specifying optimizer settings.
|
||||||
|
- `LRShim`: Learning rate scheduler shim for fine-grained learning rate control.
|
||||||
|
"""
|
||||||
optimizer_class = get_optimizer_class(optimizer_config)
|
optimizer_class = get_optimizer_class(optimizer_config)
|
||||||
optimizer = optimizer_class(model.parameters(), **optimizer_config.sgd.dict())
|
optimizer = optimizer_class(model.parameters(), **optimizer_config.sgd.dict())
|
||||||
# We're passing everything in as one group here
|
# We're passing everything in as one group here
|
||||||
|
@ -4,6 +4,17 @@ import pydantic
|
|||||||
|
|
||||||
|
|
||||||
class TwhinDataConfig(base_config.BaseConfig):
|
class TwhinDataConfig(base_config.BaseConfig):
|
||||||
|
"""
|
||||||
|
Configuration for Twhin model training data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_root (str): The root directory for the training data.
|
||||||
|
per_replica_batch_size (pydantic.PositiveInt): Batch size per replica.
|
||||||
|
global_negatives (int): The number of global negatives.
|
||||||
|
in_batch_negatives (int): The number of in-batch negatives.
|
||||||
|
limit (pydantic.PositiveInt): The limit on the number of data points to use.
|
||||||
|
offset (pydantic.PositiveInt, optional): The offset to start reading from. Default is None.
|
||||||
|
"""
|
||||||
data_root: str
|
data_root: str
|
||||||
per_replica_batch_size: pydantic.PositiveInt
|
per_replica_batch_size: pydantic.PositiveInt
|
||||||
global_negatives: int
|
global_negatives: int
|
||||||
|
@ -4,6 +4,16 @@ from tml.projects.twhin.data.edges import EdgesDataset
|
|||||||
|
|
||||||
|
|
||||||
def create_dataset(data_config: TwhinDataConfig, model_config: TwhinModelConfig):
|
def create_dataset(data_config: TwhinDataConfig, model_config: TwhinModelConfig):
|
||||||
|
"""
|
||||||
|
Create a dataset for Twhin model training.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_config (TwhinDataConfig): The data configuration for the dataset.
|
||||||
|
model_config (TwhinModelConfig): The model configuration containing embeddings and relations.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
EdgesDataset: The dataset for Twhin model training.
|
||||||
|
"""
|
||||||
tables = model_config.embeddings.tables
|
tables = model_config.embeddings.tables
|
||||||
table_sizes = {table.name: table.num_embeddings for table in tables}
|
table_sizes = {table.name: table.num_embeddings for table in tables}
|
||||||
relations = model_config.relations
|
relations = model_config.relations
|
||||||
|
@ -15,6 +15,15 @@ from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class EdgeBatch(DataclassBatch):
|
class EdgeBatch(DataclassBatch):
|
||||||
|
"""
|
||||||
|
Batch data structure for edge-based models.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
nodes (KeyedJaggedTensor): A KeyedJaggedTensor containing node embeddings.
|
||||||
|
labels (torch.Tensor): Tensor containing labels.
|
||||||
|
rels (torch.Tensor): Tensor containing relation information.
|
||||||
|
weights (torch.Tensor): Tensor containing weights.
|
||||||
|
"""
|
||||||
nodes: KeyedJaggedTensor
|
nodes: KeyedJaggedTensor
|
||||||
labels: torch.Tensor
|
labels: torch.Tensor
|
||||||
rels: torch.Tensor
|
rels: torch.Tensor
|
||||||
@ -22,6 +31,18 @@ class EdgeBatch(DataclassBatch):
|
|||||||
|
|
||||||
|
|
||||||
class EdgesDataset(Dataset):
|
class EdgesDataset(Dataset):
|
||||||
|
"""
|
||||||
|
Dataset for edge-based models.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_pattern (str): The file pattern for the dataset.
|
||||||
|
table_sizes (Dict[str, int]): A dictionary of table names and their sizes.
|
||||||
|
relations (List[Relation]): A list of relations between tables.
|
||||||
|
lhs_column_name (str): The name of the left-hand-side column.
|
||||||
|
rhs_column_name (str): The name of the right-hand-side column.
|
||||||
|
rel_column_name (str): The name of the relation column.
|
||||||
|
**dataset_kwargs: Additional keyword arguments for the parent Dataset class.
|
||||||
|
"""
|
||||||
rng = np.random.default_rng()
|
rng = np.random.default_rng()
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -56,6 +77,15 @@ class EdgesDataset(Dataset):
|
|||||||
super().__init__(file_pattern=file_pattern, **dataset_kwargs)
|
super().__init__(file_pattern=file_pattern, **dataset_kwargs)
|
||||||
|
|
||||||
def pa_to_batch(self, batch: pa.RecordBatch):
|
def pa_to_batch(self, batch: pa.RecordBatch):
|
||||||
|
"""
|
||||||
|
Converts a pyarrow RecordBatch to an EdgeBatch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch (pa.RecordBatch): A pyarrow RecordBatch containing data.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
EdgeBatch: An EdgeBatch containing node embeddings, labels, relations, and weights.
|
||||||
|
"""
|
||||||
lhs = torch.from_numpy(batch.column(self.lhs_column_name).to_numpy())
|
lhs = torch.from_numpy(batch.column(self.lhs_column_name).to_numpy())
|
||||||
rhs = torch.from_numpy(batch.column(self.rhs_column_name).to_numpy())
|
rhs = torch.from_numpy(batch.column(self.rhs_column_name).to_numpy())
|
||||||
rel = torch.from_numpy(batch.column(self.rel_column_name).to_numpy())
|
rel = torch.from_numpy(batch.column(self.rel_column_name).to_numpy())
|
||||||
@ -74,6 +104,14 @@ class EdgesDataset(Dataset):
|
|||||||
) -> Tuple[KeyedJaggedTensor, List[Tuple[int, int]]]:
|
) -> Tuple[KeyedJaggedTensor, List[Tuple[int, int]]]:
|
||||||
|
|
||||||
"""Process edges that contain lhs index, rhs index, relation index.
|
"""Process edges that contain lhs index, rhs index, relation index.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lhs (torch.Tensor): Tensor containing left-hand-side indices.
|
||||||
|
rhs (torch.Tensor): Tensor containing right-hand-side indices.
|
||||||
|
rel (torch.Tensor): Tensor containing relation indices.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[KeyedJaggedTensor, List[Tuple[int, int]]]: A KeyedJaggedTensor and relation index pairs.
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
```
|
```
|
||||||
@ -147,6 +185,12 @@ class EdgesDataset(Dataset):
|
|||||||
return KeyedJaggedTensor(keys=self.table_names, values=values, lengths=lengths)
|
return KeyedJaggedTensor(keys=self.table_names, values=values, lengths=lengths)
|
||||||
|
|
||||||
def to_batches(self):
|
def to_batches(self):
|
||||||
|
"""
|
||||||
|
Converts data to batches.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
pa.RecordBatch: A pyarrow RecordBatch containing data.
|
||||||
|
"""
|
||||||
ds = super().to_batches()
|
ds = super().to_batches()
|
||||||
batch_size = self._dataset_kwargs["batch_size"]
|
batch_size = self._dataset_kwargs["batch_size"]
|
||||||
|
|
||||||
|
@ -10,8 +10,29 @@ from pydantic import validator
|
|||||||
|
|
||||||
|
|
||||||
class TwhinEmbeddingsConfig(LargeEmbeddingsConfig):
|
class TwhinEmbeddingsConfig(LargeEmbeddingsConfig):
|
||||||
|
"""
|
||||||
|
Configuration class for Twhin model embeddings.
|
||||||
|
|
||||||
|
This class inherits from LargeEmbeddingsConfig and ensures that the embedding dimensions and data types
|
||||||
|
for all tables in the Twhin model embeddings configuration match.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
tables (List[TableConfig]): A list of table configurations for the model's embeddings.
|
||||||
|
"""
|
||||||
@validator("tables")
|
@validator("tables")
|
||||||
def embedding_dims_match(cls, tables):
|
def embedding_dims_match(cls, tables):
|
||||||
|
"""
|
||||||
|
Validate that embedding dimensions and data types match for all tables.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tables (List[TableConfig]): List of table configurations.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[TableConfig]: The list of validated table configurations.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AssertionError: If embedding dimensions or data types do not match.
|
||||||
|
"""
|
||||||
embedding_dim = tables[0].embedding_dim
|
embedding_dim = tables[0].embedding_dim
|
||||||
data_type = tables[0].data_type
|
data_type = tables[0].data_type
|
||||||
for table in tables:
|
for table in tables:
|
||||||
@ -21,11 +42,26 @@ class TwhinEmbeddingsConfig(LargeEmbeddingsConfig):
|
|||||||
|
|
||||||
|
|
||||||
class Operator(str, enum.Enum):
|
class Operator(str, enum.Enum):
|
||||||
|
"""
|
||||||
|
Enumeration of operator types.
|
||||||
|
|
||||||
|
This enumeration defines different types of operators that can be applied to Twhin model relations.
|
||||||
|
"""
|
||||||
TRANSLATION = "translation"
|
TRANSLATION = "translation"
|
||||||
|
|
||||||
|
|
||||||
class Relation(pydantic.BaseModel):
|
class Relation(pydantic.BaseModel):
|
||||||
"""graph relationship properties and operator"""
|
"""
|
||||||
|
Configuration class for graph relationships in the Twhin model.
|
||||||
|
|
||||||
|
This class defines properties and operators for graph relationships in the Twhin model.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
name (str): The name of the relationship.
|
||||||
|
lhs (str): The name of the entity on the left-hand side of the relation.
|
||||||
|
rhs (str): The name of the entity on the right-hand side of the relation.
|
||||||
|
operator (Operator): The transformation operator to apply to the left-hand side embedding before dot product.
|
||||||
|
"""
|
||||||
|
|
||||||
name: str = pydantic.Field(..., description="Relationship name.")
|
name: str = pydantic.Field(..., description="Relationship name.")
|
||||||
lhs: str = pydantic.Field(
|
lhs: str = pydantic.Field(
|
||||||
@ -42,12 +78,35 @@ class Relation(pydantic.BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class TwhinModelConfig(base_config.BaseConfig):
|
class TwhinModelConfig(base_config.BaseConfig):
|
||||||
|
"""
|
||||||
|
Configuration class for the Twhin model.
|
||||||
|
|
||||||
|
This class defines configuration options specific to the Twhin model.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
embeddings (TwhinEmbeddingsConfig): Configuration for the model's embeddings.
|
||||||
|
relations (List[Relation]): List of graph relationship configurations.
|
||||||
|
translation_optimizer (OptimizerConfig): Configuration for the optimizer used for translation.
|
||||||
|
"""
|
||||||
embeddings: TwhinEmbeddingsConfig
|
embeddings: TwhinEmbeddingsConfig
|
||||||
relations: typing.List[Relation]
|
relations: typing.List[Relation]
|
||||||
translation_optimizer: OptimizerConfig
|
translation_optimizer: OptimizerConfig
|
||||||
|
|
||||||
@validator("relations", each_item=True)
|
@validator("relations", each_item=True)
|
||||||
def valid_node_types(cls, relation, values, **kwargs):
|
def valid_node_types(cls, relation, values, **kwargs):
|
||||||
|
"""
|
||||||
|
Validate that the specified node types in relations are valid table names in embeddings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
relation (Relation): A single relation configuration.
|
||||||
|
values (dict): The values dictionary containing the "embeddings" configuration.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Relation: The validated relation configuration.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AssertionError: If the specified node types are not valid table names in embeddings.
|
||||||
|
"""
|
||||||
table_names = [table.name for table in values["embeddings"].tables]
|
table_names = [table.name for table in values["embeddings"].tables]
|
||||||
assert relation.lhs in table_names, f"Invalid lhs node type: {relation.lhs}"
|
assert relation.lhs in table_names, f"Invalid lhs node type: {relation.lhs}"
|
||||||
assert relation.rhs in table_names, f"Invalid rhs node type: {relation.rhs}"
|
assert relation.rhs in table_names, f"Invalid rhs node type: {relation.rhs}"
|
||||||
|
@ -14,6 +14,28 @@ from torchrec.optim.apply_optimizer_in_backward import apply_optimizer_in_backwa
|
|||||||
|
|
||||||
|
|
||||||
class TwhinModel(nn.Module):
|
class TwhinModel(nn.Module):
|
||||||
|
"""
|
||||||
|
Twhin model for graph-based entity embeddings and translation.
|
||||||
|
|
||||||
|
This class defines the Twhin model, which is used for learning embeddings of entities in a graph
|
||||||
|
and applying translations to these embeddings based on graph relationships.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_config (TwhinModelConfig): Configuration for the Twhin model.
|
||||||
|
data_config (TwhinDataConfig): Configuration for the data used by the model.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
batch_size (int): The batch size used for training.
|
||||||
|
table_names (List[str]): Names of tables in the model's embeddings.
|
||||||
|
large_embeddings (LargeEmbeddings): LargeEmbeddings instance for entity embeddings.
|
||||||
|
embedding_dim (int): Dimensionality of entity embeddings.
|
||||||
|
num_tables (int): Number of tables in the model's embeddings.
|
||||||
|
in_batch_negatives (int): Number of in-batch negative samples to use during training.
|
||||||
|
global_negatives (int): Number of global negative samples to use during training.
|
||||||
|
num_relations (int): Number of graph relationships in the model.
|
||||||
|
all_trans_embs (torch.nn.Parameter): Parameter tensor for translation embeddings.
|
||||||
|
|
||||||
|
"""
|
||||||
def __init__(self, model_config: TwhinModelConfig, data_config: TwhinDataConfig):
|
def __init__(self, model_config: TwhinModelConfig, data_config: TwhinDataConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.batch_size = data_config.per_replica_batch_size
|
self.batch_size = data_config.per_replica_batch_size
|
||||||
@ -31,7 +53,17 @@ class TwhinModel(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, batch: EdgeBatch):
|
def forward(self, batch: EdgeBatch):
|
||||||
|
"""
|
||||||
|
Forward pass of the Twhin model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch (EdgeBatch): Input batch containing graph edge information.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: A dictionary containing model output with "logits" and "probabilities".
|
||||||
|
- "logits" (torch.Tensor): Logit scores.
|
||||||
|
- "probabilities" (torch.Tensor): Sigmoid probabilities.
|
||||||
|
"""
|
||||||
# B x D
|
# B x D
|
||||||
trans_embs = self.all_trans_embs.data[batch.rels]
|
trans_embs = self.all_trans_embs.data[batch.rels]
|
||||||
|
|
||||||
@ -98,6 +130,18 @@ class TwhinModel(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
def apply_optimizers(model: TwhinModel, model_config: TwhinModelConfig):
|
def apply_optimizers(model: TwhinModel, model_config: TwhinModelConfig):
|
||||||
|
"""
|
||||||
|
Apply optimizers to the Twhin model's embeddings.
|
||||||
|
|
||||||
|
This function applies optimizers to the embeddings of the Twhin model based on the provided configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (TwhinModel): The Twhin model to apply optimizers to.
|
||||||
|
model_config (TwhinModelConfig): Configuration for the Twhin model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TwhinModel: The Twhin model with optimizers applied to its embeddings.
|
||||||
|
"""
|
||||||
for table in model_config.embeddings.tables:
|
for table in model_config.embeddings.tables:
|
||||||
optimizer_class = get_optimizer_class(table.optimizer)
|
optimizer_class = get_optimizer_class(table.optimizer)
|
||||||
optimizer_kwargs = get_optimizer_algorithm_config(table.optimizer).dict()
|
optimizer_kwargs = get_optimizer_algorithm_config(table.optimizer).dict()
|
||||||
@ -124,10 +168,14 @@ class TwhinModelAndLoss(torch.nn.Module):
|
|||||||
device: torch.device,
|
device: torch.device,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Args:
|
Initialize a TwhinModelAndLoss module.
|
||||||
model: torch module to wrap.
|
|
||||||
loss_fn: Function for calculating loss, should accept logits and labels.
|
Args:
|
||||||
"""
|
model: The torch module to wrap.
|
||||||
|
loss_fn: A function for calculating loss, should accept logits and labels.
|
||||||
|
data_config: Configuration for Twhin data.
|
||||||
|
device: The torch device to use for calculations.
|
||||||
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.model = model
|
self.model = model
|
||||||
self.loss_fn = loss_fn
|
self.loss_fn = loss_fn
|
||||||
@ -136,14 +184,21 @@ class TwhinModelAndLoss(torch.nn.Module):
|
|||||||
self.device = device
|
self.device = device
|
||||||
|
|
||||||
def forward(self, batch: "RecapBatch"): # type: ignore[name-defined]
|
def forward(self, batch: "RecapBatch"): # type: ignore[name-defined]
|
||||||
"""Runs model forward and calculates loss according to given loss_fn.
|
|
||||||
|
|
||||||
NOTE: The input signature here needs to be a Pipelineable object for
|
|
||||||
prefetching purposes during training using torchrec's pipeline. However
|
|
||||||
the underlying model signature needs to be exportable to onnx, requiring
|
|
||||||
generic python types. see https://pytorch.org/docs/stable/onnx.html#types.
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
Run the model forward and calculate the loss according to the given loss_fn.
|
||||||
|
|
||||||
|
NOTE: The input signature here needs to be a Pipelineable object for
|
||||||
|
prefetching purposes during training using torchrec's pipeline. However
|
||||||
|
the underlying model signature needs to be exportable to onnx, requiring
|
||||||
|
generic python types. see https://pytorch.org/docs/stable/onnx.html#types
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch ("RecapBatch"): The input batch for model inference.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[torch.Tensor, Dict[str, torch.Tensor]]: A tuple containing the loss tensor and a dictionary of
|
||||||
|
additional outputs including logits, labels, and weights.
|
||||||
|
"""
|
||||||
outputs = self.model(batch)
|
outputs = self.model(batch)
|
||||||
logits = outputs["logits"]
|
logits = outputs["logits"]
|
||||||
|
|
||||||
|
@ -18,6 +18,12 @@ EMB_DIM = 128
|
|||||||
|
|
||||||
|
|
||||||
def twhin_model_config() -> TwhinModelConfig:
|
def twhin_model_config() -> TwhinModelConfig:
|
||||||
|
"""
|
||||||
|
Create a configuration for the Twhin model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TwhinModelConfig: The Twhin model configuration.
|
||||||
|
"""
|
||||||
sgd_config_0 = OptimizerConfig(sgd=SgdConfig(lr=0.01))
|
sgd_config_0 = OptimizerConfig(sgd=SgdConfig(lr=0.01))
|
||||||
sgd_config_1 = OptimizerConfig(sgd=SgdConfig(lr=0.02))
|
sgd_config_1 = OptimizerConfig(sgd=SgdConfig(lr=0.02))
|
||||||
|
|
||||||
@ -52,6 +58,12 @@ def twhin_model_config() -> TwhinModelConfig:
|
|||||||
|
|
||||||
|
|
||||||
def twhin_data_config() -> TwhinDataConfig:
|
def twhin_data_config() -> TwhinDataConfig:
|
||||||
|
"""
|
||||||
|
Create a configuration for the Twhin data.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TwhinDataConfig: The Twhin data configuration.
|
||||||
|
"""
|
||||||
data_config = TwhinDataConfig(
|
data_config = TwhinDataConfig(
|
||||||
data_root="/",
|
data_root="/",
|
||||||
per_replica_batch_size=10,
|
per_replica_batch_size=10,
|
||||||
@ -65,6 +77,15 @@ def twhin_data_config() -> TwhinDataConfig:
|
|||||||
|
|
||||||
|
|
||||||
def test_twhin_model():
|
def test_twhin_model():
|
||||||
|
"""
|
||||||
|
Test the Twhin model creation and optimization.
|
||||||
|
|
||||||
|
This function creates a Twhin model using the specified configuration and tests its optimization. It also checks
|
||||||
|
the device placement of model parameters.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
model_config = twhin_model_config()
|
model_config = twhin_model_config()
|
||||||
loss_fn = F.binary_cross_entropy_with_logits
|
loss_fn = F.binary_cross_entropy_with_logits
|
||||||
|
|
||||||
|
@ -15,6 +15,14 @@ TRANSLATION_OPT_KEY = "operator_opt"
|
|||||||
|
|
||||||
|
|
||||||
def _lr_from_config(optimizer_config):
|
def _lr_from_config(optimizer_config):
|
||||||
|
"""Get the learning rate from an optimizer configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
optimizer_config: Optimizer configuration.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Learning rate from the optimizer configuration.
|
||||||
|
"""
|
||||||
if optimizer_config.learning_rate is not None:
|
if optimizer_config.learning_rate is not None:
|
||||||
return optimizer_config.learning_rate
|
return optimizer_config.learning_rate
|
||||||
else:
|
else:
|
||||||
@ -26,13 +34,13 @@ def _lr_from_config(optimizer_config):
|
|||||||
def build_optimizer(model: TwhinModel, config: TwhinModelConfig):
|
def build_optimizer(model: TwhinModel, config: TwhinModelConfig):
|
||||||
"""Builds an optimizer for a Twhin model combining the embeddings optimizer with an optimizer for per-relation translations.
|
"""Builds an optimizer for a Twhin model combining the embeddings optimizer with an optimizer for per-relation translations.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model: TwhinModel to build optimizer for.
|
model: TwhinModel to build optimizer for.
|
||||||
config: TwhinConfig for model.
|
config: TwhinModelConfig for model.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Optimizer for model.
|
Optimizer for model.
|
||||||
"""
|
"""
|
||||||
translation_optimizer_fn = functools.partial(
|
translation_optimizer_fn = functools.partial(
|
||||||
get_optimizer_class(config.translation_optimizer),
|
get_optimizer_class(config.translation_optimizer),
|
||||||
**get_optimizer_algorithm_config(config.translation_optimizer).dict(),
|
**get_optimizer_algorithm_config(config.translation_optimizer).dict(),
|
||||||
|
@ -37,6 +37,12 @@ def run(
|
|||||||
all_config: TwhinConfig,
|
all_config: TwhinConfig,
|
||||||
save_dir: Optional[str] = None,
|
save_dir: Optional[str] = None,
|
||||||
):
|
):
|
||||||
|
"""Run the training process for TwhinModel.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
all_config (TwhinConfig): The configuration for the entire Twhin model.
|
||||||
|
save_dir (str, optional): The directory where model checkpoints will be saved. Defaults to None.
|
||||||
|
"""
|
||||||
train_dataset = create_dataset(all_config.train_data, all_config.model)
|
train_dataset = create_dataset(all_config.train_data, all_config.model)
|
||||||
|
|
||||||
if env.is_reader():
|
if env.is_reader():
|
||||||
@ -80,6 +86,11 @@ def run(
|
|||||||
|
|
||||||
|
|
||||||
def main(argv):
|
def main(argv):
|
||||||
|
"""Main entry point for the Twhin training script.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
argv: Command-line arguments.
|
||||||
|
"""
|
||||||
logging.info("Starting")
|
logging.info("Starting")
|
||||||
|
|
||||||
logging.info(f"parsing config from {FLAGS.config_yaml_path}...")
|
logging.info(f"parsing config from {FLAGS.config_yaml_path}...")
|
||||||
|
@ -25,14 +25,58 @@ from tml.ml_logging.torch_logging import logging
|
|||||||
|
|
||||||
|
|
||||||
class _Reader(pa.flight.FlightServerBase):
|
class _Reader(pa.flight.FlightServerBase):
|
||||||
"""Distributed reader flight server wrapping a dataset."""
|
"""
|
||||||
|
Distributed reader flight server wrapping a dataset.
|
||||||
|
|
||||||
|
This class implements a Flight server that wraps a dataset, allowing clients to retrieve data
|
||||||
|
from the dataset over the Flight protocol. It is designed to be used in a distributed environment
|
||||||
|
for efficient data access.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
location (str): The location of the Flight server.
|
||||||
|
ds (Dataset): The dataset to be wrapped by the Flight server.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
_location (str): The location of the Flight server.
|
||||||
|
_ds (Dataset): The dataset wrapped by the Flight server.
|
||||||
|
|
||||||
|
Methods:
|
||||||
|
do_get(_, __): Handles Flight requests for data retrieval.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Flight is an Apache Arrow project that provides a framework for efficient data transfer.
|
||||||
|
This class allows clients to retrieve data from the dataset using Flight.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, location: str, ds: "Dataset"):
|
def __init__(self, location: str, ds: "Dataset"):
|
||||||
|
"""
|
||||||
|
Initialize a new _Reader instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
location (str): The location of the Flight server.
|
||||||
|
ds (Dataset): The dataset to be wrapped by the Flight server.
|
||||||
|
"""
|
||||||
super().__init__(location=location)
|
super().__init__(location=location)
|
||||||
self._location = location
|
self._location = location
|
||||||
self._ds = ds
|
self._ds = ds
|
||||||
|
|
||||||
def do_get(self, _, __):
|
def do_get(self, _, __):
|
||||||
|
"""
|
||||||
|
Handle Flight requests for data retrieval.
|
||||||
|
|
||||||
|
This method retrieves data from the wrapped dataset and provides it to clients over the Flight protocol.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
_: Unused argument.
|
||||||
|
__: Unused argument.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
pa.flight.RecordBatchStream: A stream of record batches containing data from the dataset.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
An updated schema (to account for column selection) must be given to the stream.
|
||||||
|
"""
|
||||||
# NB: An updated schema (to account for column selection) has to be given the stream.
|
# NB: An updated schema (to account for column selection) has to be given the stream.
|
||||||
schema = next(iter(self._ds.to_batches())).schema
|
schema = next(iter(self._ds.to_batches())).schema
|
||||||
batches = self._ds.to_batches()
|
batches = self._ds.to_batches()
|
||||||
@ -46,13 +90,49 @@ class _Reader(pa.flight.FlightServerBase):
|
|||||||
|
|
||||||
|
|
||||||
class Dataset(torch.utils.data.IterableDataset):
|
class Dataset(torch.utils.data.IterableDataset):
|
||||||
|
"""
|
||||||
|
A PyTorch IterableDataset wrapping a Parquet dataset for efficient data loading.
|
||||||
|
|
||||||
|
This class enables efficient loading of data from Parquet files using PyArrow.
|
||||||
|
It is designed to be used as an IterableDataset in PyTorch for training and inference.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_pattern (str): A glob pattern specifying the Parquet files to include in the dataset.
|
||||||
|
**dataset_kwargs: Additional keyword arguments passed to PyArrow's `to_batches` method.
|
||||||
|
Refer to https://arrow.apache.org/docs/python/generated/pyarrow.dataset.Scanner.html#pyarrow.dataset.Scanner.from_dataset
|
||||||
|
for more details.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
LOCATION (str): The default location for the Flight server used for data distribution.
|
||||||
|
_file_pattern (str): The glob pattern specifying Parquet files in the dataset.
|
||||||
|
_fs: The filesystem object used for file operations.
|
||||||
|
_dataset_kwargs (dict): Additional keyword arguments passed to PyArrow's `to_batches` method.
|
||||||
|
_files (list): A list of file paths matching the glob pattern.
|
||||||
|
_schema (pa.Schema): The schema of the Parquet dataset.
|
||||||
|
|
||||||
|
Methods:
|
||||||
|
serve(): Start serving the dataset using a Flight server.
|
||||||
|
to_batches(): Generate batches of data from the Parquet dataset.
|
||||||
|
pa_to_batch(batch: pa.RecordBatch) -> DataclassBatch: Convert a Parquet RecordBatch to a custom data batch.
|
||||||
|
dataloader(remote: bool = False): Create a PyTorch DataLoader for iterating through the dataset.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
This class efficiently loads data from Parquet files using PyArrow, and it can be used with PyTorch
|
||||||
|
to create DataLoader instances for training or inference.
|
||||||
|
"""
|
||||||
LOCATION = "grpc://0.0.0.0:2222"
|
LOCATION = "grpc://0.0.0.0:2222"
|
||||||
|
|
||||||
def __init__(self, file_pattern: str, **dataset_kwargs) -> None:
|
def __init__(self, file_pattern: str, **dataset_kwargs) -> None:
|
||||||
"""Specify batch size and column to select for.
|
"""
|
||||||
|
Initialize a new Dataset instance. Specify batch size and column to select for.
|
||||||
|
|
||||||
Refer to https://arrow.apache.org/docs/python/generated/pyarrow.dataset.Scanner.html#pyarrow.dataset.Scanner.from_dataset.
|
Refer to https://arrow.apache.org/docs/python/generated/pyarrow.dataset.Scanner.html#pyarrow.dataset.Scanner.from_dataset.
|
||||||
"""
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_pattern (str): A glob pattern specifying the Parquet files to include in the dataset.
|
||||||
|
**dataset_kwargs: Additional keyword arguments passed to PyArrow's `to_batches` method.
|
||||||
|
"""
|
||||||
self._file_pattern = file_pattern
|
self._file_pattern = file_pattern
|
||||||
self._fs = infer_fs(self._file_pattern)
|
self._fs = infer_fs(self._file_pattern)
|
||||||
self._dataset_kwargs = dataset_kwargs
|
self._dataset_kwargs = dataset_kwargs
|
||||||
@ -64,16 +144,25 @@ class Dataset(torch.utils.data.IterableDataset):
|
|||||||
self._validate_columns()
|
self._validate_columns()
|
||||||
|
|
||||||
def _validate_columns(self):
|
def _validate_columns(self):
|
||||||
|
"""
|
||||||
|
Validate the specified columns against the dataset schema.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: If any specified columns are not found in the dataset schema.
|
||||||
|
"""
|
||||||
columns = set(self._dataset_kwargs.get("columns", []))
|
columns = set(self._dataset_kwargs.get("columns", []))
|
||||||
wrong_columns = set(columns) - set(self._schema.names)
|
wrong_columns = set(columns) - set(self._schema.names)
|
||||||
if wrong_columns:
|
if wrong_columns:
|
||||||
raise Exception(f"Specified columns {list(wrong_columns)} not in schema.")
|
raise Exception(f"Specified columns {list(wrong_columns)} not in schema.")
|
||||||
|
|
||||||
def serve(self):
|
def serve(self):
|
||||||
|
"""Start serving the dataset using a Flight server."""
|
||||||
self.reader = _Reader(location=self.LOCATION, ds=self)
|
self.reader = _Reader(location=self.LOCATION, ds=self)
|
||||||
self.reader.serve()
|
self.reader.serve()
|
||||||
|
|
||||||
def _create_dataset(self):
|
def _create_dataset(self):
|
||||||
|
"""Create a PyArrow dataset for data retrieval."""
|
||||||
|
|
||||||
return pads.dataset(
|
return pads.dataset(
|
||||||
source=random.sample(self._files, len(self._files))[0],
|
source=random.sample(self._files, len(self._files))[0],
|
||||||
format="parquet",
|
format="parquet",
|
||||||
@ -100,9 +189,33 @@ class Dataset(torch.utils.data.IterableDataset):
|
|||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def pa_to_batch(self, batch: pa.RecordBatch) -> DataclassBatch:
|
def pa_to_batch(self, batch: pa.RecordBatch) -> DataclassBatch:
|
||||||
|
"""
|
||||||
|
Convert a Parquet RecordBatch to a custom data batch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch (pa.RecordBatch): A batch of data from the Parquet dataset.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DataclassBatch: A custom data batch used in PyTorch training.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotImplementedError: This method must be implemented in derived classes.
|
||||||
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def dataloader(self, remote: bool = False):
|
def dataloader(self, remote: bool = False):
|
||||||
|
"""
|
||||||
|
Create a PyTorch DataLoader for iterating through the dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
remote (bool, optional): If True, create a remote DataLoader using Flight for distributed training.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DataLoader: A PyTorch DataLoader for iterating through the dataset.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
If `remote` is True, a remote DataLoader is created for distributed training using Flight.
|
||||||
|
"""
|
||||||
if not remote:
|
if not remote:
|
||||||
return map(self.pa_to_batch, self.to_batches())
|
return map(self.pa_to_batch, self.to_batches())
|
||||||
readers = get_readers(2)
|
readers = get_readers(2)
|
||||||
@ -117,6 +230,25 @@ GRPC_OPTIONS = [
|
|||||||
|
|
||||||
|
|
||||||
def get_readers(num_readers_per_worker: int):
|
def get_readers(num_readers_per_worker: int):
|
||||||
|
"""
|
||||||
|
Get Flight readers for distributed data loading.
|
||||||
|
|
||||||
|
This function retrieves Flight readers for distributed data loading in a PyTorch environment.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_readers_per_worker (int): The number of Flight readers to retrieve per worker.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[pa.RecordBatchFileReader]: A list of Flight readers for distributed data loading.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Flight readers are used to fetch data in a distributed manner for efficient data loading.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
To obtain Flight readers, use the following code:
|
||||||
|
|
||||||
|
>>> readers = get_readers(num_readers_per_worker=2)
|
||||||
|
"""
|
||||||
addresses = env.get_flight_server_addresses()
|
addresses = env.get_flight_server_addresses()
|
||||||
|
|
||||||
readers = []
|
readers = []
|
||||||
|
@ -21,6 +21,16 @@ import torch.distributed as dist
|
|||||||
|
|
||||||
|
|
||||||
def maybe_start_dataset_service():
|
def maybe_start_dataset_service():
|
||||||
|
"""
|
||||||
|
Start the dataset service if readers are available and required dependencies are met.
|
||||||
|
|
||||||
|
This function checks if readers are available and if the required TensorFlow version is >= 2.5.
|
||||||
|
If both conditions are met and the current environment is the dispatcher or reader, it starts
|
||||||
|
the TensorFlow dataset service.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: If the required TensorFlow version is not met (>= 2.5).
|
||||||
|
"""
|
||||||
if not env.has_readers():
|
if not env.has_readers():
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -59,6 +69,24 @@ def maybe_start_dataset_service():
|
|||||||
def register_dataset(
|
def register_dataset(
|
||||||
dataset: tf.data.Dataset, dataset_service: str, compression: Optional[str] = "AUTO"
|
dataset: tf.data.Dataset, dataset_service: str, compression: Optional[str] = "AUTO"
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Register a dataset with the distributed dataset service.
|
||||||
|
|
||||||
|
This function registers a dataset with the distributed dataset service and broadcasts the dataset ID
|
||||||
|
and job name to all processes in the distributed environment.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset (tf.data.Dataset): The dataset to be registered.
|
||||||
|
dataset_service (str): The name of the dataset service.
|
||||||
|
compression (Optional[str]): The compression type for the dataset (default is "AUTO").
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[int, str]: A tuple containing the dataset ID and job name.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
This function should be called on the rank 0 process.
|
||||||
|
|
||||||
|
"""
|
||||||
if dist.get_rank() == 0:
|
if dist.get_rank() == 0:
|
||||||
dataset_id = _register_dataset(
|
dataset_id = _register_dataset(
|
||||||
service=dataset_service,
|
service=dataset_service,
|
||||||
@ -82,6 +110,23 @@ def distribute_from_dataset_id(
|
|||||||
compression: Optional[str] = "AUTO",
|
compression: Optional[str] = "AUTO",
|
||||||
prefetch: Optional[int] = tf.data.experimental.AUTOTUNE,
|
prefetch: Optional[int] = tf.data.experimental.AUTOTUNE,
|
||||||
) -> tf.data.Dataset:
|
) -> tf.data.Dataset:
|
||||||
|
"""
|
||||||
|
Distribute a dataset from a registered dataset ID.
|
||||||
|
|
||||||
|
This function consumes a dataset from the distributed dataset service using the provided dataset ID
|
||||||
|
and job name. It also supports prefetching for improved performance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset_service (str): The name of the dataset service.
|
||||||
|
dataset_id (int): The ID of the dataset to be consumed.
|
||||||
|
job_name (Optional[str]): The name of the job associated with the dataset (optional).
|
||||||
|
compression (Optional[str]): The compression type for the dataset (default is "AUTO").
|
||||||
|
prefetch (Optional[int]): The number of elements to prefetch (default is tf.data.experimental.AUTOTUNE).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tf.data.Dataset: The distributed dataset.
|
||||||
|
|
||||||
|
"""
|
||||||
logging.info(f"rank{dist.get_rank()}: Consuming dds job with {dataset_id}, {job_name}")
|
logging.info(f"rank{dist.get_rank()}: Consuming dds job with {dataset_id}, {job_name}")
|
||||||
dataset = _from_dataset_id(
|
dataset = _from_dataset_id(
|
||||||
processing_mode="parallel_epochs",
|
processing_mode="parallel_epochs",
|
||||||
@ -97,15 +142,28 @@ def distribute_from_dataset_id(
|
|||||||
|
|
||||||
|
|
||||||
def maybe_distribute_dataset(dataset: tf.data.Dataset) -> tf.data.Dataset:
|
def maybe_distribute_dataset(dataset: tf.data.Dataset) -> tf.data.Dataset:
|
||||||
"""Torch-compatible and distributed-training-aware dataset service distributor.
|
|
||||||
|
|
||||||
- rank 0 process will register the given dataset.
|
|
||||||
- rank 0 process will broadcast job name and dataset id.
|
|
||||||
- all rank processes will consume from the same job/dataset.
|
|
||||||
|
|
||||||
Without this, dataset workers will try to serve 1 job per rank process and OOM.
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
Distribute a TensorFlow dataset for Torch-compatible and distributed training-aware consumption.
|
||||||
|
|
||||||
|
This function is used to distribute a dataset in a distributed training environment. It performs the
|
||||||
|
following steps:
|
||||||
|
- On the rank 0 process, it registers the given dataset with the distributed dataset service.
|
||||||
|
- It broadcasts the job name and dataset ID to all rank processes.
|
||||||
|
- All rank processes then consume the same dataset from the distributed dataset service.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset (tf.data.Dataset): The TensorFlow dataset to be distributed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tf.data.Dataset: The distributed TensorFlow dataset.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
- If there are no reader processes in the distributed environment, the original dataset is returned
|
||||||
|
without any distribution.
|
||||||
|
- This function is intended for use in distributed training environments to prevent out-of-memory (OOM)
|
||||||
|
issues caused by each rank process trying to serve one job.
|
||||||
|
|
||||||
|
"""
|
||||||
if not env.has_readers():
|
if not env.has_readers():
|
||||||
return dataset
|
return dataset
|
||||||
dataset_service = env.get_dds()
|
dataset_service = env.get_dds()
|
||||||
|
@ -12,6 +12,17 @@ import torch
|
|||||||
|
|
||||||
|
|
||||||
def create_dataset(tmpdir):
|
def create_dataset(tmpdir):
|
||||||
|
"""
|
||||||
|
Create a mock dataset for testing.
|
||||||
|
|
||||||
|
This function creates a mock dataset using PyArrow and Parquet for testing purposes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tmpdir: A temporary directory where the dataset will be created.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
MockDataset: A mock dataset for testing.
|
||||||
|
"""
|
||||||
|
|
||||||
table = pa.table(
|
table = pa.table(
|
||||||
{
|
{
|
||||||
@ -34,6 +45,14 @@ def create_dataset(tmpdir):
|
|||||||
|
|
||||||
|
|
||||||
def test_dataset(tmpdir):
|
def test_dataset(tmpdir):
|
||||||
|
"""
|
||||||
|
Test the created dataset.
|
||||||
|
|
||||||
|
This function tests the created mock dataset and checks if it behaves as expected.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tmpdir: A temporary directory used for testing.
|
||||||
|
"""
|
||||||
ds = create_dataset(tmpdir)
|
ds = create_dataset(tmpdir)
|
||||||
batch = next(iter(ds.dataloader(remote=False)))
|
batch = next(iter(ds.dataloader(remote=False)))
|
||||||
assert batch.batch_size == 2
|
assert batch.batch_size == 2
|
||||||
@ -46,6 +65,14 @@ def test_dataset(tmpdir):
|
|||||||
reason="Multiprocessing doesn't work on github yet.",
|
reason="Multiprocessing doesn't work on github yet.",
|
||||||
)
|
)
|
||||||
def test_distributed_dataset(tmpdir):
|
def test_distributed_dataset(tmpdir):
|
||||||
|
"""
|
||||||
|
Test the distributed dataset.
|
||||||
|
|
||||||
|
This function tests the distributed version of the mock dataset using multiprocessing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tmpdir: A temporary directory used for testing.
|
||||||
|
"""
|
||||||
MOCK_ENV = {"TEMP_SLURM_NUM_READERS": "1"}
|
MOCK_ENV = {"TEMP_SLURM_NUM_READERS": "1"}
|
||||||
|
|
||||||
def _client():
|
def _client():
|
||||||
|
151
reader/utils.py
151
reader/utils.py
@ -11,11 +11,55 @@ import torch
|
|||||||
|
|
||||||
|
|
||||||
def roundrobin(*iterables):
|
def roundrobin(*iterables):
|
||||||
"""Round robin through provided iterables, useful for simple load balancing.
|
|
||||||
|
|
||||||
Adapted from https://docs.python.org/3/library/itertools.html.
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
Iterate through provided iterables in a round-robin fashion.
|
||||||
|
|
||||||
|
This function takes multiple iterables and returns an iterator that yields elements from
|
||||||
|
each iterable in a round-robin manner. It continues cycling through the iterables until
|
||||||
|
all of them are exhausted.
|
||||||
|
|
||||||
|
Adapted from https://docs.python.org/3/library/itertools.html.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
*iterables: One or more iterable objects to iterate through.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
Elements from the provided iterables in a round-robin fashion.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
StopIteration: If all provided iterables are exhausted.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
iterable1 = [1, 2, 3]
|
||||||
|
iterable2 = ['a', 'b', 'c']
|
||||||
|
iterable3 = [0.1, 0.2, 0.3]
|
||||||
|
|
||||||
|
for item in roundrobin(iterable1, iterable2, iterable3):
|
||||||
|
print(item)
|
||||||
|
|
||||||
|
# Output:
|
||||||
|
# 1
|
||||||
|
# 'a'
|
||||||
|
# 0.1
|
||||||
|
# 2
|
||||||
|
# 'b'
|
||||||
|
# 0.2
|
||||||
|
# 3
|
||||||
|
# 'c'
|
||||||
|
# 0.3
|
||||||
|
```
|
||||||
|
|
||||||
|
Note:
|
||||||
|
- If one of the provided iterables is shorter than the others, the function will
|
||||||
|
continue iterating through the remaining iterables until all are exhausted.
|
||||||
|
- If an iterable raises an exception during iteration, a warning message is logged,
|
||||||
|
and the function continues with the next iterable.
|
||||||
|
|
||||||
|
See Also:
|
||||||
|
- `itertools.cycle`: A function that repeatedly cycles through elements of an iterable.
|
||||||
|
- `itertools.islice`: A function to slice an iterable to limit the number of iterations.
|
||||||
|
"""
|
||||||
num_active = len(iterables)
|
num_active = len(iterables)
|
||||||
nexts = itertools.cycle(iter(it).__next__ for it in iterables)
|
nexts = itertools.cycle(iter(it).__next__ for it in iterables)
|
||||||
while num_active:
|
while num_active:
|
||||||
@ -35,6 +79,48 @@ def roundrobin(*iterables):
|
|||||||
|
|
||||||
|
|
||||||
def speed_check(data_loader, max_steps: int, frequency: int, peek: Optional[int]):
|
def speed_check(data_loader, max_steps: int, frequency: int, peek: Optional[int]):
|
||||||
|
"""
|
||||||
|
Monitor the speed and progress of data loading using a data loader.
|
||||||
|
|
||||||
|
This function iterates through a data loader for a specified number of steps or until
|
||||||
|
the end of the data loader is reached, periodically logging progress information.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_loader: The data loader to monitor.
|
||||||
|
max_steps: The maximum number of steps to iterate through the data loader.
|
||||||
|
frequency: The frequency (in steps) at which to log progress.
|
||||||
|
peek (optional): If specified, it indicates the frequency (in steps) at which to log
|
||||||
|
batch contents for inspection.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
# Create a data loader (replace with your own DataLoader configuration)
|
||||||
|
data_loader = DataLoader(dataset, batch_size=32, shuffle=True)
|
||||||
|
|
||||||
|
# Monitor data loading speed and progress
|
||||||
|
speed_check(data_loader, max_steps=1000, frequency=50, peek=500)
|
||||||
|
```
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_loader: The data loader to monitor.
|
||||||
|
max_steps: The maximum number of steps to iterate through the data loader.
|
||||||
|
frequency: The frequency (in steps) at which to log progress.
|
||||||
|
peek (optional): If specified, it indicates the frequency (in steps) at which to log
|
||||||
|
batch contents for inspection.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
- The function logs information about elapsed time, the number of examples processed,
|
||||||
|
and the processing speed in examples per second.
|
||||||
|
- If `peek` is provided, batch contents will be logged for inspection at the specified
|
||||||
|
frequency.
|
||||||
|
|
||||||
|
See Also:
|
||||||
|
- `torch.utils.data.DataLoader`: PyTorch's data loading utility for batching and
|
||||||
|
iterating through datasets.
|
||||||
|
"""
|
||||||
num_examples = 0
|
num_examples = 0
|
||||||
prev = time.perf_counter()
|
prev = time.perf_counter()
|
||||||
for idx, batch in enumerate(data_loader):
|
for idx, batch in enumerate(data_loader):
|
||||||
@ -57,11 +143,66 @@ def speed_check(data_loader, max_steps: int, frequency: int, peek: Optional[int]
|
|||||||
|
|
||||||
|
|
||||||
def pa_to_torch(array: pa.array) -> torch.Tensor:
|
def pa_to_torch(array: pa.array) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Convert a PyArrow Array to a PyTorch Tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
array (pa.array): The PyArrow Array to convert.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: A PyTorch Tensor containing the data from the input PyArrow Array.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
import pyarrow as pa
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# Create a PyArrow Array
|
||||||
|
arrow_array = pa.array([1, 2, 3])
|
||||||
|
|
||||||
|
# Convert it to a PyTorch Tensor
|
||||||
|
torch_tensor = pa_to_torch(arrow_array)
|
||||||
|
```
|
||||||
|
"""
|
||||||
return torch.from_numpy(array.to_numpy())
|
return torch.from_numpy(array.to_numpy())
|
||||||
|
|
||||||
|
|
||||||
def create_default_pa_to_batch(schema) -> DataclassBatch:
|
def create_default_pa_to_batch(schema) -> DataclassBatch:
|
||||||
""" """
|
"""
|
||||||
|
Create a function that converts a PyArrow RecordBatch to a custom DataclassBatch with imputed values for missing data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
schema (pa.Schema): The PyArrow schema describing the data structure of the RecordBatch.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
callable: A function that takes a PyArrow RecordBatch as input and returns a custom DataclassBatch.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
import pyarrow as pa
|
||||||
|
from dataclass_batch import DataclassBatch
|
||||||
|
|
||||||
|
# Define a PyArrow schema
|
||||||
|
schema = pa.schema([
|
||||||
|
("feature1", pa.float64()),
|
||||||
|
("feature2", pa.int64()),
|
||||||
|
("label", pa.int64()),
|
||||||
|
])
|
||||||
|
|
||||||
|
# Create the conversion function
|
||||||
|
pa_to_batch = create_default_pa_to_batch(schema)
|
||||||
|
|
||||||
|
# Create a PyArrow RecordBatch
|
||||||
|
record_batch = pa.RecordBatch.from_pandas(pd.DataFrame({
|
||||||
|
"feature1": [1.0, 2.0, None],
|
||||||
|
"feature2": [10, 20, 30],
|
||||||
|
"label": [0, 1, None],
|
||||||
|
}))
|
||||||
|
|
||||||
|
# Convert the RecordBatch to a custom DataclassBatch
|
||||||
|
custom_batch = pa_to_batch(record_batch)
|
||||||
|
```
|
||||||
|
"""
|
||||||
_CustomBatch = DataclassBatch.from_schema("DefaultBatch", schema=schema)
|
_CustomBatch = DataclassBatch.from_schema("DefaultBatch", schema=schema)
|
||||||
|
|
||||||
def get_imputation_value(pa_type):
|
def get_imputation_value(pa_type):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user