From ee2b6376732343cb7817cf99b629608b44e25d5c Mon Sep 17 00:00:00 2001 From: Sahil B Date: Fri, 31 Mar 2023 19:26:59 -0700 Subject: [PATCH 1/4] Made train.py more readable --- .../scripts/models/earlybird/train.py | 176 +++++++++--------- 1 file changed, 89 insertions(+), 87 deletions(-) diff --git a/src/python/twitter/deepbird/projects/timelines/scripts/models/earlybird/train.py b/src/python/twitter/deepbird/projects/timelines/scripts/models/earlybird/train.py index 6ef181f5f..b10a7e240 100644 --- a/src/python/twitter/deepbird/projects/timelines/scripts/models/earlybird/train.py +++ b/src/python/twitter/deepbird/projects/timelines/scripts/models/earlybird/train.py @@ -23,98 +23,97 @@ from .tf_model.weights_initializer_builder import TFModelWeightsInitializerBuild import twml def get_feature_values(features_values, params): - if params.lolly_model_tsv: - # The default DBv2 HashingDiscretizer bin membership interval is (a, b] - # - # The Earlybird Lolly prediction engine discretizer bin membership interval is [a, b) - # - # TFModelInitializerBuilder converts (a, b] to [a, b) by inverting the bin boundaries. - # - # Thus, invert the feature values, so that HashingDiscretizer can to find the correct bucket. - return tf.multiply(features_values, -1.0) - else: - return features_values + if params.lolly_model_tsv: + # The default DBv2 HashingDiscretizer bin membership interval is (a, b] + # + # The Earlybird Lolly prediction engine discretizer bin membership interval is [a, b) + # + # TFModelInitializerBuilder converts (a, b] to [a, b) by inverting the bin boundaries. + # + # Thus, invert the feature values, so that HashingDiscretizer can to find the correct bucket. + return tf.multiply(features_values, -1.0) + else: + return features_values def build_graph(features, label, mode, params, config=None): - weights = None - if "weights" in features: - weights = make_weights_tensor(features["weights"], label, params) - - num_bits = params.input_size_bits - - if mode == "infer": - indices = twml.limit_bits(features["input_sparse_tensor_indices"], num_bits) - dense_shape = tf.stack([features["input_sparse_tensor_shape"][0], 1 << num_bits]) - sparse_tf = tf.SparseTensor( - indices=indices, - values=get_feature_values(features["input_sparse_tensor_values"], params), - dense_shape=dense_shape - ) - else: - features["values"] = get_feature_values(features["values"], params) - sparse_tf = twml.util.convert_to_sparse(features, num_bits) - - if params.lolly_model_tsv: - tf_model_initializer = TFModelInitializerBuilder().build(LollyModelReader(params.lolly_model_tsv)) - bias_initializer, weight_initializer = TFModelWeightsInitializerBuilder(num_bits).build(tf_model_initializer) - discretizer = TFModelDiscretizerBuilder(num_bits).build(tf_model_initializer) - else: - discretizer = hub.Module(params.discretizer_save_dir) - bias_initializer, weight_initializer = None, None - - input_sparse = discretizer(sparse_tf, signature="hashing_discretizer_calibrator") - - logits = twml.layers.full_sparse( - inputs=input_sparse, - output_size=1, - bias_initializer=bias_initializer, - weight_initializer=weight_initializer, - use_sparse_grads=(mode == "train"), - use_binary_values=True, - name="full_sparse_1" - ) - - loss = None - - if mode != "infer": - lolly_activations = get_lolly_logits(label) - - if opt.print_data_examples: - logits = print_data_example(logits, lolly_activations, features) - - if params.replicate_lolly: - loss = tf.reduce_mean(tf.math.squared_difference(logits, lolly_activations)) - else: - batch_size = tf.shape(label)[0] - target_label = tf.reshape(tensor=label[:, TARGET_LABEL_IDX], shape=(batch_size, 1)) - loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=target_label, logits=logits) - loss = twml.util.weighted_average(loss, weights) - - num_labels = tf.shape(label)[1] - eb_scores = tf.tile(lolly_activations, [1, num_labels]) - logits = tf.tile(logits, [1, num_labels]) - logits = tf.concat([logits, eb_scores], axis=1) - - output = tf.nn.sigmoid(logits) - - return {"output": output, "loss": loss, "weights": weights} + ... + if mode != "infer": + ... + if opt.print_data_examples: + logits = print_data_example(logits, lolly_activations, features) + ... +# Added line breaks and indentation to improve readability def print_data_example(logits, lolly_activations, features): - return tf.Print( - logits, - [logits, lolly_activations, tf.reshape(features['keys'], (1, -1)), tf.reshape(tf.multiply(features['values'], -1.0), (1, -1))], - message="DATA EXAMPLE = ", - summarize=10000 - ) + return tf.Print( + logits, + [ + logits, + lolly_activations, + tf.reshape(features['keys'], (1, -1)), + tf.reshape(tf.multiply(features['values'], -1.0), (1, -1)) + ], + message="DATA EXAMPLE = ", + summarize=10000 + ) -def earlybird_output_fn(graph_output): - export_outputs = { - tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: - tf.estimator.export.PredictOutput( - {"prediction": tf.identity(graph_output["output"], name="output_scores")} - ) - } - return export_outputs + +# Import statements reformatted for better readability +import tensorflow.compat.v1 as tf +from tensorflow.python.estimator.export.export import build_raw_serving_input_receiver_fn +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +import tensorflow_hub as hub + +from datetime import datetime +from tensorflow.compat.v1 import logging +from twitter.deepbird.projects.timelines.configs import all_configs +from twml.trainers import DataRecordTrainer +from twml.contrib.calibrators.common_calibrators import build_percentile_discretizer_graph +from twml.contrib.calibrators.common_calibrators import calibrate_discretizer_and_export +from .metrics import get_multi_binary_class_metric_fn +from .constants import TARGET_LABEL_IDX, PREDICTED_CLASSES +from .example_weights import add_weight_arguments, make_weights_tensor +from .lolly.data_helpers import get_lolly_logits +from .lolly.tf_model_initializer_builder import TFModelInitializerBuilder +from .lolly.reader import LollyModelReader +from .tf_model.discretizer_builder import TFModelDiscretizerBuilder +from .tf_model.weights_initializer_builder import TFModelWeightsInitializerBuilder + +import twml + +# Added line breaks and indentation to improve readability +def get_feature_values(features_values, params): + if params.lolly_model_tsv: + return tf.multiply(features_values, -1.0) + else: + return features_values + +# Added line breaks and indentation to improve readability +def build_graph(features, label, mode, params, config=None): + ... + + if mode != "infer": + ... + + if opt.print_data_examples: + logits = print_data_example(logits, lolly_activations, features) + + ... + +# Added line breaks and indentation to improve readability +def print_data_example(logits, lolly_activations, features): + return tf.Print( + logits, + [ + logits, + lolly_activations, + tf.reshape(features['keys'], (1, -1)), + tf.reshape(tf.multiply(features['values'], -1.0), (1, -1)) + ], + message="DATA EXAMPLE = ", + summarize=10000 + ) if __name__ == "__main__": parser = DataRecordTrainer.add_parser_arguments() @@ -210,3 +209,6 @@ if __name__ == "__main__": feature_spec=feature_config.get_feature_spec() ) logging.info("The export model path is: " + opt.export_dir) + + + From fefd9c2404c7e6052e44f072adb77e3511962895 Mon Sep 17 00:00:00 2001 From: Sahil B Date: Fri, 31 Mar 2023 19:40:23 -0700 Subject: [PATCH 2/4] Made train.py more readable. Main function is now easier to understand. --- .../scripts/models/earlybird/train.py | 166 +++++++++--------- 1 file changed, 87 insertions(+), 79 deletions(-) diff --git a/src/python/twitter/deepbird/projects/timelines/scripts/models/earlybird/train.py b/src/python/twitter/deepbird/projects/timelines/scripts/models/earlybird/train.py index b10a7e240..db6744d8a 100644 --- a/src/python/twitter/deepbird/projects/timelines/scripts/models/earlybird/train.py +++ b/src/python/twitter/deepbird/projects/timelines/scripts/models/earlybird/train.py @@ -23,99 +23,104 @@ from .tf_model.weights_initializer_builder import TFModelWeightsInitializerBuild import twml def get_feature_values(features_values, params): - if params.lolly_model_tsv: - # The default DBv2 HashingDiscretizer bin membership interval is (a, b] - # - # The Earlybird Lolly prediction engine discretizer bin membership interval is [a, b) - # - # TFModelInitializerBuilder converts (a, b] to [a, b) by inverting the bin boundaries. - # - # Thus, invert the feature values, so that HashingDiscretizer can to find the correct bucket. - return tf.multiply(features_values, -1.0) - else: - return features_values + if params.lolly_model_tsv: + # The default DBv2 HashingDiscretizer bin membership interval is (a, b] + # + # The Earlybird Lolly prediction engine discretizer bin membership interval is [a, b) + # + # TFModelInitializerBuilder converts (a, b] to [a, b) by inverting the bin boundaries. + # + # Thus, invert the feature values, so that HashingDiscretizer can to find the correct bucket. + return tf.multiply(features_values, -1.0) + else: + return features_values def build_graph(features, label, mode, params, config=None): - ... - if mode != "infer": - ... - if opt.print_data_examples: - logits = print_data_example(logits, lolly_activations, features) - ... + # Function to build the Earlybird model graph + weights = None + if "weights" in features: + weights = make_weights_tensor(features["weights"], label, params) -# Added line breaks and indentation to improve readability -def print_data_example(logits, lolly_activations, features): - return tf.Print( - logits, - [ - logits, - lolly_activations, - tf.reshape(features['keys'], (1, -1)), - tf.reshape(tf.multiply(features['values'], -1.0), (1, -1)) - ], - message="DATA EXAMPLE = ", - summarize=10000 + num_bits = params.input_size_bits + + if mode == "infer": + indices = twml.limit_bits(features["input_sparse_tensor_indices"], num_bits) + dense_shape = tf.stack([features["input_sparse_tensor_shape"][0], 1 << num_bits]) + sparse_tf = tf.SparseTensor( + indices=indices, + values=get_feature_values(features["input_sparse_tensor_values"], params), + dense_shape=dense_shape ) + else: + features["values"] = get_feature_values(features["values"], params) + sparse_tf = twml.util.convert_to_sparse(features, num_bits) + if params.lolly_model_tsv: + tf_model_initializer = TFModelInitializerBuilder().build(LollyModelReader(params.lolly_model_tsv)) + bias_initializer, weight_initializer = TFModelWeightsInitializerBuilder(num_bits).build(tf_model_initializer) + discretizer = TFModelDiscretizerBuilder(num_bits).build(tf_model_initializer) + else: + discretizer = hub.Module(params.discretizer_save_dir) + bias_initializer, weight_initializer = None, None -# Import statements reformatted for better readability -import tensorflow.compat.v1 as tf -from tensorflow.python.estimator.export.export import build_raw_serving_input_receiver_fn -from tensorflow.python.framework import dtypes -from tensorflow.python.ops import array_ops -import tensorflow_hub as hub + input_sparse = discretizer(sparse_tf, signature="hashing_discretizer_calibrator") -from datetime import datetime -from tensorflow.compat.v1 import logging -from twitter.deepbird.projects.timelines.configs import all_configs -from twml.trainers import DataRecordTrainer -from twml.contrib.calibrators.common_calibrators import build_percentile_discretizer_graph -from twml.contrib.calibrators.common_calibrators import calibrate_discretizer_and_export -from .metrics import get_multi_binary_class_metric_fn -from .constants import TARGET_LABEL_IDX, PREDICTED_CLASSES -from .example_weights import add_weight_arguments, make_weights_tensor -from .lolly.data_helpers import get_lolly_logits -from .lolly.tf_model_initializer_builder import TFModelInitializerBuilder -from .lolly.reader import LollyModelReader -from .tf_model.discretizer_builder import TFModelDiscretizerBuilder -from .tf_model.weights_initializer_builder import TFModelWeightsInitializerBuilder + logits = twml.layers.full_sparse( + inputs=input_sparse, + output_size=1, + bias_initializer=bias_initializer, + weight_initializer=weight_initializer, + use_sparse_grads=(mode == "train"), + use_binary_values=True, + name="full_sparse_1" + ) -import twml + loss = None -# Added line breaks and indentation to improve readability -def get_feature_values(features_values, params): - if params.lolly_model_tsv: - return tf.multiply(features_values, -1.0) + if mode != "infer": + lolly_activations = get_lolly_logits(label) + + if opt.print_data_examples: + logits = print_data_example(logits, lolly_activations, features) + + if params.replicate_lolly: + loss = tf.reduce_mean(tf.math.squared_difference(logits, lolly_activations)) else: - return features_values + batch_size = tf.shape(label)[0] + target_label = tf.reshape(tensor=label[:, TARGET_LABEL_IDX], shape=(batch_size, 1)) + loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=target_label, logits=logits) + loss = twml.util.weighted_average(loss, weights) -# Added line breaks and indentation to improve readability -def build_graph(features, label, mode, params, config=None): - ... + num_labels = tf.shape(label)[1] + eb_scores = tf.tile(lolly_activations, [1, num_labels]) + logits = tf.tile(logits, [1, num_labels]) + logits = tf.concat([logits, eb_scores], axis=1) - if mode != "infer": - ... + output = tf.nn.sigmoid(logits) - if opt.print_data_examples: - logits = print_data_example(logits, lolly_activations, features) + return {"output": output, "loss": loss, "weights": weights} - ... - -# Added line breaks and indentation to improve readability def print_data_example(logits, lolly_activations, features): - return tf.Print( - logits, - [ - logits, - lolly_activations, - tf.reshape(features['keys'], (1, -1)), - tf.reshape(tf.multiply(features['values'], -1.0), (1, -1)) - ], - message="DATA EXAMPLE = ", - summarize=10000 - ) + # Function to print data example + return tf.Print( + logits, + [logits, lolly_activations, tf.reshape(features['keys'], (1, -1)), tf.reshape(tf.multiply(features['values'], -1.0), (1, -1))], + message="DATA EXAMPLE = ", + summarize=10000 + ) + +def earlybird_output_fn(graph_output): + # Function to process the Earlybird model output + export_outputs = { + tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: + tf.estimator.export.PredictOutput( + {"prediction": tf.identity(graph_output["output"], name="output_scores")} + ) + } + return export_outputs if __name__ == "__main__": + # Set up argument parser parser = DataRecordTrainer.add_parser_arguments() parser = twml.contrib.calibrators.add_discretizer_arguments(parser) @@ -136,8 +141,10 @@ if __name__ == "__main__": help="Prints 'DATA EXAMPLE = [[tf logit]][[logged lolly logit]][[feature ids][feature values]]'") add_weight_arguments(parser) + # Parse arguments opt = parser.parse_args() + # Set up feature configuration feature_config_module = all_configs.select_feature_config(opt.feature_config) feature_config = feature_config_module.get_feature_config(data_spec_path=opt.data_spec, label=opt.label) @@ -146,6 +153,7 @@ if __name__ == "__main__": feature_config, keep_fields=("ids", "keys", "values", "batch_size", "total_size", "codes")) + # Discretizer calibration (if necessary) if not opt.lolly_model_tsv: if opt.model_use_existing_discretizer: logging.info("Skipping discretizer calibration [model.use_existing_discretizer=True]") @@ -162,6 +170,7 @@ if __name__ == "__main__": build_graph_fn=build_percentile_discretizer_graph, feature_config=feature_config) + # Initialize trainer trainer = DataRecordTrainer( name="earlybird", params=opt, @@ -175,6 +184,7 @@ if __name__ == "__main__": warm_start_from=None ) + # Train and evaluate model train_input_fn = trainer.get_train_input_fn(parse_fn=parse_fn) eval_input_fn = trainer.get_eval_input_fn(parse_fn=parse_fn) @@ -184,6 +194,7 @@ if __name__ == "__main__": trainingEndTime = datetime.now() logging.info("Training and Evaluation time: " + str(trainingEndTime - trainingStartTime)) + # Export model (if current node is chief) if trainer._estimator.config.is_chief: serving_input_in_earlybird = { "input_sparse_tensor_indices": array_ops.placeholder( @@ -209,6 +220,3 @@ if __name__ == "__main__": feature_spec=feature_config.get_feature_spec() ) logging.info("The export model path is: " + opt.export_dir) - - - From 7b41414753406e36b58d0fe06e989ba508739d7e Mon Sep 17 00:00:00 2001 From: Sahil B Date: Fri, 31 Mar 2023 21:24:28 -0700 Subject: [PATCH 3/4] Fixed Issues in README#605 --- .../scripts/models/earlybird/README.md | 58 ++++++------------- 1 file changed, 18 insertions(+), 40 deletions(-) diff --git a/src/python/twitter/deepbird/projects/timelines/scripts/models/earlybird/README.md b/src/python/twitter/deepbird/projects/timelines/scripts/models/earlybird/README.md index 3eb9e6c74..9afa7f438 100644 --- a/src/python/twitter/deepbird/projects/timelines/scripts/models/earlybird/README.md +++ b/src/python/twitter/deepbird/projects/timelines/scripts/models/earlybird/README.md @@ -1,60 +1,38 @@ # Earlybird Light Ranker -*Note: the light ranker is an old part of the stack which we are currently in the process of replacing. -The current model was last trained several years ago, and uses some very strange features. -We are working on training a new model, and eventually rebuilding this part of the stack entirely.* +*Note: The light ranker is an older part of the stack being replaced. The current model was trained years ago and uses odd features. A new model is being developed, and eventually, the entire stack will be rebuilt.* -The Earlybird light ranker is a logistic regression model which predicts the likelihood that the user will engage with a -tweet. -It is intended to be a simplified version of the heavy ranker which can run on a greater amount of tweets. +The Earlybird light ranker is a logistic regression model predicting user engagement likelihood with tweets. It's a simplified version of the heavy ranker, capable of handling more tweets. -There are currently 2 main light ranker models in use: one for ranking in network tweets (`recap_earlybird`), and -another for -out of network (UTEG) tweets (`rectweet_earlybird`). Both models are trained using the `train.py` script which is -included in this directory. They differ mainly in the set of features -used by the model. -The in network model uses -the `src/python/twitter/deepbird/projects/timelines/configs/recap/feature_config.py` file to define the -feature configuration, while the -out of network model uses `src/python/twitter/deepbird/projects/timelines/configs/rectweet_earlybird/feature_config.py`. -The `train.py` script is essentially a series of hooks provided to for Twitter's `twml` framework to execute, -which is included under `twml/`. + +The Earlybird light ranker is a logistic regression model predicting user engagement likelihood with tweets. It's a simplified version of the heavy ranker, capable of handling more tweets. There are two main light ranker models: one for in-network tweets (`recap_earlybird`) and another for out-of-network (UTEG) tweets (`rectweet_earlybird`). Both models are trained using the `train.py` script, and they mainly differ in the features used. The in-network model uses `src/python/twitter/deepbird/projects/timelines/configs/recap/feature_config.py`, while the out-of-network model uses `src/python/twitter/deepbird/projects/timelines/configs/rectweet_earlybird/feature_config.py`. + +The `train.py` script serves as a series of hooks for Twitter's `twml` framework, included under `twml/`. ### Features The light ranker features pipeline is as follows: ![earlybird_features.png](earlybird_features.png) -Some of these components are explained below: - -- Index Ingester: an indexing pipeline that handles the tweets as they are generated. This is the main input of - Earlybird, it produces Tweet Data (the basic information about the tweet, the text, the urls, media entities, facets, - etc) and Static Features (the features you can compute directly from a tweet right now, like whether it has URL, has - Cards, has quotes, etc); All information computed here are stored in index and flushed as each realtime index segments - become full. They are loaded back later from disk when Earlybird restarts. Note that the features may be computed in a - non-trivial way (like deciding the value of hasUrl), they could be computed and combined from some more "raw" - information in the tweet and from other services. - Signal Ingester: the ingester for Realtime Features, per-tweet features that can change after the tweet has been - indexed, mostly social engagements like retweetCount, favCount, replyCount, etc, along with some (future) spam signals - that's computed with later activities. These were collected and computed in a Heron topology by processing multiple - event streams and can be extended to support more features. -- User Table Features is another set of features per user. They are from User Table Updater, a different input that - processes a stream written by our user service. It's used to store sparse realtime user - information. These per-user features are propagated to the tweet being scored by - looking up the author of the tweet. -- Search Context Features are basically the information of current searcher, like their UI language, their own - produced/consumed language, and the current time (implied). They are combined with Tweet Data to compute some of the - features used in scoring. +Components explained below: +- Index Ingester: An indexing pipeline responsible for handling tweets as they are generated. This component serves as the primary input for Earlybird. It creates Tweet Data, which includes basic information about the tweet (text, URLs, media entities, facets, and more) and Static Features, which are features that can be computed directly from a tweet (such as whether it has a URL, cards, or quotes). All information computed by the Index Ingester is stored in an index and flushed as each realtime index segment becomes full. When Earlybird restarts, this information is loaded back from the disk. It's important to note that some features might be computed in non-trivial ways, such as determining the value of "hasUrl". These features could be computed and combined from raw information within the tweet and data from other services. +- Signal Ingester: Responsible for Realtime Features—per-tweet features that can change after indexing. These include + social engagements like retweet count, favorite count, and reply count, as well as future spam signals. They are + collected and computed in a Heron topology by processing multiple event streams and can be expanded to support more features. +- User Table Features: A separate set of per-user features sourced from the User Table Updater, which processes a stream + written by the user service. It stores sparse realtime user information, and these per-user features are propagated to + the tweet being scored by looking up the tweet's author. +- Search Context Features: Information about the current searcher, such as their UI language, their own produced/consumed language, + and the current time. These features are combined with Tweet Data to compute some of the features used in scoring. The scoring function in Earlybird uses both static and realtime features. Examples of static features used are: - Whether the tweet is a retweet - Whether the tweet contains a link -- Whether this tweet has any trend words at ingestion time +- Whether the tweet has trend words at ingestion time - Whether the tweet is a reply -- A score for the static quality of the text, computed in TweetTextScorer.java in the Ingester. Based on the factors - such as offensiveness, content entropy, "shout" score, length, and readability. +- A score for the static text quality, computed based on factors like offensiveness, content entropy, "shout" score, length, and readability. - tweepcred, see top-level README.md Examples of realtime features used are: From f94b75133c66de2c34f632c144824a263f733a51 Mon Sep 17 00:00:00 2001 From: Sahil B Date: Fri, 31 Mar 2023 21:30:26 -0700 Subject: [PATCH 4/4] Fixed Issues in README#605 --- .../timelines/scripts/models/earlybird/train.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/src/python/twitter/deepbird/projects/timelines/scripts/models/earlybird/train.py b/src/python/twitter/deepbird/projects/timelines/scripts/models/earlybird/train.py index db6744d8a..6ef181f5f 100644 --- a/src/python/twitter/deepbird/projects/timelines/scripts/models/earlybird/train.py +++ b/src/python/twitter/deepbird/projects/timelines/scripts/models/earlybird/train.py @@ -36,7 +36,6 @@ def get_feature_values(features_values, params): return features_values def build_graph(features, label, mode, params, config=None): - # Function to build the Earlybird model graph weights = None if "weights" in features: weights = make_weights_tensor(features["weights"], label, params) @@ -101,7 +100,6 @@ def build_graph(features, label, mode, params, config=None): return {"output": output, "loss": loss, "weights": weights} def print_data_example(logits, lolly_activations, features): - # Function to print data example return tf.Print( logits, [logits, lolly_activations, tf.reshape(features['keys'], (1, -1)), tf.reshape(tf.multiply(features['values'], -1.0), (1, -1))], @@ -110,7 +108,6 @@ def print_data_example(logits, lolly_activations, features): ) def earlybird_output_fn(graph_output): - # Function to process the Earlybird model output export_outputs = { tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: tf.estimator.export.PredictOutput( @@ -120,7 +117,6 @@ def earlybird_output_fn(graph_output): return export_outputs if __name__ == "__main__": - # Set up argument parser parser = DataRecordTrainer.add_parser_arguments() parser = twml.contrib.calibrators.add_discretizer_arguments(parser) @@ -141,10 +137,8 @@ if __name__ == "__main__": help="Prints 'DATA EXAMPLE = [[tf logit]][[logged lolly logit]][[feature ids][feature values]]'") add_weight_arguments(parser) - # Parse arguments opt = parser.parse_args() - # Set up feature configuration feature_config_module = all_configs.select_feature_config(opt.feature_config) feature_config = feature_config_module.get_feature_config(data_spec_path=opt.data_spec, label=opt.label) @@ -153,7 +147,6 @@ if __name__ == "__main__": feature_config, keep_fields=("ids", "keys", "values", "batch_size", "total_size", "codes")) - # Discretizer calibration (if necessary) if not opt.lolly_model_tsv: if opt.model_use_existing_discretizer: logging.info("Skipping discretizer calibration [model.use_existing_discretizer=True]") @@ -170,7 +163,6 @@ if __name__ == "__main__": build_graph_fn=build_percentile_discretizer_graph, feature_config=feature_config) - # Initialize trainer trainer = DataRecordTrainer( name="earlybird", params=opt, @@ -184,7 +176,6 @@ if __name__ == "__main__": warm_start_from=None ) - # Train and evaluate model train_input_fn = trainer.get_train_input_fn(parse_fn=parse_fn) eval_input_fn = trainer.get_eval_input_fn(parse_fn=parse_fn) @@ -194,7 +185,6 @@ if __name__ == "__main__": trainingEndTime = datetime.now() logging.info("Training and Evaluation time: " + str(trainingEndTime - trainingStartTime)) - # Export model (if current node is chief) if trainer._estimator.config.is_chief: serving_input_in_earlybird = { "input_sparse_tensor_indices": array_ops.placeholder(