mirror of
https://github.com/twitter/the-algorithm-ml.git
synced 2024-10-01 09:09:52 +02:00
Compare commits
1 Commits
a8258a5da4
...
d3e9477fb0
Author | SHA1 | Date | |
---|---|---|---|
|
d3e9477fb0 |
21
model.py
21
model.py
@ -54,20 +54,12 @@ def maybe_shard_model(
|
|||||||
model,
|
model,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
):
|
):
|
||||||
"""
|
"""Set up and apply DistributedModelParallel to a model if running in a distributed environment.
|
||||||
Set up and apply DistributedModelParallel to a model if running in a distributed environment.
|
|
||||||
|
|
||||||
If in a distributed environment, constructs Topology, sharders, and ShardingPlan, then applies
|
If in a distributed environment, constructs Topology, sharders, and ShardingPlan, then applies
|
||||||
DistributedModelParallel.
|
DistributedModelParallel.
|
||||||
|
|
||||||
If not in a distributed environment, returns the model directly.
|
If not in a distributed environment, returns model directly.
|
||||||
|
|
||||||
Args:
|
|
||||||
model: The PyTorch model.
|
|
||||||
device: The target device (e.g., 'cuda').
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The model wrapped with DistributedModelParallel if in a distributed environment, else the original model.
|
|
||||||
"""
|
"""
|
||||||
if dist.is_initialized():
|
if dist.is_initialized():
|
||||||
logging.info("***** Wrapping in DistributedModelParallel *****")
|
logging.info("***** Wrapping in DistributedModelParallel *****")
|
||||||
@ -82,14 +74,13 @@ def maybe_shard_model(
|
|||||||
|
|
||||||
|
|
||||||
def log_sharded_tensor_content(weight_name: str, table_name: str, weight_tensor) -> None:
|
def log_sharded_tensor_content(weight_name: str, table_name: str, weight_tensor) -> None:
|
||||||
"""
|
"""Handy function to log the content of EBC embedding layer.
|
||||||
Handy function to log the content of an EBC (Embedding Bag Concatenation) embedding layer.
|
|
||||||
Only works for single GPU machines.
|
Only works for single GPU machines.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
weight_name: Name of the tensor, as defined in the model.
|
weight_name: name of tensor, as defined in model
|
||||||
table_name: Name of the EBC table the weight is taken from.
|
table_name: name of the EBC table the weight is taken from
|
||||||
weight_tensor: Embedding weight tensor.
|
weight_tensor: embedding weight tensor
|
||||||
"""
|
"""
|
||||||
logging.info(f"{weight_name}, {table_name}", rank=-1)
|
logging.info(f"{weight_name}, {table_name}", rank=-1)
|
||||||
logging.info(f"{weight_tensor.metadata()}", rank=-1)
|
logging.info(f"{weight_tensor.metadata()}", rank=-1)
|
||||||
|
@ -11,10 +11,8 @@ class ExplicitDateInputs(base_config.BaseConfig):
|
|||||||
"""Arguments to select train/validation data using end_date and days of data."""
|
"""Arguments to select train/validation data using end_date and days of data."""
|
||||||
|
|
||||||
data_root: str = pydantic.Field(..., description="Data path prefix.")
|
data_root: str = pydantic.Field(..., description="Data path prefix.")
|
||||||
end_date: str = pydantic.Field(...,
|
end_date: str = pydantic.Field(..., description="Data end date, inclusive.")
|
||||||
description="Data end date, inclusive.")
|
days: int = pydantic.Field(..., description="Number of days of data for dataset.")
|
||||||
days: int = pydantic.Field(...,
|
|
||||||
description="Number of days of data for dataset.")
|
|
||||||
num_missing_days_tol: int = pydantic.Field(
|
num_missing_days_tol: int = pydantic.Field(
|
||||||
0, description="We tolerate <= num_missing_days_tol days of missing data."
|
0, description="We tolerate <= num_missing_days_tol days of missing data."
|
||||||
)
|
)
|
||||||
@ -24,10 +22,8 @@ class ExplicitDatetimeInputs(base_config.BaseConfig):
|
|||||||
"""Arguments to select train/validation data using end_datetime and hours of data."""
|
"""Arguments to select train/validation data using end_datetime and hours of data."""
|
||||||
|
|
||||||
data_root: str = pydantic.Field(..., description="Data path prefix.")
|
data_root: str = pydantic.Field(..., description="Data path prefix.")
|
||||||
end_datetime: str = pydantic.Field(...,
|
end_datetime: str = pydantic.Field(..., description="Data end datetime, inclusive.")
|
||||||
description="Data end datetime, inclusive.")
|
hours: int = pydantic.Field(..., description="Number of hours of data for dataset.")
|
||||||
hours: int = pydantic.Field(...,
|
|
||||||
description="Number of hours of data for dataset.")
|
|
||||||
num_missing_hours_tol: int = pydantic.Field(
|
num_missing_hours_tol: int = pydantic.Field(
|
||||||
0, description="We tolerate <= num_missing_hours_tol hours of missing data."
|
0, description="We tolerate <= num_missing_hours_tol hours of missing data."
|
||||||
)
|
)
|
||||||
@ -46,8 +42,7 @@ class DatasetConfig(base_config.BaseConfig):
|
|||||||
explicit_datetime_inputs: ExplicitDatetimeInputs = pydantic.Field(
|
explicit_datetime_inputs: ExplicitDatetimeInputs = pydantic.Field(
|
||||||
None, one_of="date_inputs_format"
|
None, one_of="date_inputs_format"
|
||||||
)
|
)
|
||||||
explicit_date_inputs: ExplicitDateInputs = pydantic.Field(
|
explicit_date_inputs: ExplicitDateInputs = pydantic.Field(None, one_of="date_inputs_format")
|
||||||
None, one_of="date_inputs_format")
|
|
||||||
|
|
||||||
global_batch_size: pydantic.PositiveInt
|
global_batch_size: pydantic.PositiveInt
|
||||||
|
|
||||||
@ -57,8 +52,7 @@ class DatasetConfig(base_config.BaseConfig):
|
|||||||
repeat_files: bool = pydantic.Field(
|
repeat_files: bool = pydantic.Field(
|
||||||
True, description="DEPRICATED. Files are repeated no matter what this is set to."
|
True, description="DEPRICATED. Files are repeated no matter what this is set to."
|
||||||
)
|
)
|
||||||
file_batch_size: pydantic.PositiveInt = pydantic.Field(
|
file_batch_size: pydantic.PositiveInt = pydantic.Field(16, description="File batch size")
|
||||||
16, description="File batch size")
|
|
||||||
|
|
||||||
cache: bool = pydantic.Field(
|
cache: bool = pydantic.Field(
|
||||||
False,
|
False,
|
||||||
@ -76,8 +70,7 @@ class DatasetConfig(base_config.BaseConfig):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# tf.data.Dataset options
|
# tf.data.Dataset options
|
||||||
examples_shuffle_buffer_size: int = pydantic.Field(
|
examples_shuffle_buffer_size: int = pydantic.Field(1024, description="Size of shuffle buffers.")
|
||||||
1024, description="Size of shuffle buffers.")
|
|
||||||
map_num_parallel_calls: pydantic.PositiveInt = pydantic.Field(
|
map_num_parallel_calls: pydantic.PositiveInt = pydantic.Field(
|
||||||
None, description="Number of parallel calls."
|
None, description="Number of parallel calls."
|
||||||
)
|
)
|
||||||
@ -132,8 +125,7 @@ class TaskData(base_config.BaseConfig):
|
|||||||
|
|
||||||
|
|
||||||
class SegDenseSchema(base_config.BaseConfig):
|
class SegDenseSchema(base_config.BaseConfig):
|
||||||
schema_path: str = pydantic.Field(...,
|
schema_path: str = pydantic.Field(..., description="Path to feature config json.")
|
||||||
description="Path to feature config json.")
|
|
||||||
features: typing.List[str] = pydantic.Field(
|
features: typing.List[str] = pydantic.Field(
|
||||||
[],
|
[],
|
||||||
description="List of features (in addition to the renamed features) to read from schema path above.",
|
description="List of features (in addition to the renamed features) to read from schema path above.",
|
||||||
@ -200,10 +192,8 @@ class DownsampleNegatives(base_config.BaseConfig):
|
|||||||
|
|
||||||
|
|
||||||
class Preprocess(base_config.BaseConfig):
|
class Preprocess(base_config.BaseConfig):
|
||||||
truncate_and_slice: TruncateAndSlice = pydantic.Field(
|
truncate_and_slice: TruncateAndSlice = pydantic.Field(None, description="Truncation and slicing.")
|
||||||
None, description="Truncation and slicing.")
|
downcast: DownCast = pydantic.Field(None, description="Down cast to features.")
|
||||||
downcast: DownCast = pydantic.Field(
|
|
||||||
None, description="Down cast to features.")
|
|
||||||
rectify_labels: RectifyLabels = pydantic.Field(
|
rectify_labels: RectifyLabels = pydantic.Field(
|
||||||
None, description="Rectify labels for a given overlap window"
|
None, description="Rectify labels for a given overlap window"
|
||||||
)
|
)
|
||||||
@ -252,6 +242,5 @@ class RecapDataConfig(DatasetConfig):
|
|||||||
if values.get("evaluation_tasks") is not None:
|
if values.get("evaluation_tasks") is not None:
|
||||||
for task in values["evaluation_tasks"]:
|
for task in values["evaluation_tasks"]:
|
||||||
if task not in values["tasks"]:
|
if task not in values["tasks"]:
|
||||||
raise KeyError(
|
raise KeyError(f"Evaluation task {task} must be in tasks. Received {values['tasks']}")
|
||||||
f"Evaluation task {task} must be in tasks. Received {values['tasks']}")
|
|
||||||
return values
|
return values
|
||||||
|
@ -9,20 +9,9 @@ import numpy as np
|
|||||||
|
|
||||||
|
|
||||||
class TruncateAndSlice(tf.keras.Model):
|
class TruncateAndSlice(tf.keras.Model):
|
||||||
"""
|
"""Class for truncating and slicing."""
|
||||||
A class for truncating and slicing input features based on the provided configuration.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
truncate_and_slice_config: A configuration object specifying how to truncate and slice features.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, truncate_and_slice_config):
|
def __init__(self, truncate_and_slice_config):
|
||||||
"""
|
|
||||||
Initializes the TruncateAndSlice model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
truncate_and_slice_config: A configuration object specifying how to truncate and slice features.
|
|
||||||
"""
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._truncate_and_slice_config = truncate_and_slice_config
|
self._truncate_and_slice_config = truncate_and_slice_config
|
||||||
|
|
||||||
@ -43,17 +32,6 @@ class TruncateAndSlice(tf.keras.Model):
|
|||||||
self._binary_mask = None
|
self._binary_mask = None
|
||||||
|
|
||||||
def call(self, inputs, training=None, mask=None):
|
def call(self, inputs, training=None, mask=None):
|
||||||
"""
|
|
||||||
Applies truncation and slicing to the input features based on the configuration.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
inputs: A dictionary of input features.
|
|
||||||
training: A boolean indicating whether the model is in training mode.
|
|
||||||
mask: A mask tensor.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A dictionary of truncated and sliced input features.
|
|
||||||
"""
|
|
||||||
outputs = tf.nest.pack_sequence_as(inputs, tf.nest.flatten(inputs))
|
outputs = tf.nest.pack_sequence_as(inputs, tf.nest.flatten(inputs))
|
||||||
if self._truncate_and_slice_config.continuous_feature_truncation:
|
if self._truncate_and_slice_config.continuous_feature_truncation:
|
||||||
logging.info("Truncating continuous")
|
logging.info("Truncating continuous")
|
||||||
@ -73,23 +51,12 @@ class TruncateAndSlice(tf.keras.Model):
|
|||||||
|
|
||||||
|
|
||||||
class DownCast(tf.keras.Model):
|
class DownCast(tf.keras.Model):
|
||||||
"""
|
"""Class for Down casting dataset before serialization and transferring to training host.
|
||||||
A class for downcasting dataset before serialization and transferring to the training host.
|
Depends on the data type and the actual data range, the down casting can be lossless or not.
|
||||||
|
|
||||||
Depending on the data type and the actual data range, the downcasting can be lossless or not.
|
|
||||||
It is strongly recommended to compare the metrics before and after down casting.
|
It is strongly recommended to compare the metrics before and after down casting.
|
||||||
|
|
||||||
Args:
|
|
||||||
downcast_config: A configuration object specifying the features and their target data types.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, downcast_config):
|
def __init__(self, downcast_config):
|
||||||
"""
|
|
||||||
Initializes the DownCast model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
downcast_config: A configuration object specifying the features and their target data types.
|
|
||||||
"""
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = downcast_config
|
self.config = downcast_config
|
||||||
self._type_map = {
|
self._type_map = {
|
||||||
@ -98,17 +65,6 @@ class DownCast(tf.keras.Model):
|
|||||||
}
|
}
|
||||||
|
|
||||||
def call(self, inputs, training=None, mask=None):
|
def call(self, inputs, training=None, mask=None):
|
||||||
"""
|
|
||||||
Applies downcasting to the input features based on the configuration.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
inputs: A dictionary of input features.
|
|
||||||
training: A boolean indicating whether the model is in training mode.
|
|
||||||
mask: A mask tensor.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A dictionary of downcasted input features.
|
|
||||||
"""
|
|
||||||
outputs = tf.nest.pack_sequence_as(inputs, tf.nest.flatten(inputs))
|
outputs = tf.nest.pack_sequence_as(inputs, tf.nest.flatten(inputs))
|
||||||
for feature, type_str in self.config.features.items():
|
for feature, type_str in self.config.features.items():
|
||||||
assert type_str in self._type_map
|
assert type_str in self._type_map
|
||||||
@ -122,39 +78,14 @@ class DownCast(tf.keras.Model):
|
|||||||
|
|
||||||
|
|
||||||
class RectifyLabels(tf.keras.Model):
|
class RectifyLabels(tf.keras.Model):
|
||||||
"""
|
"""Class for rectifying labels"""
|
||||||
A class for downcasting dataset before serialization and transferring to the training host.
|
|
||||||
|
|
||||||
Depending on the data type and the actual data range, the downcasting can be lossless or not.
|
|
||||||
It is strongly recommended to compare the metrics before and after downcasting.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
downcast_config: A configuration object specifying the features and their target data types.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, rectify_label_config):
|
def __init__(self, rectify_label_config):
|
||||||
"""
|
|
||||||
Initializes the DownCast model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
downcast_config: A configuration object specifying the features and their target data types.
|
|
||||||
"""
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._config = rectify_label_config
|
self._config = rectify_label_config
|
||||||
self._window = int(self._config.label_rectification_window_in_hours * 60 * 60 * 1000)
|
self._window = int(self._config.label_rectification_window_in_hours * 60 * 60 * 1000)
|
||||||
|
|
||||||
def call(self, inputs, training=None, mask=None):
|
def call(self, inputs, training=None, mask=None):
|
||||||
"""
|
|
||||||
Applies downcasting to the input features based on the configuration.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
inputs: A dictionary of input features.
|
|
||||||
training: A boolean indicating whether the model is in training mode.
|
|
||||||
mask: A mask tensor.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A dictionary of downcasted input features.
|
|
||||||
"""
|
|
||||||
served_ts_field = self._config.served_timestamp_field
|
served_ts_field = self._config.served_timestamp_field
|
||||||
impressed_ts_field = self._config.impressed_timestamp_field
|
impressed_ts_field = self._config.impressed_timestamp_field
|
||||||
|
|
||||||
@ -171,37 +102,13 @@ class RectifyLabels(tf.keras.Model):
|
|||||||
|
|
||||||
|
|
||||||
class ExtractFeatures(tf.keras.Model):
|
class ExtractFeatures(tf.keras.Model):
|
||||||
"""
|
"""Class for extracting individual features from dense tensors by their index."""
|
||||||
A class for rectifying labels based on specified conditions.
|
|
||||||
|
|
||||||
This class is used to adjust label values in a dataset based on configured conditions involving timestamps.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
rectify_label_config: A configuration object specifying the timestamp fields and label-to-engaged timestamp field mappings.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, extract_features_config):
|
def __init__(self, extract_features_config):
|
||||||
"""
|
|
||||||
Initializes the RectifyLabels model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
rectify_label_config: A configuration object specifying the timestamp fields and label-to-engaged timestamp field mappings.
|
|
||||||
"""
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._config = extract_features_config
|
self._config = extract_features_config
|
||||||
|
|
||||||
def call(self, inputs, training=None, mask=None):
|
def call(self, inputs, training=None, mask=None):
|
||||||
"""
|
|
||||||
Rectifies label values based on the specified conditions.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
inputs: A dictionary of input features including timestamp fields and labels.
|
|
||||||
training: A boolean indicating whether the model is in training mode.
|
|
||||||
mask: A mask tensor.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A dictionary of input features with rectified label values.
|
|
||||||
"""
|
|
||||||
|
|
||||||
for row in self._config.extract_feature_table:
|
for row in self._config.extract_feature_table:
|
||||||
inputs[row.name] = inputs[row.source_tensor][:, row.index]
|
inputs[row.name] = inputs[row.source_tensor][:, row.index]
|
||||||
@ -261,16 +168,7 @@ class DownsampleNegatives(tf.keras.Model):
|
|||||||
|
|
||||||
|
|
||||||
def build_preprocess(preprocess_config, mode=config_mod.JobMode.TRAIN):
|
def build_preprocess(preprocess_config, mode=config_mod.JobMode.TRAIN):
|
||||||
"""
|
"""Builds a preprocess model to apply all preprocessing stages."""
|
||||||
Builds a preprocess model to apply all preprocessing stages.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
preprocess_config: A configuration object specifying the preprocessing parameters.
|
|
||||||
mode: A mode indicating the current job mode (TRAIN or INFERENCE).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A preprocess model that applies all specified preprocessing stages.
|
|
||||||
"""
|
|
||||||
if mode == config_mod.JobMode.INFERENCE:
|
if mode == config_mod.JobMode.INFERENCE:
|
||||||
logging.info("Not building preprocessors for dataloading since we are in Inference mode.")
|
logging.info("Not building preprocessors for dataloading since we are in Inference mode.")
|
||||||
return None
|
return None
|
||||||
|
@ -8,8 +8,7 @@ import tensorflow as tf
|
|||||||
|
|
||||||
|
|
||||||
DEFAULTS_MAP = {"int64_list": 0, "float_list": 0.0, "bytes_list": ""}
|
DEFAULTS_MAP = {"int64_list": 0, "float_list": 0.0, "bytes_list": ""}
|
||||||
DTYPE_MAP = {"int64_list": tf.int64,
|
DTYPE_MAP = {"int64_list": tf.int64, "float_list": tf.float32, "bytes_list": tf.string}
|
||||||
"float_list": tf.float32, "bytes_list": tf.string}
|
|
||||||
|
|
||||||
|
|
||||||
def create_tf_example_schema(
|
def create_tf_example_schema(
|
||||||
@ -28,8 +27,7 @@ def create_tf_example_schema(
|
|||||||
segdense_config = data_config.seg_dense_schema
|
segdense_config = data_config.seg_dense_schema
|
||||||
labels = list(data_config.tasks.keys())
|
labels = list(data_config.tasks.keys())
|
||||||
used_features = (
|
used_features = (
|
||||||
segdense_config.features +
|
segdense_config.features + list(segdense_config.renamed_features.values()) + labels
|
||||||
list(segdense_config.renamed_features.values()) + labels
|
|
||||||
)
|
)
|
||||||
logging.info(used_features)
|
logging.info(used_features)
|
||||||
|
|
||||||
@ -42,22 +40,19 @@ def create_tf_example_schema(
|
|||||||
dtype = entry["dtype"]
|
dtype = entry["dtype"]
|
||||||
|
|
||||||
if feature_name in labels:
|
if feature_name in labels:
|
||||||
logging.info(
|
logging.info(f"Label: feature name is {feature_name} type is {dtype}")
|
||||||
f"Label: feature name is {feature_name} type is {dtype}")
|
|
||||||
tfe_schema[feature_name] = tf.io.FixedLenFeature(
|
tfe_schema[feature_name] = tf.io.FixedLenFeature(
|
||||||
length, DTYPE_MAP[dtype], DEFAULTS_MAP[dtype]
|
length, DTYPE_MAP[dtype], DEFAULTS_MAP[dtype]
|
||||||
)
|
)
|
||||||
elif length == -1:
|
elif length == -1:
|
||||||
tfe_schema[feature_name] = tf.io.VarLenFeature(
|
tfe_schema[feature_name] = tf.io.VarLenFeature(DTYPE_MAP[dtype])
|
||||||
DTYPE_MAP[dtype])
|
|
||||||
else:
|
else:
|
||||||
tfe_schema[feature_name] = tf.io.FixedLenFeature(
|
tfe_schema[feature_name] = tf.io.FixedLenFeature(
|
||||||
length, DTYPE_MAP[dtype], [DEFAULTS_MAP[dtype]] * length
|
length, DTYPE_MAP[dtype], [DEFAULTS_MAP[dtype]] * length
|
||||||
)
|
)
|
||||||
for feature_name in used_features:
|
for feature_name in used_features:
|
||||||
if feature_name not in tfe_schema:
|
if feature_name not in tfe_schema:
|
||||||
raise ValueError(
|
raise ValueError(f"{feature_name} missing from schema: {segdense_config.schema_path}.")
|
||||||
f"{feature_name} missing from schema: {segdense_config.schema_path}.")
|
|
||||||
return tfe_schema
|
return tfe_schema
|
||||||
|
|
||||||
|
|
||||||
@ -87,8 +82,7 @@ def parse_tf_example(
|
|||||||
Returns:
|
Returns:
|
||||||
Dictionary of tensors to be used as model input.
|
Dictionary of tensors to be used as model input.
|
||||||
"""
|
"""
|
||||||
inputs = tf.io.parse_example(
|
inputs = tf.io.parse_example(serialized=serialized_example, features=tfe_schema)
|
||||||
serialized=serialized_example, features=tfe_schema)
|
|
||||||
|
|
||||||
for new_feature_name, old_feature_name in seg_dense_schema_config.renamed_features.items():
|
for new_feature_name, old_feature_name in seg_dense_schema_config.renamed_features.items():
|
||||||
inputs[new_feature_name] = inputs.pop(old_feature_name)
|
inputs[new_feature_name] = inputs.pop(old_feature_name)
|
||||||
@ -96,8 +90,7 @@ def parse_tf_example(
|
|||||||
# This should not actually be used except for experimentation with low precision floats.
|
# This should not actually be used except for experimentation with low precision floats.
|
||||||
if "mask_mantissa_features" in seg_dense_schema_config:
|
if "mask_mantissa_features" in seg_dense_schema_config:
|
||||||
for feature_name, mask_length in seg_dense_schema_config.mask_mantissa_features.items():
|
for feature_name, mask_length in seg_dense_schema_config.mask_mantissa_features.items():
|
||||||
inputs[feature_name] = mask_mantissa(
|
inputs[feature_name] = mask_mantissa(inputs[feature_name], mask_length)
|
||||||
inputs[feature_name], mask_length)
|
|
||||||
|
|
||||||
# DANGER DANGER: This default seems really scary, and it's only here because it has to be visible
|
# DANGER DANGER: This default seems really scary, and it's only here because it has to be visible
|
||||||
# at TF level.
|
# at TF level.
|
||||||
|
@ -9,59 +9,44 @@ def keyed_tensor_from_tensors_dict(
|
|||||||
tensor_map: Mapping[str, torch.Tensor]
|
tensor_map: Mapping[str, torch.Tensor]
|
||||||
) -> "torchrec.KeyedTensor":
|
) -> "torchrec.KeyedTensor":
|
||||||
"""
|
"""
|
||||||
Convert a dictionary of torch tensors to a torchrec KeyedTensor.
|
Convert a dictionary of torch tensor to torchrec keyed tensor
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tensor_map: A mapping of tensor names to torch tensors.
|
tensor_map:
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A torchrec KeyedTensor.
|
|
||||||
"""
|
"""
|
||||||
keys = list(tensor_map.keys())
|
keys = list(tensor_map.keys())
|
||||||
# We expect batch size to be first dim. However, if we get a shape [Batch_size],
|
# We expect batch size to be first dim. However, if we get a shape [Batch_size],
|
||||||
# KeyedTensor will not find the correct batch_size. So, in those cases we make sure the shape is
|
# KeyedTensor will not find the correct batch_size. So, in those cases we make sure the shape is
|
||||||
# [Batch_size x 1].
|
# [Batch_size x 1].
|
||||||
values = [
|
values = [
|
||||||
tensor_map[key] if len(tensor_map[key].shape) > 1 else torch.unsqueeze(
|
tensor_map[key] if len(tensor_map[key].shape) > 1 else torch.unsqueeze(tensor_map[key], -1)
|
||||||
tensor_map[key], -1)
|
|
||||||
for key in keys
|
for key in keys
|
||||||
]
|
]
|
||||||
return torchrec.KeyedTensor.from_tensor_list(keys, values)
|
return torchrec.KeyedTensor.from_tensor_list(keys, values)
|
||||||
|
|
||||||
|
|
||||||
def _compute_jagged_tensor_from_tensor(tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
def _compute_jagged_tensor_from_tensor(tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""
|
|
||||||
Compute a jagged tensor from a torch tensor.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tensor: Input torch tensor.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A tuple containing the values and lengths of the jagged tensor.
|
|
||||||
"""
|
|
||||||
if tensor.is_sparse:
|
if tensor.is_sparse:
|
||||||
x = tensor.coalesce() # Ensure that the indices are ordered.
|
x = tensor.coalesce() # Ensure that the indices are ordered.
|
||||||
lengths = torch.bincount(x.indices()[0])
|
lengths = torch.bincount(x.indices()[0])
|
||||||
values = x.values()
|
values = x.values()
|
||||||
else:
|
else:
|
||||||
values = tensor
|
values = tensor
|
||||||
lengths = torch.ones(
|
lengths = torch.ones(tensor.shape[0], dtype=torch.int32, device=tensor.device)
|
||||||
tensor.shape[0], dtype=torch.int32, device=tensor.device)
|
|
||||||
return values, lengths
|
return values, lengths
|
||||||
|
|
||||||
|
|
||||||
def jagged_tensor_from_tensor(tensor: torch.Tensor) -> "torchrec.JaggedTensor":
|
def jagged_tensor_from_tensor(tensor: torch.Tensor) -> "torchrec.JaggedTensor":
|
||||||
"""
|
"""
|
||||||
Convert a torch tensor to a torchrec jagged tensor.
|
Convert a torch tensor to torchrec jagged tensor.
|
||||||
|
Note: Currently only support shape of [Batch_size] or [Batch_size x N] for dense tensors.
|
||||||
Note: Currently, this function only supports input tensors with shapes of [Batch_size] or [Batch_size x N] for dense tensors.
|
For sparse tensor the shape of .values() should be [Batch_size] or [Batch_size x N]; the
|
||||||
For sparse tensors, the shape of .values() should be [Batch_size] or [Batch_size x N], and the dense_shape of the sparse tensor can be arbitrary.
|
dense_shape of the sparse tensor can be arbitrary.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tensor: A torch (sparse) tensor.
|
tensor: a torch (sparse) tensor.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A torchrec JaggedTensor.
|
|
||||||
"""
|
"""
|
||||||
values, lengths = _compute_jagged_tensor_from_tensor(tensor)
|
values, lengths = _compute_jagged_tensor_from_tensor(tensor)
|
||||||
return torchrec.JaggedTensor(values=values, lengths=lengths)
|
return torchrec.JaggedTensor(values=values, lengths=lengths)
|
||||||
@ -71,16 +56,15 @@ def keyed_jagged_tensor_from_tensors_dict(
|
|||||||
tensor_map: Mapping[str, torch.Tensor]
|
tensor_map: Mapping[str, torch.Tensor]
|
||||||
) -> "torchrec.KeyedJaggedTensor":
|
) -> "torchrec.KeyedJaggedTensor":
|
||||||
"""
|
"""
|
||||||
Convert a dictionary of (sparse) torch tensors to a torchrec keyed jagged tensor.
|
Convert a dictionary of (sparse) torch tensors to torchrec keyed jagged tensor.
|
||||||
|
Note: Currently only support shape of [Batch_size] or [Batch_size x 1] for dense tensors.
|
||||||
Note: Currently, this function only supports input tensors with shapes of [Batch_size] or [Batch_size x 1] for dense tensors.
|
For sparse tensor the shape of .values() should be [Batch_size] or [Batch_size x 1]; the
|
||||||
For sparse tensors, the shape of .values() should be [Batch_size] or [Batch_size x 1], and the dense_shape of the sparse tensor can be arbitrary.
|
dense_shape of the sparse tensor can be arbitrary.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tensor_map: A mapping of tensor names to torch tensors.
|
tensor_map:
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A torchrec KeyedJaggedTensor.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if not tensor_map:
|
if not tensor_map:
|
||||||
@ -107,29 +91,10 @@ def keyed_jagged_tensor_from_tensors_dict(
|
|||||||
|
|
||||||
|
|
||||||
def _tf_to_numpy(tf_tensor: tf.Tensor) -> np.ndarray:
|
def _tf_to_numpy(tf_tensor: tf.Tensor) -> np.ndarray:
|
||||||
"""
|
|
||||||
Convert a TensorFlow tensor to a NumPy array.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tf_tensor: TensorFlow tensor.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
NumPy array.
|
|
||||||
"""
|
|
||||||
return tf_tensor._numpy() # noqa
|
return tf_tensor._numpy() # noqa
|
||||||
|
|
||||||
|
|
||||||
def _dense_tf_to_torch(tensor: tf.Tensor, pin_memory: bool) -> torch.Tensor:
|
def _dense_tf_to_torch(tensor: tf.Tensor, pin_memory: bool) -> torch.Tensor:
|
||||||
"""
|
|
||||||
Convert a dense TensorFlow tensor to a PyTorch tensor.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tensor: Dense TensorFlow tensor.
|
|
||||||
pin_memory: Whether to pin the tensor in memory (for CUDA).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
PyTorch tensor.
|
|
||||||
"""
|
|
||||||
tensor = _tf_to_numpy(tensor)
|
tensor = _tf_to_numpy(tensor)
|
||||||
# Pytorch does not support bfloat16, up cast to float32 to keep the same number of bits on exponent
|
# Pytorch does not support bfloat16, up cast to float32 to keep the same number of bits on exponent
|
||||||
if tensor.dtype.name == "bfloat16":
|
if tensor.dtype.name == "bfloat16":
|
||||||
@ -144,16 +109,6 @@ def _dense_tf_to_torch(tensor: tf.Tensor, pin_memory: bool) -> torch.Tensor:
|
|||||||
def sparse_or_dense_tf_to_torch(
|
def sparse_or_dense_tf_to_torch(
|
||||||
tensor: Union[tf.Tensor, tf.SparseTensor], pin_memory: bool
|
tensor: Union[tf.Tensor, tf.SparseTensor], pin_memory: bool
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
|
||||||
Convert a TensorFlow tensor (sparse or dense) to a PyTorch tensor.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tensor: TensorFlow tensor (sparse or dense).
|
|
||||||
pin_memory: Whether to pin the tensor in memory (for CUDA).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
PyTorch tensor.
|
|
||||||
"""
|
|
||||||
if isinstance(tensor, tf.SparseTensor):
|
if isinstance(tensor, tf.SparseTensor):
|
||||||
tensor = torch.sparse_coo_tensor(
|
tensor = torch.sparse_coo_tensor(
|
||||||
_dense_tf_to_torch(tensor.indices, pin_memory).t(),
|
_dense_tf_to_torch(tensor.indices, pin_memory).t(),
|
||||||
|
@ -69,7 +69,6 @@ class MaskBlock(torch.nn.Module):
|
|||||||
This class is intended for internal use within neural network architectures and should not be
|
This class is intended for internal use within neural network architectures and should not be
|
||||||
directly accessed or modified by external code.
|
directly accessed or modified by external code.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, mask_block_config: config.MaskBlockConfig, input_dim: int, mask_input_dim: int
|
self, mask_block_config: config.MaskBlockConfig, input_dim: int, mask_input_dim: int
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -95,13 +94,11 @@ class MaskBlock(torch.nn.Module):
|
|||||||
self._input_layer_norm = None
|
self._input_layer_norm = None
|
||||||
|
|
||||||
if mask_block_config.reduction_factor:
|
if mask_block_config.reduction_factor:
|
||||||
aggregation_size = int(
|
aggregation_size = int(mask_input_dim * mask_block_config.reduction_factor)
|
||||||
mask_input_dim * mask_block_config.reduction_factor)
|
|
||||||
elif mask_block_config.aggregation_size is not None:
|
elif mask_block_config.aggregation_size is not None:
|
||||||
aggregation_size = mask_block_config.aggregation_size
|
aggregation_size = mask_block_config.aggregation_size
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError("Need one of reduction factor or aggregation size.")
|
||||||
"Need one of reduction factor or aggregation size.")
|
|
||||||
|
|
||||||
self._mask_layer = torch.nn.Sequential(
|
self._mask_layer = torch.nn.Sequential(
|
||||||
torch.nn.Linear(mask_input_dim, aggregation_size),
|
torch.nn.Linear(mask_input_dim, aggregation_size),
|
||||||
@ -126,8 +123,7 @@ class MaskBlock(torch.nn.Module):
|
|||||||
"""
|
"""
|
||||||
if self._input_layer_norm:
|
if self._input_layer_norm:
|
||||||
net = self._input_layer_norm(net)
|
net = self._input_layer_norm(net)
|
||||||
hidden_layer_output = self._hidden_layer(
|
hidden_layer_output = self._hidden_layer(net * self._mask_layer(mask_input))
|
||||||
net * self._mask_layer(mask_input))
|
|
||||||
return self._layer_norm(hidden_layer_output)
|
return self._layer_norm(hidden_layer_output)
|
||||||
|
|
||||||
|
|
||||||
@ -174,7 +170,6 @@ class MaskNet(torch.nn.Module):
|
|||||||
This class is intended for internal use within neural network architectures and should not be
|
This class is intended for internal use within neural network architectures and should not be
|
||||||
directly accessed or modified by external code.
|
directly accessed or modified by external code.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, mask_net_config: config.MaskNetConfig, in_features: int):
|
def __init__(self, mask_net_config: config.MaskNetConfig, in_features: int):
|
||||||
"""
|
"""
|
||||||
Initializes the MaskNet module.
|
Initializes the MaskNet module.
|
||||||
@ -194,23 +189,20 @@ class MaskNet(torch.nn.Module):
|
|||||||
if mask_net_config.use_parallel:
|
if mask_net_config.use_parallel:
|
||||||
total_output_mask_blocks = 0
|
total_output_mask_blocks = 0
|
||||||
for mask_block_config in mask_net_config.mask_blocks:
|
for mask_block_config in mask_net_config.mask_blocks:
|
||||||
mask_blocks.append(
|
mask_blocks.append(MaskBlock(mask_block_config, in_features, in_features))
|
||||||
MaskBlock(mask_block_config, in_features, in_features))
|
|
||||||
total_output_mask_blocks += mask_block_config.output_size
|
total_output_mask_blocks += mask_block_config.output_size
|
||||||
self._mask_blocks = torch.nn.ModuleList(mask_blocks)
|
self._mask_blocks = torch.nn.ModuleList(mask_blocks)
|
||||||
else:
|
else:
|
||||||
input_size = in_features
|
input_size = in_features
|
||||||
for mask_block_config in mask_net_config.mask_blocks:
|
for mask_block_config in mask_net_config.mask_blocks:
|
||||||
mask_blocks.append(
|
mask_blocks.append(MaskBlock(mask_block_config, input_size, in_features))
|
||||||
MaskBlock(mask_block_config, input_size, in_features))
|
|
||||||
input_size = mask_block_config.output_size
|
input_size = mask_block_config.output_size
|
||||||
|
|
||||||
self._mask_blocks = torch.nn.ModuleList(mask_blocks)
|
self._mask_blocks = torch.nn.ModuleList(mask_blocks)
|
||||||
total_output_mask_blocks = mask_block_config.output_size
|
total_output_mask_blocks = mask_block_config.output_size
|
||||||
|
|
||||||
if mask_net_config.mlp:
|
if mask_net_config.mlp:
|
||||||
self._dense_layers = mlp.Mlp(
|
self._dense_layers = mlp.Mlp(total_output_mask_blocks, mask_net_config.mlp)
|
||||||
total_output_mask_blocks, mask_net_config.mlp)
|
|
||||||
self.out_features = mask_net_config.mlp.layer_sizes[-1]
|
self.out_features = mask_net_config.mlp.layer_sizes[-1]
|
||||||
else:
|
else:
|
||||||
self.out_features = total_output_mask_blocks
|
self.out_features = total_output_mask_blocks
|
||||||
@ -243,6 +235,5 @@ class MaskNet(torch.nn.Module):
|
|||||||
for mask_layer in self._mask_blocks:
|
for mask_layer in self._mask_blocks:
|
||||||
net = mask_layer(net=net, mask_input=inputs)
|
net = mask_layer(net=net, mask_input=inputs)
|
||||||
# Share the output of the stacked MaskBlocks.
|
# Share the output of the stacked MaskBlocks.
|
||||||
output = net if self.mask_net_config.mlp is None else self._dense_layers[
|
output = net if self.mask_net_config.mlp is None else self._dense_layers[net]["output"]
|
||||||
net]["output"]
|
|
||||||
return {"output": output, "shared_layer": net}
|
return {"output": output, "shared_layer": net}
|
||||||
|
@ -52,7 +52,6 @@ class ModelAndLoss(torch.nn.Module):
|
|||||||
This class is intended for internal use within neural network architectures and should not be
|
This class is intended for internal use within neural network architectures and should not be
|
||||||
directly accessed or modified by external code.
|
directly accessed or modified by external code.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model,
|
model,
|
||||||
@ -93,16 +92,13 @@ class ModelAndLoss(torch.nn.Module):
|
|||||||
labels=batch.labels,
|
labels=batch.labels,
|
||||||
weights=batch.weights,
|
weights=batch.weights,
|
||||||
)
|
)
|
||||||
losses = self.loss_fn(
|
losses = self.loss_fn(outputs["logits"], batch.labels.float(), batch.weights.float())
|
||||||
outputs["logits"], batch.labels.float(), batch.weights.float())
|
|
||||||
|
|
||||||
if self.stratifiers:
|
if self.stratifiers:
|
||||||
logging.info(
|
logging.info(f"***** Adding stratifiers *****\n {self.stratifiers}")
|
||||||
f"***** Adding stratifiers *****\n {self.stratifiers}")
|
|
||||||
outputs["stratifiers"] = {}
|
outputs["stratifiers"] = {}
|
||||||
for stratifier in self.stratifiers:
|
for stratifier in self.stratifiers:
|
||||||
outputs["stratifiers"][stratifier.name] = batch.discrete_features[:,
|
outputs["stratifiers"][stratifier.name] = batch.discrete_features[:, stratifier.index]
|
||||||
stratifier.index]
|
|
||||||
|
|
||||||
# In general, we can have a large number of losses returned by our loss function.
|
# In general, we can have a large number of losses returned by our loss function.
|
||||||
if isinstance(losses, dict):
|
if isinstance(losses, dict):
|
||||||
|
@ -39,7 +39,6 @@ class NumericCalibration(torch.nn.Module):
|
|||||||
This class is intended for internal use within neural network architectures and should not be
|
This class is intended for internal use within neural network architectures and should not be
|
||||||
directly accessed or modified by external code.
|
directly accessed or modified by external code.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
pos_downsampling_rate: float,
|
pos_downsampling_rate: float,
|
||||||
|
47
tools/pq.py
47
tools/pq.py
@ -38,15 +38,6 @@ import pyarrow.parquet as pq
|
|||||||
|
|
||||||
|
|
||||||
def _create_dataset(path: str):
|
def _create_dataset(path: str):
|
||||||
"""
|
|
||||||
Create a PyArrow dataset from Parquet files located at the specified path.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
path (str): The path to the Parquet files.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
pyarrow.dataset.Dataset: The PyArrow dataset.
|
|
||||||
"""
|
|
||||||
fs = infer_fs(path)
|
fs = infer_fs(path)
|
||||||
files = fs.glob(path)
|
files = fs.glob(path)
|
||||||
return pads.dataset(files, format="parquet", filesystem=fs)
|
return pads.dataset(files, format="parquet", filesystem=fs)
|
||||||
@ -56,27 +47,12 @@ class PqReader:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self, path: str, num: int = 10, batch_size: int = 1024, columns: Optional[List[str]] = None
|
self, path: str, num: int = 10, batch_size: int = 1024, columns: Optional[List[str]] = None
|
||||||
):
|
):
|
||||||
"""
|
|
||||||
Initialize a Parquet Reader.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
path (str): The path to the Parquet files.
|
|
||||||
num (int): The maximum number of rows to read.
|
|
||||||
batch_size (int): The batch size for reading data.
|
|
||||||
columns (Optional[List[str]]): A list of column names to read (default is None, which reads all columns).
|
|
||||||
"""
|
|
||||||
self._ds = _create_dataset(path)
|
self._ds = _create_dataset(path)
|
||||||
self._batch_size = batch_size
|
self._batch_size = batch_size
|
||||||
self._num = num
|
self._num = num
|
||||||
self._columns = columns
|
self._columns = columns
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
"""
|
|
||||||
Iterate through the Parquet data and yield batches of rows.
|
|
||||||
|
|
||||||
Yields:
|
|
||||||
pyarrow.RecordBatch: A batch of rows.
|
|
||||||
"""
|
|
||||||
batches = self._ds.to_batches(batch_size=self._batch_size, columns=self._columns)
|
batches = self._ds.to_batches(batch_size=self._batch_size, columns=self._columns)
|
||||||
rows_seen = 0
|
rows_seen = 0
|
||||||
for count, record in enumerate(batches):
|
for count, record in enumerate(batches):
|
||||||
@ -86,12 +62,6 @@ class PqReader:
|
|||||||
rows_seen += record.data.num_rows
|
rows_seen += record.data.num_rows
|
||||||
|
|
||||||
def _head(self):
|
def _head(self):
|
||||||
"""
|
|
||||||
Get the first `num` rows of the Parquet data.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
pyarrow.RecordBatch: A batch of rows.
|
|
||||||
"""
|
|
||||||
total_read = self._num * self.bytes_per_row
|
total_read = self._num * self.bytes_per_row
|
||||||
if total_read >= int(500e6):
|
if total_read >= int(500e6):
|
||||||
raise Exception(
|
raise Exception(
|
||||||
@ -101,12 +71,6 @@ class PqReader:
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def bytes_per_row(self) -> int:
|
def bytes_per_row(self) -> int:
|
||||||
"""
|
|
||||||
Calculate the estimated bytes per row in the dataset.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
int: The estimated bytes per row.
|
|
||||||
"""
|
|
||||||
nbits = 0
|
nbits = 0
|
||||||
for t in self._ds.schema.types:
|
for t in self._ds.schema.types:
|
||||||
try:
|
try:
|
||||||
@ -117,22 +81,17 @@ class PqReader:
|
|||||||
return nbits // 8
|
return nbits // 8
|
||||||
|
|
||||||
def schema(self):
|
def schema(self):
|
||||||
"""
|
|
||||||
Display the schema of the Parquet dataset.
|
|
||||||
"""
|
|
||||||
print(f"\n# Schema\n{self._ds.schema}")
|
print(f"\n# Schema\n{self._ds.schema}")
|
||||||
|
|
||||||
def head(self):
|
def head(self):
|
||||||
"""
|
"""Displays first --num rows."""
|
||||||
Display the first `num` rows of the Parquet data as a pandas DataFrame.
|
|
||||||
"""
|
|
||||||
print(self._head().to_pandas())
|
print(self._head().to_pandas())
|
||||||
|
|
||||||
def distinct(self):
|
def distinct(self):
|
||||||
"""
|
"""Displays unique values seen in specified columns in the first `--num` rows.
|
||||||
Display unique values seen in specified columns in the first `num` rows.
|
|
||||||
|
|
||||||
Useful for getting an approximate vocabulary for certain columns.
|
Useful for getting an approximate vocabulary for certain columns.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
for col_name, column in zip(self._head().column_names, self._head().columns):
|
for col_name, column in zip(self._head().column_names, self._head().columns):
|
||||||
print(col_name)
|
print(col_name)
|
||||||
|
Loading…
Reference in New Issue
Block a user