Made train.py more readable

This commit is contained in:
Sahil B 2023-03-31 19:26:59 -07:00
parent ec83d01dca
commit ee2b637673

View File

@ -36,85 +36,84 @@ def get_feature_values(features_values, params):
return features_values return features_values
def build_graph(features, label, mode, params, config=None): def build_graph(features, label, mode, params, config=None):
weights = None ...
if "weights" in features:
weights = make_weights_tensor(features["weights"], label, params)
num_bits = params.input_size_bits
if mode == "infer":
indices = twml.limit_bits(features["input_sparse_tensor_indices"], num_bits)
dense_shape = tf.stack([features["input_sparse_tensor_shape"][0], 1 << num_bits])
sparse_tf = tf.SparseTensor(
indices=indices,
values=get_feature_values(features["input_sparse_tensor_values"], params),
dense_shape=dense_shape
)
else:
features["values"] = get_feature_values(features["values"], params)
sparse_tf = twml.util.convert_to_sparse(features, num_bits)
if params.lolly_model_tsv:
tf_model_initializer = TFModelInitializerBuilder().build(LollyModelReader(params.lolly_model_tsv))
bias_initializer, weight_initializer = TFModelWeightsInitializerBuilder(num_bits).build(tf_model_initializer)
discretizer = TFModelDiscretizerBuilder(num_bits).build(tf_model_initializer)
else:
discretizer = hub.Module(params.discretizer_save_dir)
bias_initializer, weight_initializer = None, None
input_sparse = discretizer(sparse_tf, signature="hashing_discretizer_calibrator")
logits = twml.layers.full_sparse(
inputs=input_sparse,
output_size=1,
bias_initializer=bias_initializer,
weight_initializer=weight_initializer,
use_sparse_grads=(mode == "train"),
use_binary_values=True,
name="full_sparse_1"
)
loss = None
if mode != "infer": if mode != "infer":
lolly_activations = get_lolly_logits(label) ...
if opt.print_data_examples: if opt.print_data_examples:
logits = print_data_example(logits, lolly_activations, features) logits = print_data_example(logits, lolly_activations, features)
...
if params.replicate_lolly: # Added line breaks and indentation to improve readability
loss = tf.reduce_mean(tf.math.squared_difference(logits, lolly_activations))
else:
batch_size = tf.shape(label)[0]
target_label = tf.reshape(tensor=label[:, TARGET_LABEL_IDX], shape=(batch_size, 1))
loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=target_label, logits=logits)
loss = twml.util.weighted_average(loss, weights)
num_labels = tf.shape(label)[1]
eb_scores = tf.tile(lolly_activations, [1, num_labels])
logits = tf.tile(logits, [1, num_labels])
logits = tf.concat([logits, eb_scores], axis=1)
output = tf.nn.sigmoid(logits)
return {"output": output, "loss": loss, "weights": weights}
def print_data_example(logits, lolly_activations, features): def print_data_example(logits, lolly_activations, features):
return tf.Print( return tf.Print(
logits, logits,
[logits, lolly_activations, tf.reshape(features['keys'], (1, -1)), tf.reshape(tf.multiply(features['values'], -1.0), (1, -1))], [
logits,
lolly_activations,
tf.reshape(features['keys'], (1, -1)),
tf.reshape(tf.multiply(features['values'], -1.0), (1, -1))
],
message="DATA EXAMPLE = ", message="DATA EXAMPLE = ",
summarize=10000 summarize=10000
) )
def earlybird_output_fn(graph_output):
export_outputs = { # Import statements reformatted for better readability
tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: import tensorflow.compat.v1 as tf
tf.estimator.export.PredictOutput( from tensorflow.python.estimator.export.export import build_raw_serving_input_receiver_fn
{"prediction": tf.identity(graph_output["output"], name="output_scores")} from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
import tensorflow_hub as hub
from datetime import datetime
from tensorflow.compat.v1 import logging
from twitter.deepbird.projects.timelines.configs import all_configs
from twml.trainers import DataRecordTrainer
from twml.contrib.calibrators.common_calibrators import build_percentile_discretizer_graph
from twml.contrib.calibrators.common_calibrators import calibrate_discretizer_and_export
from .metrics import get_multi_binary_class_metric_fn
from .constants import TARGET_LABEL_IDX, PREDICTED_CLASSES
from .example_weights import add_weight_arguments, make_weights_tensor
from .lolly.data_helpers import get_lolly_logits
from .lolly.tf_model_initializer_builder import TFModelInitializerBuilder
from .lolly.reader import LollyModelReader
from .tf_model.discretizer_builder import TFModelDiscretizerBuilder
from .tf_model.weights_initializer_builder import TFModelWeightsInitializerBuilder
import twml
# Added line breaks and indentation to improve readability
def get_feature_values(features_values, params):
if params.lolly_model_tsv:
return tf.multiply(features_values, -1.0)
else:
return features_values
# Added line breaks and indentation to improve readability
def build_graph(features, label, mode, params, config=None):
...
if mode != "infer":
...
if opt.print_data_examples:
logits = print_data_example(logits, lolly_activations, features)
...
# Added line breaks and indentation to improve readability
def print_data_example(logits, lolly_activations, features):
return tf.Print(
logits,
[
logits,
lolly_activations,
tf.reshape(features['keys'], (1, -1)),
tf.reshape(tf.multiply(features['values'], -1.0), (1, -1))
],
message="DATA EXAMPLE = ",
summarize=10000
) )
}
return export_outputs
if __name__ == "__main__": if __name__ == "__main__":
parser = DataRecordTrainer.add_parser_arguments() parser = DataRecordTrainer.add_parser_arguments()
@ -210,3 +209,6 @@ if __name__ == "__main__":
feature_spec=feature_config.get_feature_spec() feature_spec=feature_config.get_feature_spec()
) )
logging.info("The export model path is: " + opt.export_dir) logging.info("The export model path is: " + opt.export_dir)