mirror of
https://github.com/twitter/the-algorithm-ml.git
synced 2025-01-24 12:51:10 +01:00
Updates
This commit is contained in:
parent
cc73f5fcb7
commit
f7f26d0c20
21
model.py
21
model.py
@ -54,12 +54,20 @@ def maybe_shard_model(
|
||||
model,
|
||||
device: torch.device,
|
||||
):
|
||||
"""Set up and apply DistributedModelParallel to a model if running in a distributed environment.
|
||||
"""
|
||||
Set up and apply DistributedModelParallel to a model if running in a distributed environment.
|
||||
|
||||
If in a distributed environment, constructs Topology, sharders, and ShardingPlan, then applies
|
||||
DistributedModelParallel.
|
||||
|
||||
If not in a distributed environment, returns model directly.
|
||||
If not in a distributed environment, returns the model directly.
|
||||
|
||||
Args:
|
||||
model: The PyTorch model.
|
||||
device: The target device (e.g., 'cuda').
|
||||
|
||||
Returns:
|
||||
The model wrapped with DistributedModelParallel if in a distributed environment, else the original model.
|
||||
"""
|
||||
if dist.is_initialized():
|
||||
logging.info("***** Wrapping in DistributedModelParallel *****")
|
||||
@ -74,13 +82,14 @@ def maybe_shard_model(
|
||||
|
||||
|
||||
def log_sharded_tensor_content(weight_name: str, table_name: str, weight_tensor) -> None:
|
||||
"""Handy function to log the content of EBC embedding layer.
|
||||
"""
|
||||
Handy function to log the content of an EBC (Embedding Bag Concatenation) embedding layer.
|
||||
Only works for single GPU machines.
|
||||
|
||||
Args:
|
||||
weight_name: name of tensor, as defined in model
|
||||
table_name: name of the EBC table the weight is taken from
|
||||
weight_tensor: embedding weight tensor
|
||||
weight_name: Name of the tensor, as defined in the model.
|
||||
table_name: Name of the EBC table the weight is taken from.
|
||||
weight_tensor: Embedding weight tensor.
|
||||
"""
|
||||
logging.info(f"{weight_name}, {table_name}", rank=-1)
|
||||
logging.info(f"{weight_tensor.metadata()}", rank=-1)
|
||||
|
@ -11,8 +11,10 @@ class ExplicitDateInputs(base_config.BaseConfig):
|
||||
"""Arguments to select train/validation data using end_date and days of data."""
|
||||
|
||||
data_root: str = pydantic.Field(..., description="Data path prefix.")
|
||||
end_date: str = pydantic.Field(..., description="Data end date, inclusive.")
|
||||
days: int = pydantic.Field(..., description="Number of days of data for dataset.")
|
||||
end_date: str = pydantic.Field(...,
|
||||
description="Data end date, inclusive.")
|
||||
days: int = pydantic.Field(...,
|
||||
description="Number of days of data for dataset.")
|
||||
num_missing_days_tol: int = pydantic.Field(
|
||||
0, description="We tolerate <= num_missing_days_tol days of missing data."
|
||||
)
|
||||
@ -22,8 +24,10 @@ class ExplicitDatetimeInputs(base_config.BaseConfig):
|
||||
"""Arguments to select train/validation data using end_datetime and hours of data."""
|
||||
|
||||
data_root: str = pydantic.Field(..., description="Data path prefix.")
|
||||
end_datetime: str = pydantic.Field(..., description="Data end datetime, inclusive.")
|
||||
hours: int = pydantic.Field(..., description="Number of hours of data for dataset.")
|
||||
end_datetime: str = pydantic.Field(...,
|
||||
description="Data end datetime, inclusive.")
|
||||
hours: int = pydantic.Field(...,
|
||||
description="Number of hours of data for dataset.")
|
||||
num_missing_hours_tol: int = pydantic.Field(
|
||||
0, description="We tolerate <= num_missing_hours_tol hours of missing data."
|
||||
)
|
||||
@ -42,7 +46,8 @@ class DatasetConfig(base_config.BaseConfig):
|
||||
explicit_datetime_inputs: ExplicitDatetimeInputs = pydantic.Field(
|
||||
None, one_of="date_inputs_format"
|
||||
)
|
||||
explicit_date_inputs: ExplicitDateInputs = pydantic.Field(None, one_of="date_inputs_format")
|
||||
explicit_date_inputs: ExplicitDateInputs = pydantic.Field(
|
||||
None, one_of="date_inputs_format")
|
||||
|
||||
global_batch_size: pydantic.PositiveInt
|
||||
|
||||
@ -52,7 +57,8 @@ class DatasetConfig(base_config.BaseConfig):
|
||||
repeat_files: bool = pydantic.Field(
|
||||
True, description="DEPRICATED. Files are repeated no matter what this is set to."
|
||||
)
|
||||
file_batch_size: pydantic.PositiveInt = pydantic.Field(16, description="File batch size")
|
||||
file_batch_size: pydantic.PositiveInt = pydantic.Field(
|
||||
16, description="File batch size")
|
||||
|
||||
cache: bool = pydantic.Field(
|
||||
False,
|
||||
@ -70,7 +76,8 @@ class DatasetConfig(base_config.BaseConfig):
|
||||
)
|
||||
|
||||
# tf.data.Dataset options
|
||||
examples_shuffle_buffer_size: int = pydantic.Field(1024, description="Size of shuffle buffers.")
|
||||
examples_shuffle_buffer_size: int = pydantic.Field(
|
||||
1024, description="Size of shuffle buffers.")
|
||||
map_num_parallel_calls: pydantic.PositiveInt = pydantic.Field(
|
||||
None, description="Number of parallel calls."
|
||||
)
|
||||
@ -125,7 +132,8 @@ class TaskData(base_config.BaseConfig):
|
||||
|
||||
|
||||
class SegDenseSchema(base_config.BaseConfig):
|
||||
schema_path: str = pydantic.Field(..., description="Path to feature config json.")
|
||||
schema_path: str = pydantic.Field(...,
|
||||
description="Path to feature config json.")
|
||||
features: typing.List[str] = pydantic.Field(
|
||||
[],
|
||||
description="List of features (in addition to the renamed features) to read from schema path above.",
|
||||
@ -192,8 +200,10 @@ class DownsampleNegatives(base_config.BaseConfig):
|
||||
|
||||
|
||||
class Preprocess(base_config.BaseConfig):
|
||||
truncate_and_slice: TruncateAndSlice = pydantic.Field(None, description="Truncation and slicing.")
|
||||
downcast: DownCast = pydantic.Field(None, description="Down cast to features.")
|
||||
truncate_and_slice: TruncateAndSlice = pydantic.Field(
|
||||
None, description="Truncation and slicing.")
|
||||
downcast: DownCast = pydantic.Field(
|
||||
None, description="Down cast to features.")
|
||||
rectify_labels: RectifyLabels = pydantic.Field(
|
||||
None, description="Rectify labels for a given overlap window"
|
||||
)
|
||||
@ -242,5 +252,6 @@ class RecapDataConfig(DatasetConfig):
|
||||
if values.get("evaluation_tasks") is not None:
|
||||
for task in values["evaluation_tasks"]:
|
||||
if task not in values["tasks"]:
|
||||
raise KeyError(f"Evaluation task {task} must be in tasks. Received {values['tasks']}")
|
||||
raise KeyError(
|
||||
f"Evaluation task {task} must be in tasks. Received {values['tasks']}")
|
||||
return values
|
||||
|
@ -9,9 +9,20 @@ import numpy as np
|
||||
|
||||
|
||||
class TruncateAndSlice(tf.keras.Model):
|
||||
"""Class for truncating and slicing."""
|
||||
"""
|
||||
A class for truncating and slicing input features based on the provided configuration.
|
||||
|
||||
Args:
|
||||
truncate_and_slice_config: A configuration object specifying how to truncate and slice features.
|
||||
"""
|
||||
|
||||
def __init__(self, truncate_and_slice_config):
|
||||
"""
|
||||
Initializes the TruncateAndSlice model.
|
||||
|
||||
Args:
|
||||
truncate_and_slice_config: A configuration object specifying how to truncate and slice features.
|
||||
"""
|
||||
super().__init__()
|
||||
self._truncate_and_slice_config = truncate_and_slice_config
|
||||
|
||||
@ -32,6 +43,17 @@ class TruncateAndSlice(tf.keras.Model):
|
||||
self._binary_mask = None
|
||||
|
||||
def call(self, inputs, training=None, mask=None):
|
||||
"""
|
||||
Applies truncation and slicing to the input features based on the configuration.
|
||||
|
||||
Args:
|
||||
inputs: A dictionary of input features.
|
||||
training: A boolean indicating whether the model is in training mode.
|
||||
mask: A mask tensor.
|
||||
|
||||
Returns:
|
||||
A dictionary of truncated and sliced input features.
|
||||
"""
|
||||
outputs = tf.nest.pack_sequence_as(inputs, tf.nest.flatten(inputs))
|
||||
if self._truncate_and_slice_config.continuous_feature_truncation:
|
||||
logging.info("Truncating continuous")
|
||||
@ -51,12 +73,23 @@ class TruncateAndSlice(tf.keras.Model):
|
||||
|
||||
|
||||
class DownCast(tf.keras.Model):
|
||||
"""Class for Down casting dataset before serialization and transferring to training host.
|
||||
Depends on the data type and the actual data range, the down casting can be lossless or not.
|
||||
"""
|
||||
A class for downcasting dataset before serialization and transferring to the training host.
|
||||
|
||||
Depending on the data type and the actual data range, the downcasting can be lossless or not.
|
||||
It is strongly recommended to compare the metrics before and after downcasting.
|
||||
|
||||
Args:
|
||||
downcast_config: A configuration object specifying the features and their target data types.
|
||||
"""
|
||||
|
||||
def __init__(self, downcast_config):
|
||||
"""
|
||||
Initializes the DownCast model.
|
||||
|
||||
Args:
|
||||
downcast_config: A configuration object specifying the features and their target data types.
|
||||
"""
|
||||
super().__init__()
|
||||
self.config = downcast_config
|
||||
self._type_map = {
|
||||
@ -65,6 +98,17 @@ class DownCast(tf.keras.Model):
|
||||
}
|
||||
|
||||
def call(self, inputs, training=None, mask=None):
|
||||
"""
|
||||
Applies downcasting to the input features based on the configuration.
|
||||
|
||||
Args:
|
||||
inputs: A dictionary of input features.
|
||||
training: A boolean indicating whether the model is in training mode.
|
||||
mask: A mask tensor.
|
||||
|
||||
Returns:
|
||||
A dictionary of downcasted input features.
|
||||
"""
|
||||
outputs = tf.nest.pack_sequence_as(inputs, tf.nest.flatten(inputs))
|
||||
for feature, type_str in self.config.features.items():
|
||||
assert type_str in self._type_map
|
||||
@ -78,14 +122,39 @@ class DownCast(tf.keras.Model):
|
||||
|
||||
|
||||
class RectifyLabels(tf.keras.Model):
|
||||
"""Class for rectifying labels"""
|
||||
"""
|
||||
A class for downcasting dataset before serialization and transferring to the training host.
|
||||
|
||||
Depending on the data type and the actual data range, the downcasting can be lossless or not.
|
||||
It is strongly recommended to compare the metrics before and after downcasting.
|
||||
|
||||
Args:
|
||||
downcast_config: A configuration object specifying the features and their target data types.
|
||||
"""
|
||||
|
||||
def __init__(self, rectify_label_config):
|
||||
"""
|
||||
Initializes the DownCast model.
|
||||
|
||||
Args:
|
||||
downcast_config: A configuration object specifying the features and their target data types.
|
||||
"""
|
||||
super().__init__()
|
||||
self._config = rectify_label_config
|
||||
self._window = int(self._config.label_rectification_window_in_hours * 60 * 60 * 1000)
|
||||
|
||||
def call(self, inputs, training=None, mask=None):
|
||||
"""
|
||||
Applies downcasting to the input features based on the configuration.
|
||||
|
||||
Args:
|
||||
inputs: A dictionary of input features.
|
||||
training: A boolean indicating whether the model is in training mode.
|
||||
mask: A mask tensor.
|
||||
|
||||
Returns:
|
||||
A dictionary of downcasted input features.
|
||||
"""
|
||||
served_ts_field = self._config.served_timestamp_field
|
||||
impressed_ts_field = self._config.impressed_timestamp_field
|
||||
|
||||
@ -102,13 +171,37 @@ class RectifyLabels(tf.keras.Model):
|
||||
|
||||
|
||||
class ExtractFeatures(tf.keras.Model):
|
||||
"""Class for extracting individual features from dense tensors by their index."""
|
||||
"""
|
||||
A class for rectifying labels based on specified conditions.
|
||||
|
||||
This class is used to adjust label values in a dataset based on configured conditions involving timestamps.
|
||||
|
||||
Args:
|
||||
rectify_label_config: A configuration object specifying the timestamp fields and label-to-engaged timestamp field mappings.
|
||||
"""
|
||||
|
||||
def __init__(self, extract_features_config):
|
||||
"""
|
||||
Initializes the RectifyLabels model.
|
||||
|
||||
Args:
|
||||
rectify_label_config: A configuration object specifying the timestamp fields and label-to-engaged timestamp field mappings.
|
||||
"""
|
||||
super().__init__()
|
||||
self._config = extract_features_config
|
||||
|
||||
def call(self, inputs, training=None, mask=None):
|
||||
"""
|
||||
Rectifies label values based on the specified conditions.
|
||||
|
||||
Args:
|
||||
inputs: A dictionary of input features including timestamp fields and labels.
|
||||
training: A boolean indicating whether the model is in training mode.
|
||||
mask: A mask tensor.
|
||||
|
||||
Returns:
|
||||
A dictionary of input features with rectified label values.
|
||||
"""
|
||||
|
||||
for row in self._config.extract_feature_table:
|
||||
inputs[row.name] = inputs[row.source_tensor][:, row.index]
|
||||
@ -168,7 +261,16 @@ class DownsampleNegatives(tf.keras.Model):
|
||||
|
||||
|
||||
def build_preprocess(preprocess_config, mode=config_mod.JobMode.TRAIN):
|
||||
"""Builds a preprocess model to apply all preprocessing stages."""
|
||||
"""
|
||||
Builds a preprocess model to apply all preprocessing stages.
|
||||
|
||||
Args:
|
||||
preprocess_config: A configuration object specifying the preprocessing parameters.
|
||||
mode: A mode indicating the current job mode (TRAIN or INFERENCE).
|
||||
|
||||
Returns:
|
||||
A preprocess model that applies all specified preprocessing stages.
|
||||
"""
|
||||
if mode == config_mod.JobMode.INFERENCE:
|
||||
logging.info("Not building preprocessors for dataloading since we are in Inference mode.")
|
||||
return None
|
||||
|
@ -8,7 +8,8 @@ import tensorflow as tf
|
||||
|
||||
|
||||
DEFAULTS_MAP = {"int64_list": 0, "float_list": 0.0, "bytes_list": ""}
|
||||
DTYPE_MAP = {"int64_list": tf.int64, "float_list": tf.float32, "bytes_list": tf.string}
|
||||
DTYPE_MAP = {"int64_list": tf.int64,
|
||||
"float_list": tf.float32, "bytes_list": tf.string}
|
||||
|
||||
|
||||
def create_tf_example_schema(
|
||||
@ -27,7 +28,8 @@ def create_tf_example_schema(
|
||||
segdense_config = data_config.seg_dense_schema
|
||||
labels = list(data_config.tasks.keys())
|
||||
used_features = (
|
||||
segdense_config.features + list(segdense_config.renamed_features.values()) + labels
|
||||
segdense_config.features +
|
||||
list(segdense_config.renamed_features.values()) + labels
|
||||
)
|
||||
logging.info(used_features)
|
||||
|
||||
@ -40,19 +42,22 @@ def create_tf_example_schema(
|
||||
dtype = entry["dtype"]
|
||||
|
||||
if feature_name in labels:
|
||||
logging.info(f"Label: feature name is {feature_name} type is {dtype}")
|
||||
logging.info(
|
||||
f"Label: feature name is {feature_name} type is {dtype}")
|
||||
tfe_schema[feature_name] = tf.io.FixedLenFeature(
|
||||
length, DTYPE_MAP[dtype], DEFAULTS_MAP[dtype]
|
||||
)
|
||||
elif length == -1:
|
||||
tfe_schema[feature_name] = tf.io.VarLenFeature(DTYPE_MAP[dtype])
|
||||
tfe_schema[feature_name] = tf.io.VarLenFeature(
|
||||
DTYPE_MAP[dtype])
|
||||
else:
|
||||
tfe_schema[feature_name] = tf.io.FixedLenFeature(
|
||||
length, DTYPE_MAP[dtype], [DEFAULTS_MAP[dtype]] * length
|
||||
)
|
||||
for feature_name in used_features:
|
||||
if feature_name not in tfe_schema:
|
||||
raise ValueError(f"{feature_name} missing from schema: {segdense_config.schema_path}.")
|
||||
raise ValueError(
|
||||
f"{feature_name} missing from schema: {segdense_config.schema_path}.")
|
||||
return tfe_schema
|
||||
|
||||
|
||||
@ -82,7 +87,8 @@ def parse_tf_example(
|
||||
Returns:
|
||||
Dictionary of tensors to be used as model input.
|
||||
"""
|
||||
inputs = tf.io.parse_example(serialized=serialized_example, features=tfe_schema)
|
||||
inputs = tf.io.parse_example(
|
||||
serialized=serialized_example, features=tfe_schema)
|
||||
|
||||
for new_feature_name, old_feature_name in seg_dense_schema_config.renamed_features.items():
|
||||
inputs[new_feature_name] = inputs.pop(old_feature_name)
|
||||
@ -90,7 +96,8 @@ def parse_tf_example(
|
||||
# This should not actually be used except for experimentation with low precision floats.
|
||||
if "mask_mantissa_features" in seg_dense_schema_config:
|
||||
for feature_name, mask_length in seg_dense_schema_config.mask_mantissa_features.items():
|
||||
inputs[feature_name] = mask_mantissa(inputs[feature_name], mask_length)
|
||||
inputs[feature_name] = mask_mantissa(
|
||||
inputs[feature_name], mask_length)
|
||||
|
||||
# DANGER DANGER: This default seems really scary, and it's only here because it has to be visible
|
||||
# at TF level.
|
||||
|
@ -9,44 +9,59 @@ def keyed_tensor_from_tensors_dict(
|
||||
tensor_map: Mapping[str, torch.Tensor]
|
||||
) -> "torchrec.KeyedTensor":
|
||||
"""
|
||||
Convert a dictionary of torch tensor to torchrec keyed tensor
|
||||
Convert a dictionary of torch tensors to a torchrec KeyedTensor.
|
||||
|
||||
Args:
|
||||
tensor_map:
|
||||
tensor_map: A mapping of tensor names to torch tensors.
|
||||
|
||||
Returns:
|
||||
|
||||
A torchrec KeyedTensor.
|
||||
"""
|
||||
keys = list(tensor_map.keys())
|
||||
# We expect batch size to be first dim. However, if we get a shape [Batch_size],
|
||||
# KeyedTensor will not find the correct batch_size. So, in those cases we make sure the shape is
|
||||
# [Batch_size x 1].
|
||||
values = [
|
||||
tensor_map[key] if len(tensor_map[key].shape) > 1 else torch.unsqueeze(tensor_map[key], -1)
|
||||
tensor_map[key] if len(tensor_map[key].shape) > 1 else torch.unsqueeze(
|
||||
tensor_map[key], -1)
|
||||
for key in keys
|
||||
]
|
||||
return torchrec.KeyedTensor.from_tensor_list(keys, values)
|
||||
|
||||
|
||||
def _compute_jagged_tensor_from_tensor(tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Compute a jagged tensor from a torch tensor.
|
||||
|
||||
Args:
|
||||
tensor: Input torch tensor.
|
||||
|
||||
Returns:
|
||||
A tuple containing the values and lengths of the jagged tensor.
|
||||
"""
|
||||
if tensor.is_sparse:
|
||||
x = tensor.coalesce() # Ensure that the indices are ordered.
|
||||
lengths = torch.bincount(x.indices()[0])
|
||||
values = x.values()
|
||||
else:
|
||||
values = tensor
|
||||
lengths = torch.ones(tensor.shape[0], dtype=torch.int32, device=tensor.device)
|
||||
lengths = torch.ones(
|
||||
tensor.shape[0], dtype=torch.int32, device=tensor.device)
|
||||
return values, lengths
|
||||
|
||||
|
||||
def jagged_tensor_from_tensor(tensor: torch.Tensor) -> "torchrec.JaggedTensor":
|
||||
"""
|
||||
Convert a torch tensor to torchrec jagged tensor.
|
||||
Note: Currently only support shape of [Batch_size] or [Batch_size x N] for dense tensors.
|
||||
For sparse tensor the shape of .values() should be [Batch_size] or [Batch_size x N]; the
|
||||
dense_shape of the sparse tensor can be arbitrary.
|
||||
Convert a torch tensor to a torchrec jagged tensor.
|
||||
|
||||
Note: Currently, this function only supports input tensors with shapes of [Batch_size] or [Batch_size x N] for dense tensors.
|
||||
For sparse tensors, the shape of .values() should be [Batch_size] or [Batch_size x N], and the dense_shape of the sparse tensor can be arbitrary.
|
||||
|
||||
Args:
|
||||
tensor: a torch (sparse) tensor.
|
||||
tensor: A torch (sparse) tensor.
|
||||
|
||||
Returns:
|
||||
A torchrec JaggedTensor.
|
||||
"""
|
||||
values, lengths = _compute_jagged_tensor_from_tensor(tensor)
|
||||
return torchrec.JaggedTensor(values=values, lengths=lengths)
|
||||
@ -56,15 +71,16 @@ def keyed_jagged_tensor_from_tensors_dict(
|
||||
tensor_map: Mapping[str, torch.Tensor]
|
||||
) -> "torchrec.KeyedJaggedTensor":
|
||||
"""
|
||||
Convert a dictionary of (sparse) torch tensors to torchrec keyed jagged tensor.
|
||||
Note: Currently only support shape of [Batch_size] or [Batch_size x 1] for dense tensors.
|
||||
For sparse tensor the shape of .values() should be [Batch_size] or [Batch_size x 1]; the
|
||||
dense_shape of the sparse tensor can be arbitrary.
|
||||
Convert a dictionary of (sparse) torch tensors to a torchrec keyed jagged tensor.
|
||||
|
||||
Note: Currently, this function only supports input tensors with shapes of [Batch_size] or [Batch_size x 1] for dense tensors.
|
||||
For sparse tensors, the shape of .values() should be [Batch_size] or [Batch_size x 1], and the dense_shape of the sparse tensor can be arbitrary.
|
||||
|
||||
Args:
|
||||
tensor_map:
|
||||
tensor_map: A mapping of tensor names to torch tensors.
|
||||
|
||||
Returns:
|
||||
|
||||
A torchrec KeyedJaggedTensor.
|
||||
"""
|
||||
|
||||
if not tensor_map:
|
||||
@ -91,10 +107,29 @@ def keyed_jagged_tensor_from_tensors_dict(
|
||||
|
||||
|
||||
def _tf_to_numpy(tf_tensor: tf.Tensor) -> np.ndarray:
|
||||
"""
|
||||
Convert a TensorFlow tensor to a NumPy array.
|
||||
|
||||
Args:
|
||||
tf_tensor: TensorFlow tensor.
|
||||
|
||||
Returns:
|
||||
NumPy array.
|
||||
"""
|
||||
return tf_tensor._numpy() # noqa
|
||||
|
||||
|
||||
def _dense_tf_to_torch(tensor: tf.Tensor, pin_memory: bool) -> torch.Tensor:
|
||||
"""
|
||||
Convert a dense TensorFlow tensor to a PyTorch tensor.
|
||||
|
||||
Args:
|
||||
tensor: Dense TensorFlow tensor.
|
||||
pin_memory: Whether to pin the tensor in memory (for CUDA).
|
||||
|
||||
Returns:
|
||||
PyTorch tensor.
|
||||
"""
|
||||
tensor = _tf_to_numpy(tensor)
|
||||
# Pytorch does not support bfloat16, up cast to float32 to keep the same number of bits on exponent
|
||||
if tensor.dtype.name == "bfloat16":
|
||||
@ -109,6 +144,16 @@ def _dense_tf_to_torch(tensor: tf.Tensor, pin_memory: bool) -> torch.Tensor:
|
||||
def sparse_or_dense_tf_to_torch(
|
||||
tensor: Union[tf.Tensor, tf.SparseTensor], pin_memory: bool
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Convert a TensorFlow tensor (sparse or dense) to a PyTorch tensor.
|
||||
|
||||
Args:
|
||||
tensor: TensorFlow tensor (sparse or dense).
|
||||
pin_memory: Whether to pin the tensor in memory (for CUDA).
|
||||
|
||||
Returns:
|
||||
PyTorch tensor.
|
||||
"""
|
||||
if isinstance(tensor, tf.SparseTensor):
|
||||
tensor = torch.sparse_coo_tensor(
|
||||
_dense_tf_to_torch(tensor.indices, pin_memory).t(),
|
||||
|
@ -69,6 +69,7 @@ class MaskBlock(torch.nn.Module):
|
||||
This class is intended for internal use within neural network architectures and should not be
|
||||
directly accessed or modified by external code.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, mask_block_config: config.MaskBlockConfig, input_dim: int, mask_input_dim: int
|
||||
) -> None:
|
||||
@ -94,11 +95,13 @@ class MaskBlock(torch.nn.Module):
|
||||
self._input_layer_norm = None
|
||||
|
||||
if mask_block_config.reduction_factor:
|
||||
aggregation_size = int(mask_input_dim * mask_block_config.reduction_factor)
|
||||
aggregation_size = int(
|
||||
mask_input_dim * mask_block_config.reduction_factor)
|
||||
elif mask_block_config.aggregation_size is not None:
|
||||
aggregation_size = mask_block_config.aggregation_size
|
||||
else:
|
||||
raise ValueError("Need one of reduction factor or aggregation size.")
|
||||
raise ValueError(
|
||||
"Need one of reduction factor or aggregation size.")
|
||||
|
||||
self._mask_layer = torch.nn.Sequential(
|
||||
torch.nn.Linear(mask_input_dim, aggregation_size),
|
||||
@ -123,7 +126,8 @@ class MaskBlock(torch.nn.Module):
|
||||
"""
|
||||
if self._input_layer_norm:
|
||||
net = self._input_layer_norm(net)
|
||||
hidden_layer_output = self._hidden_layer(net * self._mask_layer(mask_input))
|
||||
hidden_layer_output = self._hidden_layer(
|
||||
net * self._mask_layer(mask_input))
|
||||
return self._layer_norm(hidden_layer_output)
|
||||
|
||||
|
||||
@ -170,6 +174,7 @@ class MaskNet(torch.nn.Module):
|
||||
This class is intended for internal use within neural network architectures and should not be
|
||||
directly accessed or modified by external code.
|
||||
"""
|
||||
|
||||
def __init__(self, mask_net_config: config.MaskNetConfig, in_features: int):
|
||||
"""
|
||||
Initializes the MaskNet module.
|
||||
@ -189,20 +194,23 @@ class MaskNet(torch.nn.Module):
|
||||
if mask_net_config.use_parallel:
|
||||
total_output_mask_blocks = 0
|
||||
for mask_block_config in mask_net_config.mask_blocks:
|
||||
mask_blocks.append(MaskBlock(mask_block_config, in_features, in_features))
|
||||
mask_blocks.append(
|
||||
MaskBlock(mask_block_config, in_features, in_features))
|
||||
total_output_mask_blocks += mask_block_config.output_size
|
||||
self._mask_blocks = torch.nn.ModuleList(mask_blocks)
|
||||
else:
|
||||
input_size = in_features
|
||||
for mask_block_config in mask_net_config.mask_blocks:
|
||||
mask_blocks.append(MaskBlock(mask_block_config, input_size, in_features))
|
||||
mask_blocks.append(
|
||||
MaskBlock(mask_block_config, input_size, in_features))
|
||||
input_size = mask_block_config.output_size
|
||||
|
||||
self._mask_blocks = torch.nn.ModuleList(mask_blocks)
|
||||
total_output_mask_blocks = mask_block_config.output_size
|
||||
|
||||
if mask_net_config.mlp:
|
||||
self._dense_layers = mlp.Mlp(total_output_mask_blocks, mask_net_config.mlp)
|
||||
self._dense_layers = mlp.Mlp(
|
||||
total_output_mask_blocks, mask_net_config.mlp)
|
||||
self.out_features = mask_net_config.mlp.layer_sizes[-1]
|
||||
else:
|
||||
self.out_features = total_output_mask_blocks
|
||||
@ -235,5 +243,6 @@ class MaskNet(torch.nn.Module):
|
||||
for mask_layer in self._mask_blocks:
|
||||
net = mask_layer(net=net, mask_input=inputs)
|
||||
# Share the output of the stacked MaskBlocks.
|
||||
output = net if self.mask_net_config.mlp is None else self._dense_layers[net]["output"]
|
||||
output = net if self.mask_net_config.mlp is None else self._dense_layers[
|
||||
net]["output"]
|
||||
return {"output": output, "shared_layer": net}
|
||||
|
@ -52,6 +52,7 @@ class ModelAndLoss(torch.nn.Module):
|
||||
This class is intended for internal use within neural network architectures and should not be
|
||||
directly accessed or modified by external code.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
@ -92,13 +93,16 @@ class ModelAndLoss(torch.nn.Module):
|
||||
labels=batch.labels,
|
||||
weights=batch.weights,
|
||||
)
|
||||
losses = self.loss_fn(outputs["logits"], batch.labels.float(), batch.weights.float())
|
||||
losses = self.loss_fn(
|
||||
outputs["logits"], batch.labels.float(), batch.weights.float())
|
||||
|
||||
if self.stratifiers:
|
||||
logging.info(f"***** Adding stratifiers *****\n {self.stratifiers}")
|
||||
logging.info(
|
||||
f"***** Adding stratifiers *****\n {self.stratifiers}")
|
||||
outputs["stratifiers"] = {}
|
||||
for stratifier in self.stratifiers:
|
||||
outputs["stratifiers"][stratifier.name] = batch.discrete_features[:, stratifier.index]
|
||||
outputs["stratifiers"][stratifier.name] = batch.discrete_features[:,
|
||||
stratifier.index]
|
||||
|
||||
# In general, we can have a large number of losses returned by our loss function.
|
||||
if isinstance(losses, dict):
|
||||
|
@ -39,6 +39,7 @@ class NumericCalibration(torch.nn.Module):
|
||||
This class is intended for internal use within neural network architectures and should not be
|
||||
directly accessed or modified by external code.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pos_downsampling_rate: float,
|
||||
|
47
tools/pq.py
47
tools/pq.py
@ -38,6 +38,15 @@ import pyarrow.parquet as pq
|
||||
|
||||
|
||||
def _create_dataset(path: str):
|
||||
"""
|
||||
Create a PyArrow dataset from Parquet files located at the specified path.
|
||||
|
||||
Args:
|
||||
path (str): The path to the Parquet files.
|
||||
|
||||
Returns:
|
||||
pyarrow.dataset.Dataset: The PyArrow dataset.
|
||||
"""
|
||||
fs = infer_fs(path)
|
||||
files = fs.glob(path)
|
||||
return pads.dataset(files, format="parquet", filesystem=fs)
|
||||
@ -47,12 +56,27 @@ class PqReader:
|
||||
def __init__(
|
||||
self, path: str, num: int = 10, batch_size: int = 1024, columns: Optional[List[str]] = None
|
||||
):
|
||||
"""
|
||||
Initialize a Parquet Reader.
|
||||
|
||||
Args:
|
||||
path (str): The path to the Parquet files.
|
||||
num (int): The maximum number of rows to read.
|
||||
batch_size (int): The batch size for reading data.
|
||||
columns (Optional[List[str]]): A list of column names to read (default is None, which reads all columns).
|
||||
"""
|
||||
self._ds = _create_dataset(path)
|
||||
self._batch_size = batch_size
|
||||
self._num = num
|
||||
self._columns = columns
|
||||
|
||||
def __iter__(self):
|
||||
"""
|
||||
Iterate through the Parquet data and yield batches of rows.
|
||||
|
||||
Yields:
|
||||
pyarrow.RecordBatch: A batch of rows.
|
||||
"""
|
||||
batches = self._ds.to_batches(batch_size=self._batch_size, columns=self._columns)
|
||||
rows_seen = 0
|
||||
for count, record in enumerate(batches):
|
||||
@ -62,6 +86,12 @@ class PqReader:
|
||||
rows_seen += record.data.num_rows
|
||||
|
||||
def _head(self):
|
||||
"""
|
||||
Get the first `num` rows of the Parquet data.
|
||||
|
||||
Returns:
|
||||
pyarrow.RecordBatch: A batch of rows.
|
||||
"""
|
||||
total_read = self._num * self.bytes_per_row
|
||||
if total_read >= int(500e6):
|
||||
raise Exception(
|
||||
@ -71,6 +101,12 @@ class PqReader:
|
||||
|
||||
@property
|
||||
def bytes_per_row(self) -> int:
|
||||
"""
|
||||
Calculate the estimated bytes per row in the dataset.
|
||||
|
||||
Returns:
|
||||
int: The estimated bytes per row.
|
||||
"""
|
||||
nbits = 0
|
||||
for t in self._ds.schema.types:
|
||||
try:
|
||||
@ -81,17 +117,22 @@ class PqReader:
|
||||
return nbits // 8
|
||||
|
||||
def schema(self):
|
||||
"""
|
||||
Display the schema of the Parquet dataset.
|
||||
"""
|
||||
print(f"\n# Schema\n{self._ds.schema}")
|
||||
|
||||
def head(self):
|
||||
"""Displays first --num rows."""
|
||||
"""
|
||||
Display the first `num` rows of the Parquet data as a pandas DataFrame.
|
||||
"""
|
||||
print(self._head().to_pandas())
|
||||
|
||||
def distinct(self):
|
||||
"""Displays unique values seen in specified columns in the first `--num` rows.
|
||||
"""
|
||||
Display unique values seen in specified columns in the first `num` rows.
|
||||
|
||||
Useful for getting an approximate vocabulary for certain columns.
|
||||
|
||||
"""
|
||||
for col_name, column in zip(self._head().column_names, self._head().columns):
|
||||
print(col_name)
|
||||
|
Loading…
x
Reference in New Issue
Block a user