This commit is contained in:
rajveer43 2023-09-21 22:53:34 +05:30
parent cc73f5fcb7
commit f7f26d0c20
9 changed files with 953 additions and 724 deletions

View File

@ -54,12 +54,20 @@ def maybe_shard_model(
model,
device: torch.device,
):
"""Set up and apply DistributedModelParallel to a model if running in a distributed environment.
"""
Set up and apply DistributedModelParallel to a model if running in a distributed environment.
If in a distributed environment, constructs Topology, sharders, and ShardingPlan, then applies
DistributedModelParallel.
If not in a distributed environment, returns model directly.
If not in a distributed environment, returns the model directly.
Args:
model: The PyTorch model.
device: The target device (e.g., 'cuda').
Returns:
The model wrapped with DistributedModelParallel if in a distributed environment, else the original model.
"""
if dist.is_initialized():
logging.info("***** Wrapping in DistributedModelParallel *****")
@ -74,13 +82,14 @@ def maybe_shard_model(
def log_sharded_tensor_content(weight_name: str, table_name: str, weight_tensor) -> None:
"""Handy function to log the content of EBC embedding layer.
"""
Handy function to log the content of an EBC (Embedding Bag Concatenation) embedding layer.
Only works for single GPU machines.
Args:
weight_name: name of tensor, as defined in model
table_name: name of the EBC table the weight is taken from
weight_tensor: embedding weight tensor
weight_name: Name of the tensor, as defined in the model.
table_name: Name of the EBC table the weight is taken from.
weight_tensor: Embedding weight tensor.
"""
logging.info(f"{weight_name}, {table_name}", rank=-1)
logging.info(f"{weight_tensor.metadata()}", rank=-1)

View File

