the-algorithm-ml/projects/home/recap/data/preprocessors.py

309 lines
11 KiB
Python

"""
Preprocessors applied on DDS workers in order to modify the dataset on the fly.
Some of these preprocessors are also applied to the model at serving time.
"""
from tml.projects.home.recap import config as config_mod
from absl import logging
import tensorflow as tf
import numpy as np
class TruncateAndSlice(tf.keras.Model):
"""
A class for truncating and slicing input features based on the provided configuration.
Args:
truncate_and_slice_config: A configuration object specifying how to truncate and slice features.
"""
def __init__(self, truncate_and_slice_config):
"""
Initializes the TruncateAndSlice model.
Args:
truncate_and_slice_config: A configuration object specifying how to truncate and slice features.
"""
super().__init__()
self._truncate_and_slice_config = truncate_and_slice_config
if self._truncate_and_slice_config.continuous_feature_mask_path:
with tf.io.gfile.GFile(
self._truncate_and_slice_config.continuous_feature_mask_path, "rb"
) as f:
self._continuous_mask = np.load(f).nonzero()[0]
logging.info(f"Slicing {np.sum(self._continuous_mask)} continuous features.")
else:
self._continuous_mask = None
if self._truncate_and_slice_config.binary_feature_mask_path:
with tf.io.gfile.GFile(self._truncate_and_slice_config.binary_feature_mask_path, "rb") as f:
self._binary_mask = np.load(f).nonzero()[0]
logging.info(f"Slicing {np.sum(self._binary_mask)} binary features.")
else:
self._binary_mask = None
def call(self, inputs, training=None, mask=None):
"""
Applies truncation and slicing to the input features based on the configuration.
Args:
inputs: A dictionary of input features.
training: A boolean indicating whether the model is in training mode.
mask: A mask tensor.
Returns:
A dictionary of truncated and sliced input features.
"""
outputs = tf.nest.pack_sequence_as(inputs, tf.nest.flatten(inputs))
if self._truncate_and_slice_config.continuous_feature_truncation:
logging.info("Truncating continuous")
outputs["continuous"] = outputs["continuous"][
:, : self._truncate_and_slice_config.continuous_feature_truncation
]
if self._truncate_and_slice_config.binary_feature_truncation:
logging.info("Truncating binary")
outputs["binary"] = outputs["binary"][
:, : self._truncate_and_slice_config.binary_feature_truncation
]
if self._continuous_mask is not None:
outputs["continuous"] = tf.gather(outputs["continuous"], self._continuous_mask, axis=1)
if self._binary_mask is not None:
outputs["binary"] = tf.gather(outputs["binary"], self._binary_mask, axis=1)
return outputs
class DownCast(tf.keras.Model):
"""
A class for downcasting dataset before serialization and transferring to the training host.
Depending on the data type and the actual data range, the downcasting can be lossless or not.
It is strongly recommended to compare the metrics before and after downcasting.
Args:
downcast_config: A configuration object specifying the features and their target data types.
"""
def __init__(self, downcast_config):
"""
Initializes the DownCast model.
Args:
downcast_config: A configuration object specifying the features and their target data types.
"""
super().__init__()
self.config = downcast_config
self._type_map = {
"bfloat16": tf.bfloat16,
"bool": tf.bool,
}
def call(self, inputs, training=None, mask=None):
"""
Applies downcasting to the input features based on the configuration.
Args:
inputs: A dictionary of input features.
training: A boolean indicating whether the model is in training mode.
mask: A mask tensor.
Returns:
A dictionary of downcasted input features.
"""
outputs = tf.nest.pack_sequence_as(inputs, tf.nest.flatten(inputs))
for feature, type_str in self.config.features.items():
assert type_str in self._type_map
if type_str == "bfloat16":
logging.warning(
"Although bfloat16 and float32 have the same number of exponent bits, this down casting is not 100% lossless. Please double check metrics."
)
down_cast_data_type = self._type_map[type_str]
outputs[feature] = tf.cast(outputs[feature], dtype=down_cast_data_type)
return outputs
class RectifyLabels(tf.keras.Model):
"""
A class for downcasting dataset before serialization and transferring to the training host.
Depending on the data type and the actual data range, the downcasting can be lossless or not.
It is strongly recommended to compare the metrics before and after downcasting.
Args:
downcast_config: A configuration object specifying the features and their target data types.
"""
def __init__(self, rectify_label_config):
"""
Initializes the DownCast model.
Args:
downcast_config: A configuration object specifying the features and their target data types.
"""
super().__init__()
self._config = rectify_label_config
self._window = int(self._config.label_rectification_window_in_hours * 60 * 60 * 1000)
def call(self, inputs, training=None, mask=None):
"""
Applies downcasting to the input features based on the configuration.
Args:
inputs: A dictionary of input features.
training: A boolean indicating whether the model is in training mode.
mask: A mask tensor.
Returns:
A dictionary of downcasted input features.
"""
served_ts_field = self._config.served_timestamp_field
impressed_ts_field = self._config.impressed_timestamp_field
for label, engaged_ts_field in self._config.label_to_engaged_timestamp_field.items():
impressed = inputs[impressed_ts_field]
served = inputs[served_ts_field]
engaged = inputs[engaged_ts_field]
keep = tf.math.logical_and(inputs[label] > 0, impressed - served < self._window)
keep = tf.math.logical_and(keep, engaged - served < self._window)
inputs[label] = tf.where(keep, inputs[label], tf.zeros_like(inputs[label]))
return inputs
class ExtractFeatures(tf.keras.Model):
"""
A class for rectifying labels based on specified conditions.
This class is used to adjust label values in a dataset based on configured conditions involving timestamps.
Args:
rectify_label_config: A configuration object specifying the timestamp fields and label-to-engaged timestamp field mappings.
"""
def __init__(self, extract_features_config):
"""
Initializes the RectifyLabels model.
Args:
rectify_label_config: A configuration object specifying the timestamp fields and label-to-engaged timestamp field mappings.
"""
super().__init__()
self._config = extract_features_config
def call(self, inputs, training=None, mask=None):
"""
Rectifies label values based on the specified conditions.
Args:
inputs: A dictionary of input features including timestamp fields and labels.
training: A boolean indicating whether the model is in training mode.
mask: A mask tensor.
Returns:
A dictionary of input features with rectified label values.
"""
for row in self._config.extract_feature_table:
inputs[row.name] = inputs[row.source_tensor][:, row.index]
return inputs
class DownsampleNegatives(tf.keras.Model):
"""Class for down-sampling/dropping negatives and updating the weights.
If inputs['fav'] = [1, 0, 0, 0] and inputs['weights'] = [1.0, 1.0, 1.0, 1.0]
inputs are transformed to inputs['fav'] = [1, 0] and inputs['weights'] = [1.0, 3.0]
when batch_multiplier=2 and engagements_list=['fav']
It supports multiple engagements (union/logical_or is used to aggregate engagements), so we don't
drop positives for any engagement.
"""
def __init__(self, downsample_negatives_config):
super().__init__()
self.config = downsample_negatives_config
def call(self, inputs, training=None, mask=None):
labels = self.config.engagements_list
# union of engagements
mask = tf.squeeze(tf.reduce_any(tf.stack([inputs[label] == 1 for label in labels], 1), 1))
n_positives = tf.reduce_sum(tf.cast(mask, tf.int32))
batch_size = tf.cast(tf.shape(inputs[labels[0]])[0] / self.config.batch_multiplier, tf.int32)
negative_weights = tf.math.divide_no_nan(
tf.cast(self.config.batch_multiplier * batch_size - n_positives, tf.float32),
tf.cast(batch_size - n_positives, tf.float32),
)
new_weights = tf.cast(mask, tf.float32) + (1 - tf.cast(mask, tf.float32)) * negative_weights
def _split_by_label_concatenate_and_truncate(input_tensor):
# takes positive examples and concatenate with negative examples and truncate
# DANGER: if n_positives > batch_size down-sampling is incorrect (do not use pb_50)
return tf.concat(
[
input_tensor[mask],
input_tensor[tf.math.logical_not(mask)],
],
0,
)[:batch_size]
if "weights" not in inputs:
# add placeholder so logic below applies even if weights aren't present in inputs
inputs["weights"] = tf.ones([tf.shape(inputs[labels[0]])[0], self.config.num_engagements])
for tensor in inputs:
if tensor == "weights":
inputs[tensor] = inputs[tensor] * tf.reshape(new_weights, [-1, 1])
inputs[tensor] = _split_by_label_concatenate_and_truncate(inputs[tensor])
return inputs
def build_preprocess(preprocess_config, mode=config_mod.JobMode.TRAIN):
"""
Builds a preprocess model to apply all preprocessing stages.
Args:
preprocess_config: A configuration object specifying the preprocessing parameters.
mode: A mode indicating the current job mode (TRAIN or INFERENCE).
Returns:
A preprocess model that applies all specified preprocessing stages.
"""
if mode == config_mod.JobMode.INFERENCE:
logging.info("Not building preprocessors for dataloading since we are in Inference mode.")
return None
preprocess_models = []
if preprocess_config.downsample_negatives:
preprocess_models.append(DownsampleNegatives(preprocess_config.downsample_negatives))
if preprocess_config.truncate_and_slice:
preprocess_models.append(TruncateAndSlice(preprocess_config.truncate_and_slice))
if preprocess_config.downcast:
preprocess_models.append(DownCast(preprocess_config.downcast))
if preprocess_config.rectify_labels:
preprocess_models.append(RectifyLabels(preprocess_config.rectify_labels))
if preprocess_config.extract_features:
preprocess_models.append(ExtractFeatures(preprocess_config.extract_features))
if len(preprocess_models) == 0:
raise ValueError("No known preprocessor.")
class PreprocessModel(tf.keras.Model):
def __init__(self, preprocess_models):
super().__init__()
self.preprocess_models = preprocess_models
def call(self, inputs, training=None, mask=None):
outputs = inputs
for model in self.preprocess_models:
outputs = model(outputs, training, mask)
return outputs
if len(preprocess_models) > 1:
logging.warning(
"With multiple preprocessing models, we apply these models in a predefined order. Future works may introduce customized models and orders."
)
return PreprocessModel(preprocess_models)