mirror of
https://github.com/twitter/the-algorithm.git
synced 2024-12-22 10:11:52 +01:00
Fixed Issues in README#605
This commit is contained in:
parent
7b41414753
commit
f94b75133c
@ -36,7 +36,6 @@ def get_feature_values(features_values, params):
|
|||||||
return features_values
|
return features_values
|
||||||
|
|
||||||
def build_graph(features, label, mode, params, config=None):
|
def build_graph(features, label, mode, params, config=None):
|
||||||
# Function to build the Earlybird model graph
|
|
||||||
weights = None
|
weights = None
|
||||||
if "weights" in features:
|
if "weights" in features:
|
||||||
weights = make_weights_tensor(features["weights"], label, params)
|
weights = make_weights_tensor(features["weights"], label, params)
|
||||||
@ -101,7 +100,6 @@ def build_graph(features, label, mode, params, config=None):
|
|||||||
return {"output": output, "loss": loss, "weights": weights}
|
return {"output": output, "loss": loss, "weights": weights}
|
||||||
|
|
||||||
def print_data_example(logits, lolly_activations, features):
|
def print_data_example(logits, lolly_activations, features):
|
||||||
# Function to print data example
|
|
||||||
return tf.Print(
|
return tf.Print(
|
||||||
logits,
|
logits,
|
||||||
[logits, lolly_activations, tf.reshape(features['keys'], (1, -1)), tf.reshape(tf.multiply(features['values'], -1.0), (1, -1))],
|
[logits, lolly_activations, tf.reshape(features['keys'], (1, -1)), tf.reshape(tf.multiply(features['values'], -1.0), (1, -1))],
|
||||||
@ -110,7 +108,6 @@ def print_data_example(logits, lolly_activations, features):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def earlybird_output_fn(graph_output):
|
def earlybird_output_fn(graph_output):
|
||||||
# Function to process the Earlybird model output
|
|
||||||
export_outputs = {
|
export_outputs = {
|
||||||
tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
|
tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
|
||||||
tf.estimator.export.PredictOutput(
|
tf.estimator.export.PredictOutput(
|
||||||
@ -120,7 +117,6 @@ def earlybird_output_fn(graph_output):
|
|||||||
return export_outputs
|
return export_outputs
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Set up argument parser
|
|
||||||
parser = DataRecordTrainer.add_parser_arguments()
|
parser = DataRecordTrainer.add_parser_arguments()
|
||||||
|
|
||||||
parser = twml.contrib.calibrators.add_discretizer_arguments(parser)
|
parser = twml.contrib.calibrators.add_discretizer_arguments(parser)
|
||||||
@ -141,10 +137,8 @@ if __name__ == "__main__":
|
|||||||
help="Prints 'DATA EXAMPLE = [[tf logit]][[logged lolly logit]][[feature ids][feature values]]'")
|
help="Prints 'DATA EXAMPLE = [[tf logit]][[logged lolly logit]][[feature ids][feature values]]'")
|
||||||
add_weight_arguments(parser)
|
add_weight_arguments(parser)
|
||||||
|
|
||||||
# Parse arguments
|
|
||||||
opt = parser.parse_args()
|
opt = parser.parse_args()
|
||||||
|
|
||||||
# Set up feature configuration
|
|
||||||
feature_config_module = all_configs.select_feature_config(opt.feature_config)
|
feature_config_module = all_configs.select_feature_config(opt.feature_config)
|
||||||
|
|
||||||
feature_config = feature_config_module.get_feature_config(data_spec_path=opt.data_spec, label=opt.label)
|
feature_config = feature_config_module.get_feature_config(data_spec_path=opt.data_spec, label=opt.label)
|
||||||
@ -153,7 +147,6 @@ if __name__ == "__main__":
|
|||||||
feature_config,
|
feature_config,
|
||||||
keep_fields=("ids", "keys", "values", "batch_size", "total_size", "codes"))
|
keep_fields=("ids", "keys", "values", "batch_size", "total_size", "codes"))
|
||||||
|
|
||||||
# Discretizer calibration (if necessary)
|
|
||||||
if not opt.lolly_model_tsv:
|
if not opt.lolly_model_tsv:
|
||||||
if opt.model_use_existing_discretizer:
|
if opt.model_use_existing_discretizer:
|
||||||
logging.info("Skipping discretizer calibration [model.use_existing_discretizer=True]")
|
logging.info("Skipping discretizer calibration [model.use_existing_discretizer=True]")
|
||||||
@ -170,7 +163,6 @@ if __name__ == "__main__":
|
|||||||
build_graph_fn=build_percentile_discretizer_graph,
|
build_graph_fn=build_percentile_discretizer_graph,
|
||||||
feature_config=feature_config)
|
feature_config=feature_config)
|
||||||
|
|
||||||
# Initialize trainer
|
|
||||||
trainer = DataRecordTrainer(
|
trainer = DataRecordTrainer(
|
||||||
name="earlybird",
|
name="earlybird",
|
||||||
params=opt,
|
params=opt,
|
||||||
@ -184,7 +176,6 @@ if __name__ == "__main__":
|
|||||||
warm_start_from=None
|
warm_start_from=None
|
||||||
)
|
)
|
||||||
|
|
||||||
# Train and evaluate model
|
|
||||||
train_input_fn = trainer.get_train_input_fn(parse_fn=parse_fn)
|
train_input_fn = trainer.get_train_input_fn(parse_fn=parse_fn)
|
||||||
eval_input_fn = trainer.get_eval_input_fn(parse_fn=parse_fn)
|
eval_input_fn = trainer.get_eval_input_fn(parse_fn=parse_fn)
|
||||||
|
|
||||||
@ -194,7 +185,6 @@ if __name__ == "__main__":
|
|||||||
trainingEndTime = datetime.now()
|
trainingEndTime = datetime.now()
|
||||||
logging.info("Training and Evaluation time: " + str(trainingEndTime - trainingStartTime))
|
logging.info("Training and Evaluation time: " + str(trainingEndTime - trainingStartTime))
|
||||||
|
|
||||||
# Export model (if current node is chief)
|
|
||||||
if trainer._estimator.config.is_chief:
|
if trainer._estimator.config.is_chief:
|
||||||
serving_input_in_earlybird = {
|
serving_input_in_earlybird = {
|
||||||
"input_sparse_tensor_indices": array_ops.placeholder(
|
"input_sparse_tensor_indices": array_ops.placeholder(
|
||||||
|
Loading…
Reference in New Issue
Block a user