From f7f26d0c2046b34540ab631b08f32111baed11bc Mon Sep 17 00:00:00 2001 From: rajveer43 Date: Thu, 21 Sep 2023 22:53:34 +0530 Subject: [PATCH] Updates --- model.py | 29 +- projects/home/recap/data/config.py | 357 +++++++-------- projects/home/recap/data/preprocessors.py | 116 ++++- projects/home/recap/data/tfe_parsing.py | 179 ++++---- projects/home/recap/data/util.py | 211 +++++---- projects/home/recap/model/mask_net.py | 413 +++++++++--------- projects/home/recap/model/model_and_loss.py | 210 ++++----- .../home/recap/model/numeric_calibration.py | 111 ++--- tools/pq.py | 51 ++- 9 files changed, 953 insertions(+), 724 deletions(-) diff --git a/model.py b/model.py index 9df13bc..b8941dd 100644 --- a/model.py +++ b/model.py @@ -54,13 +54,21 @@ def maybe_shard_model( model, device: torch.device, ): - """Set up and apply DistributedModelParallel to a model if running in a distributed environment. + """ + Set up and apply DistributedModelParallel to a model if running in a distributed environment. If in a distributed environment, constructs Topology, sharders, and ShardingPlan, then applies DistributedModelParallel. - If not in a distributed environment, returns model directly. - """ + If not in a distributed environment, returns the model directly. + + Args: + model: The PyTorch model. + device: The target device (e.g., 'cuda'). + + Returns: + The model wrapped with DistributedModelParallel if in a distributed environment, else the original model. + """ if dist.is_initialized(): logging.info("***** Wrapping in DistributedModelParallel *****") logging.info(f"Model before wrapping: {model}") @@ -74,14 +82,15 @@ def maybe_shard_model( def log_sharded_tensor_content(weight_name: str, table_name: str, weight_tensor) -> None: - """Handy function to log the content of EBC embedding layer. - Only works for single GPU machines. - - Args: - weight_name: name of tensor, as defined in model - table_name: name of the EBC table the weight is taken from - weight_tensor: embedding weight tensor """ + Handy function to log the content of an EBC (Embedding Bag Concatenation) embedding layer. + Only works for single GPU machines. + + Args: + weight_name: Name of the tensor, as defined in the model. + table_name: Name of the EBC table the weight is taken from. + weight_tensor: Embedding weight tensor. + """ logging.info(f"{weight_name}, {table_name}", rank=-1) logging.info(f"{weight_tensor.metadata()}", rank=-1) output_tensor = torch.zeros(*weight_tensor.size(), device=torch.device("cuda:0")) diff --git a/projects/home/recap/data/config.py b/projects/home/recap/data/config.py index 27ef3ed..c5ee4c0 100644 --- a/projects/home/recap/data/config.py +++ b/projects/home/recap/data/config.py @@ -8,239 +8,250 @@ import pydantic class ExplicitDateInputs(base_config.BaseConfig): - """Arguments to select train/validation data using end_date and days of data.""" + """Arguments to select train/validation data using end_date and days of data.""" - data_root: str = pydantic.Field(..., description="Data path prefix.") - end_date: str = pydantic.Field(..., description="Data end date, inclusive.") - days: int = pydantic.Field(..., description="Number of days of data for dataset.") - num_missing_days_tol: int = pydantic.Field( - 0, description="We tolerate <= num_missing_days_tol days of missing data." - ) + data_root: str = pydantic.Field(..., description="Data path prefix.") + end_date: str = pydantic.Field(..., + description="Data end date, inclusive.") + days: int = pydantic.Field(..., + description="Number of days of data for dataset.") + num_missing_days_tol: int = pydantic.Field( + 0, description="We tolerate <= num_missing_days_tol days of missing data." + ) class ExplicitDatetimeInputs(base_config.BaseConfig): - """Arguments to select train/validation data using end_datetime and hours of data.""" + """Arguments to select train/validation data using end_datetime and hours of data.""" - data_root: str = pydantic.Field(..., description="Data path prefix.") - end_datetime: str = pydantic.Field(..., description="Data end datetime, inclusive.") - hours: int = pydantic.Field(..., description="Number of hours of data for dataset.") - num_missing_hours_tol: int = pydantic.Field( - 0, description="We tolerate <= num_missing_hours_tol hours of missing data." - ) + data_root: str = pydantic.Field(..., description="Data path prefix.") + end_datetime: str = pydantic.Field(..., + description="Data end datetime, inclusive.") + hours: int = pydantic.Field(..., + description="Number of hours of data for dataset.") + num_missing_hours_tol: int = pydantic.Field( + 0, description="We tolerate <= num_missing_hours_tol hours of missing data." + ) class DdsCompressionOption(str, Enum): - """The only valid compression option is 'AUTO'""" + """The only valid compression option is 'AUTO'""" - AUTO = "AUTO" + AUTO = "AUTO" class DatasetConfig(base_config.BaseConfig): - inputs: str = pydantic.Field( - None, description="A glob for selecting data.", one_of="date_inputs_format" - ) - explicit_datetime_inputs: ExplicitDatetimeInputs = pydantic.Field( - None, one_of="date_inputs_format" - ) - explicit_date_inputs: ExplicitDateInputs = pydantic.Field(None, one_of="date_inputs_format") + inputs: str = pydantic.Field( + None, description="A glob for selecting data.", one_of="date_inputs_format" + ) + explicit_datetime_inputs: ExplicitDatetimeInputs = pydantic.Field( + None, one_of="date_inputs_format" + ) + explicit_date_inputs: ExplicitDateInputs = pydantic.Field( + None, one_of="date_inputs_format") - global_batch_size: pydantic.PositiveInt + global_batch_size: pydantic.PositiveInt - num_files_to_keep: pydantic.PositiveInt = pydantic.Field( - None, description="Number of shards to keep." - ) - repeat_files: bool = pydantic.Field( - True, description="DEPRICATED. Files are repeated no matter what this is set to." - ) - file_batch_size: pydantic.PositiveInt = pydantic.Field(16, description="File batch size") + num_files_to_keep: pydantic.PositiveInt = pydantic.Field( + None, description="Number of shards to keep." + ) + repeat_files: bool = pydantic.Field( + True, description="DEPRICATED. Files are repeated no matter what this is set to." + ) + file_batch_size: pydantic.PositiveInt = pydantic.Field( + 16, description="File batch size") - cache: bool = pydantic.Field( - False, - description="Cache dataset in memory. Careful to only use this when you" - " have enough memory to fit entire dataset.", - ) + cache: bool = pydantic.Field( + False, + description="Cache dataset in memory. Careful to only use this when you" + " have enough memory to fit entire dataset.", + ) - data_service_dispatcher: str = pydantic.Field(None) - ignore_data_errors: bool = pydantic.Field( - False, description="Whether to ignore tf.data errors. DANGER DANGER, may wedge jobs." - ) - dataset_service_compression: DdsCompressionOption = pydantic.Field( - None, - description="Compress the dataset for DDS worker -> training host. Disabled by default and the only valid option is 'AUTO'", - ) + data_service_dispatcher: str = pydantic.Field(None) + ignore_data_errors: bool = pydantic.Field( + False, description="Whether to ignore tf.data errors. DANGER DANGER, may wedge jobs." + ) + dataset_service_compression: DdsCompressionOption = pydantic.Field( + None, + description="Compress the dataset for DDS worker -> training host. Disabled by default and the only valid option is 'AUTO'", + ) - # tf.data.Dataset options - examples_shuffle_buffer_size: int = pydantic.Field(1024, description="Size of shuffle buffers.") - map_num_parallel_calls: pydantic.PositiveInt = pydantic.Field( - None, description="Number of parallel calls." - ) - interleave_num_parallel_calls: pydantic.PositiveInt = pydantic.Field( - None, description="Number of shards to interleave." - ) + # tf.data.Dataset options + examples_shuffle_buffer_size: int = pydantic.Field( + 1024, description="Size of shuffle buffers.") + map_num_parallel_calls: pydantic.PositiveInt = pydantic.Field( + None, description="Number of parallel calls." + ) + interleave_num_parallel_calls: pydantic.PositiveInt = pydantic.Field( + None, description="Number of shards to interleave." + ) class TruncateAndSlice(base_config.BaseConfig): - # Apply truncation and then slice. - continuous_feature_truncation: pydantic.PositiveInt = pydantic.Field( - None, description="Experimental. Truncates continuous features to this amount for efficiency." - ) - binary_feature_truncation: pydantic.PositiveInt = pydantic.Field( - None, description="Experimental. Truncates binary features to this amount for efficiency." - ) + # Apply truncation and then slice. + continuous_feature_truncation: pydantic.PositiveInt = pydantic.Field( + None, description="Experimental. Truncates continuous features to this amount for efficiency." + ) + binary_feature_truncation: pydantic.PositiveInt = pydantic.Field( + None, description="Experimental. Truncates binary features to this amount for efficiency." + ) - continuous_feature_mask_path: str = pydantic.Field( - None, description="Path of mask used to slice input continuous features." - ) - binary_feature_mask_path: str = pydantic.Field( - None, description="Path of mask used to slice input binary features." - ) + continuous_feature_mask_path: str = pydantic.Field( + None, description="Path of mask used to slice input continuous features." + ) + binary_feature_mask_path: str = pydantic.Field( + None, description="Path of mask used to slice input binary features." + ) class DataType(str, Enum): - BFLOAT16 = "bfloat16" - BOOL = "bool" + BFLOAT16 = "bfloat16" + BOOL = "bool" - FLOAT32 = "float32" - FLOAT16 = "float16" + FLOAT32 = "float32" + FLOAT16 = "float16" - UINT8 = "uint8" + UINT8 = "uint8" class DownCast(base_config.BaseConfig): - # Apply down casting to selected features. - features: typing.Dict[str, DataType] = pydantic.Field( - None, description="Map features to down cast data types." - ) + # Apply down casting to selected features. + features: typing.Dict[str, DataType] = pydantic.Field( + None, description="Map features to down cast data types." + ) class TaskData(base_config.BaseConfig): - pos_downsampling_rate: float = pydantic.Field( - 1.0, - description="Downsampling rate of positives used to generate dataset.", - ) - neg_downsampling_rate: float = pydantic.Field( - 1.0, - description="Downsampling rate of negatives used to generate dataset.", - ) + pos_downsampling_rate: float = pydantic.Field( + 1.0, + description="Downsampling rate of positives used to generate dataset.", + ) + neg_downsampling_rate: float = pydantic.Field( + 1.0, + description="Downsampling rate of negatives used to generate dataset.", + ) class SegDenseSchema(base_config.BaseConfig): - schema_path: str = pydantic.Field(..., description="Path to feature config json.") - features: typing.List[str] = pydantic.Field( - [], - description="List of features (in addition to the renamed features) to read from schema path above.", - ) - renamed_features: typing.Dict[str, str] = pydantic.Field( - {}, description="Dictionary of renamed features." - ) - mask_mantissa_features: typing.Dict[str, int] = pydantic.Field( - {}, - description="(experimental) Number of mantissa bits to mask to simulate lower precision data.", - ) + schema_path: str = pydantic.Field(..., + description="Path to feature config json.") + features: typing.List[str] = pydantic.Field( + [], + description="List of features (in addition to the renamed features) to read from schema path above.", + ) + renamed_features: typing.Dict[str, str] = pydantic.Field( + {}, description="Dictionary of renamed features." + ) + mask_mantissa_features: typing.Dict[str, int] = pydantic.Field( + {}, + description="(experimental) Number of mantissa bits to mask to simulate lower precision data.", + ) class RectifyLabels(base_config.BaseConfig): - label_rectification_window_in_hours: float = pydantic.Field( - 3.0, description="overlap time in hours for which to flip labels" - ) - served_timestamp_field: str = pydantic.Field( - ..., description="input field corresponding to served time" - ) - impressed_timestamp_field: str = pydantic.Field( - ..., description="input field corresponding to impressed time" - ) - label_to_engaged_timestamp_field: typing.Dict[str, str] = pydantic.Field( - ..., description="label to the input field corresponding to engagement time" - ) + label_rectification_window_in_hours: float = pydantic.Field( + 3.0, description="overlap time in hours for which to flip labels" + ) + served_timestamp_field: str = pydantic.Field( + ..., description="input field corresponding to served time" + ) + impressed_timestamp_field: str = pydantic.Field( + ..., description="input field corresponding to impressed time" + ) + label_to_engaged_timestamp_field: typing.Dict[str, str] = pydantic.Field( + ..., description="label to the input field corresponding to engagement time" + ) class ExtractFeaturesRow(base_config.BaseConfig): - name: str = pydantic.Field( - ..., - description="name of the new field name to be created", - ) - source_tensor: str = pydantic.Field( - ..., - description="name of the dense tensor to look for the feature", - ) - index: int = pydantic.Field( - ..., - description="index of the feature in the dense tensor", - ) + name: str = pydantic.Field( + ..., + description="name of the new field name to be created", + ) + source_tensor: str = pydantic.Field( + ..., + description="name of the dense tensor to look for the feature", + ) + index: int = pydantic.Field( + ..., + description="index of the feature in the dense tensor", + ) class ExtractFeatures(base_config.BaseConfig): - extract_feature_table: typing.List[ExtractFeaturesRow] = pydantic.Field( - [], - description="list of features to be extracted with their name, source tensor and index", - ) + extract_feature_table: typing.List[ExtractFeaturesRow] = pydantic.Field( + [], + description="list of features to be extracted with their name, source tensor and index", + ) class DownsampleNegatives(base_config.BaseConfig): - batch_multiplier: int = pydantic.Field( - None, - description="batch multiplier", - ) - engagements_list: typing.List[str] = pydantic.Field( - [], - description="engagements with kept positives", - ) - num_engagements: int = pydantic.Field( - ..., - description="number engagements used in the model, including ones excluded in engagements_list", - ) + batch_multiplier: int = pydantic.Field( + None, + description="batch multiplier", + ) + engagements_list: typing.List[str] = pydantic.Field( + [], + description="engagements with kept positives", + ) + num_engagements: int = pydantic.Field( + ..., + description="number engagements used in the model, including ones excluded in engagements_list", + ) class Preprocess(base_config.BaseConfig): - truncate_and_slice: TruncateAndSlice = pydantic.Field(None, description="Truncation and slicing.") - downcast: DownCast = pydantic.Field(None, description="Down cast to features.") - rectify_labels: RectifyLabels = pydantic.Field( - None, description="Rectify labels for a given overlap window" - ) - extract_features: ExtractFeatures = pydantic.Field( - None, description="Extract features from dense tensors." - ) - downsample_negatives: DownsampleNegatives = pydantic.Field( - None, description="Downsample negatives." - ) + truncate_and_slice: TruncateAndSlice = pydantic.Field( + None, description="Truncation and slicing.") + downcast: DownCast = pydantic.Field( + None, description="Down cast to features.") + rectify_labels: RectifyLabels = pydantic.Field( + None, description="Rectify labels for a given overlap window" + ) + extract_features: ExtractFeatures = pydantic.Field( + None, description="Extract features from dense tensors." + ) + downsample_negatives: DownsampleNegatives = pydantic.Field( + None, description="Downsample negatives." + ) class Sampler(base_config.BaseConfig): - """Assumes function is defined in data/samplers.py. + """Assumes function is defined in data/samplers.py. - Only use this for quick experimentation. - If samplers are useful, we should sample from upstream data generation. + Only use this for quick experimentation. + If samplers are useful, we should sample from upstream data generation. - DEPRICATED, DO NOT USE. - """ + DEPRICATED, DO NOT USE. + """ - name: str - kwargs: typing.Dict + name: str + kwargs: typing.Dict class RecapDataConfig(DatasetConfig): - seg_dense_schema: SegDenseSchema + seg_dense_schema: SegDenseSchema - tasks: typing.Dict[str, TaskData] = pydantic.Field( - description="Description of individual tasks in this dataset." - ) - evaluation_tasks: typing.List[str] = pydantic.Field( - [], description="If specified, lists the tasks we're generating metrics for." - ) + tasks: typing.Dict[str, TaskData] = pydantic.Field( + description="Description of individual tasks in this dataset." + ) + evaluation_tasks: typing.List[str] = pydantic.Field( + [], description="If specified, lists the tasks we're generating metrics for." + ) - preprocess: Preprocess = pydantic.Field( - None, description="Function run in tf.data.Dataset at train/eval, in-graph at inference." - ) + preprocess: Preprocess = pydantic.Field( + None, description="Function run in tf.data.Dataset at train/eval, in-graph at inference." + ) - sampler: Sampler = pydantic.Field( - None, - description="""DEPRICATED, DO NOT USE. Sampling function for offline experiments.""", - ) + sampler: Sampler = pydantic.Field( + None, + description="""DEPRICATED, DO NOT USE. Sampling function for offline experiments.""", + ) - @pydantic.root_validator() - def _validate_evaluation_tasks(cls, values): - if values.get("evaluation_tasks") is not None: - for task in values["evaluation_tasks"]: - if task not in values["tasks"]: - raise KeyError(f"Evaluation task {task} must be in tasks. Received {values['tasks']}") - return values + @pydantic.root_validator() + def _validate_evaluation_tasks(cls, values): + if values.get("evaluation_tasks") is not None: + for task in values["evaluation_tasks"]: + if task not in values["tasks"]: + raise KeyError( + f"Evaluation task {task} must be in tasks. Received {values['tasks']}") + return values diff --git a/projects/home/recap/data/preprocessors.py b/projects/home/recap/data/preprocessors.py index d5720e2..efa3042 100644 --- a/projects/home/recap/data/preprocessors.py +++ b/projects/home/recap/data/preprocessors.py @@ -9,9 +9,20 @@ import numpy as np class TruncateAndSlice(tf.keras.Model): - """Class for truncating and slicing.""" + """ + A class for truncating and slicing input features based on the provided configuration. + + Args: + truncate_and_slice_config: A configuration object specifying how to truncate and slice features. + """ def __init__(self, truncate_and_slice_config): + """ + Initializes the TruncateAndSlice model. + + Args: + truncate_and_slice_config: A configuration object specifying how to truncate and slice features. + """ super().__init__() self._truncate_and_slice_config = truncate_and_slice_config @@ -32,6 +43,17 @@ class TruncateAndSlice(tf.keras.Model): self._binary_mask = None def call(self, inputs, training=None, mask=None): + """ + Applies truncation and slicing to the input features based on the configuration. + + Args: + inputs: A dictionary of input features. + training: A boolean indicating whether the model is in training mode. + mask: A mask tensor. + + Returns: + A dictionary of truncated and sliced input features. + """ outputs = tf.nest.pack_sequence_as(inputs, tf.nest.flatten(inputs)) if self._truncate_and_slice_config.continuous_feature_truncation: logging.info("Truncating continuous") @@ -51,12 +73,23 @@ class TruncateAndSlice(tf.keras.Model): class DownCast(tf.keras.Model): - """Class for Down casting dataset before serialization and transferring to training host. - Depends on the data type and the actual data range, the down casting can be lossless or not. - It is strongly recommended to compare the metrics before and after down casting. """ + A class for downcasting dataset before serialization and transferring to the training host. + + Depending on the data type and the actual data range, the downcasting can be lossless or not. + It is strongly recommended to compare the metrics before and after downcasting. + + Args: + downcast_config: A configuration object specifying the features and their target data types. + """ def __init__(self, downcast_config): + """ + Initializes the DownCast model. + + Args: + downcast_config: A configuration object specifying the features and their target data types. + """ super().__init__() self.config = downcast_config self._type_map = { @@ -65,6 +98,17 @@ class DownCast(tf.keras.Model): } def call(self, inputs, training=None, mask=None): + """ + Applies downcasting to the input features based on the configuration. + + Args: + inputs: A dictionary of input features. + training: A boolean indicating whether the model is in training mode. + mask: A mask tensor. + + Returns: + A dictionary of downcasted input features. + """ outputs = tf.nest.pack_sequence_as(inputs, tf.nest.flatten(inputs)) for feature, type_str in self.config.features.items(): assert type_str in self._type_map @@ -78,14 +122,39 @@ class DownCast(tf.keras.Model): class RectifyLabels(tf.keras.Model): - """Class for rectifying labels""" + """ + A class for downcasting dataset before serialization and transferring to the training host. + + Depending on the data type and the actual data range, the downcasting can be lossless or not. + It is strongly recommended to compare the metrics before and after downcasting. + + Args: + downcast_config: A configuration object specifying the features and their target data types. + """ def __init__(self, rectify_label_config): + """ + Initializes the DownCast model. + + Args: + downcast_config: A configuration object specifying the features and their target data types. + """ super().__init__() self._config = rectify_label_config self._window = int(self._config.label_rectification_window_in_hours * 60 * 60 * 1000) def call(self, inputs, training=None, mask=None): + """ + Applies downcasting to the input features based on the configuration. + + Args: + inputs: A dictionary of input features. + training: A boolean indicating whether the model is in training mode. + mask: A mask tensor. + + Returns: + A dictionary of downcasted input features. + """ served_ts_field = self._config.served_timestamp_field impressed_ts_field = self._config.impressed_timestamp_field @@ -102,13 +171,37 @@ class RectifyLabels(tf.keras.Model): class ExtractFeatures(tf.keras.Model): - """Class for extracting individual features from dense tensors by their index.""" + """ + A class for rectifying labels based on specified conditions. + + This class is used to adjust label values in a dataset based on configured conditions involving timestamps. + + Args: + rectify_label_config: A configuration object specifying the timestamp fields and label-to-engaged timestamp field mappings. + """ def __init__(self, extract_features_config): + """ + Initializes the RectifyLabels model. + + Args: + rectify_label_config: A configuration object specifying the timestamp fields and label-to-engaged timestamp field mappings. + """ super().__init__() self._config = extract_features_config def call(self, inputs, training=None, mask=None): + """ + Rectifies label values based on the specified conditions. + + Args: + inputs: A dictionary of input features including timestamp fields and labels. + training: A boolean indicating whether the model is in training mode. + mask: A mask tensor. + + Returns: + A dictionary of input features with rectified label values. + """ for row in self._config.extract_feature_table: inputs[row.name] = inputs[row.source_tensor][:, row.index] @@ -168,7 +261,16 @@ class DownsampleNegatives(tf.keras.Model): def build_preprocess(preprocess_config, mode=config_mod.JobMode.TRAIN): - """Builds a preprocess model to apply all preprocessing stages.""" + """ + Builds a preprocess model to apply all preprocessing stages. + + Args: + preprocess_config: A configuration object specifying the preprocessing parameters. + mode: A mode indicating the current job mode (TRAIN or INFERENCE). + + Returns: + A preprocess model that applies all specified preprocessing stages. + """ if mode == config_mod.JobMode.INFERENCE: logging.info("Not building preprocessors for dataloading since we are in Inference mode.") return None diff --git a/projects/home/recap/data/tfe_parsing.py b/projects/home/recap/data/tfe_parsing.py index f597746..aff73d2 100644 --- a/projects/home/recap/data/tfe_parsing.py +++ b/projects/home/recap/data/tfe_parsing.py @@ -8,122 +8,129 @@ import tensorflow as tf DEFAULTS_MAP = {"int64_list": 0, "float_list": 0.0, "bytes_list": ""} -DTYPE_MAP = {"int64_list": tf.int64, "float_list": tf.float32, "bytes_list": tf.string} +DTYPE_MAP = {"int64_list": tf.int64, + "float_list": tf.float32, "bytes_list": tf.string} def create_tf_example_schema( - data_config: recap_data_config.SegDenseSchema, - segdense_schema, + data_config: recap_data_config.SegDenseSchema, + segdense_schema, ): - """Generate schema for deseralizing tf.Example. + """Generate schema for deseralizing tf.Example. - Args: - segdense_schema: List of dicts of segdense features (includes feature_name, dtype, length). - labels: List of strings denoting labels. + Args: + segdense_schema: List of dicts of segdense features (includes feature_name, dtype, length). + labels: List of strings denoting labels. - Returns: - A dictionary schema suitable for deserializing tf.Example. - """ - segdense_config = data_config.seg_dense_schema - labels = list(data_config.tasks.keys()) - used_features = ( - segdense_config.features + list(segdense_config.renamed_features.values()) + labels - ) - logging.info(used_features) + Returns: + A dictionary schema suitable for deserializing tf.Example. + """ + segdense_config = data_config.seg_dense_schema + labels = list(data_config.tasks.keys()) + used_features = ( + segdense_config.features + + list(segdense_config.renamed_features.values()) + labels + ) + logging.info(used_features) - tfe_schema = {} - for entry in segdense_schema: - feature_name = entry["feature_name"] + tfe_schema = {} + for entry in segdense_schema: + feature_name = entry["feature_name"] - if feature_name in used_features: - length = entry["length"] - dtype = entry["dtype"] + if feature_name in used_features: + length = entry["length"] + dtype = entry["dtype"] - if feature_name in labels: - logging.info(f"Label: feature name is {feature_name} type is {dtype}") - tfe_schema[feature_name] = tf.io.FixedLenFeature( - length, DTYPE_MAP[dtype], DEFAULTS_MAP[dtype] - ) - elif length == -1: - tfe_schema[feature_name] = tf.io.VarLenFeature(DTYPE_MAP[dtype]) - else: - tfe_schema[feature_name] = tf.io.FixedLenFeature( - length, DTYPE_MAP[dtype], [DEFAULTS_MAP[dtype]] * length - ) - for feature_name in used_features: - if feature_name not in tfe_schema: - raise ValueError(f"{feature_name} missing from schema: {segdense_config.schema_path}.") - return tfe_schema + if feature_name in labels: + logging.info( + f"Label: feature name is {feature_name} type is {dtype}") + tfe_schema[feature_name] = tf.io.FixedLenFeature( + length, DTYPE_MAP[dtype], DEFAULTS_MAP[dtype] + ) + elif length == -1: + tfe_schema[feature_name] = tf.io.VarLenFeature( + DTYPE_MAP[dtype]) + else: + tfe_schema[feature_name] = tf.io.FixedLenFeature( + length, DTYPE_MAP[dtype], [DEFAULTS_MAP[dtype]] * length + ) + for feature_name in used_features: + if feature_name not in tfe_schema: + raise ValueError( + f"{feature_name} missing from schema: {segdense_config.schema_path}.") + return tfe_schema @functools.lru_cache(1) def make_mantissa_mask(mask_length: int) -> tf.Tensor: - """For experimentating with emulating bfloat16 or less precise types.""" - return tf.constant((1 << 32) - (1 << mask_length), dtype=tf.int32) + """For experimentating with emulating bfloat16 or less precise types.""" + return tf.constant((1 << 32) - (1 << mask_length), dtype=tf.int32) def mask_mantissa(tensor: tf.Tensor, mask_length: int) -> tf.Tensor: - """For experimentating with emulating bfloat16 or less precise types.""" - mask: tf.Tensor = make_mantissa_mask(mask_length) - return tf.bitcast(tf.bitwise.bitwise_and(tf.bitcast(tensor, tf.int32), mask), tensor.dtype) + """For experimentating with emulating bfloat16 or less precise types.""" + mask: tf.Tensor = make_mantissa_mask(mask_length) + return tf.bitcast(tf.bitwise.bitwise_and(tf.bitcast(tensor, tf.int32), mask), tensor.dtype) def parse_tf_example( - serialized_example, - tfe_schema, - seg_dense_schema_config, + serialized_example, + tfe_schema, + seg_dense_schema_config, ): - """Parse serialized tf.Example into dict of tensors. + """Parse serialized tf.Example into dict of tensors. - Args: - serialized_example: Serialized tf.Example to be parsed. - tfe_schema: Dictionary schema suitable for deserializing tf.Example. + Args: + serialized_example: Serialized tf.Example to be parsed. + tfe_schema: Dictionary schema suitable for deserializing tf.Example. - Returns: - Dictionary of tensors to be used as model input. - """ - inputs = tf.io.parse_example(serialized=serialized_example, features=tfe_schema) + Returns: + Dictionary of tensors to be used as model input. + """ + inputs = tf.io.parse_example( + serialized=serialized_example, features=tfe_schema) - for new_feature_name, old_feature_name in seg_dense_schema_config.renamed_features.items(): - inputs[new_feature_name] = inputs.pop(old_feature_name) + for new_feature_name, old_feature_name in seg_dense_schema_config.renamed_features.items(): + inputs[new_feature_name] = inputs.pop(old_feature_name) - # This should not actually be used except for experimentation with low precision floats. - if "mask_mantissa_features" in seg_dense_schema_config: - for feature_name, mask_length in seg_dense_schema_config.mask_mantissa_features.items(): - inputs[feature_name] = mask_mantissa(inputs[feature_name], mask_length) + # This should not actually be used except for experimentation with low precision floats. + if "mask_mantissa_features" in seg_dense_schema_config: + for feature_name, mask_length in seg_dense_schema_config.mask_mantissa_features.items(): + inputs[feature_name] = mask_mantissa( + inputs[feature_name], mask_length) - # DANGER DANGER: This default seems really scary, and it's only here because it has to be visible - # at TF level. - # We should not return empty tensors if we dont use embeddings. - # Otherwise, it breaks numpy->pt conversion - renamed_keys = list(seg_dense_schema_config.renamed_features.keys()) - for renamed_key in renamed_keys: - if "embedding" in renamed_key and (renamed_key not in inputs): - inputs[renamed_key] = tf.zeros([], tf.float32) + # DANGER DANGER: This default seems really scary, and it's only here because it has to be visible + # at TF level. + # We should not return empty tensors if we dont use embeddings. + # Otherwise, it breaks numpy->pt conversion + renamed_keys = list(seg_dense_schema_config.renamed_features.keys()) + for renamed_key in renamed_keys: + if "embedding" in renamed_key and (renamed_key not in inputs): + inputs[renamed_key] = tf.zeros([], tf.float32) - logging.info(f"parsed example and inputs are {inputs}") - return inputs + logging.info(f"parsed example and inputs are {inputs}") + return inputs def get_seg_dense_parse_fn(data_config: recap_data_config.RecapDataConfig): - """Placeholder for seg dense. + """Placeholder for seg dense. - In the future, when we use more seg dense variations, we can change this. - """ - with tf.io.gfile.GFile(data_config.seg_dense_schema.schema_path, "r") as f: - seg_dense_schema = json.load(f)["schema"] + In the future, when we use more seg dense variations, we can change this. + """ + with tf.io.gfile.GFile(data_config.seg_dense_schema.schema_path, "r") as f: + seg_dense_schema = json.load(f)["schema"] - tf_example_schema = create_tf_example_schema( - data_config, - seg_dense_schema, - ) + tf_example_schema = create_tf_example_schema( + data_config, + seg_dense_schema, + ) - logging.info("***** TF Example Schema *****") - logging.info(tf_example_schema) + logging.info("***** TF Example Schema *****") + logging.info(tf_example_schema) - parse = functools.partial( - parse_tf_example, - tfe_schema=tf_example_schema, - seg_dense_schema_config=data_config.seg_dense_schema, - ) - return parse + parse = functools.partial( + parse_tf_example, + tfe_schema=tf_example_schema, + seg_dense_schema_config=data_config.seg_dense_schema, + ) + return parse diff --git a/projects/home/recap/data/util.py b/projects/home/recap/data/util.py index a9fd51e..c5616c4 100644 --- a/projects/home/recap/data/util.py +++ b/projects/home/recap/data/util.py @@ -6,115 +6,160 @@ import tensorflow as tf def keyed_tensor_from_tensors_dict( - tensor_map: Mapping[str, torch.Tensor] + tensor_map: Mapping[str, torch.Tensor] ) -> "torchrec.KeyedTensor": - """ - Convert a dictionary of torch tensor to torchrec keyed tensor - Args: - tensor_map: + """ + Convert a dictionary of torch tensors to a torchrec KeyedTensor. - Returns: + Args: + tensor_map: A mapping of tensor names to torch tensors. - """ - keys = list(tensor_map.keys()) - # We expect batch size to be first dim. However, if we get a shape [Batch_size], - # KeyedTensor will not find the correct batch_size. So, in those cases we make sure the shape is - # [Batch_size x 1]. - values = [ - tensor_map[key] if len(tensor_map[key].shape) > 1 else torch.unsqueeze(tensor_map[key], -1) - for key in keys - ] - return torchrec.KeyedTensor.from_tensor_list(keys, values) + Returns: + A torchrec KeyedTensor. + """ + keys = list(tensor_map.keys()) + # We expect batch size to be first dim. However, if we get a shape [Batch_size], + # KeyedTensor will not find the correct batch_size. So, in those cases we make sure the shape is + # [Batch_size x 1]. + values = [ + tensor_map[key] if len(tensor_map[key].shape) > 1 else torch.unsqueeze( + tensor_map[key], -1) + for key in keys + ] + return torchrec.KeyedTensor.from_tensor_list(keys, values) def _compute_jagged_tensor_from_tensor(tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - if tensor.is_sparse: - x = tensor.coalesce() # Ensure that the indices are ordered. - lengths = torch.bincount(x.indices()[0]) - values = x.values() - else: - values = tensor - lengths = torch.ones(tensor.shape[0], dtype=torch.int32, device=tensor.device) - return values, lengths + """ + Compute a jagged tensor from a torch tensor. + + Args: + tensor: Input torch tensor. + + Returns: + A tuple containing the values and lengths of the jagged tensor. + """ + if tensor.is_sparse: + x = tensor.coalesce() # Ensure that the indices are ordered. + lengths = torch.bincount(x.indices()[0]) + values = x.values() + else: + values = tensor + lengths = torch.ones( + tensor.shape[0], dtype=torch.int32, device=tensor.device) + return values, lengths def jagged_tensor_from_tensor(tensor: torch.Tensor) -> "torchrec.JaggedTensor": - """ - Convert a torch tensor to torchrec jagged tensor. - Note: Currently only support shape of [Batch_size] or [Batch_size x N] for dense tensors. - For sparse tensor the shape of .values() should be [Batch_size] or [Batch_size x N]; the - dense_shape of the sparse tensor can be arbitrary. - Args: - tensor: a torch (sparse) tensor. - Returns: - """ - values, lengths = _compute_jagged_tensor_from_tensor(tensor) - return torchrec.JaggedTensor(values=values, lengths=lengths) + """ + Convert a torch tensor to a torchrec jagged tensor. + + Note: Currently, this function only supports input tensors with shapes of [Batch_size] or [Batch_size x N] for dense tensors. + For sparse tensors, the shape of .values() should be [Batch_size] or [Batch_size x N], and the dense_shape of the sparse tensor can be arbitrary. + + Args: + tensor: A torch (sparse) tensor. + + Returns: + A torchrec JaggedTensor. + """ + values, lengths = _compute_jagged_tensor_from_tensor(tensor) + return torchrec.JaggedTensor(values=values, lengths=lengths) def keyed_jagged_tensor_from_tensors_dict( - tensor_map: Mapping[str, torch.Tensor] + tensor_map: Mapping[str, torch.Tensor] ) -> "torchrec.KeyedJaggedTensor": - """ - Convert a dictionary of (sparse) torch tensors to torchrec keyed jagged tensor. - Note: Currently only support shape of [Batch_size] or [Batch_size x 1] for dense tensors. - For sparse tensor the shape of .values() should be [Batch_size] or [Batch_size x 1]; the - dense_shape of the sparse tensor can be arbitrary. - Args: - tensor_map: + """ + Convert a dictionary of (sparse) torch tensors to a torchrec keyed jagged tensor. - Returns: + Note: Currently, this function only supports input tensors with shapes of [Batch_size] or [Batch_size x 1] for dense tensors. + For sparse tensors, the shape of .values() should be [Batch_size] or [Batch_size x 1], and the dense_shape of the sparse tensor can be arbitrary. - """ + Args: + tensor_map: A mapping of tensor names to torch tensors. + + Returns: + A torchrec KeyedJaggedTensor. + """ + + if not tensor_map: + return torchrec.KeyedJaggedTensor( + keys=[], + values=torch.zeros(0, dtype=torch.int), + lengths=torch.zeros(0, dtype=torch.int), + ) + values = [] + lengths = [] + for tensor in tensor_map.values(): + tensor_val, tensor_len = _compute_jagged_tensor_from_tensor(tensor) + values.append(torch.squeeze(tensor_val)) + lengths.append(tensor_len) + + values = torch.cat(values, axis=0) + lengths = torch.cat(lengths, axis=0) - if not tensor_map: return torchrec.KeyedJaggedTensor( - keys=[], - values=torch.zeros(0, dtype=torch.int), - lengths=torch.zeros(0, dtype=torch.int), + keys=list(tensor_map.keys()), + values=values, + lengths=lengths, ) - values = [] - lengths = [] - for tensor in tensor_map.values(): - tensor_val, tensor_len = _compute_jagged_tensor_from_tensor(tensor) - values.append(torch.squeeze(tensor_val)) - lengths.append(tensor_len) - - values = torch.cat(values, axis=0) - lengths = torch.cat(lengths, axis=0) - - return torchrec.KeyedJaggedTensor( - keys=list(tensor_map.keys()), - values=values, - lengths=lengths, - ) def _tf_to_numpy(tf_tensor: tf.Tensor) -> np.ndarray: - return tf_tensor._numpy() # noqa + """ + Convert a TensorFlow tensor to a NumPy array. + + Args: + tf_tensor: TensorFlow tensor. + + Returns: + NumPy array. + """ + return tf_tensor._numpy() # noqa def _dense_tf_to_torch(tensor: tf.Tensor, pin_memory: bool) -> torch.Tensor: - tensor = _tf_to_numpy(tensor) - # Pytorch does not support bfloat16, up cast to float32 to keep the same number of bits on exponent - if tensor.dtype.name == "bfloat16": - tensor = tensor.astype(np.float32) + """ + Convert a dense TensorFlow tensor to a PyTorch tensor. - tensor = torch.from_numpy(tensor) - if pin_memory: - tensor = tensor.pin_memory() - return tensor + Args: + tensor: Dense TensorFlow tensor. + pin_memory: Whether to pin the tensor in memory (for CUDA). + + Returns: + PyTorch tensor. + """ + tensor = _tf_to_numpy(tensor) + # Pytorch does not support bfloat16, up cast to float32 to keep the same number of bits on exponent + if tensor.dtype.name == "bfloat16": + tensor = tensor.astype(np.float32) + + tensor = torch.from_numpy(tensor) + if pin_memory: + tensor = tensor.pin_memory() + return tensor def sparse_or_dense_tf_to_torch( - tensor: Union[tf.Tensor, tf.SparseTensor], pin_memory: bool + tensor: Union[tf.Tensor, tf.SparseTensor], pin_memory: bool ) -> torch.Tensor: - if isinstance(tensor, tf.SparseTensor): - tensor = torch.sparse_coo_tensor( - _dense_tf_to_torch(tensor.indices, pin_memory).t(), - _dense_tf_to_torch(tensor.values, pin_memory), - torch.Size(_tf_to_numpy(tensor.dense_shape)), - ) - else: - tensor = _dense_tf_to_torch(tensor, pin_memory) - return tensor + """ + Convert a TensorFlow tensor (sparse or dense) to a PyTorch tensor. + + Args: + tensor: TensorFlow tensor (sparse or dense). + pin_memory: Whether to pin the tensor in memory (for CUDA). + + Returns: + PyTorch tensor. + """ + if isinstance(tensor, tf.SparseTensor): + tensor = torch.sparse_coo_tensor( + _dense_tf_to_torch(tensor.indices, pin_memory).t(), + _dense_tf_to_torch(tensor.values, pin_memory), + torch.Size(_tf_to_numpy(tensor.dense_shape)), + ) + else: + tensor = _dense_tf_to_torch(tensor, pin_memory) + return tensor diff --git a/projects/home/recap/model/mask_net.py b/projects/home/recap/model/mask_net.py index 9a1d233..951007e 100644 --- a/projects/home/recap/model/mask_net.py +++ b/projects/home/recap/model/mask_net.py @@ -6,234 +6,243 @@ import torch def _init_weights(module): - """Initializes weights - - Example - - ```python - import torch - import torch.nn as nn + """Initializes weights - # Define a simple linear layer - linear_layer = nn.Linear(64, 32) + Example - # Initialize the weights and biases using _init_weights - _init_weights(linear_layer) - ``` - - """ - if isinstance(module, torch.nn.Linear): - torch.nn.init.xavier_uniform_(module.weight) - torch.nn.init.constant_(module.bias, 0) + ```python + import torch + import torch.nn as nn + + # Define a simple linear layer + linear_layer = nn.Linear(64, 32) + + # Initialize the weights and biases using _init_weights + _init_weights(linear_layer) + ``` + + """ + if isinstance(module, torch.nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + torch.nn.init.constant_(module.bias, 0) class MaskBlock(torch.nn.Module): - """ - MaskBlock module in a mask-based neural network. + """ + MaskBlock module in a mask-based neural network. - This module represents a MaskBlock, which applies a masking operation to the input data and then - passes it through a hidden layer. It is typically used as a building block within a MaskNet. + This module represents a MaskBlock, which applies a masking operation to the input data and then + passes it through a hidden layer. It is typically used as a building block within a MaskNet. - Args: - mask_block_config (config.MaskBlockConfig): Configuration for the MaskBlock. - input_dim (int): Dimensionality of the input data. - mask_input_dim (int): Dimensionality of the mask input. + Args: + mask_block_config (config.MaskBlockConfig): Configuration for the MaskBlock. + input_dim (int): Dimensionality of the input data. + mask_input_dim (int): Dimensionality of the mask input. - Example: - To create and use a MaskBlock within a MaskNet, follow these steps: + Example: + To create and use a MaskBlock within a MaskNet, follow these steps: - ```python - # Define the configuration for the MaskBlock - mask_block_config = MaskBlockConfig( - input_layer_norm=True, # Apply input layer normalization - reduction_factor=0.5 # Reduce input dimensionality by 50% + ```python + # Define the configuration for the MaskBlock + mask_block_config = MaskBlockConfig( + input_layer_norm=True, # Apply input layer normalization + reduction_factor=0.5 # Reduce input dimensionality by 50% + ) + + # Create an instance of the MaskBlock + mask_block = MaskBlock(mask_block_config, input_dim=64, mask_input_dim=32) + + # Generate input tensors + input_data = torch.randn(batch_size, 64) + mask_input = torch.randn(batch_size, 32) + + # Perform a forward pass through the MaskBlock + output = mask_block(input_data, mask_input) + ``` + + Note: + The `MaskBlock` module applies layer normalization to the input if specified, followed by a masking + operation that combines the input and mask input. Then, it passes the result through a hidden layer + with optional dimensionality reduction. + + Warning: + This class is intended for internal use within neural network architectures and should not be + directly accessed or modified by external code. + """ + + def __init__( + self, mask_block_config: config.MaskBlockConfig, input_dim: int, mask_input_dim: int + ) -> None: + """ + Initializes the MaskBlock module. + + Args: + mask_block_config (config.MaskBlockConfig): Configuration for the MaskBlock. + input_dim (int): Dimensionality of the input data. + mask_input_dim (int): Dimensionality of the mask input. + + Returns: + None + """ + + super(MaskBlock, self).__init__() + self.mask_block_config = mask_block_config + output_size = mask_block_config.output_size + + if mask_block_config.input_layer_norm: + self._input_layer_norm = torch.nn.LayerNorm(input_dim) + else: + self._input_layer_norm = None + + if mask_block_config.reduction_factor: + aggregation_size = int( + mask_input_dim * mask_block_config.reduction_factor) + elif mask_block_config.aggregation_size is not None: + aggregation_size = mask_block_config.aggregation_size + else: + raise ValueError( + "Need one of reduction factor or aggregation size.") + + self._mask_layer = torch.nn.Sequential( + torch.nn.Linear(mask_input_dim, aggregation_size), + torch.nn.ReLU(), + torch.nn.Linear(aggregation_size, input_dim), ) + self._mask_layer.apply(_init_weights) + self._hidden_layer = torch.nn.Linear(input_dim, output_size) + self._hidden_layer.apply(_init_weights) + self._layer_norm = torch.nn.LayerNorm(output_size) - # Create an instance of the MaskBlock - mask_block = MaskBlock(mask_block_config, input_dim=64, mask_input_dim=32) - - # Generate input tensors - input_data = torch.randn(batch_size, 64) - mask_input = torch.randn(batch_size, 32) - - # Perform a forward pass through the MaskBlock - output = mask_block(input_data, mask_input) - ``` - - Note: - The `MaskBlock` module applies layer normalization to the input if specified, followed by a masking - operation that combines the input and mask input. Then, it passes the result through a hidden layer - with optional dimensionality reduction. - - Warning: - This class is intended for internal use within neural network architectures and should not be - directly accessed or modified by external code. - """ - def __init__( - self, mask_block_config: config.MaskBlockConfig, input_dim: int, mask_input_dim: int - ) -> None: - """ - Initializes the MaskBlock module. - - Args: - mask_block_config (config.MaskBlockConfig): Configuration for the MaskBlock. - input_dim (int): Dimensionality of the input data. - mask_input_dim (int): Dimensionality of the mask input. - - Returns: - None + def forward(self, net: torch.Tensor, mask_input: torch.Tensor): """ + Performs a forward pass through the MaskBlock. - super(MaskBlock, self).__init__() - self.mask_block_config = mask_block_config - output_size = mask_block_config.output_size + Args: + net (torch.Tensor): Input data tensor. + mask_input (torch.Tensor): Mask input tensor. - if mask_block_config.input_layer_norm: - self._input_layer_norm = torch.nn.LayerNorm(input_dim) - else: - self._input_layer_norm = None - - if mask_block_config.reduction_factor: - aggregation_size = int(mask_input_dim * mask_block_config.reduction_factor) - elif mask_block_config.aggregation_size is not None: - aggregation_size = mask_block_config.aggregation_size - else: - raise ValueError("Need one of reduction factor or aggregation size.") - - self._mask_layer = torch.nn.Sequential( - torch.nn.Linear(mask_input_dim, aggregation_size), - torch.nn.ReLU(), - torch.nn.Linear(aggregation_size, input_dim), - ) - self._mask_layer.apply(_init_weights) - self._hidden_layer = torch.nn.Linear(input_dim, output_size) - self._hidden_layer.apply(_init_weights) - self._layer_norm = torch.nn.LayerNorm(output_size) - - def forward(self, net: torch.Tensor, mask_input: torch.Tensor): - """ - Performs a forward pass through the MaskBlock. - - Args: - net (torch.Tensor): Input data tensor. - mask_input (torch.Tensor): Mask input tensor. - - Returns: - torch.Tensor: Output tensor of the MaskBlock. - """ - if self._input_layer_norm: - net = self._input_layer_norm(net) - hidden_layer_output = self._hidden_layer(net * self._mask_layer(mask_input)) - return self._layer_norm(hidden_layer_output) + Returns: + torch.Tensor: Output tensor of the MaskBlock. + """ + if self._input_layer_norm: + net = self._input_layer_norm(net) + hidden_layer_output = self._hidden_layer( + net * self._mask_layer(mask_input)) + return self._layer_norm(hidden_layer_output) class MaskNet(torch.nn.Module): - """ - MaskNet module in a mask-based neural network. - - This module represents a MaskNet, which consists of multiple MaskBlocks. It can be used to - create mask-based neural networks with parallel or stacked MaskBlocks. - - Args: - mask_net_config (config.MaskNetConfig): Configuration for the MaskNet. - in_features (int): Dimensionality of the input data. - - Example: - To create and use a MaskNet, you can follow these steps: - - ```python - # Define the configuration for the MaskNet - mask_net_config = MaskNetConfig( - use_parallel=True, # Use parallel MaskBlocks - mlp=MlpConfig(layer_sizes=[128, 64]) # Optional MLP on the outputs - ) - - # Create an instance of the MaskNet - mask_net = MaskNet(mask_net_config, in_features=64) - - # Generate input tensors - input_data = torch.randn(batch_size, 64) - - # Perform a forward pass through the MaskNet - outputs = mask_net(input_data) - - # Access the output and shared layer - output = outputs["output"] - shared_layer = outputs["shared_layer"] - ``` - - Note: - The `MaskNet` module allows you to create mask-based neural networks with parallel or stacked - MaskBlocks. You can also optionally apply an MLP to the outputs for further processing. - - Warning: - This class is intended for internal use within neural network architectures and should not be - directly accessed or modified by external code. """ - def __init__(self, mask_net_config: config.MaskNetConfig, in_features: int): - """ - Initializes the MaskNet module. + MaskNet module in a mask-based neural network. - Args: - mask_net_config (config.MaskNetConfig): Configuration for the MaskNet. - in_features (int): Dimensionality of the input data. + This module represents a MaskNet, which consists of multiple MaskBlocks. It can be used to + create mask-based neural networks with parallel or stacked MaskBlocks. - Returns: - None + Args: + mask_net_config (config.MaskNetConfig): Configuration for the MaskNet. + in_features (int): Dimensionality of the input data. + + Example: + To create and use a MaskNet, you can follow these steps: + + ```python + # Define the configuration for the MaskNet + mask_net_config = MaskNetConfig( + use_parallel=True, # Use parallel MaskBlocks + mlp=MlpConfig(layer_sizes=[128, 64]) # Optional MLP on the outputs + ) + + # Create an instance of the MaskNet + mask_net = MaskNet(mask_net_config, in_features=64) + + # Generate input tensors + input_data = torch.randn(batch_size, 64) + + # Perform a forward pass through the MaskNet + outputs = mask_net(input_data) + + # Access the output and shared layer + output = outputs["output"] + shared_layer = outputs["shared_layer"] + ``` + + Note: + The `MaskNet` module allows you to create mask-based neural networks with parallel or stacked + MaskBlocks. You can also optionally apply an MLP to the outputs for further processing. + + Warning: + This class is intended for internal use within neural network architectures and should not be + directly accessed or modified by external code. + """ + + def __init__(self, mask_net_config: config.MaskNetConfig, in_features: int): """ + Initializes the MaskNet module. - super().__init__() - self.mask_net_config = mask_net_config - mask_blocks = [] + Args: + mask_net_config (config.MaskNetConfig): Configuration for the MaskNet. + in_features (int): Dimensionality of the input data. - if mask_net_config.use_parallel: - total_output_mask_blocks = 0 - for mask_block_config in mask_net_config.mask_blocks: - mask_blocks.append(MaskBlock(mask_block_config, in_features, in_features)) - total_output_mask_blocks += mask_block_config.output_size - self._mask_blocks = torch.nn.ModuleList(mask_blocks) - else: - input_size = in_features - for mask_block_config in mask_net_config.mask_blocks: - mask_blocks.append(MaskBlock(mask_block_config, input_size, in_features)) - input_size = mask_block_config.output_size + Returns: + None + """ - self._mask_blocks = torch.nn.ModuleList(mask_blocks) - total_output_mask_blocks = mask_block_config.output_size + super().__init__() + self.mask_net_config = mask_net_config + mask_blocks = [] - if mask_net_config.mlp: - self._dense_layers = mlp.Mlp(total_output_mask_blocks, mask_net_config.mlp) - self.out_features = mask_net_config.mlp.layer_sizes[-1] - else: - self.out_features = total_output_mask_blocks - self.shared_size = total_output_mask_blocks + if mask_net_config.use_parallel: + total_output_mask_blocks = 0 + for mask_block_config in mask_net_config.mask_blocks: + mask_blocks.append( + MaskBlock(mask_block_config, in_features, in_features)) + total_output_mask_blocks += mask_block_config.output_size + self._mask_blocks = torch.nn.ModuleList(mask_blocks) + else: + input_size = in_features + for mask_block_config in mask_net_config.mask_blocks: + mask_blocks.append( + MaskBlock(mask_block_config, input_size, in_features)) + input_size = mask_block_config.output_size - def forward(self, inputs: torch.Tensor): - """ - Performs a forward pass through the MaskNet. + self._mask_blocks = torch.nn.ModuleList(mask_blocks) + total_output_mask_blocks = mask_block_config.output_size - Args: - inputs (torch.Tensor): Input data tensor. + if mask_net_config.mlp: + self._dense_layers = mlp.Mlp( + total_output_mask_blocks, mask_net_config.mlp) + self.out_features = mask_net_config.mlp.layer_sizes[-1] + else: + self.out_features = total_output_mask_blocks + self.shared_size = total_output_mask_blocks - Returns: - torch.Tensor: Output tensor of the MaskNet. + def forward(self, inputs: torch.Tensor): """ - if self.mask_net_config.use_parallel: - mask_outputs = [] - for mask_layer in self._mask_blocks: - mask_outputs.append(mask_layer(mask_input=inputs, net=inputs)) - # Share the outputs of the MaskBlocks. - all_mask_outputs = torch.cat(mask_outputs, dim=1) - output = ( - all_mask_outputs - if self.mask_net_config.mlp is None - else self._dense_layers(all_mask_outputs)["output"] - ) - return {"output": output, "shared_layer": all_mask_outputs} - else: - net = inputs - for mask_layer in self._mask_blocks: - net = mask_layer(net=net, mask_input=inputs) - # Share the output of the stacked MaskBlocks. - output = net if self.mask_net_config.mlp is None else self._dense_layers[net]["output"] - return {"output": output, "shared_layer": net} + Performs a forward pass through the MaskNet. + + Args: + inputs (torch.Tensor): Input data tensor. + + Returns: + torch.Tensor: Output tensor of the MaskNet. + """ + if self.mask_net_config.use_parallel: + mask_outputs = [] + for mask_layer in self._mask_blocks: + mask_outputs.append(mask_layer(mask_input=inputs, net=inputs)) + # Share the outputs of the MaskBlocks. + all_mask_outputs = torch.cat(mask_outputs, dim=1) + output = ( + all_mask_outputs + if self.mask_net_config.mlp is None + else self._dense_layers(all_mask_outputs)["output"] + ) + return {"output": output, "shared_layer": all_mask_outputs} + else: + net = inputs + for mask_layer in self._mask_blocks: + net = mask_layer(net=net, mask_input=inputs) + # Share the output of the stacked MaskBlocks. + output = net if self.mask_net_config.mlp is None else self._dense_layers[ + net]["output"] + return {"output": output, "shared_layer": net} diff --git a/projects/home/recap/model/model_and_loss.py b/projects/home/recap/model/model_and_loss.py index bcfcc0c..e16ad09 100644 --- a/projects/home/recap/model/model_and_loss.py +++ b/projects/home/recap/model/model_and_loss.py @@ -5,113 +5,117 @@ from absl import logging class ModelAndLoss(torch.nn.Module): - """ - PyTorch module that combines a neural network model and loss function. - - This module wraps a neural network model and facilitates the forward pass through the model - while also calculating the loss based on the model's predictions and provided labels. - - Args: - model: The torch module to wrap. - loss_fn (Callable): Function for calculating the loss, which should accept logits and labels. - stratifiers (Optional[List[embedding_config_mod.StratifierConfig]]): A list of stratifier configurations - for metrics stratification. Each stratifier config includes the name and index of discrete features - to emit for stratification. - - Example: - To use `ModelAndLoss` in a PyTorch training loop, you can create an instance of it and pass your model - and loss function as arguments: - - ```python - # Create a neural network model - model = YourNeuralNetworkModel() - - # Define a loss function - loss_fn = torch.nn.CrossEntropyLoss() - - # Create an instance of ModelAndLoss - model_and_loss = ModelAndLoss(model, loss_fn) - - # Generate a batch of training data (e.g., RecapBatch) - batch = generate_training_batch() - - # Perform a forward pass through the model and calculate the loss - loss, outputs = model_and_loss(batch) - - # You can now backpropagate and optimize using the computed loss - loss.backward() - optimizer.step() - ``` - - Note: - The `ModelAndLoss` class simplifies the process of running forward passes through a model and - calculating loss, making it easier to integrate the model into your training loop. Additionally, - it supports the addition of stratifiers for metrics stratification, if needed. - - Warning: - This class is intended for internal use within neural network architectures and should not be - directly accessed or modified by external code. """ - def __init__( - self, - model, - loss_fn: Callable, - stratifiers: Optional[List[embedding_config_mod.StratifierConfig]] = None, - ) -> None: - """ - Initializes the ModelAndLoss module. + PyTorch module that combines a neural network model and loss function. - Args: - model: The torch module to wrap. - loss_fn (Callable): Function for calculating the loss, which should accept logits and labels. - stratifiers (Optional[List[embedding_config_mod.StratifierConfig]]): A list of stratifier configurations - for metrics stratification. - """ - super().__init__() - self.model = model - self.loss_fn = loss_fn - self.stratifiers = stratifiers + This module wraps a neural network model and facilitates the forward pass through the model + while also calculating the loss based on the model's predictions and provided labels. - def forward(self, batch: "RecapBatch"): # type: ignore[name-defined] - """Runs model forward and calculates loss according to given loss_fn. + Args: + model: The torch module to wrap. + loss_fn (Callable): Function for calculating the loss, which should accept logits and labels. + stratifiers (Optional[List[embedding_config_mod.StratifierConfig]]): A list of stratifier configurations + for metrics stratification. Each stratifier config includes the name and index of discrete features + to emit for stratification. - NOTE: The input signature here needs to be a Pipelineable object for - prefetching purposes during training using torchrec's pipeline. However - the underlying model signature needs to be exportable to onnx, requiring - generic python types. see https://pytorch.org/docs/stable/onnx.html#types. + Example: + To use `ModelAndLoss` in a PyTorch training loop, you can create an instance of it and pass your model + and loss function as arguments: - """ - outputs = self.model( - continuous_features=batch.continuous_features, - binary_features=batch.binary_features, - discrete_features=batch.discrete_features, - sparse_features=batch.sparse_features, - user_embedding=batch.user_embedding, - user_eng_embedding=batch.user_eng_embedding, - author_embedding=batch.author_embedding, - labels=batch.labels, - weights=batch.weights, - ) - losses = self.loss_fn(outputs["logits"], batch.labels.float(), batch.weights.float()) + ```python + # Create a neural network model + model = YourNeuralNetworkModel() - if self.stratifiers: - logging.info(f"***** Adding stratifiers *****\n {self.stratifiers}") - outputs["stratifiers"] = {} - for stratifier in self.stratifiers: - outputs["stratifiers"][stratifier.name] = batch.discrete_features[:, stratifier.index] + # Define a loss function + loss_fn = torch.nn.CrossEntropyLoss() - # In general, we can have a large number of losses returned by our loss function. - if isinstance(losses, dict): - return losses["loss"], { - **outputs, - **losses, - "labels": batch.labels, - "weights": batch.weights, - } - else: # Assume that this is a float. - return losses, { - **outputs, - "loss": losses, - "labels": batch.labels, - "weights": batch.weights, - } + # Create an instance of ModelAndLoss + model_and_loss = ModelAndLoss(model, loss_fn) + + # Generate a batch of training data (e.g., RecapBatch) + batch = generate_training_batch() + + # Perform a forward pass through the model and calculate the loss + loss, outputs = model_and_loss(batch) + + # You can now backpropagate and optimize using the computed loss + loss.backward() + optimizer.step() + ``` + + Note: + The `ModelAndLoss` class simplifies the process of running forward passes through a model and + calculating loss, making it easier to integrate the model into your training loop. Additionally, + it supports the addition of stratifiers for metrics stratification, if needed. + + Warning: + This class is intended for internal use within neural network architectures and should not be + directly accessed or modified by external code. + """ + + def __init__( + self, + model, + loss_fn: Callable, + stratifiers: Optional[List[embedding_config_mod.StratifierConfig]] = None, + ) -> None: + """ + Initializes the ModelAndLoss module. + + Args: + model: The torch module to wrap. + loss_fn (Callable): Function for calculating the loss, which should accept logits and labels. + stratifiers (Optional[List[embedding_config_mod.StratifierConfig]]): A list of stratifier configurations + for metrics stratification. + """ + super().__init__() + self.model = model + self.loss_fn = loss_fn + self.stratifiers = stratifiers + + def forward(self, batch: "RecapBatch"): # type: ignore[name-defined] + """Runs model forward and calculates loss according to given loss_fn. + + NOTE: The input signature here needs to be a Pipelineable object for + prefetching purposes during training using torchrec's pipeline. However + the underlying model signature needs to be exportable to onnx, requiring + generic python types. see https://pytorch.org/docs/stable/onnx.html#types. + + """ + outputs = self.model( + continuous_features=batch.continuous_features, + binary_features=batch.binary_features, + discrete_features=batch.discrete_features, + sparse_features=batch.sparse_features, + user_embedding=batch.user_embedding, + user_eng_embedding=batch.user_eng_embedding, + author_embedding=batch.author_embedding, + labels=batch.labels, + weights=batch.weights, + ) + losses = self.loss_fn( + outputs["logits"], batch.labels.float(), batch.weights.float()) + + if self.stratifiers: + logging.info( + f"***** Adding stratifiers *****\n {self.stratifiers}") + outputs["stratifiers"] = {} + for stratifier in self.stratifiers: + outputs["stratifiers"][stratifier.name] = batch.discrete_features[:, + stratifier.index] + + # In general, we can have a large number of losses returned by our loss function. + if isinstance(losses, dict): + return losses["loss"], { + **outputs, + **losses, + "labels": batch.labels, + "weights": batch.weights, + } + else: # Assume that this is a float. + return losses, { + **outputs, + "loss": losses, + "labels": batch.labels, + "weights": batch.weights, + } diff --git a/projects/home/recap/model/numeric_calibration.py b/projects/home/recap/model/numeric_calibration.py index 34f819e..a66869f 100644 --- a/projects/home/recap/model/numeric_calibration.py +++ b/projects/home/recap/model/numeric_calibration.py @@ -2,64 +2,65 @@ import torch class NumericCalibration(torch.nn.Module): - """ - Numeric calibration module for adjusting probability scores. - - This module scales probability scores to correct for imbalanced datasets, where positive and negative samples - may be underrepresented or have different ratios. It is designed to be used as a component in a neural network - for tasks such as binary classification. - - Args: - pos_downsampling_rate (float): The downsampling rate for positive samples. - neg_downsampling_rate (float): The downsampling rate for negative samples. - - Example: - To use `NumericCalibration` in a PyTorch model, you can create an instance of it and apply it to probability - scores like this: - - ```python - # Create a NumericCalibration instance with downsampling rates - calibration = NumericCalibration(pos_downsampling_rate=0.1, neg_downsampling_rate=0.2) - - # Generate probability scores (e.g., from a neural network) - raw_probs = torch.tensor([0.8, 0.6, 0.2, 0.9]) - - # Apply numeric calibration to adjust the probabilities - calibrated_probs = calibration(raw_probs) - - # The `calibrated_probs` now contains the adjusted probability scores - ``` - - Note: - The `NumericCalibration` module is used to adjust probability scores to account for differences in - the number of positive and negative samples in a dataset. It can help improve the calibration of - probability estimates in imbalanced classification problems. - - Warning: - This class is intended for internal use within neural network architectures and should not be - directly accessed or modified by external code. """ - def __init__( - self, - pos_downsampling_rate: float, - neg_downsampling_rate: float, - ): - """ - Apply numeric calibration to probability scores. + Numeric calibration module for adjusting probability scores. - Args: - probs (torch.Tensor): Probability scores to be calibrated. + This module scales probability scores to correct for imbalanced datasets, where positive and negative samples + may be underrepresented or have different ratios. It is designed to be used as a component in a neural network + for tasks such as binary classification. - Returns: - torch.Tensor: Calibrated probability scores. + Args: + pos_downsampling_rate (float): The downsampling rate for positive samples. + neg_downsampling_rate (float): The downsampling rate for negative samples. + + Example: + To use `NumericCalibration` in a PyTorch model, you can create an instance of it and apply it to probability + scores like this: + + ```python + # Create a NumericCalibration instance with downsampling rates + calibration = NumericCalibration(pos_downsampling_rate=0.1, neg_downsampling_rate=0.2) + + # Generate probability scores (e.g., from a neural network) + raw_probs = torch.tensor([0.8, 0.6, 0.2, 0.9]) + + # Apply numeric calibration to adjust the probabilities + calibrated_probs = calibration(raw_probs) + + # The `calibrated_probs` now contains the adjusted probability scores + ``` + + Note: + The `NumericCalibration` module is used to adjust probability scores to account for differences in + the number of positive and negative samples in a dataset. It can help improve the calibration of + probability estimates in imbalanced classification problems. + + Warning: + This class is intended for internal use within neural network architectures and should not be + directly accessed or modified by external code. + """ + + def __init__( + self, + pos_downsampling_rate: float, + neg_downsampling_rate: float, + ): """ - super().__init__() + Apply numeric calibration to probability scores. - # Using buffer to make sure they are on correct device (and not moved every time). - # Will also be part of state_dict. - self.register_buffer( - "ratio", torch.as_tensor(neg_downsampling_rate / pos_downsampling_rate), persistent=True - ) + Args: + probs (torch.Tensor): Probability scores to be calibrated. - def forward(self, probs: torch.Tensor): - return probs * self.ratio / (1.0 - probs + (self.ratio * probs)) + Returns: + torch.Tensor: Calibrated probability scores. + """ + super().__init__() + + # Using buffer to make sure they are on correct device (and not moved every time). + # Will also be part of state_dict. + self.register_buffer( + "ratio", torch.as_tensor(neg_downsampling_rate / pos_downsampling_rate), persistent=True + ) + + def forward(self, probs: torch.Tensor): + return probs * self.ratio / (1.0 - probs + (self.ratio * probs)) diff --git a/tools/pq.py b/tools/pq.py index 24c6345..b18f68a 100644 --- a/tools/pq.py +++ b/tools/pq.py @@ -38,6 +38,15 @@ import pyarrow.parquet as pq def _create_dataset(path: str): + """ + Create a PyArrow dataset from Parquet files located at the specified path. + + Args: + path (str): The path to the Parquet files. + + Returns: + pyarrow.dataset.Dataset: The PyArrow dataset. + """ fs = infer_fs(path) files = fs.glob(path) return pads.dataset(files, format="parquet", filesystem=fs) @@ -47,12 +56,27 @@ class PqReader: def __init__( self, path: str, num: int = 10, batch_size: int = 1024, columns: Optional[List[str]] = None ): + """ + Initialize a Parquet Reader. + + Args: + path (str): The path to the Parquet files. + num (int): The maximum number of rows to read. + batch_size (int): The batch size for reading data. + columns (Optional[List[str]]): A list of column names to read (default is None, which reads all columns). + """ self._ds = _create_dataset(path) self._batch_size = batch_size self._num = num self._columns = columns def __iter__(self): + """ + Iterate through the Parquet data and yield batches of rows. + + Yields: + pyarrow.RecordBatch: A batch of rows. + """ batches = self._ds.to_batches(batch_size=self._batch_size, columns=self._columns) rows_seen = 0 for count, record in enumerate(batches): @@ -62,6 +86,12 @@ class PqReader: rows_seen += record.data.num_rows def _head(self): + """ + Get the first `num` rows of the Parquet data. + + Returns: + pyarrow.RecordBatch: A batch of rows. + """ total_read = self._num * self.bytes_per_row if total_read >= int(500e6): raise Exception( @@ -71,6 +101,12 @@ class PqReader: @property def bytes_per_row(self) -> int: + """ + Calculate the estimated bytes per row in the dataset. + + Returns: + int: The estimated bytes per row. + """ nbits = 0 for t in self._ds.schema.types: try: @@ -81,18 +117,23 @@ class PqReader: return nbits // 8 def schema(self): + """ + Display the schema of the Parquet dataset. + """ print(f"\n# Schema\n{self._ds.schema}") def head(self): - """Displays first --num rows.""" + """ + Display the first `num` rows of the Parquet data as a pandas DataFrame. + """ print(self._head().to_pandas()) def distinct(self): - """Displays unique values seen in specified columns in the first `--num` rows. - - Useful for getting an approximate vocabulary for certain columns. - """ + Display unique values seen in specified columns in the first `num` rows. + + Useful for getting an approximate vocabulary for certain columns. + """ for col_name, column in zip(self._head().column_names, self._head().columns): print(col_name) print("unique:", column.unique().to_pylist())