137 lines
4.6 KiB
Python
137 lines
4.6 KiB
Python
"""
|
|
Training job for the heavy ranker of the push notification service.
|
|
"""
|
|
from datetime import datetime
|
|
import json
|
|
import os
|
|
|
|
import twml
|
|
|
|
from ..libs.metric_fn_utils import flip_disliked_labels, get_metric_fn
|
|
from ..libs.model_utils import read_config
|
|
from ..libs.warm_start_utils import get_feature_list_for_heavy_ranking, warm_start_checkpoint
|
|
from .features import get_feature_config
|
|
from .model_pools import ALL_MODELS
|
|
from .params import load_graph_params
|
|
from .run_args import get_training_arg_parser
|
|
|
|
import tensorflow.compat.v1 as tf
|
|
from tensorflow.compat.v1 import logging
|
|
|
|
|
|
def main() -> None:
|
|
args, _ = get_training_arg_parser().parse_known_args()
|
|
logging.info(f"Parsed args: {args}")
|
|
|
|
params = load_graph_params(args)
|
|
logging.info(f"Loaded graph params: {params}")
|
|
|
|
param_file = os.path.join(args.save_dir, "params.json")
|
|
logging.info(f"Saving graph params to: {param_file}")
|
|
with tf.io.gfile.GFile(param_file, mode="w") as file:
|
|
json.dump(params.json(), file, ensure_ascii=False, indent=4)
|
|
|
|
logging.info(f"Get Feature Config: {args.feature_list}")
|
|
feature_list = read_config(args.feature_list).items()
|
|
feature_config = get_feature_config(
|
|
data_spec_path=args.data_spec,
|
|
params=params,
|
|
feature_list_provided=feature_list,
|
|
)
|
|
feature_list_path = args.feature_list
|
|
|
|
warm_start_from = args.warm_start_from
|
|
if args.warm_start_base_dir:
|
|
logging.info(f"Get warm started model from: {args.warm_start_base_dir}.")
|
|
|
|
continuous_binary_feat_list_save_path = os.path.join(
|
|
args.warm_start_base_dir, "continuous_binary_feat_list.json"
|
|
)
|
|
warm_start_folder = os.path.join(args.warm_start_base_dir, "best_checkpoint")
|
|
job_name = os.path.basename(args.save_dir)
|
|
ws_output_ckpt_folder = os.path.join(args.warm_start_base_dir, f"warm_start_for_{job_name}")
|
|
if tf.io.gfile.exists(ws_output_ckpt_folder):
|
|
tf.io.gfile.rmtree(ws_output_ckpt_folder)
|
|
|
|
tf.io.gfile.mkdir(ws_output_ckpt_folder)
|
|
|
|
warm_start_from = warm_start_checkpoint(
|
|
warm_start_folder,
|
|
continuous_binary_feat_list_save_path,
|
|
feature_list_path,
|
|
args.data_spec,
|
|
ws_output_ckpt_folder,
|
|
)
|
|
logging.info(f"Created warm_start_from_ckpt {warm_start_from}.")
|
|
|
|
logging.info("Build Trainer.")
|
|
metric_fn = get_metric_fn("OONC_Engagement" if len(params.tasks) == 2 else "OONC", False)
|
|
|
|
trainer = twml.trainers.DataRecordTrainer(
|
|
name="magic_recs",
|
|
params=args,
|
|
build_graph_fn=lambda *args: ALL_MODELS[params.model.name](params=params)(*args),
|
|
save_dir=args.save_dir,
|
|
run_config=None,
|
|
feature_config=feature_config,
|
|
metric_fn=flip_disliked_labels(metric_fn),
|
|
warm_start_from=warm_start_from,
|
|
)
|
|
|
|
logging.info("Build train and eval input functions.")
|
|
train_input_fn = trainer.get_train_input_fn(shuffle=True)
|
|
eval_input_fn = trainer.get_eval_input_fn(repeat=False, shuffle=False)
|
|
|
|
learn = trainer.learn
|
|
if args.distributed or args.num_workers is not None:
|
|
learn = trainer.train_and_evaluate
|
|
|
|
if not args.directly_export_best:
|
|
logging.info("Starting training")
|
|
start = datetime.now()
|
|
learn(
|
|
early_stop_minimize=False,
|
|
early_stop_metric="pr_auc_unweighted_OONC",
|
|
early_stop_patience=args.early_stop_patience,
|
|
early_stop_tolerance=args.early_stop_tolerance,
|
|
eval_input_fn=eval_input_fn,
|
|
train_input_fn=train_input_fn,
|
|
)
|
|
logging.info(f"Total training time: {datetime.now() - start}")
|
|
else:
|
|
logging.info("Directly exporting the model")
|
|
|
|
if not args.export_dir:
|
|
args.export_dir = os.path.join(args.save_dir, "exported_models")
|
|
|
|
logging.info(f"Exporting the model to {args.export_dir}.")
|
|
start = datetime.now()
|
|
twml.contrib.export.export_fn.export_all_models(
|
|
trainer=trainer,
|
|
export_dir=args.export_dir,
|
|
parse_fn=feature_config.get_parse_fn(),
|
|
serving_input_receiver_fn=feature_config.get_serving_input_receiver_fn(),
|
|
export_output_fn=twml.export_output_fns.batch_prediction_continuous_output_fn,
|
|
)
|
|
|
|
logging.info(f"Total model export time: {datetime.now() - start}")
|
|
logging.info(f"The MLP directory is: {args.save_dir}")
|
|
|
|
continuous_binary_feat_list_save_path = os.path.join(
|
|
args.save_dir, "continuous_binary_feat_list.json"
|
|
)
|
|
logging.info(
|
|
f"Saving the list of continuous and binary features to {continuous_binary_feat_list_save_path}."
|
|
)
|
|
continuous_binary_feat_list = get_feature_list_for_heavy_ranking(
|
|
feature_list_path, args.data_spec
|
|
)
|
|
twml.util.write_file(
|
|
continuous_binary_feat_list_save_path, continuous_binary_feat_list, encode="json"
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
logging.info("Done.")
|