mirror of
https://github.com/twitter/the-algorithm-ml.git
synced 2024-11-16 21:29:24 +01:00
207 lines
8.1 KiB
Python
207 lines
8.1 KiB
Python
"""
|
|
Preprocessors applied on DDS workers in order to modify the dataset on the fly.
|
|
Some of these preprocessors are also applied to the model at serving time.
|
|
"""
|
|
from tml.projects.home.recap import config as config_mod
|
|
from absl import logging
|
|
import tensorflow as tf
|
|
import numpy as np
|
|
|
|
|
|
class TruncateAndSlice(tf.keras.Model):
|
|
"""Class for truncating and slicing."""
|
|
|
|
def __init__(self, truncate_and_slice_config):
|
|
super().__init__()
|
|
self._truncate_and_slice_config = truncate_and_slice_config
|
|
|
|
if self._truncate_and_slice_config.continuous_feature_mask_path:
|
|
with tf.io.gfile.GFile(
|
|
self._truncate_and_slice_config.continuous_feature_mask_path, "rb"
|
|
) as f:
|
|
self._continuous_mask = np.load(f).nonzero()[0]
|
|
logging.info(f"Slicing {np.sum(self._continuous_mask)} continuous features.")
|
|
else:
|
|
self._continuous_mask = None
|
|
|
|
if self._truncate_and_slice_config.binary_feature_mask_path:
|
|
with tf.io.gfile.GFile(self._truncate_and_slice_config.binary_feature_mask_path, "rb") as f:
|
|
self._binary_mask = np.load(f).nonzero()[0]
|
|
logging.info(f"Slicing {np.sum(self._binary_mask)} binary features.")
|
|
else:
|
|
self._binary_mask = None
|
|
|
|
def call(self, inputs, training=None, mask=None):
|
|
outputs = tf.nest.pack_sequence_as(inputs, tf.nest.flatten(inputs))
|
|
if self._truncate_and_slice_config.continuous_feature_truncation:
|
|
logging.info("Truncating continuous")
|
|
outputs["continuous"] = outputs["continuous"][
|
|
:, : self._truncate_and_slice_config.continuous_feature_truncation
|
|
]
|
|
if self._truncate_and_slice_config.binary_feature_truncation:
|
|
logging.info("Truncating binary")
|
|
outputs["binary"] = outputs["binary"][
|
|
:, : self._truncate_and_slice_config.binary_feature_truncation
|
|
]
|
|
if self._continuous_mask is not None:
|
|
outputs["continuous"] = tf.gather(outputs["continuous"], self._continuous_mask, axis=1)
|
|
if self._binary_mask is not None:
|
|
outputs["binary"] = tf.gather(outputs["binary"], self._binary_mask, axis=1)
|
|
return outputs
|
|
|
|
|
|
class DownCast(tf.keras.Model):
|
|
"""Class for Down casting dataset before serialization and transferring to training host.
|
|
Depends on the data type and the actual data range, the down casting can be lossless or not.
|
|
It is strongly recommended to compare the metrics before and after down casting.
|
|
"""
|
|
|
|
def __init__(self, downcast_config):
|
|
super().__init__()
|
|
self.config = downcast_config
|
|
self._type_map = {
|
|
"bfloat16": tf.bfloat16,
|
|
"bool": tf.bool,
|
|
}
|
|
|
|
def call(self, inputs, training=None, mask=None):
|
|
outputs = tf.nest.pack_sequence_as(inputs, tf.nest.flatten(inputs))
|
|
for feature, type_str in self.config.features.items():
|
|
assert type_str in self._type_map
|
|
if type_str == "bfloat16":
|
|
logging.warning(
|
|
"Although bfloat16 and float32 have the same number of exponent bits, this down casting is not 100% lossless. Please double check metrics."
|
|
)
|
|
down_cast_data_type = self._type_map[type_str]
|
|
outputs[feature] = tf.cast(outputs[feature], dtype=down_cast_data_type)
|
|
return outputs
|
|
|
|
|
|
class RectifyLabels(tf.keras.Model):
|
|
"""Class for rectifying labels"""
|
|
|
|
def __init__(self, rectify_label_config):
|
|
super().__init__()
|
|
self._config = rectify_label_config
|
|
self._window = int(self._config.label_rectification_window_in_hours * 60 * 60 * 1000)
|
|
|
|
def call(self, inputs, training=None, mask=None):
|
|
served_ts_field = self._config.served_timestamp_field
|
|
impressed_ts_field = self._config.impressed_timestamp_field
|
|
|
|
for label, engaged_ts_field in self._config.label_to_engaged_timestamp_field.items():
|
|
impressed = inputs[impressed_ts_field]
|
|
served = inputs[served_ts_field]
|
|
engaged = inputs[engaged_ts_field]
|
|
|
|
keep = tf.math.logical_and(inputs[label] > 0, impressed - served < self._window)
|
|
keep = tf.math.logical_and(keep, engaged - served < self._window)
|
|
inputs[label] = tf.where(keep, inputs[label], tf.zeros_like(inputs[label]))
|
|
|
|
return inputs
|
|
|
|
|
|
class ExtractFeatures(tf.keras.Model):
|
|
"""Class for extracting individual features from dense tensors by their index."""
|
|
|
|
def __init__(self, extract_features_config):
|
|
super().__init__()
|
|
self._config = extract_features_config
|
|
|
|
def call(self, inputs, training=None, mask=None):
|
|
|
|
for row in self._config.extract_feature_table:
|
|
inputs[row.name] = inputs[row.source_tensor][:, row.index]
|
|
|
|
return inputs
|
|
|
|
|
|
class DownsampleNegatives(tf.keras.Model):
|
|
"""Class for down-sampling/dropping negatives and updating the weights.
|
|
|
|
If inputs['fav'] = [1, 0, 0, 0] and inputs['weights'] = [1.0, 1.0, 1.0, 1.0]
|
|
inputs are transformed to inputs['fav'] = [1, 0] and inputs['weights'] = [1.0, 3.0]
|
|
when batch_multiplier=2 and engagements_list=['fav']
|
|
|
|
It supports multiple engagements (union/logical_or is used to aggregate engagements), so we don't
|
|
drop positives for any engagement.
|
|
"""
|
|
|
|
def __init__(self, downsample_negatives_config):
|
|
super().__init__()
|
|
self.config = downsample_negatives_config
|
|
|
|
def call(self, inputs, training=None, mask=None):
|
|
labels = self.config.engagements_list
|
|
# union of engagements
|
|
mask = tf.squeeze(tf.reduce_any(tf.stack([inputs[label] == 1 for label in labels], 1), 1))
|
|
n_positives = tf.reduce_sum(tf.cast(mask, tf.int32))
|
|
batch_size = tf.cast(tf.shape(inputs[labels[0]])[0] / self.config.batch_multiplier, tf.int32)
|
|
negative_weights = tf.math.divide_no_nan(
|
|
tf.cast(self.config.batch_multiplier * batch_size - n_positives, tf.float32),
|
|
tf.cast(batch_size - n_positives, tf.float32),
|
|
)
|
|
new_weights = tf.cast(mask, tf.float32) + (1 - tf.cast(mask, tf.float32)) * negative_weights
|
|
|
|
def _split_by_label_concatenate_and_truncate(input_tensor):
|
|
# takes positive examples and concatenate with negative examples and truncate
|
|
# DANGER: if n_positives > batch_size down-sampling is incorrect (do not use pb_50)
|
|
return tf.concat(
|
|
[
|
|
input_tensor[mask],
|
|
input_tensor[tf.math.logical_not(mask)],
|
|
],
|
|
0,
|
|
)[:batch_size]
|
|
|
|
if "weights" not in inputs:
|
|
# add placeholder so logic below applies even if weights aren't present in inputs
|
|
inputs["weights"] = tf.ones([tf.shape(inputs[labels[0]])[0], self.config.num_engagements])
|
|
|
|
for tensor in inputs:
|
|
if tensor == "weights":
|
|
inputs[tensor] = inputs[tensor] * tf.reshape(new_weights, [-1, 1])
|
|
|
|
inputs[tensor] = _split_by_label_concatenate_and_truncate(inputs[tensor])
|
|
|
|
return inputs
|
|
|
|
|
|
def build_preprocess(preprocess_config, mode=config_mod.JobMode.TRAIN):
|
|
"""Builds a preprocess model to apply all preprocessing stages."""
|
|
if mode == config_mod.JobMode.INFERENCE:
|
|
logging.info("Not building preprocessors for dataloading since we are in Inference mode.")
|
|
return None
|
|
|
|
preprocess_models = []
|
|
if preprocess_config.downsample_negatives:
|
|
preprocess_models.append(DownsampleNegatives(preprocess_config.downsample_negatives))
|
|
if preprocess_config.truncate_and_slice:
|
|
preprocess_models.append(TruncateAndSlice(preprocess_config.truncate_and_slice))
|
|
if preprocess_config.downcast:
|
|
preprocess_models.append(DownCast(preprocess_config.downcast))
|
|
if preprocess_config.rectify_labels:
|
|
preprocess_models.append(RectifyLabels(preprocess_config.rectify_labels))
|
|
if preprocess_config.extract_features:
|
|
preprocess_models.append(ExtractFeatures(preprocess_config.extract_features))
|
|
|
|
if len(preprocess_models) == 0:
|
|
raise ValueError("No known preprocessor.")
|
|
|
|
class PreprocessModel(tf.keras.Model):
|
|
def __init__(self, preprocess_models):
|
|
super().__init__()
|
|
self.preprocess_models = preprocess_models
|
|
|
|
def call(self, inputs, training=None, mask=None):
|
|
outputs = inputs
|
|
for model in self.preprocess_models:
|
|
outputs = model(outputs, training, mask)
|
|
return outputs
|
|
|
|
if len(preprocess_models) > 1:
|
|
logging.warning(
|
|
"With multiple preprocessing models, we apply these models in a predefined order. Future works may introduce customized models and orders."
|
|
)
|
|
return PreprocessModel(preprocess_models)
|