60 lines
1.3 KiB
Python
60 lines
1.3 KiB
Python
from twml.trainers import DataRecordTrainer
|
|
|
|
from .features import FEATURE_LIST_DEFAULT_PATH
|
|
|
|
|
|
def get_training_arg_parser():
|
|
parser = DataRecordTrainer.add_parser_arguments()
|
|
|
|
parser.add_argument(
|
|
"--feature_list",
|
|
default=FEATURE_LIST_DEFAULT_PATH,
|
|
type=str,
|
|
help="Which features to use for training",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--param_file",
|
|
default=None,
|
|
type=str,
|
|
help="Path to JSON file containing the graph parameters. If None, model will load default parameters.",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--directly_export_best",
|
|
default=False,
|
|
action="store_true",
|
|
help="whether to directly_export best_checkpoint",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--warm_start_from", default=None, type=str, help="model dir to warm start from"
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--warm_start_base_dir",
|
|
default=None,
|
|
type=str,
|
|
help="latest ckpt in this folder will be used to ",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--model_type",
|
|
default=None,
|
|
type=str,
|
|
help="Which type of model to train.",
|
|
)
|
|
return parser
|
|
|
|
|
|
def get_eval_arg_parser():
|
|
parser = get_training_arg_parser()
|
|
parser.add_argument(
|
|
"--eval_checkpoint",
|
|
default=None,
|
|
type=str,
|
|
help="Which checkpoint to use for evaluation",
|
|
)
|
|
|
|
return parser
|