""" Preprocessors applied on DDS workers in order to modify the dataset on the fly. Some of these preprocessors are also applied to the model at serving time. """ from tml.projects.home.recap import config as config_mod from absl import logging import tensorflow as tf import numpy as np class TruncateAndSlice(tf.keras.Model): """ A class for truncating and slicing input features based on the provided configuration. Args: truncate_and_slice_config: A configuration object specifying how to truncate and slice features. """ def __init__(self, truncate_and_slice_config): """ Initializes the TruncateAndSlice model. Args: truncate_and_slice_config: A configuration object specifying how to truncate and slice features. """ super().__init__() self._truncate_and_slice_config = truncate_and_slice_config if self._truncate_and_slice_config.continuous_feature_mask_path: with tf.io.gfile.GFile( self._truncate_and_slice_config.continuous_feature_mask_path, "rb" ) as f: self._continuous_mask = np.load(f).nonzero()[0] logging.info(f"Slicing {np.sum(self._continuous_mask)} continuous features.") else: self._continuous_mask = None if self._truncate_and_slice_config.binary_feature_mask_path: with tf.io.gfile.GFile(self._truncate_and_slice_config.binary_feature_mask_path, "rb") as f: self._binary_mask = np.load(f).nonzero()[0] logging.info(f"Slicing {np.sum(self._binary_mask)} binary features.") else: self._binary_mask = None def call(self, inputs, training=None, mask=None): """ Applies truncation and slicing to the input features based on the configuration. Args: inputs: A dictionary of input features. training: A boolean indicating whether the model is in training mode. mask: A mask tensor. Returns: A dictionary of truncated and sliced input features. """ outputs = tf.nest.pack_sequence_as(inputs, tf.nest.flatten(inputs)) if self._truncate_and_slice_config.continuous_feature_truncation: logging.info("Truncating continuous") outputs["continuous"] = outputs["continuous"][ :, : self._truncate_and_slice_config.continuous_feature_truncation ] if self._truncate_and_slice_config.binary_feature_truncation: logging.info("Truncating binary") outputs["binary"] = outputs["binary"][ :, : self._truncate_and_slice_config.binary_feature_truncation ] if self._continuous_mask is not None: outputs["continuous"] = tf.gather(outputs["continuous"], self._continuous_mask, axis=1) if self._binary_mask is not None: outputs["binary"] = tf.gather(outputs["binary"], self._binary_mask, axis=1) return outputs class DownCast(tf.keras.Model): """ A class for downcasting dataset before serialization and transferring to the training host. Depending on the data type and the actual data range, the downcasting can be lossless or not. It is strongly recommended to compare the metrics before and after downcasting. Args: downcast_config: A configuration object specifying the features and their target data types. """ def __init__(self, downcast_config): """ Initializes the DownCast model. Args: downcast_config: A configuration object specifying the features and their target data types. """ super().__init__() self.config = downcast_config self._type_map = { "bfloat16": tf.bfloat16, "bool": tf.bool, } def call(self, inputs, training=None, mask=None): """ Applies downcasting to the input features based on the configuration. Args: inputs: A dictionary of input features. training: A boolean indicating whether the model is in training mode. mask: A mask tensor. Returns: A dictionary of downcasted input features. """ outputs = tf.nest.pack_sequence_as(inputs, tf.nest.flatten(inputs)) for feature, type_str in self.config.features.items(): assert type_str in self._type_map if type_str == "bfloat16": logging.warning( "Although bfloat16 and float32 have the same number of exponent bits, this down casting is not 100% lossless. Please double check metrics." ) down_cast_data_type = self._type_map[type_str] outputs[feature] = tf.cast(outputs[feature], dtype=down_cast_data_type) return outputs class RectifyLabels(tf.keras.Model): """ A class for downcasting dataset before serialization and transferring to the training host. Depending on the data type and the actual data range, the downcasting can be lossless or not. It is strongly recommended to compare the metrics before and after downcasting. Args: downcast_config: A configuration object specifying the features and their target data types. """ def __init__(self, rectify_label_config): """ Initializes the DownCast model. Args: downcast_config: A configuration object specifying the features and their target data types. """ super().__init__() self._config = rectify_label_config self._window = int(self._config.label_rectification_window_in_hours * 60 * 60 * 1000) def call(self, inputs, training=None, mask=None): """ Applies downcasting to the input features based on the configuration. Args: inputs: A dictionary of input features. training: A boolean indicating whether the model is in training mode. mask: A mask tensor. Returns: A dictionary of downcasted input features. """ served_ts_field = self._config.served_timestamp_field impressed_ts_field = self._config.impressed_timestamp_field for label, engaged_ts_field in self._config.label_to_engaged_timestamp_field.items(): impressed = inputs[impressed_ts_field] served = inputs[served_ts_field] engaged = inputs[engaged_ts_field] keep = tf.math.logical_and(inputs[label] > 0, impressed - served < self._window) keep = tf.math.logical_and(keep, engaged - served < self._window) inputs[label] = tf.where(keep, inputs[label], tf.zeros_like(inputs[label])) return inputs class ExtractFeatures(tf.keras.Model): """ A class for rectifying labels based on specified conditions. This class is used to adjust label values in a dataset based on configured conditions involving timestamps. Args: rectify_label_config: A configuration object specifying the timestamp fields and label-to-engaged timestamp field mappings. """ def __init__(self, extract_features_config): """ Initializes the RectifyLabels model. Args: rectify_label_config: A configuration object specifying the timestamp fields and label-to-engaged timestamp field mappings. """ super().__init__() self._config = extract_features_config def call(self, inputs, training=None, mask=None): """ Rectifies label values based on the specified conditions. Args: inputs: A dictionary of input features including timestamp fields and labels. training: A boolean indicating whether the model is in training mode. mask: A mask tensor. Returns: A dictionary of input features with rectified label values. """ for row in self._config.extract_feature_table: inputs[row.name] = inputs[row.source_tensor][:, row.index] return inputs class DownsampleNegatives(tf.keras.Model): """Class for down-sampling/dropping negatives and updating the weights. If inputs['fav'] = [1, 0, 0, 0] and inputs['weights'] = [1.0, 1.0, 1.0, 1.0] inputs are transformed to inputs['fav'] = [1, 0] and inputs['weights'] = [1.0, 3.0] when batch_multiplier=2 and engagements_list=['fav'] It supports multiple engagements (union/logical_or is used to aggregate engagements), so we don't drop positives for any engagement. """ def __init__(self, downsample_negatives_config): super().__init__() self.config = downsample_negatives_config def call(self, inputs, training=None, mask=None): labels = self.config.engagements_list # union of engagements mask = tf.squeeze(tf.reduce_any(tf.stack([inputs[label] == 1 for label in labels], 1), 1)) n_positives = tf.reduce_sum(tf.cast(mask, tf.int32)) batch_size = tf.cast(tf.shape(inputs[labels[0]])[0] / self.config.batch_multiplier, tf.int32) negative_weights = tf.math.divide_no_nan( tf.cast(self.config.batch_multiplier * batch_size - n_positives, tf.float32), tf.cast(batch_size - n_positives, tf.float32), ) new_weights = tf.cast(mask, tf.float32) + (1 - tf.cast(mask, tf.float32)) * negative_weights def _split_by_label_concatenate_and_truncate(input_tensor): # takes positive examples and concatenate with negative examples and truncate # DANGER: if n_positives > batch_size down-sampling is incorrect (do not use pb_50) return tf.concat( [ input_tensor[mask], input_tensor[tf.math.logical_not(mask)], ], 0, )[:batch_size] if "weights" not in inputs: # add placeholder so logic below applies even if weights aren't present in inputs inputs["weights"] = tf.ones([tf.shape(inputs[labels[0]])[0], self.config.num_engagements]) for tensor in inputs: if tensor == "weights": inputs[tensor] = inputs[tensor] * tf.reshape(new_weights, [-1, 1]) inputs[tensor] = _split_by_label_concatenate_and_truncate(inputs[tensor]) return inputs def build_preprocess(preprocess_config, mode=config_mod.JobMode.TRAIN): """ Builds a preprocess model to apply all preprocessing stages. Args: preprocess_config: A configuration object specifying the preprocessing parameters. mode: A mode indicating the current job mode (TRAIN or INFERENCE). Returns: A preprocess model that applies all specified preprocessing stages. """ if mode == config_mod.JobMode.INFERENCE: logging.info("Not building preprocessors for dataloading since we are in Inference mode.") return None preprocess_models = [] if preprocess_config.downsample_negatives: preprocess_models.append(DownsampleNegatives(preprocess_config.downsample_negatives)) if preprocess_config.truncate_and_slice: preprocess_models.append(TruncateAndSlice(preprocess_config.truncate_and_slice)) if preprocess_config.downcast: preprocess_models.append(DownCast(preprocess_config.downcast)) if preprocess_config.rectify_labels: preprocess_models.append(RectifyLabels(preprocess_config.rectify_labels)) if preprocess_config.extract_features: preprocess_models.append(ExtractFeatures(preprocess_config.extract_features)) if len(preprocess_models) == 0: raise ValueError("No known preprocessor.") class PreprocessModel(tf.keras.Model): def __init__(self, preprocess_models): super().__init__() self.preprocess_models = preprocess_models def call(self, inputs, training=None, mask=None): outputs = inputs for model in self.preprocess_models: outputs = model(outputs, training, mask) return outputs if len(preprocess_models) > 1: logging.warning( "With multiple preprocessing models, we apply these models in a predefined order. Future works may introduce customized models and orders." ) return PreprocessModel(preprocess_models)