the-algorithm/pushservice/src/main/python/models/heavy_ranking/deep_norm.py

137 lines
4.6 KiB
Python

"""
Training job for the heavy ranker of the push notification service.
"""
from datetime import datetime
import json
import os
import twml
from ..libs.metric_fn_utils import flip_disliked_labels, get_metric_fn
from ..libs.model_utils import read_config
from ..libs.warm_start_utils import get_feature_list_for_heavy_ranking, warm_start_checkpoint
from .features import get_feature_config
from .model_pools import ALL_MODELS
from .params import load_graph_params
from .run_args import get_training_arg_parser
import tensorflow.compat.v1 as tf
from tensorflow.compat.v1 import logging
def main() -> None:
args, _ = get_training_arg_parser().parse_known_args()
logging.info(f"Parsed args: {args}")
params = load_graph_params(args)
logging.info(f"Loaded graph params: {params}")
param_file = os.path.join(args.save_dir, "params.json")
logging.info(f"Saving graph params to: {param_file}")
with tf.io.gfile.GFile(param_file, mode="w") as file:
json.dump(params.json(), file, ensure_ascii=False, indent=4)
logging.info(f"Get Feature Config: {args.feature_list}")
feature_list = read_config(args.feature_list).items()
feature_config = get_feature_config(
data_spec_path=args.data_spec,
params=params,
feature_list_provided=feature_list,
)
feature_list_path = args.feature_list
warm_start_from = args.warm_start_from
if args.warm_start_base_dir:
logging.info(f"Get warm started model from: {args.warm_start_base_dir}.")
continuous_binary_feat_list_save_path = os.path.join(
args.warm_start_base_dir, "continuous_binary_feat_list.json"
)
warm_start_folder = os.path.join(args.warm_start_base_dir, "best_checkpoint")
job_name = os.path.basename(args.save_dir)
ws_output_ckpt_folder = os.path.join(args.warm_start_base_dir, f"warm_start_for_{job_name}")
if tf.io.gfile.exists(ws_output_ckpt_folder):
tf.io.gfile.rmtree(ws_output_ckpt_folder)
tf.io.gfile.mkdir(ws_output_ckpt_folder)
warm_start_from = warm_start_checkpoint(
warm_start_folder,
continuous_binary_feat_list_save_path,
feature_list_path,
args.data_spec,
ws_output_ckpt_folder,
)
logging.info(f"Created warm_start_from_ckpt {warm_start_from}.")
logging.info("Build Trainer.")
metric_fn = get_metric_fn("OONC_Engagement" if len(params.tasks) == 2 else "OONC", False)
trainer = twml.trainers.DataRecordTrainer(
name="magic_recs",
params=args,
build_graph_fn=lambda *args: ALL_MODELS[params.model.name](params=params)(*args),
save_dir=args.save_dir,
run_config=None,
feature_config=feature_config,
metric_fn=flip_disliked_labels(metric_fn),
warm_start_from=warm_start_from,
)
logging.info("Build train and eval input functions.")
train_input_fn = trainer.get_train_input_fn(shuffle=True)
eval_input_fn = trainer.get_eval_input_fn(repeat=False, shuffle=False)
learn = trainer.learn
if args.distributed or args.num_workers is not None:
learn = trainer.train_and_evaluate
if not args.directly_export_best:
logging.info("Starting training")
start = datetime.now()
learn(
early_stop_minimize=False,
early_stop_metric="pr_auc_unweighted_OONC",
early_stop_patience=args.early_stop_patience,
early_stop_tolerance=args.early_stop_tolerance,
eval_input_fn=eval_input_fn,
train_input_fn=train_input_fn,
)
logging.info(f"Total training time: {datetime.now() - start}")
else:
logging.info("Directly exporting the model")
if not args.export_dir:
args.export_dir = os.path.join(args.save_dir, "exported_models")
logging.info(f"Exporting the model to {args.export_dir}.")
start = datetime.now()
twml.contrib.export.export_fn.export_all_models(
trainer=trainer,
export_dir=args.export_dir,
parse_fn=feature_config.get_parse_fn(),
serving_input_receiver_fn=feature_config.get_serving_input_receiver_fn(),
export_output_fn=twml.export_output_fns.batch_prediction_continuous_output_fn,
)
logging.info(f"Total model export time: {datetime.now() - start}")
logging.info(f"The MLP directory is: {args.save_dir}")
continuous_binary_feat_list_save_path = os.path.join(
args.save_dir, "continuous_binary_feat_list.json"
)
logging.info(
f"Saving the list of continuous and binary features to {continuous_binary_feat_list_save_path}."
)
continuous_binary_feat_list = get_feature_list_for_heavy_ranking(
feature_list_path, args.data_spec
)
twml.util.write_file(
continuous_binary_feat_list_save_path, continuous_binary_feat_list, encode="json"
)
if __name__ == "__main__":
main()
logging.info("Done.")