mirror of
https://github.com/twitter/the-algorithm.git
synced 2024-12-22 18:21:51 +01:00
Made train.py more readable. Main function is now easier to understand.
This commit is contained in:
parent
ee2b637673
commit
fefd9c2404
@ -23,99 +23,104 @@ from .tf_model.weights_initializer_builder import TFModelWeightsInitializerBuild
|
|||||||
import twml
|
import twml
|
||||||
|
|
||||||
def get_feature_values(features_values, params):
|
def get_feature_values(features_values, params):
|
||||||
if params.lolly_model_tsv:
|
if params.lolly_model_tsv:
|
||||||
# The default DBv2 HashingDiscretizer bin membership interval is (a, b]
|
# The default DBv2 HashingDiscretizer bin membership interval is (a, b]
|
||||||
#
|
#
|
||||||
# The Earlybird Lolly prediction engine discretizer bin membership interval is [a, b)
|
# The Earlybird Lolly prediction engine discretizer bin membership interval is [a, b)
|
||||||
#
|
#
|
||||||
# TFModelInitializerBuilder converts (a, b] to [a, b) by inverting the bin boundaries.
|
# TFModelInitializerBuilder converts (a, b] to [a, b) by inverting the bin boundaries.
|
||||||
#
|
#
|
||||||
# Thus, invert the feature values, so that HashingDiscretizer can to find the correct bucket.
|
# Thus, invert the feature values, so that HashingDiscretizer can to find the correct bucket.
|
||||||
return tf.multiply(features_values, -1.0)
|
return tf.multiply(features_values, -1.0)
|
||||||
else:
|
else:
|
||||||
return features_values
|
return features_values
|
||||||
|
|
||||||
def build_graph(features, label, mode, params, config=None):
|
def build_graph(features, label, mode, params, config=None):
|
||||||
...
|
# Function to build the Earlybird model graph
|
||||||
if mode != "infer":
|
weights = None
|
||||||
...
|
if "weights" in features:
|
||||||
if opt.print_data_examples:
|
weights = make_weights_tensor(features["weights"], label, params)
|
||||||
logits = print_data_example(logits, lolly_activations, features)
|
|
||||||
...
|
|
||||||
|
|
||||||
# Added line breaks and indentation to improve readability
|
num_bits = params.input_size_bits
|
||||||
def print_data_example(logits, lolly_activations, features):
|
|
||||||
return tf.Print(
|
if mode == "infer":
|
||||||
logits,
|
indices = twml.limit_bits(features["input_sparse_tensor_indices"], num_bits)
|
||||||
[
|
dense_shape = tf.stack([features["input_sparse_tensor_shape"][0], 1 << num_bits])
|
||||||
logits,
|
sparse_tf = tf.SparseTensor(
|
||||||
lolly_activations,
|
indices=indices,
|
||||||
tf.reshape(features['keys'], (1, -1)),
|
values=get_feature_values(features["input_sparse_tensor_values"], params),
|
||||||
tf.reshape(tf.multiply(features['values'], -1.0), (1, -1))
|
dense_shape=dense_shape
|
||||||
],
|
|
||||||
message="DATA EXAMPLE = ",
|
|
||||||
summarize=10000
|
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
features["values"] = get_feature_values(features["values"], params)
|
||||||
|
sparse_tf = twml.util.convert_to_sparse(features, num_bits)
|
||||||
|
|
||||||
|
if params.lolly_model_tsv:
|
||||||
|
tf_model_initializer = TFModelInitializerBuilder().build(LollyModelReader(params.lolly_model_tsv))
|
||||||
|
bias_initializer, weight_initializer = TFModelWeightsInitializerBuilder(num_bits).build(tf_model_initializer)
|
||||||
|
discretizer = TFModelDiscretizerBuilder(num_bits).build(tf_model_initializer)
|
||||||
|
else:
|
||||||
|
discretizer = hub.Module(params.discretizer_save_dir)
|
||||||
|
bias_initializer, weight_initializer = None, None
|
||||||
|
|
||||||
# Import statements reformatted for better readability
|
input_sparse = discretizer(sparse_tf, signature="hashing_discretizer_calibrator")
|
||||||
import tensorflow.compat.v1 as tf
|
|
||||||
from tensorflow.python.estimator.export.export import build_raw_serving_input_receiver_fn
|
|
||||||
from tensorflow.python.framework import dtypes
|
|
||||||
from tensorflow.python.ops import array_ops
|
|
||||||
import tensorflow_hub as hub
|
|
||||||
|
|
||||||
from datetime import datetime
|
logits = twml.layers.full_sparse(
|
||||||
from tensorflow.compat.v1 import logging
|
inputs=input_sparse,
|
||||||
from twitter.deepbird.projects.timelines.configs import all_configs
|
output_size=1,
|
||||||
from twml.trainers import DataRecordTrainer
|
bias_initializer=bias_initializer,
|
||||||
from twml.contrib.calibrators.common_calibrators import build_percentile_discretizer_graph
|
weight_initializer=weight_initializer,
|
||||||
from twml.contrib.calibrators.common_calibrators import calibrate_discretizer_and_export
|
use_sparse_grads=(mode == "train"),
|
||||||
from .metrics import get_multi_binary_class_metric_fn
|
use_binary_values=True,
|
||||||
from .constants import TARGET_LABEL_IDX, PREDICTED_CLASSES
|
name="full_sparse_1"
|
||||||
from .example_weights import add_weight_arguments, make_weights_tensor
|
)
|
||||||
from .lolly.data_helpers import get_lolly_logits
|
|
||||||
from .lolly.tf_model_initializer_builder import TFModelInitializerBuilder
|
|
||||||
from .lolly.reader import LollyModelReader
|
|
||||||
from .tf_model.discretizer_builder import TFModelDiscretizerBuilder
|
|
||||||
from .tf_model.weights_initializer_builder import TFModelWeightsInitializerBuilder
|
|
||||||
|
|
||||||
import twml
|
loss = None
|
||||||
|
|
||||||
# Added line breaks and indentation to improve readability
|
if mode != "infer":
|
||||||
def get_feature_values(features_values, params):
|
lolly_activations = get_lolly_logits(label)
|
||||||
if params.lolly_model_tsv:
|
|
||||||
return tf.multiply(features_values, -1.0)
|
if opt.print_data_examples:
|
||||||
|
logits = print_data_example(logits, lolly_activations, features)
|
||||||
|
|
||||||
|
if params.replicate_lolly:
|
||||||
|
loss = tf.reduce_mean(tf.math.squared_difference(logits, lolly_activations))
|
||||||
else:
|
else:
|
||||||
return features_values
|
batch_size = tf.shape(label)[0]
|
||||||
|
target_label = tf.reshape(tensor=label[:, TARGET_LABEL_IDX], shape=(batch_size, 1))
|
||||||
|
loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=target_label, logits=logits)
|
||||||
|
loss = twml.util.weighted_average(loss, weights)
|
||||||
|
|
||||||
# Added line breaks and indentation to improve readability
|
num_labels = tf.shape(label)[1]
|
||||||
def build_graph(features, label, mode, params, config=None):
|
eb_scores = tf.tile(lolly_activations, [1, num_labels])
|
||||||
...
|
logits = tf.tile(logits, [1, num_labels])
|
||||||
|
logits = tf.concat([logits, eb_scores], axis=1)
|
||||||
|
|
||||||
if mode != "infer":
|
output = tf.nn.sigmoid(logits)
|
||||||
...
|
|
||||||
|
|
||||||
if opt.print_data_examples:
|
return {"output": output, "loss": loss, "weights": weights}
|
||||||
logits = print_data_example(logits, lolly_activations, features)
|
|
||||||
|
|
||||||
...
|
|
||||||
|
|
||||||
# Added line breaks and indentation to improve readability
|
|
||||||
def print_data_example(logits, lolly_activations, features):
|
def print_data_example(logits, lolly_activations, features):
|
||||||
return tf.Print(
|
# Function to print data example
|
||||||
logits,
|
return tf.Print(
|
||||||
[
|
logits,
|
||||||
logits,
|
[logits, lolly_activations, tf.reshape(features['keys'], (1, -1)), tf.reshape(tf.multiply(features['values'], -1.0), (1, -1))],
|
||||||
lolly_activations,
|
message="DATA EXAMPLE = ",
|
||||||
tf.reshape(features['keys'], (1, -1)),
|
summarize=10000
|
||||||
tf.reshape(tf.multiply(features['values'], -1.0), (1, -1))
|
)
|
||||||
],
|
|
||||||
message="DATA EXAMPLE = ",
|
def earlybird_output_fn(graph_output):
|
||||||
summarize=10000
|
# Function to process the Earlybird model output
|
||||||
)
|
export_outputs = {
|
||||||
|
tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
|
||||||
|
tf.estimator.export.PredictOutput(
|
||||||
|
{"prediction": tf.identity(graph_output["output"], name="output_scores")}
|
||||||
|
)
|
||||||
|
}
|
||||||
|
return export_outputs
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
# Set up argument parser
|
||||||
parser = DataRecordTrainer.add_parser_arguments()
|
parser = DataRecordTrainer.add_parser_arguments()
|
||||||
|
|
||||||
parser = twml.contrib.calibrators.add_discretizer_arguments(parser)
|
parser = twml.contrib.calibrators.add_discretizer_arguments(parser)
|
||||||
@ -136,8 +141,10 @@ if __name__ == "__main__":
|
|||||||
help="Prints 'DATA EXAMPLE = [[tf logit]][[logged lolly logit]][[feature ids][feature values]]'")
|
help="Prints 'DATA EXAMPLE = [[tf logit]][[logged lolly logit]][[feature ids][feature values]]'")
|
||||||
add_weight_arguments(parser)
|
add_weight_arguments(parser)
|
||||||
|
|
||||||
|
# Parse arguments
|
||||||
opt = parser.parse_args()
|
opt = parser.parse_args()
|
||||||
|
|
||||||
|
# Set up feature configuration
|
||||||
feature_config_module = all_configs.select_feature_config(opt.feature_config)
|
feature_config_module = all_configs.select_feature_config(opt.feature_config)
|
||||||
|
|
||||||
feature_config = feature_config_module.get_feature_config(data_spec_path=opt.data_spec, label=opt.label)
|
feature_config = feature_config_module.get_feature_config(data_spec_path=opt.data_spec, label=opt.label)
|
||||||
@ -146,6 +153,7 @@ if __name__ == "__main__":
|
|||||||
feature_config,
|
feature_config,
|
||||||
keep_fields=("ids", "keys", "values", "batch_size", "total_size", "codes"))
|
keep_fields=("ids", "keys", "values", "batch_size", "total_size", "codes"))
|
||||||
|
|
||||||
|
# Discretizer calibration (if necessary)
|
||||||
if not opt.lolly_model_tsv:
|
if not opt.lolly_model_tsv:
|
||||||
if opt.model_use_existing_discretizer:
|
if opt.model_use_existing_discretizer:
|
||||||
logging.info("Skipping discretizer calibration [model.use_existing_discretizer=True]")
|
logging.info("Skipping discretizer calibration [model.use_existing_discretizer=True]")
|
||||||
@ -162,6 +170,7 @@ if __name__ == "__main__":
|
|||||||
build_graph_fn=build_percentile_discretizer_graph,
|
build_graph_fn=build_percentile_discretizer_graph,
|
||||||
feature_config=feature_config)
|
feature_config=feature_config)
|
||||||
|
|
||||||
|
# Initialize trainer
|
||||||
trainer = DataRecordTrainer(
|
trainer = DataRecordTrainer(
|
||||||
name="earlybird",
|
name="earlybird",
|
||||||
params=opt,
|
params=opt,
|
||||||
@ -175,6 +184,7 @@ if __name__ == "__main__":
|
|||||||
warm_start_from=None
|
warm_start_from=None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Train and evaluate model
|
||||||
train_input_fn = trainer.get_train_input_fn(parse_fn=parse_fn)
|
train_input_fn = trainer.get_train_input_fn(parse_fn=parse_fn)
|
||||||
eval_input_fn = trainer.get_eval_input_fn(parse_fn=parse_fn)
|
eval_input_fn = trainer.get_eval_input_fn(parse_fn=parse_fn)
|
||||||
|
|
||||||
@ -184,6 +194,7 @@ if __name__ == "__main__":
|
|||||||
trainingEndTime = datetime.now()
|
trainingEndTime = datetime.now()
|
||||||
logging.info("Training and Evaluation time: " + str(trainingEndTime - trainingStartTime))
|
logging.info("Training and Evaluation time: " + str(trainingEndTime - trainingStartTime))
|
||||||
|
|
||||||
|
# Export model (if current node is chief)
|
||||||
if trainer._estimator.config.is_chief:
|
if trainer._estimator.config.is_chief:
|
||||||
serving_input_in_earlybird = {
|
serving_input_in_earlybird = {
|
||||||
"input_sparse_tensor_indices": array_ops.placeholder(
|
"input_sparse_tensor_indices": array_ops.placeholder(
|
||||||
@ -209,6 +220,3 @@ if __name__ == "__main__":
|
|||||||
feature_spec=feature_config.get_feature_spec()
|
feature_spec=feature_config.get_feature_spec()
|
||||||
)
|
)
|
||||||
logging.info("The export model path is: " + opt.export_dir)
|
logging.info("The export model path is: " + opt.export_dir)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user