the-algorithm/pushservice/src/main/python/models/heavy_ranking/run_args.py

60 lines
1.3 KiB
Python

from twml.trainers import DataRecordTrainer
from .features import FEATURE_LIST_DEFAULT_PATH
def get_training_arg_parser():
parser = DataRecordTrainer.add_parser_arguments()
parser.add_argument(
"--feature_list",
default=FEATURE_LIST_DEFAULT_PATH,
type=str,
help="Which features to use for training",
)
parser.add_argument(
"--param_file",
default=None,
type=str,
help="Path to JSON file containing the graph parameters. If None, model will load default parameters.",
)
parser.add_argument(
"--directly_export_best",
default=False,
action="store_true",
help="whether to directly_export best_checkpoint",
)
parser.add_argument(
"--warm_start_from", default=None, type=str, help="model dir to warm start from"
)
parser.add_argument(
"--warm_start_base_dir",
default=None,
type=str,
help="latest ckpt in this folder will be used to ",
)
parser.add_argument(
"--model_type",
default=None,
type=str,
help="Which type of model to train.",
)
return parser
def get_eval_arg_parser():
parser = get_training_arg_parser()
parser.add_argument(
"--eval_checkpoint",
default=None,
type=str,
help="Which checkpoint to use for evaluation",
)
return parser