mirror of
https://github.com/twitter/the-algorithm-ml.git
synced 2025-01-09 22:39:22 +01:00
update
This commit is contained in:
parent
590e8b76fe
commit
db4ff958f6
@ -14,7 +14,24 @@ from absl import logging as logging
|
||||
|
||||
|
||||
def setup_absl_logging():
|
||||
"""Make sure that absl logging pushes to stdout rather than stderr."""
|
||||
"""
|
||||
Configure absl-py logging to direct log messages to stdout and apply a custom log message format.
|
||||
|
||||
This function ensures that log messages generated by the absl-py library are written to stdout
|
||||
rather than stderr. It also applies a custom log message format that includes module, function,
|
||||
line number, log level, and the log message content.
|
||||
|
||||
Note:
|
||||
This function should be called once at the beginning of your script or application to
|
||||
configure absl-py logging.
|
||||
|
||||
Example:
|
||||
To use this function, simply call it at the start of your script:
|
||||
```
|
||||
setup_absl_logging()
|
||||
```
|
||||
|
||||
"""
|
||||
logging.get_absl_handler().python_handler.stream = sys.stdout
|
||||
formatter = py_logging.Formatter(
|
||||
fmt="[%(module)s.%(funcName)s:%(lineno)s - %(levelname)s] %(message)s"
|
||||
|
@ -5,6 +5,21 @@ from tml.ml_logging.torch_logging import logging
|
||||
|
||||
class Testtlogging(unittest.TestCase):
|
||||
def test_warn_once(self):
|
||||
"""
|
||||
Test that warning messages are logged only once when using the assertLogs context manager.
|
||||
|
||||
This unit test checks the behavior of the logging system when warning messages are issued
|
||||
multiple times within the same context. It uses the assertLogs context manager to capture
|
||||
log messages at the INFO level and verifies that warning messages are logged only once.
|
||||
|
||||
Example:
|
||||
To use this test case, call it using a test runner like unittest:
|
||||
```
|
||||
python -m unittest your_test_module.TestLogging.test_warn_once
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
with self.assertLogs(level="INFO") as captured_logs:
|
||||
logging.info("first info")
|
||||
logging.warning("first warning")
|
||||
|
@ -18,7 +18,35 @@ import torch.distributed as dist
|
||||
|
||||
|
||||
def rank_specific(logger):
|
||||
"""Ensures that we only override a given logger once."""
|
||||
"""
|
||||
Customize logger behavior based on the distributed environment and rank.
|
||||
|
||||
This function allows for customizing the behavior of a logger based on the distributed environment and the rank
|
||||
of the current process. It overrides standard logging methods (e.g., error, warning) to conditionally log messages
|
||||
depending on the rank or limit the number of redundant logs.
|
||||
|
||||
Args:
|
||||
logger: The logger object to customize.
|
||||
|
||||
Returns:
|
||||
The customized logger.
|
||||
|
||||
Example:
|
||||
To use this function with the `logging` module:
|
||||
```python
|
||||
import logging
|
||||
from rank_specific_logging import rank_specific
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
rank_specific(logger)
|
||||
```
|
||||
|
||||
Customization:
|
||||
- Messages are only logged if the distributed environment is not initialized or if the rank matches.
|
||||
- The 'warning' method is limited to logging a single redundant warning.
|
||||
- Logging from rank -1 is redirected to include the rank information.
|
||||
|
||||
"""
|
||||
if hasattr(logger, "_ALREADY_OVERWRITTEN_TO_BE_RANK_SPECIFIC"):
|
||||
return logger
|
||||
|
||||
|
@ -8,11 +8,60 @@ import pydantic
|
||||
|
||||
|
||||
class PiecewiseConstant(base_config.BaseConfig):
|
||||
"""
|
||||
Configuration for a piecewise constant learning rate schedule.
|
||||
|
||||
This configuration class allows you to specify a piecewise constant learning rate schedule
|
||||
by defining boundaries and corresponding learning rate values.
|
||||
|
||||
Attributes:
|
||||
learning_rate_boundaries (List[int], optional): List of step boundaries at which
|
||||
the learning rate will change. If None, no boundaries are defined.
|
||||
learning_rate_values (List[float], optional): List of learning rate values
|
||||
corresponding to the boundaries. If None, no values are defined.
|
||||
|
||||
Example:
|
||||
To configure a piecewise constant learning rate schedule, create an instance of this class
|
||||
and set the attributes accordingly. For example:
|
||||
|
||||
```python
|
||||
piecewise_lr = PiecewiseConstant(
|
||||
learning_rate_boundaries=[1000, 2000, 3000],
|
||||
learning_rate_values=[0.1, 0.05, 0.01, 0.001]
|
||||
)
|
||||
```
|
||||
|
||||
Note:
|
||||
The number of learning rate values should be one more than the number of boundaries.
|
||||
|
||||
"""
|
||||
learning_rate_boundaries: typing.List[int] = pydantic.Field(None)
|
||||
learning_rate_values: typing.List[float] = pydantic.Field(None)
|
||||
|
||||
|
||||
class LinearRampToConstant(base_config.BaseConfig):
|
||||
"""
|
||||
Configuration for a linear ramp-up to constant learning rate schedule.
|
||||
|
||||
This configuration class allows you to specify a learning rate schedule that ramps up linearly
|
||||
from zero to a constant value over a specified number of steps.
|
||||
|
||||
Attributes:
|
||||
learning_rate (float): The final constant learning rate.
|
||||
num_ramp_steps (PositiveInt): Number of steps to ramp up the learning rate from zero.
|
||||
|
||||
Example:
|
||||
To configure a linear ramp-up to a constant learning rate, create an instance of this class
|
||||
and set the attributes accordingly. For example:
|
||||
|
||||
```python
|
||||
linear_ramp_lr = LinearRampToConstant(
|
||||
learning_rate=0.1,
|
||||
num_ramp_steps=1000
|
||||
)
|
||||
```
|
||||
|
||||
"""
|
||||
learning_rate: float
|
||||
num_ramp_steps: pydantic.PositiveInt = pydantic.Field(
|
||||
description="Number of steps to ramp this up from zero."
|
||||
@ -20,6 +69,32 @@ class LinearRampToConstant(base_config.BaseConfig):
|
||||
|
||||
|
||||
class LinearRampToCosine(base_config.BaseConfig):
|
||||
"""
|
||||
Configuration for a linear ramp-up to cosine decay learning rate schedule.
|
||||
|
||||
This configuration class allows you to specify a learning rate schedule that ramps up linearly
|
||||
from zero, then decays following a cosine schedule to a final constant learning rate.
|
||||
|
||||
Attributes:
|
||||
learning_rate (float): The initial learning rate at the start of ramp-up.
|
||||
final_learning_rate (float): The final constant learning rate after decay.
|
||||
num_ramp_steps (PositiveInt): Number of steps to ramp up the learning rate from zero.
|
||||
final_num_steps (PositiveInt): Final number of steps where decay stops.
|
||||
|
||||
Example:
|
||||
To configure a linear ramp-up to cosine decay learning rate, create an instance of this
|
||||
class and set the attributes accordingly. For example:
|
||||
|
||||
```python
|
||||
ramp_to_cosine_lr = LinearRampToCosine(
|
||||
learning_rate=0.01,
|
||||
final_learning_rate=0.001,
|
||||
num_ramp_steps=1000,
|
||||
final_num_steps=5000
|
||||
)
|
||||
```
|
||||
|
||||
"""
|
||||
learning_rate: float
|
||||
final_learning_rate: float
|
||||
num_ramp_steps: pydantic.PositiveInt = pydantic.Field(
|
||||
@ -31,6 +106,41 @@ class LinearRampToCosine(base_config.BaseConfig):
|
||||
|
||||
|
||||
class LearningRate(base_config.BaseConfig):
|
||||
"""
|
||||
Learning rate configuration for training.
|
||||
|
||||
This configuration class allows you to specify different learning rate schedules
|
||||
for your training process.
|
||||
|
||||
Attributes:
|
||||
constant (float, optional): Constant learning rate to be used throughout training.
|
||||
linear_ramp_to_cosine (LinearRampToCosine, optional): Learning rate that ramps up linearly
|
||||
and then decays following a cosine schedule.
|
||||
linear_ramp_to_constant (LinearRampToConstant, optional): Learning rate that ramps up
|
||||
linearly and then remains constant.
|
||||
piecewise_constant (PiecewiseConstant, optional): Learning rate that changes at specified
|
||||
boundaries with corresponding values.
|
||||
|
||||
Example:
|
||||
To configure a learning rate schedule, create an instance of this class and set the
|
||||
attributes accordingly. For example:
|
||||
|
||||
```python
|
||||
learning_rate = LearningRate(
|
||||
constant=0.01,
|
||||
linear_ramp_to_cosine=LinearRampToCosine(
|
||||
learning_rate=0.1,
|
||||
final_learning_rate=0.001,
|
||||
num_ramp_steps=1000,
|
||||
final_num_steps=5000
|
||||
)
|
||||
)
|
||||
```
|
||||
|
||||
Note:
|
||||
Each learning rate schedule attribute can be set to `None` if not needed.
|
||||
|
||||
"""
|
||||
constant: float = pydantic.Field(None, one_of="lr")
|
||||
linear_ramp_to_cosine: LinearRampToCosine = pydantic.Field(None, one_of="lr")
|
||||
linear_ramp_to_constant: LinearRampToConstant = pydantic.Field(None, one_of="lr")
|
||||
@ -38,30 +148,166 @@ class LearningRate(base_config.BaseConfig):
|
||||
|
||||
|
||||
class OptimizerAlgorithmConfig(base_config.BaseConfig):
|
||||
"""Base class for optimizer configurations."""
|
||||
"""
|
||||
Base class for optimizer configurations.
|
||||
|
||||
This base configuration class provides a structure for specifying various optimizer-related
|
||||
settings, including the learning rate and different learning rate schedules.
|
||||
|
||||
Attributes:
|
||||
lr (float): The base learning rate used by the optimizer.
|
||||
|
||||
Subclasses should inherit from this base class and define additional attributes specific to
|
||||
the optimizer algorithm they represent.
|
||||
|
||||
Example:
|
||||
To create a custom optimizer configuration, create a subclass of this base class and
|
||||
define the necessary attributes. For example:
|
||||
|
||||
```python
|
||||
class MyOptimizerConfig(OptimizerAlgorithmConfig):
|
||||
momentum: float = pydantic.Field(0.9, description="Momentum value for SGD.")
|
||||
```
|
||||
|
||||
Note:
|
||||
This base class does not include specific optimizer settings. Subclasses should define
|
||||
the optimizer-specific attributes as needed.
|
||||
|
||||
"""
|
||||
|
||||
lr: float
|
||||
...
|
||||
|
||||
|
||||
class AdamConfig(OptimizerAlgorithmConfig):
|
||||
# see https://pytorch.org/docs/stable/generated/torch.optim.Adam.html#torch.optim.Adam
|
||||
"""
|
||||
Configuration for the Adam optimizer.
|
||||
|
||||
This configuration class allows you to specify the hyperparameters for the Adam optimizer.
|
||||
|
||||
Attributes:
|
||||
lr (float): The learning rate for optimization.
|
||||
betas (Tuple[float, float], optional): Coefficients used for computing running averages
|
||||
of gradient and squared gradient. Defaults to (0.9, 0.999).
|
||||
eps (float, optional): A small constant added to the denominator for numerical stability.
|
||||
Defaults to 1e-7.
|
||||
|
||||
Example:
|
||||
To configure the Adam optimizer, create an instance of this class and set the attributes
|
||||
accordingly. For example:
|
||||
|
||||
```python
|
||||
adam_optimizer = AdamConfig(
|
||||
lr=0.001,
|
||||
betas=(0.9, 0.999),
|
||||
eps=1e-8
|
||||
)
|
||||
```
|
||||
|
||||
See Also:
|
||||
[PyTorch Adam Documentation](https://pytorch.org/docs/stable/generated/torch.optim.Adam.html#torch.optim.Adam)
|
||||
|
||||
"""
|
||||
lr: float
|
||||
betas: typing.Tuple[float, float] = [0.9, 0.999]
|
||||
eps: float = 1e-7 # Numerical stability in denominator.
|
||||
|
||||
|
||||
class SgdConfig(OptimizerAlgorithmConfig):
|
||||
"""
|
||||
Configuration for the Stochastic Gradient Descent (SGD) optimizer.
|
||||
|
||||
This configuration class allows you to specify the hyperparameters for the SGD optimizer.
|
||||
|
||||
Attributes:
|
||||
lr (float): The learning rate for optimization.
|
||||
momentum (float, optional): The momentum factor for SGD. Defaults to 0.0.
|
||||
|
||||
Example:
|
||||
To configure the SGD optimizer, create an instance of this class and set the attributes
|
||||
accordingly. For example:
|
||||
|
||||
```python
|
||||
sgd_optimizer = SgdConfig(
|
||||
lr=0.01,
|
||||
momentum=0.9
|
||||
)
|
||||
```
|
||||
|
||||
"""
|
||||
lr: float
|
||||
momentum: float = 0.0
|
||||
|
||||
|
||||
class AdagradConfig(OptimizerAlgorithmConfig):
|
||||
"""
|
||||
Configuration for the optimizer used during training.
|
||||
|
||||
This configuration class allows you to specify the optimizer for training, including
|
||||
options for various optimizer algorithms.
|
||||
|
||||
Attributes:
|
||||
learning_rate (LearningRate, optional): Learning rate configuration. Defaults to None.
|
||||
adam (AdamConfig, optional): Configuration for the Adam optimizer. Defaults to None.
|
||||
sgd (SgdConfig, optional): Configuration for the Stochastic Gradient Descent (SGD) optimizer.
|
||||
Defaults to None.
|
||||
adagrad (AdagradConfig, optional): Configuration for the Adagrad optimizer. Defaults to None.
|
||||
|
||||
Example:
|
||||
To configure the optimizer for training, create an instance of this class and set the
|
||||
attributes accordingly. For example:
|
||||
|
||||
```python
|
||||
optimizer_config = OptimizerConfig(
|
||||
learning_rate=LearningRate(constant=0.001),
|
||||
adam=AdamConfig(lr=0.001, betas=(0.9, 0.999), eps=1e-8)
|
||||
)
|
||||
```
|
||||
|
||||
"""
|
||||
lr: float
|
||||
eps: float = 0
|
||||
|
||||
|
||||
class OptimizerConfig(base_config.BaseConfig):
|
||||
"""
|
||||
Configuration for defining different optimizer algorithms and their parameters.
|
||||
|
||||
This class allows you to configure various optimizer algorithms such as Adam, SGD, and Adagrad,
|
||||
along with their respective hyperparameters.
|
||||
|
||||
Args:
|
||||
learning_rate (LearningRate): The learning rate configuration, which can include
|
||||
constant learning rates or other learning rate schedules.
|
||||
adam (AdamConfig): Configuration for the Adam optimizer.
|
||||
sgd (SgdConfig): Configuration for the Stochastic Gradient Descent (SGD) optimizer.
|
||||
adagrad (AdagradConfig): Configuration for the Adagrad optimizer.
|
||||
|
||||
Example:
|
||||
```python
|
||||
optimizer_config = OptimizerConfig(
|
||||
learning_rate=LearningRate(constant=0.001),
|
||||
adam=AdamConfig(lr=0.001, betas=(0.9, 0.999), eps=1e-8),
|
||||
)
|
||||
```
|
||||
|
||||
Attributes:
|
||||
learning_rate (LearningRate): The learning rate configuration.
|
||||
adam (AdamConfig): Configuration for the Adam optimizer.
|
||||
sgd (SgdConfig): Configuration for the Stochastic Gradient Descent (SGD) optimizer.
|
||||
adagrad (AdagradConfig): Configuration for the Adagrad optimizer.
|
||||
|
||||
Note:
|
||||
You can specify only one of the optimizer configurations (adam, sgd, or adagrad) in an
|
||||
`OptimizerConfig` instance.
|
||||
|
||||
See Also:
|
||||
- `LearningRate`: Configuration for specifying learning rates.
|
||||
- `AdamConfig`: Configuration for the Adam optimizer.
|
||||
- `SgdConfig`: Configuration for the Stochastic Gradient Descent (SGD) optimizer.
|
||||
- `AdagradConfig`: Configuration for the Adagrad optimizer.
|
||||
|
||||
"""
|
||||
learning_rate: LearningRate = pydantic.Field(
|
||||
None,
|
||||
description="Constant learning rates",
|
||||
@ -72,6 +318,33 @@ class OptimizerConfig(base_config.BaseConfig):
|
||||
|
||||
|
||||
def get_optimizer_algorithm_config(optimizer_config: OptimizerConfig):
|
||||
"""
|
||||
Get the optimizer algorithm configuration from the given `OptimizerConfig`.
|
||||
|
||||
This function extracts and returns the specific optimizer algorithm configuration
|
||||
(e.g., Adam, SGD, or Adagrad) from the provided `OptimizerConfig`.
|
||||
|
||||
Args:
|
||||
optimizer_config (OptimizerConfig): The optimizer configuration object containing
|
||||
one of the optimizer algorithm configurations.
|
||||
|
||||
Returns:
|
||||
Union[AdamConfig, SgdConfig, AdagradConfig]: The specific optimizer algorithm
|
||||
configuration extracted from `optimizer_config`.
|
||||
|
||||
Raises:
|
||||
ValueError: If no optimizer algorithm is selected in `optimizer_config`.
|
||||
|
||||
Example:
|
||||
```python
|
||||
optimizer_config = OptimizerConfig(
|
||||
adam=AdamConfig(lr=0.001, betas=(0.9, 0.999), eps=1e-8)
|
||||
)
|
||||
algorithm_config = get_optimizer_algorithm_config(optimizer_config)
|
||||
# `algorithm_config` will be an instance of `AdamConfig`.
|
||||
```
|
||||
|
||||
"""
|
||||
if optimizer_config.adam is not None:
|
||||
return optimizer_config.adam
|
||||
elif optimizer_config.sgd is not None:
|
||||
|
@ -14,7 +14,35 @@ from tml.ml_logging.torch_logging import logging
|
||||
|
||||
|
||||
def compute_lr(lr_config, step):
|
||||
"""Compute a learning rate."""
|
||||
"""
|
||||
Compute the learning rate based on the specified learning rate configuration.
|
||||
|
||||
This function calculates the learning rate according to the given configuration, which can include
|
||||
constant learning rates, piecewise constant schedules, linear ramps, and cosine annealing.
|
||||
|
||||
Args:
|
||||
lr_config (LearningRate): The learning rate configuration specifying the learning rate schedule.
|
||||
step (int): The current training step or iteration.
|
||||
|
||||
Returns:
|
||||
float: The computed learning rate for the current step.
|
||||
|
||||
Raises:
|
||||
ValueError: If the `lr_config` is invalid or contains conflicting options.
|
||||
|
||||
Example:
|
||||
```python
|
||||
lr_schedule = LearningRate(
|
||||
constant=0.001,
|
||||
piecewise_constant=PiecewiseConstant(
|
||||
learning_rate_boundaries=[1000, 2000, 3000],
|
||||
learning_rate_values=[0.1, 0.05, 0.01, 0.001]
|
||||
)
|
||||
)
|
||||
current_step = 2500
|
||||
learning_rate = compute_lr(lr_schedule, current_step)
|
||||
```
|
||||
"""
|
||||
if lr_config.constant is not None:
|
||||
return lr_config.constant
|
||||
elif lr_config.piecewise_constant is not None:
|
||||
@ -46,11 +74,54 @@ def compute_lr(lr_config, step):
|
||||
|
||||
|
||||
class LRShim(_LRScheduler):
|
||||
"""Shim to get learning rates into a LRScheduler.
|
||||
|
||||
This adheres to the torch.optim scheduler API and can be plugged anywhere that
|
||||
e.g. exponential decay can be used.
|
||||
"""
|
||||
Learning Rate Scheduler Shim to adjust learning rates during training.
|
||||
|
||||
This class acts as a shim to apply different learning rates to individual parameter groups
|
||||
within an optimizer. It adheres to the torch.optim scheduler API and can be used with various
|
||||
optimizers, allowing fine-grained control over learning rates based on configuration.
|
||||
|
||||
Args:
|
||||
optimizer (torch.optim.Optimizer): The optimizer for which learning rates will be adjusted.
|
||||
lr_dict (Dict[str, LearningRate]): A dictionary mapping parameter group names to their
|
||||
corresponding learning rate configurations.
|
||||
last_epoch (int, optional): The index of the last epoch. Default is -1.
|
||||
verbose (bool, optional): If True, prints a warning message when accessing learning rates
|
||||
using the deprecated `get_lr()` method. Default is False.
|
||||
|
||||
Raises:
|
||||
ValueError: If the number of parameter groups in the optimizer does not match the number
|
||||
of learning rate configurations provided.
|
||||
|
||||
Note:
|
||||
To obtain the last computed learning rates, please use `get_last_lr()`.
|
||||
|
||||
Example:
|
||||
```python
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
|
||||
lr_schedule = {
|
||||
'main': LearningRate(constant=0.01),
|
||||
'auxiliary': LearningRate(piecewise_constant=PiecewiseConstant(
|
||||
learning_rate_boundaries=[1000, 2000],
|
||||
learning_rate_values=[0.01, 0.001]
|
||||
))
|
||||
}
|
||||
lr_shim = LRShim(optimizer, lr_schedule)
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
# Train the model
|
||||
train(...)
|
||||
# Update learning rates at the end of each epoch
|
||||
lr_shim.step(epoch)
|
||||
|
||||
final_lr_main = lr_shim.get_last_lr()['main']
|
||||
final_lr_auxiliary = lr_shim.get_last_lr()['auxiliary']
|
||||
```
|
||||
|
||||
See Also:
|
||||
- `LearningRate`: Configuration for specifying learning rates.
|
||||
- `PiecewiseConstant`: Configuration for piecewise constant learning rate schedules.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -95,9 +166,42 @@ def get_optimizer_class(optimizer_config: OptimizerConfig):
|
||||
def build_optimizer(
|
||||
model: torch.nn.Module, optimizer_config: OptimizerConfig
|
||||
) -> Tuple[Optimizer, _LRScheduler]:
|
||||
"""Builds an optimizer and LR scheduler from an OptimizerConfig.
|
||||
Note: use this when you want the same optimizer and learning rate schedule for all your parameters.
|
||||
"""
|
||||
Build an optimizer and learning rate scheduler based on the provided optimizer configuration.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): The PyTorch model for which the optimizer will be created.
|
||||
optimizer_config (OptimizerConfig): The optimizer configuration specifying the optimizer
|
||||
algorithm and learning rate settings.
|
||||
|
||||
Returns:
|
||||
Tuple[Optimizer, _LRScheduler]: A tuple containing the optimizer and learning rate scheduler
|
||||
objects.
|
||||
|
||||
Note:
|
||||
This function is intended for cases where you want the same optimizer and learning rate
|
||||
schedule for all model parameters.
|
||||
|
||||
Example:
|
||||
```python
|
||||
model = MyModel()
|
||||
optimizer_config = OptimizerConfig(
|
||||
learning_rate=LearningRate(constant=0.01),
|
||||
sgd=SgdConfig(lr=0.01, momentum=0.9)
|
||||
)
|
||||
optimizer, scheduler = build_optimizer(model, optimizer_config)
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
# Train the model with the optimizer
|
||||
train(model, optimizer, ...)
|
||||
# Update learning rates at the end of each epoch
|
||||
scheduler.step(epoch)
|
||||
```
|
||||
|
||||
See Also:
|
||||
- `OptimizerConfig`: Configuration for specifying optimizer settings.
|
||||
- `LRShim`: Learning rate scheduler shim for fine-grained learning rate control.
|
||||
"""
|
||||
optimizer_class = get_optimizer_class(optimizer_config)
|
||||
optimizer = optimizer_class(model.parameters(), **optimizer_config.sgd.dict())
|
||||
# We're passing everything in as one group here
|
||||
|
@ -4,6 +4,17 @@ import pydantic
|
||||
|
||||
|
||||
class TwhinDataConfig(base_config.BaseConfig):
|
||||
"""
|
||||
Configuration for Twhin model training data.
|
||||
|
||||
Args:
|
||||
data_root (str): The root directory for the training data.
|
||||
per_replica_batch_size (pydantic.PositiveInt): Batch size per replica.
|
||||
global_negatives (int): The number of global negatives.
|
||||
in_batch_negatives (int): The number of in-batch negatives.
|
||||
limit (pydantic.PositiveInt): The limit on the number of data points to use.
|
||||
offset (pydantic.PositiveInt, optional): The offset to start reading from. Default is None.
|
||||
"""
|
||||
data_root: str
|
||||
per_replica_batch_size: pydantic.PositiveInt
|
||||
global_negatives: int
|
||||
|
@ -4,6 +4,16 @@ from tml.projects.twhin.data.edges import EdgesDataset
|
||||
|
||||
|
||||
def create_dataset(data_config: TwhinDataConfig, model_config: TwhinModelConfig):
|
||||
"""
|
||||
Create a dataset for Twhin model training.
|
||||
|
||||
Args:
|
||||
data_config (TwhinDataConfig): The data configuration for the dataset.
|
||||
model_config (TwhinModelConfig): The model configuration containing embeddings and relations.
|
||||
|
||||
Returns:
|
||||
EdgesDataset: The dataset for Twhin model training.
|
||||
"""
|
||||
tables = model_config.embeddings.tables
|
||||
table_sizes = {table.name: table.num_embeddings for table in tables}
|
||||
relations = model_config.relations
|
||||
|
@ -15,6 +15,15 @@ from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
|
||||
|
||||
@dataclass
|
||||
class EdgeBatch(DataclassBatch):
|
||||
"""
|
||||
Batch data structure for edge-based models.
|
||||
|
||||
Args:
|
||||
nodes (KeyedJaggedTensor): A KeyedJaggedTensor containing node embeddings.
|
||||
labels (torch.Tensor): Tensor containing labels.
|
||||
rels (torch.Tensor): Tensor containing relation information.
|
||||
weights (torch.Tensor): Tensor containing weights.
|
||||
"""
|
||||
nodes: KeyedJaggedTensor
|
||||
labels: torch.Tensor
|
||||
rels: torch.Tensor
|
||||
@ -22,6 +31,18 @@ class EdgeBatch(DataclassBatch):
|
||||
|
||||
|
||||
class EdgesDataset(Dataset):
|
||||
"""
|
||||
Dataset for edge-based models.
|
||||
|
||||
Args:
|
||||
file_pattern (str): The file pattern for the dataset.
|
||||
table_sizes (Dict[str, int]): A dictionary of table names and their sizes.
|
||||
relations (List[Relation]): A list of relations between tables.
|
||||
lhs_column_name (str): The name of the left-hand-side column.
|
||||
rhs_column_name (str): The name of the right-hand-side column.
|
||||
rel_column_name (str): The name of the relation column.
|
||||
**dataset_kwargs: Additional keyword arguments for the parent Dataset class.
|
||||
"""
|
||||
rng = np.random.default_rng()
|
||||
|
||||
def __init__(
|
||||
@ -56,6 +77,15 @@ class EdgesDataset(Dataset):
|
||||
super().__init__(file_pattern=file_pattern, **dataset_kwargs)
|
||||
|
||||
def pa_to_batch(self, batch: pa.RecordBatch):
|
||||
"""
|
||||
Converts a pyarrow RecordBatch to an EdgeBatch.
|
||||
|
||||
Args:
|
||||
batch (pa.RecordBatch): A pyarrow RecordBatch containing data.
|
||||
|
||||
Returns:
|
||||
EdgeBatch: An EdgeBatch containing node embeddings, labels, relations, and weights.
|
||||
"""
|
||||
lhs = torch.from_numpy(batch.column(self.lhs_column_name).to_numpy())
|
||||
rhs = torch.from_numpy(batch.column(self.rhs_column_name).to_numpy())
|
||||
rel = torch.from_numpy(batch.column(self.rel_column_name).to_numpy())
|
||||
@ -74,6 +104,14 @@ class EdgesDataset(Dataset):
|
||||
) -> Tuple[KeyedJaggedTensor, List[Tuple[int, int]]]:
|
||||
|
||||
"""Process edges that contain lhs index, rhs index, relation index.
|
||||
|
||||
Args:
|
||||
lhs (torch.Tensor): Tensor containing left-hand-side indices.
|
||||
rhs (torch.Tensor): Tensor containing right-hand-side indices.
|
||||
rel (torch.Tensor): Tensor containing relation indices.
|
||||
|
||||
Returns:
|
||||
Tuple[KeyedJaggedTensor, List[Tuple[int, int]]]: A KeyedJaggedTensor and relation index pairs.
|
||||
Example:
|
||||
|
||||
```
|
||||
@ -147,6 +185,12 @@ class EdgesDataset(Dataset):
|
||||
return KeyedJaggedTensor(keys=self.table_names, values=values, lengths=lengths)
|
||||
|
||||
def to_batches(self):
|
||||
"""
|
||||
Converts data to batches.
|
||||
|
||||
Yields:
|
||||
pa.RecordBatch: A pyarrow RecordBatch containing data.
|
||||
"""
|
||||
ds = super().to_batches()
|
||||
batch_size = self._dataset_kwargs["batch_size"]
|
||||
|
||||
|
@ -10,8 +10,29 @@ from pydantic import validator
|
||||
|
||||
|
||||
class TwhinEmbeddingsConfig(LargeEmbeddingsConfig):
|
||||
"""
|
||||
Configuration class for Twhin model embeddings.
|
||||
|
||||
This class inherits from LargeEmbeddingsConfig and ensures that the embedding dimensions and data types
|
||||
for all tables in the Twhin model embeddings configuration match.
|
||||
|
||||
Attributes:
|
||||
tables (List[TableConfig]): A list of table configurations for the model's embeddings.
|
||||
"""
|
||||
@validator("tables")
|
||||
def embedding_dims_match(cls, tables):
|
||||
"""
|
||||
Validate that embedding dimensions and data types match for all tables.
|
||||
|
||||
Args:
|
||||
tables (List[TableConfig]): List of table configurations.
|
||||
|
||||
Returns:
|
||||
List[TableConfig]: The list of validated table configurations.
|
||||
|
||||
Raises:
|
||||
AssertionError: If embedding dimensions or data types do not match.
|
||||
"""
|
||||
embedding_dim = tables[0].embedding_dim
|
||||
data_type = tables[0].data_type
|
||||
for table in tables:
|
||||
@ -21,11 +42,26 @@ class TwhinEmbeddingsConfig(LargeEmbeddingsConfig):
|
||||
|
||||
|
||||
class Operator(str, enum.Enum):
|
||||
"""
|
||||
Enumeration of operator types.
|
||||
|
||||
This enumeration defines different types of operators that can be applied to Twhin model relations.
|
||||
"""
|
||||
TRANSLATION = "translation"
|
||||
|
||||
|
||||
class Relation(pydantic.BaseModel):
|
||||
"""graph relationship properties and operator"""
|
||||
"""
|
||||
Configuration class for graph relationships in the Twhin model.
|
||||
|
||||
This class defines properties and operators for graph relationships in the Twhin model.
|
||||
|
||||
Attributes:
|
||||
name (str): The name of the relationship.
|
||||
lhs (str): The name of the entity on the left-hand side of the relation.
|
||||
rhs (str): The name of the entity on the right-hand side of the relation.
|
||||
operator (Operator): The transformation operator to apply to the left-hand side embedding before dot product.
|
||||
"""
|
||||
|
||||
name: str = pydantic.Field(..., description="Relationship name.")
|
||||
lhs: str = pydantic.Field(
|
||||
@ -42,12 +78,35 @@ class Relation(pydantic.BaseModel):
|
||||
|
||||
|
||||
class TwhinModelConfig(base_config.BaseConfig):
|
||||
"""
|
||||
Configuration class for the Twhin model.
|
||||
|
||||
This class defines configuration options specific to the Twhin model.
|
||||
|
||||
Attributes:
|
||||
embeddings (TwhinEmbeddingsConfig): Configuration for the model's embeddings.
|
||||
relations (List[Relation]): List of graph relationship configurations.
|
||||
translation_optimizer (OptimizerConfig): Configuration for the optimizer used for translation.
|
||||
"""
|
||||
embeddings: TwhinEmbeddingsConfig
|
||||
relations: typing.List[Relation]
|
||||
translation_optimizer: OptimizerConfig
|
||||
|
||||
@validator("relations", each_item=True)
|
||||
def valid_node_types(cls, relation, values, **kwargs):
|
||||
"""
|
||||
Validate that the specified node types in relations are valid table names in embeddings.
|
||||
|
||||
Args:
|
||||
relation (Relation): A single relation configuration.
|
||||
values (dict): The values dictionary containing the "embeddings" configuration.
|
||||
|
||||
Returns:
|
||||
Relation: The validated relation configuration.
|
||||
|
||||
Raises:
|
||||
AssertionError: If the specified node types are not valid table names in embeddings.
|
||||
"""
|
||||
table_names = [table.name for table in values["embeddings"].tables]
|
||||
assert relation.lhs in table_names, f"Invalid lhs node type: {relation.lhs}"
|
||||
assert relation.rhs in table_names, f"Invalid rhs node type: {relation.rhs}"
|
||||
|
@ -14,6 +14,28 @@ from torchrec.optim.apply_optimizer_in_backward import apply_optimizer_in_backwa
|
||||
|
||||
|
||||
class TwhinModel(nn.Module):
|
||||
"""
|
||||
Twhin model for graph-based entity embeddings and translation.
|
||||
|
||||
This class defines the Twhin model, which is used for learning embeddings of entities in a graph
|
||||
and applying translations to these embeddings based on graph relationships.
|
||||
|
||||
Args:
|
||||
model_config (TwhinModelConfig): Configuration for the Twhin model.
|
||||
data_config (TwhinDataConfig): Configuration for the data used by the model.
|
||||
|
||||
Attributes:
|
||||
batch_size (int): The batch size used for training.
|
||||
table_names (List[str]): Names of tables in the model's embeddings.
|
||||
large_embeddings (LargeEmbeddings): LargeEmbeddings instance for entity embeddings.
|
||||
embedding_dim (int): Dimensionality of entity embeddings.
|
||||
num_tables (int): Number of tables in the model's embeddings.
|
||||
in_batch_negatives (int): Number of in-batch negative samples to use during training.
|
||||
global_negatives (int): Number of global negative samples to use during training.
|
||||
num_relations (int): Number of graph relationships in the model.
|
||||
all_trans_embs (torch.nn.Parameter): Parameter tensor for translation embeddings.
|
||||
|
||||
"""
|
||||
def __init__(self, model_config: TwhinModelConfig, data_config: TwhinDataConfig):
|
||||
super().__init__()
|
||||
self.batch_size = data_config.per_replica_batch_size
|
||||
@ -31,7 +53,17 @@ class TwhinModel(nn.Module):
|
||||
)
|
||||
|
||||
def forward(self, batch: EdgeBatch):
|
||||
"""
|
||||
Forward pass of the Twhin model.
|
||||
|
||||
Args:
|
||||
batch (EdgeBatch): Input batch containing graph edge information.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing model output with "logits" and "probabilities".
|
||||
- "logits" (torch.Tensor): Logit scores.
|
||||
- "probabilities" (torch.Tensor): Sigmoid probabilities.
|
||||
"""
|
||||
# B x D
|
||||
trans_embs = self.all_trans_embs.data[batch.rels]
|
||||
|
||||
@ -98,6 +130,18 @@ class TwhinModel(nn.Module):
|
||||
|
||||
|
||||
def apply_optimizers(model: TwhinModel, model_config: TwhinModelConfig):
|
||||
"""
|
||||
Apply optimizers to the Twhin model's embeddings.
|
||||
|
||||
This function applies optimizers to the embeddings of the Twhin model based on the provided configuration.
|
||||
|
||||
Args:
|
||||
model (TwhinModel): The Twhin model to apply optimizers to.
|
||||
model_config (TwhinModelConfig): Configuration for the Twhin model.
|
||||
|
||||
Returns:
|
||||
TwhinModel: The Twhin model with optimizers applied to its embeddings.
|
||||
"""
|
||||
for table in model_config.embeddings.tables:
|
||||
optimizer_class = get_optimizer_class(table.optimizer)
|
||||
optimizer_kwargs = get_optimizer_algorithm_config(table.optimizer).dict()
|
||||
@ -124,10 +168,14 @@ class TwhinModelAndLoss(torch.nn.Module):
|
||||
device: torch.device,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
model: torch module to wrap.
|
||||
loss_fn: Function for calculating loss, should accept logits and labels.
|
||||
"""
|
||||
Initialize a TwhinModelAndLoss module.
|
||||
|
||||
Args:
|
||||
model: The torch module to wrap.
|
||||
loss_fn: A function for calculating loss, should accept logits and labels.
|
||||
data_config: Configuration for Twhin data.
|
||||
device: The torch device to use for calculations.
|
||||
"""
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.loss_fn = loss_fn
|
||||
@ -136,14 +184,21 @@ class TwhinModelAndLoss(torch.nn.Module):
|
||||
self.device = device
|
||||
|
||||
def forward(self, batch: "RecapBatch"): # type: ignore[name-defined]
|
||||
"""Runs model forward and calculates loss according to given loss_fn.
|
||||
|
||||
NOTE: The input signature here needs to be a Pipelineable object for
|
||||
prefetching purposes during training using torchrec's pipeline. However
|
||||
the underlying model signature needs to be exportable to onnx, requiring
|
||||
generic python types. see https://pytorch.org/docs/stable/onnx.html#types.
|
||||
|
||||
"""
|
||||
Run the model forward and calculate the loss according to the given loss_fn.
|
||||
|
||||
NOTE: The input signature here needs to be a Pipelineable object for
|
||||
prefetching purposes during training using torchrec's pipeline. However
|
||||
the underlying model signature needs to be exportable to onnx, requiring
|
||||
generic python types. see https://pytorch.org/docs/stable/onnx.html#types
|
||||
|
||||
Args:
|
||||
batch ("RecapBatch"): The input batch for model inference.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, Dict[str, torch.Tensor]]: A tuple containing the loss tensor and a dictionary of
|
||||
additional outputs including logits, labels, and weights.
|
||||
"""
|
||||
outputs = self.model(batch)
|
||||
logits = outputs["logits"]
|
||||
|
||||
|
@ -18,6 +18,12 @@ EMB_DIM = 128
|
||||
|
||||
|
||||
def twhin_model_config() -> TwhinModelConfig:
|
||||
"""
|
||||
Create a configuration for the Twhin model.
|
||||
|
||||
Returns:
|
||||
TwhinModelConfig: The Twhin model configuration.
|
||||
"""
|
||||
sgd_config_0 = OptimizerConfig(sgd=SgdConfig(lr=0.01))
|
||||
sgd_config_1 = OptimizerConfig(sgd=SgdConfig(lr=0.02))
|
||||
|
||||
@ -52,6 +58,12 @@ def twhin_model_config() -> TwhinModelConfig:
|
||||
|
||||
|
||||
def twhin_data_config() -> TwhinDataConfig:
|
||||
"""
|
||||
Create a configuration for the Twhin data.
|
||||
|
||||
Returns:
|
||||
TwhinDataConfig: The Twhin data configuration.
|
||||
"""
|
||||
data_config = TwhinDataConfig(
|
||||
data_root="/",
|
||||
per_replica_batch_size=10,
|
||||
@ -65,6 +77,15 @@ def twhin_data_config() -> TwhinDataConfig:
|
||||
|
||||
|
||||
def test_twhin_model():
|
||||
"""
|
||||
Test the Twhin model creation and optimization.
|
||||
|
||||
This function creates a Twhin model using the specified configuration and tests its optimization. It also checks
|
||||
the device placement of model parameters.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
model_config = twhin_model_config()
|
||||
loss_fn = F.binary_cross_entropy_with_logits
|
||||
|
||||
|
@ -15,6 +15,14 @@ TRANSLATION_OPT_KEY = "operator_opt"
|
||||
|
||||
|
||||
def _lr_from_config(optimizer_config):
|
||||
"""Get the learning rate from an optimizer configuration.
|
||||
|
||||
Args:
|
||||
optimizer_config: Optimizer configuration.
|
||||
|
||||
Returns:
|
||||
Learning rate from the optimizer configuration.
|
||||
"""
|
||||
if optimizer_config.learning_rate is not None:
|
||||
return optimizer_config.learning_rate
|
||||
else:
|
||||
@ -26,13 +34,13 @@ def _lr_from_config(optimizer_config):
|
||||
def build_optimizer(model: TwhinModel, config: TwhinModelConfig):
|
||||
"""Builds an optimizer for a Twhin model combining the embeddings optimizer with an optimizer for per-relation translations.
|
||||
|
||||
Args:
|
||||
model: TwhinModel to build optimizer for.
|
||||
config: TwhinConfig for model.
|
||||
Args:
|
||||
model: TwhinModel to build optimizer for.
|
||||
config: TwhinModelConfig for model.
|
||||
|
||||
Returns:
|
||||
Optimizer for model.
|
||||
"""
|
||||
Returns:
|
||||
Optimizer for model.
|
||||
"""
|
||||
translation_optimizer_fn = functools.partial(
|
||||
get_optimizer_class(config.translation_optimizer),
|
||||
**get_optimizer_algorithm_config(config.translation_optimizer).dict(),
|
||||
|
@ -37,6 +37,12 @@ def run(
|
||||
all_config: TwhinConfig,
|
||||
save_dir: Optional[str] = None,
|
||||
):
|
||||
"""Run the training process for TwhinModel.
|
||||
|
||||
Args:
|
||||
all_config (TwhinConfig): The configuration for the entire Twhin model.
|
||||
save_dir (str, optional): The directory where model checkpoints will be saved. Defaults to None.
|
||||
"""
|
||||
train_dataset = create_dataset(all_config.train_data, all_config.model)
|
||||
|
||||
if env.is_reader():
|
||||
@ -80,6 +86,11 @@ def run(
|
||||
|
||||
|
||||
def main(argv):
|
||||
"""Main entry point for the Twhin training script.
|
||||
|
||||
Args:
|
||||
argv: Command-line arguments.
|
||||
"""
|
||||
logging.info("Starting")
|
||||
|
||||
logging.info(f"parsing config from {FLAGS.config_yaml_path}...")
|
||||
|
@ -25,14 +25,58 @@ from tml.ml_logging.torch_logging import logging
|
||||
|
||||
|
||||
class _Reader(pa.flight.FlightServerBase):
|
||||
"""Distributed reader flight server wrapping a dataset."""
|
||||
"""
|
||||
Distributed reader flight server wrapping a dataset.
|
||||
|
||||
This class implements a Flight server that wraps a dataset, allowing clients to retrieve data
|
||||
from the dataset over the Flight protocol. It is designed to be used in a distributed environment
|
||||
for efficient data access.
|
||||
|
||||
Args:
|
||||
location (str): The location of the Flight server.
|
||||
ds (Dataset): The dataset to be wrapped by the Flight server.
|
||||
|
||||
Attributes:
|
||||
_location (str): The location of the Flight server.
|
||||
_ds (Dataset): The dataset wrapped by the Flight server.
|
||||
|
||||
Methods:
|
||||
do_get(_, __): Handles Flight requests for data retrieval.
|
||||
|
||||
Note:
|
||||
Flight is an Apache Arrow project that provides a framework for efficient data transfer.
|
||||
This class allows clients to retrieve data from the dataset using Flight.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, location: str, ds: "Dataset"):
|
||||
"""
|
||||
Initialize a new _Reader instance.
|
||||
|
||||
Args:
|
||||
location (str): The location of the Flight server.
|
||||
ds (Dataset): The dataset to be wrapped by the Flight server.
|
||||
"""
|
||||
super().__init__(location=location)
|
||||
self._location = location
|
||||
self._ds = ds
|
||||
|
||||
def do_get(self, _, __):
|
||||
"""
|
||||
Handle Flight requests for data retrieval.
|
||||
|
||||
This method retrieves data from the wrapped dataset and provides it to clients over the Flight protocol.
|
||||
|
||||
Args:
|
||||
_: Unused argument.
|
||||
__: Unused argument.
|
||||
|
||||
Returns:
|
||||
pa.flight.RecordBatchStream: A stream of record batches containing data from the dataset.
|
||||
|
||||
Note:
|
||||
An updated schema (to account for column selection) must be given to the stream.
|
||||
"""
|
||||
# NB: An updated schema (to account for column selection) has to be given the stream.
|
||||
schema = next(iter(self._ds.to_batches())).schema
|
||||
batches = self._ds.to_batches()
|
||||
@ -46,13 +90,49 @@ class _Reader(pa.flight.FlightServerBase):
|
||||
|
||||
|
||||
class Dataset(torch.utils.data.IterableDataset):
|
||||
"""
|
||||
A PyTorch IterableDataset wrapping a Parquet dataset for efficient data loading.
|
||||
|
||||
This class enables efficient loading of data from Parquet files using PyArrow.
|
||||
It is designed to be used as an IterableDataset in PyTorch for training and inference.
|
||||
|
||||
Args:
|
||||
file_pattern (str): A glob pattern specifying the Parquet files to include in the dataset.
|
||||
**dataset_kwargs: Additional keyword arguments passed to PyArrow's `to_batches` method.
|
||||
Refer to https://arrow.apache.org/docs/python/generated/pyarrow.dataset.Scanner.html#pyarrow.dataset.Scanner.from_dataset
|
||||
for more details.
|
||||
|
||||
Attributes:
|
||||
LOCATION (str): The default location for the Flight server used for data distribution.
|
||||
_file_pattern (str): The glob pattern specifying Parquet files in the dataset.
|
||||
_fs: The filesystem object used for file operations.
|
||||
_dataset_kwargs (dict): Additional keyword arguments passed to PyArrow's `to_batches` method.
|
||||
_files (list): A list of file paths matching the glob pattern.
|
||||
_schema (pa.Schema): The schema of the Parquet dataset.
|
||||
|
||||
Methods:
|
||||
serve(): Start serving the dataset using a Flight server.
|
||||
to_batches(): Generate batches of data from the Parquet dataset.
|
||||
pa_to_batch(batch: pa.RecordBatch) -> DataclassBatch: Convert a Parquet RecordBatch to a custom data batch.
|
||||
dataloader(remote: bool = False): Create a PyTorch DataLoader for iterating through the dataset.
|
||||
|
||||
Note:
|
||||
This class efficiently loads data from Parquet files using PyArrow, and it can be used with PyTorch
|
||||
to create DataLoader instances for training or inference.
|
||||
"""
|
||||
LOCATION = "grpc://0.0.0.0:2222"
|
||||
|
||||
def __init__(self, file_pattern: str, **dataset_kwargs) -> None:
|
||||
"""Specify batch size and column to select for.
|
||||
"""
|
||||
Initialize a new Dataset instance. Specify batch size and column to select for.
|
||||
|
||||
Refer to https://arrow.apache.org/docs/python/generated/pyarrow.dataset.Scanner.html#pyarrow.dataset.Scanner.from_dataset.
|
||||
"""
|
||||
|
||||
|
||||
Args:
|
||||
file_pattern (str): A glob pattern specifying the Parquet files to include in the dataset.
|
||||
**dataset_kwargs: Additional keyword arguments passed to PyArrow's `to_batches` method.
|
||||
"""
|
||||
self._file_pattern = file_pattern
|
||||
self._fs = infer_fs(self._file_pattern)
|
||||
self._dataset_kwargs = dataset_kwargs
|
||||
@ -64,16 +144,25 @@ class Dataset(torch.utils.data.IterableDataset):
|
||||
self._validate_columns()
|
||||
|
||||
def _validate_columns(self):
|
||||
"""
|
||||
Validate the specified columns against the dataset schema.
|
||||
|
||||
Raises:
|
||||
Exception: If any specified columns are not found in the dataset schema.
|
||||
"""
|
||||
columns = set(self._dataset_kwargs.get("columns", []))
|
||||
wrong_columns = set(columns) - set(self._schema.names)
|
||||
if wrong_columns:
|
||||
raise Exception(f"Specified columns {list(wrong_columns)} not in schema.")
|
||||
|
||||
def serve(self):
|
||||
"""Start serving the dataset using a Flight server."""
|
||||
self.reader = _Reader(location=self.LOCATION, ds=self)
|
||||
self.reader.serve()
|
||||
|
||||
def _create_dataset(self):
|
||||
"""Create a PyArrow dataset for data retrieval."""
|
||||
|
||||
return pads.dataset(
|
||||
source=random.sample(self._files, len(self._files))[0],
|
||||
format="parquet",
|
||||
@ -100,9 +189,33 @@ class Dataset(torch.utils.data.IterableDataset):
|
||||
|
||||
@abc.abstractmethod
|
||||
def pa_to_batch(self, batch: pa.RecordBatch) -> DataclassBatch:
|
||||
"""
|
||||
Convert a Parquet RecordBatch to a custom data batch.
|
||||
|
||||
Args:
|
||||
batch (pa.RecordBatch): A batch of data from the Parquet dataset.
|
||||
|
||||
Returns:
|
||||
DataclassBatch: A custom data batch used in PyTorch training.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: This method must be implemented in derived classes.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def dataloader(self, remote: bool = False):
|
||||
"""
|
||||
Create a PyTorch DataLoader for iterating through the dataset.
|
||||
|
||||
Args:
|
||||
remote (bool, optional): If True, create a remote DataLoader using Flight for distributed training.
|
||||
|
||||
Returns:
|
||||
DataLoader: A PyTorch DataLoader for iterating through the dataset.
|
||||
|
||||
Note:
|
||||
If `remote` is True, a remote DataLoader is created for distributed training using Flight.
|
||||
"""
|
||||
if not remote:
|
||||
return map(self.pa_to_batch, self.to_batches())
|
||||
readers = get_readers(2)
|
||||
@ -117,6 +230,25 @@ GRPC_OPTIONS = [
|
||||
|
||||
|
||||
def get_readers(num_readers_per_worker: int):
|
||||
"""
|
||||
Get Flight readers for distributed data loading.
|
||||
|
||||
This function retrieves Flight readers for distributed data loading in a PyTorch environment.
|
||||
|
||||
Args:
|
||||
num_readers_per_worker (int): The number of Flight readers to retrieve per worker.
|
||||
|
||||
Returns:
|
||||
List[pa.RecordBatchFileReader]: A list of Flight readers for distributed data loading.
|
||||
|
||||
Note:
|
||||
Flight readers are used to fetch data in a distributed manner for efficient data loading.
|
||||
|
||||
Example:
|
||||
To obtain Flight readers, use the following code:
|
||||
|
||||
>>> readers = get_readers(num_readers_per_worker=2)
|
||||
"""
|
||||
addresses = env.get_flight_server_addresses()
|
||||
|
||||
readers = []
|
||||
|
@ -21,6 +21,16 @@ import torch.distributed as dist
|
||||
|
||||
|
||||
def maybe_start_dataset_service():
|
||||
"""
|
||||
Start the dataset service if readers are available and required dependencies are met.
|
||||
|
||||
This function checks if readers are available and if the required TensorFlow version is >= 2.5.
|
||||
If both conditions are met and the current environment is the dispatcher or reader, it starts
|
||||
the TensorFlow dataset service.
|
||||
|
||||
Raises:
|
||||
Exception: If the required TensorFlow version is not met (>= 2.5).
|
||||
"""
|
||||
if not env.has_readers():
|
||||
return
|
||||
|
||||
@ -59,6 +69,24 @@ def maybe_start_dataset_service():
|
||||
def register_dataset(
|
||||
dataset: tf.data.Dataset, dataset_service: str, compression: Optional[str] = "AUTO"
|
||||
):
|
||||
"""
|
||||
Register a dataset with the distributed dataset service.
|
||||
|
||||
This function registers a dataset with the distributed dataset service and broadcasts the dataset ID
|
||||
and job name to all processes in the distributed environment.
|
||||
|
||||
Args:
|
||||
dataset (tf.data.Dataset): The dataset to be registered.
|
||||
dataset_service (str): The name of the dataset service.
|
||||
compression (Optional[str]): The compression type for the dataset (default is "AUTO").
|
||||
|
||||
Returns:
|
||||
Tuple[int, str]: A tuple containing the dataset ID and job name.
|
||||
|
||||
Note:
|
||||
This function should be called on the rank 0 process.
|
||||
|
||||
"""
|
||||
if dist.get_rank() == 0:
|
||||
dataset_id = _register_dataset(
|
||||
service=dataset_service,
|
||||
@ -82,6 +110,23 @@ def distribute_from_dataset_id(
|
||||
compression: Optional[str] = "AUTO",
|
||||
prefetch: Optional[int] = tf.data.experimental.AUTOTUNE,
|
||||
) -> tf.data.Dataset:
|
||||
"""
|
||||
Distribute a dataset from a registered dataset ID.
|
||||
|
||||
This function consumes a dataset from the distributed dataset service using the provided dataset ID
|
||||
and job name. It also supports prefetching for improved performance.
|
||||
|
||||
Args:
|
||||
dataset_service (str): The name of the dataset service.
|
||||
dataset_id (int): The ID of the dataset to be consumed.
|
||||
job_name (Optional[str]): The name of the job associated with the dataset (optional).
|
||||
compression (Optional[str]): The compression type for the dataset (default is "AUTO").
|
||||
prefetch (Optional[int]): The number of elements to prefetch (default is tf.data.experimental.AUTOTUNE).
|
||||
|
||||
Returns:
|
||||
tf.data.Dataset: The distributed dataset.
|
||||
|
||||
"""
|
||||
logging.info(f"rank{dist.get_rank()}: Consuming dds job with {dataset_id}, {job_name}")
|
||||
dataset = _from_dataset_id(
|
||||
processing_mode="parallel_epochs",
|
||||
@ -97,15 +142,28 @@ def distribute_from_dataset_id(
|
||||
|
||||
|
||||
def maybe_distribute_dataset(dataset: tf.data.Dataset) -> tf.data.Dataset:
|
||||
"""Torch-compatible and distributed-training-aware dataset service distributor.
|
||||
|
||||
- rank 0 process will register the given dataset.
|
||||
- rank 0 process will broadcast job name and dataset id.
|
||||
- all rank processes will consume from the same job/dataset.
|
||||
|
||||
Without this, dataset workers will try to serve 1 job per rank process and OOM.
|
||||
|
||||
"""
|
||||
Distribute a TensorFlow dataset for Torch-compatible and distributed training-aware consumption.
|
||||
|
||||
This function is used to distribute a dataset in a distributed training environment. It performs the
|
||||
following steps:
|
||||
- On the rank 0 process, it registers the given dataset with the distributed dataset service.
|
||||
- It broadcasts the job name and dataset ID to all rank processes.
|
||||
- All rank processes then consume the same dataset from the distributed dataset service.
|
||||
|
||||
Args:
|
||||
dataset (tf.data.Dataset): The TensorFlow dataset to be distributed.
|
||||
|
||||
Returns:
|
||||
tf.data.Dataset: The distributed TensorFlow dataset.
|
||||
|
||||
Note:
|
||||
- If there are no reader processes in the distributed environment, the original dataset is returned
|
||||
without any distribution.
|
||||
- This function is intended for use in distributed training environments to prevent out-of-memory (OOM)
|
||||
issues caused by each rank process trying to serve one job.
|
||||
|
||||
"""
|
||||
if not env.has_readers():
|
||||
return dataset
|
||||
dataset_service = env.get_dds()
|
||||
|
@ -12,6 +12,17 @@ import torch
|
||||
|
||||
|
||||
def create_dataset(tmpdir):
|
||||
"""
|
||||
Create a mock dataset for testing.
|
||||
|
||||
This function creates a mock dataset using PyArrow and Parquet for testing purposes.
|
||||
|
||||
Args:
|
||||
tmpdir: A temporary directory where the dataset will be created.
|
||||
|
||||
Returns:
|
||||
MockDataset: A mock dataset for testing.
|
||||
"""
|
||||
|
||||
table = pa.table(
|
||||
{
|
||||
@ -34,6 +45,14 @@ def create_dataset(tmpdir):
|
||||
|
||||
|
||||
def test_dataset(tmpdir):
|
||||
"""
|
||||
Test the created dataset.
|
||||
|
||||
This function tests the created mock dataset and checks if it behaves as expected.
|
||||
|
||||
Args:
|
||||
tmpdir: A temporary directory used for testing.
|
||||
"""
|
||||
ds = create_dataset(tmpdir)
|
||||
batch = next(iter(ds.dataloader(remote=False)))
|
||||
assert batch.batch_size == 2
|
||||
@ -46,6 +65,14 @@ def test_dataset(tmpdir):
|
||||
reason="Multiprocessing doesn't work on github yet.",
|
||||
)
|
||||
def test_distributed_dataset(tmpdir):
|
||||
"""
|
||||
Test the distributed dataset.
|
||||
|
||||
This function tests the distributed version of the mock dataset using multiprocessing.
|
||||
|
||||
Args:
|
||||
tmpdir: A temporary directory used for testing.
|
||||
"""
|
||||
MOCK_ENV = {"TEMP_SLURM_NUM_READERS": "1"}
|
||||
|
||||
def _client():
|
||||
|
151
reader/utils.py
151
reader/utils.py
@ -11,11 +11,55 @@ import torch
|
||||
|
||||
|
||||
def roundrobin(*iterables):
|
||||
"""Round robin through provided iterables, useful for simple load balancing.
|
||||
|
||||
Adapted from https://docs.python.org/3/library/itertools.html.
|
||||
|
||||
"""
|
||||
Iterate through provided iterables in a round-robin fashion.
|
||||
|
||||
This function takes multiple iterables and returns an iterator that yields elements from
|
||||
each iterable in a round-robin manner. It continues cycling through the iterables until
|
||||
all of them are exhausted.
|
||||
|
||||
Adapted from https://docs.python.org/3/library/itertools.html.
|
||||
|
||||
Args:
|
||||
*iterables: One or more iterable objects to iterate through.
|
||||
|
||||
Yields:
|
||||
Elements from the provided iterables in a round-robin fashion.
|
||||
|
||||
Raises:
|
||||
StopIteration: If all provided iterables are exhausted.
|
||||
|
||||
Example:
|
||||
```python
|
||||
iterable1 = [1, 2, 3]
|
||||
iterable2 = ['a', 'b', 'c']
|
||||
iterable3 = [0.1, 0.2, 0.3]
|
||||
|
||||
for item in roundrobin(iterable1, iterable2, iterable3):
|
||||
print(item)
|
||||
|
||||
# Output:
|
||||
# 1
|
||||
# 'a'
|
||||
# 0.1
|
||||
# 2
|
||||
# 'b'
|
||||
# 0.2
|
||||
# 3
|
||||
# 'c'
|
||||
# 0.3
|
||||
```
|
||||
|
||||
Note:
|
||||
- If one of the provided iterables is shorter than the others, the function will
|
||||
continue iterating through the remaining iterables until all are exhausted.
|
||||
- If an iterable raises an exception during iteration, a warning message is logged,
|
||||
and the function continues with the next iterable.
|
||||
|
||||
See Also:
|
||||
- `itertools.cycle`: A function that repeatedly cycles through elements of an iterable.
|
||||
- `itertools.islice`: A function to slice an iterable to limit the number of iterations.
|
||||
"""
|
||||
num_active = len(iterables)
|
||||
nexts = itertools.cycle(iter(it).__next__ for it in iterables)
|
||||
while num_active:
|
||||
@ -35,6 +79,48 @@ def roundrobin(*iterables):
|
||||
|
||||
|
||||
def speed_check(data_loader, max_steps: int, frequency: int, peek: Optional[int]):
|
||||
"""
|
||||
Monitor the speed and progress of data loading using a data loader.
|
||||
|
||||
This function iterates through a data loader for a specified number of steps or until
|
||||
the end of the data loader is reached, periodically logging progress information.
|
||||
|
||||
Args:
|
||||
data_loader: The data loader to monitor.
|
||||
max_steps: The maximum number of steps to iterate through the data loader.
|
||||
frequency: The frequency (in steps) at which to log progress.
|
||||
peek (optional): If specified, it indicates the frequency (in steps) at which to log
|
||||
batch contents for inspection.
|
||||
|
||||
Example:
|
||||
```python
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
# Create a data loader (replace with your own DataLoader configuration)
|
||||
data_loader = DataLoader(dataset, batch_size=32, shuffle=True)
|
||||
|
||||
# Monitor data loading speed and progress
|
||||
speed_check(data_loader, max_steps=1000, frequency=50, peek=500)
|
||||
```
|
||||
|
||||
Args:
|
||||
data_loader: The data loader to monitor.
|
||||
max_steps: The maximum number of steps to iterate through the data loader.
|
||||
frequency: The frequency (in steps) at which to log progress.
|
||||
peek (optional): If specified, it indicates the frequency (in steps) at which to log
|
||||
batch contents for inspection.
|
||||
|
||||
Note:
|
||||
- The function logs information about elapsed time, the number of examples processed,
|
||||
and the processing speed in examples per second.
|
||||
- If `peek` is provided, batch contents will be logged for inspection at the specified
|
||||
frequency.
|
||||
|
||||
See Also:
|
||||
- `torch.utils.data.DataLoader`: PyTorch's data loading utility for batching and
|
||||
iterating through datasets.
|
||||
"""
|
||||
num_examples = 0
|
||||
prev = time.perf_counter()
|
||||
for idx, batch in enumerate(data_loader):
|
||||
@ -57,11 +143,66 @@ def speed_check(data_loader, max_steps: int, frequency: int, peek: Optional[int]
|
||||
|
||||
|
||||
def pa_to_torch(array: pa.array) -> torch.Tensor:
|
||||
"""
|
||||
Convert a PyArrow Array to a PyTorch Tensor.
|
||||
|
||||
Args:
|
||||
array (pa.array): The PyArrow Array to convert.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: A PyTorch Tensor containing the data from the input PyArrow Array.
|
||||
|
||||
Example:
|
||||
```python
|
||||
import pyarrow as pa
|
||||
import torch
|
||||
|
||||
# Create a PyArrow Array
|
||||
arrow_array = pa.array([1, 2, 3])
|
||||
|
||||
# Convert it to a PyTorch Tensor
|
||||
torch_tensor = pa_to_torch(arrow_array)
|
||||
```
|
||||
"""
|
||||
return torch.from_numpy(array.to_numpy())
|
||||
|
||||
|
||||
def create_default_pa_to_batch(schema) -> DataclassBatch:
|
||||
""" """
|
||||
"""
|
||||
Create a function that converts a PyArrow RecordBatch to a custom DataclassBatch with imputed values for missing data.
|
||||
|
||||
Args:
|
||||
schema (pa.Schema): The PyArrow schema describing the data structure of the RecordBatch.
|
||||
|
||||
Returns:
|
||||
callable: A function that takes a PyArrow RecordBatch as input and returns a custom DataclassBatch.
|
||||
|
||||
Example:
|
||||
```python
|
||||
import pyarrow as pa
|
||||
from dataclass_batch import DataclassBatch
|
||||
|
||||
# Define a PyArrow schema
|
||||
schema = pa.schema([
|
||||
("feature1", pa.float64()),
|
||||
("feature2", pa.int64()),
|
||||
("label", pa.int64()),
|
||||
])
|
||||
|
||||
# Create the conversion function
|
||||
pa_to_batch = create_default_pa_to_batch(schema)
|
||||
|
||||
# Create a PyArrow RecordBatch
|
||||
record_batch = pa.RecordBatch.from_pandas(pd.DataFrame({
|
||||
"feature1": [1.0, 2.0, None],
|
||||
"feature2": [10, 20, 30],
|
||||
"label": [0, 1, None],
|
||||
}))
|
||||
|
||||
# Convert the RecordBatch to a custom DataclassBatch
|
||||
custom_batch = pa_to_batch(record_batch)
|
||||
```
|
||||
"""
|
||||
_CustomBatch = DataclassBatch.from_schema("DefaultBatch", schema=schema)
|
||||
|
||||
def get_imputation_value(pa_type):
|
||||
|
Loading…
x
Reference in New Issue
Block a user