mirror of
https://github.com/twitter/the-algorithm-ml.git
synced 2025-01-24 21:01:11 +01:00
Updates
This commit is contained in:
parent
cc73f5fcb7
commit
f7f26d0c20
21
model.py
21
model.py
@ -54,12 +54,20 @@ def maybe_shard_model(
|
|||||||
model,
|
model,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
):
|
):
|
||||||
"""Set up and apply DistributedModelParallel to a model if running in a distributed environment.
|
"""
|
||||||
|
Set up and apply DistributedModelParallel to a model if running in a distributed environment.
|
||||||
|
|
||||||
If in a distributed environment, constructs Topology, sharders, and ShardingPlan, then applies
|
If in a distributed environment, constructs Topology, sharders, and ShardingPlan, then applies
|
||||||
DistributedModelParallel.
|
DistributedModelParallel.
|
||||||
|
|
||||||
If not in a distributed environment, returns model directly.
|
If not in a distributed environment, returns the model directly.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: The PyTorch model.
|
||||||
|
device: The target device (e.g., 'cuda').
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The model wrapped with DistributedModelParallel if in a distributed environment, else the original model.
|
||||||
"""
|
"""
|
||||||
if dist.is_initialized():
|
if dist.is_initialized():
|
||||||
logging.info("***** Wrapping in DistributedModelParallel *****")
|
logging.info("***** Wrapping in DistributedModelParallel *****")
|
||||||
@ -74,13 +82,14 @@ def maybe_shard_model(
|
|||||||
|
|
||||||
|
|
||||||
def log_sharded_tensor_content(weight_name: str, table_name: str, weight_tensor) -> None:
|
def log_sharded_tensor_content(weight_name: str, table_name: str, weight_tensor) -> None:
|
||||||
"""Handy function to log the content of EBC embedding layer.
|
"""
|
||||||
|
Handy function to log the content of an EBC (Embedding Bag Concatenation) embedding layer.
|
||||||
Only works for single GPU machines.
|
Only works for single GPU machines.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
weight_name: name of tensor, as defined in model
|
weight_name: Name of the tensor, as defined in the model.
|
||||||
table_name: name of the EBC table the weight is taken from
|
table_name: Name of the EBC table the weight is taken from.
|
||||||
weight_tensor: embedding weight tensor
|
weight_tensor: Embedding weight tensor.
|
||||||
"""
|
"""
|
||||||
logging.info(f"{weight_name}, {table_name}", rank=-1)
|
logging.info(f"{weight_name}, {table_name}", rank=-1)
|
||||||
logging.info(f"{weight_tensor.metadata()}", rank=-1)
|
logging.info(f"{weight_tensor.metadata()}", rank=-1)
|
||||||
|
@ -11,8 +11,10 @@ class ExplicitDateInputs(base_config.BaseConfig):
|
|||||||
"""Arguments to select train/validation data using end_date and days of data."""
|
"""Arguments to select train/validation data using end_date and days of data."""
|
||||||
|
|
||||||
data_root: str = pydantic.Field(..., description="Data path prefix.")
|
data_root: str = pydantic.Field(..., description="Data path prefix.")
|
||||||
end_date: str = pydantic.Field(..., description="Data end date, inclusive.")
|
end_date: str = pydantic.Field(...,
|
||||||
days: int = pydantic.Field(..., description="Number of days of data for dataset.")
|
description="Data end date, inclusive.")
|
||||||
|
days: int = pydantic.Field(...,
|
||||||
|
description="Number of days of data for dataset.")
|
||||||
num_missing_days_tol: int = pydantic.Field(
|
num_missing_days_tol: int = pydantic.Field(
|
||||||
0, description="We tolerate <= num_missing_days_tol days of missing data."
|
0, description="We tolerate <= num_missing_days_tol days of missing data."
|
||||||
)
|
)
|
||||||
@ -22,8 +24,10 @@ class ExplicitDatetimeInputs(base_config.BaseConfig):
|
|||||||
"""Arguments to select train/validation data using end_datetime and hours of data."""
|
"""Arguments to select train/validation data using end_datetime and hours of data."""
|
||||||
|
|
||||||
data_root: str = pydantic.Field(..., description="Data path prefix.")
|
data_root: str = pydantic.Field(..., description="Data path prefix.")
|
||||||
end_datetime: str = pydantic.Field(..., description="Data end datetime, inclusive.")
|
end_datetime: str = pydantic.Field(...,
|
||||||
hours: int = pydantic.Field(..., description="Number of hours of data for dataset.")
|
description="Data end datetime, inclusive.")
|
||||||
|
hours: int = pydantic.Field(...,
|
||||||
|
description="Number of hours of data for dataset.")
|
||||||
num_missing_hours_tol: int = pydantic.Field(
|
num_missing_hours_tol: int = pydantic.Field(
|
||||||
0, description="We tolerate <= num_missing_hours_tol hours of missing data."
|
0, description="We tolerate <= num_missing_hours_tol hours of missing data."
|
||||||
)
|
)
|
||||||
@ -42,7 +46,8 @@ class DatasetConfig(base_config.BaseConfig):
|
|||||||
explicit_datetime_inputs: ExplicitDatetimeInputs = pydantic.Field(
|
explicit_datetime_inputs: ExplicitDatetimeInputs = pydantic.Field(
|
||||||
None, one_of="date_inputs_format"
|
None, one_of="date_inputs_format"
|
||||||
)
|
)
|
||||||
explicit_date_inputs: ExplicitDateInputs = pydantic.Field(None, one_of="date_inputs_format")
|
explicit_date_inputs: ExplicitDateInputs = pydantic.Field(
|
||||||
|
None, one_of="date_inputs_format")
|
||||||
|
|
||||||
global_batch_size: pydantic.PositiveInt
|
global_batch_size: pydantic.PositiveInt
|
||||||
|
|
||||||
@ -52,7 +57,8 @@ class DatasetConfig(base_config.BaseConfig):
|
|||||||
repeat_files: bool = pydantic.Field(
|
repeat_files: bool = pydantic.Field(
|
||||||
True, description="DEPRICATED. Files are repeated no matter what this is set to."
|
True, description="DEPRICATED. Files are repeated no matter what this is set to."
|
||||||
)
|
)
|
||||||
file_batch_size: pydantic.PositiveInt = pydantic.Field(16, description="File batch size")
|
file_batch_size: pydantic.PositiveInt = pydantic.Field(
|
||||||
|
16, description="File batch size")
|
||||||
|
|
||||||
cache: bool = pydantic.Field(
|
cache: bool = pydantic.Field(
|
||||||
False,
|
False,
|
||||||
@ -70,7 +76,8 @@ class DatasetConfig(base_config.BaseConfig):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# tf.data.Dataset options
|
# tf.data.Dataset options
|
||||||
examples_shuffle_buffer_size: int = pydantic.Field(1024, description="Size of shuffle buffers.")
|
examples_shuffle_buffer_size: int = pydantic.Field(
|
||||||
|
1024, description="Size of shuffle buffers.")
|
||||||
map_num_parallel_calls: pydantic.PositiveInt = pydantic.Field(
|
map_num_parallel_calls: pydantic.PositiveInt = pydantic.Field(
|
||||||
None, description="Number of parallel calls."
|
None, description="Number of parallel calls."
|
||||||
)
|
)
|
||||||
@ -125,7 +132,8 @@ class TaskData(base_config.BaseConfig):
|
|||||||
|
|
||||||
|
|
||||||
class SegDenseSchema(base_config.BaseConfig):
|
class SegDenseSchema(base_config.BaseConfig):
|
||||||
schema_path: str = pydantic.Field(..., description="Path to feature config json.")
|
schema_path: str = pydantic.Field(...,
|
||||||
|
description="Path to feature config json.")
|
||||||
features: typing.List[str] = pydantic.Field(
|
features: typing.List[str] = pydantic.Field(
|
||||||
[],
|
[],
|
||||||
description="List of features (in addition to the renamed features) to read from schema path above.",
|
description="List of features (in addition to the renamed features) to read from schema path above.",
|
||||||
@ -192,8 +200,10 @@ class DownsampleNegatives(base_config.BaseConfig):
|
|||||||
|
|
||||||
|
|
||||||
class Preprocess(base_config.BaseConfig):
|
class Preprocess(base_config.BaseConfig):
|
||||||
truncate_and_slice: TruncateAndSlice = pydantic.Field(None, description="Truncation and slicing.")
|
truncate_and_slice: TruncateAndSlice = pydantic.Field(
|
||||||
downcast: DownCast = pydantic.Field(None, description="Down cast to features.")
|
None, description="Truncation and slicing.")
|
||||||
|
downcast: DownCast = pydantic.Field(
|
||||||
|
None, description="Down cast to features.")
|
||||||
rectify_labels: RectifyLabels = pydantic.Field(
|
rectify_labels: RectifyLabels = pydantic.Field(
|
||||||
None, description="Rectify labels for a given overlap window"
|
None, description="Rectify labels for a given overlap window"
|
||||||
)
|
)
|
||||||
@ -242,5 +252,6 @@ class RecapDataConfig(DatasetConfig):
|
|||||||
if values.get("evaluation_tasks") is not None:
|
if values.get("evaluation_tasks") is not None:
|
||||||
for task in values["evaluation_tasks"]:
|
for task in values["evaluation_tasks"]:
|
||||||
if task not in values["tasks"]:
|
if task not in values["tasks"]:
|
||||||
raise KeyError(f"Evaluation task {task} must be in tasks. Received {values['tasks']}")
|
raise KeyError(
|
||||||
|
f"Evaluation task {task} must be in tasks. Received {values['tasks']}")
|
||||||
return values
|
return values
|
||||||
|
@ -9,9 +9,20 @@ import numpy as np
|
|||||||
|
|
||||||
|
|
||||||
class TruncateAndSlice(tf.keras.Model):
|
class TruncateAndSlice(tf.keras.Model):
|
||||||
"""Class for truncating and slicing."""
|
"""
|
||||||
|
A class for truncating and slicing input features based on the provided configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
truncate_and_slice_config: A configuration object specifying how to truncate and slice features.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, truncate_and_slice_config):
|
def __init__(self, truncate_and_slice_config):
|
||||||
|
"""
|
||||||
|
Initializes the TruncateAndSlice model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
truncate_and_slice_config: A configuration object specifying how to truncate and slice features.
|
||||||
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._truncate_and_slice_config = truncate_and_slice_config
|
self._truncate_and_slice_config = truncate_and_slice_config
|
||||||
|
|
||||||
@ -32,6 +43,17 @@ class TruncateAndSlice(tf.keras.Model):
|
|||||||
self._binary_mask = None
|
self._binary_mask = None
|
||||||
|
|
||||||
def call(self, inputs, training=None, mask=None):
|
def call(self, inputs, training=None, mask=None):
|
||||||
|
"""
|
||||||
|
Applies truncation and slicing to the input features based on the configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs: A dictionary of input features.
|
||||||
|
training: A boolean indicating whether the model is in training mode.
|
||||||
|
mask: A mask tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary of truncated and sliced input features.
|
||||||
|
"""
|
||||||
outputs = tf.nest.pack_sequence_as(inputs, tf.nest.flatten(inputs))
|
outputs = tf.nest.pack_sequence_as(inputs, tf.nest.flatten(inputs))
|
||||||
if self._truncate_and_slice_config.continuous_feature_truncation:
|
if self._truncate_and_slice_config.continuous_feature_truncation:
|
||||||
logging.info("Truncating continuous")
|
logging.info("Truncating continuous")
|
||||||
@ -51,12 +73,23 @@ class TruncateAndSlice(tf.keras.Model):
|
|||||||
|
|
||||||
|
|
||||||
class DownCast(tf.keras.Model):
|
class DownCast(tf.keras.Model):
|
||||||
"""Class for Down casting dataset before serialization and transferring to training host.
|
"""
|
||||||
Depends on the data type and the actual data range, the down casting can be lossless or not.
|
A class for downcasting dataset before serialization and transferring to the training host.
|
||||||
It is strongly recommended to compare the metrics before and after down casting.
|
|
||||||
|
Depending on the data type and the actual data range, the downcasting can be lossless or not.
|
||||||
|
It is strongly recommended to compare the metrics before and after downcasting.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
downcast_config: A configuration object specifying the features and their target data types.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, downcast_config):
|
def __init__(self, downcast_config):
|
||||||
|
"""
|
||||||
|
Initializes the DownCast model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
downcast_config: A configuration object specifying the features and their target data types.
|
||||||
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = downcast_config
|
self.config = downcast_config
|
||||||
self._type_map = {
|
self._type_map = {
|
||||||
@ -65,6 +98,17 @@ class DownCast(tf.keras.Model):
|
|||||||
}
|
}
|
||||||
|
|
||||||
def call(self, inputs, training=None, mask=None):
|
def call(self, inputs, training=None, mask=None):
|
||||||
|
"""
|
||||||
|
Applies downcasting to the input features based on the configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs: A dictionary of input features.
|
||||||
|
training: A boolean indicating whether the model is in training mode.
|
||||||
|
mask: A mask tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary of downcasted input features.
|
||||||
|
"""
|
||||||
outputs = tf.nest.pack_sequence_as(inputs, tf.nest.flatten(inputs))
|
outputs = tf.nest.pack_sequence_as(inputs, tf.nest.flatten(inputs))
|
||||||
for feature, type_str in self.config.features.items():
|
for feature, type_str in self.config.features.items():
|
||||||
assert type_str in self._type_map
|
assert type_str in self._type_map
|
||||||
@ -78,14 +122,39 @@ class DownCast(tf.keras.Model):
|
|||||||
|
|
||||||
|
|
||||||
class RectifyLabels(tf.keras.Model):
|
class RectifyLabels(tf.keras.Model):
|
||||||
"""Class for rectifying labels"""
|
"""
|
||||||
|
A class for downcasting dataset before serialization and transferring to the training host.
|
||||||
|
|
||||||
|
Depending on the data type and the actual data range, the downcasting can be lossless or not.
|
||||||
|
It is strongly recommended to compare the metrics before and after downcasting.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
downcast_config: A configuration object specifying the features and their target data types.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, rectify_label_config):
|
def __init__(self, rectify_label_config):
|
||||||
|
"""
|
||||||
|
Initializes the DownCast model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
downcast_config: A configuration object specifying the features and their target data types.
|
||||||
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._config = rectify_label_config
|
self._config = rectify_label_config
|
||||||
self._window = int(self._config.label_rectification_window_in_hours * 60 * 60 * 1000)
|
self._window = int(self._config.label_rectification_window_in_hours * 60 * 60 * 1000)
|
||||||
|
|
||||||
def call(self, inputs, training=None, mask=None):
|
def call(self, inputs, training=None, mask=None):
|
||||||
|
"""
|
||||||
|
Applies downcasting to the input features based on the configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs: A dictionary of input features.
|
||||||
|
training: A boolean indicating whether the model is in training mode.
|
||||||
|
mask: A mask tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary of downcasted input features.
|
||||||
|
"""
|
||||||
served_ts_field = self._config.served_timestamp_field
|
served_ts_field = self._config.served_timestamp_field
|
||||||
impressed_ts_field = self._config.impressed_timestamp_field
|
impressed_ts_field = self._config.impressed_timestamp_field
|
||||||
|
|
||||||
@ -102,13 +171,37 @@ class RectifyLabels(tf.keras.Model):
|
|||||||
|
|
||||||
|
|
||||||
class ExtractFeatures(tf.keras.Model):
|
class ExtractFeatures(tf.keras.Model):
|
||||||
"""Class for extracting individual features from dense tensors by their index."""
|
"""
|
||||||
|
A class for rectifying labels based on specified conditions.
|
||||||
|
|
||||||
|
This class is used to adjust label values in a dataset based on configured conditions involving timestamps.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
rectify_label_config: A configuration object specifying the timestamp fields and label-to-engaged timestamp field mappings.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, extract_features_config):
|
def __init__(self, extract_features_config):
|
||||||
|
"""
|
||||||
|
Initializes the RectifyLabels model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
rectify_label_config: A configuration object specifying the timestamp fields and label-to-engaged timestamp field mappings.
|
||||||
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._config = extract_features_config
|
self._config = extract_features_config
|
||||||
|
|
||||||
def call(self, inputs, training=None, mask=None):
|
def call(self, inputs, training=None, mask=None):
|
||||||
|
"""
|
||||||
|
Rectifies label values based on the specified conditions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs: A dictionary of input features including timestamp fields and labels.
|
||||||
|
training: A boolean indicating whether the model is in training mode.
|
||||||
|
mask: A mask tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary of input features with rectified label values.
|
||||||
|
"""
|
||||||
|
|
||||||
for row in self._config.extract_feature_table:
|
for row in self._config.extract_feature_table:
|
||||||
inputs[row.name] = inputs[row.source_tensor][:, row.index]
|
inputs[row.name] = inputs[row.source_tensor][:, row.index]
|
||||||
@ -168,7 +261,16 @@ class DownsampleNegatives(tf.keras.Model):
|
|||||||
|
|
||||||
|
|
||||||
def build_preprocess(preprocess_config, mode=config_mod.JobMode.TRAIN):
|
def build_preprocess(preprocess_config, mode=config_mod.JobMode.TRAIN):
|
||||||
"""Builds a preprocess model to apply all preprocessing stages."""
|
"""
|
||||||
|
Builds a preprocess model to apply all preprocessing stages.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
preprocess_config: A configuration object specifying the preprocessing parameters.
|
||||||
|
mode: A mode indicating the current job mode (TRAIN or INFERENCE).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A preprocess model that applies all specified preprocessing stages.
|
||||||
|
"""
|
||||||
if mode == config_mod.JobMode.INFERENCE:
|
if mode == config_mod.JobMode.INFERENCE:
|
||||||
logging.info("Not building preprocessors for dataloading since we are in Inference mode.")
|
logging.info("Not building preprocessors for dataloading since we are in Inference mode.")
|
||||||
return None
|
return None
|
||||||
|
@ -8,7 +8,8 @@ import tensorflow as tf
|
|||||||
|
|
||||||
|
|
||||||
DEFAULTS_MAP = {"int64_list": 0, "float_list": 0.0, "bytes_list": ""}
|
DEFAULTS_MAP = {"int64_list": 0, "float_list": 0.0, "bytes_list": ""}
|
||||||
DTYPE_MAP = {"int64_list": tf.int64, "float_list": tf.float32, "bytes_list": tf.string}
|
DTYPE_MAP = {"int64_list": tf.int64,
|
||||||
|
"float_list": tf.float32, "bytes_list": tf.string}
|
||||||
|
|
||||||
|
|
||||||
def create_tf_example_schema(
|
def create_tf_example_schema(
|
||||||
@ -27,7 +28,8 @@ def create_tf_example_schema(
|
|||||||
segdense_config = data_config.seg_dense_schema
|
segdense_config = data_config.seg_dense_schema
|
||||||
labels = list(data_config.tasks.keys())
|
labels = list(data_config.tasks.keys())
|
||||||
used_features = (
|
used_features = (
|
||||||
segdense_config.features + list(segdense_config.renamed_features.values()) + labels
|
segdense_config.features +
|
||||||
|
list(segdense_config.renamed_features.values()) + labels
|
||||||
)
|
)
|
||||||
logging.info(used_features)
|
logging.info(used_features)
|
||||||
|
|
||||||
@ -40,19 +42,22 @@ def create_tf_example_schema(
|
|||||||
dtype = entry["dtype"]
|
dtype = entry["dtype"]
|
||||||
|
|
||||||
if feature_name in labels:
|
if feature_name in labels:
|
||||||
logging.info(f"Label: feature name is {feature_name} type is {dtype}")
|
logging.info(
|
||||||
|
f"Label: feature name is {feature_name} type is {dtype}")
|
||||||
tfe_schema[feature_name] = tf.io.FixedLenFeature(
|
tfe_schema[feature_name] = tf.io.FixedLenFeature(
|
||||||
length, DTYPE_MAP[dtype], DEFAULTS_MAP[dtype]
|
length, DTYPE_MAP[dtype], DEFAULTS_MAP[dtype]
|
||||||
)
|
)
|
||||||
elif length == -1:
|
elif length == -1:
|
||||||
tfe_schema[feature_name] = tf.io.VarLenFeature(DTYPE_MAP[dtype])
|
tfe_schema[feature_name] = tf.io.VarLenFeature(
|
||||||
|
DTYPE_MAP[dtype])
|
||||||
else:
|
else:
|
||||||
tfe_schema[feature_name] = tf.io.FixedLenFeature(
|
tfe_schema[feature_name] = tf.io.FixedLenFeature(
|
||||||
length, DTYPE_MAP[dtype], [DEFAULTS_MAP[dtype]] * length
|
length, DTYPE_MAP[dtype], [DEFAULTS_MAP[dtype]] * length
|
||||||
)
|
)
|
||||||
for feature_name in used_features:
|
for feature_name in used_features:
|
||||||
if feature_name not in tfe_schema:
|
if feature_name not in tfe_schema:
|
||||||
raise ValueError(f"{feature_name} missing from schema: {segdense_config.schema_path}.")
|
raise ValueError(
|
||||||
|
f"{feature_name} missing from schema: {segdense_config.schema_path}.")
|
||||||
return tfe_schema
|
return tfe_schema
|
||||||
|
|
||||||
|
|
||||||
@ -82,7 +87,8 @@ def parse_tf_example(
|
|||||||
Returns:
|
Returns:
|
||||||
Dictionary of tensors to be used as model input.
|
Dictionary of tensors to be used as model input.
|
||||||
"""
|
"""
|
||||||
inputs = tf.io.parse_example(serialized=serialized_example, features=tfe_schema)
|
inputs = tf.io.parse_example(
|
||||||
|
serialized=serialized_example, features=tfe_schema)
|
||||||
|
|
||||||
for new_feature_name, old_feature_name in seg_dense_schema_config.renamed_features.items():
|
for new_feature_name, old_feature_name in seg_dense_schema_config.renamed_features.items():
|
||||||
inputs[new_feature_name] = inputs.pop(old_feature_name)
|
inputs[new_feature_name] = inputs.pop(old_feature_name)
|
||||||
@ -90,7 +96,8 @@ def parse_tf_example(
|
|||||||
# This should not actually be used except for experimentation with low precision floats.
|
# This should not actually be used except for experimentation with low precision floats.
|
||||||
if "mask_mantissa_features" in seg_dense_schema_config:
|
if "mask_mantissa_features" in seg_dense_schema_config:
|
||||||
for feature_name, mask_length in seg_dense_schema_config.mask_mantissa_features.items():
|
for feature_name, mask_length in seg_dense_schema_config.mask_mantissa_features.items():
|
||||||
inputs[feature_name] = mask_mantissa(inputs[feature_name], mask_length)
|
inputs[feature_name] = mask_mantissa(
|
||||||
|
inputs[feature_name], mask_length)
|
||||||
|
|
||||||
# DANGER DANGER: This default seems really scary, and it's only here because it has to be visible
|
# DANGER DANGER: This default seems really scary, and it's only here because it has to be visible
|
||||||
# at TF level.
|
# at TF level.
|
||||||
|
@ -9,44 +9,59 @@ def keyed_tensor_from_tensors_dict(
|
|||||||
tensor_map: Mapping[str, torch.Tensor]
|
tensor_map: Mapping[str, torch.Tensor]
|
||||||
) -> "torchrec.KeyedTensor":
|
) -> "torchrec.KeyedTensor":
|
||||||
"""
|
"""
|
||||||
Convert a dictionary of torch tensor to torchrec keyed tensor
|
Convert a dictionary of torch tensors to a torchrec KeyedTensor.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tensor_map:
|
tensor_map: A mapping of tensor names to torch tensors.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
A torchrec KeyedTensor.
|
||||||
"""
|
"""
|
||||||
keys = list(tensor_map.keys())
|
keys = list(tensor_map.keys())
|
||||||
# We expect batch size to be first dim. However, if we get a shape [Batch_size],
|
# We expect batch size to be first dim. However, if we get a shape [Batch_size],
|
||||||
# KeyedTensor will not find the correct batch_size. So, in those cases we make sure the shape is
|
# KeyedTensor will not find the correct batch_size. So, in those cases we make sure the shape is
|
||||||
# [Batch_size x 1].
|
# [Batch_size x 1].
|
||||||
values = [
|
values = [
|
||||||
tensor_map[key] if len(tensor_map[key].shape) > 1 else torch.unsqueeze(tensor_map[key], -1)
|
tensor_map[key] if len(tensor_map[key].shape) > 1 else torch.unsqueeze(
|
||||||
|
tensor_map[key], -1)
|
||||||
for key in keys
|
for key in keys
|
||||||
]
|
]
|
||||||
return torchrec.KeyedTensor.from_tensor_list(keys, values)
|
return torchrec.KeyedTensor.from_tensor_list(keys, values)
|
||||||
|
|
||||||
|
|
||||||
def _compute_jagged_tensor_from_tensor(tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
def _compute_jagged_tensor_from_tensor(tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Compute a jagged tensor from a torch tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor: Input torch tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple containing the values and lengths of the jagged tensor.
|
||||||
|
"""
|
||||||
if tensor.is_sparse:
|
if tensor.is_sparse:
|
||||||
x = tensor.coalesce() # Ensure that the indices are ordered.
|
x = tensor.coalesce() # Ensure that the indices are ordered.
|
||||||
lengths = torch.bincount(x.indices()[0])
|
lengths = torch.bincount(x.indices()[0])
|
||||||
values = x.values()
|
values = x.values()
|
||||||
else:
|
else:
|
||||||
values = tensor
|
values = tensor
|
||||||
lengths = torch.ones(tensor.shape[0], dtype=torch.int32, device=tensor.device)
|
lengths = torch.ones(
|
||||||
|
tensor.shape[0], dtype=torch.int32, device=tensor.device)
|
||||||
return values, lengths
|
return values, lengths
|
||||||
|
|
||||||
|
|
||||||
def jagged_tensor_from_tensor(tensor: torch.Tensor) -> "torchrec.JaggedTensor":
|
def jagged_tensor_from_tensor(tensor: torch.Tensor) -> "torchrec.JaggedTensor":
|
||||||
"""
|
"""
|
||||||
Convert a torch tensor to torchrec jagged tensor.
|
Convert a torch tensor to a torchrec jagged tensor.
|
||||||
Note: Currently only support shape of [Batch_size] or [Batch_size x N] for dense tensors.
|
|
||||||
For sparse tensor the shape of .values() should be [Batch_size] or [Batch_size x N]; the
|
Note: Currently, this function only supports input tensors with shapes of [Batch_size] or [Batch_size x N] for dense tensors.
|
||||||
dense_shape of the sparse tensor can be arbitrary.
|
For sparse tensors, the shape of .values() should be [Batch_size] or [Batch_size x N], and the dense_shape of the sparse tensor can be arbitrary.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tensor: a torch (sparse) tensor.
|
tensor: A torch (sparse) tensor.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
A torchrec JaggedTensor.
|
||||||
"""
|
"""
|
||||||
values, lengths = _compute_jagged_tensor_from_tensor(tensor)
|
values, lengths = _compute_jagged_tensor_from_tensor(tensor)
|
||||||
return torchrec.JaggedTensor(values=values, lengths=lengths)
|
return torchrec.JaggedTensor(values=values, lengths=lengths)
|
||||||
@ -56,15 +71,16 @@ def keyed_jagged_tensor_from_tensors_dict(
|
|||||||
tensor_map: Mapping[str, torch.Tensor]
|
tensor_map: Mapping[str, torch.Tensor]
|
||||||
) -> "torchrec.KeyedJaggedTensor":
|
) -> "torchrec.KeyedJaggedTensor":
|
||||||
"""
|
"""
|
||||||
Convert a dictionary of (sparse) torch tensors to torchrec keyed jagged tensor.
|
Convert a dictionary of (sparse) torch tensors to a torchrec keyed jagged tensor.
|
||||||
Note: Currently only support shape of [Batch_size] or [Batch_size x 1] for dense tensors.
|
|
||||||
For sparse tensor the shape of .values() should be [Batch_size] or [Batch_size x 1]; the
|
Note: Currently, this function only supports input tensors with shapes of [Batch_size] or [Batch_size x 1] for dense tensors.
|
||||||
dense_shape of the sparse tensor can be arbitrary.
|
For sparse tensors, the shape of .values() should be [Batch_size] or [Batch_size x 1], and the dense_shape of the sparse tensor can be arbitrary.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tensor_map:
|
tensor_map: A mapping of tensor names to torch tensors.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
A torchrec KeyedJaggedTensor.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if not tensor_map:
|
if not tensor_map:
|
||||||
@ -91,10 +107,29 @@ def keyed_jagged_tensor_from_tensors_dict(
|
|||||||
|
|
||||||
|
|
||||||
def _tf_to_numpy(tf_tensor: tf.Tensor) -> np.ndarray:
|
def _tf_to_numpy(tf_tensor: tf.Tensor) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Convert a TensorFlow tensor to a NumPy array.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tf_tensor: TensorFlow tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
NumPy array.
|
||||||
|
"""
|
||||||
return tf_tensor._numpy() # noqa
|
return tf_tensor._numpy() # noqa
|
||||||
|
|
||||||
|
|
||||||
def _dense_tf_to_torch(tensor: tf.Tensor, pin_memory: bool) -> torch.Tensor:
|
def _dense_tf_to_torch(tensor: tf.Tensor, pin_memory: bool) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Convert a dense TensorFlow tensor to a PyTorch tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor: Dense TensorFlow tensor.
|
||||||
|
pin_memory: Whether to pin the tensor in memory (for CUDA).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PyTorch tensor.
|
||||||
|
"""
|
||||||
tensor = _tf_to_numpy(tensor)
|
tensor = _tf_to_numpy(tensor)
|
||||||
# Pytorch does not support bfloat16, up cast to float32 to keep the same number of bits on exponent
|
# Pytorch does not support bfloat16, up cast to float32 to keep the same number of bits on exponent
|
||||||
if tensor.dtype.name == "bfloat16":
|
if tensor.dtype.name == "bfloat16":
|
||||||
@ -109,6 +144,16 @@ def _dense_tf_to_torch(tensor: tf.Tensor, pin_memory: bool) -> torch.Tensor:
|
|||||||
def sparse_or_dense_tf_to_torch(
|
def sparse_or_dense_tf_to_torch(
|
||||||
tensor: Union[tf.Tensor, tf.SparseTensor], pin_memory: bool
|
tensor: Union[tf.Tensor, tf.SparseTensor], pin_memory: bool
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Convert a TensorFlow tensor (sparse or dense) to a PyTorch tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor: TensorFlow tensor (sparse or dense).
|
||||||
|
pin_memory: Whether to pin the tensor in memory (for CUDA).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PyTorch tensor.
|
||||||
|
"""
|
||||||
if isinstance(tensor, tf.SparseTensor):
|
if isinstance(tensor, tf.SparseTensor):
|
||||||
tensor = torch.sparse_coo_tensor(
|
tensor = torch.sparse_coo_tensor(
|
||||||
_dense_tf_to_torch(tensor.indices, pin_memory).t(),
|
_dense_tf_to_torch(tensor.indices, pin_memory).t(),
|
||||||
|
@ -69,6 +69,7 @@ class MaskBlock(torch.nn.Module):
|
|||||||
This class is intended for internal use within neural network architectures and should not be
|
This class is intended for internal use within neural network architectures and should not be
|
||||||
directly accessed or modified by external code.
|
directly accessed or modified by external code.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, mask_block_config: config.MaskBlockConfig, input_dim: int, mask_input_dim: int
|
self, mask_block_config: config.MaskBlockConfig, input_dim: int, mask_input_dim: int
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -94,11 +95,13 @@ class MaskBlock(torch.nn.Module):
|
|||||||
self._input_layer_norm = None
|
self._input_layer_norm = None
|
||||||
|
|
||||||
if mask_block_config.reduction_factor:
|
if mask_block_config.reduction_factor:
|
||||||
aggregation_size = int(mask_input_dim * mask_block_config.reduction_factor)
|
aggregation_size = int(
|
||||||
|
mask_input_dim * mask_block_config.reduction_factor)
|
||||||
elif mask_block_config.aggregation_size is not None:
|
elif mask_block_config.aggregation_size is not None:
|
||||||
aggregation_size = mask_block_config.aggregation_size
|
aggregation_size = mask_block_config.aggregation_size
|
||||||
else:
|
else:
|
||||||
raise ValueError("Need one of reduction factor or aggregation size.")
|
raise ValueError(
|
||||||
|
"Need one of reduction factor or aggregation size.")
|
||||||
|
|
||||||
self._mask_layer = torch.nn.Sequential(
|
self._mask_layer = torch.nn.Sequential(
|
||||||
torch.nn.Linear(mask_input_dim, aggregation_size),
|
torch.nn.Linear(mask_input_dim, aggregation_size),
|
||||||
@ -123,7 +126,8 @@ class MaskBlock(torch.nn.Module):
|
|||||||
"""
|
"""
|
||||||
if self._input_layer_norm:
|
if self._input_layer_norm:
|
||||||
net = self._input_layer_norm(net)
|
net = self._input_layer_norm(net)
|
||||||
hidden_layer_output = self._hidden_layer(net * self._mask_layer(mask_input))
|
hidden_layer_output = self._hidden_layer(
|
||||||
|
net * self._mask_layer(mask_input))
|
||||||
return self._layer_norm(hidden_layer_output)
|
return self._layer_norm(hidden_layer_output)
|
||||||
|
|
||||||
|
|
||||||
@ -170,6 +174,7 @@ class MaskNet(torch.nn.Module):
|
|||||||
This class is intended for internal use within neural network architectures and should not be
|
This class is intended for internal use within neural network architectures and should not be
|
||||||
directly accessed or modified by external code.
|
directly accessed or modified by external code.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, mask_net_config: config.MaskNetConfig, in_features: int):
|
def __init__(self, mask_net_config: config.MaskNetConfig, in_features: int):
|
||||||
"""
|
"""
|
||||||
Initializes the MaskNet module.
|
Initializes the MaskNet module.
|
||||||
@ -189,20 +194,23 @@ class MaskNet(torch.nn.Module):
|
|||||||
if mask_net_config.use_parallel:
|
if mask_net_config.use_parallel:
|
||||||
total_output_mask_blocks = 0
|
total_output_mask_blocks = 0
|
||||||
for mask_block_config in mask_net_config.mask_blocks:
|
for mask_block_config in mask_net_config.mask_blocks:
|
||||||
mask_blocks.append(MaskBlock(mask_block_config, in_features, in_features))
|
mask_blocks.append(
|
||||||
|
MaskBlock(mask_block_config, in_features, in_features))
|
||||||
total_output_mask_blocks += mask_block_config.output_size
|
total_output_mask_blocks += mask_block_config.output_size
|
||||||
self._mask_blocks = torch.nn.ModuleList(mask_blocks)
|
self._mask_blocks = torch.nn.ModuleList(mask_blocks)
|
||||||
else:
|
else:
|
||||||
input_size = in_features
|
input_size = in_features
|
||||||
for mask_block_config in mask_net_config.mask_blocks:
|
for mask_block_config in mask_net_config.mask_blocks:
|
||||||
mask_blocks.append(MaskBlock(mask_block_config, input_size, in_features))
|
mask_blocks.append(
|
||||||
|
MaskBlock(mask_block_config, input_size, in_features))
|
||||||
input_size = mask_block_config.output_size
|
input_size = mask_block_config.output_size
|
||||||
|
|
||||||
self._mask_blocks = torch.nn.ModuleList(mask_blocks)
|
self._mask_blocks = torch.nn.ModuleList(mask_blocks)
|
||||||
total_output_mask_blocks = mask_block_config.output_size
|
total_output_mask_blocks = mask_block_config.output_size
|
||||||
|
|
||||||
if mask_net_config.mlp:
|
if mask_net_config.mlp:
|
||||||
self._dense_layers = mlp.Mlp(total_output_mask_blocks, mask_net_config.mlp)
|
self._dense_layers = mlp.Mlp(
|
||||||
|
total_output_mask_blocks, mask_net_config.mlp)
|
||||||
self.out_features = mask_net_config.mlp.layer_sizes[-1]
|
self.out_features = mask_net_config.mlp.layer_sizes[-1]
|
||||||
else:
|
else:
|
||||||
self.out_features = total_output_mask_blocks
|
self.out_features = total_output_mask_blocks
|
||||||
@ -235,5 +243,6 @@ class MaskNet(torch.nn.Module):
|
|||||||
for mask_layer in self._mask_blocks:
|
for mask_layer in self._mask_blocks:
|
||||||
net = mask_layer(net=net, mask_input=inputs)
|
net = mask_layer(net=net, mask_input=inputs)
|
||||||
# Share the output of the stacked MaskBlocks.
|
# Share the output of the stacked MaskBlocks.
|
||||||
output = net if self.mask_net_config.mlp is None else self._dense_layers[net]["output"]
|
output = net if self.mask_net_config.mlp is None else self._dense_layers[
|
||||||
|
net]["output"]
|
||||||
return {"output": output, "shared_layer": net}
|
return {"output": output, "shared_layer": net}
|
||||||
|
@ -52,6 +52,7 @@ class ModelAndLoss(torch.nn.Module):
|
|||||||
This class is intended for internal use within neural network architectures and should not be
|
This class is intended for internal use within neural network architectures and should not be
|
||||||
directly accessed or modified by external code.
|
directly accessed or modified by external code.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model,
|
model,
|
||||||
@ -92,13 +93,16 @@ class ModelAndLoss(torch.nn.Module):
|
|||||||
labels=batch.labels,
|
labels=batch.labels,
|
||||||
weights=batch.weights,
|
weights=batch.weights,
|
||||||
)
|
)
|
||||||
losses = self.loss_fn(outputs["logits"], batch.labels.float(), batch.weights.float())
|
losses = self.loss_fn(
|
||||||
|
outputs["logits"], batch.labels.float(), batch.weights.float())
|
||||||
|
|
||||||
if self.stratifiers:
|
if self.stratifiers:
|
||||||
logging.info(f"***** Adding stratifiers *****\n {self.stratifiers}")
|
logging.info(
|
||||||
|
f"***** Adding stratifiers *****\n {self.stratifiers}")
|
||||||
outputs["stratifiers"] = {}
|
outputs["stratifiers"] = {}
|
||||||
for stratifier in self.stratifiers:
|
for stratifier in self.stratifiers:
|
||||||
outputs["stratifiers"][stratifier.name] = batch.discrete_features[:, stratifier.index]
|
outputs["stratifiers"][stratifier.name] = batch.discrete_features[:,
|
||||||
|
stratifier.index]
|
||||||
|
|
||||||
# In general, we can have a large number of losses returned by our loss function.
|
# In general, we can have a large number of losses returned by our loss function.
|
||||||
if isinstance(losses, dict):
|
if isinstance(losses, dict):
|
||||||
|
@ -39,6 +39,7 @@ class NumericCalibration(torch.nn.Module):
|
|||||||
This class is intended for internal use within neural network architectures and should not be
|
This class is intended for internal use within neural network architectures and should not be
|
||||||
directly accessed or modified by external code.
|
directly accessed or modified by external code.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
pos_downsampling_rate: float,
|
pos_downsampling_rate: float,
|
||||||
|
47
tools/pq.py
47
tools/pq.py
@ -38,6 +38,15 @@ import pyarrow.parquet as pq
|
|||||||
|
|
||||||
|
|
||||||
def _create_dataset(path: str):
|
def _create_dataset(path: str):
|
||||||
|
"""
|
||||||
|
Create a PyArrow dataset from Parquet files located at the specified path.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path (str): The path to the Parquet files.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
pyarrow.dataset.Dataset: The PyArrow dataset.
|
||||||
|
"""
|
||||||
fs = infer_fs(path)
|
fs = infer_fs(path)
|
||||||
files = fs.glob(path)
|
files = fs.glob(path)
|
||||||
return pads.dataset(files, format="parquet", filesystem=fs)
|
return pads.dataset(files, format="parquet", filesystem=fs)
|
||||||
@ -47,12 +56,27 @@ class PqReader:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self, path: str, num: int = 10, batch_size: int = 1024, columns: Optional[List[str]] = None
|
self, path: str, num: int = 10, batch_size: int = 1024, columns: Optional[List[str]] = None
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Initialize a Parquet Reader.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path (str): The path to the Parquet files.
|
||||||
|
num (int): The maximum number of rows to read.
|
||||||
|
batch_size (int): The batch size for reading data.
|
||||||
|
columns (Optional[List[str]]): A list of column names to read (default is None, which reads all columns).
|
||||||
|
"""
|
||||||
self._ds = _create_dataset(path)
|
self._ds = _create_dataset(path)
|
||||||
self._batch_size = batch_size
|
self._batch_size = batch_size
|
||||||
self._num = num
|
self._num = num
|
||||||
self._columns = columns
|
self._columns = columns
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
|
"""
|
||||||
|
Iterate through the Parquet data and yield batches of rows.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
pyarrow.RecordBatch: A batch of rows.
|
||||||
|
"""
|
||||||
batches = self._ds.to_batches(batch_size=self._batch_size, columns=self._columns)
|
batches = self._ds.to_batches(batch_size=self._batch_size, columns=self._columns)
|
||||||
rows_seen = 0
|
rows_seen = 0
|
||||||
for count, record in enumerate(batches):
|
for count, record in enumerate(batches):
|
||||||
@ -62,6 +86,12 @@ class PqReader:
|
|||||||
rows_seen += record.data.num_rows
|
rows_seen += record.data.num_rows
|
||||||
|
|
||||||
def _head(self):
|
def _head(self):
|
||||||
|
"""
|
||||||
|
Get the first `num` rows of the Parquet data.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
pyarrow.RecordBatch: A batch of rows.
|
||||||
|
"""
|
||||||
total_read = self._num * self.bytes_per_row
|
total_read = self._num * self.bytes_per_row
|
||||||
if total_read >= int(500e6):
|
if total_read >= int(500e6):
|
||||||
raise Exception(
|
raise Exception(
|
||||||
@ -71,6 +101,12 @@ class PqReader:
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def bytes_per_row(self) -> int:
|
def bytes_per_row(self) -> int:
|
||||||
|
"""
|
||||||
|
Calculate the estimated bytes per row in the dataset.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: The estimated bytes per row.
|
||||||
|
"""
|
||||||
nbits = 0
|
nbits = 0
|
||||||
for t in self._ds.schema.types:
|
for t in self._ds.schema.types:
|
||||||
try:
|
try:
|
||||||
@ -81,17 +117,22 @@ class PqReader:
|
|||||||
return nbits // 8
|
return nbits // 8
|
||||||
|
|
||||||
def schema(self):
|
def schema(self):
|
||||||
|
"""
|
||||||
|
Display the schema of the Parquet dataset.
|
||||||
|
"""
|
||||||
print(f"\n# Schema\n{self._ds.schema}")
|
print(f"\n# Schema\n{self._ds.schema}")
|
||||||
|
|
||||||
def head(self):
|
def head(self):
|
||||||
"""Displays first --num rows."""
|
"""
|
||||||
|
Display the first `num` rows of the Parquet data as a pandas DataFrame.
|
||||||
|
"""
|
||||||
print(self._head().to_pandas())
|
print(self._head().to_pandas())
|
||||||
|
|
||||||
def distinct(self):
|
def distinct(self):
|
||||||
"""Displays unique values seen in specified columns in the first `--num` rows.
|
"""
|
||||||
|
Display unique values seen in specified columns in the first `num` rows.
|
||||||
|
|
||||||
Useful for getting an approximate vocabulary for certain columns.
|
Useful for getting an approximate vocabulary for certain columns.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
for col_name, column in zip(self._head().column_names, self._head().columns):
|
for col_name, column in zip(self._head().column_names, self._head().columns):
|
||||||
print(col_name)
|
print(col_name)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user