Fixed Issues in README#605

This commit is contained in:
Sahil B 2023-03-31 21:30:26 -07:00
parent 7b41414753
commit f94b75133c

View File

@ -36,7 +36,6 @@ def get_feature_values(features_values, params):
return features_values return features_values
def build_graph(features, label, mode, params, config=None): def build_graph(features, label, mode, params, config=None):
# Function to build the Earlybird model graph
weights = None weights = None
if "weights" in features: if "weights" in features:
weights = make_weights_tensor(features["weights"], label, params) weights = make_weights_tensor(features["weights"], label, params)
@ -101,7 +100,6 @@ def build_graph(features, label, mode, params, config=None):
return {"output": output, "loss": loss, "weights": weights} return {"output": output, "loss": loss, "weights": weights}
def print_data_example(logits, lolly_activations, features): def print_data_example(logits, lolly_activations, features):
# Function to print data example
return tf.Print( return tf.Print(
logits, logits,
[logits, lolly_activations, tf.reshape(features['keys'], (1, -1)), tf.reshape(tf.multiply(features['values'], -1.0), (1, -1))], [logits, lolly_activations, tf.reshape(features['keys'], (1, -1)), tf.reshape(tf.multiply(features['values'], -1.0), (1, -1))],
@ -110,7 +108,6 @@ def print_data_example(logits, lolly_activations, features):
) )
def earlybird_output_fn(graph_output): def earlybird_output_fn(graph_output):
# Function to process the Earlybird model output
export_outputs = { export_outputs = {
tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
tf.estimator.export.PredictOutput( tf.estimator.export.PredictOutput(
@ -120,7 +117,6 @@ def earlybird_output_fn(graph_output):
return export_outputs return export_outputs
if __name__ == "__main__": if __name__ == "__main__":
# Set up argument parser
parser = DataRecordTrainer.add_parser_arguments() parser = DataRecordTrainer.add_parser_arguments()
parser = twml.contrib.calibrators.add_discretizer_arguments(parser) parser = twml.contrib.calibrators.add_discretizer_arguments(parser)
@ -141,10 +137,8 @@ if __name__ == "__main__":
help="Prints 'DATA EXAMPLE = [[tf logit]][[logged lolly logit]][[feature ids][feature values]]'") help="Prints 'DATA EXAMPLE = [[tf logit]][[logged lolly logit]][[feature ids][feature values]]'")
add_weight_arguments(parser) add_weight_arguments(parser)
# Parse arguments
opt = parser.parse_args() opt = parser.parse_args()
# Set up feature configuration
feature_config_module = all_configs.select_feature_config(opt.feature_config) feature_config_module = all_configs.select_feature_config(opt.feature_config)
feature_config = feature_config_module.get_feature_config(data_spec_path=opt.data_spec, label=opt.label) feature_config = feature_config_module.get_feature_config(data_spec_path=opt.data_spec, label=opt.label)
@ -153,7 +147,6 @@ if __name__ == "__main__":
feature_config, feature_config,
keep_fields=("ids", "keys", "values", "batch_size", "total_size", "codes")) keep_fields=("ids", "keys", "values", "batch_size", "total_size", "codes"))
# Discretizer calibration (if necessary)
if not opt.lolly_model_tsv: if not opt.lolly_model_tsv:
if opt.model_use_existing_discretizer: if opt.model_use_existing_discretizer:
logging.info("Skipping discretizer calibration [model.use_existing_discretizer=True]") logging.info("Skipping discretizer calibration [model.use_existing_discretizer=True]")
@ -170,7 +163,6 @@ if __name__ == "__main__":
build_graph_fn=build_percentile_discretizer_graph, build_graph_fn=build_percentile_discretizer_graph,
feature_config=feature_config) feature_config=feature_config)
# Initialize trainer
trainer = DataRecordTrainer( trainer = DataRecordTrainer(
name="earlybird", name="earlybird",
params=opt, params=opt,
@ -184,7 +176,6 @@ if __name__ == "__main__":
warm_start_from=None warm_start_from=None
) )
# Train and evaluate model
train_input_fn = trainer.get_train_input_fn(parse_fn=parse_fn) train_input_fn = trainer.get_train_input_fn(parse_fn=parse_fn)
eval_input_fn = trainer.get_eval_input_fn(parse_fn=parse_fn) eval_input_fn = trainer.get_eval_input_fn(parse_fn=parse_fn)
@ -194,7 +185,6 @@ if __name__ == "__main__":
trainingEndTime = datetime.now() trainingEndTime = datetime.now()
logging.info("Training and Evaluation time: " + str(trainingEndTime - trainingStartTime)) logging.info("Training and Evaluation time: " + str(trainingEndTime - trainingStartTime))
# Export model (if current node is chief)
if trainer._estimator.config.is_chief: if trainer._estimator.config.is_chief:
serving_input_in_earlybird = { serving_input_in_earlybird = {
"input_sparse_tensor_indices": array_ops.placeholder( "input_sparse_tensor_indices": array_ops.placeholder(