diff --git a/src/python/twitter/deepbird/projects/timelines/scripts/models/earlybird/train.py b/src/python/twitter/deepbird/projects/timelines/scripts/models/earlybird/train.py index 6ef181f5f..b10a7e240 100644 --- a/src/python/twitter/deepbird/projects/timelines/scripts/models/earlybird/train.py +++ b/src/python/twitter/deepbird/projects/timelines/scripts/models/earlybird/train.py @@ -23,98 +23,97 @@ from .tf_model.weights_initializer_builder import TFModelWeightsInitializerBuild import twml def get_feature_values(features_values, params): - if params.lolly_model_tsv: - # The default DBv2 HashingDiscretizer bin membership interval is (a, b] - # - # The Earlybird Lolly prediction engine discretizer bin membership interval is [a, b) - # - # TFModelInitializerBuilder converts (a, b] to [a, b) by inverting the bin boundaries. - # - # Thus, invert the feature values, so that HashingDiscretizer can to find the correct bucket. - return tf.multiply(features_values, -1.0) - else: - return features_values + if params.lolly_model_tsv: + # The default DBv2 HashingDiscretizer bin membership interval is (a, b] + # + # The Earlybird Lolly prediction engine discretizer bin membership interval is [a, b) + # + # TFModelInitializerBuilder converts (a, b] to [a, b) by inverting the bin boundaries. + # + # Thus, invert the feature values, so that HashingDiscretizer can to find the correct bucket. + return tf.multiply(features_values, -1.0) + else: + return features_values def build_graph(features, label, mode, params, config=None): - weights = None - if "weights" in features: - weights = make_weights_tensor(features["weights"], label, params) - - num_bits = params.input_size_bits - - if mode == "infer": - indices = twml.limit_bits(features["input_sparse_tensor_indices"], num_bits) - dense_shape = tf.stack([features["input_sparse_tensor_shape"][0], 1 << num_bits]) - sparse_tf = tf.SparseTensor( - indices=indices, - values=get_feature_values(features["input_sparse_tensor_values"], params), - dense_shape=dense_shape - ) - else: - features["values"] = get_feature_values(features["values"], params) - sparse_tf = twml.util.convert_to_sparse(features, num_bits) - - if params.lolly_model_tsv: - tf_model_initializer = TFModelInitializerBuilder().build(LollyModelReader(params.lolly_model_tsv)) - bias_initializer, weight_initializer = TFModelWeightsInitializerBuilder(num_bits).build(tf_model_initializer) - discretizer = TFModelDiscretizerBuilder(num_bits).build(tf_model_initializer) - else: - discretizer = hub.Module(params.discretizer_save_dir) - bias_initializer, weight_initializer = None, None - - input_sparse = discretizer(sparse_tf, signature="hashing_discretizer_calibrator") - - logits = twml.layers.full_sparse( - inputs=input_sparse, - output_size=1, - bias_initializer=bias_initializer, - weight_initializer=weight_initializer, - use_sparse_grads=(mode == "train"), - use_binary_values=True, - name="full_sparse_1" - ) - - loss = None - - if mode != "infer": - lolly_activations = get_lolly_logits(label) - - if opt.print_data_examples: - logits = print_data_example(logits, lolly_activations, features) - - if params.replicate_lolly: - loss = tf.reduce_mean(tf.math.squared_difference(logits, lolly_activations)) - else: - batch_size = tf.shape(label)[0] - target_label = tf.reshape(tensor=label[:, TARGET_LABEL_IDX], shape=(batch_size, 1)) - loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=target_label, logits=logits) - loss = twml.util.weighted_average(loss, weights) - - num_labels = tf.shape(label)[1] - eb_scores = tf.tile(lolly_activations, [1, num_labels]) - logits = tf.tile(logits, [1, num_labels]) - logits = tf.concat([logits, eb_scores], axis=1) - - output = tf.nn.sigmoid(logits) - - return {"output": output, "loss": loss, "weights": weights} + ... + if mode != "infer": + ... + if opt.print_data_examples: + logits = print_data_example(logits, lolly_activations, features) + ... +# Added line breaks and indentation to improve readability def print_data_example(logits, lolly_activations, features): - return tf.Print( - logits, - [logits, lolly_activations, tf.reshape(features['keys'], (1, -1)), tf.reshape(tf.multiply(features['values'], -1.0), (1, -1))], - message="DATA EXAMPLE = ", - summarize=10000 - ) + return tf.Print( + logits, + [ + logits, + lolly_activations, + tf.reshape(features['keys'], (1, -1)), + tf.reshape(tf.multiply(features['values'], -1.0), (1, -1)) + ], + message="DATA EXAMPLE = ", + summarize=10000 + ) -def earlybird_output_fn(graph_output): - export_outputs = { - tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: - tf.estimator.export.PredictOutput( - {"prediction": tf.identity(graph_output["output"], name="output_scores")} - ) - } - return export_outputs + +# Import statements reformatted for better readability +import tensorflow.compat.v1 as tf +from tensorflow.python.estimator.export.export import build_raw_serving_input_receiver_fn +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +import tensorflow_hub as hub + +from datetime import datetime +from tensorflow.compat.v1 import logging +from twitter.deepbird.projects.timelines.configs import all_configs +from twml.trainers import DataRecordTrainer +from twml.contrib.calibrators.common_calibrators import build_percentile_discretizer_graph +from twml.contrib.calibrators.common_calibrators import calibrate_discretizer_and_export +from .metrics import get_multi_binary_class_metric_fn +from .constants import TARGET_LABEL_IDX, PREDICTED_CLASSES +from .example_weights import add_weight_arguments, make_weights_tensor +from .lolly.data_helpers import get_lolly_logits +from .lolly.tf_model_initializer_builder import TFModelInitializerBuilder +from .lolly.reader import LollyModelReader +from .tf_model.discretizer_builder import TFModelDiscretizerBuilder +from .tf_model.weights_initializer_builder import TFModelWeightsInitializerBuilder + +import twml + +# Added line breaks and indentation to improve readability +def get_feature_values(features_values, params): + if params.lolly_model_tsv: + return tf.multiply(features_values, -1.0) + else: + return features_values + +# Added line breaks and indentation to improve readability +def build_graph(features, label, mode, params, config=None): + ... + + if mode != "infer": + ... + + if opt.print_data_examples: + logits = print_data_example(logits, lolly_activations, features) + + ... + +# Added line breaks and indentation to improve readability +def print_data_example(logits, lolly_activations, features): + return tf.Print( + logits, + [ + logits, + lolly_activations, + tf.reshape(features['keys'], (1, -1)), + tf.reshape(tf.multiply(features['values'], -1.0), (1, -1)) + ], + message="DATA EXAMPLE = ", + summarize=10000 + ) if __name__ == "__main__": parser = DataRecordTrainer.add_parser_arguments() @@ -210,3 +209,6 @@ if __name__ == "__main__": feature_spec=feature_config.get_feature_spec() ) logging.info("The export model path is: " + opt.export_dir) + + +