This commit is contained in:
rajveer43 2023-09-13 11:22:13 +05:30
parent 590e8b76fe
commit db4ff958f6
17 changed files with 1059 additions and 45 deletions

View File

@ -14,7 +14,24 @@ from absl import logging as logging
def setup_absl_logging():
"""Make sure that absl logging pushes to stdout rather than stderr."""
"""
Configure absl-py logging to direct log messages to stdout and apply a custom log message format.
This function ensures that log messages generated by the absl-py library are written to stdout
rather than stderr. It also applies a custom log message format that includes module, function,
line number, log level, and the log message content.
Note:
This function should be called once at the beginning of your script or application to
configure absl-py logging.
Example:
To use this function, simply call it at the start of your script:
```
setup_absl_logging()
```
"""
logging.get_absl_handler().python_handler.stream = sys.stdout
formatter = py_logging.Formatter(
fmt="[%(module)s.%(funcName)s:%(lineno)s - %(levelname)s] %(message)s"

View File

@ -5,6 +5,21 @@ from tml.ml_logging.torch_logging import logging
class Testtlogging(unittest.TestCase):
def test_warn_once(self):
"""
Test that warning messages are logged only once when using the assertLogs context manager.
This unit test checks the behavior of the logging system when warning messages are issued
multiple times within the same context. It uses the assertLogs context manager to capture
log messages at the INFO level and verifies that warning messages are logged only once.
Example:
To use this test case, call it using a test runner like unittest:
```
python -m unittest your_test_module.TestLogging.test_warn_once
```
"""
with self.assertLogs(level="INFO") as captured_logs:
logging.info("first info")
logging.warning("first warning")

View File

@ -18,7 +18,35 @@ import torch.distributed as dist
def rank_specific(logger):
"""Ensures that we only override a given logger once."""
"""
Customize logger behavior based on the distributed environment and rank.
This function allows for customizing the behavior of a logger based on the distributed environment and the rank
of the current process. It overrides standard logging methods (e.g., error, warning) to conditionally log messages
depending on the rank or limit the number of redundant logs.
Args:
logger: The logger object to customize.
Returns:
The customized logger.
Example:
To use this function with the `logging` module:
```python
import logging
from rank_specific_logging import rank_specific
logger = logging.getLogger(__name__)
rank_specific(logger)
```
Customization:
- Messages are only logged if the distributed environment is not initialized or if the rank matches.
- The 'warning' method is limited to logging a single redundant warning.
- Logging from rank -1 is redirected to include the rank information.
"""
if hasattr(logger, "_ALREADY_OVERWRITTEN_TO_BE_RANK_SPECIFIC"):
return logger

View File

@ -8,11 +8,60 @@ import pydantic
class PiecewiseConstant(base_config.BaseConfig):
"""
Configuration for a piecewise constant learning rate schedule.
This configuration class allows you to specify a piecewise constant learning rate schedule
by defining boundaries and corresponding learning rate values.
Attributes:
learning_rate_boundaries (List[int], optional): List of step boundaries at which
the learning rate will change. If None, no boundaries are defined.
learning_rate_values (List[float], optional): List of learning rate values
corresponding to the boundaries. If None, no values are defined.
Example:
To configure a piecewise constant learning rate schedule, create an instance of this class
and set the attributes accordingly. For example:
```python
piecewise_lr = PiecewiseConstant(
learning_rate_boundaries=[1000, 2000, 3000],
learning_rate_values=[0.1, 0.05, 0.01, 0.001]
)
```
Note:
The number of learning rate values should be one more than the number of boundaries.
"""
learning_rate_boundaries: typing.List[int] = pydantic.Field(None)
learning_rate_values: typing.List[float] = pydantic.Field(None)
class LinearRampToConstant(base_config.BaseConfig):
"""
Configuration for a linear ramp-up to constant learning rate schedule.
This configuration class allows you to specify a learning rate schedule that ramps up linearly
from zero to a constant value over a specified number of steps.
Attributes:
learning_rate (float): The final constant learning rate.
num_ramp_steps (PositiveInt): Number of steps to ramp up the learning rate from zero.
Example:
To configure a linear ramp-up to a constant learning rate, create an instance of this class
and set the attributes accordingly. For example:
```python
linear_ramp_lr = LinearRampToConstant(
learning_rate=0.1,
num_ramp_steps=1000
)
```
"""
learning_rate: float
num_ramp_steps: pydantic.PositiveInt = pydantic.Field(
description="Number of steps to ramp this up from zero."
@ -20,6 +69,32 @@ class LinearRampToConstant(base_config.BaseConfig):
class LinearRampToCosine(base_config.BaseConfig):
"""
Configuration for a linear ramp-up to cosine decay learning rate schedule.
This configuration class allows you to specify a learning rate schedule that ramps up linearly
from zero, then decays following a cosine schedule to a final constant learning rate.
Attributes:
learning_rate (float): The initial learning rate at the start of ramp-up.
final_learning_rate (float): The final constant learning rate after decay.
num_ramp_steps (PositiveInt): Number of steps to ramp up the learning rate from zero.
final_num_steps (PositiveInt): Final number of steps where decay stops.
Example:
To configure a linear ramp-up to cosine decay learning rate, create an instance of this
class and set the attributes accordingly. For example:
```python
ramp_to_cosine_lr = LinearRampToCosine(
learning_rate=0.01,
final_learning_rate=0.001,
num_ramp_steps=1000,
final_num_steps=5000
)
```
"""
learning_rate: float
final_learning_rate: float
num_ramp_steps: pydantic.PositiveInt = pydantic.Field(
@ -31,6 +106,41 @@ class LinearRampToCosine(base_config.BaseConfig):
class LearningRate(base_config.BaseConfig):
"""
Learning rate configuration for training.
This configuration class allows you to specify different learning rate schedules
for your training process.
Attributes:
constant (float, optional): Constant learning rate to be used throughout training.
linear_ramp_to_cosine (LinearRampToCosine, optional): Learning rate that ramps up linearly
and then decays following a cosine schedule.
linear_ramp_to_constant (LinearRampToConstant, optional): Learning rate that ramps up
linearly and then remains constant.
piecewise_constant (PiecewiseConstant, optional): Learning rate that changes at specified
boundaries with corresponding values.
Example:
To configure a learning rate schedule, create an instance of this class and set the
attributes accordingly. For example:
```python
learning_rate = LearningRate(
constant=0.01,
linear_ramp_to_cosine=LinearRampToCosine(
learning_rate=0.1,
final_learning_rate=0.001,
num_ramp_steps=1000,
final_num_steps=5000
)
)
```
Note:
Each learning rate schedule attribute can be set to `None` if not needed.
"""
constant: float = pydantic.Field(None, one_of="lr")
linear_ramp_to_cosine: LinearRampToCosine = pydantic.Field(None, one_of="lr")
linear_ramp_to_constant: LinearRampToConstant = pydantic.Field(None, one_of="lr")
@ -38,30 +148,166 @@ class LearningRate(base_config.BaseConfig):
class OptimizerAlgorithmConfig(base_config.BaseConfig):
"""Base class for optimizer configurations."""
"""
Base class for optimizer configurations.
This base configuration class provides a structure for specifying various optimizer-related
settings, including the learning rate and different learning rate schedules.
Attributes:
lr (float): The base learning rate used by the optimizer.
Subclasses should inherit from this base class and define additional attributes specific to
the optimizer algorithm they represent.
Example:
To create a custom optimizer configuration, create a subclass of this base class and
define the necessary attributes. For example:
```python
class MyOptimizerConfig(OptimizerAlgorithmConfig):
momentum: float = pydantic.Field(0.9, description="Momentum value for SGD.")
```
Note:
This base class does not include specific optimizer settings. Subclasses should define
the optimizer-specific attributes as needed.
"""
lr: float
...
class AdamConfig(OptimizerAlgorithmConfig):
# see https://pytorch.org/docs/stable/generated/torch.optim.Adam.html#torch.optim.Adam
"""
Configuration for the Adam optimizer.
This configuration class allows you to specify the hyperparameters for the Adam optimizer.
Attributes:
lr (float): The learning rate for optimization.
betas (Tuple[float, float], optional): Coefficients used for computing running averages
of gradient and squared gradient. Defaults to (0.9, 0.999).
eps (float, optional): A small constant added to the denominator for numerical stability.
Defaults to 1e-7.
Example:
To configure the Adam optimizer, create an instance of this class and set the attributes
accordingly. For example:
```python
adam_optimizer = AdamConfig(
lr=0.001,
betas=(0.9, 0.999),
eps=1e-8
)
```
See Also:
[PyTorch Adam Documentation](https://pytorch.org/docs/stable/generated/torch.optim.Adam.html#torch.optim.Adam)
"""
lr: float
betas: typing.Tuple[float, float] = [0.9, 0.999]
eps: float = 1e-7 # Numerical stability in denominator.
class SgdConfig(OptimizerAlgorithmConfig):
"""
Configuration for the Stochastic Gradient Descent (SGD) optimizer.
This configuration class allows you to specify the hyperparameters for the SGD optimizer.
Attributes:
lr (float): The learning rate for optimization.
momentum (float, optional): The momentum factor for SGD. Defaults to 0.0.
Example:
To configure the SGD optimizer, create an instance of this class and set the attributes
accordingly. For example:
```python
sgd_optimizer = SgdConfig(
lr=0.01,
momentum=0.9
)
```
"""
lr: float
momentum: float = 0.0
class AdagradConfig(OptimizerAlgorithmConfig):
"""
Configuration for the optimizer used during training.
This configuration class allows you to specify the optimizer for training, including
options for various optimizer algorithms.
Attributes:
learning_rate (LearningRate, optional): Learning rate configuration. Defaults to None.
adam (AdamConfig, optional): Configuration for the Adam optimizer. Defaults to None.
sgd (SgdConfig, optional): Configuration for the Stochastic Gradient Descent (SGD) optimizer.
Defaults to None.
adagrad (AdagradConfig, optional): Configuration for the Adagrad optimizer. Defaults to None.
Example:
To configure the optimizer for training, create an instance of this class and set the
attributes accordingly. For example:
```python
optimizer_config = OptimizerConfig(
learning_rate=LearningRate(constant=0.001),
adam=AdamConfig(lr=0.001, betas=(0.9, 0.999), eps=1e-8)
)
```
"""
lr: float
eps: float = 0
class OptimizerConfig(base_config.BaseConfig):
"""
Configuration for defining different optimizer algorithms and their parameters.
This class allows you to configure various optimizer algorithms such as Adam, SGD, and Adagrad,
along with their respective hyperparameters.
Args:
learning_rate (LearningRate): The learning rate configuration, which can include
constant learning rates or other learning rate schedules.
adam (AdamConfig): Configuration for the Adam optimizer.
sgd (SgdConfig): Configuration for the Stochastic Gradient Descent (SGD) optimizer.
adagrad (AdagradConfig): Configuration for the Adagrad optimizer.
Example:
```python
optimizer_config = OptimizerConfig(
learning_rate=LearningRate(constant=0.001),
adam=AdamConfig(lr=0.001, betas=(0.9, 0.999), eps=1e-8),
)
```
Attributes:
learning_rate (LearningRate): The learning rate configuration.
adam (AdamConfig): Configuration for the Adam optimizer.
sgd (SgdConfig): Configuration for the Stochastic Gradient Descent (SGD) optimizer.
adagrad (AdagradConfig): Configuration for the Adagrad optimizer.
Note:
You can specify only one of the optimizer configurations (adam, sgd, or adagrad) in an
`OptimizerConfig` instance.
See Also:
- `LearningRate`: Configuration for specifying learning rates.
- `AdamConfig`: Configuration for the Adam optimizer.
- `SgdConfig`: Configuration for the Stochastic Gradient Descent (SGD) optimizer.
- `AdagradConfig`: Configuration for the Adagrad optimizer.
"""
learning_rate: LearningRate = pydantic.Field(
None,
description="Constant learning rates",
@ -72,6 +318,33 @@ class OptimizerConfig(base_config.BaseConfig):
def get_optimizer_algorithm_config(optimizer_config: OptimizerConfig):
"""
Get the optimizer algorithm configuration from the given `OptimizerConfig`.
This function extracts and returns the specific optimizer algorithm configuration
(e.g., Adam, SGD, or Adagrad) from the provided `OptimizerConfig`.
Args:
optimizer_config (OptimizerConfig): The optimizer configuration object containing
one of the optimizer algorithm configurations.
Returns:
Union[AdamConfig, SgdConfig, AdagradConfig]: The specific optimizer algorithm
configuration extracted from `optimizer_config`.
Raises:
ValueError: If no optimizer algorithm is selected in `optimizer_config`.
Example:
```python
optimizer_config = OptimizerConfig(
adam=AdamConfig(lr=0.001, betas=(0.9, 0.999), eps=1e-8)
)
algorithm_config = get_optimizer_algorithm_config(optimizer_config)
# `algorithm_config` will be an instance of `AdamConfig`.
```
"""
if optimizer_config.adam is not None:
return optimizer_config.adam
elif optimizer_config.sgd is not None:

View File

@ -14,7 +14,35 @@ from tml.ml_logging.torch_logging import logging
def compute_lr(lr_config, step):
"""Compute a learning rate."""
"""
Compute the learning rate based on the specified learning rate configuration.
This function calculates the learning rate according to the given configuration, which can include
constant learning rates, piecewise constant schedules, linear ramps, and cosine annealing.
Args:
lr_config (LearningRate): The learning rate configuration specifying the learning rate schedule.
step (int): The current training step or iteration.
Returns:
float: The computed learning rate for the current step.
Raises:
ValueError: If the `lr_config` is invalid or contains conflicting options.
Example:
```python
lr_schedule = LearningRate(
constant=0.001,
piecewise_constant=PiecewiseConstant(
learning_rate_boundaries=[1000, 2000, 3000],
learning_rate_values=[0.1, 0.05, 0.01, 0.001]
)
)
current_step = 2500
learning_rate = compute_lr(lr_schedule, current_step)
```
"""
if lr_config.constant is not None:
return lr_config.constant
elif lr_config.piecewise_constant is not None:
@ -46,11 +74,54 @@ def compute_lr(lr_config, step):
class LRShim(_LRScheduler):
"""Shim to get learning rates into a LRScheduler.
This adheres to the torch.optim scheduler API and can be plugged anywhere that
e.g. exponential decay can be used.
"""
Learning Rate Scheduler Shim to adjust learning rates during training.
This class acts as a shim to apply different learning rates to individual parameter groups
within an optimizer. It adheres to the torch.optim scheduler API and can be used with various
optimizers, allowing fine-grained control over learning rates based on configuration.
Args:
optimizer (torch.optim.Optimizer): The optimizer for which learning rates will be adjusted.
lr_dict (Dict[str, LearningRate]): A dictionary mapping parameter group names to their
corresponding learning rate configurations.
last_epoch (int, optional): The index of the last epoch. Default is -1.
verbose (bool, optional): If True, prints a warning message when accessing learning rates
using the deprecated `get_lr()` method. Default is False.
Raises:
ValueError: If the number of parameter groups in the optimizer does not match the number
of learning rate configurations provided.
Note:
To obtain the last computed learning rates, please use `get_last_lr()`.
Example:
```python
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
lr_schedule = {
'main': LearningRate(constant=0.01),
'auxiliary': LearningRate(piecewise_constant=PiecewiseConstant(
learning_rate_boundaries=[1000, 2000],
learning_rate_values=[0.01, 0.001]
))
}
lr_shim = LRShim(optimizer, lr_schedule)
for epoch in range(num_epochs):
# Train the model
train(...)
# Update learning rates at the end of each epoch
lr_shim.step(epoch)
final_lr_main = lr_shim.get_last_lr()['main']
final_lr_auxiliary = lr_shim.get_last_lr()['auxiliary']
```
See Also:
- `LearningRate`: Configuration for specifying learning rates.
- `PiecewiseConstant`: Configuration for piecewise constant learning rate schedules.
"""
def __init__(
self,
@ -95,9 +166,42 @@ def get_optimizer_class(optimizer_config: OptimizerConfig):
def build_optimizer(
model: torch.nn.Module, optimizer_config: OptimizerConfig
) -> Tuple[Optimizer, _LRScheduler]:
"""Builds an optimizer and LR scheduler from an OptimizerConfig.
Note: use this when you want the same optimizer and learning rate schedule for all your parameters.
"""
Build an optimizer and learning rate scheduler based on the provided optimizer configuration.
Args:
model (torch.nn.Module): The PyTorch model for which the optimizer will be created.
optimizer_config (OptimizerConfig): The optimizer configuration specifying the optimizer
algorithm and learning rate settings.
Returns:
Tuple[Optimizer, _LRScheduler]: A tuple containing the optimizer and learning rate scheduler
objects.
Note:
This function is intended for cases where you want the same optimizer and learning rate
schedule for all model parameters.
Example:
```python
model = MyModel()
optimizer_config = OptimizerConfig(
learning_rate=LearningRate(constant=0.01),
sgd=SgdConfig(lr=0.01, momentum=0.9)
)
optimizer, scheduler = build_optimizer(model, optimizer_config)
for epoch in range(num_epochs):
# Train the model with the optimizer
train(model, optimizer, ...)
# Update learning rates at the end of each epoch
scheduler.step(epoch)
```
See Also:
- `OptimizerConfig`: Configuration for specifying optimizer settings.
- `LRShim`: Learning rate scheduler shim for fine-grained learning rate control.
"""
optimizer_class = get_optimizer_class(optimizer_config)
optimizer = optimizer_class(model.parameters(), **optimizer_config.sgd.dict())
# We're passing everything in as one group here

View File

@ -4,6 +4,17 @@ import pydantic
class TwhinDataConfig(base_config.BaseConfig):
"""
Configuration for Twhin model training data.
Args:
data_root (str): The root directory for the training data.
per_replica_batch_size (pydantic.PositiveInt): Batch size per replica.
global_negatives (int): The number of global negatives.
in_batch_negatives (int): The number of in-batch negatives.
limit (pydantic.PositiveInt): The limit on the number of data points to use.
offset (pydantic.PositiveInt, optional): The offset to start reading from. Default is None.
"""
data_root: str
per_replica_batch_size: pydantic.PositiveInt
global_negatives: int

View File

@ -4,6 +4,16 @@ from tml.projects.twhin.data.edges import EdgesDataset
def create_dataset(data_config: TwhinDataConfig, model_config: TwhinModelConfig):
"""
Create a dataset for Twhin model training.
Args:
data_config (TwhinDataConfig): The data configuration for the dataset.
model_config (TwhinModelConfig): The model configuration containing embeddings and relations.
Returns:
EdgesDataset: The dataset for Twhin model training.
"""
tables = model_config.embeddings.tables
table_sizes = {table.name: table.num_embeddings for table in tables}
relations = model_config.relations

View File

@ -15,6 +15,15 @@ from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
@dataclass
class EdgeBatch(DataclassBatch):
"""
Batch data structure for edge-based models.
Args:
nodes (KeyedJaggedTensor): A KeyedJaggedTensor containing node embeddings.
labels (torch.Tensor): Tensor containing labels.
rels (torch.Tensor): Tensor containing relation information.
weights (torch.Tensor): Tensor containing weights.
"""
nodes: KeyedJaggedTensor
labels: torch.Tensor
rels: torch.Tensor
@ -22,6 +31,18 @@ class EdgeBatch(DataclassBatch):
class EdgesDataset(Dataset):
"""
Dataset for edge-based models.
Args:
file_pattern (str): The file pattern for the dataset.
table_sizes (Dict[str, int]): A dictionary of table names and their sizes.
relations (List[Relation]): A list of relations between tables.
lhs_column_name (str): The name of the left-hand-side column.
rhs_column_name (str): The name of the right-hand-side column.
rel_column_name (str): The name of the relation column.
**dataset_kwargs: Additional keyword arguments for the parent Dataset class.
"""
rng = np.random.default_rng()
def __init__(
@ -56,6 +77,15 @@ class EdgesDataset(Dataset):
super().__init__(file_pattern=file_pattern, **dataset_kwargs)
def pa_to_batch(self, batch: pa.RecordBatch):
"""
Converts a pyarrow RecordBatch to an EdgeBatch.
Args:
batch (pa.RecordBatch): A pyarrow RecordBatch containing data.
Returns:
EdgeBatch: An EdgeBatch containing node embeddings, labels, relations, and weights.
"""
lhs = torch.from_numpy(batch.column(self.lhs_column_name).to_numpy())
rhs = torch.from_numpy(batch.column(self.rhs_column_name).to_numpy())
rel = torch.from_numpy(batch.column(self.rel_column_name).to_numpy())
@ -74,6 +104,14 @@ class EdgesDataset(Dataset):
) -> Tuple[KeyedJaggedTensor, List[Tuple[int, int]]]:
"""Process edges that contain lhs index, rhs index, relation index.
Args:
lhs (torch.Tensor): Tensor containing left-hand-side indices.
rhs (torch.Tensor): Tensor containing right-hand-side indices.
rel (torch.Tensor): Tensor containing relation indices.
Returns:
Tuple[KeyedJaggedTensor, List[Tuple[int, int]]]: A KeyedJaggedTensor and relation index pairs.
Example:
```
@ -147,6 +185,12 @@ class EdgesDataset(Dataset):
return KeyedJaggedTensor(keys=self.table_names, values=values, lengths=lengths)
def to_batches(self):
"""
Converts data to batches.
Yields:
pa.RecordBatch: A pyarrow RecordBatch containing data.
"""
ds = super().to_batches()
batch_size = self._dataset_kwargs["batch_size"]

View File

@ -10,8 +10,29 @@ from pydantic import validator
class TwhinEmbeddingsConfig(LargeEmbeddingsConfig):
"""
Configuration class for Twhin model embeddings.
This class inherits from LargeEmbeddingsConfig and ensures that the embedding dimensions and data types
for all tables in the Twhin model embeddings configuration match.
Attributes:
tables (List[TableConfig]): A list of table configurations for the model's embeddings.
"""
@validator("tables")
def embedding_dims_match(cls, tables):
"""
Validate that embedding dimensions and data types match for all tables.
Args:
tables (List[TableConfig]): List of table configurations.
Returns:
List[TableConfig]: The list of validated table configurations.
Raises:
AssertionError: If embedding dimensions or data types do not match.
"""
embedding_dim = tables[0].embedding_dim
data_type = tables[0].data_type
for table in tables:
@ -21,11 +42,26 @@ class TwhinEmbeddingsConfig(LargeEmbeddingsConfig):
class Operator(str, enum.Enum):
"""
Enumeration of operator types.
This enumeration defines different types of operators that can be applied to Twhin model relations.
"""
TRANSLATION = "translation"
class Relation(pydantic.BaseModel):
"""graph relationship properties and operator"""
"""
Configuration class for graph relationships in the Twhin model.
This class defines properties and operators for graph relationships in the Twhin model.
Attributes:
name (str): The name of the relationship.
lhs (str): The name of the entity on the left-hand side of the relation.
rhs (str): The name of the entity on the right-hand side of the relation.
operator (Operator): The transformation operator to apply to the left-hand side embedding before dot product.
"""
name: str = pydantic.Field(..., description="Relationship name.")
lhs: str = pydantic.Field(
@ -42,12 +78,35 @@ class Relation(pydantic.BaseModel):
class TwhinModelConfig(base_config.BaseConfig):
"""
Configuration class for the Twhin model.
This class defines configuration options specific to the Twhin model.
Attributes:
embeddings (TwhinEmbeddingsConfig): Configuration for the model's embeddings.
relations (List[Relation]): List of graph relationship configurations.
translation_optimizer (OptimizerConfig): Configuration for the optimizer used for translation.
"""
embeddings: TwhinEmbeddingsConfig
relations: typing.List[Relation]
translation_optimizer: OptimizerConfig
@validator("relations", each_item=True)
def valid_node_types(cls, relation, values, **kwargs):
"""
Validate that the specified node types in relations are valid table names in embeddings.
Args:
relation (Relation): A single relation configuration.
values (dict): The values dictionary containing the "embeddings" configuration.
Returns:
Relation: The validated relation configuration.
Raises:
AssertionError: If the specified node types are not valid table names in embeddings.
"""
table_names = [table.name for table in values["embeddings"].tables]
assert relation.lhs in table_names, f"Invalid lhs node type: {relation.lhs}"
assert relation.rhs in table_names, f"Invalid rhs node type: {relation.rhs}"

View File

@ -14,6 +14,28 @@ from torchrec.optim.apply_optimizer_in_backward import apply_optimizer_in_backwa
class TwhinModel(nn.Module):
"""
Twhin model for graph-based entity embeddings and translation.
This class defines the Twhin model, which is used for learning embeddings of entities in a graph
and applying translations to these embeddings based on graph relationships.
Args:
model_config (TwhinModelConfig): Configuration for the Twhin model.
data_config (TwhinDataConfig): Configuration for the data used by the model.
Attributes:
batch_size (int): The batch size used for training.
table_names (List[str]): Names of tables in the model's embeddings.
large_embeddings (LargeEmbeddings): LargeEmbeddings instance for entity embeddings.
embedding_dim (int): Dimensionality of entity embeddings.
num_tables (int): Number of tables in the model's embeddings.
in_batch_negatives (int): Number of in-batch negative samples to use during training.
global_negatives (int): Number of global negative samples to use during training.
num_relations (int): Number of graph relationships in the model.
all_trans_embs (torch.nn.Parameter): Parameter tensor for translation embeddings.
"""
def __init__(self, model_config: TwhinModelConfig, data_config: TwhinDataConfig):
super().__init__()
self.batch_size = data_config.per_replica_batch_size
@ -31,7 +53,17 @@ class TwhinModel(nn.Module):
)
def forward(self, batch: EdgeBatch):
"""
Forward pass of the Twhin model.
Args:
batch (EdgeBatch): Input batch containing graph edge information.
Returns:
dict: A dictionary containing model output with "logits" and "probabilities".
- "logits" (torch.Tensor): Logit scores.
- "probabilities" (torch.Tensor): Sigmoid probabilities.
"""
# B x D
trans_embs = self.all_trans_embs.data[batch.rels]
@ -98,6 +130,18 @@ class TwhinModel(nn.Module):
def apply_optimizers(model: TwhinModel, model_config: TwhinModelConfig):
"""
Apply optimizers to the Twhin model's embeddings.
This function applies optimizers to the embeddings of the Twhin model based on the provided configuration.
Args:
model (TwhinModel): The Twhin model to apply optimizers to.
model_config (TwhinModelConfig): Configuration for the Twhin model.
Returns:
TwhinModel: The Twhin model with optimizers applied to its embeddings.
"""
for table in model_config.embeddings.tables:
optimizer_class = get_optimizer_class(table.optimizer)
optimizer_kwargs = get_optimizer_algorithm_config(table.optimizer).dict()
@ -124,10 +168,14 @@ class TwhinModelAndLoss(torch.nn.Module):
device: torch.device,
) -> None:
"""
Args:
model: torch module to wrap.
loss_fn: Function for calculating loss, should accept logits and labels.
"""
Initialize a TwhinModelAndLoss module.
Args:
model: The torch module to wrap.
loss_fn: A function for calculating loss, should accept logits and labels.
data_config: Configuration for Twhin data.
device: The torch device to use for calculations.
"""
super().__init__()
self.model = model
self.loss_fn = loss_fn
@ -136,14 +184,21 @@ class TwhinModelAndLoss(torch.nn.Module):
self.device = device
def forward(self, batch: "RecapBatch"): # type: ignore[name-defined]
"""Runs model forward and calculates loss according to given loss_fn.
NOTE: The input signature here needs to be a Pipelineable object for
prefetching purposes during training using torchrec's pipeline. However
the underlying model signature needs to be exportable to onnx, requiring
generic python types. see https://pytorch.org/docs/stable/onnx.html#types.
"""
Run the model forward and calculate the loss according to the given loss_fn.
NOTE: The input signature here needs to be a Pipelineable object for
prefetching purposes during training using torchrec's pipeline. However
the underlying model signature needs to be exportable to onnx, requiring
generic python types. see https://pytorch.org/docs/stable/onnx.html#types
Args:
batch ("RecapBatch"): The input batch for model inference.
Returns:
Tuple[torch.Tensor, Dict[str, torch.Tensor]]: A tuple containing the loss tensor and a dictionary of
additional outputs including logits, labels, and weights.
"""
outputs = self.model(batch)
logits = outputs["logits"]

View File

@ -18,6 +18,12 @@ EMB_DIM = 128
def twhin_model_config() -> TwhinModelConfig:
"""
Create a configuration for the Twhin model.
Returns:
TwhinModelConfig: The Twhin model configuration.
"""
sgd_config_0 = OptimizerConfig(sgd=SgdConfig(lr=0.01))
sgd_config_1 = OptimizerConfig(sgd=SgdConfig(lr=0.02))
@ -52,6 +58,12 @@ def twhin_model_config() -> TwhinModelConfig:
def twhin_data_config() -> TwhinDataConfig:
"""
Create a configuration for the Twhin data.
Returns:
TwhinDataConfig: The Twhin data configuration.
"""
data_config = TwhinDataConfig(
data_root="/",
per_replica_batch_size=10,
@ -65,6 +77,15 @@ def twhin_data_config() -> TwhinDataConfig:
def test_twhin_model():
"""
Test the Twhin model creation and optimization.
This function creates a Twhin model using the specified configuration and tests its optimization. It also checks
the device placement of model parameters.
Returns:
None
"""
model_config = twhin_model_config()
loss_fn = F.binary_cross_entropy_with_logits

View File

@ -15,6 +15,14 @@ TRANSLATION_OPT_KEY = "operator_opt"
def _lr_from_config(optimizer_config):
"""Get the learning rate from an optimizer configuration.
Args:
optimizer_config: Optimizer configuration.
Returns:
Learning rate from the optimizer configuration.
"""
if optimizer_config.learning_rate is not None:
return optimizer_config.learning_rate
else:
@ -26,13 +34,13 @@ def _lr_from_config(optimizer_config):
def build_optimizer(model: TwhinModel, config: TwhinModelConfig):
"""Builds an optimizer for a Twhin model combining the embeddings optimizer with an optimizer for per-relation translations.
Args:
model: TwhinModel to build optimizer for.
config: TwhinConfig for model.
Args:
model: TwhinModel to build optimizer for.
config: TwhinModelConfig for model.
Returns:
Optimizer for model.
"""
Returns:
Optimizer for model.
"""
translation_optimizer_fn = functools.partial(
get_optimizer_class(config.translation_optimizer),
**get_optimizer_algorithm_config(config.translation_optimizer).dict(),

View File

@ -37,6 +37,12 @@ def run(
all_config: TwhinConfig,
save_dir: Optional[str] = None,
):
"""Run the training process for TwhinModel.
Args:
all_config (TwhinConfig): The configuration for the entire Twhin model.
save_dir (str, optional): The directory where model checkpoints will be saved. Defaults to None.
"""
train_dataset = create_dataset(all_config.train_data, all_config.model)
if env.is_reader():
@ -80,6 +86,11 @@ def run(
def main(argv):
"""Main entry point for the Twhin training script.
Args:
argv: Command-line arguments.
"""
logging.info("Starting")
logging.info(f"parsing config from {FLAGS.config_yaml_path}...")

View File

@ -25,14 +25,58 @@ from tml.ml_logging.torch_logging import logging
class _Reader(pa.flight.FlightServerBase):
"""Distributed reader flight server wrapping a dataset."""
"""
Distributed reader flight server wrapping a dataset.
This class implements a Flight server that wraps a dataset, allowing clients to retrieve data
from the dataset over the Flight protocol. It is designed to be used in a distributed environment
for efficient data access.
Args:
location (str): The location of the Flight server.
ds (Dataset): The dataset to be wrapped by the Flight server.
Attributes:
_location (str): The location of the Flight server.
_ds (Dataset): The dataset wrapped by the Flight server.
Methods:
do_get(_, __): Handles Flight requests for data retrieval.
Note:
Flight is an Apache Arrow project that provides a framework for efficient data transfer.
This class allows clients to retrieve data from the dataset using Flight.
"""
def __init__(self, location: str, ds: "Dataset"):
"""
Initialize a new _Reader instance.
Args:
location (str): The location of the Flight server.
ds (Dataset): The dataset to be wrapped by the Flight server.
"""
super().__init__(location=location)
self._location = location
self._ds = ds
def do_get(self, _, __):
"""
Handle Flight requests for data retrieval.
This method retrieves data from the wrapped dataset and provides it to clients over the Flight protocol.
Args:
_: Unused argument.
__: Unused argument.
Returns:
pa.flight.RecordBatchStream: A stream of record batches containing data from the dataset.
Note:
An updated schema (to account for column selection) must be given to the stream.
"""
# NB: An updated schema (to account for column selection) has to be given the stream.
schema = next(iter(self._ds.to_batches())).schema
batches = self._ds.to_batches()
@ -46,13 +90,49 @@ class _Reader(pa.flight.FlightServerBase):
class Dataset(torch.utils.data.IterableDataset):
"""
A PyTorch IterableDataset wrapping a Parquet dataset for efficient data loading.
This class enables efficient loading of data from Parquet files using PyArrow.
It is designed to be used as an IterableDataset in PyTorch for training and inference.
Args:
file_pattern (str): A glob pattern specifying the Parquet files to include in the dataset.
**dataset_kwargs: Additional keyword arguments passed to PyArrow's `to_batches` method.
Refer to https://arrow.apache.org/docs/python/generated/pyarrow.dataset.Scanner.html#pyarrow.dataset.Scanner.from_dataset
for more details.
Attributes:
LOCATION (str): The default location for the Flight server used for data distribution.
_file_pattern (str): The glob pattern specifying Parquet files in the dataset.
_fs: The filesystem object used for file operations.
_dataset_kwargs (dict): Additional keyword arguments passed to PyArrow's `to_batches` method.
_files (list): A list of file paths matching the glob pattern.
_schema (pa.Schema): The schema of the Parquet dataset.
Methods:
serve(): Start serving the dataset using a Flight server.
to_batches(): Generate batches of data from the Parquet dataset.
pa_to_batch(batch: pa.RecordBatch) -> DataclassBatch: Convert a Parquet RecordBatch to a custom data batch.
dataloader(remote: bool = False): Create a PyTorch DataLoader for iterating through the dataset.
Note:
This class efficiently loads data from Parquet files using PyArrow, and it can be used with PyTorch
to create DataLoader instances for training or inference.
"""
LOCATION = "grpc://0.0.0.0:2222"
def __init__(self, file_pattern: str, **dataset_kwargs) -> None:
"""Specify batch size and column to select for.
"""
Initialize a new Dataset instance. Specify batch size and column to select for.
Refer to https://arrow.apache.org/docs/python/generated/pyarrow.dataset.Scanner.html#pyarrow.dataset.Scanner.from_dataset.
"""
Args:
file_pattern (str): A glob pattern specifying the Parquet files to include in the dataset.
**dataset_kwargs: Additional keyword arguments passed to PyArrow's `to_batches` method.
"""
self._file_pattern = file_pattern
self._fs = infer_fs(self._file_pattern)
self._dataset_kwargs = dataset_kwargs
@ -64,16 +144,25 @@ class Dataset(torch.utils.data.IterableDataset):
self._validate_columns()
def _validate_columns(self):
"""
Validate the specified columns against the dataset schema.
Raises:
Exception: If any specified columns are not found in the dataset schema.
"""
columns = set(self._dataset_kwargs.get("columns", []))
wrong_columns = set(columns) - set(self._schema.names)
if wrong_columns:
raise Exception(f"Specified columns {list(wrong_columns)} not in schema.")
def serve(self):
"""Start serving the dataset using a Flight server."""
self.reader = _Reader(location=self.LOCATION, ds=self)
self.reader.serve()
def _create_dataset(self):
"""Create a PyArrow dataset for data retrieval."""
return pads.dataset(
source=random.sample(self._files, len(self._files))[0],
format="parquet",
@ -100,9 +189,33 @@ class Dataset(torch.utils.data.IterableDataset):
@abc.abstractmethod
def pa_to_batch(self, batch: pa.RecordBatch) -> DataclassBatch:
"""
Convert a Parquet RecordBatch to a custom data batch.
Args:
batch (pa.RecordBatch): A batch of data from the Parquet dataset.
Returns:
DataclassBatch: A custom data batch used in PyTorch training.
Raises:
NotImplementedError: This method must be implemented in derived classes.
"""
raise NotImplementedError
def dataloader(self, remote: bool = False):
"""
Create a PyTorch DataLoader for iterating through the dataset.
Args:
remote (bool, optional): If True, create a remote DataLoader using Flight for distributed training.
Returns:
DataLoader: A PyTorch DataLoader for iterating through the dataset.
Note:
If `remote` is True, a remote DataLoader is created for distributed training using Flight.
"""
if not remote:
return map(self.pa_to_batch, self.to_batches())
readers = get_readers(2)
@ -117,6 +230,25 @@ GRPC_OPTIONS = [
def get_readers(num_readers_per_worker: int):
"""
Get Flight readers for distributed data loading.
This function retrieves Flight readers for distributed data loading in a PyTorch environment.
Args:
num_readers_per_worker (int): The number of Flight readers to retrieve per worker.
Returns:
List[pa.RecordBatchFileReader]: A list of Flight readers for distributed data loading.
Note:
Flight readers are used to fetch data in a distributed manner for efficient data loading.
Example:
To obtain Flight readers, use the following code:
>>> readers = get_readers(num_readers_per_worker=2)
"""
addresses = env.get_flight_server_addresses()
readers = []

View File

@ -21,6 +21,16 @@ import torch.distributed as dist
def maybe_start_dataset_service():
"""
Start the dataset service if readers are available and required dependencies are met.
This function checks if readers are available and if the required TensorFlow version is >= 2.5.
If both conditions are met and the current environment is the dispatcher or reader, it starts
the TensorFlow dataset service.
Raises:
Exception: If the required TensorFlow version is not met (>= 2.5).
"""
if not env.has_readers():
return
@ -59,6 +69,24 @@ def maybe_start_dataset_service():
def register_dataset(
dataset: tf.data.Dataset, dataset_service: str, compression: Optional[str] = "AUTO"
):
"""
Register a dataset with the distributed dataset service.
This function registers a dataset with the distributed dataset service and broadcasts the dataset ID
and job name to all processes in the distributed environment.
Args:
dataset (tf.data.Dataset): The dataset to be registered.
dataset_service (str): The name of the dataset service.
compression (Optional[str]): The compression type for the dataset (default is "AUTO").
Returns:
Tuple[int, str]: A tuple containing the dataset ID and job name.
Note:
This function should be called on the rank 0 process.
"""
if dist.get_rank() == 0:
dataset_id = _register_dataset(
service=dataset_service,
@ -82,6 +110,23 @@ def distribute_from_dataset_id(
compression: Optional[str] = "AUTO",
prefetch: Optional[int] = tf.data.experimental.AUTOTUNE,
) -> tf.data.Dataset:
"""
Distribute a dataset from a registered dataset ID.
This function consumes a dataset from the distributed dataset service using the provided dataset ID
and job name. It also supports prefetching for improved performance.
Args:
dataset_service (str): The name of the dataset service.
dataset_id (int): The ID of the dataset to be consumed.
job_name (Optional[str]): The name of the job associated with the dataset (optional).
compression (Optional[str]): The compression type for the dataset (default is "AUTO").
prefetch (Optional[int]): The number of elements to prefetch (default is tf.data.experimental.AUTOTUNE).
Returns:
tf.data.Dataset: The distributed dataset.
"""
logging.info(f"rank{dist.get_rank()}: Consuming dds job with {dataset_id}, {job_name}")
dataset = _from_dataset_id(
processing_mode="parallel_epochs",
@ -97,15 +142,28 @@ def distribute_from_dataset_id(
def maybe_distribute_dataset(dataset: tf.data.Dataset) -> tf.data.Dataset:
"""Torch-compatible and distributed-training-aware dataset service distributor.
- rank 0 process will register the given dataset.
- rank 0 process will broadcast job name and dataset id.
- all rank processes will consume from the same job/dataset.
Without this, dataset workers will try to serve 1 job per rank process and OOM.
"""
Distribute a TensorFlow dataset for Torch-compatible and distributed training-aware consumption.
This function is used to distribute a dataset in a distributed training environment. It performs the
following steps:
- On the rank 0 process, it registers the given dataset with the distributed dataset service.
- It broadcasts the job name and dataset ID to all rank processes.
- All rank processes then consume the same dataset from the distributed dataset service.
Args:
dataset (tf.data.Dataset): The TensorFlow dataset to be distributed.
Returns:
tf.data.Dataset: The distributed TensorFlow dataset.
Note:
- If there are no reader processes in the distributed environment, the original dataset is returned
without any distribution.
- This function is intended for use in distributed training environments to prevent out-of-memory (OOM)
issues caused by each rank process trying to serve one job.
"""
if not env.has_readers():
return dataset
dataset_service = env.get_dds()

View File

@ -12,6 +12,17 @@ import torch
def create_dataset(tmpdir):
"""
Create a mock dataset for testing.
This function creates a mock dataset using PyArrow and Parquet for testing purposes.
Args:
tmpdir: A temporary directory where the dataset will be created.
Returns:
MockDataset: A mock dataset for testing.
"""
table = pa.table(
{
@ -34,6 +45,14 @@ def create_dataset(tmpdir):
def test_dataset(tmpdir):
"""
Test the created dataset.
This function tests the created mock dataset and checks if it behaves as expected.
Args:
tmpdir: A temporary directory used for testing.
"""
ds = create_dataset(tmpdir)
batch = next(iter(ds.dataloader(remote=False)))
assert batch.batch_size == 2
@ -46,6 +65,14 @@ def test_dataset(tmpdir):
reason="Multiprocessing doesn't work on github yet.",
)
def test_distributed_dataset(tmpdir):
"""
Test the distributed dataset.
This function tests the distributed version of the mock dataset using multiprocessing.
Args:
tmpdir: A temporary directory used for testing.
"""
MOCK_ENV = {"TEMP_SLURM_NUM_READERS": "1"}
def _client():

View File

@ -11,11 +11,55 @@ import torch
def roundrobin(*iterables):
"""Round robin through provided iterables, useful for simple load balancing.
Adapted from https://docs.python.org/3/library/itertools.html.
"""
Iterate through provided iterables in a round-robin fashion.
This function takes multiple iterables and returns an iterator that yields elements from
each iterable in a round-robin manner. It continues cycling through the iterables until
all of them are exhausted.
Adapted from https://docs.python.org/3/library/itertools.html.
Args:
*iterables: One or more iterable objects to iterate through.
Yields:
Elements from the provided iterables in a round-robin fashion.
Raises:
StopIteration: If all provided iterables are exhausted.
Example:
```python
iterable1 = [1, 2, 3]
iterable2 = ['a', 'b', 'c']
iterable3 = [0.1, 0.2, 0.3]
for item in roundrobin(iterable1, iterable2, iterable3):
print(item)
# Output:
# 1
# 'a'
# 0.1
# 2
# 'b'
# 0.2
# 3
# 'c'
# 0.3
```
Note:
- If one of the provided iterables is shorter than the others, the function will
continue iterating through the remaining iterables until all are exhausted.
- If an iterable raises an exception during iteration, a warning message is logged,
and the function continues with the next iterable.
See Also:
- `itertools.cycle`: A function that repeatedly cycles through elements of an iterable.
- `itertools.islice`: A function to slice an iterable to limit the number of iterations.
"""
num_active = len(iterables)
nexts = itertools.cycle(iter(it).__next__ for it in iterables)
while num_active:
@ -35,6 +79,48 @@ def roundrobin(*iterables):
def speed_check(data_loader, max_steps: int, frequency: int, peek: Optional[int]):
"""
Monitor the speed and progress of data loading using a data loader.
This function iterates through a data loader for a specified number of steps or until
the end of the data loader is reached, periodically logging progress information.
Args:
data_loader: The data loader to monitor.
max_steps: The maximum number of steps to iterate through the data loader.
frequency: The frequency (in steps) at which to log progress.
peek (optional): If specified, it indicates the frequency (in steps) at which to log
batch contents for inspection.
Example:
```python
import torch
from torch.utils.data import DataLoader
# Create a data loader (replace with your own DataLoader configuration)
data_loader = DataLoader(dataset, batch_size=32, shuffle=True)
# Monitor data loading speed and progress
speed_check(data_loader, max_steps=1000, frequency=50, peek=500)
```
Args:
data_loader: The data loader to monitor.
max_steps: The maximum number of steps to iterate through the data loader.
frequency: The frequency (in steps) at which to log progress.
peek (optional): If specified, it indicates the frequency (in steps) at which to log
batch contents for inspection.
Note:
- The function logs information about elapsed time, the number of examples processed,
and the processing speed in examples per second.
- If `peek` is provided, batch contents will be logged for inspection at the specified
frequency.
See Also:
- `torch.utils.data.DataLoader`: PyTorch's data loading utility for batching and
iterating through datasets.
"""
num_examples = 0
prev = time.perf_counter()
for idx, batch in enumerate(data_loader):
@ -57,11 +143,66 @@ def speed_check(data_loader, max_steps: int, frequency: int, peek: Optional[int]
def pa_to_torch(array: pa.array) -> torch.Tensor:
"""
Convert a PyArrow Array to a PyTorch Tensor.
Args:
array (pa.array): The PyArrow Array to convert.
Returns:
torch.Tensor: A PyTorch Tensor containing the data from the input PyArrow Array.
Example:
```python
import pyarrow as pa
import torch
# Create a PyArrow Array
arrow_array = pa.array([1, 2, 3])
# Convert it to a PyTorch Tensor
torch_tensor = pa_to_torch(arrow_array)
```
"""
return torch.from_numpy(array.to_numpy())
def create_default_pa_to_batch(schema) -> DataclassBatch:
""" """
"""
Create a function that converts a PyArrow RecordBatch to a custom DataclassBatch with imputed values for missing data.
Args:
schema (pa.Schema): The PyArrow schema describing the data structure of the RecordBatch.
Returns:
callable: A function that takes a PyArrow RecordBatch as input and returns a custom DataclassBatch.
Example:
```python
import pyarrow as pa
from dataclass_batch import DataclassBatch
# Define a PyArrow schema
schema = pa.schema([
("feature1", pa.float64()),
("feature2", pa.int64()),
("label", pa.int64()),
])
# Create the conversion function
pa_to_batch = create_default_pa_to_batch(schema)
# Create a PyArrow RecordBatch
record_batch = pa.RecordBatch.from_pandas(pd.DataFrame({
"feature1": [1.0, 2.0, None],
"feature2": [10, 20, 30],
"label": [0, 1, None],
}))
# Convert the RecordBatch to a custom DataclassBatch
custom_batch = pa_to_batch(record_batch)
```
"""
_CustomBatch = DataclassBatch.from_schema("DefaultBatch", schema=schema)
def get_imputation_value(pa_type):