This commit is contained in:
rajveer43 2023-09-21 22:53:34 +05:30
parent cc73f5fcb7
commit f7f26d0c20
9 changed files with 953 additions and 724 deletions

View File

@ -54,12 +54,20 @@ def maybe_shard_model(
model, model,
device: torch.device, device: torch.device,
): ):
"""Set up and apply DistributedModelParallel to a model if running in a distributed environment. """
Set up and apply DistributedModelParallel to a model if running in a distributed environment.
If in a distributed environment, constructs Topology, sharders, and ShardingPlan, then applies If in a distributed environment, constructs Topology, sharders, and ShardingPlan, then applies
DistributedModelParallel. DistributedModelParallel.
If not in a distributed environment, returns model directly. If not in a distributed environment, returns the model directly.
Args:
model: The PyTorch model.
device: The target device (e.g., 'cuda').
Returns:
The model wrapped with DistributedModelParallel if in a distributed environment, else the original model.
""" """
if dist.is_initialized(): if dist.is_initialized():
logging.info("***** Wrapping in DistributedModelParallel *****") logging.info("***** Wrapping in DistributedModelParallel *****")
@ -74,13 +82,14 @@ def maybe_shard_model(
def log_sharded_tensor_content(weight_name: str, table_name: str, weight_tensor) -> None: def log_sharded_tensor_content(weight_name: str, table_name: str, weight_tensor) -> None:
"""Handy function to log the content of EBC embedding layer. """
Handy function to log the content of an EBC (Embedding Bag Concatenation) embedding layer.
Only works for single GPU machines. Only works for single GPU machines.
Args: Args:
weight_name: name of tensor, as defined in model weight_name: Name of the tensor, as defined in the model.
table_name: name of the EBC table the weight is taken from table_name: Name of the EBC table the weight is taken from.
weight_tensor: embedding weight tensor weight_tensor: Embedding weight tensor.
""" """
logging.info(f"{weight_name}, {table_name}", rank=-1) logging.info(f"{weight_name}, {table_name}", rank=-1)
logging.info(f"{weight_tensor.metadata()}", rank=-1) logging.info(f"{weight_tensor.metadata()}", rank=-1)

View File

@ -11,8 +11,10 @@ class ExplicitDateInputs(base_config.BaseConfig):
"""Arguments to select train/validation data using end_date and days of data.""" """Arguments to select train/validation data using end_date and days of data."""
data_root: str = pydantic.Field(..., description="Data path prefix.") data_root: str = pydantic.Field(..., description="Data path prefix.")
end_date: str = pydantic.Field(..., description="Data end date, inclusive.") end_date: str = pydantic.Field(...,
days: int = pydantic.Field(..., description="Number of days of data for dataset.") description="Data end date, inclusive.")
days: int = pydantic.Field(...,
description="Number of days of data for dataset.")
num_missing_days_tol: int = pydantic.Field( num_missing_days_tol: int = pydantic.Field(
0, description="We tolerate <= num_missing_days_tol days of missing data." 0, description="We tolerate <= num_missing_days_tol days of missing data."
) )
@ -22,8 +24,10 @@ class ExplicitDatetimeInputs(base_config.BaseConfig):
"""Arguments to select train/validation data using end_datetime and hours of data.""" """Arguments to select train/validation data using end_datetime and hours of data."""
data_root: str = pydantic.Field(..., description="Data path prefix.") data_root: str = pydantic.Field(..., description="Data path prefix.")
end_datetime: str = pydantic.Field(..., description="Data end datetime, inclusive.") end_datetime: str = pydantic.Field(...,
hours: int = pydantic.Field(..., description="Number of hours of data for dataset.") description="Data end datetime, inclusive.")
hours: int = pydantic.Field(...,
description="Number of hours of data for dataset.")
num_missing_hours_tol: int = pydantic.Field( num_missing_hours_tol: int = pydantic.Field(
0, description="We tolerate <= num_missing_hours_tol hours of missing data." 0, description="We tolerate <= num_missing_hours_tol hours of missing data."
) )
@ -42,7 +46,8 @@ class DatasetConfig(base_config.BaseConfig):
explicit_datetime_inputs: ExplicitDatetimeInputs = pydantic.Field( explicit_datetime_inputs: ExplicitDatetimeInputs = pydantic.Field(
None, one_of="date_inputs_format" None, one_of="date_inputs_format"
) )
explicit_date_inputs: ExplicitDateInputs = pydantic.Field(None, one_of="date_inputs_format") explicit_date_inputs: ExplicitDateInputs = pydantic.Field(
None, one_of="date_inputs_format")
global_batch_size: pydantic.PositiveInt global_batch_size: pydantic.PositiveInt
@ -52,7 +57,8 @@ class DatasetConfig(base_config.BaseConfig):
repeat_files: bool = pydantic.Field( repeat_files: bool = pydantic.Field(
True, description="DEPRICATED. Files are repeated no matter what this is set to." True, description="DEPRICATED. Files are repeated no matter what this is set to."
) )
file_batch_size: pydantic.PositiveInt = pydantic.Field(16, description="File batch size") file_batch_size: pydantic.PositiveInt = pydantic.Field(
16, description="File batch size")
cache: bool = pydantic.Field( cache: bool = pydantic.Field(
False, False,
@ -70,7 +76,8 @@ class DatasetConfig(base_config.BaseConfig):
) )
# tf.data.Dataset options # tf.data.Dataset options
examples_shuffle_buffer_size: int = pydantic.Field(1024, description="Size of shuffle buffers.") examples_shuffle_buffer_size: int = pydantic.Field(
1024, description="Size of shuffle buffers.")
map_num_parallel_calls: pydantic.PositiveInt = pydantic.Field( map_num_parallel_calls: pydantic.PositiveInt = pydantic.Field(
None, description="Number of parallel calls." None, description="Number of parallel calls."
) )
@ -125,7 +132,8 @@ class TaskData(base_config.BaseConfig):
class SegDenseSchema(base_config.BaseConfig): class SegDenseSchema(base_config.BaseConfig):
schema_path: str = pydantic.Field(..., description="Path to feature config json.") schema_path: str = pydantic.Field(...,
description="Path to feature config json.")
features: typing.List[str] = pydantic.Field( features: typing.List[str] = pydantic.Field(
[], [],
description="List of features (in addition to the renamed features) to read from schema path above.", description="List of features (in addition to the renamed features) to read from schema path above.",
@ -192,8 +200,10 @@ class DownsampleNegatives(base_config.BaseConfig):
class Preprocess(base_config.BaseConfig): class Preprocess(base_config.BaseConfig):
truncate_and_slice: TruncateAndSlice = pydantic.Field(None, description="Truncation and slicing.") truncate_and_slice: TruncateAndSlice = pydantic.Field(
downcast: DownCast = pydantic.Field(None, description="Down cast to features.") None, description="Truncation and slicing.")
downcast: DownCast = pydantic.Field(
None, description="Down cast to features.")
rectify_labels: RectifyLabels = pydantic.Field( rectify_labels: RectifyLabels = pydantic.Field(
None, description="Rectify labels for a given overlap window" None, description="Rectify labels for a given overlap window"
) )
@ -242,5 +252,6 @@ class RecapDataConfig(DatasetConfig):
if values.get("evaluation_tasks") is not None: if values.get("evaluation_tasks") is not None:
for task in values["evaluation_tasks"]: for task in values["evaluation_tasks"]:
if task not in values["tasks"]: if task not in values["tasks"]:
raise KeyError(f"Evaluation task {task} must be in tasks. Received {values['tasks']}") raise KeyError(
f"Evaluation task {task} must be in tasks. Received {values['tasks']}")
return values return values

View File

@ -9,9 +9,20 @@ import numpy as np
class TruncateAndSlice(tf.keras.Model): class TruncateAndSlice(tf.keras.Model):
"""Class for truncating and slicing.""" """
A class for truncating and slicing input features based on the provided configuration.
Args:
truncate_and_slice_config: A configuration object specifying how to truncate and slice features.
"""
def __init__(self, truncate_and_slice_config): def __init__(self, truncate_and_slice_config):
"""
Initializes the TruncateAndSlice model.
Args:
truncate_and_slice_config: A configuration object specifying how to truncate and slice features.
"""
super().__init__() super().__init__()
self._truncate_and_slice_config = truncate_and_slice_config self._truncate_and_slice_config = truncate_and_slice_config
@ -32,6 +43,17 @@ class TruncateAndSlice(tf.keras.Model):
self._binary_mask = None self._binary_mask = None
def call(self, inputs, training=None, mask=None): def call(self, inputs, training=None, mask=None):
"""
Applies truncation and slicing to the input features based on the configuration.
Args:
inputs: A dictionary of input features.
training: A boolean indicating whether the model is in training mode.
mask: A mask tensor.
Returns:
A dictionary of truncated and sliced input features.
"""
outputs = tf.nest.pack_sequence_as(inputs, tf.nest.flatten(inputs)) outputs = tf.nest.pack_sequence_as(inputs, tf.nest.flatten(inputs))
if self._truncate_and_slice_config.continuous_feature_truncation: if self._truncate_and_slice_config.continuous_feature_truncation:
logging.info("Truncating continuous") logging.info("Truncating continuous")
@ -51,12 +73,23 @@ class TruncateAndSlice(tf.keras.Model):
class DownCast(tf.keras.Model): class DownCast(tf.keras.Model):
"""Class for Down casting dataset before serialization and transferring to training host. """
Depends on the data type and the actual data range, the down casting can be lossless or not. A class for downcasting dataset before serialization and transferring to the training host.
Depending on the data type and the actual data range, the downcasting can be lossless or not.
It is strongly recommended to compare the metrics before and after downcasting. It is strongly recommended to compare the metrics before and after downcasting.
Args:
downcast_config: A configuration object specifying the features and their target data types.
""" """
def __init__(self, downcast_config): def __init__(self, downcast_config):
"""
Initializes the DownCast model.
Args:
downcast_config: A configuration object specifying the features and their target data types.
"""
super().__init__() super().__init__()
self.config = downcast_config self.config = downcast_config
self._type_map = { self._type_map = {
@ -65,6 +98,17 @@ class DownCast(tf.keras.Model):
} }
def call(self, inputs, training=None, mask=None): def call(self, inputs, training=None, mask=None):
"""
Applies downcasting to the input features based on the configuration.
Args:
inputs: A dictionary of input features.
training: A boolean indicating whether the model is in training mode.
mask: A mask tensor.
Returns:
A dictionary of downcasted input features.
"""
outputs = tf.nest.pack_sequence_as(inputs, tf.nest.flatten(inputs)) outputs = tf.nest.pack_sequence_as(inputs, tf.nest.flatten(inputs))
for feature, type_str in self.config.features.items(): for feature, type_str in self.config.features.items():
assert type_str in self._type_map assert type_str in self._type_map
@ -78,14 +122,39 @@ class DownCast(tf.keras.Model):
class RectifyLabels(tf.keras.Model): class RectifyLabels(tf.keras.Model):
"""Class for rectifying labels""" """
A class for downcasting dataset before serialization and transferring to the training host.
Depending on the data type and the actual data range, the downcasting can be lossless or not.
It is strongly recommended to compare the metrics before and after downcasting.
Args:
downcast_config: A configuration object specifying the features and their target data types.
"""
def __init__(self, rectify_label_config): def __init__(self, rectify_label_config):
"""
Initializes the DownCast model.
Args:
downcast_config: A configuration object specifying the features and their target data types.
"""
super().__init__() super().__init__()
self._config = rectify_label_config self._config = rectify_label_config
self._window = int(self._config.label_rectification_window_in_hours * 60 * 60 * 1000) self._window = int(self._config.label_rectification_window_in_hours * 60 * 60 * 1000)
def call(self, inputs, training=None, mask=None): def call(self, inputs, training=None, mask=None):
"""
Applies downcasting to the input features based on the configuration.
Args:
inputs: A dictionary of input features.
training: A boolean indicating whether the model is in training mode.
mask: A mask tensor.
Returns:
A dictionary of downcasted input features.
"""
served_ts_field = self._config.served_timestamp_field served_ts_field = self._config.served_timestamp_field
impressed_ts_field = self._config.impressed_timestamp_field impressed_ts_field = self._config.impressed_timestamp_field
@ -102,13 +171,37 @@ class RectifyLabels(tf.keras.Model):
class ExtractFeatures(tf.keras.Model): class ExtractFeatures(tf.keras.Model):
"""Class for extracting individual features from dense tensors by their index.""" """
A class for rectifying labels based on specified conditions.
This class is used to adjust label values in a dataset based on configured conditions involving timestamps.
Args:
rectify_label_config: A configuration object specifying the timestamp fields and label-to-engaged timestamp field mappings.
"""
def __init__(self, extract_features_config): def __init__(self, extract_features_config):
"""
Initializes the RectifyLabels model.
Args:
rectify_label_config: A configuration object specifying the timestamp fields and label-to-engaged timestamp field mappings.
"""
super().__init__() super().__init__()
self._config = extract_features_config self._config = extract_features_config
def call(self, inputs, training=None, mask=None): def call(self, inputs, training=None, mask=None):
"""
Rectifies label values based on the specified conditions.
Args:
inputs: A dictionary of input features including timestamp fields and labels.
training: A boolean indicating whether the model is in training mode.
mask: A mask tensor.
Returns:
A dictionary of input features with rectified label values.
"""
for row in self._config.extract_feature_table: for row in self._config.extract_feature_table:
inputs[row.name] = inputs[row.source_tensor][:, row.index] inputs[row.name] = inputs[row.source_tensor][:, row.index]
@ -168,7 +261,16 @@ class DownsampleNegatives(tf.keras.Model):
def build_preprocess(preprocess_config, mode=config_mod.JobMode.TRAIN): def build_preprocess(preprocess_config, mode=config_mod.JobMode.TRAIN):
"""Builds a preprocess model to apply all preprocessing stages.""" """
Builds a preprocess model to apply all preprocessing stages.
Args:
preprocess_config: A configuration object specifying the preprocessing parameters.
mode: A mode indicating the current job mode (TRAIN or INFERENCE).
Returns:
A preprocess model that applies all specified preprocessing stages.
"""
if mode == config_mod.JobMode.INFERENCE: if mode == config_mod.JobMode.INFERENCE:
logging.info("Not building preprocessors for dataloading since we are in Inference mode.") logging.info("Not building preprocessors for dataloading since we are in Inference mode.")
return None return None

View File

@ -8,7 +8,8 @@ import tensorflow as tf
DEFAULTS_MAP = {"int64_list": 0, "float_list": 0.0, "bytes_list": ""} DEFAULTS_MAP = {"int64_list": 0, "float_list": 0.0, "bytes_list": ""}
DTYPE_MAP = {"int64_list": tf.int64, "float_list": tf.float32, "bytes_list": tf.string} DTYPE_MAP = {"int64_list": tf.int64,
"float_list": tf.float32, "bytes_list": tf.string}
def create_tf_example_schema( def create_tf_example_schema(
@ -27,7 +28,8 @@ def create_tf_example_schema(
segdense_config = data_config.seg_dense_schema segdense_config = data_config.seg_dense_schema
labels = list(data_config.tasks.keys()) labels = list(data_config.tasks.keys())
used_features = ( used_features = (
segdense_config.features + list(segdense_config.renamed_features.values()) + labels segdense_config.features +
list(segdense_config.renamed_features.values()) + labels
) )
logging.info(used_features) logging.info(used_features)
@ -40,19 +42,22 @@ def create_tf_example_schema(
dtype = entry["dtype"] dtype = entry["dtype"]
if feature_name in labels: if feature_name in labels:
logging.info(f"Label: feature name is {feature_name} type is {dtype}") logging.info(
f"Label: feature name is {feature_name} type is {dtype}")
tfe_schema[feature_name] = tf.io.FixedLenFeature( tfe_schema[feature_name] = tf.io.FixedLenFeature(
length, DTYPE_MAP[dtype], DEFAULTS_MAP[dtype] length, DTYPE_MAP[dtype], DEFAULTS_MAP[dtype]
) )
elif length == -1: elif length == -1:
tfe_schema[feature_name] = tf.io.VarLenFeature(DTYPE_MAP[dtype]) tfe_schema[feature_name] = tf.io.VarLenFeature(
DTYPE_MAP[dtype])
else: else:
tfe_schema[feature_name] = tf.io.FixedLenFeature( tfe_schema[feature_name] = tf.io.FixedLenFeature(
length, DTYPE_MAP[dtype], [DEFAULTS_MAP[dtype]] * length length, DTYPE_MAP[dtype], [DEFAULTS_MAP[dtype]] * length
) )
for feature_name in used_features: for feature_name in used_features:
if feature_name not in tfe_schema: if feature_name not in tfe_schema:
raise ValueError(f"{feature_name} missing from schema: {segdense_config.schema_path}.") raise ValueError(
f"{feature_name} missing from schema: {segdense_config.schema_path}.")
return tfe_schema return tfe_schema
@ -82,7 +87,8 @@ def parse_tf_example(
Returns: Returns:
Dictionary of tensors to be used as model input. Dictionary of tensors to be used as model input.
""" """
inputs = tf.io.parse_example(serialized=serialized_example, features=tfe_schema) inputs = tf.io.parse_example(
serialized=serialized_example, features=tfe_schema)
for new_feature_name, old_feature_name in seg_dense_schema_config.renamed_features.items(): for new_feature_name, old_feature_name in seg_dense_schema_config.renamed_features.items():
inputs[new_feature_name] = inputs.pop(old_feature_name) inputs[new_feature_name] = inputs.pop(old_feature_name)
@ -90,7 +96,8 @@ def parse_tf_example(
# This should not actually be used except for experimentation with low precision floats. # This should not actually be used except for experimentation with low precision floats.
if "mask_mantissa_features" in seg_dense_schema_config: if "mask_mantissa_features" in seg_dense_schema_config:
for feature_name, mask_length in seg_dense_schema_config.mask_mantissa_features.items(): for feature_name, mask_length in seg_dense_schema_config.mask_mantissa_features.items():
inputs[feature_name] = mask_mantissa(inputs[feature_name], mask_length) inputs[feature_name] = mask_mantissa(
inputs[feature_name], mask_length)
# DANGER DANGER: This default seems really scary, and it's only here because it has to be visible # DANGER DANGER: This default seems really scary, and it's only here because it has to be visible
# at TF level. # at TF level.

View File

@ -9,44 +9,59 @@ def keyed_tensor_from_tensors_dict(
tensor_map: Mapping[str, torch.Tensor] tensor_map: Mapping[str, torch.Tensor]
) -> "torchrec.KeyedTensor": ) -> "torchrec.KeyedTensor":
""" """
Convert a dictionary of torch tensor to torchrec keyed tensor Convert a dictionary of torch tensors to a torchrec KeyedTensor.
Args: Args:
tensor_map: tensor_map: A mapping of tensor names to torch tensors.
Returns: Returns:
A torchrec KeyedTensor.
""" """
keys = list(tensor_map.keys()) keys = list(tensor_map.keys())
# We expect batch size to be first dim. However, if we get a shape [Batch_size], # We expect batch size to be first dim. However, if we get a shape [Batch_size],
# KeyedTensor will not find the correct batch_size. So, in those cases we make sure the shape is # KeyedTensor will not find the correct batch_size. So, in those cases we make sure the shape is
# [Batch_size x 1]. # [Batch_size x 1].
values = [ values = [
tensor_map[key] if len(tensor_map[key].shape) > 1 else torch.unsqueeze(tensor_map[key], -1) tensor_map[key] if len(tensor_map[key].shape) > 1 else torch.unsqueeze(
tensor_map[key], -1)
for key in keys for key in keys
] ]
return torchrec.KeyedTensor.from_tensor_list(keys, values) return torchrec.KeyedTensor.from_tensor_list(keys, values)
def _compute_jagged_tensor_from_tensor(tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: def _compute_jagged_tensor_from_tensor(tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Compute a jagged tensor from a torch tensor.
Args:
tensor: Input torch tensor.
Returns:
A tuple containing the values and lengths of the jagged tensor.
"""
if tensor.is_sparse: if tensor.is_sparse:
x = tensor.coalesce() # Ensure that the indices are ordered. x = tensor.coalesce() # Ensure that the indices are ordered.
lengths = torch.bincount(x.indices()[0]) lengths = torch.bincount(x.indices()[0])
values = x.values() values = x.values()
else: else:
values = tensor values = tensor
lengths = torch.ones(tensor.shape[0], dtype=torch.int32, device=tensor.device) lengths = torch.ones(
tensor.shape[0], dtype=torch.int32, device=tensor.device)
return values, lengths return values, lengths
def jagged_tensor_from_tensor(tensor: torch.Tensor) -> "torchrec.JaggedTensor": def jagged_tensor_from_tensor(tensor: torch.Tensor) -> "torchrec.JaggedTensor":
""" """
Convert a torch tensor to torchrec jagged tensor. Convert a torch tensor to a torchrec jagged tensor.
Note: Currently only support shape of [Batch_size] or [Batch_size x N] for dense tensors.
For sparse tensor the shape of .values() should be [Batch_size] or [Batch_size x N]; the Note: Currently, this function only supports input tensors with shapes of [Batch_size] or [Batch_size x N] for dense tensors.
dense_shape of the sparse tensor can be arbitrary. For sparse tensors, the shape of .values() should be [Batch_size] or [Batch_size x N], and the dense_shape of the sparse tensor can be arbitrary.
Args: Args:
tensor: a torch (sparse) tensor. tensor: A torch (sparse) tensor.
Returns: Returns:
A torchrec JaggedTensor.
""" """
values, lengths = _compute_jagged_tensor_from_tensor(tensor) values, lengths = _compute_jagged_tensor_from_tensor(tensor)
return torchrec.JaggedTensor(values=values, lengths=lengths) return torchrec.JaggedTensor(values=values, lengths=lengths)
@ -56,15 +71,16 @@ def keyed_jagged_tensor_from_tensors_dict(
tensor_map: Mapping[str, torch.Tensor] tensor_map: Mapping[str, torch.Tensor]
) -> "torchrec.KeyedJaggedTensor": ) -> "torchrec.KeyedJaggedTensor":
""" """
Convert a dictionary of (sparse) torch tensors to torchrec keyed jagged tensor. Convert a dictionary of (sparse) torch tensors to a torchrec keyed jagged tensor.
Note: Currently only support shape of [Batch_size] or [Batch_size x 1] for dense tensors.
For sparse tensor the shape of .values() should be [Batch_size] or [Batch_size x 1]; the Note: Currently, this function only supports input tensors with shapes of [Batch_size] or [Batch_size x 1] for dense tensors.
dense_shape of the sparse tensor can be arbitrary. For sparse tensors, the shape of .values() should be [Batch_size] or [Batch_size x 1], and the dense_shape of the sparse tensor can be arbitrary.
Args: Args:
tensor_map: tensor_map: A mapping of tensor names to torch tensors.
Returns: Returns:
A torchrec KeyedJaggedTensor.
""" """
if not tensor_map: if not tensor_map:
@ -91,10 +107,29 @@ def keyed_jagged_tensor_from_tensors_dict(
def _tf_to_numpy(tf_tensor: tf.Tensor) -> np.ndarray: def _tf_to_numpy(tf_tensor: tf.Tensor) -> np.ndarray:
"""
Convert a TensorFlow tensor to a NumPy array.
Args:
tf_tensor: TensorFlow tensor.
Returns:
NumPy array.
"""
return tf_tensor._numpy() # noqa return tf_tensor._numpy() # noqa
def _dense_tf_to_torch(tensor: tf.Tensor, pin_memory: bool) -> torch.Tensor: def _dense_tf_to_torch(tensor: tf.Tensor, pin_memory: bool) -> torch.Tensor:
"""
Convert a dense TensorFlow tensor to a PyTorch tensor.
Args:
tensor: Dense TensorFlow tensor.
pin_memory: Whether to pin the tensor in memory (for CUDA).
Returns:
PyTorch tensor.
"""
tensor = _tf_to_numpy(tensor) tensor = _tf_to_numpy(tensor)
# Pytorch does not support bfloat16, up cast to float32 to keep the same number of bits on exponent # Pytorch does not support bfloat16, up cast to float32 to keep the same number of bits on exponent
if tensor.dtype.name == "bfloat16": if tensor.dtype.name == "bfloat16":
@ -109,6 +144,16 @@ def _dense_tf_to_torch(tensor: tf.Tensor, pin_memory: bool) -> torch.Tensor:
def sparse_or_dense_tf_to_torch( def sparse_or_dense_tf_to_torch(
tensor: Union[tf.Tensor, tf.SparseTensor], pin_memory: bool tensor: Union[tf.Tensor, tf.SparseTensor], pin_memory: bool
) -> torch.Tensor: ) -> torch.Tensor:
"""
Convert a TensorFlow tensor (sparse or dense) to a PyTorch tensor.
Args:
tensor: TensorFlow tensor (sparse or dense).
pin_memory: Whether to pin the tensor in memory (for CUDA).
Returns:
PyTorch tensor.
"""
if isinstance(tensor, tf.SparseTensor): if isinstance(tensor, tf.SparseTensor):
tensor = torch.sparse_coo_tensor( tensor = torch.sparse_coo_tensor(
_dense_tf_to_torch(tensor.indices, pin_memory).t(), _dense_tf_to_torch(tensor.indices, pin_memory).t(),

View File

@ -69,6 +69,7 @@ class MaskBlock(torch.nn.Module):
This class is intended for internal use within neural network architectures and should not be This class is intended for internal use within neural network architectures and should not be
directly accessed or modified by external code. directly accessed or modified by external code.
""" """
def __init__( def __init__(
self, mask_block_config: config.MaskBlockConfig, input_dim: int, mask_input_dim: int self, mask_block_config: config.MaskBlockConfig, input_dim: int, mask_input_dim: int
) -> None: ) -> None:
@ -94,11 +95,13 @@ class MaskBlock(torch.nn.Module):
self._input_layer_norm = None self._input_layer_norm = None
if mask_block_config.reduction_factor: if mask_block_config.reduction_factor:
aggregation_size = int(mask_input_dim * mask_block_config.reduction_factor) aggregation_size = int(
mask_input_dim * mask_block_config.reduction_factor)
elif mask_block_config.aggregation_size is not None: elif mask_block_config.aggregation_size is not None:
aggregation_size = mask_block_config.aggregation_size aggregation_size = mask_block_config.aggregation_size
else: else:
raise ValueError("Need one of reduction factor or aggregation size.") raise ValueError(
"Need one of reduction factor or aggregation size.")
self._mask_layer = torch.nn.Sequential( self._mask_layer = torch.nn.Sequential(
torch.nn.Linear(mask_input_dim, aggregation_size), torch.nn.Linear(mask_input_dim, aggregation_size),
@ -123,7 +126,8 @@ class MaskBlock(torch.nn.Module):
""" """
if self._input_layer_norm: if self._input_layer_norm:
net = self._input_layer_norm(net) net = self._input_layer_norm(net)
hidden_layer_output = self._hidden_layer(net * self._mask_layer(mask_input)) hidden_layer_output = self._hidden_layer(
net * self._mask_layer(mask_input))
return self._layer_norm(hidden_layer_output) return self._layer_norm(hidden_layer_output)
@ -170,6 +174,7 @@ class MaskNet(torch.nn.Module):
This class is intended for internal use within neural network architectures and should not be This class is intended for internal use within neural network architectures and should not be
directly accessed or modified by external code. directly accessed or modified by external code.
""" """
def __init__(self, mask_net_config: config.MaskNetConfig, in_features: int): def __init__(self, mask_net_config: config.MaskNetConfig, in_features: int):
""" """
Initializes the MaskNet module. Initializes the MaskNet module.
@ -189,20 +194,23 @@ class MaskNet(torch.nn.Module):
if mask_net_config.use_parallel: if mask_net_config.use_parallel:
total_output_mask_blocks = 0 total_output_mask_blocks = 0
for mask_block_config in mask_net_config.mask_blocks: for mask_block_config in mask_net_config.mask_blocks:
mask_blocks.append(MaskBlock(mask_block_config, in_features, in_features)) mask_blocks.append(
MaskBlock(mask_block_config, in_features, in_features))
total_output_mask_blocks += mask_block_config.output_size total_output_mask_blocks += mask_block_config.output_size
self._mask_blocks = torch.nn.ModuleList(mask_blocks) self._mask_blocks = torch.nn.ModuleList(mask_blocks)
else: else:
input_size = in_features input_size = in_features
for mask_block_config in mask_net_config.mask_blocks: for mask_block_config in mask_net_config.mask_blocks:
mask_blocks.append(MaskBlock(mask_block_config, input_size, in_features)) mask_blocks.append(
MaskBlock(mask_block_config, input_size, in_features))
input_size = mask_block_config.output_size input_size = mask_block_config.output_size
self._mask_blocks = torch.nn.ModuleList(mask_blocks) self._mask_blocks = torch.nn.ModuleList(mask_blocks)
total_output_mask_blocks = mask_block_config.output_size total_output_mask_blocks = mask_block_config.output_size
if mask_net_config.mlp: if mask_net_config.mlp:
self._dense_layers = mlp.Mlp(total_output_mask_blocks, mask_net_config.mlp) self._dense_layers = mlp.Mlp(
total_output_mask_blocks, mask_net_config.mlp)
self.out_features = mask_net_config.mlp.layer_sizes[-1] self.out_features = mask_net_config.mlp.layer_sizes[-1]
else: else:
self.out_features = total_output_mask_blocks self.out_features = total_output_mask_blocks
@ -235,5 +243,6 @@ class MaskNet(torch.nn.Module):
for mask_layer in self._mask_blocks: for mask_layer in self._mask_blocks:
net = mask_layer(net=net, mask_input=inputs) net = mask_layer(net=net, mask_input=inputs)
# Share the output of the stacked MaskBlocks. # Share the output of the stacked MaskBlocks.
output = net if self.mask_net_config.mlp is None else self._dense_layers[net]["output"] output = net if self.mask_net_config.mlp is None else self._dense_layers[
net]["output"]
return {"output": output, "shared_layer": net} return {"output": output, "shared_layer": net}

View File

@ -52,6 +52,7 @@ class ModelAndLoss(torch.nn.Module):
This class is intended for internal use within neural network architectures and should not be This class is intended for internal use within neural network architectures and should not be
directly accessed or modified by external code. directly accessed or modified by external code.
""" """
def __init__( def __init__(
self, self,
model, model,
@ -92,13 +93,16 @@ class ModelAndLoss(torch.nn.Module):
labels=batch.labels, labels=batch.labels,
weights=batch.weights, weights=batch.weights,
) )
losses = self.loss_fn(outputs["logits"], batch.labels.float(), batch.weights.float()) losses = self.loss_fn(
outputs["logits"], batch.labels.float(), batch.weights.float())
if self.stratifiers: if self.stratifiers:
logging.info(f"***** Adding stratifiers *****\n {self.stratifiers}") logging.info(
f"***** Adding stratifiers *****\n {self.stratifiers}")
outputs["stratifiers"] = {} outputs["stratifiers"] = {}
for stratifier in self.stratifiers: for stratifier in self.stratifiers:
outputs["stratifiers"][stratifier.name] = batch.discrete_features[:, stratifier.index] outputs["stratifiers"][stratifier.name] = batch.discrete_features[:,
stratifier.index]
# In general, we can have a large number of losses returned by our loss function. # In general, we can have a large number of losses returned by our loss function.
if isinstance(losses, dict): if isinstance(losses, dict):

View File

@ -39,6 +39,7 @@ class NumericCalibration(torch.nn.Module):
This class is intended for internal use within neural network architectures and should not be This class is intended for internal use within neural network architectures and should not be
directly accessed or modified by external code. directly accessed or modified by external code.
""" """
def __init__( def __init__(
self, self,
pos_downsampling_rate: float, pos_downsampling_rate: float,

View File

@ -38,6 +38,15 @@ import pyarrow.parquet as pq
def _create_dataset(path: str): def _create_dataset(path: str):
"""
Create a PyArrow dataset from Parquet files located at the specified path.
Args:
path (str): The path to the Parquet files.
Returns:
pyarrow.dataset.Dataset: The PyArrow dataset.
"""
fs = infer_fs(path) fs = infer_fs(path)
files = fs.glob(path) files = fs.glob(path)
return pads.dataset(files, format="parquet", filesystem=fs) return pads.dataset(files, format="parquet", filesystem=fs)
@ -47,12 +56,27 @@ class PqReader:
def __init__( def __init__(
self, path: str, num: int = 10, batch_size: int = 1024, columns: Optional[List[str]] = None self, path: str, num: int = 10, batch_size: int = 1024, columns: Optional[List[str]] = None
): ):
"""
Initialize a Parquet Reader.
Args:
path (str): The path to the Parquet files.
num (int): The maximum number of rows to read.
batch_size (int): The batch size for reading data.
columns (Optional[List[str]]): A list of column names to read (default is None, which reads all columns).
"""
self._ds = _create_dataset(path) self._ds = _create_dataset(path)
self._batch_size = batch_size self._batch_size = batch_size
self._num = num self._num = num
self._columns = columns self._columns = columns
def __iter__(self): def __iter__(self):
"""
Iterate through the Parquet data and yield batches of rows.
Yields:
pyarrow.RecordBatch: A batch of rows.
"""
batches = self._ds.to_batches(batch_size=self._batch_size, columns=self._columns) batches = self._ds.to_batches(batch_size=self._batch_size, columns=self._columns)
rows_seen = 0 rows_seen = 0
for count, record in enumerate(batches): for count, record in enumerate(batches):
@ -62,6 +86,12 @@ class PqReader:
rows_seen += record.data.num_rows rows_seen += record.data.num_rows
def _head(self): def _head(self):
"""
Get the first `num` rows of the Parquet data.
Returns:
pyarrow.RecordBatch: A batch of rows.
"""
total_read = self._num * self.bytes_per_row total_read = self._num * self.bytes_per_row
if total_read >= int(500e6): if total_read >= int(500e6):
raise Exception( raise Exception(
@ -71,6 +101,12 @@ class PqReader:
@property @property
def bytes_per_row(self) -> int: def bytes_per_row(self) -> int:
"""
Calculate the estimated bytes per row in the dataset.
Returns:
int: The estimated bytes per row.
"""
nbits = 0 nbits = 0
for t in self._ds.schema.types: for t in self._ds.schema.types:
try: try:
@ -81,17 +117,22 @@ class PqReader:
return nbits // 8 return nbits // 8
def schema(self): def schema(self):
"""
Display the schema of the Parquet dataset.
"""
print(f"\n# Schema\n{self._ds.schema}") print(f"\n# Schema\n{self._ds.schema}")
def head(self): def head(self):
"""Displays first --num rows.""" """
Display the first `num` rows of the Parquet data as a pandas DataFrame.
"""
print(self._head().to_pandas()) print(self._head().to_pandas())
def distinct(self): def distinct(self):
"""Displays unique values seen in specified columns in the first `--num` rows. """
Display unique values seen in specified columns in the first `num` rows.
Useful for getting an approximate vocabulary for certain columns. Useful for getting an approximate vocabulary for certain columns.
""" """
for col_name, column in zip(self._head().column_names, self._head().columns): for col_name, column in zip(self._head().column_names, self._head().columns):
print(col_name) print(col_name)