the-algorithm/pushservice/src/main/python/models/heavy_ranking/update_warm_start_checkpoin...

147 lines
4.2 KiB
Python

"""
Model for modifying the checkpoints of the magic recs cnn Model with addition, deletion, and reordering
of continuous and binary features.
"""
import os
from twitter.deepbird.projects.magic_recs.libs.get_feat_config import FEATURE_LIST_DEFAULT_PATH
from twitter.deepbird.projects.magic_recs.libs.warm_start_utils_v11 import (
get_feature_list_for_heavy_ranking,
mkdirp,
rename_dir,
rmdir,
warm_start_checkpoint,
)
import twml
from twml.trainers import DataRecordTrainer
import tensorflow.compat.v1 as tf
from tensorflow.compat.v1 import logging
def get_arg_parser():
parser = DataRecordTrainer.add_parser_arguments()
parser.add_argument(
"--model_type",
default="deepnorm_gbdt_inputdrop2_rescale",
type=str,
help="specify the model type to use.",
)
parser.add_argument(
"--model_trainer_name",
default="None",
type=str,
help="deprecated, added here just for api compatibility.",
)
parser.add_argument(
"--warm_start_base_dir",
default="none",
type=str,
help="latest ckpt in this folder will be used.",
)
parser.add_argument(
"--output_checkpoint_dir",
default="none",
type=str,
help="Output folder for warm started ckpt. If none, it will move warm_start_base_dir to backup, and overwrite it",
)
parser.add_argument(
"--feature_list",
default="none",
type=str,
help="Which features to use for training",
)
parser.add_argument(
"--old_feature_list",
default="none",
type=str,
help="Which features to use for training",
)
return parser
def get_params(args=None):
parser = get_arg_parser()
if args is None:
return parser.parse_args()
else:
return parser.parse_args(args)
def _main():
opt = get_params()
logging.info("parse is: ")
logging.info(opt)
if opt.feature_list == "none":
feature_list_path = FEATURE_LIST_DEFAULT_PATH
else:
feature_list_path = opt.feature_list
if opt.warm_start_base_dir != "none" and tf.io.gfile.exists(opt.warm_start_base_dir):
if opt.output_checkpoint_dir == "none" or opt.output_checkpoint_dir == opt.warm_start_base_dir:
_warm_start_base_dir = os.path.normpath(opt.warm_start_base_dir) + "_backup_warm_start"
_output_folder_dir = opt.warm_start_base_dir
rename_dir(opt.warm_start_base_dir, _warm_start_base_dir)
tf.logging.info(f"moved {opt.warm_start_base_dir} to {_warm_start_base_dir}")
else:
_warm_start_base_dir = opt.warm_start_base_dir
_output_folder_dir = opt.output_checkpoint_dir
continuous_binary_feat_list_save_path = os.path.join(
_warm_start_base_dir, "continuous_binary_feat_list.json"
)
if opt.old_feature_list != "none":
tf.logging.info("getting old continuous_binary_feat_list")
continuous_binary_feat_list = get_feature_list_for_heavy_ranking(
opt.old_feature_list, opt.data_spec
)
rmdir(continuous_binary_feat_list_save_path)
twml.util.write_file(
continuous_binary_feat_list_save_path, continuous_binary_feat_list, encode="json"
)
tf.logging.info(f"Finish writting files to {continuous_binary_feat_list_save_path}")
warm_start_folder = os.path.join(_warm_start_base_dir, "best_checkpoint")
if not tf.io.gfile.exists(warm_start_folder):
warm_start_folder = _warm_start_base_dir
rmdir(_output_folder_dir)
mkdirp(_output_folder_dir)
new_ckpt = warm_start_checkpoint(
warm_start_folder,
continuous_binary_feat_list_save_path,
feature_list_path,
opt.data_spec,
_output_folder_dir,
opt.model_type,
)
logging.info(f"Created new ckpt {new_ckpt} from {warm_start_folder}")
tf.logging.info("getting new continuous_binary_feat_list")
new_continuous_binary_feat_list_save_path = os.path.join(
_output_folder_dir, "continuous_binary_feat_list.json"
)
continuous_binary_feat_list = get_feature_list_for_heavy_ranking(
feature_list_path, opt.data_spec
)
rmdir(new_continuous_binary_feat_list_save_path)
twml.util.write_file(
new_continuous_binary_feat_list_save_path, continuous_binary_feat_list, encode="json"
)
tf.logging.info(f"Finish writting files to {new_continuous_binary_feat_list_save_path}")
if __name__ == "__main__":
_main()