diff --git a/ml_logging/absl_logging.py b/ml_logging/absl_logging.py
index e4d9e9e..07be8bb 100644
--- a/ml_logging/absl_logging.py
+++ b/ml_logging/absl_logging.py
@@ -14,7 +14,24 @@ from absl import logging as logging
 
 
 def setup_absl_logging():
-  """Make sure that absl logging pushes to stdout rather than stderr."""
+  """
+    Configure absl-py logging to direct log messages to stdout and apply a custom log message format.
+
+    This function ensures that log messages generated by the absl-py library are written to stdout
+    rather than stderr. It also applies a custom log message format that includes module, function,
+    line number, log level, and the log message content.
+
+    Note:
+        This function should be called once at the beginning of your script or application to
+        configure absl-py logging.
+
+    Example:
+        To use this function, simply call it at the start of your script:
+        ```
+        setup_absl_logging()
+        ```
+
+    """
   logging.get_absl_handler().python_handler.stream = sys.stdout
   formatter = py_logging.Formatter(
     fmt="[%(module)s.%(funcName)s:%(lineno)s - %(levelname)s] %(message)s"
diff --git a/ml_logging/test_torch_logging.py b/ml_logging/test_torch_logging.py
index e3a2e90..36c141d 100644
--- a/ml_logging/test_torch_logging.py
+++ b/ml_logging/test_torch_logging.py
@@ -5,6 +5,21 @@ from tml.ml_logging.torch_logging import logging
 
 class Testtlogging(unittest.TestCase):
   def test_warn_once(self):
+    """
+        Test that warning messages are logged only once when using the assertLogs context manager.
+
+        This unit test checks the behavior of the logging system when warning messages are issued
+        multiple times within the same context. It uses the assertLogs context manager to capture
+        log messages at the INFO level and verifies that warning messages are logged only once.
+
+        Example:
+            To use this test case, call it using a test runner like unittest:
+            ```
+            python -m unittest your_test_module.TestLogging.test_warn_once
+            ```
+
+        """
+
     with self.assertLogs(level="INFO") as captured_logs:
       logging.info("first info")
       logging.warning("first warning")
diff --git a/ml_logging/torch_logging.py b/ml_logging/torch_logging.py
index e791c46..2f60114 100644
--- a/ml_logging/torch_logging.py
+++ b/ml_logging/torch_logging.py
@@ -18,7 +18,35 @@ import torch.distributed as dist
 
 
 def rank_specific(logger):
-  """Ensures that we only override a given logger once."""
+  """
+    Customize logger behavior based on the distributed environment and rank.
+
+    This function allows for customizing the behavior of a logger based on the distributed environment and the rank
+    of the current process. It overrides standard logging methods (e.g., error, warning) to conditionally log messages
+    depending on the rank or limit the number of redundant logs.
+
+    Args:
+        logger: The logger object to customize.
+
+    Returns:
+        The customized logger.
+
+    Example:
+        To use this function with the `logging` module:
+        ```python
+        import logging
+        from rank_specific_logging import rank_specific
+
+        logger = logging.getLogger(__name__)
+        rank_specific(logger)
+        ```
+
+    Customization:
+        - Messages are only logged if the distributed environment is not initialized or if the rank matches.
+        - The 'warning' method is limited to logging a single redundant warning.
+        - Logging from rank -1 is redirected to include the rank information.
+
+    """
   if hasattr(logger, "_ALREADY_OVERWRITTEN_TO_BE_RANK_SPECIFIC"):
     return logger
 
diff --git a/optimizers/config.py b/optimizers/config.py
index f5011f0..dc5c94a 100644
--- a/optimizers/config.py
+++ b/optimizers/config.py
@@ -8,11 +8,60 @@ import pydantic
 
 
 class PiecewiseConstant(base_config.BaseConfig):
+  """
+    Configuration for a piecewise constant learning rate schedule.
+
+    This configuration class allows you to specify a piecewise constant learning rate schedule
+    by defining boundaries and corresponding learning rate values.
+
+    Attributes:
+        learning_rate_boundaries (List[int], optional): List of step boundaries at which
+            the learning rate will change. If None, no boundaries are defined.
+        learning_rate_values (List[float], optional): List of learning rate values
+            corresponding to the boundaries. If None, no values are defined.
+
+    Example:
+        To configure a piecewise constant learning rate schedule, create an instance of this class
+        and set the attributes accordingly. For example:
+
+        ```python
+        piecewise_lr = PiecewiseConstant(
+            learning_rate_boundaries=[1000, 2000, 3000],
+            learning_rate_values=[0.1, 0.05, 0.01, 0.001]
+        )
+        ```
+
+    Note:
+        The number of learning rate values should be one more than the number of boundaries.
+
+    """
   learning_rate_boundaries: typing.List[int] = pydantic.Field(None)
   learning_rate_values: typing.List[float] = pydantic.Field(None)
 
 
 class LinearRampToConstant(base_config.BaseConfig):
+  """
+    Configuration for a linear ramp-up to constant learning rate schedule.
+
+    This configuration class allows you to specify a learning rate schedule that ramps up linearly
+    from zero to a constant value over a specified number of steps.
+
+    Attributes:
+        learning_rate (float): The final constant learning rate.
+        num_ramp_steps (PositiveInt): Number of steps to ramp up the learning rate from zero.
+
+    Example:
+        To configure a linear ramp-up to a constant learning rate, create an instance of this class
+        and set the attributes accordingly. For example:
+
+        ```python
+        linear_ramp_lr = LinearRampToConstant(
+            learning_rate=0.1,
+            num_ramp_steps=1000
+        )
+        ```
+
+    """
   learning_rate: float
   num_ramp_steps: pydantic.PositiveInt = pydantic.Field(
     description="Number of steps to ramp this up from zero."
@@ -20,6 +69,32 @@ class LinearRampToConstant(base_config.BaseConfig):
 
 
 class LinearRampToCosine(base_config.BaseConfig):
+  """
+    Configuration for a linear ramp-up to cosine decay learning rate schedule.
+
+    This configuration class allows you to specify a learning rate schedule that ramps up linearly
+    from zero, then decays following a cosine schedule to a final constant learning rate.
+
+    Attributes:
+        learning_rate (float): The initial learning rate at the start of ramp-up.
+        final_learning_rate (float): The final constant learning rate after decay.
+        num_ramp_steps (PositiveInt): Number of steps to ramp up the learning rate from zero.
+        final_num_steps (PositiveInt): Final number of steps where decay stops.
+
+    Example:
+        To configure a linear ramp-up to cosine decay learning rate, create an instance of this
+        class and set the attributes accordingly. For example:
+
+        ```python
+        ramp_to_cosine_lr = LinearRampToCosine(
+            learning_rate=0.01,
+            final_learning_rate=0.001,
+            num_ramp_steps=1000,
+            final_num_steps=5000
+        )
+        ```
+
+    """
   learning_rate: float
   final_learning_rate: float
   num_ramp_steps: pydantic.PositiveInt = pydantic.Field(
@@ -31,6 +106,41 @@ class LinearRampToCosine(base_config.BaseConfig):
 
 
 class LearningRate(base_config.BaseConfig):
+  """
+    Learning rate configuration for training.
+
+    This configuration class allows you to specify different learning rate schedules
+    for your training process.
+
+    Attributes:
+        constant (float, optional): Constant learning rate to be used throughout training.
+        linear_ramp_to_cosine (LinearRampToCosine, optional): Learning rate that ramps up linearly
+            and then decays following a cosine schedule.
+        linear_ramp_to_constant (LinearRampToConstant, optional): Learning rate that ramps up
+            linearly and then remains constant.
+        piecewise_constant (PiecewiseConstant, optional): Learning rate that changes at specified
+            boundaries with corresponding values.
+
+    Example:
+        To configure a learning rate schedule, create an instance of this class and set the
+        attributes accordingly. For example:
+
+        ```python
+        learning_rate = LearningRate(
+            constant=0.01,
+            linear_ramp_to_cosine=LinearRampToCosine(
+                learning_rate=0.1,
+                final_learning_rate=0.001,
+                num_ramp_steps=1000,
+                final_num_steps=5000
+            )
+        )
+        ```
+
+    Note:
+        Each learning rate schedule attribute can be set to `None` if not needed.
+
+    """
   constant: float = pydantic.Field(None, one_of="lr")
   linear_ramp_to_cosine: LinearRampToCosine = pydantic.Field(None, one_of="lr")
   linear_ramp_to_constant: LinearRampToConstant = pydantic.Field(None, one_of="lr")
@@ -38,30 +148,166 @@ class LearningRate(base_config.BaseConfig):
 
 
 class OptimizerAlgorithmConfig(base_config.BaseConfig):
-  """Base class for optimizer configurations."""
+  """
+    Base class for optimizer configurations.
+
+    This base configuration class provides a structure for specifying various optimizer-related
+    settings, including the learning rate and different learning rate schedules.
+
+    Attributes:
+        lr (float): The base learning rate used by the optimizer.
+
+    Subclasses should inherit from this base class and define additional attributes specific to
+    the optimizer algorithm they represent.
+
+    Example:
+        To create a custom optimizer configuration, create a subclass of this base class and
+        define the necessary attributes. For example:
+
+        ```python
+        class MyOptimizerConfig(OptimizerAlgorithmConfig):
+            momentum: float = pydantic.Field(0.9, description="Momentum value for SGD.")
+        ```
+
+    Note:
+        This base class does not include specific optimizer settings. Subclasses should define
+        the optimizer-specific attributes as needed.
+
+    """
 
   lr: float
   ...
 
 
 class AdamConfig(OptimizerAlgorithmConfig):
-  # see https://pytorch.org/docs/stable/generated/torch.optim.Adam.html#torch.optim.Adam
+  """
+    Configuration for the Adam optimizer.
+
+    This configuration class allows you to specify the hyperparameters for the Adam optimizer.
+
+    Attributes:
+        lr (float): The learning rate for optimization.
+        betas (Tuple[float, float], optional): Coefficients used for computing running averages
+            of gradient and squared gradient. Defaults to (0.9, 0.999).
+        eps (float, optional): A small constant added to the denominator for numerical stability.
+            Defaults to 1e-7.
+
+    Example:
+        To configure the Adam optimizer, create an instance of this class and set the attributes
+        accordingly. For example:
+
+        ```python
+        adam_optimizer = AdamConfig(
+            lr=0.001,
+            betas=(0.9, 0.999),
+            eps=1e-8
+        )
+        ```
+
+    See Also:
+        [PyTorch Adam Documentation](https://pytorch.org/docs/stable/generated/torch.optim.Adam.html#torch.optim.Adam)
+
+    """
   lr: float
   betas: typing.Tuple[float, float] = [0.9, 0.999]
   eps: float = 1e-7  # Numerical stability in denominator.
 
 
 class SgdConfig(OptimizerAlgorithmConfig):
+  """
+    Configuration for the Stochastic Gradient Descent (SGD) optimizer.
+
+    This configuration class allows you to specify the hyperparameters for the SGD optimizer.
+
+    Attributes:
+        lr (float): The learning rate for optimization.
+        momentum (float, optional): The momentum factor for SGD. Defaults to 0.0.
+
+    Example:
+        To configure the SGD optimizer, create an instance of this class and set the attributes
+        accordingly. For example:
+
+        ```python
+        sgd_optimizer = SgdConfig(
+            lr=0.01,
+            momentum=0.9
+        )
+        ```
+
+    """
   lr: float
   momentum: float = 0.0
 
 
 class AdagradConfig(OptimizerAlgorithmConfig):
+  """
+    Configuration for the optimizer used during training.
+
+    This configuration class allows you to specify the optimizer for training, including
+    options for various optimizer algorithms.
+
+    Attributes:
+        learning_rate (LearningRate, optional): Learning rate configuration. Defaults to None.
+        adam (AdamConfig, optional): Configuration for the Adam optimizer. Defaults to None.
+        sgd (SgdConfig, optional): Configuration for the Stochastic Gradient Descent (SGD) optimizer.
+            Defaults to None.
+        adagrad (AdagradConfig, optional): Configuration for the Adagrad optimizer. Defaults to None.
+
+    Example:
+        To configure the optimizer for training, create an instance of this class and set the
+        attributes accordingly. For example:
+
+        ```python
+        optimizer_config = OptimizerConfig(
+            learning_rate=LearningRate(constant=0.001),
+            adam=AdamConfig(lr=0.001, betas=(0.9, 0.999), eps=1e-8)
+        )
+        ```
+
+    """
   lr: float
   eps: float = 0
 
 
 class OptimizerConfig(base_config.BaseConfig):
+  """
+    Configuration for defining different optimizer algorithms and their parameters.
+
+    This class allows you to configure various optimizer algorithms such as Adam, SGD, and Adagrad,
+    along with their respective hyperparameters.
+
+    Args:
+        learning_rate (LearningRate): The learning rate configuration, which can include
+            constant learning rates or other learning rate schedules.
+        adam (AdamConfig): Configuration for the Adam optimizer.
+        sgd (SgdConfig): Configuration for the Stochastic Gradient Descent (SGD) optimizer.
+        adagrad (AdagradConfig): Configuration for the Adagrad optimizer.
+
+    Example:
+        ```python
+        optimizer_config = OptimizerConfig(
+            learning_rate=LearningRate(constant=0.001),
+            adam=AdamConfig(lr=0.001, betas=(0.9, 0.999), eps=1e-8),
+        )
+        ```
+
+    Attributes:
+        learning_rate (LearningRate): The learning rate configuration.
+        adam (AdamConfig): Configuration for the Adam optimizer.
+        sgd (SgdConfig): Configuration for the Stochastic Gradient Descent (SGD) optimizer.
+        adagrad (AdagradConfig): Configuration for the Adagrad optimizer.
+
+    Note:
+        You can specify only one of the optimizer configurations (adam, sgd, or adagrad) in an
+        `OptimizerConfig` instance.
+
+    See Also:
+        - `LearningRate`: Configuration for specifying learning rates.
+        - `AdamConfig`: Configuration for the Adam optimizer.
+        - `SgdConfig`: Configuration for the Stochastic Gradient Descent (SGD) optimizer.
+        - `AdagradConfig`: Configuration for the Adagrad optimizer.
+
+    """
   learning_rate: LearningRate = pydantic.Field(
     None,
     description="Constant learning rates",
@@ -72,6 +318,33 @@ class OptimizerConfig(base_config.BaseConfig):
 
 
 def get_optimizer_algorithm_config(optimizer_config: OptimizerConfig):
+  """
+    Get the optimizer algorithm configuration from the given `OptimizerConfig`.
+
+    This function extracts and returns the specific optimizer algorithm configuration
+    (e.g., Adam, SGD, or Adagrad) from the provided `OptimizerConfig`.
+
+    Args:
+        optimizer_config (OptimizerConfig): The optimizer configuration object containing
+            one of the optimizer algorithm configurations.
+
+    Returns:
+        Union[AdamConfig, SgdConfig, AdagradConfig]: The specific optimizer algorithm
+        configuration extracted from `optimizer_config`.
+
+    Raises:
+        ValueError: If no optimizer algorithm is selected in `optimizer_config`.
+
+    Example:
+        ```python
+        optimizer_config = OptimizerConfig(
+            adam=AdamConfig(lr=0.001, betas=(0.9, 0.999), eps=1e-8)
+        )
+        algorithm_config = get_optimizer_algorithm_config(optimizer_config)
+        # `algorithm_config` will be an instance of `AdamConfig`.
+        ```
+
+    """
   if optimizer_config.adam is not None:
     return optimizer_config.adam
   elif optimizer_config.sgd is not None:
diff --git a/optimizers/optimizer.py b/optimizers/optimizer.py
index 4517368..e930f73 100644
--- a/optimizers/optimizer.py
+++ b/optimizers/optimizer.py
@@ -14,7 +14,35 @@ from tml.ml_logging.torch_logging import logging
 
 
 def compute_lr(lr_config, step):
-  """Compute a learning rate."""
+  """
+    Compute the learning rate based on the specified learning rate configuration.
+
+    This function calculates the learning rate according to the given configuration, which can include
+    constant learning rates, piecewise constant schedules, linear ramps, and cosine annealing.
+
+    Args:
+        lr_config (LearningRate): The learning rate configuration specifying the learning rate schedule.
+        step (int): The current training step or iteration.
+
+    Returns:
+        float: The computed learning rate for the current step.
+
+    Raises:
+        ValueError: If the `lr_config` is invalid or contains conflicting options.
+
+    Example:
+        ```python
+        lr_schedule = LearningRate(
+            constant=0.001,
+            piecewise_constant=PiecewiseConstant(
+                learning_rate_boundaries=[1000, 2000, 3000],
+                learning_rate_values=[0.1, 0.05, 0.01, 0.001]
+            )
+        )
+        current_step = 2500
+        learning_rate = compute_lr(lr_schedule, current_step)
+        ```
+  """
   if lr_config.constant is not None:
     return lr_config.constant
   elif lr_config.piecewise_constant is not None:
@@ -46,11 +74,54 @@ def compute_lr(lr_config, step):
 
 
 class LRShim(_LRScheduler):
-  """Shim to get learning rates into a LRScheduler.
-
-  This adheres to the torch.optim scheduler API and can be plugged anywhere that
-  e.g. exponential decay can be used.
   """
+    Learning Rate Scheduler Shim to adjust learning rates during training.
+
+    This class acts as a shim to apply different learning rates to individual parameter groups
+    within an optimizer. It adheres to the torch.optim scheduler API and can be used with various
+    optimizers, allowing fine-grained control over learning rates based on configuration.
+
+    Args:
+        optimizer (torch.optim.Optimizer): The optimizer for which learning rates will be adjusted.
+        lr_dict (Dict[str, LearningRate]): A dictionary mapping parameter group names to their
+            corresponding learning rate configurations.
+        last_epoch (int, optional): The index of the last epoch. Default is -1.
+        verbose (bool, optional): If True, prints a warning message when accessing learning rates
+            using the deprecated `get_lr()` method. Default is False.
+
+    Raises:
+        ValueError: If the number of parameter groups in the optimizer does not match the number
+            of learning rate configurations provided.
+
+    Note:
+        To obtain the last computed learning rates, please use `get_last_lr()`.
+
+    Example:
+        ```python
+        optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
+        lr_schedule = {
+            'main': LearningRate(constant=0.01),
+            'auxiliary': LearningRate(piecewise_constant=PiecewiseConstant(
+                learning_rate_boundaries=[1000, 2000],
+                learning_rate_values=[0.01, 0.001]
+            ))
+        }
+        lr_shim = LRShim(optimizer, lr_schedule)
+
+        for epoch in range(num_epochs):
+            # Train the model
+            train(...)
+            # Update learning rates at the end of each epoch
+            lr_shim.step(epoch)
+
+        final_lr_main = lr_shim.get_last_lr()['main']
+        final_lr_auxiliary = lr_shim.get_last_lr()['auxiliary']
+        ```
+
+    See Also:
+        - `LearningRate`: Configuration for specifying learning rates.
+        - `PiecewiseConstant`: Configuration for piecewise constant learning rate schedules.
+    """
 
   def __init__(
     self,
@@ -95,9 +166,42 @@ def get_optimizer_class(optimizer_config: OptimizerConfig):
 def build_optimizer(
   model: torch.nn.Module, optimizer_config: OptimizerConfig
 ) -> Tuple[Optimizer, _LRScheduler]:
-  """Builds an optimizer and LR scheduler from an OptimizerConfig.
-  Note: use this when you want the same optimizer and learning rate schedule for all your parameters.
   """
+    Build an optimizer and learning rate scheduler based on the provided optimizer configuration.
+
+    Args:
+        model (torch.nn.Module): The PyTorch model for which the optimizer will be created.
+        optimizer_config (OptimizerConfig): The optimizer configuration specifying the optimizer
+            algorithm and learning rate settings.
+
+    Returns:
+        Tuple[Optimizer, _LRScheduler]: A tuple containing the optimizer and learning rate scheduler
+            objects.
+
+    Note:
+        This function is intended for cases where you want the same optimizer and learning rate
+        schedule for all model parameters.
+
+    Example:
+        ```python
+        model = MyModel()
+        optimizer_config = OptimizerConfig(
+            learning_rate=LearningRate(constant=0.01),
+            sgd=SgdConfig(lr=0.01, momentum=0.9)
+        )
+        optimizer, scheduler = build_optimizer(model, optimizer_config)
+
+        for epoch in range(num_epochs):
+            # Train the model with the optimizer
+            train(model, optimizer, ...)
+            # Update learning rates at the end of each epoch
+            scheduler.step(epoch)
+        ```
+
+    See Also:
+        - `OptimizerConfig`: Configuration for specifying optimizer settings.
+        - `LRShim`: Learning rate scheduler shim for fine-grained learning rate control.
+    """
   optimizer_class = get_optimizer_class(optimizer_config)
   optimizer = optimizer_class(model.parameters(), **optimizer_config.sgd.dict())
   # We're passing everything in as one group here
diff --git a/projects/twhin/data/config.py b/projects/twhin/data/config.py
index 55e3cec..05c70c4 100644
--- a/projects/twhin/data/config.py
+++ b/projects/twhin/data/config.py
@@ -4,6 +4,17 @@ import pydantic
 
 
 class TwhinDataConfig(base_config.BaseConfig):
+  """
+    Configuration for Twhin model training data.
+
+    Args:
+        data_root (str): The root directory for the training data.
+        per_replica_batch_size (pydantic.PositiveInt): Batch size per replica.
+        global_negatives (int): The number of global negatives.
+        in_batch_negatives (int): The number of in-batch negatives.
+        limit (pydantic.PositiveInt): The limit on the number of data points to use.
+        offset (pydantic.PositiveInt, optional): The offset to start reading from. Default is None.
+    """
   data_root: str
   per_replica_batch_size: pydantic.PositiveInt
   global_negatives: int
diff --git a/projects/twhin/data/data.py b/projects/twhin/data/data.py
index 37bf8f5..b0afe7f 100644
--- a/projects/twhin/data/data.py
+++ b/projects/twhin/data/data.py
@@ -4,6 +4,16 @@ from tml.projects.twhin.data.edges import EdgesDataset
 
 
 def create_dataset(data_config: TwhinDataConfig, model_config: TwhinModelConfig):
+  """
+    Create a dataset for Twhin model training.
+
+    Args:
+        data_config (TwhinDataConfig): The data configuration for the dataset.
+        model_config (TwhinModelConfig): The model configuration containing embeddings and relations.
+
+    Returns:
+        EdgesDataset: The dataset for Twhin model training.
+    """
   tables = model_config.embeddings.tables
   table_sizes = {table.name: table.num_embeddings for table in tables}
   relations = model_config.relations
diff --git a/projects/twhin/data/edges.py b/projects/twhin/data/edges.py
index f7864b1..778a188 100644
--- a/projects/twhin/data/edges.py
+++ b/projects/twhin/data/edges.py
@@ -15,6 +15,15 @@ from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
 
 @dataclass
 class EdgeBatch(DataclassBatch):
+  """
+    Batch data structure for edge-based models.
+
+    Args:
+        nodes (KeyedJaggedTensor): A KeyedJaggedTensor containing node embeddings.
+        labels (torch.Tensor): Tensor containing labels.
+        rels (torch.Tensor): Tensor containing relation information.
+        weights (torch.Tensor): Tensor containing weights.
+    """
   nodes: KeyedJaggedTensor
   labels: torch.Tensor
   rels: torch.Tensor
@@ -22,6 +31,18 @@ class EdgeBatch(DataclassBatch):
 
 
 class EdgesDataset(Dataset):
+  """
+    Dataset for edge-based models.
+
+    Args:
+        file_pattern (str): The file pattern for the dataset.
+        table_sizes (Dict[str, int]): A dictionary of table names and their sizes.
+        relations (List[Relation]): A list of relations between tables.
+        lhs_column_name (str): The name of the left-hand-side column.
+        rhs_column_name (str): The name of the right-hand-side column.
+        rel_column_name (str): The name of the relation column.
+        **dataset_kwargs: Additional keyword arguments for the parent Dataset class.
+    """
   rng = np.random.default_rng()
 
   def __init__(
@@ -56,6 +77,15 @@ class EdgesDataset(Dataset):
     super().__init__(file_pattern=file_pattern, **dataset_kwargs)
 
   def pa_to_batch(self, batch: pa.RecordBatch):
+    """
+        Converts a pyarrow RecordBatch to an EdgeBatch.
+
+        Args:
+            batch (pa.RecordBatch): A pyarrow RecordBatch containing data.
+
+        Returns:
+            EdgeBatch: An EdgeBatch containing node embeddings, labels, relations, and weights.
+        """
     lhs = torch.from_numpy(batch.column(self.lhs_column_name).to_numpy())
     rhs = torch.from_numpy(batch.column(self.rhs_column_name).to_numpy())
     rel = torch.from_numpy(batch.column(self.rel_column_name).to_numpy())
@@ -74,6 +104,14 @@ class EdgesDataset(Dataset):
   ) -> Tuple[KeyedJaggedTensor, List[Tuple[int, int]]]:
 
     """Process edges that contain lhs index, rhs index, relation index.
+
+    Args:
+      lhs (torch.Tensor): Tensor containing left-hand-side indices.
+      rhs (torch.Tensor): Tensor containing right-hand-side indices.
+      rel (torch.Tensor): Tensor containing relation indices.
+
+    Returns:
+      Tuple[KeyedJaggedTensor, List[Tuple[int, int]]]: A KeyedJaggedTensor and relation index pairs.
     Example:
 
     ```
@@ -147,6 +185,12 @@ class EdgesDataset(Dataset):
     return KeyedJaggedTensor(keys=self.table_names, values=values, lengths=lengths)
 
   def to_batches(self):
+    """
+        Converts data to batches.
+
+        Yields:
+            pa.RecordBatch: A pyarrow RecordBatch containing data.
+        """
     ds = super().to_batches()
     batch_size = self._dataset_kwargs["batch_size"]
 
diff --git a/projects/twhin/models/config.py b/projects/twhin/models/config.py
index aafc9e6..c7b2e1c 100644
--- a/projects/twhin/models/config.py
+++ b/projects/twhin/models/config.py
@@ -10,8 +10,29 @@ from pydantic import validator
 
 
 class TwhinEmbeddingsConfig(LargeEmbeddingsConfig):
+  """
+    Configuration class for Twhin model embeddings.
+
+    This class inherits from LargeEmbeddingsConfig and ensures that the embedding dimensions and data types
+    for all tables in the Twhin model embeddings configuration match.
+
+    Attributes:
+        tables (List[TableConfig]): A list of table configurations for the model's embeddings.
+    """
   @validator("tables")
   def embedding_dims_match(cls, tables):
+    """
+        Validate that embedding dimensions and data types match for all tables.
+
+        Args:
+            tables (List[TableConfig]): List of table configurations.
+
+        Returns:
+            List[TableConfig]: The list of validated table configurations.
+
+        Raises:
+            AssertionError: If embedding dimensions or data types do not match.
+        """
     embedding_dim = tables[0].embedding_dim
     data_type = tables[0].data_type
     for table in tables:
@@ -21,11 +42,26 @@ class TwhinEmbeddingsConfig(LargeEmbeddingsConfig):
 
 
 class Operator(str, enum.Enum):
+  """
+    Enumeration of operator types.
+
+    This enumeration defines different types of operators that can be applied to Twhin model relations.
+    """
   TRANSLATION = "translation"
 
 
 class Relation(pydantic.BaseModel):
-  """graph relationship properties and operator"""
+  """
+    Configuration class for graph relationships in the Twhin model.
+
+    This class defines properties and operators for graph relationships in the Twhin model.
+
+    Attributes:
+        name (str): The name of the relationship.
+        lhs (str): The name of the entity on the left-hand side of the relation.
+        rhs (str): The name of the entity on the right-hand side of the relation.
+        operator (Operator): The transformation operator to apply to the left-hand side embedding before dot product.
+    """
 
   name: str = pydantic.Field(..., description="Relationship name.")
   lhs: str = pydantic.Field(
@@ -42,12 +78,35 @@ class Relation(pydantic.BaseModel):
 
 
 class TwhinModelConfig(base_config.BaseConfig):
+  """
+    Configuration class for the Twhin model.
+
+    This class defines configuration options specific to the Twhin model.
+
+    Attributes:
+        embeddings (TwhinEmbeddingsConfig): Configuration for the model's embeddings.
+        relations (List[Relation]): List of graph relationship configurations.
+        translation_optimizer (OptimizerConfig): Configuration for the optimizer used for translation.
+    """
   embeddings: TwhinEmbeddingsConfig
   relations: typing.List[Relation]
   translation_optimizer: OptimizerConfig
 
   @validator("relations", each_item=True)
   def valid_node_types(cls, relation, values, **kwargs):
+    """
+        Validate that the specified node types in relations are valid table names in embeddings.
+
+        Args:
+            relation (Relation): A single relation configuration.
+            values (dict): The values dictionary containing the "embeddings" configuration.
+
+        Returns:
+            Relation: The validated relation configuration.
+
+        Raises:
+            AssertionError: If the specified node types are not valid table names in embeddings.
+        """
     table_names = [table.name for table in values["embeddings"].tables]
     assert relation.lhs in table_names, f"Invalid lhs node type: {relation.lhs}"
     assert relation.rhs in table_names, f"Invalid rhs node type: {relation.rhs}"
diff --git a/projects/twhin/models/models.py b/projects/twhin/models/models.py
index b1dfaa7..0b5f1cb 100644
--- a/projects/twhin/models/models.py
+++ b/projects/twhin/models/models.py
@@ -14,6 +14,28 @@ from torchrec.optim.apply_optimizer_in_backward import apply_optimizer_in_backwa
 
 
 class TwhinModel(nn.Module):
+  """
+    Twhin model for graph-based entity embeddings and translation.
+
+    This class defines the Twhin model, which is used for learning embeddings of entities in a graph
+    and applying translations to these embeddings based on graph relationships.
+
+    Args:
+        model_config (TwhinModelConfig): Configuration for the Twhin model.
+        data_config (TwhinDataConfig): Configuration for the data used by the model.
+
+    Attributes:
+        batch_size (int): The batch size used for training.
+        table_names (List[str]): Names of tables in the model's embeddings.
+        large_embeddings (LargeEmbeddings): LargeEmbeddings instance for entity embeddings.
+        embedding_dim (int): Dimensionality of entity embeddings.
+        num_tables (int): Number of tables in the model's embeddings.
+        in_batch_negatives (int): Number of in-batch negative samples to use during training.
+        global_negatives (int): Number of global negative samples to use during training.
+        num_relations (int): Number of graph relationships in the model.
+        all_trans_embs (torch.nn.Parameter): Parameter tensor for translation embeddings.
+
+    """
   def __init__(self, model_config: TwhinModelConfig, data_config: TwhinDataConfig):
     super().__init__()
     self.batch_size = data_config.per_replica_batch_size
@@ -31,7 +53,17 @@ class TwhinModel(nn.Module):
     )
 
   def forward(self, batch: EdgeBatch):
+    """
+        Forward pass of the Twhin model.
 
+        Args:
+            batch (EdgeBatch): Input batch containing graph edge information.
+
+        Returns:
+            dict: A dictionary containing model output with "logits" and "probabilities".
+                - "logits" (torch.Tensor): Logit scores.
+                - "probabilities" (torch.Tensor): Sigmoid probabilities.
+        """
     # B x D
     trans_embs = self.all_trans_embs.data[batch.rels]
 
@@ -98,6 +130,18 @@ class TwhinModel(nn.Module):
 
 
 def apply_optimizers(model: TwhinModel, model_config: TwhinModelConfig):
+  """
+    Apply optimizers to the Twhin model's embeddings.
+
+    This function applies optimizers to the embeddings of the Twhin model based on the provided configuration.
+
+    Args:
+        model (TwhinModel): The Twhin model to apply optimizers to.
+        model_config (TwhinModelConfig): Configuration for the Twhin model.
+
+    Returns:
+        TwhinModel: The Twhin model with optimizers applied to its embeddings.
+    """
   for table in model_config.embeddings.tables:
     optimizer_class = get_optimizer_class(table.optimizer)
     optimizer_kwargs = get_optimizer_algorithm_config(table.optimizer).dict()
@@ -124,10 +168,14 @@ class TwhinModelAndLoss(torch.nn.Module):
     device: torch.device,
   ) -> None:
     """
-    Args:
-      model: torch module to wrap.
-      loss_fn: Function for calculating loss, should accept logits and labels.
-    """
+        Initialize a TwhinModelAndLoss module.
+
+        Args:
+            model: The torch module to wrap.
+            loss_fn: A function for calculating loss, should accept logits and labels.
+            data_config: Configuration for Twhin data.
+            device: The torch device to use for calculations.
+        """
     super().__init__()
     self.model = model
     self.loss_fn = loss_fn
@@ -136,14 +184,21 @@ class TwhinModelAndLoss(torch.nn.Module):
     self.device = device
 
   def forward(self, batch: "RecapBatch"):  # type: ignore[name-defined]
-    """Runs model forward and calculates loss according to given loss_fn.
-
-    NOTE: The input signature here needs to be a Pipelineable object for
-    prefetching purposes during training using torchrec's pipeline.  However
-    the underlying model signature needs to be exportable to onnx, requiring
-    generic python types.  see https://pytorch.org/docs/stable/onnx.html#types.
-
     """
+        Run the model forward and calculate the loss according to the given loss_fn.
+
+        NOTE: The input signature here needs to be a Pipelineable object for
+        prefetching purposes during training using torchrec's pipeline.  However
+        the underlying model signature needs to be exportable to onnx, requiring
+        generic python types.  see https://pytorch.org/docs/stable/onnx.html#types
+
+        Args:
+            batch ("RecapBatch"): The input batch for model inference.
+
+        Returns:
+            Tuple[torch.Tensor, Dict[str, torch.Tensor]]: A tuple containing the loss tensor and a dictionary of
+            additional outputs including logits, labels, and weights.
+        """
     outputs = self.model(batch)
     logits = outputs["logits"]
 
diff --git a/projects/twhin/models/test_models.py b/projects/twhin/models/test_models.py
index f8203ef..05e9b2b 100644
--- a/projects/twhin/models/test_models.py
+++ b/projects/twhin/models/test_models.py
@@ -18,6 +18,12 @@ EMB_DIM = 128
 
 
 def twhin_model_config() -> TwhinModelConfig:
+  """
+    Create a configuration for the Twhin model.
+
+    Returns:
+        TwhinModelConfig: The Twhin model configuration.
+    """
   sgd_config_0 = OptimizerConfig(sgd=SgdConfig(lr=0.01))
   sgd_config_1 = OptimizerConfig(sgd=SgdConfig(lr=0.02))
 
@@ -52,6 +58,12 @@ def twhin_model_config() -> TwhinModelConfig:
 
 
 def twhin_data_config() -> TwhinDataConfig:
+  """
+    Create a configuration for the Twhin data.
+
+    Returns:
+        TwhinDataConfig: The Twhin data configuration.
+    """
   data_config = TwhinDataConfig(
     data_root="/",
     per_replica_batch_size=10,
@@ -65,6 +77,15 @@ def twhin_data_config() -> TwhinDataConfig:
 
 
 def test_twhin_model():
+  """
+    Test the Twhin model creation and optimization.
+
+    This function creates a Twhin model using the specified configuration and tests its optimization. It also checks
+    the device placement of model parameters.
+
+    Returns:
+        None
+    """
   model_config = twhin_model_config()
   loss_fn = F.binary_cross_entropy_with_logits
 
diff --git a/projects/twhin/optimizer.py b/projects/twhin/optimizer.py
index 7c43b56..e39e0d0 100644
--- a/projects/twhin/optimizer.py
+++ b/projects/twhin/optimizer.py
@@ -15,6 +15,14 @@ TRANSLATION_OPT_KEY = "operator_opt"
 
 
 def _lr_from_config(optimizer_config):
+  """Get the learning rate from an optimizer configuration.
+
+    Args:
+        optimizer_config: Optimizer configuration.
+
+    Returns:
+        Learning rate from the optimizer configuration.
+    """
   if optimizer_config.learning_rate is not None:
     return optimizer_config.learning_rate
   else:
@@ -26,13 +34,13 @@ def _lr_from_config(optimizer_config):
 def build_optimizer(model: TwhinModel, config: TwhinModelConfig):
   """Builds an optimizer for a Twhin model combining the embeddings optimizer with an optimizer for per-relation translations.
 
-  Args:
-    model: TwhinModel to build optimizer for.
-    config: TwhinConfig for model.
+    Args:
+        model: TwhinModel to build optimizer for.
+        config: TwhinModelConfig for model.
 
-  Returns:
-    Optimizer for model.
-  """
+    Returns:
+        Optimizer for model.
+    """
   translation_optimizer_fn = functools.partial(
     get_optimizer_class(config.translation_optimizer),
     **get_optimizer_algorithm_config(config.translation_optimizer).dict(),
diff --git a/projects/twhin/run.py b/projects/twhin/run.py
index 12081e0..6124bc1 100644
--- a/projects/twhin/run.py
+++ b/projects/twhin/run.py
@@ -37,6 +37,12 @@ def run(
   all_config: TwhinConfig,
   save_dir: Optional[str] = None,
 ):
+  """Run the training process for TwhinModel.
+
+  Args:
+    all_config (TwhinConfig): The configuration for the entire Twhin model.
+    save_dir (str, optional): The directory where model checkpoints will be saved. Defaults to None.
+  """
   train_dataset = create_dataset(all_config.train_data, all_config.model)
 
   if env.is_reader():
@@ -80,6 +86,11 @@ def run(
 
 
 def main(argv):
+  """Main entry point for the Twhin training script.
+
+  Args:
+    argv: Command-line arguments.
+  """
   logging.info("Starting")
 
   logging.info(f"parsing config from {FLAGS.config_yaml_path}...")
diff --git a/reader/dataset.py b/reader/dataset.py
index 6e811cc..5d873cd 100644
--- a/reader/dataset.py
+++ b/reader/dataset.py
@@ -25,14 +25,58 @@ from tml.ml_logging.torch_logging import logging
 
 
 class _Reader(pa.flight.FlightServerBase):
-  """Distributed reader flight server wrapping a dataset."""
+  """
+    Distributed reader flight server wrapping a dataset.
+
+    This class implements a Flight server that wraps a dataset, allowing clients to retrieve data
+    from the dataset over the Flight protocol. It is designed to be used in a distributed environment
+    for efficient data access.
+
+    Args:
+        location (str): The location of the Flight server.
+        ds (Dataset): The dataset to be wrapped by the Flight server.
+
+    Attributes:
+        _location (str): The location of the Flight server.
+        _ds (Dataset): The dataset wrapped by the Flight server.
+
+    Methods:
+        do_get(_, __): Handles Flight requests for data retrieval.
+
+    Note:
+        Flight is an Apache Arrow project that provides a framework for efficient data transfer.
+        This class allows clients to retrieve data from the dataset using Flight.
+
+    """
 
   def __init__(self, location: str, ds: "Dataset"):
+    """
+        Initialize a new _Reader instance.
+
+        Args:
+            location (str): The location of the Flight server.
+            ds (Dataset): The dataset to be wrapped by the Flight server.
+        """
     super().__init__(location=location)
     self._location = location
     self._ds = ds
 
   def do_get(self, _, __):
+    """
+        Handle Flight requests for data retrieval.
+
+        This method retrieves data from the wrapped dataset and provides it to clients over the Flight protocol.
+
+        Args:
+            _: Unused argument.
+            __: Unused argument.
+
+        Returns:
+            pa.flight.RecordBatchStream: A stream of record batches containing data from the dataset.
+
+        Note:
+            An updated schema (to account for column selection) must be given to the stream.
+        """
     # NB: An updated schema (to account for column selection) has to be given the stream.
     schema = next(iter(self._ds.to_batches())).schema
     batches = self._ds.to_batches()
@@ -46,13 +90,49 @@ class _Reader(pa.flight.FlightServerBase):
 
 
 class Dataset(torch.utils.data.IterableDataset):
+  """
+    A PyTorch IterableDataset wrapping a Parquet dataset for efficient data loading.
+
+    This class enables efficient loading of data from Parquet files using PyArrow.
+    It is designed to be used as an IterableDataset in PyTorch for training and inference.
+
+    Args:
+        file_pattern (str): A glob pattern specifying the Parquet files to include in the dataset.
+        **dataset_kwargs: Additional keyword arguments passed to PyArrow's `to_batches` method.
+                         Refer to https://arrow.apache.org/docs/python/generated/pyarrow.dataset.Scanner.html#pyarrow.dataset.Scanner.from_dataset
+                         for more details.
+
+    Attributes:
+        LOCATION (str): The default location for the Flight server used for data distribution.
+        _file_pattern (str): The glob pattern specifying Parquet files in the dataset.
+        _fs: The filesystem object used for file operations.
+        _dataset_kwargs (dict): Additional keyword arguments passed to PyArrow's `to_batches` method.
+        _files (list): A list of file paths matching the glob pattern.
+        _schema (pa.Schema): The schema of the Parquet dataset.
+
+    Methods:
+        serve(): Start serving the dataset using a Flight server.
+        to_batches(): Generate batches of data from the Parquet dataset.
+        pa_to_batch(batch: pa.RecordBatch) -> DataclassBatch: Convert a Parquet RecordBatch to a custom data batch.
+        dataloader(remote: bool = False): Create a PyTorch DataLoader for iterating through the dataset.
+
+    Note:
+        This class efficiently loads data from Parquet files using PyArrow, and it can be used with PyTorch
+        to create DataLoader instances for training or inference.
+    """
   LOCATION = "grpc://0.0.0.0:2222"
 
   def __init__(self, file_pattern: str, **dataset_kwargs) -> None:
-    """Specify batch size and column to select for.
+    """
+    Initialize a new Dataset instance. Specify batch size and column to select for.
 
     Refer to https://arrow.apache.org/docs/python/generated/pyarrow.dataset.Scanner.html#pyarrow.dataset.Scanner.from_dataset.
-    """
+
+
+        Args:
+            file_pattern (str): A glob pattern specifying the Parquet files to include in the dataset.
+            **dataset_kwargs: Additional keyword arguments passed to PyArrow's `to_batches` method.
+        """
     self._file_pattern = file_pattern
     self._fs = infer_fs(self._file_pattern)
     self._dataset_kwargs = dataset_kwargs
@@ -64,16 +144,25 @@ class Dataset(torch.utils.data.IterableDataset):
     self._validate_columns()
 
   def _validate_columns(self):
+    """
+        Validate the specified columns against the dataset schema.
+
+        Raises:
+            Exception: If any specified columns are not found in the dataset schema.
+        """
     columns = set(self._dataset_kwargs.get("columns", []))
     wrong_columns = set(columns) - set(self._schema.names)
     if wrong_columns:
       raise Exception(f"Specified columns {list(wrong_columns)} not in schema.")
 
   def serve(self):
+    """Start serving the dataset using a Flight server."""
     self.reader = _Reader(location=self.LOCATION, ds=self)
     self.reader.serve()
 
   def _create_dataset(self):
+    """Create a PyArrow dataset for data retrieval."""
+
     return pads.dataset(
       source=random.sample(self._files, len(self._files))[0],
       format="parquet",
@@ -100,9 +189,33 @@ class Dataset(torch.utils.data.IterableDataset):
 
   @abc.abstractmethod
   def pa_to_batch(self, batch: pa.RecordBatch) -> DataclassBatch:
+    """
+        Convert a Parquet RecordBatch to a custom data batch.
+
+        Args:
+            batch (pa.RecordBatch): A batch of data from the Parquet dataset.
+
+        Returns:
+            DataclassBatch: A custom data batch used in PyTorch training.
+
+        Raises:
+            NotImplementedError: This method must be implemented in derived classes.
+        """
     raise NotImplementedError
 
   def dataloader(self, remote: bool = False):
+    """
+        Create a PyTorch DataLoader for iterating through the dataset.
+
+        Args:
+            remote (bool, optional): If True, create a remote DataLoader using Flight for distributed training.
+
+        Returns:
+            DataLoader: A PyTorch DataLoader for iterating through the dataset.
+
+        Note:
+            If `remote` is True, a remote DataLoader is created for distributed training using Flight.
+        """
     if not remote:
       return map(self.pa_to_batch, self.to_batches())
     readers = get_readers(2)
@@ -117,6 +230,25 @@ GRPC_OPTIONS = [
 
 
 def get_readers(num_readers_per_worker: int):
+  """
+    Get Flight readers for distributed data loading.
+
+    This function retrieves Flight readers for distributed data loading in a PyTorch environment.
+
+    Args:
+        num_readers_per_worker (int): The number of Flight readers to retrieve per worker.
+
+    Returns:
+        List[pa.RecordBatchFileReader]: A list of Flight readers for distributed data loading.
+
+    Note:
+        Flight readers are used to fetch data in a distributed manner for efficient data loading.
+
+    Example:
+        To obtain Flight readers, use the following code:
+
+        >>> readers = get_readers(num_readers_per_worker=2)
+    """
   addresses = env.get_flight_server_addresses()
 
   readers = []
diff --git a/reader/dds.py b/reader/dds.py
index 7b49893..54b2c02 100644
--- a/reader/dds.py
+++ b/reader/dds.py
@@ -21,6 +21,16 @@ import torch.distributed as dist
 
 
 def maybe_start_dataset_service():
+  """
+    Start the dataset service if readers are available and required dependencies are met.
+
+    This function checks if readers are available and if the required TensorFlow version is >= 2.5.
+    If both conditions are met and the current environment is the dispatcher or reader, it starts
+    the TensorFlow dataset service.
+
+    Raises:
+        Exception: If the required TensorFlow version is not met (>= 2.5).
+    """
   if not env.has_readers():
     return
 
@@ -59,6 +69,24 @@ def maybe_start_dataset_service():
 def register_dataset(
   dataset: tf.data.Dataset, dataset_service: str, compression: Optional[str] = "AUTO"
 ):
+  """
+    Register a dataset with the distributed dataset service.
+
+    This function registers a dataset with the distributed dataset service and broadcasts the dataset ID
+    and job name to all processes in the distributed environment.
+
+    Args:
+        dataset (tf.data.Dataset): The dataset to be registered.
+        dataset_service (str): The name of the dataset service.
+        compression (Optional[str]): The compression type for the dataset (default is "AUTO").
+
+    Returns:
+        Tuple[int, str]: A tuple containing the dataset ID and job name.
+
+    Note:
+        This function should be called on the rank 0 process.
+
+    """
   if dist.get_rank() == 0:
     dataset_id = _register_dataset(
       service=dataset_service,
@@ -82,6 +110,23 @@ def distribute_from_dataset_id(
   compression: Optional[str] = "AUTO",
   prefetch: Optional[int] = tf.data.experimental.AUTOTUNE,
 ) -> tf.data.Dataset:
+  """
+    Distribute a dataset from a registered dataset ID.
+
+    This function consumes a dataset from the distributed dataset service using the provided dataset ID
+    and job name. It also supports prefetching for improved performance.
+
+    Args:
+        dataset_service (str): The name of the dataset service.
+        dataset_id (int): The ID of the dataset to be consumed.
+        job_name (Optional[str]): The name of the job associated with the dataset (optional).
+        compression (Optional[str]): The compression type for the dataset (default is "AUTO").
+        prefetch (Optional[int]): The number of elements to prefetch (default is tf.data.experimental.AUTOTUNE).
+
+    Returns:
+        tf.data.Dataset: The distributed dataset.
+
+    """
   logging.info(f"rank{dist.get_rank()}: Consuming dds job with {dataset_id}, {job_name}")
   dataset = _from_dataset_id(
     processing_mode="parallel_epochs",
@@ -97,15 +142,28 @@ def distribute_from_dataset_id(
 
 
 def maybe_distribute_dataset(dataset: tf.data.Dataset) -> tf.data.Dataset:
-  """Torch-compatible and distributed-training-aware dataset service distributor.
-
-  - rank 0 process will register the given dataset.
-  - rank 0 process will broadcast job name and dataset id.
-  - all rank processes will consume from the same job/dataset.
-
-  Without this, dataset workers will try to serve 1 job per rank process and OOM.
-
   """
+    Distribute a TensorFlow dataset for Torch-compatible and distributed training-aware consumption.
+
+    This function is used to distribute a dataset in a distributed training environment. It performs the
+    following steps:
+    - On the rank 0 process, it registers the given dataset with the distributed dataset service.
+    - It broadcasts the job name and dataset ID to all rank processes.
+    - All rank processes then consume the same dataset from the distributed dataset service.
+
+    Args:
+        dataset (tf.data.Dataset): The TensorFlow dataset to be distributed.
+
+    Returns:
+        tf.data.Dataset: The distributed TensorFlow dataset.
+
+    Note:
+        - If there are no reader processes in the distributed environment, the original dataset is returned
+          without any distribution.
+        - This function is intended for use in distributed training environments to prevent out-of-memory (OOM)
+          issues caused by each rank process trying to serve one job.
+
+    """
   if not env.has_readers():
     return dataset
   dataset_service = env.get_dds()
diff --git a/reader/test_dataset.py b/reader/test_dataset.py
index 85e756f..d9400e4 100644
--- a/reader/test_dataset.py
+++ b/reader/test_dataset.py
@@ -12,6 +12,17 @@ import torch
 
 
 def create_dataset(tmpdir):
+  """
+    Create a mock dataset for testing.
+
+    This function creates a mock dataset using PyArrow and Parquet for testing purposes.
+
+    Args:
+        tmpdir: A temporary directory where the dataset will be created.
+
+    Returns:
+        MockDataset: A mock dataset for testing.
+    """
 
   table = pa.table(
     {
@@ -34,6 +45,14 @@ def create_dataset(tmpdir):
 
 
 def test_dataset(tmpdir):
+  """
+    Test the created dataset.
+
+    This function tests the created mock dataset and checks if it behaves as expected.
+
+    Args:
+        tmpdir: A temporary directory used for testing.
+    """
   ds = create_dataset(tmpdir)
   batch = next(iter(ds.dataloader(remote=False)))
   assert batch.batch_size == 2
@@ -46,6 +65,14 @@ def test_dataset(tmpdir):
   reason="Multiprocessing doesn't work on github yet.",
 )
 def test_distributed_dataset(tmpdir):
+  """
+    Test the distributed dataset.
+
+    This function tests the distributed version of the mock dataset using multiprocessing.
+
+    Args:
+        tmpdir: A temporary directory used for testing.
+    """
   MOCK_ENV = {"TEMP_SLURM_NUM_READERS": "1"}
 
   def _client():
diff --git a/reader/utils.py b/reader/utils.py
index fc0e34c..c5c9721 100644
--- a/reader/utils.py
+++ b/reader/utils.py
@@ -11,11 +11,55 @@ import torch
 
 
 def roundrobin(*iterables):
-  """Round robin through provided iterables, useful for simple load balancing.
-
-  Adapted from https://docs.python.org/3/library/itertools.html.
-
   """
+    Iterate through provided iterables in a round-robin fashion.
+
+    This function takes multiple iterables and returns an iterator that yields elements from
+    each iterable in a round-robin manner. It continues cycling through the iterables until
+    all of them are exhausted.
+
+    Adapted from https://docs.python.org/3/library/itertools.html.
+
+    Args:
+        *iterables: One or more iterable objects to iterate through.
+
+    Yields:
+        Elements from the provided iterables in a round-robin fashion.
+
+    Raises:
+        StopIteration: If all provided iterables are exhausted.
+
+    Example:
+        ```python
+        iterable1 = [1, 2, 3]
+        iterable2 = ['a', 'b', 'c']
+        iterable3 = [0.1, 0.2, 0.3]
+
+        for item in roundrobin(iterable1, iterable2, iterable3):
+            print(item)
+
+        # Output:
+        # 1
+        # 'a'
+        # 0.1
+        # 2
+        # 'b'
+        # 0.2
+        # 3
+        # 'c'
+        # 0.3
+        ```
+
+    Note:
+        - If one of the provided iterables is shorter than the others, the function will
+          continue iterating through the remaining iterables until all are exhausted.
+        - If an iterable raises an exception during iteration, a warning message is logged,
+          and the function continues with the next iterable.
+
+    See Also:
+        - `itertools.cycle`: A function that repeatedly cycles through elements of an iterable.
+        - `itertools.islice`: A function to slice an iterable to limit the number of iterations.
+    """
   num_active = len(iterables)
   nexts = itertools.cycle(iter(it).__next__ for it in iterables)
   while num_active:
@@ -35,6 +79,48 @@ def roundrobin(*iterables):
 
 
 def speed_check(data_loader, max_steps: int, frequency: int, peek: Optional[int]):
+  """
+    Monitor the speed and progress of data loading using a data loader.
+
+    This function iterates through a data loader for a specified number of steps or until
+    the end of the data loader is reached, periodically logging progress information.
+
+    Args:
+        data_loader: The data loader to monitor.
+        max_steps: The maximum number of steps to iterate through the data loader.
+        frequency: The frequency (in steps) at which to log progress.
+        peek (optional): If specified, it indicates the frequency (in steps) at which to log
+            batch contents for inspection.
+
+    Example:
+        ```python
+        import torch
+        from torch.utils.data import DataLoader
+
+        # Create a data loader (replace with your own DataLoader configuration)
+        data_loader = DataLoader(dataset, batch_size=32, shuffle=True)
+
+        # Monitor data loading speed and progress
+        speed_check(data_loader, max_steps=1000, frequency=50, peek=500)
+        ```
+
+    Args:
+        data_loader: The data loader to monitor.
+        max_steps: The maximum number of steps to iterate through the data loader.
+        frequency: The frequency (in steps) at which to log progress.
+        peek (optional): If specified, it indicates the frequency (in steps) at which to log
+            batch contents for inspection.
+
+    Note:
+        - The function logs information about elapsed time, the number of examples processed,
+          and the processing speed in examples per second.
+        - If `peek` is provided, batch contents will be logged for inspection at the specified
+          frequency.
+
+    See Also:
+        - `torch.utils.data.DataLoader`: PyTorch's data loading utility for batching and
+          iterating through datasets.
+    """
   num_examples = 0
   prev = time.perf_counter()
   for idx, batch in enumerate(data_loader):
@@ -57,11 +143,66 @@ def speed_check(data_loader, max_steps: int, frequency: int, peek: Optional[int]
 
 
 def pa_to_torch(array: pa.array) -> torch.Tensor:
+  """
+    Convert a PyArrow Array to a PyTorch Tensor.
+
+    Args:
+        array (pa.array): The PyArrow Array to convert.
+
+    Returns:
+        torch.Tensor: A PyTorch Tensor containing the data from the input PyArrow Array.
+
+    Example:
+        ```python
+        import pyarrow as pa
+        import torch
+
+        # Create a PyArrow Array
+        arrow_array = pa.array([1, 2, 3])
+
+        # Convert it to a PyTorch Tensor
+        torch_tensor = pa_to_torch(arrow_array)
+        ```
+    """
   return torch.from_numpy(array.to_numpy())
 
 
 def create_default_pa_to_batch(schema) -> DataclassBatch:
-  """ """
+  """
+    Create a function that converts a PyArrow RecordBatch to a custom DataclassBatch with imputed values for missing data.
+
+    Args:
+        schema (pa.Schema): The PyArrow schema describing the data structure of the RecordBatch.
+
+    Returns:
+        callable: A function that takes a PyArrow RecordBatch as input and returns a custom DataclassBatch.
+
+    Example:
+        ```python
+        import pyarrow as pa
+        from dataclass_batch import DataclassBatch
+
+        # Define a PyArrow schema
+        schema = pa.schema([
+            ("feature1", pa.float64()),
+            ("feature2", pa.int64()),
+            ("label", pa.int64()),
+        ])
+
+        # Create the conversion function
+        pa_to_batch = create_default_pa_to_batch(schema)
+
+        # Create a PyArrow RecordBatch
+        record_batch = pa.RecordBatch.from_pandas(pd.DataFrame({
+            "feature1": [1.0, 2.0, None],
+            "feature2": [10, 20, 30],
+            "label": [0, 1, None],
+        }))
+
+        # Convert the RecordBatch to a custom DataclassBatch
+        custom_batch = pa_to_batch(record_batch)
+        ```
+    """
   _CustomBatch = DataclassBatch.from_schema("DefaultBatch", schema=schema)
 
   def get_imputation_value(pa_type):