mirror of
https://github.com/twitter/the-algorithm.git
synced 2024-06-02 17:28:45 +02:00
b389c3d302
Pushservice is the main recommendation service we use to surface recommendations to our users via notifications. It fetches candidates from various sources, ranks them in order of relevance, and applies filters to determine the best one to send.
90 lines
2.7 KiB
Python
90 lines
2.7 KiB
Python
from datetime import datetime
|
|
from functools import partial
|
|
import os
|
|
|
|
from ..libs.group_metrics import (
|
|
run_group_metrics_light_ranking,
|
|
run_group_metrics_light_ranking_in_bq,
|
|
)
|
|
from ..libs.metric_fn_utils import get_metric_fn
|
|
from ..libs.model_args import get_arg_parser_light_ranking
|
|
from ..libs.model_utils import read_config
|
|
from .deep_norm import build_graph, DataRecordTrainer, get_config_func, logging
|
|
|
|
|
|
# checkstyle: noqa
|
|
|
|
if __name__ == "__main__":
|
|
parser = get_arg_parser_light_ranking()
|
|
parser.add_argument(
|
|
"--eval_checkpoint",
|
|
default=None,
|
|
type=str,
|
|
help="Which checkpoint to use for evaluation",
|
|
)
|
|
parser.add_argument(
|
|
"--saved_model_path",
|
|
default=None,
|
|
type=str,
|
|
help="Path to saved model for evaluation",
|
|
)
|
|
parser.add_argument(
|
|
"--run_binary_metrics",
|
|
default=False,
|
|
action="store_true",
|
|
help="Whether to compute the basic binary metrics for Light Ranking.",
|
|
)
|
|
|
|
opt = parser.parse_args()
|
|
logging.info("parse is: ")
|
|
logging.info(opt)
|
|
|
|
feature_list = read_config(opt.feature_list).items()
|
|
feature_config = get_config_func(opt.feat_config_type)(
|
|
data_spec_path=opt.data_spec,
|
|
feature_list_provided=feature_list,
|
|
opt=opt,
|
|
add_gbdt=opt.use_gbdt_features,
|
|
run_light_ranking_group_metrics_in_bq=opt.run_light_ranking_group_metrics_in_bq,
|
|
)
|
|
|
|
# -----------------------------------------------
|
|
# Create Trainer
|
|
# -----------------------------------------------
|
|
trainer = DataRecordTrainer(
|
|
name=opt.model_trainer_name,
|
|
params=opt,
|
|
build_graph_fn=partial(build_graph, run_light_ranking_group_metrics_in_bq=True),
|
|
save_dir=opt.save_dir,
|
|
run_config=None,
|
|
feature_config=feature_config,
|
|
metric_fn=get_metric_fn(opt.task_name, use_stratify_metrics=False),
|
|
)
|
|
|
|
# -----------------------------------------------
|
|
# Model Evaluation
|
|
# -----------------------------------------------
|
|
logging.info("Evaluating...")
|
|
start = datetime.now()
|
|
|
|
if opt.run_binary_metrics:
|
|
eval_input_fn = trainer.get_eval_input_fn(repeat=False, shuffle=False)
|
|
eval_steps = None if (opt.eval_steps is not None and opt.eval_steps < 0) else opt.eval_steps
|
|
trainer.estimator.evaluate(eval_input_fn, steps=eval_steps, checkpoint_path=opt.eval_checkpoint)
|
|
|
|
if opt.run_light_ranking_group_metrics_in_bq:
|
|
run_group_metrics_light_ranking_in_bq(
|
|
trainer=trainer, params=opt, checkpoint_path=opt.eval_checkpoint
|
|
)
|
|
|
|
if opt.run_light_ranking_group_metrics:
|
|
run_group_metrics_light_ranking(
|
|
trainer=trainer,
|
|
data_dir=os.path.join(opt.eval_data_dir, opt.eval_start_datetime),
|
|
model_path=opt.saved_model_path,
|
|
parse_fn=feature_config.get_parse_fn(),
|
|
)
|
|
|
|
end = datetime.now()
|
|
logging.info("Evaluating time: " + str(end - start))
|