This commit is contained in:
rajveer43 2023-09-21 22:53:34 +05:30
parent cc73f5fcb7
commit f7f26d0c20
9 changed files with 953 additions and 724 deletions

View File

@ -54,13 +54,21 @@ def maybe_shard_model(
model, model,
device: torch.device, device: torch.device,
): ):
"""Set up and apply DistributedModelParallel to a model if running in a distributed environment. """
Set up and apply DistributedModelParallel to a model if running in a distributed environment.
If in a distributed environment, constructs Topology, sharders, and ShardingPlan, then applies If in a distributed environment, constructs Topology, sharders, and ShardingPlan, then applies
DistributedModelParallel. DistributedModelParallel.
If not in a distributed environment, returns model directly. If not in a distributed environment, returns the model directly.
"""
Args:
model: The PyTorch model.
device: The target device (e.g., 'cuda').
Returns:
The model wrapped with DistributedModelParallel if in a distributed environment, else the original model.
"""
if dist.is_initialized(): if dist.is_initialized():
logging.info("***** Wrapping in DistributedModelParallel *****") logging.info("***** Wrapping in DistributedModelParallel *****")
logging.info(f"Model before wrapping: {model}") logging.info(f"Model before wrapping: {model}")
@ -74,14 +82,15 @@ def maybe_shard_model(
def log_sharded_tensor_content(weight_name: str, table_name: str, weight_tensor) -> None: def log_sharded_tensor_content(weight_name: str, table_name: str, weight_tensor) -> None:
"""Handy function to log the content of EBC embedding layer.
Only works for single GPU machines.
Args:
weight_name: name of tensor, as defined in model
table_name: name of the EBC table the weight is taken from
weight_tensor: embedding weight tensor
""" """
Handy function to log the content of an EBC (Embedding Bag Concatenation) embedding layer.
Only works for single GPU machines.
Args:
weight_name: Name of the tensor, as defined in the model.
table_name: Name of the EBC table the weight is taken from.
weight_tensor: Embedding weight tensor.
"""
logging.info(f"{weight_name}, {table_name}", rank=-1) logging.info(f"{weight_name}, {table_name}", rank=-1)
logging.info(f"{weight_tensor.metadata()}", rank=-1) logging.info(f"{weight_tensor.metadata()}", rank=-1)
output_tensor = torch.zeros(*weight_tensor.size(), device=torch.device("cuda:0")) output_tensor = torch.zeros(*weight_tensor.size(), device=torch.device("cuda:0"))

View File

@ -8,239 +8,250 @@ import pydantic
class ExplicitDateInputs(base_config.BaseConfig): class ExplicitDateInputs(base_config.BaseConfig):
"""Arguments to select train/validation data using end_date and days of data.""" """Arguments to select train/validation data using end_date and days of data."""
data_root: str = pydantic.Field(..., description="Data path prefix.") data_root: str = pydantic.Field(..., description="Data path prefix.")
end_date: str = pydantic.Field(..., description="Data end date, inclusive.") end_date: str = pydantic.Field(...,
days: int = pydantic.Field(..., description="Number of days of data for dataset.") description="Data end date, inclusive.")
num_missing_days_tol: int = pydantic.Field( days: int = pydantic.Field(...,
0, description="We tolerate <= num_missing_days_tol days of missing data." description="Number of days of data for dataset.")
) num_missing_days_tol: int = pydantic.Field(
0, description="We tolerate <= num_missing_days_tol days of missing data."
)
class ExplicitDatetimeInputs(base_config.BaseConfig): class ExplicitDatetimeInputs(base_config.BaseConfig):
"""Arguments to select train/validation data using end_datetime and hours of data.""" """Arguments to select train/validation data using end_datetime and hours of data."""
data_root: str = pydantic.Field(..., description="Data path prefix.") data_root: str = pydantic.Field(..., description="Data path prefix.")
end_datetime: str = pydantic.Field(..., description="Data end datetime, inclusive.") end_datetime: str = pydantic.Field(...,
hours: int = pydantic.Field(..., description="Number of hours of data for dataset.") description="Data end datetime, inclusive.")
num_missing_hours_tol: int = pydantic.Field( hours: int = pydantic.Field(...,
0, description="We tolerate <= num_missing_hours_tol hours of missing data." description="Number of hours of data for dataset.")
) num_missing_hours_tol: int = pydantic.Field(
0, description="We tolerate <= num_missing_hours_tol hours of missing data."
)
class DdsCompressionOption(str, Enum): class DdsCompressionOption(str, Enum):
"""The only valid compression option is 'AUTO'""" """The only valid compression option is 'AUTO'"""
AUTO = "AUTO" AUTO = "AUTO"
class DatasetConfig(base_config.BaseConfig): class DatasetConfig(base_config.BaseConfig):
inputs: str = pydantic.Field( inputs: str = pydantic.Field(
None, description="A glob for selecting data.", one_of="date_inputs_format" None, description="A glob for selecting data.", one_of="date_inputs_format"
) )
explicit_datetime_inputs: ExplicitDatetimeInputs = pydantic.Field( explicit_datetime_inputs: ExplicitDatetimeInputs = pydantic.Field(
None, one_of="date_inputs_format" None, one_of="date_inputs_format"
) )
explicit_date_inputs: ExplicitDateInputs = pydantic.Field(None, one_of="date_inputs_format") explicit_date_inputs: ExplicitDateInputs = pydantic.Field(
None, one_of="date_inputs_format")
global_batch_size: pydantic.PositiveInt global_batch_size: pydantic.PositiveInt
num_files_to_keep: pydantic.PositiveInt = pydantic.Field( num_files_to_keep: pydantic.PositiveInt = pydantic.Field(
None, description="Number of shards to keep." None, description="Number of shards to keep."
) )
repeat_files: bool = pydantic.Field( repeat_files: bool = pydantic.Field(
True, description="DEPRICATED. Files are repeated no matter what this is set to." True, description="DEPRICATED. Files are repeated no matter what this is set to."
) )
file_batch_size: pydantic.PositiveInt = pydantic.Field(16, description="File batch size") file_batch_size: pydantic.PositiveInt = pydantic.Field(
16, description="File batch size")
cache: bool = pydantic.Field( cache: bool = pydantic.Field(
False, False,
description="Cache dataset in memory. Careful to only use this when you" description="Cache dataset in memory. Careful to only use this when you"
" have enough memory to fit entire dataset.", " have enough memory to fit entire dataset.",
) )
data_service_dispatcher: str = pydantic.Field(None) data_service_dispatcher: str = pydantic.Field(None)
ignore_data_errors: bool = pydantic.Field( ignore_data_errors: bool = pydantic.Field(
False, description="Whether to ignore tf.data errors. DANGER DANGER, may wedge jobs." False, description="Whether to ignore tf.data errors. DANGER DANGER, may wedge jobs."
) )
dataset_service_compression: DdsCompressionOption = pydantic.Field( dataset_service_compression: DdsCompressionOption = pydantic.Field(
None, None,
description="Compress the dataset for DDS worker -> training host. Disabled by default and the only valid option is 'AUTO'", description="Compress the dataset for DDS worker -> training host. Disabled by default and the only valid option is 'AUTO'",
) )
# tf.data.Dataset options # tf.data.Dataset options
examples_shuffle_buffer_size: int = pydantic.Field(1024, description="Size of shuffle buffers.") examples_shuffle_buffer_size: int = pydantic.Field(
map_num_parallel_calls: pydantic.PositiveInt = pydantic.Field( 1024, description="Size of shuffle buffers.")
None, description="Number of parallel calls." map_num_parallel_calls: pydantic.PositiveInt = pydantic.Field(
) None, description="Number of parallel calls."
interleave_num_parallel_calls: pydantic.PositiveInt = pydantic.Field( )
None, description="Number of shards to interleave." interleave_num_parallel_calls: pydantic.PositiveInt = pydantic.Field(
) None, description="Number of shards to interleave."
)
class TruncateAndSlice(base_config.BaseConfig): class TruncateAndSlice(base_config.BaseConfig):
# Apply truncation and then slice. # Apply truncation and then slice.
continuous_feature_truncation: pydantic.PositiveInt = pydantic.Field( continuous_feature_truncation: pydantic.PositiveInt = pydantic.Field(
None, description="Experimental. Truncates continuous features to this amount for efficiency." None, description="Experimental. Truncates continuous features to this amount for efficiency."
) )
binary_feature_truncation: pydantic.PositiveInt = pydantic.Field( binary_feature_truncation: pydantic.PositiveInt = pydantic.Field(
None, description="Experimental. Truncates binary features to this amount for efficiency." None, description="Experimental. Truncates binary features to this amount for efficiency."
) )
continuous_feature_mask_path: str = pydantic.Field( continuous_feature_mask_path: str = pydantic.Field(
None, description="Path of mask used to slice input continuous features." None, description="Path of mask used to slice input continuous features."
) )
binary_feature_mask_path: str = pydantic.Field( binary_feature_mask_path: str = pydantic.Field(
None, description="Path of mask used to slice input binary features." None, description="Path of mask used to slice input binary features."
) )
class DataType(str, Enum): class DataType(str, Enum):
BFLOAT16 = "bfloat16" BFLOAT16 = "bfloat16"
BOOL = "bool" BOOL = "bool"
FLOAT32 = "float32" FLOAT32 = "float32"
FLOAT16 = "float16" FLOAT16 = "float16"
UINT8 = "uint8" UINT8 = "uint8"
class DownCast(base_config.BaseConfig): class DownCast(base_config.BaseConfig):
# Apply down casting to selected features. # Apply down casting to selected features.
features: typing.Dict[str, DataType] = pydantic.Field( features: typing.Dict[str, DataType] = pydantic.Field(
None, description="Map features to down cast data types." None, description="Map features to down cast data types."
) )
class TaskData(base_config.BaseConfig): class TaskData(base_config.BaseConfig):
pos_downsampling_rate: float = pydantic.Field( pos_downsampling_rate: float = pydantic.Field(
1.0, 1.0,
description="Downsampling rate of positives used to generate dataset.", description="Downsampling rate of positives used to generate dataset.",
) )
neg_downsampling_rate: float = pydantic.Field( neg_downsampling_rate: float = pydantic.Field(
1.0, 1.0,
description="Downsampling rate of negatives used to generate dataset.", description="Downsampling rate of negatives used to generate dataset.",
) )
class SegDenseSchema(base_config.BaseConfig): class SegDenseSchema(base_config.BaseConfig):
schema_path: str = pydantic.Field(..., description="Path to feature config json.") schema_path: str = pydantic.Field(...,
features: typing.List[str] = pydantic.Field( description="Path to feature config json.")
[], features: typing.List[str] = pydantic.Field(
description="List of features (in addition to the renamed features) to read from schema path above.", [],
) description="List of features (in addition to the renamed features) to read from schema path above.",
renamed_features: typing.Dict[str, str] = pydantic.Field( )
{}, description="Dictionary of renamed features." renamed_features: typing.Dict[str, str] = pydantic.Field(
) {}, description="Dictionary of renamed features."
mask_mantissa_features: typing.Dict[str, int] = pydantic.Field( )
{}, mask_mantissa_features: typing.Dict[str, int] = pydantic.Field(
description="(experimental) Number of mantissa bits to mask to simulate lower precision data.", {},
) description="(experimental) Number of mantissa bits to mask to simulate lower precision data.",
)
class RectifyLabels(base_config.BaseConfig): class RectifyLabels(base_config.BaseConfig):
label_rectification_window_in_hours: float = pydantic.Field( label_rectification_window_in_hours: float = pydantic.Field(
3.0, description="overlap time in hours for which to flip labels" 3.0, description="overlap time in hours for which to flip labels"
) )
served_timestamp_field: str = pydantic.Field( served_timestamp_field: str = pydantic.Field(
..., description="input field corresponding to served time" ..., description="input field corresponding to served time"
) )
impressed_timestamp_field: str = pydantic.Field( impressed_timestamp_field: str = pydantic.Field(
..., description="input field corresponding to impressed time" ..., description="input field corresponding to impressed time"
) )
label_to_engaged_timestamp_field: typing.Dict[str, str] = pydantic.Field( label_to_engaged_timestamp_field: typing.Dict[str, str] = pydantic.Field(
..., description="label to the input field corresponding to engagement time" ..., description="label to the input field corresponding to engagement time"
) )
class ExtractFeaturesRow(base_config.BaseConfig): class ExtractFeaturesRow(base_config.BaseConfig):
name: str = pydantic.Field( name: str = pydantic.Field(
..., ...,
description="name of the new field name to be created", description="name of the new field name to be created",
) )
source_tensor: str = pydantic.Field( source_tensor: str = pydantic.Field(
..., ...,
description="name of the dense tensor to look for the feature", description="name of the dense tensor to look for the feature",
) )
index: int = pydantic.Field( index: int = pydantic.Field(
..., ...,
description="index of the feature in the dense tensor", description="index of the feature in the dense tensor",
) )
class ExtractFeatures(base_config.BaseConfig): class ExtractFeatures(base_config.BaseConfig):
extract_feature_table: typing.List[ExtractFeaturesRow] = pydantic.Field( extract_feature_table: typing.List[ExtractFeaturesRow] = pydantic.Field(
[], [],
description="list of features to be extracted with their name, source tensor and index", description="list of features to be extracted with their name, source tensor and index",
) )
class DownsampleNegatives(base_config.BaseConfig): class DownsampleNegatives(base_config.BaseConfig):
batch_multiplier: int = pydantic.Field( batch_multiplier: int = pydantic.Field(
None, None,
description="batch multiplier", description="batch multiplier",
) )
engagements_list: typing.List[str] = pydantic.Field( engagements_list: typing.List[str] = pydantic.Field(
[], [],
description="engagements with kept positives", description="engagements with kept positives",
) )
num_engagements: int = pydantic.Field( num_engagements: int = pydantic.Field(
..., ...,
description="number engagements used in the model, including ones excluded in engagements_list", description="number engagements used in the model, including ones excluded in engagements_list",
) )
class Preprocess(base_config.BaseConfig): class Preprocess(base_config.BaseConfig):
truncate_and_slice: TruncateAndSlice = pydantic.Field(None, description="Truncation and slicing.") truncate_and_slice: TruncateAndSlice = pydantic.Field(
downcast: DownCast = pydantic.Field(None, description="Down cast to features.") None, description="Truncation and slicing.")
rectify_labels: RectifyLabels = pydantic.Field( downcast: DownCast = pydantic.Field(
None, description="Rectify labels for a given overlap window" None, description="Down cast to features.")
) rectify_labels: RectifyLabels = pydantic.Field(
extract_features: ExtractFeatures = pydantic.Field( None, description="Rectify labels for a given overlap window"
None, description="Extract features from dense tensors." )
) extract_features: ExtractFeatures = pydantic.Field(
downsample_negatives: DownsampleNegatives = pydantic.Field( None, description="Extract features from dense tensors."
None, description="Downsample negatives." )
) downsample_negatives: DownsampleNegatives = pydantic.Field(
None, description="Downsample negatives."
)
class Sampler(base_config.BaseConfig): class Sampler(base_config.BaseConfig):
"""Assumes function is defined in data/samplers.py. """Assumes function is defined in data/samplers.py.
Only use this for quick experimentation. Only use this for quick experimentation.
If samplers are useful, we should sample from upstream data generation. If samplers are useful, we should sample from upstream data generation.
DEPRICATED, DO NOT USE. DEPRICATED, DO NOT USE.
""" """
name: str name: str
kwargs: typing.Dict kwargs: typing.Dict
class RecapDataConfig(DatasetConfig): class RecapDataConfig(DatasetConfig):
seg_dense_schema: SegDenseSchema seg_dense_schema: SegDenseSchema
tasks: typing.Dict[str, TaskData] = pydantic.Field( tasks: typing.Dict[str, TaskData] = pydantic.Field(
description="Description of individual tasks in this dataset." description="Description of individual tasks in this dataset."
) )
evaluation_tasks: typing.List[str] = pydantic.Field( evaluation_tasks: typing.List[str] = pydantic.Field(
[], description="If specified, lists the tasks we're generating metrics for." [], description="If specified, lists the tasks we're generating metrics for."
) )
preprocess: Preprocess = pydantic.Field( preprocess: Preprocess = pydantic.Field(
None, description="Function run in tf.data.Dataset at train/eval, in-graph at inference." None, description="Function run in tf.data.Dataset at train/eval, in-graph at inference."
) )
sampler: Sampler = pydantic.Field( sampler: Sampler = pydantic.Field(
None, None,
description="""DEPRICATED, DO NOT USE. Sampling function for offline experiments.""", description="""DEPRICATED, DO NOT USE. Sampling function for offline experiments.""",
) )
@pydantic.root_validator() @pydantic.root_validator()
def _validate_evaluation_tasks(cls, values): def _validate_evaluation_tasks(cls, values):
if values.get("evaluation_tasks") is not None: if values.get("evaluation_tasks") is not None:
for task in values["evaluation_tasks"]: for task in values["evaluation_tasks"]:
if task not in values["tasks"]: if task not in values["tasks"]:
raise KeyError(f"Evaluation task {task} must be in tasks. Received {values['tasks']}") raise KeyError(
return values f"Evaluation task {task} must be in tasks. Received {values['tasks']}")
return values

View File

@ -9,9 +9,20 @@ import numpy as np
class TruncateAndSlice(tf.keras.Model): class TruncateAndSlice(tf.keras.Model):
"""Class for truncating and slicing.""" """
A class for truncating and slicing input features based on the provided configuration.
Args:
truncate_and_slice_config: A configuration object specifying how to truncate and slice features.
"""
def __init__(self, truncate_and_slice_config): def __init__(self, truncate_and_slice_config):
"""
Initializes the TruncateAndSlice model.
Args:
truncate_and_slice_config: A configuration object specifying how to truncate and slice features.
"""
super().__init__() super().__init__()
self._truncate_and_slice_config = truncate_and_slice_config self._truncate_and_slice_config = truncate_and_slice_config
@ -32,6 +43,17 @@ class TruncateAndSlice(tf.keras.Model):
self._binary_mask = None self._binary_mask = None
def call(self, inputs, training=None, mask=None): def call(self, inputs, training=None, mask=None):
"""
Applies truncation and slicing to the input features based on the configuration.
Args:
inputs: A dictionary of input features.
training: A boolean indicating whether the model is in training mode.
mask: A mask tensor.
Returns:
A dictionary of truncated and sliced input features.
"""
outputs = tf.nest.pack_sequence_as(inputs, tf.nest.flatten(inputs)) outputs = tf.nest.pack_sequence_as(inputs, tf.nest.flatten(inputs))
if self._truncate_and_slice_config.continuous_feature_truncation: if self._truncate_and_slice_config.continuous_feature_truncation:
logging.info("Truncating continuous") logging.info("Truncating continuous")
@ -51,12 +73,23 @@ class TruncateAndSlice(tf.keras.Model):
class DownCast(tf.keras.Model): class DownCast(tf.keras.Model):
"""Class for Down casting dataset before serialization and transferring to training host.
Depends on the data type and the actual data range, the down casting can be lossless or not.
It is strongly recommended to compare the metrics before and after down casting.
""" """
A class for downcasting dataset before serialization and transferring to the training host.
Depending on the data type and the actual data range, the downcasting can be lossless or not.
It is strongly recommended to compare the metrics before and after downcasting.
Args:
downcast_config: A configuration object specifying the features and their target data types.
"""
def __init__(self, downcast_config): def __init__(self, downcast_config):
"""
Initializes the DownCast model.
Args:
downcast_config: A configuration object specifying the features and their target data types.
"""
super().__init__() super().__init__()
self.config = downcast_config self.config = downcast_config
self._type_map = { self._type_map = {
@ -65,6 +98,17 @@ class DownCast(tf.keras.Model):
} }
def call(self, inputs, training=None, mask=None): def call(self, inputs, training=None, mask=None):
"""
Applies downcasting to the input features based on the configuration.
Args:
inputs: A dictionary of input features.
training: A boolean indicating whether the model is in training mode.
mask: A mask tensor.
Returns:
A dictionary of downcasted input features.
"""
outputs = tf.nest.pack_sequence_as(inputs, tf.nest.flatten(inputs)) outputs = tf.nest.pack_sequence_as(inputs, tf.nest.flatten(inputs))
for feature, type_str in self.config.features.items(): for feature, type_str in self.config.features.items():
assert type_str in self._type_map assert type_str in self._type_map
@ -78,14 +122,39 @@ class DownCast(tf.keras.Model):
class RectifyLabels(tf.keras.Model): class RectifyLabels(tf.keras.Model):
"""Class for rectifying labels""" """
A class for downcasting dataset before serialization and transferring to the training host.
Depending on the data type and the actual data range, the downcasting can be lossless or not.
It is strongly recommended to compare the metrics before and after downcasting.
Args:
downcast_config: A configuration object specifying the features and their target data types.
"""
def __init__(self, rectify_label_config): def __init__(self, rectify_label_config):
"""
Initializes the DownCast model.
Args:
downcast_config: A configuration object specifying the features and their target data types.
"""
super().__init__() super().__init__()
self._config = rectify_label_config self._config = rectify_label_config
self._window = int(self._config.label_rectification_window_in_hours * 60 * 60 * 1000) self._window = int(self._config.label_rectification_window_in_hours * 60 * 60 * 1000)
def call(self, inputs, training=None, mask=None): def call(self, inputs, training=None, mask=None):
"""
Applies downcasting to the input features based on the configuration.
Args:
inputs: A dictionary of input features.
training: A boolean indicating whether the model is in training mode.
mask: A mask tensor.
Returns:
A dictionary of downcasted input features.
"""
served_ts_field = self._config.served_timestamp_field served_ts_field = self._config.served_timestamp_field
impressed_ts_field = self._config.impressed_timestamp_field impressed_ts_field = self._config.impressed_timestamp_field
@ -102,13 +171,37 @@ class RectifyLabels(tf.keras.Model):
class ExtractFeatures(tf.keras.Model): class ExtractFeatures(tf.keras.Model):
"""Class for extracting individual features from dense tensors by their index.""" """
A class for rectifying labels based on specified conditions.
This class is used to adjust label values in a dataset based on configured conditions involving timestamps.
Args:
rectify_label_config: A configuration object specifying the timestamp fields and label-to-engaged timestamp field mappings.
"""
def __init__(self, extract_features_config): def __init__(self, extract_features_config):
"""
Initializes the RectifyLabels model.
Args:
rectify_label_config: A configuration object specifying the timestamp fields and label-to-engaged timestamp field mappings.
"""
super().__init__() super().__init__()
self._config = extract_features_config self._config = extract_features_config
def call(self, inputs, training=None, mask=None): def call(self, inputs, training=None, mask=None):
"""
Rectifies label values based on the specified conditions.
Args:
inputs: A dictionary of input features including timestamp fields and labels.
training: A boolean indicating whether the model is in training mode.
mask: A mask tensor.
Returns:
A dictionary of input features with rectified label values.
"""
for row in self._config.extract_feature_table: for row in self._config.extract_feature_table:
inputs[row.name] = inputs[row.source_tensor][:, row.index] inputs[row.name] = inputs[row.source_tensor][:, row.index]
@ -168,7 +261,16 @@ class DownsampleNegatives(tf.keras.Model):
def build_preprocess(preprocess_config, mode=config_mod.JobMode.TRAIN): def build_preprocess(preprocess_config, mode=config_mod.JobMode.TRAIN):
"""Builds a preprocess model to apply all preprocessing stages.""" """
Builds a preprocess model to apply all preprocessing stages.
Args:
preprocess_config: A configuration object specifying the preprocessing parameters.
mode: A mode indicating the current job mode (TRAIN or INFERENCE).
Returns:
A preprocess model that applies all specified preprocessing stages.
"""
if mode == config_mod.JobMode.INFERENCE: if mode == config_mod.JobMode.INFERENCE:
logging.info("Not building preprocessors for dataloading since we are in Inference mode.") logging.info("Not building preprocessors for dataloading since we are in Inference mode.")
return None return None

View File

@ -8,122 +8,129 @@ import tensorflow as tf
DEFAULTS_MAP = {"int64_list": 0, "float_list": 0.0, "bytes_list": ""} DEFAULTS_MAP = {"int64_list": 0, "float_list": 0.0, "bytes_list": ""}
DTYPE_MAP = {"int64_list": tf.int64, "float_list": tf.float32, "bytes_list": tf.string} DTYPE_MAP = {"int64_list": tf.int64,
"float_list": tf.float32, "bytes_list": tf.string}
def create_tf_example_schema( def create_tf_example_schema(
data_config: recap_data_config.SegDenseSchema, data_config: recap_data_config.SegDenseSchema,
segdense_schema, segdense_schema,
): ):
"""Generate schema for deseralizing tf.Example. """Generate schema for deseralizing tf.Example.
Args: Args:
segdense_schema: List of dicts of segdense features (includes feature_name, dtype, length). segdense_schema: List of dicts of segdense features (includes feature_name, dtype, length).
labels: List of strings denoting labels. labels: List of strings denoting labels.
Returns: Returns:
A dictionary schema suitable for deserializing tf.Example. A dictionary schema suitable for deserializing tf.Example.
""" """
segdense_config = data_config.seg_dense_schema segdense_config = data_config.seg_dense_schema
labels = list(data_config.tasks.keys()) labels = list(data_config.tasks.keys())
used_features = ( used_features = (
segdense_config.features + list(segdense_config.renamed_features.values()) + labels segdense_config.features +
) list(segdense_config.renamed_features.values()) + labels
logging.info(used_features) )
logging.info(used_features)
tfe_schema = {} tfe_schema = {}
for entry in segdense_schema: for entry in segdense_schema:
feature_name = entry["feature_name"] feature_name = entry["feature_name"]
if feature_name in used_features: if feature_name in used_features:
length = entry["length"] length = entry["length"]
dtype = entry["dtype"] dtype = entry["dtype"]
if feature_name in labels: if feature_name in labels:
logging.info(f"Label: feature name is {feature_name} type is {dtype}") logging.info(
tfe_schema[feature_name] = tf.io.FixedLenFeature( f"Label: feature name is {feature_name} type is {dtype}")
length, DTYPE_MAP[dtype], DEFAULTS_MAP[dtype] tfe_schema[feature_name] = tf.io.FixedLenFeature(
) length, DTYPE_MAP[dtype], DEFAULTS_MAP[dtype]
elif length == -1: )
tfe_schema[feature_name] = tf.io.VarLenFeature(DTYPE_MAP[dtype]) elif length == -1:
else: tfe_schema[feature_name] = tf.io.VarLenFeature(
tfe_schema[feature_name] = tf.io.FixedLenFeature( DTYPE_MAP[dtype])
length, DTYPE_MAP[dtype], [DEFAULTS_MAP[dtype]] * length else:
) tfe_schema[feature_name] = tf.io.FixedLenFeature(
for feature_name in used_features: length, DTYPE_MAP[dtype], [DEFAULTS_MAP[dtype]] * length
if feature_name not in tfe_schema: )
raise ValueError(f"{feature_name} missing from schema: {segdense_config.schema_path}.") for feature_name in used_features:
return tfe_schema if feature_name not in tfe_schema:
raise ValueError(
f"{feature_name} missing from schema: {segdense_config.schema_path}.")
return tfe_schema
@functools.lru_cache(1) @functools.lru_cache(1)
def make_mantissa_mask(mask_length: int) -> tf.Tensor: def make_mantissa_mask(mask_length: int) -> tf.Tensor:
"""For experimentating with emulating bfloat16 or less precise types.""" """For experimentating with emulating bfloat16 or less precise types."""
return tf.constant((1 << 32) - (1 << mask_length), dtype=tf.int32) return tf.constant((1 << 32) - (1 << mask_length), dtype=tf.int32)
def mask_mantissa(tensor: tf.Tensor, mask_length: int) -> tf.Tensor: def mask_mantissa(tensor: tf.Tensor, mask_length: int) -> tf.Tensor:
"""For experimentating with emulating bfloat16 or less precise types.""" """For experimentating with emulating bfloat16 or less precise types."""
mask: tf.Tensor = make_mantissa_mask(mask_length) mask: tf.Tensor = make_mantissa_mask(mask_length)
return tf.bitcast(tf.bitwise.bitwise_and(tf.bitcast(tensor, tf.int32), mask), tensor.dtype) return tf.bitcast(tf.bitwise.bitwise_and(tf.bitcast(tensor, tf.int32), mask), tensor.dtype)
def parse_tf_example( def parse_tf_example(
serialized_example, serialized_example,
tfe_schema, tfe_schema,
seg_dense_schema_config, seg_dense_schema_config,
): ):
"""Parse serialized tf.Example into dict of tensors. """Parse serialized tf.Example into dict of tensors.
Args: Args:
serialized_example: Serialized tf.Example to be parsed. serialized_example: Serialized tf.Example to be parsed.
tfe_schema: Dictionary schema suitable for deserializing tf.Example. tfe_schema: Dictionary schema suitable for deserializing tf.Example.
Returns: Returns:
Dictionary of tensors to be used as model input. Dictionary of tensors to be used as model input.
""" """
inputs = tf.io.parse_example(serialized=serialized_example, features=tfe_schema) inputs = tf.io.parse_example(
serialized=serialized_example, features=tfe_schema)
for new_feature_name, old_feature_name in seg_dense_schema_config.renamed_features.items(): for new_feature_name, old_feature_name in seg_dense_schema_config.renamed_features.items():
inputs[new_feature_name] = inputs.pop(old_feature_name) inputs[new_feature_name] = inputs.pop(old_feature_name)
# This should not actually be used except for experimentation with low precision floats. # This should not actually be used except for experimentation with low precision floats.
if "mask_mantissa_features" in seg_dense_schema_config: if "mask_mantissa_features" in seg_dense_schema_config:
for feature_name, mask_length in seg_dense_schema_config.mask_mantissa_features.items(): for feature_name, mask_length in seg_dense_schema_config.mask_mantissa_features.items():
inputs[feature_name] = mask_mantissa(inputs[feature_name], mask_length) inputs[feature_name] = mask_mantissa(
inputs[feature_name], mask_length)
# DANGER DANGER: This default seems really scary, and it's only here because it has to be visible # DANGER DANGER: This default seems really scary, and it's only here because it has to be visible
# at TF level. # at TF level.
# We should not return empty tensors if we dont use embeddings. # We should not return empty tensors if we dont use embeddings.
# Otherwise, it breaks numpy->pt conversion # Otherwise, it breaks numpy->pt conversion
renamed_keys = list(seg_dense_schema_config.renamed_features.keys()) renamed_keys = list(seg_dense_schema_config.renamed_features.keys())
for renamed_key in renamed_keys: for renamed_key in renamed_keys:
if "embedding" in renamed_key and (renamed_key not in inputs): if "embedding" in renamed_key and (renamed_key not in inputs):
inputs[renamed_key] = tf.zeros([], tf.float32) inputs[renamed_key] = tf.zeros([], tf.float32)
logging.info(f"parsed example and inputs are {inputs}") logging.info(f"parsed example and inputs are {inputs}")
return inputs return inputs
def get_seg_dense_parse_fn(data_config: recap_data_config.RecapDataConfig): def get_seg_dense_parse_fn(data_config: recap_data_config.RecapDataConfig):
"""Placeholder for seg dense. """Placeholder for seg dense.
In the future, when we use more seg dense variations, we can change this. In the future, when we use more seg dense variations, we can change this.
""" """
with tf.io.gfile.GFile(data_config.seg_dense_schema.schema_path, "r") as f: with tf.io.gfile.GFile(data_config.seg_dense_schema.schema_path, "r") as f:
seg_dense_schema = json.load(f)["schema"] seg_dense_schema = json.load(f)["schema"]
tf_example_schema = create_tf_example_schema( tf_example_schema = create_tf_example_schema(
data_config, data_config,
seg_dense_schema, seg_dense_schema,
) )
logging.info("***** TF Example Schema *****") logging.info("***** TF Example Schema *****")
logging.info(tf_example_schema) logging.info(tf_example_schema)
parse = functools.partial( parse = functools.partial(
parse_tf_example, parse_tf_example,
tfe_schema=tf_example_schema, tfe_schema=tf_example_schema,
seg_dense_schema_config=data_config.seg_dense_schema, seg_dense_schema_config=data_config.seg_dense_schema,
) )
return parse return parse

View File

@ -6,115 +6,160 @@ import tensorflow as tf
def keyed_tensor_from_tensors_dict( def keyed_tensor_from_tensors_dict(
tensor_map: Mapping[str, torch.Tensor] tensor_map: Mapping[str, torch.Tensor]
) -> "torchrec.KeyedTensor": ) -> "torchrec.KeyedTensor":
""" """
Convert a dictionary of torch tensor to torchrec keyed tensor Convert a dictionary of torch tensors to a torchrec KeyedTensor.
Args:
tensor_map:
Returns: Args:
tensor_map: A mapping of tensor names to torch tensors.
""" Returns:
keys = list(tensor_map.keys()) A torchrec KeyedTensor.
# We expect batch size to be first dim. However, if we get a shape [Batch_size], """
# KeyedTensor will not find the correct batch_size. So, in those cases we make sure the shape is keys = list(tensor_map.keys())
# [Batch_size x 1]. # We expect batch size to be first dim. However, if we get a shape [Batch_size],
values = [ # KeyedTensor will not find the correct batch_size. So, in those cases we make sure the shape is
tensor_map[key] if len(tensor_map[key].shape) > 1 else torch.unsqueeze(tensor_map[key], -1) # [Batch_size x 1].
for key in keys values = [
] tensor_map[key] if len(tensor_map[key].shape) > 1 else torch.unsqueeze(
return torchrec.KeyedTensor.from_tensor_list(keys, values) tensor_map[key], -1)
for key in keys
]
return torchrec.KeyedTensor.from_tensor_list(keys, values)
def _compute_jagged_tensor_from_tensor(tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: def _compute_jagged_tensor_from_tensor(tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
if tensor.is_sparse: """
x = tensor.coalesce() # Ensure that the indices are ordered. Compute a jagged tensor from a torch tensor.
lengths = torch.bincount(x.indices()[0])
values = x.values() Args:
else: tensor: Input torch tensor.
values = tensor
lengths = torch.ones(tensor.shape[0], dtype=torch.int32, device=tensor.device) Returns:
return values, lengths A tuple containing the values and lengths of the jagged tensor.
"""
if tensor.is_sparse:
x = tensor.coalesce() # Ensure that the indices are ordered.
lengths = torch.bincount(x.indices()[0])
values = x.values()
else:
values = tensor
lengths = torch.ones(
tensor.shape[0], dtype=torch.int32, device=tensor.device)
return values, lengths
def jagged_tensor_from_tensor(tensor: torch.Tensor) -> "torchrec.JaggedTensor": def jagged_tensor_from_tensor(tensor: torch.Tensor) -> "torchrec.JaggedTensor":
""" """
Convert a torch tensor to torchrec jagged tensor. Convert a torch tensor to a torchrec jagged tensor.
Note: Currently only support shape of [Batch_size] or [Batch_size x N] for dense tensors.
For sparse tensor the shape of .values() should be [Batch_size] or [Batch_size x N]; the Note: Currently, this function only supports input tensors with shapes of [Batch_size] or [Batch_size x N] for dense tensors.
dense_shape of the sparse tensor can be arbitrary. For sparse tensors, the shape of .values() should be [Batch_size] or [Batch_size x N], and the dense_shape of the sparse tensor can be arbitrary.
Args:
tensor: a torch (sparse) tensor. Args:
Returns: tensor: A torch (sparse) tensor.
"""
values, lengths = _compute_jagged_tensor_from_tensor(tensor) Returns:
return torchrec.JaggedTensor(values=values, lengths=lengths) A torchrec JaggedTensor.
"""
values, lengths = _compute_jagged_tensor_from_tensor(tensor)
return torchrec.JaggedTensor(values=values, lengths=lengths)
def keyed_jagged_tensor_from_tensors_dict( def keyed_jagged_tensor_from_tensors_dict(
tensor_map: Mapping[str, torch.Tensor] tensor_map: Mapping[str, torch.Tensor]
) -> "torchrec.KeyedJaggedTensor": ) -> "torchrec.KeyedJaggedTensor":
""" """
Convert a dictionary of (sparse) torch tensors to torchrec keyed jagged tensor. Convert a dictionary of (sparse) torch tensors to a torchrec keyed jagged tensor.
Note: Currently only support shape of [Batch_size] or [Batch_size x 1] for dense tensors.
For sparse tensor the shape of .values() should be [Batch_size] or [Batch_size x 1]; the
dense_shape of the sparse tensor can be arbitrary.
Args:
tensor_map:
Returns: Note: Currently, this function only supports input tensors with shapes of [Batch_size] or [Batch_size x 1] for dense tensors.
For sparse tensors, the shape of .values() should be [Batch_size] or [Batch_size x 1], and the dense_shape of the sparse tensor can be arbitrary.
""" Args:
tensor_map: A mapping of tensor names to torch tensors.
Returns:
A torchrec KeyedJaggedTensor.
"""
if not tensor_map:
return torchrec.KeyedJaggedTensor(
keys=[],
values=torch.zeros(0, dtype=torch.int),
lengths=torch.zeros(0, dtype=torch.int),
)
values = []
lengths = []
for tensor in tensor_map.values():
tensor_val, tensor_len = _compute_jagged_tensor_from_tensor(tensor)
values.append(torch.squeeze(tensor_val))
lengths.append(tensor_len)
values = torch.cat(values, axis=0)
lengths = torch.cat(lengths, axis=0)
if not tensor_map:
return torchrec.KeyedJaggedTensor( return torchrec.KeyedJaggedTensor(
keys=[], keys=list(tensor_map.keys()),
values=torch.zeros(0, dtype=torch.int), values=values,
lengths=torch.zeros(0, dtype=torch.int), lengths=lengths,
) )
values = []
lengths = []
for tensor in tensor_map.values():
tensor_val, tensor_len = _compute_jagged_tensor_from_tensor(tensor)
values.append(torch.squeeze(tensor_val))
lengths.append(tensor_len)
values = torch.cat(values, axis=0)
lengths = torch.cat(lengths, axis=0)
return torchrec.KeyedJaggedTensor(
keys=list(tensor_map.keys()),
values=values,
lengths=lengths,
)
def _tf_to_numpy(tf_tensor: tf.Tensor) -> np.ndarray: def _tf_to_numpy(tf_tensor: tf.Tensor) -> np.ndarray:
return tf_tensor._numpy() # noqa """
Convert a TensorFlow tensor to a NumPy array.
Args:
tf_tensor: TensorFlow tensor.
Returns:
NumPy array.
"""
return tf_tensor._numpy() # noqa
def _dense_tf_to_torch(tensor: tf.Tensor, pin_memory: bool) -> torch.Tensor: def _dense_tf_to_torch(tensor: tf.Tensor, pin_memory: bool) -> torch.Tensor:
tensor = _tf_to_numpy(tensor) """
# Pytorch does not support bfloat16, up cast to float32 to keep the same number of bits on exponent Convert a dense TensorFlow tensor to a PyTorch tensor.
if tensor.dtype.name == "bfloat16":
tensor = tensor.astype(np.float32)
tensor = torch.from_numpy(tensor) Args:
if pin_memory: tensor: Dense TensorFlow tensor.
tensor = tensor.pin_memory() pin_memory: Whether to pin the tensor in memory (for CUDA).
return tensor
Returns:
PyTorch tensor.
"""
tensor = _tf_to_numpy(tensor)
# Pytorch does not support bfloat16, up cast to float32 to keep the same number of bits on exponent
if tensor.dtype.name == "bfloat16":
tensor = tensor.astype(np.float32)
tensor = torch.from_numpy(tensor)
if pin_memory:
tensor = tensor.pin_memory()
return tensor
def sparse_or_dense_tf_to_torch( def sparse_or_dense_tf_to_torch(
tensor: Union[tf.Tensor, tf.SparseTensor], pin_memory: bool tensor: Union[tf.Tensor, tf.SparseTensor], pin_memory: bool
) -> torch.Tensor: ) -> torch.Tensor:
if isinstance(tensor, tf.SparseTensor): """
tensor = torch.sparse_coo_tensor( Convert a TensorFlow tensor (sparse or dense) to a PyTorch tensor.
_dense_tf_to_torch(tensor.indices, pin_memory).t(),
_dense_tf_to_torch(tensor.values, pin_memory), Args:
torch.Size(_tf_to_numpy(tensor.dense_shape)), tensor: TensorFlow tensor (sparse or dense).
) pin_memory: Whether to pin the tensor in memory (for CUDA).
else:
tensor = _dense_tf_to_torch(tensor, pin_memory) Returns:
return tensor PyTorch tensor.
"""
if isinstance(tensor, tf.SparseTensor):
tensor = torch.sparse_coo_tensor(
_dense_tf_to_torch(tensor.indices, pin_memory).t(),
_dense_tf_to_torch(tensor.values, pin_memory),
torch.Size(_tf_to_numpy(tensor.dense_shape)),
)
else:
tensor = _dense_tf_to_torch(tensor, pin_memory)
return tensor

View File

@ -6,234 +6,243 @@ import torch
def _init_weights(module): def _init_weights(module):
"""Initializes weights """Initializes weights
Example Example
```python ```python
import torch import torch
import torch.nn as nn import torch.nn as nn
# Define a simple linear layer # Define a simple linear layer
linear_layer = nn.Linear(64, 32) linear_layer = nn.Linear(64, 32)
# Initialize the weights and biases using _init_weights # Initialize the weights and biases using _init_weights
_init_weights(linear_layer) _init_weights(linear_layer)
``` ```
""" """
if isinstance(module, torch.nn.Linear): if isinstance(module, torch.nn.Linear):
torch.nn.init.xavier_uniform_(module.weight) torch.nn.init.xavier_uniform_(module.weight)
torch.nn.init.constant_(module.bias, 0) torch.nn.init.constant_(module.bias, 0)
class MaskBlock(torch.nn.Module): class MaskBlock(torch.nn.Module):
""" """
MaskBlock module in a mask-based neural network. MaskBlock module in a mask-based neural network.
This module represents a MaskBlock, which applies a masking operation to the input data and then This module represents a MaskBlock, which applies a masking operation to the input data and then
passes it through a hidden layer. It is typically used as a building block within a MaskNet. passes it through a hidden layer. It is typically used as a building block within a MaskNet.
Args: Args:
mask_block_config (config.MaskBlockConfig): Configuration for the MaskBlock. mask_block_config (config.MaskBlockConfig): Configuration for the MaskBlock.
input_dim (int): Dimensionality of the input data. input_dim (int): Dimensionality of the input data.
mask_input_dim (int): Dimensionality of the mask input. mask_input_dim (int): Dimensionality of the mask input.
Example: Example:
To create and use a MaskBlock within a MaskNet, follow these steps: To create and use a MaskBlock within a MaskNet, follow these steps:
```python ```python
# Define the configuration for the MaskBlock # Define the configuration for the MaskBlock
mask_block_config = MaskBlockConfig( mask_block_config = MaskBlockConfig(
input_layer_norm=True, # Apply input layer normalization input_layer_norm=True, # Apply input layer normalization
reduction_factor=0.5 # Reduce input dimensionality by 50% reduction_factor=0.5 # Reduce input dimensionality by 50%
)
# Create an instance of the MaskBlock
mask_block = MaskBlock(mask_block_config, input_dim=64, mask_input_dim=32)
# Generate input tensors
input_data = torch.randn(batch_size, 64)
mask_input = torch.randn(batch_size, 32)
# Perform a forward pass through the MaskBlock
output = mask_block(input_data, mask_input)
```
Note:
The `MaskBlock` module applies layer normalization to the input if specified, followed by a masking
operation that combines the input and mask input. Then, it passes the result through a hidden layer
with optional dimensionality reduction.
Warning:
This class is intended for internal use within neural network architectures and should not be
directly accessed or modified by external code.
"""
def __init__(
self, mask_block_config: config.MaskBlockConfig, input_dim: int, mask_input_dim: int
) -> None:
"""
Initializes the MaskBlock module.
Args:
mask_block_config (config.MaskBlockConfig): Configuration for the MaskBlock.
input_dim (int): Dimensionality of the input data.
mask_input_dim (int): Dimensionality of the mask input.
Returns:
None
"""
super(MaskBlock, self).__init__()
self.mask_block_config = mask_block_config
output_size = mask_block_config.output_size
if mask_block_config.input_layer_norm:
self._input_layer_norm = torch.nn.LayerNorm(input_dim)
else:
self._input_layer_norm = None
if mask_block_config.reduction_factor:
aggregation_size = int(
mask_input_dim * mask_block_config.reduction_factor)
elif mask_block_config.aggregation_size is not None:
aggregation_size = mask_block_config.aggregation_size
else:
raise ValueError(
"Need one of reduction factor or aggregation size.")
self._mask_layer = torch.nn.Sequential(
torch.nn.Linear(mask_input_dim, aggregation_size),
torch.nn.ReLU(),
torch.nn.Linear(aggregation_size, input_dim),
) )
self._mask_layer.apply(_init_weights)
self._hidden_layer = torch.nn.Linear(input_dim, output_size)
self._hidden_layer.apply(_init_weights)
self._layer_norm = torch.nn.LayerNorm(output_size)
# Create an instance of the MaskBlock def forward(self, net: torch.Tensor, mask_input: torch.Tensor):
mask_block = MaskBlock(mask_block_config, input_dim=64, mask_input_dim=32)
# Generate input tensors
input_data = torch.randn(batch_size, 64)
mask_input = torch.randn(batch_size, 32)
# Perform a forward pass through the MaskBlock
output = mask_block(input_data, mask_input)
```
Note:
The `MaskBlock` module applies layer normalization to the input if specified, followed by a masking
operation that combines the input and mask input. Then, it passes the result through a hidden layer
with optional dimensionality reduction.
Warning:
This class is intended for internal use within neural network architectures and should not be
directly accessed or modified by external code.
"""
def __init__(
self, mask_block_config: config.MaskBlockConfig, input_dim: int, mask_input_dim: int
) -> None:
"""
Initializes the MaskBlock module.
Args:
mask_block_config (config.MaskBlockConfig): Configuration for the MaskBlock.
input_dim (int): Dimensionality of the input data.
mask_input_dim (int): Dimensionality of the mask input.
Returns:
None
""" """
Performs a forward pass through the MaskBlock.
super(MaskBlock, self).__init__() Args:
self.mask_block_config = mask_block_config net (torch.Tensor): Input data tensor.
output_size = mask_block_config.output_size mask_input (torch.Tensor): Mask input tensor.
if mask_block_config.input_layer_norm: Returns:
self._input_layer_norm = torch.nn.LayerNorm(input_dim) torch.Tensor: Output tensor of the MaskBlock.
else: """
self._input_layer_norm = None if self._input_layer_norm:
net = self._input_layer_norm(net)
if mask_block_config.reduction_factor: hidden_layer_output = self._hidden_layer(
aggregation_size = int(mask_input_dim * mask_block_config.reduction_factor) net * self._mask_layer(mask_input))
elif mask_block_config.aggregation_size is not None: return self._layer_norm(hidden_layer_output)
aggregation_size = mask_block_config.aggregation_size
else:
raise ValueError("Need one of reduction factor or aggregation size.")
self._mask_layer = torch.nn.Sequential(
torch.nn.Linear(mask_input_dim, aggregation_size),
torch.nn.ReLU(),
torch.nn.Linear(aggregation_size, input_dim),
)
self._mask_layer.apply(_init_weights)
self._hidden_layer = torch.nn.Linear(input_dim, output_size)
self._hidden_layer.apply(_init_weights)
self._layer_norm = torch.nn.LayerNorm(output_size)
def forward(self, net: torch.Tensor, mask_input: torch.Tensor):
"""
Performs a forward pass through the MaskBlock.
Args:
net (torch.Tensor): Input data tensor.
mask_input (torch.Tensor): Mask input tensor.
Returns:
torch.Tensor: Output tensor of the MaskBlock.
"""
if self._input_layer_norm:
net = self._input_layer_norm(net)
hidden_layer_output = self._hidden_layer(net * self._mask_layer(mask_input))
return self._layer_norm(hidden_layer_output)
class MaskNet(torch.nn.Module): class MaskNet(torch.nn.Module):
"""
MaskNet module in a mask-based neural network.
This module represents a MaskNet, which consists of multiple MaskBlocks. It can be used to
create mask-based neural networks with parallel or stacked MaskBlocks.
Args:
mask_net_config (config.MaskNetConfig): Configuration for the MaskNet.
in_features (int): Dimensionality of the input data.
Example:
To create and use a MaskNet, you can follow these steps:
```python
# Define the configuration for the MaskNet
mask_net_config = MaskNetConfig(
use_parallel=True, # Use parallel MaskBlocks
mlp=MlpConfig(layer_sizes=[128, 64]) # Optional MLP on the outputs
)
# Create an instance of the MaskNet
mask_net = MaskNet(mask_net_config, in_features=64)
# Generate input tensors
input_data = torch.randn(batch_size, 64)
# Perform a forward pass through the MaskNet
outputs = mask_net(input_data)
# Access the output and shared layer
output = outputs["output"]
shared_layer = outputs["shared_layer"]
```
Note:
The `MaskNet` module allows you to create mask-based neural networks with parallel or stacked
MaskBlocks. You can also optionally apply an MLP to the outputs for further processing.
Warning:
This class is intended for internal use within neural network architectures and should not be
directly accessed or modified by external code.
""" """
def __init__(self, mask_net_config: config.MaskNetConfig, in_features: int): MaskNet module in a mask-based neural network.
"""
Initializes the MaskNet module.
Args: This module represents a MaskNet, which consists of multiple MaskBlocks. It can be used to
mask_net_config (config.MaskNetConfig): Configuration for the MaskNet. create mask-based neural networks with parallel or stacked MaskBlocks.
in_features (int): Dimensionality of the input data.
Returns: Args:
None mask_net_config (config.MaskNetConfig): Configuration for the MaskNet.
in_features (int): Dimensionality of the input data.
Example:
To create and use a MaskNet, you can follow these steps:
```python
# Define the configuration for the MaskNet
mask_net_config = MaskNetConfig(
use_parallel=True, # Use parallel MaskBlocks
mlp=MlpConfig(layer_sizes=[128, 64]) # Optional MLP on the outputs
)
# Create an instance of the MaskNet
mask_net = MaskNet(mask_net_config, in_features=64)
# Generate input tensors
input_data = torch.randn(batch_size, 64)
# Perform a forward pass through the MaskNet
outputs = mask_net(input_data)
# Access the output and shared layer
output = outputs["output"]
shared_layer = outputs["shared_layer"]
```
Note:
The `MaskNet` module allows you to create mask-based neural networks with parallel or stacked
MaskBlocks. You can also optionally apply an MLP to the outputs for further processing.
Warning:
This class is intended for internal use within neural network architectures and should not be
directly accessed or modified by external code.
"""
def __init__(self, mask_net_config: config.MaskNetConfig, in_features: int):
""" """
Initializes the MaskNet module.
super().__init__() Args:
self.mask_net_config = mask_net_config mask_net_config (config.MaskNetConfig): Configuration for the MaskNet.
mask_blocks = [] in_features (int): Dimensionality of the input data.
if mask_net_config.use_parallel: Returns:
total_output_mask_blocks = 0 None
for mask_block_config in mask_net_config.mask_blocks: """
mask_blocks.append(MaskBlock(mask_block_config, in_features, in_features))
total_output_mask_blocks += mask_block_config.output_size
self._mask_blocks = torch.nn.ModuleList(mask_blocks)
else:
input_size = in_features
for mask_block_config in mask_net_config.mask_blocks:
mask_blocks.append(MaskBlock(mask_block_config, input_size, in_features))
input_size = mask_block_config.output_size
self._mask_blocks = torch.nn.ModuleList(mask_blocks) super().__init__()
total_output_mask_blocks = mask_block_config.output_size self.mask_net_config = mask_net_config
mask_blocks = []
if mask_net_config.mlp: if mask_net_config.use_parallel:
self._dense_layers = mlp.Mlp(total_output_mask_blocks, mask_net_config.mlp) total_output_mask_blocks = 0
self.out_features = mask_net_config.mlp.layer_sizes[-1] for mask_block_config in mask_net_config.mask_blocks:
else: mask_blocks.append(
self.out_features = total_output_mask_blocks MaskBlock(mask_block_config, in_features, in_features))
self.shared_size = total_output_mask_blocks total_output_mask_blocks += mask_block_config.output_size
self._mask_blocks = torch.nn.ModuleList(mask_blocks)
else:
input_size = in_features
for mask_block_config in mask_net_config.mask_blocks:
mask_blocks.append(
MaskBlock(mask_block_config, input_size, in_features))
input_size = mask_block_config.output_size
def forward(self, inputs: torch.Tensor): self._mask_blocks = torch.nn.ModuleList(mask_blocks)
""" total_output_mask_blocks = mask_block_config.output_size
Performs a forward pass through the MaskNet.
Args: if mask_net_config.mlp:
inputs (torch.Tensor): Input data tensor. self._dense_layers = mlp.Mlp(
total_output_mask_blocks, mask_net_config.mlp)
self.out_features = mask_net_config.mlp.layer_sizes[-1]
else:
self.out_features = total_output_mask_blocks
self.shared_size = total_output_mask_blocks
Returns: def forward(self, inputs: torch.Tensor):
torch.Tensor: Output tensor of the MaskNet.
""" """
if self.mask_net_config.use_parallel: Performs a forward pass through the MaskNet.
mask_outputs = []
for mask_layer in self._mask_blocks: Args:
mask_outputs.append(mask_layer(mask_input=inputs, net=inputs)) inputs (torch.Tensor): Input data tensor.
# Share the outputs of the MaskBlocks.
all_mask_outputs = torch.cat(mask_outputs, dim=1) Returns:
output = ( torch.Tensor: Output tensor of the MaskNet.
all_mask_outputs """
if self.mask_net_config.mlp is None if self.mask_net_config.use_parallel:
else self._dense_layers(all_mask_outputs)["output"] mask_outputs = []
) for mask_layer in self._mask_blocks:
return {"output": output, "shared_layer": all_mask_outputs} mask_outputs.append(mask_layer(mask_input=inputs, net=inputs))
else: # Share the outputs of the MaskBlocks.
net = inputs all_mask_outputs = torch.cat(mask_outputs, dim=1)
for mask_layer in self._mask_blocks: output = (
net = mask_layer(net=net, mask_input=inputs) all_mask_outputs
# Share the output of the stacked MaskBlocks. if self.mask_net_config.mlp is None
output = net if self.mask_net_config.mlp is None else self._dense_layers[net]["output"] else self._dense_layers(all_mask_outputs)["output"]
return {"output": output, "shared_layer": net} )
return {"output": output, "shared_layer": all_mask_outputs}
else:
net = inputs
for mask_layer in self._mask_blocks:
net = mask_layer(net=net, mask_input=inputs)
# Share the output of the stacked MaskBlocks.
output = net if self.mask_net_config.mlp is None else self._dense_layers[
net]["output"]
return {"output": output, "shared_layer": net}

View File

@ -5,113 +5,117 @@ from absl import logging
class ModelAndLoss(torch.nn.Module): class ModelAndLoss(torch.nn.Module):
"""
PyTorch module that combines a neural network model and loss function.
This module wraps a neural network model and facilitates the forward pass through the model
while also calculating the loss based on the model's predictions and provided labels.
Args:
model: The torch module to wrap.
loss_fn (Callable): Function for calculating the loss, which should accept logits and labels.
stratifiers (Optional[List[embedding_config_mod.StratifierConfig]]): A list of stratifier configurations
for metrics stratification. Each stratifier config includes the name and index of discrete features
to emit for stratification.
Example:
To use `ModelAndLoss` in a PyTorch training loop, you can create an instance of it and pass your model
and loss function as arguments:
```python
# Create a neural network model
model = YourNeuralNetworkModel()
# Define a loss function
loss_fn = torch.nn.CrossEntropyLoss()
# Create an instance of ModelAndLoss
model_and_loss = ModelAndLoss(model, loss_fn)
# Generate a batch of training data (e.g., RecapBatch)
batch = generate_training_batch()
# Perform a forward pass through the model and calculate the loss
loss, outputs = model_and_loss(batch)
# You can now backpropagate and optimize using the computed loss
loss.backward()
optimizer.step()
```
Note:
The `ModelAndLoss` class simplifies the process of running forward passes through a model and
calculating loss, making it easier to integrate the model into your training loop. Additionally,
it supports the addition of stratifiers for metrics stratification, if needed.
Warning:
This class is intended for internal use within neural network architectures and should not be
directly accessed or modified by external code.
""" """
def __init__( PyTorch module that combines a neural network model and loss function.
self,
model,
loss_fn: Callable,
stratifiers: Optional[List[embedding_config_mod.StratifierConfig]] = None,
) -> None:
"""
Initializes the ModelAndLoss module.
Args: This module wraps a neural network model and facilitates the forward pass through the model
model: The torch module to wrap. while also calculating the loss based on the model's predictions and provided labels.
loss_fn (Callable): Function for calculating the loss, which should accept logits and labels.
stratifiers (Optional[List[embedding_config_mod.StratifierConfig]]): A list of stratifier configurations
for metrics stratification.
"""
super().__init__()
self.model = model
self.loss_fn = loss_fn
self.stratifiers = stratifiers
def forward(self, batch: "RecapBatch"): # type: ignore[name-defined] Args:
"""Runs model forward and calculates loss according to given loss_fn. model: The torch module to wrap.
loss_fn (Callable): Function for calculating the loss, which should accept logits and labels.
stratifiers (Optional[List[embedding_config_mod.StratifierConfig]]): A list of stratifier configurations
for metrics stratification. Each stratifier config includes the name and index of discrete features
to emit for stratification.
NOTE: The input signature here needs to be a Pipelineable object for Example:
prefetching purposes during training using torchrec's pipeline. However To use `ModelAndLoss` in a PyTorch training loop, you can create an instance of it and pass your model
the underlying model signature needs to be exportable to onnx, requiring and loss function as arguments:
generic python types. see https://pytorch.org/docs/stable/onnx.html#types.
""" ```python
outputs = self.model( # Create a neural network model
continuous_features=batch.continuous_features, model = YourNeuralNetworkModel()
binary_features=batch.binary_features,
discrete_features=batch.discrete_features,
sparse_features=batch.sparse_features,
user_embedding=batch.user_embedding,
user_eng_embedding=batch.user_eng_embedding,
author_embedding=batch.author_embedding,
labels=batch.labels,
weights=batch.weights,
)
losses = self.loss_fn(outputs["logits"], batch.labels.float(), batch.weights.float())
if self.stratifiers: # Define a loss function
logging.info(f"***** Adding stratifiers *****\n {self.stratifiers}") loss_fn = torch.nn.CrossEntropyLoss()
outputs["stratifiers"] = {}
for stratifier in self.stratifiers:
outputs["stratifiers"][stratifier.name] = batch.discrete_features[:, stratifier.index]
# In general, we can have a large number of losses returned by our loss function. # Create an instance of ModelAndLoss
if isinstance(losses, dict): model_and_loss = ModelAndLoss(model, loss_fn)
return losses["loss"], {
**outputs, # Generate a batch of training data (e.g., RecapBatch)
**losses, batch = generate_training_batch()
"labels": batch.labels,
"weights": batch.weights, # Perform a forward pass through the model and calculate the loss
} loss, outputs = model_and_loss(batch)
else: # Assume that this is a float.
return losses, { # You can now backpropagate and optimize using the computed loss
**outputs, loss.backward()
"loss": losses, optimizer.step()
"labels": batch.labels, ```
"weights": batch.weights,
} Note:
The `ModelAndLoss` class simplifies the process of running forward passes through a model and
calculating loss, making it easier to integrate the model into your training loop. Additionally,
it supports the addition of stratifiers for metrics stratification, if needed.
Warning:
This class is intended for internal use within neural network architectures and should not be
directly accessed or modified by external code.
"""
def __init__(
self,
model,
loss_fn: Callable,
stratifiers: Optional[List[embedding_config_mod.StratifierConfig]] = None,
) -> None:
"""
Initializes the ModelAndLoss module.
Args:
model: The torch module to wrap.
loss_fn (Callable): Function for calculating the loss, which should accept logits and labels.
stratifiers (Optional[List[embedding_config_mod.StratifierConfig]]): A list of stratifier configurations
for metrics stratification.
"""
super().__init__()
self.model = model
self.loss_fn = loss_fn
self.stratifiers = stratifiers
def forward(self, batch: "RecapBatch"): # type: ignore[name-defined]
"""Runs model forward and calculates loss according to given loss_fn.
NOTE: The input signature here needs to be a Pipelineable object for
prefetching purposes during training using torchrec's pipeline. However
the underlying model signature needs to be exportable to onnx, requiring
generic python types. see https://pytorch.org/docs/stable/onnx.html#types.
"""
outputs = self.model(
continuous_features=batch.continuous_features,
binary_features=batch.binary_features,
discrete_features=batch.discrete_features,
sparse_features=batch.sparse_features,
user_embedding=batch.user_embedding,
user_eng_embedding=batch.user_eng_embedding,
author_embedding=batch.author_embedding,
labels=batch.labels,
weights=batch.weights,
)
losses = self.loss_fn(
outputs["logits"], batch.labels.float(), batch.weights.float())
if self.stratifiers:
logging.info(
f"***** Adding stratifiers *****\n {self.stratifiers}")
outputs["stratifiers"] = {}
for stratifier in self.stratifiers:
outputs["stratifiers"][stratifier.name] = batch.discrete_features[:,
stratifier.index]
# In general, we can have a large number of losses returned by our loss function.
if isinstance(losses, dict):
return losses["loss"], {
**outputs,
**losses,
"labels": batch.labels,
"weights": batch.weights,
}
else: # Assume that this is a float.
return losses, {
**outputs,
"loss": losses,
"labels": batch.labels,
"weights": batch.weights,
}

View File

@ -2,64 +2,65 @@ import torch
class NumericCalibration(torch.nn.Module): class NumericCalibration(torch.nn.Module):
"""
Numeric calibration module for adjusting probability scores.
This module scales probability scores to correct for imbalanced datasets, where positive and negative samples
may be underrepresented or have different ratios. It is designed to be used as a component in a neural network
for tasks such as binary classification.
Args:
pos_downsampling_rate (float): The downsampling rate for positive samples.
neg_downsampling_rate (float): The downsampling rate for negative samples.
Example:
To use `NumericCalibration` in a PyTorch model, you can create an instance of it and apply it to probability
scores like this:
```python
# Create a NumericCalibration instance with downsampling rates
calibration = NumericCalibration(pos_downsampling_rate=0.1, neg_downsampling_rate=0.2)
# Generate probability scores (e.g., from a neural network)
raw_probs = torch.tensor([0.8, 0.6, 0.2, 0.9])
# Apply numeric calibration to adjust the probabilities
calibrated_probs = calibration(raw_probs)
# The `calibrated_probs` now contains the adjusted probability scores
```
Note:
The `NumericCalibration` module is used to adjust probability scores to account for differences in
the number of positive and negative samples in a dataset. It can help improve the calibration of
probability estimates in imbalanced classification problems.
Warning:
This class is intended for internal use within neural network architectures and should not be
directly accessed or modified by external code.
""" """
def __init__( Numeric calibration module for adjusting probability scores.
self,
pos_downsampling_rate: float,
neg_downsampling_rate: float,
):
"""
Apply numeric calibration to probability scores.
Args: This module scales probability scores to correct for imbalanced datasets, where positive and negative samples
probs (torch.Tensor): Probability scores to be calibrated. may be underrepresented or have different ratios. It is designed to be used as a component in a neural network
for tasks such as binary classification.
Returns: Args:
torch.Tensor: Calibrated probability scores. pos_downsampling_rate (float): The downsampling rate for positive samples.
neg_downsampling_rate (float): The downsampling rate for negative samples.
Example:
To use `NumericCalibration` in a PyTorch model, you can create an instance of it and apply it to probability
scores like this:
```python
# Create a NumericCalibration instance with downsampling rates
calibration = NumericCalibration(pos_downsampling_rate=0.1, neg_downsampling_rate=0.2)
# Generate probability scores (e.g., from a neural network)
raw_probs = torch.tensor([0.8, 0.6, 0.2, 0.9])
# Apply numeric calibration to adjust the probabilities
calibrated_probs = calibration(raw_probs)
# The `calibrated_probs` now contains the adjusted probability scores
```
Note:
The `NumericCalibration` module is used to adjust probability scores to account for differences in
the number of positive and negative samples in a dataset. It can help improve the calibration of
probability estimates in imbalanced classification problems.
Warning:
This class is intended for internal use within neural network architectures and should not be
directly accessed or modified by external code.
"""
def __init__(
self,
pos_downsampling_rate: float,
neg_downsampling_rate: float,
):
""" """
super().__init__() Apply numeric calibration to probability scores.
# Using buffer to make sure they are on correct device (and not moved every time). Args:
# Will also be part of state_dict. probs (torch.Tensor): Probability scores to be calibrated.
self.register_buffer(
"ratio", torch.as_tensor(neg_downsampling_rate / pos_downsampling_rate), persistent=True
)
def forward(self, probs: torch.Tensor): Returns:
return probs * self.ratio / (1.0 - probs + (self.ratio * probs)) torch.Tensor: Calibrated probability scores.
"""
super().__init__()
# Using buffer to make sure they are on correct device (and not moved every time).
# Will also be part of state_dict.
self.register_buffer(
"ratio", torch.as_tensor(neg_downsampling_rate / pos_downsampling_rate), persistent=True
)
def forward(self, probs: torch.Tensor):
return probs * self.ratio / (1.0 - probs + (self.ratio * probs))

View File

@ -38,6 +38,15 @@ import pyarrow.parquet as pq
def _create_dataset(path: str): def _create_dataset(path: str):
"""
Create a PyArrow dataset from Parquet files located at the specified path.
Args:
path (str): The path to the Parquet files.
Returns:
pyarrow.dataset.Dataset: The PyArrow dataset.
"""
fs = infer_fs(path) fs = infer_fs(path)
files = fs.glob(path) files = fs.glob(path)
return pads.dataset(files, format="parquet", filesystem=fs) return pads.dataset(files, format="parquet", filesystem=fs)
@ -47,12 +56,27 @@ class PqReader:
def __init__( def __init__(
self, path: str, num: int = 10, batch_size: int = 1024, columns: Optional[List[str]] = None self, path: str, num: int = 10, batch_size: int = 1024, columns: Optional[List[str]] = None
): ):
"""
Initialize a Parquet Reader.
Args:
path (str): The path to the Parquet files.
num (int): The maximum number of rows to read.
batch_size (int): The batch size for reading data.
columns (Optional[List[str]]): A list of column names to read (default is None, which reads all columns).
"""
self._ds = _create_dataset(path) self._ds = _create_dataset(path)
self._batch_size = batch_size self._batch_size = batch_size
self._num = num self._num = num
self._columns = columns self._columns = columns
def __iter__(self): def __iter__(self):
"""
Iterate through the Parquet data and yield batches of rows.
Yields:
pyarrow.RecordBatch: A batch of rows.
"""
batches = self._ds.to_batches(batch_size=self._batch_size, columns=self._columns) batches = self._ds.to_batches(batch_size=self._batch_size, columns=self._columns)
rows_seen = 0 rows_seen = 0
for count, record in enumerate(batches): for count, record in enumerate(batches):
@ -62,6 +86,12 @@ class PqReader:
rows_seen += record.data.num_rows rows_seen += record.data.num_rows
def _head(self): def _head(self):
"""
Get the first `num` rows of the Parquet data.
Returns:
pyarrow.RecordBatch: A batch of rows.
"""
total_read = self._num * self.bytes_per_row total_read = self._num * self.bytes_per_row
if total_read >= int(500e6): if total_read >= int(500e6):
raise Exception( raise Exception(
@ -71,6 +101,12 @@ class PqReader:
@property @property
def bytes_per_row(self) -> int: def bytes_per_row(self) -> int:
"""
Calculate the estimated bytes per row in the dataset.
Returns:
int: The estimated bytes per row.
"""
nbits = 0 nbits = 0
for t in self._ds.schema.types: for t in self._ds.schema.types:
try: try:
@ -81,18 +117,23 @@ class PqReader:
return nbits // 8 return nbits // 8
def schema(self): def schema(self):
"""
Display the schema of the Parquet dataset.
"""
print(f"\n# Schema\n{self._ds.schema}") print(f"\n# Schema\n{self._ds.schema}")
def head(self): def head(self):
"""Displays first --num rows.""" """
Display the first `num` rows of the Parquet data as a pandas DataFrame.
"""
print(self._head().to_pandas()) print(self._head().to_pandas())
def distinct(self): def distinct(self):
"""Displays unique values seen in specified columns in the first `--num` rows.
Useful for getting an approximate vocabulary for certain columns.
""" """
Display unique values seen in specified columns in the first `num` rows.
Useful for getting an approximate vocabulary for certain columns.
"""
for col_name, column in zip(self._head().column_names, self._head().columns): for col_name, column in zip(self._head().column_names, self._head().columns):
print(col_name) print(col_name)
print("unique:", column.unique().to_pylist()) print("unique:", column.unique().to_pylist())