diff --git a/ml_logging/absl_logging.py b/ml_logging/absl_logging.py index e4d9e9e..07be8bb 100644 --- a/ml_logging/absl_logging.py +++ b/ml_logging/absl_logging.py @@ -14,7 +14,24 @@ from absl import logging as logging def setup_absl_logging(): - """Make sure that absl logging pushes to stdout rather than stderr.""" + """ + Configure absl-py logging to direct log messages to stdout and apply a custom log message format. + + This function ensures that log messages generated by the absl-py library are written to stdout + rather than stderr. It also applies a custom log message format that includes module, function, + line number, log level, and the log message content. + + Note: + This function should be called once at the beginning of your script or application to + configure absl-py logging. + + Example: + To use this function, simply call it at the start of your script: + ``` + setup_absl_logging() + ``` + + """ logging.get_absl_handler().python_handler.stream = sys.stdout formatter = py_logging.Formatter( fmt="[%(module)s.%(funcName)s:%(lineno)s - %(levelname)s] %(message)s" diff --git a/ml_logging/test_torch_logging.py b/ml_logging/test_torch_logging.py index e3a2e90..36c141d 100644 --- a/ml_logging/test_torch_logging.py +++ b/ml_logging/test_torch_logging.py @@ -5,6 +5,21 @@ from tml.ml_logging.torch_logging import logging class Testtlogging(unittest.TestCase): def test_warn_once(self): + """ + Test that warning messages are logged only once when using the assertLogs context manager. + + This unit test checks the behavior of the logging system when warning messages are issued + multiple times within the same context. It uses the assertLogs context manager to capture + log messages at the INFO level and verifies that warning messages are logged only once. + + Example: + To use this test case, call it using a test runner like unittest: + ``` + python -m unittest your_test_module.TestLogging.test_warn_once + ``` + + """ + with self.assertLogs(level="INFO") as captured_logs: logging.info("first info") logging.warning("first warning") diff --git a/ml_logging/torch_logging.py b/ml_logging/torch_logging.py index e791c46..2f60114 100644 --- a/ml_logging/torch_logging.py +++ b/ml_logging/torch_logging.py @@ -18,7 +18,35 @@ import torch.distributed as dist def rank_specific(logger): - """Ensures that we only override a given logger once.""" + """ + Customize logger behavior based on the distributed environment and rank. + + This function allows for customizing the behavior of a logger based on the distributed environment and the rank + of the current process. It overrides standard logging methods (e.g., error, warning) to conditionally log messages + depending on the rank or limit the number of redundant logs. + + Args: + logger: The logger object to customize. + + Returns: + The customized logger. + + Example: + To use this function with the `logging` module: + ```python + import logging + from rank_specific_logging import rank_specific + + logger = logging.getLogger(__name__) + rank_specific(logger) + ``` + + Customization: + - Messages are only logged if the distributed environment is not initialized or if the rank matches. + - The 'warning' method is limited to logging a single redundant warning. + - Logging from rank -1 is redirected to include the rank information. + + """ if hasattr(logger, "_ALREADY_OVERWRITTEN_TO_BE_RANK_SPECIFIC"): return logger diff --git a/optimizers/config.py b/optimizers/config.py index f5011f0..dc5c94a 100644 --- a/optimizers/config.py +++ b/optimizers/config.py @@ -8,11 +8,60 @@ import pydantic class PiecewiseConstant(base_config.BaseConfig): + """ + Configuration for a piecewise constant learning rate schedule. + + This configuration class allows you to specify a piecewise constant learning rate schedule + by defining boundaries and corresponding learning rate values. + + Attributes: + learning_rate_boundaries (List[int], optional): List of step boundaries at which + the learning rate will change. If None, no boundaries are defined. + learning_rate_values (List[float], optional): List of learning rate values + corresponding to the boundaries. If None, no values are defined. + + Example: + To configure a piecewise constant learning rate schedule, create an instance of this class + and set the attributes accordingly. For example: + + ```python + piecewise_lr = PiecewiseConstant( + learning_rate_boundaries=[1000, 2000, 3000], + learning_rate_values=[0.1, 0.05, 0.01, 0.001] + ) + ``` + + Note: + The number of learning rate values should be one more than the number of boundaries. + + """ learning_rate_boundaries: typing.List[int] = pydantic.Field(None) learning_rate_values: typing.List[float] = pydantic.Field(None) class LinearRampToConstant(base_config.BaseConfig): + """ + Configuration for a linear ramp-up to constant learning rate schedule. + + This configuration class allows you to specify a learning rate schedule that ramps up linearly + from zero to a constant value over a specified number of steps. + + Attributes: + learning_rate (float): The final constant learning rate. + num_ramp_steps (PositiveInt): Number of steps to ramp up the learning rate from zero. + + Example: + To configure a linear ramp-up to a constant learning rate, create an instance of this class + and set the attributes accordingly. For example: + + ```python + linear_ramp_lr = LinearRampToConstant( + learning_rate=0.1, + num_ramp_steps=1000 + ) + ``` + + """ learning_rate: float num_ramp_steps: pydantic.PositiveInt = pydantic.Field( description="Number of steps to ramp this up from zero." @@ -20,6 +69,32 @@ class LinearRampToConstant(base_config.BaseConfig): class LinearRampToCosine(base_config.BaseConfig): + """ + Configuration for a linear ramp-up to cosine decay learning rate schedule. + + This configuration class allows you to specify a learning rate schedule that ramps up linearly + from zero, then decays following a cosine schedule to a final constant learning rate. + + Attributes: + learning_rate (float): The initial learning rate at the start of ramp-up. + final_learning_rate (float): The final constant learning rate after decay. + num_ramp_steps (PositiveInt): Number of steps to ramp up the learning rate from zero. + final_num_steps (PositiveInt): Final number of steps where decay stops. + + Example: + To configure a linear ramp-up to cosine decay learning rate, create an instance of this + class and set the attributes accordingly. For example: + + ```python + ramp_to_cosine_lr = LinearRampToCosine( + learning_rate=0.01, + final_learning_rate=0.001, + num_ramp_steps=1000, + final_num_steps=5000 + ) + ``` + + """ learning_rate: float final_learning_rate: float num_ramp_steps: pydantic.PositiveInt = pydantic.Field( @@ -31,6 +106,41 @@ class LinearRampToCosine(base_config.BaseConfig): class LearningRate(base_config.BaseConfig): + """ + Learning rate configuration for training. + + This configuration class allows you to specify different learning rate schedules + for your training process. + + Attributes: + constant (float, optional): Constant learning rate to be used throughout training. + linear_ramp_to_cosine (LinearRampToCosine, optional): Learning rate that ramps up linearly + and then decays following a cosine schedule. + linear_ramp_to_constant (LinearRampToConstant, optional): Learning rate that ramps up + linearly and then remains constant. + piecewise_constant (PiecewiseConstant, optional): Learning rate that changes at specified + boundaries with corresponding values. + + Example: + To configure a learning rate schedule, create an instance of this class and set the + attributes accordingly. For example: + + ```python + learning_rate = LearningRate( + constant=0.01, + linear_ramp_to_cosine=LinearRampToCosine( + learning_rate=0.1, + final_learning_rate=0.001, + num_ramp_steps=1000, + final_num_steps=5000 + ) + ) + ``` + + Note: + Each learning rate schedule attribute can be set to `None` if not needed. + + """ constant: float = pydantic.Field(None, one_of="lr") linear_ramp_to_cosine: LinearRampToCosine = pydantic.Field(None, one_of="lr") linear_ramp_to_constant: LinearRampToConstant = pydantic.Field(None, one_of="lr") @@ -38,30 +148,166 @@ class LearningRate(base_config.BaseConfig): class OptimizerAlgorithmConfig(base_config.BaseConfig): - """Base class for optimizer configurations.""" + """ + Base class for optimizer configurations. + + This base configuration class provides a structure for specifying various optimizer-related + settings, including the learning rate and different learning rate schedules. + + Attributes: + lr (float): The base learning rate used by the optimizer. + + Subclasses should inherit from this base class and define additional attributes specific to + the optimizer algorithm they represent. + + Example: + To create a custom optimizer configuration, create a subclass of this base class and + define the necessary attributes. For example: + + ```python + class MyOptimizerConfig(OptimizerAlgorithmConfig): + momentum: float = pydantic.Field(0.9, description="Momentum value for SGD.") + ``` + + Note: + This base class does not include specific optimizer settings. Subclasses should define + the optimizer-specific attributes as needed. + + """ lr: float ... class AdamConfig(OptimizerAlgorithmConfig): - # see https://pytorch.org/docs/stable/generated/torch.optim.Adam.html#torch.optim.Adam + """ + Configuration for the Adam optimizer. + + This configuration class allows you to specify the hyperparameters for the Adam optimizer. + + Attributes: + lr (float): The learning rate for optimization. + betas (Tuple[float, float], optional): Coefficients used for computing running averages + of gradient and squared gradient. Defaults to (0.9, 0.999). + eps (float, optional): A small constant added to the denominator for numerical stability. + Defaults to 1e-7. + + Example: + To configure the Adam optimizer, create an instance of this class and set the attributes + accordingly. For example: + + ```python + adam_optimizer = AdamConfig( + lr=0.001, + betas=(0.9, 0.999), + eps=1e-8 + ) + ``` + + See Also: + [PyTorch Adam Documentation](https://pytorch.org/docs/stable/generated/torch.optim.Adam.html#torch.optim.Adam) + + """ lr: float betas: typing.Tuple[float, float] = [0.9, 0.999] eps: float = 1e-7 # Numerical stability in denominator. class SgdConfig(OptimizerAlgorithmConfig): + """ + Configuration for the Stochastic Gradient Descent (SGD) optimizer. + + This configuration class allows you to specify the hyperparameters for the SGD optimizer. + + Attributes: + lr (float): The learning rate for optimization. + momentum (float, optional): The momentum factor for SGD. Defaults to 0.0. + + Example: + To configure the SGD optimizer, create an instance of this class and set the attributes + accordingly. For example: + + ```python + sgd_optimizer = SgdConfig( + lr=0.01, + momentum=0.9 + ) + ``` + + """ lr: float momentum: float = 0.0 class AdagradConfig(OptimizerAlgorithmConfig): + """ + Configuration for the optimizer used during training. + + This configuration class allows you to specify the optimizer for training, including + options for various optimizer algorithms. + + Attributes: + learning_rate (LearningRate, optional): Learning rate configuration. Defaults to None. + adam (AdamConfig, optional): Configuration for the Adam optimizer. Defaults to None. + sgd (SgdConfig, optional): Configuration for the Stochastic Gradient Descent (SGD) optimizer. + Defaults to None. + adagrad (AdagradConfig, optional): Configuration for the Adagrad optimizer. Defaults to None. + + Example: + To configure the optimizer for training, create an instance of this class and set the + attributes accordingly. For example: + + ```python + optimizer_config = OptimizerConfig( + learning_rate=LearningRate(constant=0.001), + adam=AdamConfig(lr=0.001, betas=(0.9, 0.999), eps=1e-8) + ) + ``` + + """ lr: float eps: float = 0 class OptimizerConfig(base_config.BaseConfig): + """ + Configuration for defining different optimizer algorithms and their parameters. + + This class allows you to configure various optimizer algorithms such as Adam, SGD, and Adagrad, + along with their respective hyperparameters. + + Args: + learning_rate (LearningRate): The learning rate configuration, which can include + constant learning rates or other learning rate schedules. + adam (AdamConfig): Configuration for the Adam optimizer. + sgd (SgdConfig): Configuration for the Stochastic Gradient Descent (SGD) optimizer. + adagrad (AdagradConfig): Configuration for the Adagrad optimizer. + + Example: + ```python + optimizer_config = OptimizerConfig( + learning_rate=LearningRate(constant=0.001), + adam=AdamConfig(lr=0.001, betas=(0.9, 0.999), eps=1e-8), + ) + ``` + + Attributes: + learning_rate (LearningRate): The learning rate configuration. + adam (AdamConfig): Configuration for the Adam optimizer. + sgd (SgdConfig): Configuration for the Stochastic Gradient Descent (SGD) optimizer. + adagrad (AdagradConfig): Configuration for the Adagrad optimizer. + + Note: + You can specify only one of the optimizer configurations (adam, sgd, or adagrad) in an + `OptimizerConfig` instance. + + See Also: + - `LearningRate`: Configuration for specifying learning rates. + - `AdamConfig`: Configuration for the Adam optimizer. + - `SgdConfig`: Configuration for the Stochastic Gradient Descent (SGD) optimizer. + - `AdagradConfig`: Configuration for the Adagrad optimizer. + + """ learning_rate: LearningRate = pydantic.Field( None, description="Constant learning rates", @@ -72,6 +318,33 @@ class OptimizerConfig(base_config.BaseConfig): def get_optimizer_algorithm_config(optimizer_config: OptimizerConfig): + """ + Get the optimizer algorithm configuration from the given `OptimizerConfig`. + + This function extracts and returns the specific optimizer algorithm configuration + (e.g., Adam, SGD, or Adagrad) from the provided `OptimizerConfig`. + + Args: + optimizer_config (OptimizerConfig): The optimizer configuration object containing + one of the optimizer algorithm configurations. + + Returns: + Union[AdamConfig, SgdConfig, AdagradConfig]: The specific optimizer algorithm + configuration extracted from `optimizer_config`. + + Raises: + ValueError: If no optimizer algorithm is selected in `optimizer_config`. + + Example: + ```python + optimizer_config = OptimizerConfig( + adam=AdamConfig(lr=0.001, betas=(0.9, 0.999), eps=1e-8) + ) + algorithm_config = get_optimizer_algorithm_config(optimizer_config) + # `algorithm_config` will be an instance of `AdamConfig`. + ``` + + """ if optimizer_config.adam is not None: return optimizer_config.adam elif optimizer_config.sgd is not None: diff --git a/optimizers/optimizer.py b/optimizers/optimizer.py index 4517368..e930f73 100644 --- a/optimizers/optimizer.py +++ b/optimizers/optimizer.py @@ -14,7 +14,35 @@ from tml.ml_logging.torch_logging import logging def compute_lr(lr_config, step): - """Compute a learning rate.""" + """ + Compute the learning rate based on the specified learning rate configuration. + + This function calculates the learning rate according to the given configuration, which can include + constant learning rates, piecewise constant schedules, linear ramps, and cosine annealing. + + Args: + lr_config (LearningRate): The learning rate configuration specifying the learning rate schedule. + step (int): The current training step or iteration. + + Returns: + float: The computed learning rate for the current step. + + Raises: + ValueError: If the `lr_config` is invalid or contains conflicting options. + + Example: + ```python + lr_schedule = LearningRate( + constant=0.001, + piecewise_constant=PiecewiseConstant( + learning_rate_boundaries=[1000, 2000, 3000], + learning_rate_values=[0.1, 0.05, 0.01, 0.001] + ) + ) + current_step = 2500 + learning_rate = compute_lr(lr_schedule, current_step) + ``` + """ if lr_config.constant is not None: return lr_config.constant elif lr_config.piecewise_constant is not None: @@ -46,11 +74,54 @@ def compute_lr(lr_config, step): class LRShim(_LRScheduler): - """Shim to get learning rates into a LRScheduler. - - This adheres to the torch.optim scheduler API and can be plugged anywhere that - e.g. exponential decay can be used. """ + Learning Rate Scheduler Shim to adjust learning rates during training. + + This class acts as a shim to apply different learning rates to individual parameter groups + within an optimizer. It adheres to the torch.optim scheduler API and can be used with various + optimizers, allowing fine-grained control over learning rates based on configuration. + + Args: + optimizer (torch.optim.Optimizer): The optimizer for which learning rates will be adjusted. + lr_dict (Dict[str, LearningRate]): A dictionary mapping parameter group names to their + corresponding learning rate configurations. + last_epoch (int, optional): The index of the last epoch. Default is -1. + verbose (bool, optional): If True, prints a warning message when accessing learning rates + using the deprecated `get_lr()` method. Default is False. + + Raises: + ValueError: If the number of parameter groups in the optimizer does not match the number + of learning rate configurations provided. + + Note: + To obtain the last computed learning rates, please use `get_last_lr()`. + + Example: + ```python + optimizer = torch.optim.SGD(model.parameters(), lr=0.01) + lr_schedule = { + 'main': LearningRate(constant=0.01), + 'auxiliary': LearningRate(piecewise_constant=PiecewiseConstant( + learning_rate_boundaries=[1000, 2000], + learning_rate_values=[0.01, 0.001] + )) + } + lr_shim = LRShim(optimizer, lr_schedule) + + for epoch in range(num_epochs): + # Train the model + train(...) + # Update learning rates at the end of each epoch + lr_shim.step(epoch) + + final_lr_main = lr_shim.get_last_lr()['main'] + final_lr_auxiliary = lr_shim.get_last_lr()['auxiliary'] + ``` + + See Also: + - `LearningRate`: Configuration for specifying learning rates. + - `PiecewiseConstant`: Configuration for piecewise constant learning rate schedules. + """ def __init__( self, @@ -95,9 +166,42 @@ def get_optimizer_class(optimizer_config: OptimizerConfig): def build_optimizer( model: torch.nn.Module, optimizer_config: OptimizerConfig ) -> Tuple[Optimizer, _LRScheduler]: - """Builds an optimizer and LR scheduler from an OptimizerConfig. - Note: use this when you want the same optimizer and learning rate schedule for all your parameters. """ + Build an optimizer and learning rate scheduler based on the provided optimizer configuration. + + Args: + model (torch.nn.Module): The PyTorch model for which the optimizer will be created. + optimizer_config (OptimizerConfig): The optimizer configuration specifying the optimizer + algorithm and learning rate settings. + + Returns: + Tuple[Optimizer, _LRScheduler]: A tuple containing the optimizer and learning rate scheduler + objects. + + Note: + This function is intended for cases where you want the same optimizer and learning rate + schedule for all model parameters. + + Example: + ```python + model = MyModel() + optimizer_config = OptimizerConfig( + learning_rate=LearningRate(constant=0.01), + sgd=SgdConfig(lr=0.01, momentum=0.9) + ) + optimizer, scheduler = build_optimizer(model, optimizer_config) + + for epoch in range(num_epochs): + # Train the model with the optimizer + train(model, optimizer, ...) + # Update learning rates at the end of each epoch + scheduler.step(epoch) + ``` + + See Also: + - `OptimizerConfig`: Configuration for specifying optimizer settings. + - `LRShim`: Learning rate scheduler shim for fine-grained learning rate control. + """ optimizer_class = get_optimizer_class(optimizer_config) optimizer = optimizer_class(model.parameters(), **optimizer_config.sgd.dict()) # We're passing everything in as one group here diff --git a/projects/twhin/data/config.py b/projects/twhin/data/config.py index 55e3cec..05c70c4 100644 --- a/projects/twhin/data/config.py +++ b/projects/twhin/data/config.py @@ -4,6 +4,17 @@ import pydantic class TwhinDataConfig(base_config.BaseConfig): + """ + Configuration for Twhin model training data. + + Args: + data_root (str): The root directory for the training data. + per_replica_batch_size (pydantic.PositiveInt): Batch size per replica. + global_negatives (int): The number of global negatives. + in_batch_negatives (int): The number of in-batch negatives. + limit (pydantic.PositiveInt): The limit on the number of data points to use. + offset (pydantic.PositiveInt, optional): The offset to start reading from. Default is None. + """ data_root: str per_replica_batch_size: pydantic.PositiveInt global_negatives: int diff --git a/projects/twhin/data/data.py b/projects/twhin/data/data.py index 37bf8f5..b0afe7f 100644 --- a/projects/twhin/data/data.py +++ b/projects/twhin/data/data.py @@ -4,6 +4,16 @@ from tml.projects.twhin.data.edges import EdgesDataset def create_dataset(data_config: TwhinDataConfig, model_config: TwhinModelConfig): + """ + Create a dataset for Twhin model training. + + Args: + data_config (TwhinDataConfig): The data configuration for the dataset. + model_config (TwhinModelConfig): The model configuration containing embeddings and relations. + + Returns: + EdgesDataset: The dataset for Twhin model training. + """ tables = model_config.embeddings.tables table_sizes = {table.name: table.num_embeddings for table in tables} relations = model_config.relations diff --git a/projects/twhin/data/edges.py b/projects/twhin/data/edges.py index f7864b1..778a188 100644 --- a/projects/twhin/data/edges.py +++ b/projects/twhin/data/edges.py @@ -15,6 +15,15 @@ from torchrec.sparse.jagged_tensor import KeyedJaggedTensor @dataclass class EdgeBatch(DataclassBatch): + """ + Batch data structure for edge-based models. + + Args: + nodes (KeyedJaggedTensor): A KeyedJaggedTensor containing node embeddings. + labels (torch.Tensor): Tensor containing labels. + rels (torch.Tensor): Tensor containing relation information. + weights (torch.Tensor): Tensor containing weights. + """ nodes: KeyedJaggedTensor labels: torch.Tensor rels: torch.Tensor @@ -22,6 +31,18 @@ class EdgeBatch(DataclassBatch): class EdgesDataset(Dataset): + """ + Dataset for edge-based models. + + Args: + file_pattern (str): The file pattern for the dataset. + table_sizes (Dict[str, int]): A dictionary of table names and their sizes. + relations (List[Relation]): A list of relations between tables. + lhs_column_name (str): The name of the left-hand-side column. + rhs_column_name (str): The name of the right-hand-side column. + rel_column_name (str): The name of the relation column. + **dataset_kwargs: Additional keyword arguments for the parent Dataset class. + """ rng = np.random.default_rng() def __init__( @@ -56,6 +77,15 @@ class EdgesDataset(Dataset): super().__init__(file_pattern=file_pattern, **dataset_kwargs) def pa_to_batch(self, batch: pa.RecordBatch): + """ + Converts a pyarrow RecordBatch to an EdgeBatch. + + Args: + batch (pa.RecordBatch): A pyarrow RecordBatch containing data. + + Returns: + EdgeBatch: An EdgeBatch containing node embeddings, labels, relations, and weights. + """ lhs = torch.from_numpy(batch.column(self.lhs_column_name).to_numpy()) rhs = torch.from_numpy(batch.column(self.rhs_column_name).to_numpy()) rel = torch.from_numpy(batch.column(self.rel_column_name).to_numpy()) @@ -74,6 +104,14 @@ class EdgesDataset(Dataset): ) -> Tuple[KeyedJaggedTensor, List[Tuple[int, int]]]: """Process edges that contain lhs index, rhs index, relation index. + + Args: + lhs (torch.Tensor): Tensor containing left-hand-side indices. + rhs (torch.Tensor): Tensor containing right-hand-side indices. + rel (torch.Tensor): Tensor containing relation indices. + + Returns: + Tuple[KeyedJaggedTensor, List[Tuple[int, int]]]: A KeyedJaggedTensor and relation index pairs. Example: ``` @@ -147,6 +185,12 @@ class EdgesDataset(Dataset): return KeyedJaggedTensor(keys=self.table_names, values=values, lengths=lengths) def to_batches(self): + """ + Converts data to batches. + + Yields: + pa.RecordBatch: A pyarrow RecordBatch containing data. + """ ds = super().to_batches() batch_size = self._dataset_kwargs["batch_size"] diff --git a/projects/twhin/models/config.py b/projects/twhin/models/config.py index aafc9e6..c7b2e1c 100644 --- a/projects/twhin/models/config.py +++ b/projects/twhin/models/config.py @@ -10,8 +10,29 @@ from pydantic import validator class TwhinEmbeddingsConfig(LargeEmbeddingsConfig): + """ + Configuration class for Twhin model embeddings. + + This class inherits from LargeEmbeddingsConfig and ensures that the embedding dimensions and data types + for all tables in the Twhin model embeddings configuration match. + + Attributes: + tables (List[TableConfig]): A list of table configurations for the model's embeddings. + """ @validator("tables") def embedding_dims_match(cls, tables): + """ + Validate that embedding dimensions and data types match for all tables. + + Args: + tables (List[TableConfig]): List of table configurations. + + Returns: + List[TableConfig]: The list of validated table configurations. + + Raises: + AssertionError: If embedding dimensions or data types do not match. + """ embedding_dim = tables[0].embedding_dim data_type = tables[0].data_type for table in tables: @@ -21,11 +42,26 @@ class TwhinEmbeddingsConfig(LargeEmbeddingsConfig): class Operator(str, enum.Enum): + """ + Enumeration of operator types. + + This enumeration defines different types of operators that can be applied to Twhin model relations. + """ TRANSLATION = "translation" class Relation(pydantic.BaseModel): - """graph relationship properties and operator""" + """ + Configuration class for graph relationships in the Twhin model. + + This class defines properties and operators for graph relationships in the Twhin model. + + Attributes: + name (str): The name of the relationship. + lhs (str): The name of the entity on the left-hand side of the relation. + rhs (str): The name of the entity on the right-hand side of the relation. + operator (Operator): The transformation operator to apply to the left-hand side embedding before dot product. + """ name: str = pydantic.Field(..., description="Relationship name.") lhs: str = pydantic.Field( @@ -42,12 +78,35 @@ class Relation(pydantic.BaseModel): class TwhinModelConfig(base_config.BaseConfig): + """ + Configuration class for the Twhin model. + + This class defines configuration options specific to the Twhin model. + + Attributes: + embeddings (TwhinEmbeddingsConfig): Configuration for the model's embeddings. + relations (List[Relation]): List of graph relationship configurations. + translation_optimizer (OptimizerConfig): Configuration for the optimizer used for translation. + """ embeddings: TwhinEmbeddingsConfig relations: typing.List[Relation] translation_optimizer: OptimizerConfig @validator("relations", each_item=True) def valid_node_types(cls, relation, values, **kwargs): + """ + Validate that the specified node types in relations are valid table names in embeddings. + + Args: + relation (Relation): A single relation configuration. + values (dict): The values dictionary containing the "embeddings" configuration. + + Returns: + Relation: The validated relation configuration. + + Raises: + AssertionError: If the specified node types are not valid table names in embeddings. + """ table_names = [table.name for table in values["embeddings"].tables] assert relation.lhs in table_names, f"Invalid lhs node type: {relation.lhs}" assert relation.rhs in table_names, f"Invalid rhs node type: {relation.rhs}" diff --git a/projects/twhin/models/models.py b/projects/twhin/models/models.py index b1dfaa7..0b5f1cb 100644 --- a/projects/twhin/models/models.py +++ b/projects/twhin/models/models.py @@ -14,6 +14,28 @@ from torchrec.optim.apply_optimizer_in_backward import apply_optimizer_in_backwa class TwhinModel(nn.Module): + """ + Twhin model for graph-based entity embeddings and translation. + + This class defines the Twhin model, which is used for learning embeddings of entities in a graph + and applying translations to these embeddings based on graph relationships. + + Args: + model_config (TwhinModelConfig): Configuration for the Twhin model. + data_config (TwhinDataConfig): Configuration for the data used by the model. + + Attributes: + batch_size (int): The batch size used for training. + table_names (List[str]): Names of tables in the model's embeddings. + large_embeddings (LargeEmbeddings): LargeEmbeddings instance for entity embeddings. + embedding_dim (int): Dimensionality of entity embeddings. + num_tables (int): Number of tables in the model's embeddings. + in_batch_negatives (int): Number of in-batch negative samples to use during training. + global_negatives (int): Number of global negative samples to use during training. + num_relations (int): Number of graph relationships in the model. + all_trans_embs (torch.nn.Parameter): Parameter tensor for translation embeddings. + + """ def __init__(self, model_config: TwhinModelConfig, data_config: TwhinDataConfig): super().__init__() self.batch_size = data_config.per_replica_batch_size @@ -31,7 +53,17 @@ class TwhinModel(nn.Module): ) def forward(self, batch: EdgeBatch): + """ + Forward pass of the Twhin model. + Args: + batch (EdgeBatch): Input batch containing graph edge information. + + Returns: + dict: A dictionary containing model output with "logits" and "probabilities". + - "logits" (torch.Tensor): Logit scores. + - "probabilities" (torch.Tensor): Sigmoid probabilities. + """ # B x D trans_embs = self.all_trans_embs.data[batch.rels] @@ -98,6 +130,18 @@ class TwhinModel(nn.Module): def apply_optimizers(model: TwhinModel, model_config: TwhinModelConfig): + """ + Apply optimizers to the Twhin model's embeddings. + + This function applies optimizers to the embeddings of the Twhin model based on the provided configuration. + + Args: + model (TwhinModel): The Twhin model to apply optimizers to. + model_config (TwhinModelConfig): Configuration for the Twhin model. + + Returns: + TwhinModel: The Twhin model with optimizers applied to its embeddings. + """ for table in model_config.embeddings.tables: optimizer_class = get_optimizer_class(table.optimizer) optimizer_kwargs = get_optimizer_algorithm_config(table.optimizer).dict() @@ -124,10 +168,14 @@ class TwhinModelAndLoss(torch.nn.Module): device: torch.device, ) -> None: """ - Args: - model: torch module to wrap. - loss_fn: Function for calculating loss, should accept logits and labels. - """ + Initialize a TwhinModelAndLoss module. + + Args: + model: The torch module to wrap. + loss_fn: A function for calculating loss, should accept logits and labels. + data_config: Configuration for Twhin data. + device: The torch device to use for calculations. + """ super().__init__() self.model = model self.loss_fn = loss_fn @@ -136,14 +184,21 @@ class TwhinModelAndLoss(torch.nn.Module): self.device = device def forward(self, batch: "RecapBatch"): # type: ignore[name-defined] - """Runs model forward and calculates loss according to given loss_fn. - - NOTE: The input signature here needs to be a Pipelineable object for - prefetching purposes during training using torchrec's pipeline. However - the underlying model signature needs to be exportable to onnx, requiring - generic python types. see https://pytorch.org/docs/stable/onnx.html#types. - """ + Run the model forward and calculate the loss according to the given loss_fn. + + NOTE: The input signature here needs to be a Pipelineable object for + prefetching purposes during training using torchrec's pipeline. However + the underlying model signature needs to be exportable to onnx, requiring + generic python types. see https://pytorch.org/docs/stable/onnx.html#types + + Args: + batch ("RecapBatch"): The input batch for model inference. + + Returns: + Tuple[torch.Tensor, Dict[str, torch.Tensor]]: A tuple containing the loss tensor and a dictionary of + additional outputs including logits, labels, and weights. + """ outputs = self.model(batch) logits = outputs["logits"] diff --git a/projects/twhin/models/test_models.py b/projects/twhin/models/test_models.py index f8203ef..05e9b2b 100644 --- a/projects/twhin/models/test_models.py +++ b/projects/twhin/models/test_models.py @@ -18,6 +18,12 @@ EMB_DIM = 128 def twhin_model_config() -> TwhinModelConfig: + """ + Create a configuration for the Twhin model. + + Returns: + TwhinModelConfig: The Twhin model configuration. + """ sgd_config_0 = OptimizerConfig(sgd=SgdConfig(lr=0.01)) sgd_config_1 = OptimizerConfig(sgd=SgdConfig(lr=0.02)) @@ -52,6 +58,12 @@ def twhin_model_config() -> TwhinModelConfig: def twhin_data_config() -> TwhinDataConfig: + """ + Create a configuration for the Twhin data. + + Returns: + TwhinDataConfig: The Twhin data configuration. + """ data_config = TwhinDataConfig( data_root="/", per_replica_batch_size=10, @@ -65,6 +77,15 @@ def twhin_data_config() -> TwhinDataConfig: def test_twhin_model(): + """ + Test the Twhin model creation and optimization. + + This function creates a Twhin model using the specified configuration and tests its optimization. It also checks + the device placement of model parameters. + + Returns: + None + """ model_config = twhin_model_config() loss_fn = F.binary_cross_entropy_with_logits diff --git a/projects/twhin/optimizer.py b/projects/twhin/optimizer.py index 7c43b56..e39e0d0 100644 --- a/projects/twhin/optimizer.py +++ b/projects/twhin/optimizer.py @@ -15,6 +15,14 @@ TRANSLATION_OPT_KEY = "operator_opt" def _lr_from_config(optimizer_config): + """Get the learning rate from an optimizer configuration. + + Args: + optimizer_config: Optimizer configuration. + + Returns: + Learning rate from the optimizer configuration. + """ if optimizer_config.learning_rate is not None: return optimizer_config.learning_rate else: @@ -26,13 +34,13 @@ def _lr_from_config(optimizer_config): def build_optimizer(model: TwhinModel, config: TwhinModelConfig): """Builds an optimizer for a Twhin model combining the embeddings optimizer with an optimizer for per-relation translations. - Args: - model: TwhinModel to build optimizer for. - config: TwhinConfig for model. + Args: + model: TwhinModel to build optimizer for. + config: TwhinModelConfig for model. - Returns: - Optimizer for model. - """ + Returns: + Optimizer for model. + """ translation_optimizer_fn = functools.partial( get_optimizer_class(config.translation_optimizer), **get_optimizer_algorithm_config(config.translation_optimizer).dict(), diff --git a/projects/twhin/run.py b/projects/twhin/run.py index 12081e0..6124bc1 100644 --- a/projects/twhin/run.py +++ b/projects/twhin/run.py @@ -37,6 +37,12 @@ def run( all_config: TwhinConfig, save_dir: Optional[str] = None, ): + """Run the training process for TwhinModel. + + Args: + all_config (TwhinConfig): The configuration for the entire Twhin model. + save_dir (str, optional): The directory where model checkpoints will be saved. Defaults to None. + """ train_dataset = create_dataset(all_config.train_data, all_config.model) if env.is_reader(): @@ -80,6 +86,11 @@ def run( def main(argv): + """Main entry point for the Twhin training script. + + Args: + argv: Command-line arguments. + """ logging.info("Starting") logging.info(f"parsing config from {FLAGS.config_yaml_path}...") diff --git a/reader/dataset.py b/reader/dataset.py index 6e811cc..5d873cd 100644 --- a/reader/dataset.py +++ b/reader/dataset.py @@ -25,14 +25,58 @@ from tml.ml_logging.torch_logging import logging class _Reader(pa.flight.FlightServerBase): - """Distributed reader flight server wrapping a dataset.""" + """ + Distributed reader flight server wrapping a dataset. + + This class implements a Flight server that wraps a dataset, allowing clients to retrieve data + from the dataset over the Flight protocol. It is designed to be used in a distributed environment + for efficient data access. + + Args: + location (str): The location of the Flight server. + ds (Dataset): The dataset to be wrapped by the Flight server. + + Attributes: + _location (str): The location of the Flight server. + _ds (Dataset): The dataset wrapped by the Flight server. + + Methods: + do_get(_, __): Handles Flight requests for data retrieval. + + Note: + Flight is an Apache Arrow project that provides a framework for efficient data transfer. + This class allows clients to retrieve data from the dataset using Flight. + + """ def __init__(self, location: str, ds: "Dataset"): + """ + Initialize a new _Reader instance. + + Args: + location (str): The location of the Flight server. + ds (Dataset): The dataset to be wrapped by the Flight server. + """ super().__init__(location=location) self._location = location self._ds = ds def do_get(self, _, __): + """ + Handle Flight requests for data retrieval. + + This method retrieves data from the wrapped dataset and provides it to clients over the Flight protocol. + + Args: + _: Unused argument. + __: Unused argument. + + Returns: + pa.flight.RecordBatchStream: A stream of record batches containing data from the dataset. + + Note: + An updated schema (to account for column selection) must be given to the stream. + """ # NB: An updated schema (to account for column selection) has to be given the stream. schema = next(iter(self._ds.to_batches())).schema batches = self._ds.to_batches() @@ -46,13 +90,49 @@ class _Reader(pa.flight.FlightServerBase): class Dataset(torch.utils.data.IterableDataset): + """ + A PyTorch IterableDataset wrapping a Parquet dataset for efficient data loading. + + This class enables efficient loading of data from Parquet files using PyArrow. + It is designed to be used as an IterableDataset in PyTorch for training and inference. + + Args: + file_pattern (str): A glob pattern specifying the Parquet files to include in the dataset. + **dataset_kwargs: Additional keyword arguments passed to PyArrow's `to_batches` method. + Refer to https://arrow.apache.org/docs/python/generated/pyarrow.dataset.Scanner.html#pyarrow.dataset.Scanner.from_dataset + for more details. + + Attributes: + LOCATION (str): The default location for the Flight server used for data distribution. + _file_pattern (str): The glob pattern specifying Parquet files in the dataset. + _fs: The filesystem object used for file operations. + _dataset_kwargs (dict): Additional keyword arguments passed to PyArrow's `to_batches` method. + _files (list): A list of file paths matching the glob pattern. + _schema (pa.Schema): The schema of the Parquet dataset. + + Methods: + serve(): Start serving the dataset using a Flight server. + to_batches(): Generate batches of data from the Parquet dataset. + pa_to_batch(batch: pa.RecordBatch) -> DataclassBatch: Convert a Parquet RecordBatch to a custom data batch. + dataloader(remote: bool = False): Create a PyTorch DataLoader for iterating through the dataset. + + Note: + This class efficiently loads data from Parquet files using PyArrow, and it can be used with PyTorch + to create DataLoader instances for training or inference. + """ LOCATION = "grpc://0.0.0.0:2222" def __init__(self, file_pattern: str, **dataset_kwargs) -> None: - """Specify batch size and column to select for. + """ + Initialize a new Dataset instance. Specify batch size and column to select for. Refer to https://arrow.apache.org/docs/python/generated/pyarrow.dataset.Scanner.html#pyarrow.dataset.Scanner.from_dataset. - """ + + + Args: + file_pattern (str): A glob pattern specifying the Parquet files to include in the dataset. + **dataset_kwargs: Additional keyword arguments passed to PyArrow's `to_batches` method. + """ self._file_pattern = file_pattern self._fs = infer_fs(self._file_pattern) self._dataset_kwargs = dataset_kwargs @@ -64,16 +144,25 @@ class Dataset(torch.utils.data.IterableDataset): self._validate_columns() def _validate_columns(self): + """ + Validate the specified columns against the dataset schema. + + Raises: + Exception: If any specified columns are not found in the dataset schema. + """ columns = set(self._dataset_kwargs.get("columns", [])) wrong_columns = set(columns) - set(self._schema.names) if wrong_columns: raise Exception(f"Specified columns {list(wrong_columns)} not in schema.") def serve(self): + """Start serving the dataset using a Flight server.""" self.reader = _Reader(location=self.LOCATION, ds=self) self.reader.serve() def _create_dataset(self): + """Create a PyArrow dataset for data retrieval.""" + return pads.dataset( source=random.sample(self._files, len(self._files))[0], format="parquet", @@ -100,9 +189,33 @@ class Dataset(torch.utils.data.IterableDataset): @abc.abstractmethod def pa_to_batch(self, batch: pa.RecordBatch) -> DataclassBatch: + """ + Convert a Parquet RecordBatch to a custom data batch. + + Args: + batch (pa.RecordBatch): A batch of data from the Parquet dataset. + + Returns: + DataclassBatch: A custom data batch used in PyTorch training. + + Raises: + NotImplementedError: This method must be implemented in derived classes. + """ raise NotImplementedError def dataloader(self, remote: bool = False): + """ + Create a PyTorch DataLoader for iterating through the dataset. + + Args: + remote (bool, optional): If True, create a remote DataLoader using Flight for distributed training. + + Returns: + DataLoader: A PyTorch DataLoader for iterating through the dataset. + + Note: + If `remote` is True, a remote DataLoader is created for distributed training using Flight. + """ if not remote: return map(self.pa_to_batch, self.to_batches()) readers = get_readers(2) @@ -117,6 +230,25 @@ GRPC_OPTIONS = [ def get_readers(num_readers_per_worker: int): + """ + Get Flight readers for distributed data loading. + + This function retrieves Flight readers for distributed data loading in a PyTorch environment. + + Args: + num_readers_per_worker (int): The number of Flight readers to retrieve per worker. + + Returns: + List[pa.RecordBatchFileReader]: A list of Flight readers for distributed data loading. + + Note: + Flight readers are used to fetch data in a distributed manner for efficient data loading. + + Example: + To obtain Flight readers, use the following code: + + >>> readers = get_readers(num_readers_per_worker=2) + """ addresses = env.get_flight_server_addresses() readers = [] diff --git a/reader/dds.py b/reader/dds.py index 7b49893..54b2c02 100644 --- a/reader/dds.py +++ b/reader/dds.py @@ -21,6 +21,16 @@ import torch.distributed as dist def maybe_start_dataset_service(): + """ + Start the dataset service if readers are available and required dependencies are met. + + This function checks if readers are available and if the required TensorFlow version is >= 2.5. + If both conditions are met and the current environment is the dispatcher or reader, it starts + the TensorFlow dataset service. + + Raises: + Exception: If the required TensorFlow version is not met (>= 2.5). + """ if not env.has_readers(): return @@ -59,6 +69,24 @@ def maybe_start_dataset_service(): def register_dataset( dataset: tf.data.Dataset, dataset_service: str, compression: Optional[str] = "AUTO" ): + """ + Register a dataset with the distributed dataset service. + + This function registers a dataset with the distributed dataset service and broadcasts the dataset ID + and job name to all processes in the distributed environment. + + Args: + dataset (tf.data.Dataset): The dataset to be registered. + dataset_service (str): The name of the dataset service. + compression (Optional[str]): The compression type for the dataset (default is "AUTO"). + + Returns: + Tuple[int, str]: A tuple containing the dataset ID and job name. + + Note: + This function should be called on the rank 0 process. + + """ if dist.get_rank() == 0: dataset_id = _register_dataset( service=dataset_service, @@ -82,6 +110,23 @@ def distribute_from_dataset_id( compression: Optional[str] = "AUTO", prefetch: Optional[int] = tf.data.experimental.AUTOTUNE, ) -> tf.data.Dataset: + """ + Distribute a dataset from a registered dataset ID. + + This function consumes a dataset from the distributed dataset service using the provided dataset ID + and job name. It also supports prefetching for improved performance. + + Args: + dataset_service (str): The name of the dataset service. + dataset_id (int): The ID of the dataset to be consumed. + job_name (Optional[str]): The name of the job associated with the dataset (optional). + compression (Optional[str]): The compression type for the dataset (default is "AUTO"). + prefetch (Optional[int]): The number of elements to prefetch (default is tf.data.experimental.AUTOTUNE). + + Returns: + tf.data.Dataset: The distributed dataset. + + """ logging.info(f"rank{dist.get_rank()}: Consuming dds job with {dataset_id}, {job_name}") dataset = _from_dataset_id( processing_mode="parallel_epochs", @@ -97,15 +142,28 @@ def distribute_from_dataset_id( def maybe_distribute_dataset(dataset: tf.data.Dataset) -> tf.data.Dataset: - """Torch-compatible and distributed-training-aware dataset service distributor. - - - rank 0 process will register the given dataset. - - rank 0 process will broadcast job name and dataset id. - - all rank processes will consume from the same job/dataset. - - Without this, dataset workers will try to serve 1 job per rank process and OOM. - """ + Distribute a TensorFlow dataset for Torch-compatible and distributed training-aware consumption. + + This function is used to distribute a dataset in a distributed training environment. It performs the + following steps: + - On the rank 0 process, it registers the given dataset with the distributed dataset service. + - It broadcasts the job name and dataset ID to all rank processes. + - All rank processes then consume the same dataset from the distributed dataset service. + + Args: + dataset (tf.data.Dataset): The TensorFlow dataset to be distributed. + + Returns: + tf.data.Dataset: The distributed TensorFlow dataset. + + Note: + - If there are no reader processes in the distributed environment, the original dataset is returned + without any distribution. + - This function is intended for use in distributed training environments to prevent out-of-memory (OOM) + issues caused by each rank process trying to serve one job. + + """ if not env.has_readers(): return dataset dataset_service = env.get_dds() diff --git a/reader/test_dataset.py b/reader/test_dataset.py index 85e756f..d9400e4 100644 --- a/reader/test_dataset.py +++ b/reader/test_dataset.py @@ -12,6 +12,17 @@ import torch def create_dataset(tmpdir): + """ + Create a mock dataset for testing. + + This function creates a mock dataset using PyArrow and Parquet for testing purposes. + + Args: + tmpdir: A temporary directory where the dataset will be created. + + Returns: + MockDataset: A mock dataset for testing. + """ table = pa.table( { @@ -34,6 +45,14 @@ def create_dataset(tmpdir): def test_dataset(tmpdir): + """ + Test the created dataset. + + This function tests the created mock dataset and checks if it behaves as expected. + + Args: + tmpdir: A temporary directory used for testing. + """ ds = create_dataset(tmpdir) batch = next(iter(ds.dataloader(remote=False))) assert batch.batch_size == 2 @@ -46,6 +65,14 @@ def test_dataset(tmpdir): reason="Multiprocessing doesn't work on github yet.", ) def test_distributed_dataset(tmpdir): + """ + Test the distributed dataset. + + This function tests the distributed version of the mock dataset using multiprocessing. + + Args: + tmpdir: A temporary directory used for testing. + """ MOCK_ENV = {"TEMP_SLURM_NUM_READERS": "1"} def _client(): diff --git a/reader/utils.py b/reader/utils.py index fc0e34c..c5c9721 100644 --- a/reader/utils.py +++ b/reader/utils.py @@ -11,11 +11,55 @@ import torch def roundrobin(*iterables): - """Round robin through provided iterables, useful for simple load balancing. - - Adapted from https://docs.python.org/3/library/itertools.html. - """ + Iterate through provided iterables in a round-robin fashion. + + This function takes multiple iterables and returns an iterator that yields elements from + each iterable in a round-robin manner. It continues cycling through the iterables until + all of them are exhausted. + + Adapted from https://docs.python.org/3/library/itertools.html. + + Args: + *iterables: One or more iterable objects to iterate through. + + Yields: + Elements from the provided iterables in a round-robin fashion. + + Raises: + StopIteration: If all provided iterables are exhausted. + + Example: + ```python + iterable1 = [1, 2, 3] + iterable2 = ['a', 'b', 'c'] + iterable3 = [0.1, 0.2, 0.3] + + for item in roundrobin(iterable1, iterable2, iterable3): + print(item) + + # Output: + # 1 + # 'a' + # 0.1 + # 2 + # 'b' + # 0.2 + # 3 + # 'c' + # 0.3 + ``` + + Note: + - If one of the provided iterables is shorter than the others, the function will + continue iterating through the remaining iterables until all are exhausted. + - If an iterable raises an exception during iteration, a warning message is logged, + and the function continues with the next iterable. + + See Also: + - `itertools.cycle`: A function that repeatedly cycles through elements of an iterable. + - `itertools.islice`: A function to slice an iterable to limit the number of iterations. + """ num_active = len(iterables) nexts = itertools.cycle(iter(it).__next__ for it in iterables) while num_active: @@ -35,6 +79,48 @@ def roundrobin(*iterables): def speed_check(data_loader, max_steps: int, frequency: int, peek: Optional[int]): + """ + Monitor the speed and progress of data loading using a data loader. + + This function iterates through a data loader for a specified number of steps or until + the end of the data loader is reached, periodically logging progress information. + + Args: + data_loader: The data loader to monitor. + max_steps: The maximum number of steps to iterate through the data loader. + frequency: The frequency (in steps) at which to log progress. + peek (optional): If specified, it indicates the frequency (in steps) at which to log + batch contents for inspection. + + Example: + ```python + import torch + from torch.utils.data import DataLoader + + # Create a data loader (replace with your own DataLoader configuration) + data_loader = DataLoader(dataset, batch_size=32, shuffle=True) + + # Monitor data loading speed and progress + speed_check(data_loader, max_steps=1000, frequency=50, peek=500) + ``` + + Args: + data_loader: The data loader to monitor. + max_steps: The maximum number of steps to iterate through the data loader. + frequency: The frequency (in steps) at which to log progress. + peek (optional): If specified, it indicates the frequency (in steps) at which to log + batch contents for inspection. + + Note: + - The function logs information about elapsed time, the number of examples processed, + and the processing speed in examples per second. + - If `peek` is provided, batch contents will be logged for inspection at the specified + frequency. + + See Also: + - `torch.utils.data.DataLoader`: PyTorch's data loading utility for batching and + iterating through datasets. + """ num_examples = 0 prev = time.perf_counter() for idx, batch in enumerate(data_loader): @@ -57,11 +143,66 @@ def speed_check(data_loader, max_steps: int, frequency: int, peek: Optional[int] def pa_to_torch(array: pa.array) -> torch.Tensor: + """ + Convert a PyArrow Array to a PyTorch Tensor. + + Args: + array (pa.array): The PyArrow Array to convert. + + Returns: + torch.Tensor: A PyTorch Tensor containing the data from the input PyArrow Array. + + Example: + ```python + import pyarrow as pa + import torch + + # Create a PyArrow Array + arrow_array = pa.array([1, 2, 3]) + + # Convert it to a PyTorch Tensor + torch_tensor = pa_to_torch(arrow_array) + ``` + """ return torch.from_numpy(array.to_numpy()) def create_default_pa_to_batch(schema) -> DataclassBatch: - """ """ + """ + Create a function that converts a PyArrow RecordBatch to a custom DataclassBatch with imputed values for missing data. + + Args: + schema (pa.Schema): The PyArrow schema describing the data structure of the RecordBatch. + + Returns: + callable: A function that takes a PyArrow RecordBatch as input and returns a custom DataclassBatch. + + Example: + ```python + import pyarrow as pa + from dataclass_batch import DataclassBatch + + # Define a PyArrow schema + schema = pa.schema([ + ("feature1", pa.float64()), + ("feature2", pa.int64()), + ("label", pa.int64()), + ]) + + # Create the conversion function + pa_to_batch = create_default_pa_to_batch(schema) + + # Create a PyArrow RecordBatch + record_batch = pa.RecordBatch.from_pandas(pd.DataFrame({ + "feature1": [1.0, 2.0, None], + "feature2": [10, 20, 30], + "label": [0, 1, None], + })) + + # Convert the RecordBatch to a custom DataclassBatch + custom_batch = pa_to_batch(record_batch) + ``` + """ _CustomBatch = DataclassBatch.from_schema("DefaultBatch", schema=schema) def get_imputation_value(pa_type):