@ -11,8 +11,10 @@ class ExplicitDateInputs(base_config.BaseConfig):
"""Arguments to select train/validation data using end_date and days of data."""
data_root: str = pydantic.Field(..., description="Data path prefix.")
end_date: str = pydantic.Field(..., description="Data end date, inclusive.")
days: int = pydantic.Field(..., description="Number of days of data for dataset.")
end_date: str = pydantic.Field(...,
description="Data end date, inclusive.")
days: int = pydantic.Field(...,
description="Number of days of data for dataset.")
num_missing_days_tol: int = pydantic.Field(
0, description="We tolerate <= num_missing_days_tol days of missing data."
)
@ -22,8 +24,10 @@ class ExplicitDatetimeInputs(base_config.BaseConfig):
"""Arguments to select train/validation data using end_datetime and hours of data."""
data_root: str = pydantic.Field(..., description="Data path prefix.")
end_datetime: str = pydantic.Field(..., description="Data end datetime, inclusive.")
hours: int = pydantic.Field(..., description="Number of hours of data for dataset.")
end_datetime: str = pydantic.Field(...,
description="Data end datetime, inclusive.")
hours: int = pydantic.Field(...,
description="Number of hours of data for dataset.")
num_missing_hours_tol: int = pydantic.Field(
0, description="We tolerate <= num_missing_hours_tol hours of missing data."
)
@ -42,7 +46,8 @@ class DatasetConfig(base_config.BaseConfig):
explicit_datetime_inputs: ExplicitDatetimeInputs = pydantic.Field(
None, one_of="date_inputs_format"
)
explicit_date_inputs: ExplicitDateInputs = pydantic.Field(None, one_of="date_inputs_format")
explicit_date_inputs: ExplicitDateInputs = pydantic.Field(
None, one_of="date_inputs_format")
global_batch_size: pydantic.PositiveInt
@ -52,7 +57,8 @@ class DatasetConfig(base_config.BaseConfig):
repeat_files: bool = pydantic.Field(
True, description="DEPRICATED. Files are repeated no matter what this is set to."
)
file_batch_size: pydantic.PositiveInt = pydantic.Field(16, description="File batch size")
file_batch_size: pydantic.PositiveInt = pydantic.Field(
16, description="File batch size")
cache: bool = pydantic.Field(
False,
@ -70,7 +76,8 @@ class DatasetConfig(base_config.BaseConfig):
)
# tf.data.Dataset options
examples_shuffle_buffer_size: int = pydantic.Field(1024, description="Size of shuffle buffers.")
examples_shuffle_buffer_size: int = pydantic.Field(
1024, description="Size of shuffle buffers.")
map_num_parallel_calls: pydantic.PositiveInt = pydantic.Field(
None, description="Number of parallel calls."
)
@ -125,7 +132,8 @@ class TaskData(base_config.BaseConfig):
class SegDenseSchema(base_config.BaseConfig):
schema_path: str = pydantic.Field(..., description="Path to feature config json.")
schema_path: str = pydantic.Field(...,
description="Path to feature config json.")
features: typing.List[str] = pydantic.Field(
[],
description="List of features (in addition to the renamed features) to read from schema path above.",
@ -192,8 +200,10 @@ class DownsampleNegatives(base_config.BaseConfig):
class Preprocess(base_config.BaseConfig):
truncate_and_slice: TruncateAndSlice = pydantic.Field(None, description="Truncation and slicing.")
downcast: DownCast = pydantic.Field(None, description="Down cast to features.")
truncate_and_slice: TruncateAndSlice = pydantic.Field(
None, description="Truncation and slicing.")
downcast: DownCast = pydantic.Field(
None, description="Down cast to features.")
rectify_labels: RectifyLabels = pydantic.Field(
None, description="Rectify labels for a given overlap window"
)
@ -242,5 +252,6 @@ class RecapDataConfig(DatasetConfig):
if values.get("evaluation_tasks") is not None:
for task in values["evaluation_tasks"]:
if task not in values["tasks"]:
raise KeyError(f"Evaluation task {task} must be in tasks. Received {values['tasks']}")
raise KeyError(
f"Evaluation task {task} must be in tasks. Received {values['tasks']}")
return values

View File

@ -9,9 +9,20 @@ import numpy as np
class TruncateAndSlice(tf.keras.Model):
"""Class for truncating and slicing."""
"""
A class for truncating and slicing input features based on the provided configuration.
Args:
truncate_and_slice_config: A configuration object specifying how to truncate and slice features.
"""
def __init__(self, truncate_and_slice_config):
"""
Initializes the TruncateAndSlice model.
Args:
truncate_and_slice_config: A configuration object specifying how to truncate and slice features.
"""
super().__init__()
self._truncate_and_slice_config = truncate_and_slice_config
@ -32,6 +43,17 @@ class TruncateAndSlice(tf.keras.Model):
self._binary_mask = None
def call(self, inputs, training=None, mask=None):
"""
Applies truncation and slicing to the input features based on the configuration.
Args:
inputs: A dictionary of input features.
training: A boolean indicating whether the model is in training mode.
mask: A mask tensor.
Returns:
A dictionary of truncated and sliced input features.
"""
outputs = tf.nest.pack_sequence_as(inputs, tf.nest.flatten(inputs))
if self._truncate_and_slice_config.continuous_feature_truncation:
logging.info("Truncating continuous")
@ -51,12 +73,23 @@ class TruncateAndSlice(tf.keras.Model):
class DownCast(tf.keras.Model):
"""Class for Down casting dataset before serialization and transferring to training host.
Depends on the data type and the actual data range, the down casting can be lossless or not.
It is strongly recommended to compare the metrics before and after down casting.
"""
A class for downcasting dataset before serialization and transferring to the training host.
Depending on the data type and the actual data range, the downcasting can be lossless or not.
It is strongly recommended to compare the metrics before and after downcasting.
Args:
downcast_config: A configuration object specifying the features and their target data types.
"""
def __init__(self, downcast_config):
"""
Initializes the DownCast model.
Args:
downcast_config: A configuration object specifying the features and their target data types.
"""
super().__init__()
self.config = downcast_config
self._type_map = {
@ -65,6 +98,17 @@ class DownCast(tf.keras.Model):
}
def call(self, inputs, training=None, mask=None):
"""
Applies downcasting to the input features based on the configuration.
Args:
inputs: A dictionary of input features.
training: A boolean indicating whether the model is in training mode.
mask: A mask tensor.
Returns:
A dictionary of downcasted input features.
"""
outputs = tf.nest.pack_sequence_as(inputs, tf.nest.flatten(inputs))
for feature, type_str in self.config.features.items():
assert type_str in self._type_map
@ -78,14 +122,39 @@ class DownCast(tf.keras.Model):
class RectifyLabels(tf.keras.Model):
"""Class for rectifying labels"""
"""
A class for downcasting dataset before serialization and transferring to the training host.
Depending on the data type and the actual data range, the downcasting can be lossless or not.
It is strongly recommended to compare the metrics before and after downcasting.
Args:
downcast_config: A configuration object specifying the features and their target data types.
"""
def __init__(self, rectify_label_config):
"""
Initializes the DownCast model.
Args:
downcast_config: A configuration object specifying the features and their target data types.
"""
super().__init__()
self._config = rectify_label_config
self._window = int(self._config.label_rectification_window_in_hours * 60 * 60 * 1000)
def call(self, inputs, training=None, mask=None):
"""
Applies downcasting to the input features based on the configuration.
Args:
inputs: A dictionary of input features.
training: A boolean indicating whether the model is in training mode.
mask: A mask tensor.
Returns:
A dictionary of downcasted input features.
"""
served_ts_field = self._config.served_timestamp_field
impressed_ts_field = self._config.impressed_timestamp_field
@ -102,13 +171,37 @@ class RectifyLabels(tf.keras.Model):
class ExtractFeatures(tf.keras.Model):
"""Class for extracting individual features from dense tensors by their index."""
"""
A class for rectifying labels based on specified conditions.
This class is used to adjust label values in a dataset based on configured conditions involving timestamps.
Args:
rectify_label_config: A configuration object specifying the timestamp fields and label-to-engaged timestamp field mappings.
"""
def __init__(self, extract_features_config):
"""
Initializes the RectifyLabels model.
Args:
rectify_label_config: A configuration object specifying the timestamp fields and label-to-engaged timestamp field mappings.
"""
super().__init__()
self._config = extract_features_config
def call(self, inputs, training=None, mask=None):
"""
Rectifies label values based on the specified conditions.
Args:
inputs: A dictionary of input features including timestamp fields and labels.
training: A boolean indicating whether the model is in training mode.
mask: A mask tensor.
Returns:
A dictionary of input features with rectified label values.
"""
for row in self._config.extract_feature_table:
inputs[row.name] = inputs[row.source_tensor][:, row.index]
@ -168,7 +261,16 @@ class DownsampleNegatives(tf.keras.Model):
def build_preprocess(preprocess_config, mode=config_mod.JobMode.TRAIN):
"""Builds a preprocess model to apply all preprocessing stages."""
"""
Builds a preprocess model to apply all preprocessing stages.
Args:
preprocess_config: A configuration object specifying the preprocessing parameters.
mode: A mode indicating the current job mode (TRAIN or INFERENCE).
Returns:
A preprocess model that applies all specified preprocessing stages.
"""
if mode == config_mod.JobMode.INFERENCE:
logging.info("Not building preprocessors for dataloading since we are in Inference mode.")
return None

View File

@ -8,7 +8,8 @@ import tensorflow as tf
DEFAULTS_MAP = {"int64_list": 0, "float_list": 0.0, "bytes_list": ""}
DTYPE_MAP = {"int64_list": tf.int64, "float_list": tf.float32, "bytes_list": tf.string}
DTYPE_MAP = {"int64_list": tf.int64,
"float_list": tf.float32, "bytes_list": tf.string}
def create_tf_example_schema(
@ -27,7 +28,8 @@ def create_tf_example_schema(
segdense_config = data_config.seg_dense_schema
labels = list(data_config.tasks.keys())
used_features = (
segdense_config.features + list(segdense_config.renamed_features.values()) + labels
segdense_config.features +
list(segdense_config.renamed_features.values()) + labels
)
logging.info(used_features)
@ -40,19 +42,22 @@ def create_tf_example_schema(
dtype = entry["dtype"]
if feature_name in labels:
logging.info(f"Label: feature name is {feature_name} type is {dtype}")
logging.info(
f"Label: feature name is {feature_name} type is {dtype}")
tfe_schema[feature_name] = tf.io.FixedLenFeature(
length, DTYPE_MAP[dtype], DEFAULTS_MAP[dtype]
)
elif length == -1:
tfe_schema[feature_name] = tf.io.VarLenFeature(DTYPE_MAP[dtype])
tfe_schema[feature_name] = tf.io.VarLenFeature(
DTYPE_MAP[dtype])
else:
tfe_schema[feature_name] = tf.io.FixedLenFeature(
length, DTYPE_MAP[dtype], [DEFAULTS_MAP[dtype]] * length
)
for feature_name in used_features:
if feature_name not in tfe_schema:
raise ValueError(f"{feature_name} missing from schema: {segdense_config.schema_path}.")
raise ValueError(
f"{feature_name} missing from schema: {segdense_config.schema_path}.")
return tfe_schema
@ -82,7 +87,8 @@ def parse_tf_example(
Returns:
Dictionary of tensors to be used as model input.
"""
inputs = tf.io.parse_example(serialized=serialized_example, features=tfe_schema)
inputs = tf.io.parse_example(
serialized=serialized_example, features=tfe_schema)
for new_feature_name, old_feature_name in seg_dense_schema_config.renamed_features.items():
inputs[new_feature_name] = inputs.pop(old_feature_name)
@ -90,7 +96,8 @@ def parse_tf_example(
# This should not actually be used except for experimentation with low precision floats.
if "mask_mantissa_features" in seg_dense_schema_config:
for feature_name, mask_length in seg_dense_schema_config.mask_mantissa_features.items():
inputs[feature_name] = mask_mantissa(inputs[feature_name], mask_length)
inputs[feature_name] = mask_mantissa(
inputs[feature_name], mask_length)
# DANGER DANGER: This default seems really scary, and it's only here because it has to be visible
# at TF level.

View File

@ -9,44 +9,59 @@ def keyed_tensor_from_tensors_dict(
tensor_map: Mapping[str, torch.Tensor]
) -> "torchrec.KeyedTensor":
"""
Convert a dictionary of torch tensor to torchrec keyed tensor
Convert a dictionary of torch tensors to a torchrec KeyedTensor.
Args:
tensor_map:
tensor_map: A mapping of tensor names to torch tensors.
Returns:
A torchrec KeyedTensor.
"""
keys = list(tensor_map.keys())
# We expect batch size to be first dim. However, if we get a shape [Batch_size],
# KeyedTensor will not find the correct batch_size. So, in those cases we make sure the shape is
# [Batch_size x 1].
values = [
tensor_map[key] if len(tensor_map[key].shape) > 1 else torch.unsqueeze(tensor_map[key], -1)
tensor_map[key] if len(tensor_map[key].shape) > 1 else torch.unsqueeze(
tensor_map[key], -1)
for key in keys
]
return torchrec.KeyedTensor.from_tensor_list(keys, values)
def _compute_jagged_tensor_from_tensor(tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Compute a jagged tensor from a torch tensor.
Args:
tensor: Input torch tensor.
Returns:
A tuple containing the values and lengths of the jagged tensor.
"""
if tensor.is_sparse:
x = tensor.coalesce() # Ensure that the indices are ordered.
lengths = torch.bincount(x.indices()[0])
values = x.values()
else:
values = tensor
lengths = torch.ones(tensor.shape[0], dtype=torch.int32, device=tensor.device)
lengths = torch.ones(
tensor.shape[0], dtype=torch.int32, device=tensor.device)
return values, lengths
def jagged_tensor_from_tensor(tensor: torch.Tensor) -> "torchrec.JaggedTensor":
"""
Convert a torch tensor to torchrec jagged tensor.
Note: Currently only support shape of [Batch_size] or [Batch_size x N] for dense tensors.
For sparse tensor the shape of .values() should be [Batch_size] or [Batch_size x N]; the
dense_shape of the sparse tensor can be arbitrary.
Convert a torch tensor to a torchrec jagged tensor.
Note: Currently, this function only supports input tensors with shapes of [Batch_size] or [Batch_size x N] for dense tensors.
For sparse tensors, the shape of .values() should be [Batch_size] or [Batch_size x N], and the dense_shape of the sparse tensor can be arbitrary.
Args:
tensor: a torch (sparse) tensor.
tensor: A torch (sparse) tensor.
Returns:
A torchrec JaggedTensor.
"""
values, lengths = _compute_jagged_tensor_from_tensor(tensor)
return torchrec.JaggedTensor(values=values, lengths=lengths)
@ -56,15 +71,16 @@ def keyed_jagged_tensor_from_tensors_dict(
tensor_map: Mapping[str, torch.Tensor]
) -> "torchrec.KeyedJaggedTensor":
"""
Convert a dictionary of (sparse) torch tensors to torchrec keyed jagged tensor.
Note: Currently only support shape of [Batch_size] or [Batch_size x 1] for dense tensors.
For sparse tensor the shape of .values() should be [Batch_size] or [Batch_size x 1]; the
dense_shape of the sparse tensor can be arbitrary.
Convert a dictionary of (sparse) torch tensors to a torchrec keyed jagged tensor.
Note: Currently, this function only supports input tensors with shapes of [Batch_size] or [Batch_size x 1] for dense tensors.
For sparse tensors, the shape of .values() should be [Batch_size] or [Batch_size x 1], and the dense_shape of the sparse tensor can be arbitrary.
Args:
tensor_map:
tensor_map: A mapping of tensor names to torch tensors.
Returns:
A torchrec KeyedJaggedTensor.
"""
if not tensor_map:
@ -91,10 +107,29 @@ def keyed_jagged_tensor_from_tensors_dict(
def _tf_to_numpy(tf_tensor: tf.Tensor) -> np.ndarray:
"""
Convert a TensorFlow tensor to a NumPy array.
Args:
tf_tensor: TensorFlow tensor.
Returns:
NumPy array.
"""
return tf_tensor._numpy() # noqa
def _dense_tf_to_torch(tensor: tf.Tensor, pin_memory: bool) -> torch.Tensor:
"""
Convert a dense TensorFlow tensor to a PyTorch tensor.
Args:
tensor: Dense TensorFlow tensor.
pin_memory: Whether to pin the tensor in memory (for CUDA).
Returns:
PyTorch tensor.
"""
tensor = _tf_to_numpy(tensor)
# Pytorch does not support bfloat16, up cast to float32 to keep the same number of bits on exponent
if tensor.dtype.name == "bfloat16":
@ -109,6 +144,16 @@ def _dense_tf_to_torch(tensor: tf.Tensor, pin_memory: bool) -> torch.Tensor:
def sparse_or_dense_tf_to_torch(
tensor: Union[tf.Tensor, tf.SparseTensor], pin_memory: bool
) -> torch.Tensor:
"""
Convert a TensorFlow tensor (sparse or dense) to a PyTorch tensor.
Args:
tensor: TensorFlow tensor (sparse or dense).
pin_memory: Whether to pin the tensor in memory (for CUDA).
Returns:
PyTorch tensor.
"""
if isinstance(tensor, tf.SparseTensor):
tensor = torch.sparse_coo_tensor(
_dense_tf_to_torch(tensor.indices, pin_memory).t(),

View File

@ -69,6 +69,7 @@ class MaskBlock(torch.nn.Module):
This class is intended for internal use within neural network architectures and should not be
directly accessed or modified by external code.
"""
def __init__(
self, mask_block_config: config.MaskBlockConfig, input_dim: int, mask_input_dim: int
) -> None:
@ -94,11 +95,13 @@ class MaskBlock(torch.nn.Module):
self._input_layer_norm = None
if mask_block_config.reduction_factor:
aggregation_size = int(mask_input_dim * mask_block_config.reduction_factor)
aggregation_size = int(
mask_input_dim * mask_block_config.reduction_factor)
elif mask_block_config.aggregation_size is not None:
aggregation_size = mask_block_config.aggregation_size
else:
raise ValueError("Need one of reduction factor or aggregation size.")
raise ValueError(
"Need one of reduction factor or aggregation size.")
self._mask_layer = torch.nn.Sequential(
torch.nn.Linear(mask_input_dim, aggregation_size),
@ -123,7 +126,8 @@ class MaskBlock(torch.nn.Module):
"""
if self._input_layer_norm:
net = self._input_layer_norm(net)
hidden_layer_output = self._hidden_layer(net * self._mask_layer(mask_input))
hidden_layer_output = self._hidden_layer(
net * self._mask_layer(mask_input))
return self._layer_norm(hidden_layer_output)
@ -170,6 +174,7 @@ class MaskNet(torch.nn.Module):
This class is intended for internal use within neural network architectures and should not be
directly accessed or modified by external code.
"""
def __init__(self, mask_net_config: config.MaskNetConfig, in_features: int):
"""
Initializes the MaskNet module.
@ -189,20 +194,23 @@ class MaskNet(torch.nn.Module):
if mask_net_config.use_parallel:
total_output_mask_blocks = 0
for mask_block_config in mask_net_config.mask_blocks:
mask_blocks.append(MaskBlock(mask_block_config, in_features, in_features))
mask_blocks.append(
MaskBlock(mask_block_config, in_features, in_features))
total_output_mask_blocks += mask_block_config.output_size
self._mask_blocks = torch.nn.ModuleList(mask_blocks)
else:
input_size = in_features
for mask_block_config in mask_net_config.mask_blocks:
mask_blocks.append(MaskBlock(mask_block_config, input_size, in_features))
mask_blocks.append(
MaskBlock(mask_block_config, input_size, in_features))
input_size = mask_block_config.output_size
self._mask_blocks = torch.nn.ModuleList(mask_blocks)
total_output_mask_blocks = mask_block_config.output_size
if mask_net_config.mlp:
self._dense_layers = mlp.Mlp(total_output_mask_blocks, mask_net_config.mlp)
self._dense_layers = mlp.Mlp(
total_output_mask_blocks, mask_net_config.mlp)
self.out_features = mask_net_config.mlp.layer_sizes[-1]
else:
self.out_features = total_output_mask_blocks
@ -235,5 +243,6 @@ class MaskNet(torch.nn.Module):
for mask_layer in self._mask_blocks:
net = mask_layer(net=net, mask_input=inputs)
# Share the output of the stacked MaskBlocks.
output = net if self.mask_net_config.mlp is None else self._dense_layers[net]["output"]
output = net if self.mask_net_config.mlp is None else self._dense_layers[
net]["output"]
return {"output": output, "shared_layer": net}

View File

@ -52,6 +52,7 @@ class ModelAndLoss(torch.nn.Module):
This class is intended for internal use within neural network architectures and should not be
directly accessed or modified by external code.
"""
def __init__(
self,
model,
@ -92,13 +93,16 @@ class ModelAndLoss(torch.nn.Module):
labels=batch.labels,
weights=batch.weights,
)
losses = self.loss_fn(outputs["logits"], batch.labels.float(), batch.weights.float())
losses = self.loss_fn(
outputs["logits"], batch.labels.float(), batch.weights.float())
if self.stratifiers:
logging.info(f"***** Adding stratifiers *****\n {self.stratifiers}")
logging.info(
f"***** Adding stratifiers *****\n {self.stratifiers}")
outputs["stratifiers"] = {}
for stratifier in self.stratifiers:
outputs["stratifiers"][stratifier.name] = batch.discrete_features[:, stratifier.index]
outputs["stratifiers"][stratifier.name] = batch.discrete_features[:,
stratifier.index]
# In general, we can have a large number of losses returned by our loss function.
if isinstance(losses, dict):

View File

@ -39,6 +39,7 @@ class NumericCalibration(torch.nn.Module):
This class is intended for internal use within neural network architectures and should not be
directly accessed or modified by external code.
"""
def __init__(
self,
pos_downsampling_rate: float,

View File

@ -38,6 +38,15 @@ import pyarrow.parquet as pq
def _create_dataset(path: str):
"""
Create a PyArrow dataset from Parquet files located at the specified path.
Args:
path (str): The path to the Parquet files.
Returns:
pyarrow.dataset.Dataset: The PyArrow dataset.
"""
fs = infer_fs(path)
files = fs.glob(path)
return pads.dataset(files, format="parquet", filesystem=fs)
@ -47,12 +56,27 @@ class PqReader:
def __init__(
self, path: str, num: int = 10, batch_size: int = 1024, columns: Optional[List[str]] = None
):
"""
Initialize a Parquet Reader.
Args:
path (str): The path to the Parquet files.
num (int): The maximum number of rows to read.
batch_size (int): The batch size for reading data.
columns (Optional[List[str]]): A list of column names to read (default is None, which reads all columns).
"""
self._ds = _create_dataset(path)
self._batch_size = batch_size
self._num = num
self._columns = columns
def __iter__(self):
"""
Iterate through the Parquet data and yield batches of rows.
Yields:
pyarrow.RecordBatch: A batch of rows.
"""
batches = self._ds.to_batches(batch_size=self._batch_size, columns=self._columns)
rows_seen = 0
for count, record in enumerate(batches):
@ -62,6 +86,12 @@ class PqReader:
rows_seen += record.data.num_rows
def _head(self):
"""
Get the first `num` rows of the Parquet data.
Returns:
pyarrow.RecordBatch: A batch of rows.
"""
total_read = self._num * self.bytes_per_row
if total_read >= int(500e6):
raise Exception(
@ -71,6 +101,12 @@ class PqReader:
@property
def bytes_per_row(self) -> int:
"""
Calculate the estimated bytes per row in the dataset.
Returns:
int: The estimated bytes per row.
"""
nbits = 0
for t in self._ds.schema.types:
try:
@ -81,17 +117,22 @@ class PqReader:
return nbits // 8
def schema(self):
"""
Display the schema of the Parquet dataset.
"""
print(f"\n# Schema\n{self._ds.schema}")
def head(self):
"""Displays first --num rows."""
"""
Display the first `num` rows of the Parquet data as a pandas DataFrame.
"""
print(self._head().to_pandas())
def distinct(self):
"""Displays unique values seen in specified columns in the first `--num` rows.
"""
Display unique values seen in specified columns in the first `num` rows.
Useful for getting an approximate vocabulary for certain columns.
"""
for col_name, column in zip(self._head().column_names, self._head().columns):
print(col_name)