mirror of
https://github.com/twitter/the-algorithm.git
synced 2025-01-05 09:01:54 +01:00
Compare commits
2 Commits
90d7ea370e
...
b389c3d302
Author | SHA1 | Date | |
---|---|---|---|
|
b389c3d302 | ||
|
01dbfee4c0 |
11
README.md
11
README.md
@ -24,6 +24,7 @@ Product surfaces at Twitter are built on a shared set of data, models, and softw
|
||||
| | [timelines-aggregation-framework](timelines/data_processing/ml_util/aggregation_framework/README.md) | Framework for generating aggregate features in batch or real time. |
|
||||
| | [representation-manager](representation-manager/README.md) | Service to retrieve embeddings (i.e. SimClusers and TwHIN). |
|
||||
| | [twml](twml/README.md) | Legacy machine learning framework built on TensorFlow v1. |
|
||||
| | [Tweetypie](tweetypie/server/README.md) | Core Tweet service that handles the reading and writing of Tweet data. |
|
||||
|
||||
The product surface currently included in this repository is the For You Timeline.
|
||||
|
||||
@ -47,6 +48,16 @@ The core components of the For You Timeline included in this repository are list
|
||||
| | [visibility-filters](visibilitylib/README.md) | Responsible for filtering Twitter content to support legal compliance, improve product quality, increase user trust, protect revenue through the use of hard-filtering, visible product treatments, and coarse-grained downranking. |
|
||||
| | [timelineranker](timelineranker/README.md) | Legacy service which provides relevance-scored tweets from the Earlybird Search Index and UTEG service. |
|
||||
|
||||
### Recommended Notifications
|
||||
|
||||
The core components that power Recommended Notifications included in this repository are listed below:
|
||||
|
||||
| Type | Component | Description |
|
||||
|------------|------------|------------|
|
||||
| Service | [pushservice](pushservice/README.md) | Main recommendation service at Twitter used to surface recommendations to our users via notifications.
|
||||
| Ranking | [pushservice-light-ranker](pushservice/src/main/python/models/light_ranking/README.md) | Light Ranker model used by pushservice to rank Tweets. Bridges candidate generation and heavy ranking by pre-selecting highly-relevant candidates from the initial huge candidate pool. |
|
||||
| | [pushservice-heavy-ranker](pushservice/src/main/python/models/heavy_ranking/README.md) | Multi-task learning model to predict the probabilities that the target users will open and engage with the sent notifications. |
|
||||
|
||||
## Build and test code
|
||||
|
||||
We include Bazel BUILD files for most components, but not a top-level BUILD or WORKSPACE file. We plan to add a more complete build and test system in the future.
|
||||
|
48
pushservice/BUILD.bazel
Normal file
48
pushservice/BUILD.bazel
Normal file
@ -0,0 +1,48 @@
|
||||
alias(
|
||||
name = "frigate-pushservice",
|
||||
target = ":frigate-pushservice_lib",
|
||||
)
|
||||
|
||||
target(
|
||||
name = "frigate-pushservice_lib",
|
||||
dependencies = [
|
||||
"frigate/frigate-pushservice-opensource/src/main/scala/com/twitter/frigate/pushservice",
|
||||
],
|
||||
)
|
||||
|
||||
jvm_binary(
|
||||
name = "bin",
|
||||
basename = "frigate-pushservice",
|
||||
main = "com.twitter.frigate.pushservice.PushServiceMain",
|
||||
runtime_platform = "java11",
|
||||
tags = ["bazel-compatible"],
|
||||
dependencies = [
|
||||
"3rdparty/jvm/ch/qos/logback:logback-classic",
|
||||
"finatra/inject/inject-logback/src/main/scala",
|
||||
"frigate/frigate-pushservice-opensource/src/main/scala/com/twitter/frigate/pushservice",
|
||||
"loglens/loglens-logback/src/main/scala/com/twitter/loglens/logback",
|
||||
"twitter-server/logback-classic/src/main/scala",
|
||||
],
|
||||
excludes = [
|
||||
exclude("com.twitter.translations", "translations-twitter"),
|
||||
exclude("org.apache.hadoop", "hadoop-aws"),
|
||||
exclude("org.tensorflow"),
|
||||
scala_exclude("com.twitter", "ckoia-scala"),
|
||||
],
|
||||
)
|
||||
|
||||
jvm_app(
|
||||
name = "bundle",
|
||||
basename = "frigate-pushservice-package-dist",
|
||||
archive = "zip",
|
||||
binary = ":bin",
|
||||
tags = ["bazel-compatible"],
|
||||
)
|
||||
|
||||
python3_library(
|
||||
name = "mr_model_constants",
|
||||
sources = [
|
||||
"config/deepbird/constants.py",
|
||||
],
|
||||
tags = ["bazel-compatible"],
|
||||
)
|
45
pushservice/readme.md
Normal file
45
pushservice/readme.md
Normal file
@ -0,0 +1,45 @@
|
||||
# Pushservice
|
||||
|
||||
Pushservice is the main push recommendation service at Twitter used to generate recommendation-based notifications for users. It currently powers two functionalities:
|
||||
|
||||
- RefreshForPushHandler: This handler determines whether to send a recommendation push to a user based on their ID. It generates the best push recommendation item and coordinates with downstream services to deliver it
|
||||
- SendHandler: This handler determines and manage whether send the push to users based on the given target user details and the provided push recommendation item
|
||||
|
||||
## Overview
|
||||
|
||||
### RefreshForPushHandler
|
||||
|
||||
RefreshForPushHandler follows these steps:
|
||||
|
||||
- Building Target and checking eligibility
|
||||
- Builds a target user object based on the given user ID
|
||||
- Performs target-level filterings to determine if the target is eligible for a recommendation push
|
||||
- Fetch Candidates
|
||||
- Retrieves a list of potential candidates for the push by querying various candidate sources using the target
|
||||
- Candidate Hydration
|
||||
- Hydrates the candidate details with batch calls to different downstream services.
|
||||
- Pre-rank Filtering, also called Light Filtering
|
||||
- Filters the hydrated candidates with lightweight RPC calls.
|
||||
- Rank
|
||||
- Perform feature hydration for candidates and target user
|
||||
- Performs light ranking on candidates
|
||||
- Performs heavy ranking on candidates
|
||||
- Take Step, also called Heavy Filtering
|
||||
- Takes the top-ranked candidates one by one and applies heavy filtering until one candidate passes all filter steps
|
||||
- Send
|
||||
- Calls the appropriate downstream service to deliver the eligible candidate as a push and in-app notification to the target user
|
||||
|
||||
### SendHandler
|
||||
|
||||
SendHandler follows these steps:
|
||||
|
||||
- Building Target
|
||||
- Builds a target user object based on the given user ID
|
||||
- Candidate Hydration
|
||||
- Hydrates the candidate details with batch calls to different downstream services.
|
||||
- Feature Hydration
|
||||
- Perform feature hydration for candidates and target user
|
||||
- Take Step, also called Heavy Filtering
|
||||
- Perform filterings and validation checking for the given candidate
|
||||
- Send
|
||||
- Calls the appropriate downstream service to deliver the given candidate as a push and/or in-app notification to the target user
|
169
pushservice/src/main/python/models/heavy_ranking/BUILD
Normal file
169
pushservice/src/main/python/models/heavy_ranking/BUILD
Normal file
@ -0,0 +1,169 @@
|
||||
python37_binary(
|
||||
name = "update_warm_start_checkpoint",
|
||||
source = "update_warm_start_checkpoint.py",
|
||||
tags = ["no-mypy"],
|
||||
dependencies = [
|
||||
":deep_norm_lib",
|
||||
"3rdparty/python/_closures/frigate/frigate-pushservice-opensource/src/main/python/models/heavy_ranking:update_warm_start_checkpoint",
|
||||
],
|
||||
)
|
||||
|
||||
python3_library(
|
||||
name = "params_lib",
|
||||
sources = ["params.py"],
|
||||
tags = ["no-mypy"],
|
||||
dependencies = [
|
||||
"3rdparty/python/pydantic:default",
|
||||
"src/python/twitter/deepbird/projects/magic_recs/v11/lib:params_lib",
|
||||
],
|
||||
)
|
||||
|
||||
python3_library(
|
||||
name = "features_lib",
|
||||
sources = ["features.py"],
|
||||
tags = ["no-mypy"],
|
||||
dependencies = [
|
||||
":params_lib",
|
||||
"src/python/twitter/deepbird/projects/magic_recs/libs",
|
||||
"twml:twml-nodeps",
|
||||
],
|
||||
)
|
||||
|
||||
python3_library(
|
||||
name = "model_pools_lib",
|
||||
sources = ["model_pools.py"],
|
||||
tags = ["no-mypy"],
|
||||
dependencies = [
|
||||
":features_lib",
|
||||
":params_lib",
|
||||
"src/python/twitter/deepbird/projects/magic_recs/v11/lib:model_lib",
|
||||
],
|
||||
)
|
||||
|
||||
python3_library(
|
||||
name = "graph_lib",
|
||||
sources = ["graph.py"],
|
||||
tags = ["no-mypy"],
|
||||
dependencies = [
|
||||
":params_lib",
|
||||
"src/python/twitter/deepbird/projects/magic_recs/libs",
|
||||
],
|
||||
)
|
||||
|
||||
python3_library(
|
||||
name = "run_args_lib",
|
||||
sources = ["run_args.py"],
|
||||
tags = ["no-mypy"],
|
||||
dependencies = [
|
||||
":features_lib",
|
||||
":params_lib",
|
||||
"twml:twml-nodeps",
|
||||
],
|
||||
)
|
||||
|
||||
python3_library(
|
||||
name = "deep_norm_lib",
|
||||
sources = ["deep_norm.py"],
|
||||
tags = ["no-mypy"],
|
||||
dependencies = [
|
||||
":features_lib",
|
||||
":graph_lib",
|
||||
":model_pools_lib",
|
||||
":params_lib",
|
||||
":run_args_lib",
|
||||
"src/python/twitter/deepbird/projects/magic_recs/libs",
|
||||
"src/python/twitter/deepbird/util/data",
|
||||
"twml:twml-nodeps",
|
||||
],
|
||||
)
|
||||
|
||||
python3_library(
|
||||
name = "eval_lib",
|
||||
sources = ["eval.py"],
|
||||
tags = ["no-mypy"],
|
||||
dependencies = [
|
||||
":features_lib",
|
||||
":graph_lib",
|
||||
":model_pools_lib",
|
||||
":params_lib",
|
||||
":run_args_lib",
|
||||
"src/python/twitter/deepbird/projects/magic_recs/libs",
|
||||
"twml:twml-nodeps",
|
||||
],
|
||||
)
|
||||
|
||||
python37_binary(
|
||||
name = "deep_norm",
|
||||
source = "deep_norm.py",
|
||||
dependencies = [
|
||||
":deep_norm_lib",
|
||||
"3rdparty/python/_closures/frigate/frigate-pushservice-opensource/src/main/python/models/heavy_ranking:deep_norm",
|
||||
"twml",
|
||||
],
|
||||
)
|
||||
|
||||
python37_binary(
|
||||
name = "eval",
|
||||
source = "eval.py",
|
||||
dependencies = [
|
||||
":eval_lib",
|
||||
"3rdparty/python/_closures/frigate/frigate-pushservice-opensource/src/main/python/models/heavy_ranking:eval",
|
||||
"twml",
|
||||
],
|
||||
)
|
||||
|
||||
python3_library(
|
||||
name = "mlwf_libs",
|
||||
tags = ["no-mypy"],
|
||||
dependencies = [
|
||||
":deep_norm_lib",
|
||||
"twml",
|
||||
],
|
||||
)
|
||||
|
||||
python37_binary(
|
||||
name = "train_model",
|
||||
source = "deep_norm.py",
|
||||
dependencies = [
|
||||
":deep_norm_lib",
|
||||
"3rdparty/python/_closures/frigate/frigate-pushservice-opensource/src/main/python/models/heavy_ranking:train_model",
|
||||
],
|
||||
)
|
||||
|
||||
python37_binary(
|
||||
name = "train_model_local",
|
||||
source = "deep_norm.py",
|
||||
dependencies = [
|
||||
":deep_norm_lib",
|
||||
"3rdparty/python/_closures/frigate/frigate-pushservice-opensource/src/main/python/models/heavy_ranking:train_model_local",
|
||||
"twml",
|
||||
],
|
||||
)
|
||||
|
||||
python37_binary(
|
||||
name = "eval_model_local",
|
||||
source = "eval.py",
|
||||
dependencies = [
|
||||
":eval_lib",
|
||||
"3rdparty/python/_closures/frigate/frigate-pushservice-opensource/src/main/python/models/heavy_ranking:eval_model_local",
|
||||
"twml",
|
||||
],
|
||||
)
|
||||
|
||||
python37_binary(
|
||||
name = "eval_model",
|
||||
source = "eval.py",
|
||||
dependencies = [
|
||||
":eval_lib",
|
||||
"3rdparty/python/_closures/frigate/frigate-pushservice-opensource/src/main/python/models/heavy_ranking:eval_model",
|
||||
],
|
||||
)
|
||||
|
||||
python37_binary(
|
||||
name = "mlwf_model",
|
||||
source = "deep_norm.py",
|
||||
dependencies = [
|
||||
":mlwf_libs",
|
||||
"3rdparty/python/_closures/frigate/frigate-pushservice-opensource/src/main/python/models/heavy_ranking:mlwf_model",
|
||||
],
|
||||
)
|
20
pushservice/src/main/python/models/heavy_ranking/README.md
Normal file
20
pushservice/src/main/python/models/heavy_ranking/README.md
Normal file
@ -0,0 +1,20 @@
|
||||
# Notification Heavy Ranker Model
|
||||
|
||||
## Model Context
|
||||
There are 4 major components of Twitter notifications recommendation system: 1) candidate generation 2) light ranking 3) heavy ranking & 4) quality control. This notification heavy ranker model is the core ranking model for the personalised notifications recommendation. It's a multi-task learning model to predict the probabilities that the target users will open and engage with the sent notifications.
|
||||
|
||||
|
||||
## Directory Structure
|
||||
- BUILD: this file defines python library dependencies
|
||||
- deep_norm.py: this file contains how to set up continuous training, model evaluation and model exporting for the notification heavy ranker model
|
||||
- eval.py: the main python entry file to set up the overall model evaluation pipeline
|
||||
- features.py: this file contains importing feature list and support functions for feature engineering
|
||||
- graph.py: this file defines how to build the tensorflow graph with specified model architecture, loss function and training configuration
|
||||
- model_pools.py: this file defines the available model types for the heavy ranker
|
||||
- params.py: this file defines hyper-parameters used in the notification heavy ranker
|
||||
- run_args.py: this file defines command line parameters to run model training & evaluation
|
||||
- update_warm_start_checkpoint.py: this file contains the support to modify checkpoints of the given saved heavy ranker model
|
||||
- lib/BUILD: this file defines python library dependencies for tensorflow model architecture
|
||||
- lib/layers.py: this file defines different type of convolution layers to be used in the heavy ranker model
|
||||
- lib/model.py: this file defines the module containing ClemNet, the heavy ranker model type
|
||||
- lib/params.py: this file defines parameters used in the heavy ranker model
|
136
pushservice/src/main/python/models/heavy_ranking/deep_norm.py
Normal file
136
pushservice/src/main/python/models/heavy_ranking/deep_norm.py
Normal file
@ -0,0 +1,136 @@
|
||||
"""
|
||||
Training job for the heavy ranker of the push notification service.
|
||||
"""
|
||||
from datetime import datetime
|
||||
import json
|
||||
import os
|
||||
|
||||
import twml
|
||||
|
||||
from ..libs.metric_fn_utils import flip_disliked_labels, get_metric_fn
|
||||
from ..libs.model_utils import read_config
|
||||
from ..libs.warm_start_utils import get_feature_list_for_heavy_ranking, warm_start_checkpoint
|
||||
from .features import get_feature_config
|
||||
from .model_pools import ALL_MODELS
|
||||
from .params import load_graph_params
|
||||
from .run_args import get_training_arg_parser
|
||||
|
||||
import tensorflow.compat.v1 as tf
|
||||
from tensorflow.compat.v1 import logging
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args, _ = get_training_arg_parser().parse_known_args()
|
||||
logging.info(f"Parsed args: {args}")
|
||||
|
||||
params = load_graph_params(args)
|
||||
logging.info(f"Loaded graph params: {params}")
|
||||
|
||||
param_file = os.path.join(args.save_dir, "params.json")
|
||||
logging.info(f"Saving graph params to: {param_file}")
|
||||
with tf.io.gfile.GFile(param_file, mode="w") as file:
|
||||
json.dump(params.json(), file, ensure_ascii=False, indent=4)
|
||||
|
||||
logging.info(f"Get Feature Config: {args.feature_list}")
|
||||
feature_list = read_config(args.feature_list).items()
|
||||
feature_config = get_feature_config(
|
||||
data_spec_path=args.data_spec,
|
||||
params=params,
|
||||
feature_list_provided=feature_list,
|
||||
)
|
||||
feature_list_path = args.feature_list
|
||||
|
||||
warm_start_from = args.warm_start_from
|
||||
if args.warm_start_base_dir:
|
||||
logging.info(f"Get warm started model from: {args.warm_start_base_dir}.")
|
||||
|
||||
continuous_binary_feat_list_save_path = os.path.join(
|
||||
args.warm_start_base_dir, "continuous_binary_feat_list.json"
|
||||
)
|
||||
warm_start_folder = os.path.join(args.warm_start_base_dir, "best_checkpoint")
|
||||
job_name = os.path.basename(args.save_dir)
|
||||
ws_output_ckpt_folder = os.path.join(args.warm_start_base_dir, f"warm_start_for_{job_name}")
|
||||
if tf.io.gfile.exists(ws_output_ckpt_folder):
|
||||
tf.io.gfile.rmtree(ws_output_ckpt_folder)
|
||||
|
||||
tf.io.gfile.mkdir(ws_output_ckpt_folder)
|
||||
|
||||
warm_start_from = warm_start_checkpoint(
|
||||
warm_start_folder,
|
||||
continuous_binary_feat_list_save_path,
|
||||
feature_list_path,
|
||||
args.data_spec,
|
||||
ws_output_ckpt_folder,
|
||||
)
|
||||
logging.info(f"Created warm_start_from_ckpt {warm_start_from}.")
|
||||
|
||||
logging.info("Build Trainer.")
|
||||
metric_fn = get_metric_fn("OONC_Engagement" if len(params.tasks) == 2 else "OONC", False)
|
||||
|
||||
trainer = twml.trainers.DataRecordTrainer(
|
||||
name="magic_recs",
|
||||
params=args,
|
||||
build_graph_fn=lambda *args: ALL_MODELS[params.model.name](params=params)(*args),
|
||||
save_dir=args.save_dir,
|
||||
run_config=None,
|
||||
feature_config=feature_config,
|
||||
metric_fn=flip_disliked_labels(metric_fn),
|
||||
warm_start_from=warm_start_from,
|
||||
)
|
||||
|
||||
logging.info("Build train and eval input functions.")
|
||||
train_input_fn = trainer.get_train_input_fn(shuffle=True)
|
||||
eval_input_fn = trainer.get_eval_input_fn(repeat=False, shuffle=False)
|
||||
|
||||
learn = trainer.learn
|
||||
if args.distributed or args.num_workers is not None:
|
||||
learn = trainer.train_and_evaluate
|
||||
|
||||
if not args.directly_export_best:
|
||||
logging.info("Starting training")
|
||||
start = datetime.now()
|
||||
learn(
|
||||
early_stop_minimize=False,
|
||||
early_stop_metric="pr_auc_unweighted_OONC",
|
||||
early_stop_patience=args.early_stop_patience,
|
||||
early_stop_tolerance=args.early_stop_tolerance,
|
||||
eval_input_fn=eval_input_fn,
|
||||
train_input_fn=train_input_fn,
|
||||
)
|
||||
logging.info(f"Total training time: {datetime.now() - start}")
|
||||
else:
|
||||
logging.info("Directly exporting the model")
|
||||
|
||||
if not args.export_dir:
|
||||
args.export_dir = os.path.join(args.save_dir, "exported_models")
|
||||
|
||||
logging.info(f"Exporting the model to {args.export_dir}.")
|
||||
start = datetime.now()
|
||||
twml.contrib.export.export_fn.export_all_models(
|
||||
trainer=trainer,
|
||||
export_dir=args.export_dir,
|
||||
parse_fn=feature_config.get_parse_fn(),
|
||||
serving_input_receiver_fn=feature_config.get_serving_input_receiver_fn(),
|
||||
export_output_fn=twml.export_output_fns.batch_prediction_continuous_output_fn,
|
||||
)
|
||||
|
||||
logging.info(f"Total model export time: {datetime.now() - start}")
|
||||
logging.info(f"The MLP directory is: {args.save_dir}")
|
||||
|
||||
continuous_binary_feat_list_save_path = os.path.join(
|
||||
args.save_dir, "continuous_binary_feat_list.json"
|
||||
)
|
||||
logging.info(
|
||||
f"Saving the list of continuous and binary features to {continuous_binary_feat_list_save_path}."
|
||||
)
|
||||
continuous_binary_feat_list = get_feature_list_for_heavy_ranking(
|
||||
feature_list_path, args.data_spec
|
||||
)
|
||||
twml.util.write_file(
|
||||
continuous_binary_feat_list_save_path, continuous_binary_feat_list, encode="json"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
logging.info("Done.")
|
59
pushservice/src/main/python/models/heavy_ranking/eval.py
Normal file
59
pushservice/src/main/python/models/heavy_ranking/eval.py
Normal file
@ -0,0 +1,59 @@
|
||||
"""
|
||||
Evaluation job for the heavy ranker of the push notification service.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
import twml
|
||||
|
||||
from ..libs.metric_fn_utils import get_metric_fn
|
||||
from ..libs.model_utils import read_config
|
||||
from .features import get_feature_config
|
||||
from .model_pools import ALL_MODELS
|
||||
from .params import load_graph_params
|
||||
from .run_args import get_eval_arg_parser
|
||||
|
||||
from tensorflow.compat.v1 import logging
|
||||
|
||||
|
||||
def main():
|
||||
args, _ = get_eval_arg_parser().parse_known_args()
|
||||
logging.info(f"Parsed args: {args}")
|
||||
|
||||
params = load_graph_params(args)
|
||||
logging.info(f"Loaded graph params: {params}")
|
||||
|
||||
logging.info(f"Get Feature Config: {args.feature_list}")
|
||||
feature_list = read_config(args.feature_list).items()
|
||||
feature_config = get_feature_config(
|
||||
data_spec_path=args.data_spec,
|
||||
params=params,
|
||||
feature_list_provided=feature_list,
|
||||
)
|
||||
|
||||
logging.info("Build DataRecordTrainer.")
|
||||
metric_fn = get_metric_fn("OONC_Engagement" if len(params.tasks) == 2 else "OONC", False)
|
||||
|
||||
trainer = twml.trainers.DataRecordTrainer(
|
||||
name="magic_recs",
|
||||
params=args,
|
||||
build_graph_fn=lambda *args: ALL_MODELS[params.model.name](params=params)(*args),
|
||||
save_dir=args.save_dir,
|
||||
run_config=None,
|
||||
feature_config=feature_config,
|
||||
metric_fn=metric_fn,
|
||||
)
|
||||
|
||||
logging.info("Run the evaluation.")
|
||||
start = datetime.now()
|
||||
trainer._estimator.evaluate(
|
||||
input_fn=trainer.get_eval_input_fn(repeat=False, shuffle=False),
|
||||
steps=None if (args.eval_steps is not None and args.eval_steps < 0) else args.eval_steps,
|
||||
checkpoint_path=args.eval_checkpoint,
|
||||
)
|
||||
logging.info(f"Evaluating time: {datetime.now() - start}.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
logging.info("Job done.")
|
138
pushservice/src/main/python/models/heavy_ranking/features.py
Normal file
138
pushservice/src/main/python/models/heavy_ranking/features.py
Normal file
@ -0,0 +1,138 @@
|
||||
import os
|
||||
from typing import Dict
|
||||
|
||||
from twitter.deepbird.projects.magic_recs.libs.model_utils import filter_nans_and_infs
|
||||
import twml
|
||||
from twml.layers import full_sparse, sparse_max_norm
|
||||
|
||||
from .params import FeaturesParams, GraphParams, SparseFeaturesParams
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow import Tensor
|
||||
import tensorflow.compat.v1 as tf1
|
||||
|
||||
|
||||
FEAT_CONFIG_DEFAULT_VAL = 0
|
||||
DEFAULT_FEATURE_LIST_PATH = "./feature_list_default.yaml"
|
||||
FEATURE_LIST_DEFAULT_PATH = os.path.join(
|
||||
os.path.dirname(os.path.realpath(__file__)), DEFAULT_FEATURE_LIST_PATH
|
||||
)
|
||||
|
||||
|
||||
def get_feature_config(data_spec_path=None, feature_list_provided=[], params: GraphParams = None):
|
||||
|
||||
a_string_feat_list = [feat for feat, feat_type in feature_list_provided if feat_type != "S"]
|
||||
|
||||
builder = twml.contrib.feature_config.FeatureConfigBuilder(
|
||||
data_spec_path=data_spec_path, debug=False
|
||||
)
|
||||
|
||||
builder = builder.extract_feature_group(
|
||||
feature_regexes=a_string_feat_list,
|
||||
group_name="continuous_features",
|
||||
default_value=FEAT_CONFIG_DEFAULT_VAL,
|
||||
type_filter=["CONTINUOUS"],
|
||||
)
|
||||
|
||||
builder = builder.extract_feature_group(
|
||||
feature_regexes=a_string_feat_list,
|
||||
group_name="binary_features",
|
||||
type_filter=["BINARY"],
|
||||
)
|
||||
|
||||
if params.model.features.sparse_features:
|
||||
builder = builder.extract_features_as_hashed_sparse(
|
||||
feature_regexes=a_string_feat_list,
|
||||
hash_space_size_bits=params.model.features.sparse_features.bits,
|
||||
type_filter=["DISCRETE", "STRING", "SPARSE_BINARY"],
|
||||
output_tensor_name="sparse_not_continuous",
|
||||
)
|
||||
|
||||
builder = builder.extract_features_as_hashed_sparse(
|
||||
feature_regexes=[feat for feat, feat_type in feature_list_provided if feat_type == "S"],
|
||||
hash_space_size_bits=params.model.features.sparse_features.bits,
|
||||
type_filter=["SPARSE_CONTINUOUS"],
|
||||
output_tensor_name="sparse_continuous",
|
||||
)
|
||||
|
||||
builder = builder.add_labels([task.label for task in params.tasks] + ["label.ntabDislike"])
|
||||
|
||||
if params.weight:
|
||||
builder = builder.define_weight(params.weight)
|
||||
|
||||
return builder.build()
|
||||
|
||||
|
||||
def dense_features(features: Dict[str, Tensor], training: bool) -> Tensor:
|
||||
"""
|
||||
Performs feature transformations on the raw dense features (continuous and binary).
|
||||
"""
|
||||
with tf.name_scope("dense_features"):
|
||||
x = filter_nans_and_infs(features["continuous_features"])
|
||||
|
||||
x = tf.sign(x) * tf.math.log(tf.abs(x) + 1)
|
||||
x = tf1.layers.batch_normalization(
|
||||
x, momentum=0.9999, training=training, renorm=training, axis=1
|
||||
)
|
||||
x = tf.clip_by_value(x, -5, 5)
|
||||
|
||||
transformed_continous_features = tf.where(tf.math.is_nan(x), tf.zeros_like(x), x)
|
||||
|
||||
binary_features = filter_nans_and_infs(features["binary_features"])
|
||||
binary_features = tf.dtypes.cast(binary_features, tf.float32)
|
||||
|
||||
output = tf.concat([transformed_continous_features, binary_features], axis=1)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def sparse_features(
|
||||
features: Dict[str, Tensor], training: bool, params: SparseFeaturesParams
|
||||
) -> Tensor:
|
||||
"""
|
||||
Performs feature transformations on the raw sparse features.
|
||||
"""
|
||||
|
||||
with tf.name_scope("sparse_features"):
|
||||
with tf.name_scope("sparse_not_continuous"):
|
||||
sparse_not_continuous = full_sparse(
|
||||
inputs=features["sparse_not_continuous"],
|
||||
output_size=params.embedding_size,
|
||||
use_sparse_grads=training,
|
||||
use_binary_values=False,
|
||||
)
|
||||
|
||||
with tf.name_scope("sparse_continuous"):
|
||||
shape_enforced_input = twml.util.limit_sparse_tensor_size(
|
||||
sparse_tf=features["sparse_continuous"], input_size_bits=params.bits, mask_indices=False
|
||||
)
|
||||
|
||||
normalized_continuous_sparse = sparse_max_norm(
|
||||
inputs=shape_enforced_input, is_training=training
|
||||
)
|
||||
|
||||
sparse_continuous = full_sparse(
|
||||
inputs=normalized_continuous_sparse,
|
||||
output_size=params.embedding_size,
|
||||
use_sparse_grads=training,
|
||||
use_binary_values=False,
|
||||
)
|
||||
|
||||
output = tf.concat([sparse_not_continuous, sparse_continuous], axis=1)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def get_features(features: Dict[str, Tensor], training: bool, params: FeaturesParams) -> Tensor:
|
||||
"""
|
||||
Performs feature transformations on the dense and sparse features and combine the resulting
|
||||
tensors into a single one.
|
||||
"""
|
||||
with tf.name_scope("features"):
|
||||
x = dense_features(features, training)
|
||||
tf1.logging.info(f"Dense features: {x.shape}")
|
||||
|
||||
if params.sparse_features:
|
||||
x = tf.concat([x, sparse_features(features, training, params.sparse_features)], axis=1)
|
||||
|
||||
return x
|
129
pushservice/src/main/python/models/heavy_ranking/graph.py
Normal file
129
pushservice/src/main/python/models/heavy_ranking/graph.py
Normal file
@ -0,0 +1,129 @@
|
||||
"""
|
||||
Graph class defining methods to obtain key quantities such as:
|
||||
* the logits
|
||||
* the probabilities
|
||||
* the final score
|
||||
* the loss function
|
||||
* the training operator
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict
|
||||
|
||||
from twitter.deepbird.hparam import HParams
|
||||
import twml
|
||||
|
||||
from ..libs.model_utils import generate_disliked_mask
|
||||
from .params import GraphParams
|
||||
|
||||
import tensorflow as tf
|
||||
import tensorflow.compat.v1 as tf1
|
||||
|
||||
|
||||
class Graph(ABC):
|
||||
def __init__(self, params: GraphParams):
|
||||
self.params = params
|
||||
|
||||
@abstractmethod
|
||||
def get_logits(self, features: Dict[str, tf.Tensor], mode: tf.estimator.ModeKeys) -> tf.Tensor:
|
||||
pass
|
||||
|
||||
def get_probabilities(self, logits: tf.Tensor) -> tf.Tensor:
|
||||
return tf.math.cumprod(tf.nn.sigmoid(logits), axis=1, name="probabilities")
|
||||
|
||||
def get_task_weights(self, labels: tf.Tensor) -> tf.Tensor:
|
||||
oonc_label = tf.reshape(labels[:, 0], shape=(-1, 1))
|
||||
task_weights = tf.concat([tf.ones_like(oonc_label), oonc_label], axis=1)
|
||||
|
||||
n_labels = len(self.params.tasks)
|
||||
task_weights = tf.reshape(task_weights[:, 0:n_labels], shape=(-1, n_labels))
|
||||
|
||||
return task_weights
|
||||
|
||||
def get_loss(self, labels: tf.Tensor, logits: tf.Tensor, **kwargs: Any) -> tf.Tensor:
|
||||
with tf.name_scope("weights"):
|
||||
disliked_mask = generate_disliked_mask(labels)
|
||||
|
||||
labels = tf.reshape(labels[:, 0:2], shape=[-1, 2])
|
||||
|
||||
labels = labels * tf.cast(tf.logical_not(disliked_mask), dtype=labels.dtype)
|
||||
|
||||
with tf.name_scope("task_weight"):
|
||||
task_weights = self.get_task_weights(labels)
|
||||
|
||||
with tf.name_scope("batch_size"):
|
||||
batch_size = tf.cast(tf.shape(labels)[0], dtype=tf.float32, name="batch_size")
|
||||
|
||||
weights = task_weights / batch_size
|
||||
|
||||
with tf.name_scope("loss"):
|
||||
loss = tf.reduce_sum(
|
||||
tf.nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits) * weights,
|
||||
)
|
||||
|
||||
return loss
|
||||
|
||||
def get_score(self, probabilities: tf.Tensor) -> tf.Tensor:
|
||||
with tf.name_scope("score_weight"):
|
||||
score_weights = tf.constant([task.score_weight for task in self.params.tasks])
|
||||
score_weights = score_weights / tf.reduce_sum(score_weights, axis=0)
|
||||
|
||||
with tf.name_scope("score"):
|
||||
score = tf.reshape(tf.reduce_sum(probabilities * score_weights, axis=1), shape=[-1, 1])
|
||||
|
||||
return score
|
||||
|
||||
def get_train_op(self, loss: tf.Tensor, twml_params) -> Any:
|
||||
with tf.name_scope("optimizer"):
|
||||
learning_rate = twml_params.learning_rate
|
||||
optimizer = tf1.train.GradientDescentOptimizer(learning_rate=learning_rate)
|
||||
|
||||
update_ops = set(tf1.get_collection(tf1.GraphKeys.UPDATE_OPS))
|
||||
with tf.control_dependencies(update_ops):
|
||||
train_op = twml.optimizers.optimize_loss(
|
||||
loss=loss,
|
||||
variables=tf1.trainable_variables(),
|
||||
global_step=tf1.train.get_global_step(),
|
||||
optimizer=optimizer,
|
||||
learning_rate=None,
|
||||
)
|
||||
|
||||
return train_op
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
features: Dict[str, tf.Tensor],
|
||||
labels: tf.Tensor,
|
||||
mode: tf.estimator.ModeKeys,
|
||||
params: HParams,
|
||||
config=None,
|
||||
) -> Dict[str, tf.Tensor]:
|
||||
training = mode == tf.estimator.ModeKeys.TRAIN
|
||||
logits = self.get_logits(features=features, training=training)
|
||||
probabilities = self.get_probabilities(logits=logits)
|
||||
score = None
|
||||
loss = None
|
||||
train_op = None
|
||||
|
||||
if mode == tf.estimator.ModeKeys.PREDICT:
|
||||
score = self.get_score(probabilities=probabilities)
|
||||
output = {"loss": loss, "train_op": train_op, "prediction": score}
|
||||
|
||||
elif mode in (tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL):
|
||||
loss = self.get_loss(labels=labels, logits=logits)
|
||||
|
||||
if mode == tf.estimator.ModeKeys.TRAIN:
|
||||
train_op = self.get_train_op(loss=loss, twml_params=params)
|
||||
|
||||
output = {"loss": loss, "train_op": train_op, "output": probabilities}
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"""
|
||||
Invalid mode. Possible values are: {tf.estimator.ModeKeys.PREDICT}, {tf.estimator.ModeKeys.TRAIN}, and {tf.estimator.ModeKeys.EVAL}
|
||||
. Passed: {mode}
|
||||
"""
|
||||
)
|
||||
|
||||
return output
|
42
pushservice/src/main/python/models/heavy_ranking/lib/BUILD
Normal file
42
pushservice/src/main/python/models/heavy_ranking/lib/BUILD
Normal file
@ -0,0 +1,42 @@
|
||||
python3_library(
|
||||
name = "params_lib",
|
||||
sources = [
|
||||
"params.py",
|
||||
],
|
||||
tags = [
|
||||
"bazel-compatible",
|
||||
"no-mypy",
|
||||
],
|
||||
dependencies = [
|
||||
"3rdparty/python/pydantic:default",
|
||||
],
|
||||
)
|
||||
|
||||
python3_library(
|
||||
name = "layers_lib",
|
||||
sources = [
|
||||
"layers.py",
|
||||
],
|
||||
tags = [
|
||||
"bazel-compatible",
|
||||
"no-mypy",
|
||||
],
|
||||
dependencies = [
|
||||
],
|
||||
)
|
||||
|
||||
python3_library(
|
||||
name = "model_lib",
|
||||
sources = [
|
||||
"model.py",
|
||||
],
|
||||
tags = [
|
||||
"bazel-compatible",
|
||||
"no-mypy",
|
||||
],
|
||||
dependencies = [
|
||||
":layers_lib",
|
||||
":params_lib",
|
||||
"3rdparty/python/absl-py:default",
|
||||
],
|
||||
)
|
128
pushservice/src/main/python/models/heavy_ranking/lib/layers.py
Normal file
128
pushservice/src/main/python/models/heavy_ranking/lib/layers.py
Normal file
@ -0,0 +1,128 @@
|
||||
"""
|
||||
Different type of convolution layers to be used in the ClemNet.
|
||||
"""
|
||||
from typing import Any
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
class KerasConv1D(tf.keras.layers.Layer):
|
||||
"""
|
||||
Basic Conv1D layer in a wrapper to be compatible with ClemNet.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kernel_size: int,
|
||||
filters: int,
|
||||
strides: int,
|
||||
padding: str,
|
||||
use_bias: bool = True,
|
||||
kernel_initializer: str = "glorot_uniform",
|
||||
bias_initializer: str = "zeros",
|
||||
**kwargs: Any,
|
||||
):
|
||||
super(KerasConv1D, self).__init__(**kwargs)
|
||||
self.kernel_size = kernel_size
|
||||
self.filters = filters
|
||||
self.use_bias = use_bias
|
||||
self.kernel_initializer = kernel_initializer
|
||||
self.bias_initializer = bias_initializer
|
||||
self.strides = strides
|
||||
self.padding = padding
|
||||
|
||||
def build(self, input_shape: tf.TensorShape) -> None:
|
||||
assert (
|
||||
len(input_shape) == 3
|
||||
), f"Tensor shape must be of length 3. Passed tensor of shape {input_shape}."
|
||||
|
||||
self.features = input_shape[1]
|
||||
|
||||
self.w = tf.keras.layers.Conv1D(
|
||||
kernel_size=self.kernel_size,
|
||||
filters=self.filters,
|
||||
strides=self.strides,
|
||||
padding=self.padding,
|
||||
use_bias=self.use_bias,
|
||||
kernel_initializer=self.kernel_initializer,
|
||||
bias_initializer=self.bias_initializer,
|
||||
name=self.name,
|
||||
)
|
||||
|
||||
def call(self, inputs: tf.Tensor, **kwargs: Any) -> tf.Tensor:
|
||||
return self.w(inputs)
|
||||
|
||||
|
||||
class ChannelWiseDense(tf.keras.layers.Layer):
|
||||
"""
|
||||
Dense layer is applied to each channel separately. This is more memory and computationally
|
||||
efficient than flattening the channels and performing single dense layers over it which is the
|
||||
default behavior in tf1.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
output_size: int,
|
||||
use_bias: bool,
|
||||
kernel_initializer: str = "uniform_glorot",
|
||||
bias_initializer: str = "zeros",
|
||||
**kwargs: Any,
|
||||
):
|
||||
super(ChannelWiseDense, self).__init__(**kwargs)
|
||||
self.output_size = output_size
|
||||
self.use_bias = use_bias
|
||||
self.kernel_initializer = kernel_initializer
|
||||
self.bias_initializer = bias_initializer
|
||||
|
||||
def build(self, input_shape: tf.TensorShape) -> None:
|
||||
assert (
|
||||
len(input_shape) == 3
|
||||
), f"Tensor shape must be of length 3. Passed tensor of shape {input_shape}."
|
||||
|
||||
input_size = input_shape[1]
|
||||
channels = input_shape[2]
|
||||
|
||||
self.kernel = self.add_weight(
|
||||
name="kernel",
|
||||
shape=(channels, input_size, self.output_size),
|
||||
initializer=self.kernel_initializer,
|
||||
trainable=True,
|
||||
)
|
||||
|
||||
self.bias = self.add_weight(
|
||||
name="bias",
|
||||
shape=(channels, self.output_size),
|
||||
initializer=self.bias_initializer,
|
||||
trainable=self.use_bias,
|
||||
)
|
||||
|
||||
def call(self, inputs: tf.Tensor, **kwargs: Any) -> tf.Tensor:
|
||||
x = inputs
|
||||
|
||||
transposed_x = tf.transpose(x, perm=[2, 0, 1])
|
||||
transposed_residual = (
|
||||
tf.transpose(tf.matmul(transposed_x, self.kernel), perm=[1, 0, 2]) + self.bias
|
||||
)
|
||||
output = tf.transpose(transposed_residual, perm=[0, 2, 1])
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class ResidualLayer(tf.keras.layers.Layer):
|
||||
"""
|
||||
Layer implementing a 3D-residual connection.
|
||||
"""
|
||||
|
||||
def build(self, input_shape: tf.TensorShape) -> None:
|
||||
assert (
|
||||
len(input_shape) == 3
|
||||
), f"Tensor shape must be of length 3. Passed tensor of shape {input_shape}."
|
||||
|
||||
def call(self, inputs: tf.Tensor, residual: tf.Tensor, **kwargs: Any) -> tf.Tensor:
|
||||
shortcut = tf.keras.layers.Conv1D(
|
||||
filters=int(residual.shape[2]), strides=1, kernel_size=1, padding="SAME", use_bias=False
|
||||
)(inputs)
|
||||
|
||||
output = tf.add(shortcut, residual)
|
||||
|
||||
return output
|
@ -0,0 +1,76 @@
|
||||
"""
|
||||
Module containing ClemNet.
|
||||
"""
|
||||
from typing import Any
|
||||
|
||||
from .layers import ChannelWiseDense, KerasConv1D, ResidualLayer
|
||||
from .params import BlockParams, ClemNetParams
|
||||
|
||||
import tensorflow as tf
|
||||
import tensorflow.compat.v1 as tf1
|
||||
|
||||
|
||||
class Block2(tf.keras.layers.Layer):
|
||||
"""
|
||||
Possible ClemNet block. Architecture is as follow:
|
||||
Optional(DenseLayer + BN + Act)
|
||||
Optional(ConvLayer + BN + Act)
|
||||
Optional(Residual Layer)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, params: BlockParams, **kwargs: Any):
|
||||
super(Block2, self).__init__(**kwargs)
|
||||
self.params = params
|
||||
|
||||
def build(self, input_shape: tf.TensorShape) -> None:
|
||||
assert (
|
||||
len(input_shape) == 3
|
||||
), f"Tensor shape must be of length 3. Passed tensor of shape {input_shape}."
|
||||
|
||||
def call(self, inputs: tf.Tensor, training: bool) -> tf.Tensor:
|
||||
x = inputs
|
||||
if self.params.dense:
|
||||
x = ChannelWiseDense(**self.params.dense.dict())(inputs=x, training=training)
|
||||
x = tf1.layers.batch_normalization(x, momentum=0.9999, training=training, axis=1)
|
||||
x = tf.keras.layers.Activation(self.params.activation)(x)
|
||||
|
||||
if self.params.conv:
|
||||
x = KerasConv1D(**self.params.conv.dict())(inputs=x, training=training)
|
||||
x = tf1.layers.batch_normalization(x, momentum=0.9999, training=training, axis=1)
|
||||
x = tf.keras.layers.Activation(self.params.activation)(x)
|
||||
|
||||
if self.params.residual:
|
||||
x = ResidualLayer()(inputs=inputs, residual=x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class ClemNet(tf.keras.layers.Layer):
|
||||
"""
|
||||
A residual network stacking residual blocks composed of dense layers and convolutions.
|
||||
"""
|
||||
|
||||
def __init__(self, params: ClemNetParams, **kwargs: Any):
|
||||
super(ClemNet, self).__init__(**kwargs)
|
||||
self.params = params
|
||||
|
||||
def build(self, input_shape: tf.TensorShape) -> None:
|
||||
assert len(input_shape) in (
|
||||
2,
|
||||
3,
|
||||
), f"Tensor shape must be of length 3. Passed tensor of shape {input_shape}."
|
||||
|
||||
def call(self, inputs: tf.Tensor, training: bool) -> tf.Tensor:
|
||||
if len(inputs.shape) < 3:
|
||||
inputs = tf.expand_dims(inputs, axis=-1)
|
||||
|
||||
x = inputs
|
||||
for block_params in self.params.blocks:
|
||||
x = Block2(block_params)(inputs=x, training=training)
|
||||
|
||||
x = tf.keras.layers.Flatten(name="flattened")(x)
|
||||
if self.params.top:
|
||||
x = tf.keras.layers.Dense(units=self.params.top.n_labels, name="logits")(x)
|
||||
|
||||
return x
|
@ -0,0 +1,49 @@
|
||||
"""
|
||||
Parameters used in ClemNet.
|
||||
"""
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel, Extra, Field, PositiveInt
|
||||
|
||||
|
||||
# checkstyle: noqa
|
||||
|
||||
|
||||
class ExtendedBaseModel(BaseModel):
|
||||
class Config:
|
||||
extra = Extra.forbid
|
||||
|
||||
|
||||
class DenseParams(ExtendedBaseModel):
|
||||
name: Optional[str]
|
||||
bias_initializer: str = "zeros"
|
||||
kernel_initializer: str = "glorot_uniform"
|
||||
output_size: PositiveInt
|
||||
use_bias: bool = Field(True)
|
||||
|
||||
|
||||
class ConvParams(ExtendedBaseModel):
|
||||
name: Optional[str]
|
||||
bias_initializer: str = "zeros"
|
||||
filters: PositiveInt
|
||||
kernel_initializer: str = "glorot_uniform"
|
||||
kernel_size: PositiveInt
|
||||
padding: str = "SAME"
|
||||
strides: PositiveInt = 1
|
||||
use_bias: bool = Field(True)
|
||||
|
||||
|
||||
class BlockParams(ExtendedBaseModel):
|
||||
activation: Optional[str]
|
||||
conv: Optional[ConvParams]
|
||||
dense: Optional[DenseParams]
|
||||
residual: Optional[bool]
|
||||
|
||||
|
||||
class TopLayerParams(ExtendedBaseModel):
|
||||
n_labels: PositiveInt
|
||||
|
||||
|
||||
class ClemNetParams(ExtendedBaseModel):
|
||||
blocks: List[BlockParams] = []
|
||||
top: Optional[TopLayerParams]
|
@ -0,0 +1,34 @@
|
||||
"""
|
||||
Candidate architectures for each task's.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict
|
||||
|
||||
from .features import get_features
|
||||
from .graph import Graph
|
||||
from .lib.model import ClemNet
|
||||
from .params import ModelTypeEnum
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
class MagicRecsClemNet(Graph):
|
||||
def get_logits(self, features: Dict[str, tf.Tensor], training: bool) -> tf.Tensor:
|
||||
|
||||
with tf.name_scope("logits"):
|
||||
inputs = get_features(features=features, training=training, params=self.params.model.features)
|
||||
|
||||
with tf.name_scope("OONC_logits"):
|
||||
model = ClemNet(params=self.params.model.architecture)
|
||||
oonc_logit = model(inputs=inputs, training=training)
|
||||
|
||||
with tf.name_scope("EngagementGivenOONC_logits"):
|
||||
model = ClemNet(params=self.params.model.architecture)
|
||||
eng_logits = model(inputs=inputs, training=training)
|
||||
|
||||
return tf.concat([oonc_logit, eng_logits], axis=1)
|
||||
|
||||
|
||||
ALL_MODELS = {ModelTypeEnum.clemnet: MagicRecsClemNet}
|
89
pushservice/src/main/python/models/heavy_ranking/params.py
Normal file
89
pushservice/src/main/python/models/heavy_ranking/params.py
Normal file
@ -0,0 +1,89 @@
|
||||
import enum
|
||||
import json
|
||||
from typing import List, Optional
|
||||
|
||||
from .lib.params import BlockParams, ClemNetParams, ConvParams, DenseParams, TopLayerParams
|
||||
|
||||
from pydantic import BaseModel, Extra, NonNegativeFloat
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
|
||||
# checkstyle: noqa
|
||||
|
||||
|
||||
class ExtendedBaseModel(BaseModel):
|
||||
class Config:
|
||||
extra = Extra.forbid
|
||||
|
||||
|
||||
class SparseFeaturesParams(ExtendedBaseModel):
|
||||
bits: int
|
||||
embedding_size: int
|
||||
|
||||
|
||||
class FeaturesParams(ExtendedBaseModel):
|
||||
sparse_features: Optional[SparseFeaturesParams]
|
||||
|
||||
|
||||
class ModelTypeEnum(str, enum.Enum):
|
||||
clemnet: str = "clemnet"
|
||||
|
||||
|
||||
class ModelParams(ExtendedBaseModel):
|
||||
name: ModelTypeEnum
|
||||
features: FeaturesParams
|
||||
architecture: ClemNetParams
|
||||
|
||||
|
||||
class TaskNameEnum(str, enum.Enum):
|
||||
oonc: str = "OONC"
|
||||
engagement: str = "Engagement"
|
||||
|
||||
|
||||
class Task(ExtendedBaseModel):
|
||||
name: TaskNameEnum
|
||||
label: str
|
||||
score_weight: NonNegativeFloat
|
||||
|
||||
|
||||
DEFAULT_TASKS = [
|
||||
Task(name=TaskNameEnum.oonc, label="label", score_weight=0.9),
|
||||
Task(name=TaskNameEnum.engagement, label="label.engagement", score_weight=0.1),
|
||||
]
|
||||
|
||||
|
||||
class GraphParams(ExtendedBaseModel):
|
||||
tasks: List[Task] = DEFAULT_TASKS
|
||||
model: ModelParams
|
||||
weight: Optional[str]
|
||||
|
||||
|
||||
DEFAULT_ARCHITECTURE_PARAMS = ClemNetParams(
|
||||
blocks=[
|
||||
BlockParams(
|
||||
activation="relu",
|
||||
conv=ConvParams(kernel_size=3, filters=5),
|
||||
dense=DenseParams(output_size=output_size),
|
||||
residual=False,
|
||||
)
|
||||
for output_size in [1024, 512, 256, 128]
|
||||
],
|
||||
top=TopLayerParams(n_labels=1),
|
||||
)
|
||||
|
||||
DEFAULT_GRAPH_PARAMS = GraphParams(
|
||||
model=ModelParams(
|
||||
name=ModelTypeEnum.clemnet,
|
||||
architecture=DEFAULT_ARCHITECTURE_PARAMS,
|
||||
features=FeaturesParams(sparse_features=SparseFeaturesParams(bits=18, embedding_size=50)),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def load_graph_params(args) -> GraphParams:
|
||||
params = DEFAULT_GRAPH_PARAMS
|
||||
if args.param_file:
|
||||
with tf.io.gfile.GFile(args.param_file, mode="r+") as file:
|
||||
params = GraphParams.parse_obj(json.load(file))
|
||||
|
||||
return params
|
59
pushservice/src/main/python/models/heavy_ranking/run_args.py
Normal file
59
pushservice/src/main/python/models/heavy_ranking/run_args.py
Normal file
@ -0,0 +1,59 @@
|
||||
from twml.trainers import DataRecordTrainer
|
||||
|
||||
from .features import FEATURE_LIST_DEFAULT_PATH
|
||||
|
||||
|
||||
def get_training_arg_parser():
|
||||
parser = DataRecordTrainer.add_parser_arguments()
|
||||
|
||||
parser.add_argument(
|
||||
"--feature_list",
|
||||
default=FEATURE_LIST_DEFAULT_PATH,
|
||||
type=str,
|
||||
help="Which features to use for training",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--param_file",
|
||||
default=None,
|
||||
type=str,
|
||||
help="Path to JSON file containing the graph parameters. If None, model will load default parameters.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--directly_export_best",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="whether to directly_export best_checkpoint",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--warm_start_from", default=None, type=str, help="model dir to warm start from"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--warm_start_base_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
help="latest ckpt in this folder will be used to ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model_type",
|
||||
default=None,
|
||||
type=str,
|
||||
help="Which type of model to train.",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def get_eval_arg_parser():
|
||||
parser = get_training_arg_parser()
|
||||
parser.add_argument(
|
||||
"--eval_checkpoint",
|
||||
default=None,
|
||||
type=str,
|
||||
help="Which checkpoint to use for evaluation",
|
||||
)
|
||||
|
||||
return parser
|
@ -0,0 +1,146 @@
|
||||
"""
|
||||
Model for modifying the checkpoints of the magic recs cnn Model with addition, deletion, and reordering
|
||||
of continuous and binary features.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from twitter.deepbird.projects.magic_recs.libs.get_feat_config import FEATURE_LIST_DEFAULT_PATH
|
||||
from twitter.deepbird.projects.magic_recs.libs.warm_start_utils_v11 import (
|
||||
get_feature_list_for_heavy_ranking,
|
||||
mkdirp,
|
||||
rename_dir,
|
||||
rmdir,
|
||||
warm_start_checkpoint,
|
||||
)
|
||||
import twml
|
||||
from twml.trainers import DataRecordTrainer
|
||||
|
||||
import tensorflow.compat.v1 as tf
|
||||
from tensorflow.compat.v1 import logging
|
||||
|
||||
|
||||
def get_arg_parser():
|
||||
parser = DataRecordTrainer.add_parser_arguments()
|
||||
parser.add_argument(
|
||||
"--model_type",
|
||||
default="deepnorm_gbdt_inputdrop2_rescale",
|
||||
type=str,
|
||||
help="specify the model type to use.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model_trainer_name",
|
||||
default="None",
|
||||
type=str,
|
||||
help="deprecated, added here just for api compatibility.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--warm_start_base_dir",
|
||||
default="none",
|
||||
type=str,
|
||||
help="latest ckpt in this folder will be used.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output_checkpoint_dir",
|
||||
default="none",
|
||||
type=str,
|
||||
help="Output folder for warm started ckpt. If none, it will move warm_start_base_dir to backup, and overwrite it",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--feature_list",
|
||||
default="none",
|
||||
type=str,
|
||||
help="Which features to use for training",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--old_feature_list",
|
||||
default="none",
|
||||
type=str,
|
||||
help="Which features to use for training",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def get_params(args=None):
|
||||
parser = get_arg_parser()
|
||||
if args is None:
|
||||
return parser.parse_args()
|
||||
else:
|
||||
return parser.parse_args(args)
|
||||
|
||||
|
||||
def _main():
|
||||
opt = get_params()
|
||||
logging.info("parse is: ")
|
||||
logging.info(opt)
|
||||
|
||||
if opt.feature_list == "none":
|
||||
feature_list_path = FEATURE_LIST_DEFAULT_PATH
|
||||
else:
|
||||
feature_list_path = opt.feature_list
|
||||
|
||||
if opt.warm_start_base_dir != "none" and tf.io.gfile.exists(opt.warm_start_base_dir):
|
||||
if opt.output_checkpoint_dir == "none" or opt.output_checkpoint_dir == opt.warm_start_base_dir:
|
||||
_warm_start_base_dir = os.path.normpath(opt.warm_start_base_dir) + "_backup_warm_start"
|
||||
_output_folder_dir = opt.warm_start_base_dir
|
||||
|
||||
rename_dir(opt.warm_start_base_dir, _warm_start_base_dir)
|
||||
tf.logging.info(f"moved {opt.warm_start_base_dir} to {_warm_start_base_dir}")
|
||||
else:
|
||||
_warm_start_base_dir = opt.warm_start_base_dir
|
||||
_output_folder_dir = opt.output_checkpoint_dir
|
||||
|
||||
continuous_binary_feat_list_save_path = os.path.join(
|
||||
_warm_start_base_dir, "continuous_binary_feat_list.json"
|
||||
)
|
||||
|
||||
if opt.old_feature_list != "none":
|
||||
tf.logging.info("getting old continuous_binary_feat_list")
|
||||
continuous_binary_feat_list = get_feature_list_for_heavy_ranking(
|
||||
opt.old_feature_list, opt.data_spec
|
||||
)
|
||||
rmdir(continuous_binary_feat_list_save_path)
|
||||
twml.util.write_file(
|
||||
continuous_binary_feat_list_save_path, continuous_binary_feat_list, encode="json"
|
||||
)
|
||||
tf.logging.info(f"Finish writting files to {continuous_binary_feat_list_save_path}")
|
||||
|
||||
warm_start_folder = os.path.join(_warm_start_base_dir, "best_checkpoint")
|
||||
if not tf.io.gfile.exists(warm_start_folder):
|
||||
warm_start_folder = _warm_start_base_dir
|
||||
|
||||
rmdir(_output_folder_dir)
|
||||
mkdirp(_output_folder_dir)
|
||||
|
||||
new_ckpt = warm_start_checkpoint(
|
||||
warm_start_folder,
|
||||
continuous_binary_feat_list_save_path,
|
||||
feature_list_path,
|
||||
opt.data_spec,
|
||||
_output_folder_dir,
|
||||
opt.model_type,
|
||||
)
|
||||
logging.info(f"Created new ckpt {new_ckpt} from {warm_start_folder}")
|
||||
|
||||
tf.logging.info("getting new continuous_binary_feat_list")
|
||||
new_continuous_binary_feat_list_save_path = os.path.join(
|
||||
_output_folder_dir, "continuous_binary_feat_list.json"
|
||||
)
|
||||
continuous_binary_feat_list = get_feature_list_for_heavy_ranking(
|
||||
feature_list_path, opt.data_spec
|
||||
)
|
||||
rmdir(new_continuous_binary_feat_list_save_path)
|
||||
twml.util.write_file(
|
||||
new_continuous_binary_feat_list_save_path, continuous_binary_feat_list, encode="json"
|
||||
)
|
||||
tf.logging.info(f"Finish writting files to {new_continuous_binary_feat_list_save_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
_main()
|
16
pushservice/src/main/python/models/libs/BUILD
Normal file
16
pushservice/src/main/python/models/libs/BUILD
Normal file
@ -0,0 +1,16 @@
|
||||
python3_library(
|
||||
name = "libs",
|
||||
sources = ["*.py"],
|
||||
tags = [
|
||||
"bazel-compatible",
|
||||
"no-mypy",
|
||||
],
|
||||
dependencies = [
|
||||
"cortex/recsys/src/python/twitter/cortex/recsys/utils",
|
||||
"magicpony/common/file_access/src/python/twitter/magicpony/common/file_access",
|
||||
"src/python/twitter/cortex/ml/embeddings/deepbird",
|
||||
"src/python/twitter/cortex/ml/embeddings/deepbird/grouped_metrics",
|
||||
"src/python/twitter/deepbird/util/data",
|
||||
"twml:twml-nodeps",
|
||||
],
|
||||
)
|
0
pushservice/src/main/python/models/libs/__init__.py
Normal file
0
pushservice/src/main/python/models/libs/__init__.py
Normal file
@ -0,0 +1,56 @@
|
||||
# pylint: disable=no-member, arguments-differ, attribute-defined-outside-init, unused-argument
|
||||
"""
|
||||
Implementing Full Sparse Layer, allow specify use_binary_value in call() to
|
||||
overide default action.
|
||||
"""
|
||||
|
||||
from twml.layers import FullSparse as defaultFullSparse
|
||||
from twml.layers.full_sparse import sparse_dense_matmul
|
||||
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
|
||||
class FullSparse(defaultFullSparse):
|
||||
def call(self, inputs, use_binary_values=None, **kwargs): # pylint: disable=unused-argument
|
||||
"""The logic of the layer lives here.
|
||||
|
||||
Arguments:
|
||||
inputs:
|
||||
A SparseTensor or a list of SparseTensors.
|
||||
If `inputs` is a list, all tensors must have same `dense_shape`.
|
||||
|
||||
Returns:
|
||||
- If `inputs` is `SparseTensor`, then returns `bias + inputs * dense_b`.
|
||||
- If `inputs` is a `list[SparseTensor`, then returns
|
||||
`bias + add_n([sp_a * dense_b for sp_a in inputs])`.
|
||||
"""
|
||||
|
||||
if use_binary_values is not None:
|
||||
default_use_binary_values = use_binary_values
|
||||
else:
|
||||
default_use_binary_values = self.use_binary_values
|
||||
|
||||
if isinstance(default_use_binary_values, (list, tuple)):
|
||||
raise ValueError(
|
||||
"use_binary_values can not be %s when inputs is %s"
|
||||
% (type(default_use_binary_values), type(inputs))
|
||||
)
|
||||
|
||||
outputs = sparse_dense_matmul(
|
||||
inputs,
|
||||
self.weight,
|
||||
self.use_sparse_grads,
|
||||
default_use_binary_values,
|
||||
name="sparse_mm",
|
||||
partition_axis=self.partition_axis,
|
||||
num_partitions=self.num_partitions,
|
||||
compress_ids=self._use_compression,
|
||||
cast_indices_dtype=self._cast_indices_dtype,
|
||||
)
|
||||
|
||||
if self.bias is not None:
|
||||
outputs = tf.nn.bias_add(outputs, self.bias)
|
||||
|
||||
if self.activation is not None:
|
||||
return self.activation(outputs) # pylint: disable=not-callable
|
||||
return outputs
|
176
pushservice/src/main/python/models/libs/get_feat_config.py
Normal file
176
pushservice/src/main/python/models/libs/get_feat_config.py
Normal file
@ -0,0 +1,176 @@
|
||||
import os
|
||||
|
||||
from twitter.deepbird.projects.magic_recs.libs.metric_fn_utils import USER_AGE_FEATURE_NAME
|
||||
from twitter.deepbird.projects.magic_recs.libs.model_utils import read_config
|
||||
from twml.contrib import feature_config as contrib_feature_config
|
||||
|
||||
|
||||
# checkstyle: noqa
|
||||
|
||||
FEAT_CONFIG_DEFAULT_VAL = -1.23456789
|
||||
|
||||
DEFAULT_INPUT_SIZE_BITS = 18
|
||||
|
||||
DEFAULT_FEATURE_LIST_PATH = "./feature_list_default.yaml"
|
||||
FEATURE_LIST_DEFAULT_PATH = os.path.join(
|
||||
os.path.dirname(os.path.realpath(__file__)), DEFAULT_FEATURE_LIST_PATH
|
||||
)
|
||||
|
||||
DEFAULT_FEATURE_LIST_LIGHT_RANKING_PATH = "./feature_list_light_ranking.yaml"
|
||||
FEATURE_LIST_DEFAULT_LIGHT_RANKING_PATH = os.path.join(
|
||||
os.path.dirname(os.path.realpath(__file__)), DEFAULT_FEATURE_LIST_LIGHT_RANKING_PATH
|
||||
)
|
||||
|
||||
FEATURE_LIST_DEFAULT = read_config(FEATURE_LIST_DEFAULT_PATH).items()
|
||||
FEATURE_LIST_LIGHT_RANKING_DEFAULT = read_config(FEATURE_LIST_DEFAULT_LIGHT_RANKING_PATH).items()
|
||||
|
||||
|
||||
LABELS = ["label"]
|
||||
LABELS_MTL = {"OONC": ["label"], "OONC_Engagement": ["label", "label.engagement"]}
|
||||
LABELS_LR = {
|
||||
"Sent": ["label.sent"],
|
||||
"HeavyRankPosition": ["meta.ranking.is_top3"],
|
||||
"HeavyRankProbability": ["meta.ranking.weighted_oonc_model_score"],
|
||||
}
|
||||
|
||||
|
||||
def _get_new_feature_config_base(
|
||||
data_spec_path,
|
||||
labels,
|
||||
add_sparse_continous=True,
|
||||
add_gbdt=True,
|
||||
add_user_id=False,
|
||||
add_timestamp=False,
|
||||
add_user_age=False,
|
||||
feature_list_provided=[],
|
||||
opt=None,
|
||||
run_light_ranking_group_metrics_in_bq=False,
|
||||
):
|
||||
"""
|
||||
Getter of the feature config based on specification.
|
||||
|
||||
Args:
|
||||
data_spec_path: A string indicating the path of the data_spec.json file, which could be
|
||||
either a local path or a hdfs path.
|
||||
labels: A list of strings indicating the name of the label in the data spec.
|
||||
add_sparse_continous: A bool indicating if sparse_continuous feature needs to be included.
|
||||
add_gbdt: A bool indicating if gbdt feature needs to be included.
|
||||
add_user_id: A bool indicating if user_id feature needs to be included.
|
||||
add_timestamp: A bool indicating if timestamp feature needs to be included. This will be useful
|
||||
for sequential models and meta learning models.
|
||||
add_user_age: A bool indicating if the user age feature needs to be included.
|
||||
feature_list_provided: A list of features thats need to be included. If not specified, will use
|
||||
FEATURE_LIST_DEFAULT by default.
|
||||
opt: A namespace of arguments indicating the hyparameters.
|
||||
run_light_ranking_group_metrics_in_bq: A bool indicating if heavy ranker score info needs to be included to compute group metrics in BigQuery.
|
||||
|
||||
Returns:
|
||||
A twml feature config object.
|
||||
"""
|
||||
|
||||
input_size_bits = DEFAULT_INPUT_SIZE_BITS if opt is None else opt.input_size_bits
|
||||
|
||||
feature_list = feature_list_provided if feature_list_provided != [] else FEATURE_LIST_DEFAULT
|
||||
a_string_feat_list = [f[0] for f in feature_list if f[1] != "S"]
|
||||
|
||||
builder = contrib_feature_config.FeatureConfigBuilder(data_spec_path=data_spec_path)
|
||||
|
||||
builder = builder.extract_feature_group(
|
||||
feature_regexes=a_string_feat_list,
|
||||
group_name="continuous",
|
||||
default_value=FEAT_CONFIG_DEFAULT_VAL,
|
||||
type_filter=["CONTINUOUS"],
|
||||
)
|
||||
|
||||
builder = builder.extract_features_as_hashed_sparse(
|
||||
feature_regexes=a_string_feat_list,
|
||||
output_tensor_name="sparse_no_continuous",
|
||||
hash_space_size_bits=input_size_bits,
|
||||
type_filter=["BINARY", "DISCRETE", "STRING", "SPARSE_BINARY"],
|
||||
)
|
||||
|
||||
if add_gbdt:
|
||||
builder = builder.extract_features_as_hashed_sparse(
|
||||
feature_regexes=["ads\..*"],
|
||||
output_tensor_name="gbdt_sparse",
|
||||
hash_space_size_bits=input_size_bits,
|
||||
)
|
||||
|
||||
if add_sparse_continous:
|
||||
s_string_feat_list = [f[0] for f in feature_list if f[1] == "S"]
|
||||
|
||||
builder = builder.extract_features_as_hashed_sparse(
|
||||
feature_regexes=s_string_feat_list,
|
||||
output_tensor_name="sparse_continuous",
|
||||
hash_space_size_bits=input_size_bits,
|
||||
type_filter=["SPARSE_CONTINUOUS"],
|
||||
)
|
||||
|
||||
if add_user_id:
|
||||
builder = builder.extract_feature("meta.user_id")
|
||||
if add_timestamp:
|
||||
builder = builder.extract_feature("meta.timestamp")
|
||||
if add_user_age:
|
||||
builder = builder.extract_feature(USER_AGE_FEATURE_NAME)
|
||||
|
||||
if run_light_ranking_group_metrics_in_bq:
|
||||
builder = builder.extract_feature("meta.trace_id")
|
||||
builder = builder.extract_feature("meta.ranking.weighted_oonc_model_score")
|
||||
|
||||
builder = builder.add_labels(labels).define_weight("meta.weight")
|
||||
|
||||
return builder.build()
|
||||
|
||||
|
||||
def get_feature_config_with_sparse_continuous(
|
||||
data_spec_path,
|
||||
feature_list_provided=[],
|
||||
opt=None,
|
||||
add_user_id=False,
|
||||
add_timestamp=False,
|
||||
add_user_age=False,
|
||||
):
|
||||
task_name = opt.task_name if getattr(opt, "task_name", None) is not None else "OONC"
|
||||
if task_name not in LABELS_MTL:
|
||||
raise ValueError("Invalid Task Name !")
|
||||
|
||||
return _get_new_feature_config_base(
|
||||
data_spec_path=data_spec_path,
|
||||
labels=LABELS_MTL[task_name],
|
||||
add_sparse_continous=True,
|
||||
add_user_id=add_user_id,
|
||||
add_timestamp=add_timestamp,
|
||||
add_user_age=add_user_age,
|
||||
feature_list_provided=feature_list_provided,
|
||||
opt=opt,
|
||||
)
|
||||
|
||||
|
||||
def get_feature_config_light_ranking(
|
||||
data_spec_path,
|
||||
feature_list_provided=[],
|
||||
opt=None,
|
||||
add_user_id=True,
|
||||
add_timestamp=False,
|
||||
add_user_age=False,
|
||||
add_gbdt=False,
|
||||
run_light_ranking_group_metrics_in_bq=False,
|
||||
):
|
||||
task_name = opt.task_name if getattr(opt, "task_name", None) is not None else "HeavyRankPosition"
|
||||
if task_name not in LABELS_LR:
|
||||
raise ValueError("Invalid Task Name !")
|
||||
if not feature_list_provided:
|
||||
feature_list_provided = FEATURE_LIST_LIGHT_RANKING_DEFAULT
|
||||
|
||||
return _get_new_feature_config_base(
|
||||
data_spec_path=data_spec_path,
|
||||
labels=LABELS_LR[task_name],
|
||||
add_sparse_continous=False,
|
||||
add_gbdt=add_gbdt,
|
||||
add_user_id=add_user_id,
|
||||
add_timestamp=add_timestamp,
|
||||
add_user_age=add_user_age,
|
||||
feature_list_provided=feature_list_provided,
|
||||
opt=opt,
|
||||
run_light_ranking_group_metrics_in_bq=run_light_ranking_group_metrics_in_bq,
|
||||
)
|
42
pushservice/src/main/python/models/libs/graph_utils.py
Normal file
42
pushservice/src/main/python/models/libs/graph_utils.py
Normal file
@ -0,0 +1,42 @@
|
||||
"""
|
||||
Utilties that aid in building the magic recs graph.
|
||||
"""
|
||||
|
||||
import re
|
||||
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
|
||||
def get_trainable_variables(all_trainable_variables, trainable_regexes):
|
||||
"""Returns a subset of trainable variables for training.
|
||||
|
||||
Given a collection of trainable variables, this will return all those that match the given regexes.
|
||||
Will also log those variables.
|
||||
|
||||
Args:
|
||||
all_trainable_variables (a collection of trainable tf.Variable): The variables to search through.
|
||||
trainable_regexes (a collection of regexes): Variables that match any regex will be included.
|
||||
|
||||
Returns a list of tf.Variable
|
||||
"""
|
||||
if trainable_regexes is None or len(trainable_regexes) == 0:
|
||||
tf.logging.info("No trainable regexes found. Not using get_trainable_variables behavior.")
|
||||
return None
|
||||
|
||||
assert any(
|
||||
tf.is_tensor(var) for var in all_trainable_variables
|
||||
), f"Non TF variable found: {all_trainable_variables}"
|
||||
trainable_variables = list(
|
||||
filter(
|
||||
lambda var: any(re.match(regex, var.name, re.IGNORECASE) for regex in trainable_regexes),
|
||||
all_trainable_variables,
|
||||
)
|
||||
)
|
||||
tf.logging.info(f"Using filtered trainable variables: {trainable_variables}")
|
||||
|
||||
assert (
|
||||
trainable_variables
|
||||
), "Did not find trainable variables after filtering after filtering from {} number of vars originaly. All vars: {} and train regexes: {}".format(
|
||||
len(all_trainable_variables), all_trainable_variables, trainable_regexes
|
||||
)
|
||||
return trainable_variables
|
114
pushservice/src/main/python/models/libs/group_metrics.py
Normal file
114
pushservice/src/main/python/models/libs/group_metrics.py
Normal file
@ -0,0 +1,114 @@
|
||||
import os
|
||||
import time
|
||||
|
||||
from twitter.cortex.ml.embeddings.deepbird.grouped_metrics.computation import (
|
||||
write_grouped_metrics_to_mldash,
|
||||
)
|
||||
from twitter.cortex.ml.embeddings.deepbird.grouped_metrics.configuration import (
|
||||
ClassificationGroupedMetricsConfiguration,
|
||||
NDCGGroupedMetricsConfiguration,
|
||||
)
|
||||
import twml
|
||||
|
||||
from .light_ranking_metrics import (
|
||||
CGRGroupedMetricsConfiguration,
|
||||
ExpectedLossGroupedMetricsConfiguration,
|
||||
RecallGroupedMetricsConfiguration,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
import tensorflow.compat.v1 as tf
|
||||
from tensorflow.compat.v1 import logging
|
||||
|
||||
|
||||
# checkstyle: noqa
|
||||
|
||||
|
||||
def run_group_metrics(trainer, data_dir, model_path, parse_fn, group_feature_name="meta.user_id"):
|
||||
|
||||
start_time = time.time()
|
||||
logging.info("Evaluating with group metrics.")
|
||||
|
||||
metrics = write_grouped_metrics_to_mldash(
|
||||
trainer=trainer,
|
||||
data_dir=data_dir,
|
||||
model_path=model_path,
|
||||
group_fn=lambda datarecord: str(
|
||||
datarecord.discreteFeatures[twml.feature_id(group_feature_name)[0]]
|
||||
),
|
||||
parse_fn=parse_fn,
|
||||
metric_configurations=[
|
||||
ClassificationGroupedMetricsConfiguration(),
|
||||
NDCGGroupedMetricsConfiguration(k=[5, 10, 20]),
|
||||
],
|
||||
total_records_to_read=1000000000,
|
||||
shuffle=False,
|
||||
mldash_metrics_name="grouped_metrics",
|
||||
)
|
||||
|
||||
end_time = time.time()
|
||||
logging.info(f"Evaluated Group Metics: {metrics}.")
|
||||
logging.info(f"Group metrics evaluation time {end_time - start_time}.")
|
||||
|
||||
|
||||
def run_group_metrics_light_ranking(
|
||||
trainer, data_dir, model_path, parse_fn, group_feature_name="meta.trace_id"
|
||||
):
|
||||
|
||||
start_time = time.time()
|
||||
logging.info("Evaluating with group metrics.")
|
||||
|
||||
metrics = write_grouped_metrics_to_mldash(
|
||||
trainer=trainer,
|
||||
data_dir=data_dir,
|
||||
model_path=model_path,
|
||||
group_fn=lambda datarecord: str(
|
||||
datarecord.discreteFeatures[twml.feature_id(group_feature_name)[0]]
|
||||
),
|
||||
parse_fn=parse_fn,
|
||||
metric_configurations=[
|
||||
CGRGroupedMetricsConfiguration(lightNs=[50, 100, 200], heavyKs=[1, 3, 10, 20, 50]),
|
||||
RecallGroupedMetricsConfiguration(n=[50, 100, 200], k=[1, 3, 10, 20, 50]),
|
||||
ExpectedLossGroupedMetricsConfiguration(lightNs=[50, 100, 200]),
|
||||
],
|
||||
total_records_to_read=10000000,
|
||||
num_batches_to_load=50,
|
||||
batch_size=1024,
|
||||
shuffle=False,
|
||||
mldash_metrics_name="grouped_metrics_for_light_ranking",
|
||||
)
|
||||
|
||||
end_time = time.time()
|
||||
logging.info(f"Evaluated Group Metics for Light Ranking: {metrics}.")
|
||||
logging.info(f"Group metrics evaluation time {end_time - start_time}.")
|
||||
|
||||
|
||||
def run_group_metrics_light_ranking_in_bq(trainer, params, checkpoint_path):
|
||||
logging.info("getting Test Predictions for Light Ranking Group Metrics in BigQuery !!!")
|
||||
eval_input_fn = trainer.get_eval_input_fn(repeat=False, shuffle=False)
|
||||
info_pool = []
|
||||
|
||||
for result in trainer.estimator.predict(
|
||||
eval_input_fn, checkpoint_path=checkpoint_path, yield_single_examples=False
|
||||
):
|
||||
traceID = result["trace_id"]
|
||||
pred = result["prediction"]
|
||||
label = result["target"]
|
||||
info = np.concatenate([traceID, pred, label], axis=1)
|
||||
info_pool.append(info)
|
||||
|
||||
info_pool = np.concatenate(info_pool)
|
||||
|
||||
locname = "/tmp/000/"
|
||||
if not os.path.exists(locname):
|
||||
os.makedirs(locname)
|
||||
|
||||
locfile = locname + params.pred_file_name
|
||||
columns = ["trace_id", "model_prediction", "meta__ranking__weighted_oonc_model_score"]
|
||||
np.savetxt(locfile, info_pool, delimiter=",", header=",".join(columns))
|
||||
tf.io.gfile.copy(locfile, params.pred_file_path + params.pred_file_name, overwrite=True)
|
||||
|
||||
if os.path.isfile(locfile):
|
||||
os.remove(locfile)
|
||||
|
||||
logging.info("Done Prediction for Light Ranking Group Metrics in BigQuery.")
|
118
pushservice/src/main/python/models/libs/initializer.py
Normal file
118
pushservice/src/main/python/models/libs/initializer.py
Normal file
@ -0,0 +1,118 @@
|
||||
import numpy as np
|
||||
from tensorflow.keras import backend as K
|
||||
|
||||
|
||||
class VarianceScaling(object):
|
||||
"""Initializer capable of adapting its scale to the shape of weights.
|
||||
With `distribution="normal"`, samples are drawn from a truncated normal
|
||||
distribution centered on zero, with `stddev = sqrt(scale / n)` where n is:
|
||||
- number of input units in the weight tensor, if mode = "fan_in"
|
||||
- number of output units, if mode = "fan_out"
|
||||
- average of the numbers of input and output units, if mode = "fan_avg"
|
||||
With `distribution="uniform"`,
|
||||
samples are drawn from a uniform distribution
|
||||
within [-limit, limit], with `limit = sqrt(3 * scale / n)`.
|
||||
# Arguments
|
||||
scale: Scaling factor (positive float).
|
||||
mode: One of "fan_in", "fan_out", "fan_avg".
|
||||
distribution: Random distribution to use. One of "normal", "uniform".
|
||||
seed: A Python integer. Used to seed the random generator.
|
||||
# Raises
|
||||
ValueError: In case of an invalid value for the "scale", mode" or
|
||||
"distribution" arguments."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scale=1.0,
|
||||
mode="fan_in",
|
||||
distribution="normal",
|
||||
seed=None,
|
||||
fan_in=None,
|
||||
fan_out=None,
|
||||
):
|
||||
self.fan_in = fan_in
|
||||
self.fan_out = fan_out
|
||||
if scale <= 0.0:
|
||||
raise ValueError("`scale` must be a positive float. Got:", scale)
|
||||
mode = mode.lower()
|
||||
if mode not in {"fan_in", "fan_out", "fan_avg"}:
|
||||
raise ValueError(
|
||||
"Invalid `mode` argument: " 'expected on of {"fan_in", "fan_out", "fan_avg"} ' "but got",
|
||||
mode,
|
||||
)
|
||||
distribution = distribution.lower()
|
||||
if distribution not in {"normal", "uniform"}:
|
||||
raise ValueError(
|
||||
"Invalid `distribution` argument: " 'expected one of {"normal", "uniform"} ' "but got",
|
||||
distribution,
|
||||
)
|
||||
self.scale = scale
|
||||
self.mode = mode
|
||||
self.distribution = distribution
|
||||
self.seed = seed
|
||||
|
||||
def __call__(self, shape, dtype=None, partition_info=None):
|
||||
fan_in = shape[-2] if self.fan_in is None else self.fan_in
|
||||
fan_out = shape[-1] if self.fan_out is None else self.fan_out
|
||||
|
||||
scale = self.scale
|
||||
if self.mode == "fan_in":
|
||||
scale /= max(1.0, fan_in)
|
||||
elif self.mode == "fan_out":
|
||||
scale /= max(1.0, fan_out)
|
||||
else:
|
||||
scale /= max(1.0, float(fan_in + fan_out) / 2)
|
||||
if self.distribution == "normal":
|
||||
stddev = np.sqrt(scale) / 0.87962566103423978
|
||||
return K.truncated_normal(shape, 0.0, stddev, dtype=dtype, seed=self.seed)
|
||||
else:
|
||||
limit = np.sqrt(3.0 * scale)
|
||||
return K.random_uniform(shape, -limit, limit, dtype=dtype, seed=self.seed)
|
||||
|
||||
def get_config(self):
|
||||
return {
|
||||
"scale": self.scale,
|
||||
"mode": self.mode,
|
||||
"distribution": self.distribution,
|
||||
"seed": self.seed,
|
||||
}
|
||||
|
||||
|
||||
def customized_glorot_uniform(seed=None, fan_in=None, fan_out=None):
|
||||
"""Glorot uniform initializer, also called Xavier uniform initializer.
|
||||
It draws samples from a uniform distribution within [-limit, limit]
|
||||
where `limit` is `sqrt(6 / (fan_in + fan_out))`
|
||||
where `fan_in` is the number of input units in the weight tensor
|
||||
and `fan_out` is the number of output units in the weight tensor.
|
||||
# Arguments
|
||||
seed: A Python integer. Used to seed the random generator.
|
||||
# Returns
|
||||
An initializer."""
|
||||
return VarianceScaling(
|
||||
scale=1.0,
|
||||
mode="fan_avg",
|
||||
distribution="uniform",
|
||||
seed=seed,
|
||||
fan_in=fan_in,
|
||||
fan_out=fan_out,
|
||||
)
|
||||
|
||||
|
||||
def customized_glorot_norm(seed=None, fan_in=None, fan_out=None):
|
||||
"""Glorot norm initializer, also called Xavier uniform initializer.
|
||||
It draws samples from a uniform distribution within [-limit, limit]
|
||||
where `limit` is `sqrt(6 / (fan_in + fan_out))`
|
||||
where `fan_in` is the number of input units in the weight tensor
|
||||
and `fan_out` is the number of output units in the weight tensor.
|
||||
# Arguments
|
||||
seed: A Python integer. Used to seed the random generator.
|
||||
# Returns
|
||||
An initializer."""
|
||||
return VarianceScaling(
|
||||
scale=1.0,
|
||||
mode="fan_avg",
|
||||
distribution="normal",
|
||||
seed=seed,
|
||||
fan_in=fan_in,
|
||||
fan_out=fan_out,
|
||||
)
|
255
pushservice/src/main/python/models/libs/light_ranking_metrics.py
Normal file
255
pushservice/src/main/python/models/libs/light_ranking_metrics.py
Normal file
@ -0,0 +1,255 @@
|
||||
from functools import partial
|
||||
|
||||
from twitter.cortex.ml.embeddings.deepbird.grouped_metrics.configuration import (
|
||||
GroupedMetricsConfiguration,
|
||||
)
|
||||
from twitter.cortex.ml.embeddings.deepbird.grouped_metrics.helpers import (
|
||||
extract_prediction_from_prediction_record,
|
||||
)
|
||||
|
||||
|
||||
# checkstyle: noqa
|
||||
|
||||
|
||||
def score_loss_at_n(labels, predictions, lightN):
|
||||
"""
|
||||
Compute the absolute ScoreLoss ranking metric
|
||||
Args:
|
||||
labels (list) : A list of label values (HeavyRanking Reference)
|
||||
predictions (list): A list of prediction values (LightRanking Predictions)
|
||||
lightN (int): size of the list at which of Initial candidates to compute ScoreLoss. (LightRanking)
|
||||
"""
|
||||
assert len(labels) == len(predictions)
|
||||
|
||||
if lightN <= 0:
|
||||
return None
|
||||
|
||||
labels_with_predictions = zip(labels, predictions)
|
||||
labels_with_sorted_predictions = sorted(
|
||||
labels_with_predictions, key=lambda x: x[1], reverse=True
|
||||
)[:lightN]
|
||||
labels_top1_light = max([label for label, _ in labels_with_sorted_predictions])
|
||||
labels_top1_heavy = max(labels)
|
||||
|
||||
return labels_top1_heavy - labels_top1_light
|
||||
|
||||
|
||||
def cgr_at_nk(labels, predictions, lightN, heavyK):
|
||||
"""
|
||||
Compute Cumulative Gain Ratio (CGR) ranking metric
|
||||
Args:
|
||||
labels (list) : A list of label values (HeavyRanking Reference)
|
||||
predictions (list): A list of prediction values (LightRanking Predictions)
|
||||
lightN (int): size of the list at which of Initial candidates to compute CGR. (LightRanking)
|
||||
heavyK (int): size of the list at which of Refined candidates to compute CGR. (HeavyRanking)
|
||||
"""
|
||||
assert len(labels) == len(predictions)
|
||||
|
||||
if (not lightN) or (not heavyK):
|
||||
out = None
|
||||
elif lightN <= 0 or heavyK <= 0:
|
||||
out = None
|
||||
else:
|
||||
|
||||
labels_with_predictions = zip(labels, predictions)
|
||||
labels_with_sorted_predictions = sorted(
|
||||
labels_with_predictions, key=lambda x: x[1], reverse=True
|
||||
)[:lightN]
|
||||
labels_topN_light = [label for label, _ in labels_with_sorted_predictions]
|
||||
|
||||
if lightN <= heavyK:
|
||||
cg_light = sum(labels_topN_light)
|
||||
else:
|
||||
labels_topK_heavy_from_light = sorted(labels_topN_light, reverse=True)[:heavyK]
|
||||
cg_light = sum(labels_topK_heavy_from_light)
|
||||
|
||||
ideal_ordering = sorted(labels, reverse=True)
|
||||
cg_heavy = sum(ideal_ordering[: min(lightN, heavyK)])
|
||||
|
||||
out = 0.0
|
||||
if cg_heavy != 0:
|
||||
out = max(cg_light / cg_heavy, 0)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def _get_weight(w, atK):
|
||||
if not w:
|
||||
return 1.0
|
||||
elif len(w) <= atK:
|
||||
return 0.0
|
||||
else:
|
||||
return w[atK]
|
||||
|
||||
|
||||
def recall_at_nk(labels, predictions, n=None, k=None, w=None):
|
||||
"""
|
||||
Recall at N-K ranking metric
|
||||
Args:
|
||||
labels (list): A list of label values
|
||||
predictions (list): A list of prediction values
|
||||
n (int): size of the list at which of predictions to compute recall. (Light Ranking Predictions)
|
||||
The default is None in which case the length of the provided predictions is used as L
|
||||
k (int): size of the list at which of labels to compute recall. (Heavy Ranking Predictions)
|
||||
The default is None in which case the length of the provided labels is used as L
|
||||
w (list): weight vector sorted by labels
|
||||
"""
|
||||
assert len(labels) == len(predictions)
|
||||
|
||||
if not any(labels):
|
||||
out = None
|
||||
else:
|
||||
|
||||
safe_n = len(predictions) if not n else min(len(predictions), n)
|
||||
safe_k = len(labels) if not k else min(len(labels), k)
|
||||
|
||||
labels_with_predictions = zip(labels, predictions)
|
||||
sorted_labels_with_predictions = sorted(
|
||||
labels_with_predictions, key=lambda x: x[0], reverse=True
|
||||
)
|
||||
|
||||
order_sorted_labels_predictions = zip(range(len(labels)), *zip(*sorted_labels_with_predictions))
|
||||
|
||||
order_with_predictions = [
|
||||
(order, pred) for order, label, pred in order_sorted_labels_predictions
|
||||
]
|
||||
order_with_sorted_predictions = sorted(order_with_predictions, key=lambda x: x[1], reverse=True)
|
||||
|
||||
pred_sorted_order_at_n = [order for order, _ in order_with_sorted_predictions][:safe_n]
|
||||
|
||||
intersection_weight = [
|
||||
_get_weight(w, order) if order < safe_k else 0 for order in pred_sorted_order_at_n
|
||||
]
|
||||
|
||||
intersection_score = sum(intersection_weight)
|
||||
full_score = sum(w) if w else float(safe_k)
|
||||
|
||||
out = 0.0
|
||||
if full_score != 0:
|
||||
out = intersection_score / full_score
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ExpectedLossGroupedMetricsConfiguration(GroupedMetricsConfiguration):
|
||||
"""
|
||||
This is the Expected Loss Grouped metric computation configuration.
|
||||
"""
|
||||
|
||||
def __init__(self, lightNs=[]):
|
||||
"""
|
||||
Args:
|
||||
lightNs (list): size of the list at which of Initial candidates to compute Expected Loss. (LightRanking)
|
||||
"""
|
||||
self.lightNs = lightNs
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return "ExpectedLoss"
|
||||
|
||||
@property
|
||||
def metrics_dict(self):
|
||||
metrics_to_compute = {}
|
||||
for lightN in self.lightNs:
|
||||
metric_name = "ExpectedLoss_atLight_" + str(lightN)
|
||||
metrics_to_compute[metric_name] = partial(score_loss_at_n, lightN=lightN)
|
||||
return metrics_to_compute
|
||||
|
||||
def extract_label(self, prec, drec, drec_label):
|
||||
return drec_label
|
||||
|
||||
def extract_prediction(self, prec, drec, drec_label):
|
||||
return extract_prediction_from_prediction_record(prec)
|
||||
|
||||
|
||||
class CGRGroupedMetricsConfiguration(GroupedMetricsConfiguration):
|
||||
"""
|
||||
This is the Cumulative Gain Ratio (CGR) Grouped metric computation configuration.
|
||||
CGR at the max length of each session is the default.
|
||||
CGR at additional positions can be computed by specifying a list of 'n's and 'k's
|
||||
"""
|
||||
|
||||
def __init__(self, lightNs=[], heavyKs=[]):
|
||||
"""
|
||||
Args:
|
||||
lightNs (list): size of the list at which of Initial candidates to compute CGR. (LightRanking)
|
||||
heavyK (int): size of the list at which of Refined candidates to compute CGR. (HeavyRanking)
|
||||
"""
|
||||
self.lightNs = lightNs
|
||||
self.heavyKs = heavyKs
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return "cgr"
|
||||
|
||||
@property
|
||||
def metrics_dict(self):
|
||||
metrics_to_compute = {}
|
||||
for lightN in self.lightNs:
|
||||
for heavyK in self.heavyKs:
|
||||
metric_name = "cgr_atLight_" + str(lightN) + "_atHeavy_" + str(heavyK)
|
||||
metrics_to_compute[metric_name] = partial(cgr_at_nk, lightN=lightN, heavyK=heavyK)
|
||||
return metrics_to_compute
|
||||
|
||||
def extract_label(self, prec, drec, drec_label):
|
||||
return drec_label
|
||||
|
||||
def extract_prediction(self, prec, drec, drec_label):
|
||||
return extract_prediction_from_prediction_record(prec)
|
||||
|
||||
|
||||
class RecallGroupedMetricsConfiguration(GroupedMetricsConfiguration):
|
||||
"""
|
||||
This is the Recall Grouped metric computation configuration.
|
||||
Recall at the max length of each session is the default.
|
||||
Recall at additional positions can be computed by specifying a list of 'n's and 'k's
|
||||
"""
|
||||
|
||||
def __init__(self, n=[], k=[], w=[]):
|
||||
"""
|
||||
Args:
|
||||
n (list): A list of ints. List of prediction rank thresholds (for light)
|
||||
k (list): A list of ints. List of label rank thresholds (for heavy)
|
||||
"""
|
||||
self.predN = n
|
||||
self.labelK = k
|
||||
self.weight = w
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return "group_recall"
|
||||
|
||||
@property
|
||||
def metrics_dict(self):
|
||||
metrics_to_compute = {"group_recall_unweighted": recall_at_nk}
|
||||
if not self.weight:
|
||||
metrics_to_compute["group_recall_weighted"] = partial(recall_at_nk, w=self.weight)
|
||||
|
||||
if self.predN and self.labelK:
|
||||
for n in self.predN:
|
||||
for k in self.labelK:
|
||||
if n >= k:
|
||||
metrics_to_compute[
|
||||
"group_recall_unweighted_at_L" + str(n) + "_at_H" + str(k)
|
||||
] = partial(recall_at_nk, n=n, k=k)
|
||||
if self.weight:
|
||||
metrics_to_compute[
|
||||
"group_recall_weighted_at_L" + str(n) + "_at_H" + str(k)
|
||||
] = partial(recall_at_nk, n=n, k=k, w=self.weight)
|
||||
|
||||
if self.labelK and not self.predN:
|
||||
for k in self.labelK:
|
||||
metrics_to_compute["group_recall_unweighted_at_full_at_H" + str(k)] = partial(
|
||||
recall_at_nk, k=k
|
||||
)
|
||||
if self.weight:
|
||||
metrics_to_compute["group_recall_weighted_at_full_at_H" + str(k)] = partial(
|
||||
recall_at_nk, k=k, w=self.weight
|
||||
)
|
||||
return metrics_to_compute
|
||||
|
||||
def extract_label(self, prec, drec, drec_label):
|
||||
return drec_label
|
||||
|
||||
def extract_prediction(self, prec, drec, drec_label):
|
||||
return extract_prediction_from_prediction_record(prec)
|
294
pushservice/src/main/python/models/libs/metric_fn_utils.py
Normal file
294
pushservice/src/main/python/models/libs/metric_fn_utils.py
Normal file
@ -0,0 +1,294 @@
|
||||
"""
|
||||
Utilties for constructing a metric_fn for magic recs.
|
||||
"""
|
||||
|
||||
from twml.contrib.metrics.metrics import (
|
||||
get_dual_binary_tasks_metric_fn,
|
||||
get_numeric_metric_fn,
|
||||
get_partial_multi_binary_class_metric_fn,
|
||||
get_single_binary_task_metric_fn,
|
||||
)
|
||||
|
||||
from .model_utils import generate_disliked_mask
|
||||
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
|
||||
METRIC_BOOK = {
|
||||
"OONC": ["OONC"],
|
||||
"OONC_Engagement": ["OONC", "Engagement"],
|
||||
"Sent": ["Sent"],
|
||||
"HeavyRankPosition": ["HeavyRankPosition"],
|
||||
"HeavyRankProbability": ["HeavyRankProbability"],
|
||||
}
|
||||
|
||||
USER_AGE_FEATURE_NAME = "accountAge"
|
||||
NEW_USER_AGE_CUTOFF = 0
|
||||
|
||||
|
||||
def remove_padding_and_flatten(tensor, valid_batch_size):
|
||||
"""Remove the padding of the input padded tensor given the valid batch size tensor,
|
||||
then flatten the output with respect to the first dimension.
|
||||
Args:
|
||||
tensor: A tensor of size [META_BATCH_SIZE, BATCH_SIZE, FEATURE_DIM].
|
||||
valid_batch_size: A tensor of size [META_BATCH_SIZE], with each element indicating
|
||||
the effective batch size of the BATCH_SIZE dimension.
|
||||
|
||||
Returns:
|
||||
A tesnor of size [tf.reduce_sum(valid_batch_size), FEATURE_DIM].
|
||||
"""
|
||||
unpadded_ragged_tensor = tf.RaggedTensor.from_tensor(tensor=tensor, lengths=valid_batch_size)
|
||||
|
||||
return unpadded_ragged_tensor.flat_values
|
||||
|
||||
|
||||
def safe_mask(values, mask):
|
||||
"""Mask values if possible.
|
||||
|
||||
Boolean mask inputed values if and only if values is a tensor of the same dimension as mask (or can be broadcasted to that dimension).
|
||||
|
||||
Args:
|
||||
values (Any or Tensor): Input tensor to mask. Dim 0 should be size N.
|
||||
mask (boolean tensor): A boolean tensor of size N.
|
||||
|
||||
Returns Values or Values masked.
|
||||
"""
|
||||
if values is None:
|
||||
return values
|
||||
if not tf.is_tensor(values):
|
||||
return values
|
||||
values_shape = values.get_shape()
|
||||
if not values_shape or len(values_shape) == 0:
|
||||
return values
|
||||
if not mask.get_shape().is_compatible_with(values_shape[0]):
|
||||
return values
|
||||
return tf.boolean_mask(values, mask)
|
||||
|
||||
|
||||
def add_new_user_metrics(metric_fn):
|
||||
"""Will stratify the metric_fn by adding new user metrics.
|
||||
|
||||
Given an input metric_fn, double every metric: One will be the orignal and the other will only include those for new users.
|
||||
|
||||
Args:
|
||||
metric_fn (python function): Base twml metric_fn.
|
||||
|
||||
Returns a metric_fn with new user metrics included.
|
||||
"""
|
||||
|
||||
def metric_fn_with_new_users(graph_output, labels, weights):
|
||||
if USER_AGE_FEATURE_NAME not in graph_output:
|
||||
raise ValueError(
|
||||
"In order to get metrics stratified by user age, {name} feature should be added to model graph output. However, only the following output keys were found: {keys}.".format(
|
||||
name=USER_AGE_FEATURE_NAME, keys=graph_output.keys()
|
||||
)
|
||||
)
|
||||
|
||||
metric_ops = metric_fn(graph_output, labels, weights)
|
||||
|
||||
is_new = tf.reshape(
|
||||
tf.math.less_equal(
|
||||
tf.cast(graph_output[USER_AGE_FEATURE_NAME], tf.int64),
|
||||
tf.cast(NEW_USER_AGE_CUTOFF, tf.int64),
|
||||
),
|
||||
[-1],
|
||||
)
|
||||
|
||||
labels = safe_mask(labels, is_new)
|
||||
weights = safe_mask(weights, is_new)
|
||||
graph_output = {key: safe_mask(values, is_new) for key, values in graph_output.items()}
|
||||
|
||||
new_user_metric_ops = metric_fn(graph_output, labels, weights)
|
||||
new_user_metric_ops = {name + "_new_users": ops for name, ops in new_user_metric_ops.items()}
|
||||
metric_ops.update(new_user_metric_ops)
|
||||
return metric_ops
|
||||
|
||||
return metric_fn_with_new_users
|
||||
|
||||
|
||||
def get_meta_learn_single_binary_task_metric_fn(
|
||||
metrics, classnames, top_k=(5, 5, 5), use_top_k=False
|
||||
):
|
||||
"""Wrapper function to use the metric_fn with meta learning evaluation scheme.
|
||||
|
||||
Args:
|
||||
metrics: A list of string representing metric names.
|
||||
classnames: A list of string repsenting class names, In case of multiple binary class models,
|
||||
the names for each class or label.
|
||||
top_k: A tuple of int to specify top K metrics.
|
||||
use_top_k: A boolean value indicating of top K of metrics is used.
|
||||
|
||||
Returns:
|
||||
A customized metric_fn function.
|
||||
"""
|
||||
|
||||
def get_eval_metric_ops(graph_output, labels, weights):
|
||||
"""The op func of the eval_metrics. Comparing with normal version,
|
||||
the difference is we flatten the output, label, and weights.
|
||||
|
||||
Args:
|
||||
graph_output: A dict of tensors.
|
||||
labels: A tensor of int32 be the value of either 0 or 1.
|
||||
weights: A tensor of float32 to indicate the per record weight.
|
||||
|
||||
Returns:
|
||||
A dict of metric names and values.
|
||||
"""
|
||||
metric_op_weighted = get_partial_multi_binary_class_metric_fn(
|
||||
metrics, predcols=0, classes=classnames
|
||||
)
|
||||
classnames_unweighted = ["unweighted_" + classname for classname in classnames]
|
||||
metric_op_unweighted = get_partial_multi_binary_class_metric_fn(
|
||||
metrics, predcols=0, classes=classnames_unweighted
|
||||
)
|
||||
|
||||
valid_batch_size = graph_output["valid_batch_size"]
|
||||
graph_output["output"] = remove_padding_and_flatten(graph_output["output"], valid_batch_size)
|
||||
labels = remove_padding_and_flatten(labels, valid_batch_size)
|
||||
weights = remove_padding_and_flatten(weights, valid_batch_size)
|
||||
|
||||
tf.ensure_shape(graph_output["output"], [None, 1])
|
||||
tf.ensure_shape(labels, [None, 1])
|
||||
tf.ensure_shape(weights, [None, 1])
|
||||
|
||||
metrics_weighted = metric_op_weighted(graph_output, labels, weights)
|
||||
metrics_unweighted = metric_op_unweighted(graph_output, labels, None)
|
||||
metrics_weighted.update(metrics_unweighted)
|
||||
|
||||
if use_top_k:
|
||||
metric_op_numeric = get_numeric_metric_fn(metrics=None, topK=top_k, predcol=0, labelcol=1)
|
||||
metrics_numeric = metric_op_numeric(graph_output, labels, weights)
|
||||
metrics_weighted.update(metrics_numeric)
|
||||
return metrics_weighted
|
||||
|
||||
return get_eval_metric_ops
|
||||
|
||||
|
||||
def get_meta_learn_dual_binary_tasks_metric_fn(
|
||||
metrics, classnames, top_k=(5, 5, 5), use_top_k=False
|
||||
):
|
||||
"""Wrapper function to use the metric_fn with meta learning evaluation scheme.
|
||||
|
||||
Args:
|
||||
metrics: A list of string representing metric names.
|
||||
classnames: A list of string repsenting class names, In case of multiple binary class models,
|
||||
the names for each class or label.
|
||||
top_k: A tuple of int to specify top K metrics.
|
||||
use_top_k: A boolean value indicating of top K of metrics is used.
|
||||
|
||||
Returns:
|
||||
A customized metric_fn function.
|
||||
"""
|
||||
|
||||
def get_eval_metric_ops(graph_output, labels, weights):
|
||||
"""The op func of the eval_metrics. Comparing with normal version,
|
||||
the difference is we flatten the output, label, and weights.
|
||||
|
||||
Args:
|
||||
graph_output: A dict of tensors.
|
||||
labels: A tensor of int32 be the value of either 0 or 1.
|
||||
weights: A tensor of float32 to indicate the per record weight.
|
||||
|
||||
Returns:
|
||||
A dict of metric names and values.
|
||||
"""
|
||||
metric_op_weighted = get_partial_multi_binary_class_metric_fn(
|
||||
metrics, predcols=[0, 1], classes=classnames
|
||||
)
|
||||
classnames_unweighted = ["unweighted_" + classname for classname in classnames]
|
||||
metric_op_unweighted = get_partial_multi_binary_class_metric_fn(
|
||||
metrics, predcols=[0, 1], classes=classnames_unweighted
|
||||
)
|
||||
|
||||
valid_batch_size = graph_output["valid_batch_size"]
|
||||
graph_output["output"] = remove_padding_and_flatten(graph_output["output"], valid_batch_size)
|
||||
labels = remove_padding_and_flatten(labels, valid_batch_size)
|
||||
weights = remove_padding_and_flatten(weights, valid_batch_size)
|
||||
|
||||
tf.ensure_shape(graph_output["output"], [None, 2])
|
||||
tf.ensure_shape(labels, [None, 2])
|
||||
tf.ensure_shape(weights, [None, 1])
|
||||
|
||||
metrics_weighted = metric_op_weighted(graph_output, labels, weights)
|
||||
metrics_unweighted = metric_op_unweighted(graph_output, labels, None)
|
||||
metrics_weighted.update(metrics_unweighted)
|
||||
|
||||
if use_top_k:
|
||||
metric_op_numeric = get_numeric_metric_fn(metrics=None, topK=top_k, predcol=2, labelcol=2)
|
||||
metrics_numeric = metric_op_numeric(graph_output, labels, weights)
|
||||
metrics_weighted.update(metrics_numeric)
|
||||
return metrics_weighted
|
||||
|
||||
return get_eval_metric_ops
|
||||
|
||||
|
||||
def get_metric_fn(task_name, use_stratify_metrics, use_meta_batch=False):
|
||||
"""Will retrieve the metric_fn for magic recs.
|
||||
|
||||
Args:
|
||||
task_name (string): Which task is being used for this model.
|
||||
use_stratify_metrics (boolean): Should we add stratified metrics (new user metrics).
|
||||
use_meta_batch (boolean): If the output/label/weights are passed in 3D shape instead of
|
||||
2D shape.
|
||||
|
||||
Returns:
|
||||
A metric_fn function to pass in twml Trainer.
|
||||
"""
|
||||
if task_name not in METRIC_BOOK:
|
||||
raise ValueError(
|
||||
"Task name of {task_name} not recognized. Unable to retrieve metrics.".format(
|
||||
task_name=task_name
|
||||
)
|
||||
)
|
||||
class_names = METRIC_BOOK[task_name]
|
||||
if use_meta_batch:
|
||||
get_n_binary_task_metric_fn = (
|
||||
get_meta_learn_single_binary_task_metric_fn
|
||||
if len(class_names) == 1
|
||||
else get_meta_learn_dual_binary_tasks_metric_fn
|
||||
)
|
||||
else:
|
||||
get_n_binary_task_metric_fn = (
|
||||
get_single_binary_task_metric_fn if len(class_names) == 1 else get_dual_binary_tasks_metric_fn
|
||||
)
|
||||
|
||||
metric_fn = get_n_binary_task_metric_fn(metrics=None, classnames=METRIC_BOOK[task_name])
|
||||
|
||||
if use_stratify_metrics:
|
||||
metric_fn = add_new_user_metrics(metric_fn)
|
||||
|
||||
return metric_fn
|
||||
|
||||
|
||||
def flip_disliked_labels(metric_fn):
|
||||
"""This function returns an adapted metric_fn which flips the labels of the OONCed evaluation data to 0 if it is disliked.
|
||||
Args:
|
||||
metric_fn: A metric_fn function to pass in twml Trainer.
|
||||
|
||||
Returns:
|
||||
_adapted_metric_fn: A customized metric_fn function with disliked OONC labels flipped.
|
||||
"""
|
||||
|
||||
def _adapted_metric_fn(graph_output, labels, weights):
|
||||
"""A customized metric_fn function with disliked OONC labels flipped.
|
||||
|
||||
Args:
|
||||
graph_output: A dict of tensors.
|
||||
labels: labels of training samples, which is a 2D tensor of shape batch_size x 3: [OONCs, engagements, dislikes]
|
||||
weights: A tensor of float32 to indicate the per record weight.
|
||||
|
||||
Returns:
|
||||
A dict of metric names and values.
|
||||
"""
|
||||
# We want to multiply the label of the observation by 0 only when it is disliked
|
||||
disliked_mask = generate_disliked_mask(labels)
|
||||
|
||||
# Extract OONC and engagement labels only.
|
||||
labels = tf.reshape(labels[:, 0:2], shape=[-1, 2])
|
||||
|
||||
# Labels will be set to 0 if it is disliked.
|
||||
adapted_labels = labels * tf.cast(tf.logical_not(disliked_mask), dtype=labels.dtype)
|
||||
|
||||
return metric_fn(graph_output, adapted_labels, weights)
|
||||
|
||||
return _adapted_metric_fn
|
231
pushservice/src/main/python/models/libs/model_args.py
Normal file
231
pushservice/src/main/python/models/libs/model_args.py
Normal file
@ -0,0 +1,231 @@
|
||||
from twml.trainers import DataRecordTrainer
|
||||
|
||||
|
||||
# checkstyle: noqa
|
||||
|
||||
|
||||
def get_arg_parser():
|
||||
parser = DataRecordTrainer.add_parser_arguments()
|
||||
|
||||
parser.add_argument(
|
||||
"--input_size_bits",
|
||||
type=int,
|
||||
default=18,
|
||||
help="number of bits allocated to the input size",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_trainer_name",
|
||||
default="magic_recs_mlp_calibration_MTL_OONC_Engagement",
|
||||
type=str,
|
||||
help="specify the model trainer name.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model_type",
|
||||
default="deepnorm_gbdt_inputdrop2_rescale",
|
||||
type=str,
|
||||
help="specify the model type to use.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--feat_config_type",
|
||||
default="get_feature_config_with_sparse_continuous",
|
||||
type=str,
|
||||
help="specify the feature configure function to use.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--directly_export_best",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="whether to directly_export best_checkpoint",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--warm_start_base_dir",
|
||||
default="none",
|
||||
type=str,
|
||||
help="latest ckpt in this folder will be used to ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--feature_list",
|
||||
default="none",
|
||||
type=str,
|
||||
help="Which features to use for training",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--warm_start_from", default=None, type=str, help="model dir to warm start from"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--momentum", default=0.99999, type=float, help="Momentum term for batch normalization"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dropout",
|
||||
default=0.2,
|
||||
type=float,
|
||||
help="input_dropout_rate to rescale output by (1 - input_dropout_rate)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--out_layer_1_size", default=256, type=int, help="Size of MLP_branch layer 1"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--out_layer_2_size", default=128, type=int, help="Size of MLP_branch layer 2"
|
||||
)
|
||||
parser.add_argument("--out_layer_3_size", default=64, type=int, help="Size of MLP_branch layer 3")
|
||||
parser.add_argument(
|
||||
"--sparse_embedding_size", default=50, type=int, help="Dimensionality of sparse embedding layer"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dense_embedding_size", default=128, type=int, help="Dimensionality of dense embedding layer"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use_uam_label",
|
||||
default=False,
|
||||
type=str,
|
||||
help="Whether to use uam_label or not",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--task_name",
|
||||
default="OONC_Engagement",
|
||||
type=str,
|
||||
help="specify the task name to use: OONC or OONC_Engagement.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--init_weight",
|
||||
default=0.9,
|
||||
type=float,
|
||||
help="Initial OONC Task Weight MTL: OONC+Engagement.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_engagement_weight",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="whether to use engagement weight for base model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mtl_num_extra_layers",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of Hidden Layers for each TaskBranch.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mtl_neuron_scale", type=int, default=4, help="Scaling Factor of Neurons in MTL Extra Layers."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_oonc_score",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="whether to use oonc score only or combined score.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_stratified_metrics",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Use stratified metrics: Break out new-user metrics.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--run_group_metrics",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Will run evaluation metrics grouped by user.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_full_scope",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Will add extra scope and naming to graph.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--trainable_regexes",
|
||||
default=None,
|
||||
nargs="*",
|
||||
help="The union of variables specified by the list of regexes will be considered trainable.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fine_tuning.ckpt_to_initialize_from",
|
||||
dest="fine_tuning_ckpt_to_initialize_from",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Checkpoint path from which to warm start. Indicates the pre-trained model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fine_tuning.warm_start_scope_regex",
|
||||
dest="fine_tuning_warm_start_scope_regex",
|
||||
type=str,
|
||||
default=None,
|
||||
help="All variables matching this will be restored.",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def get_params(args=None):
|
||||
parser = get_arg_parser()
|
||||
if args is None:
|
||||
return parser.parse_args()
|
||||
else:
|
||||
return parser.parse_args(args)
|
||||
|
||||
|
||||
def get_arg_parser_light_ranking():
|
||||
parser = get_arg_parser()
|
||||
|
||||
parser.add_argument(
|
||||
"--use_record_weight",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="whether to use record weight for base model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--min_record_weight", default=0.0, type=float, help="Minimum record weight to use."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--smooth_weight", default=0.0, type=float, help="Factor to smooth Rank Position Weight."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num_mlp_layers", type=int, default=3, help="Number of Hidden Layers for MLP model."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mlp_neuron_scale", type=int, default=4, help="Scaling Factor of Neurons in MLP Layers."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--run_light_ranking_group_metrics",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Will run evaluation metrics grouped by user for Light Ranking.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_missing_sub_branch",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Whether to use missing value sub-branch for Light Ranking.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_gbdt_features",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Whether to use GBDT features for Light Ranking.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--run_light_ranking_group_metrics_in_bq",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Whether to get_predictions for Light Ranking to compute group metrics in BigQuery.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pred_file_path",
|
||||
default=None,
|
||||
type=str,
|
||||
help="path",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pred_file_name",
|
||||
default=None,
|
||||
type=str,
|
||||
help="path",
|
||||
)
|
||||
return parser
|
339
pushservice/src/main/python/models/libs/model_utils.py
Normal file
339
pushservice/src/main/python/models/libs/model_utils.py
Normal file
@ -0,0 +1,339 @@
|
||||
import sys
|
||||
|
||||
import twml
|
||||
|
||||
from .initializer import customized_glorot_uniform
|
||||
|
||||
import tensorflow.compat.v1 as tf
|
||||
import yaml
|
||||
|
||||
|
||||
# checkstyle: noqa
|
||||
|
||||
|
||||
def read_config(whitelist_yaml_file):
|
||||
with tf.gfile.FastGFile(whitelist_yaml_file) as f:
|
||||
try:
|
||||
return yaml.safe_load(f)
|
||||
except yaml.YAMLError as exc:
|
||||
print(exc)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def _sparse_feature_fixup(features, input_size_bits):
|
||||
"""Rebuild a sparse tensor feature so that its dense shape attribute is present.
|
||||
|
||||
Arguments:
|
||||
features (SparseTensor): Sparse feature tensor of shape ``(B, sparse_feature_dim)``.
|
||||
input_size_bits (int): Number of columns in ``log2`` scale. Must be positive.
|
||||
|
||||
Returns:
|
||||
SparseTensor: Rebuilt and non-faulty version of `features`."""
|
||||
sparse_feature_dim = tf.constant(2**input_size_bits, dtype=tf.int64)
|
||||
sparse_shape = tf.stack([features.dense_shape[0], sparse_feature_dim])
|
||||
sparse_tf = tf.SparseTensor(features.indices, features.values, sparse_shape)
|
||||
return sparse_tf
|
||||
|
||||
|
||||
def self_atten_dense(input, out_dim, activation=None, use_bias=True, name=None):
|
||||
def safe_concat(base, suffix):
|
||||
"""Concats variables name components if base is given."""
|
||||
if not base:
|
||||
return base
|
||||
return f"{base}:{suffix}"
|
||||
|
||||
input_dim = input.shape.as_list()[1]
|
||||
|
||||
sigmoid_out = twml.layers.FullDense(
|
||||
input_dim, dtype=tf.float32, activation=tf.nn.sigmoid, name=safe_concat(name, "sigmoid_out")
|
||||
)(input)
|
||||
atten_input = sigmoid_out * input
|
||||
mlp_out = twml.layers.FullDense(
|
||||
out_dim,
|
||||
dtype=tf.float32,
|
||||
activation=activation,
|
||||
use_bias=use_bias,
|
||||
name=safe_concat(name, "mlp_out"),
|
||||
)(atten_input)
|
||||
return mlp_out
|
||||
|
||||
|
||||
def get_dense_out(input, out_dim, activation, dense_type):
|
||||
if dense_type == "full_dense":
|
||||
out = twml.layers.FullDense(out_dim, dtype=tf.float32, activation=activation)(input)
|
||||
elif dense_type == "self_atten_dense":
|
||||
out = self_atten_dense(input, out_dim, activation=activation)
|
||||
return out
|
||||
|
||||
|
||||
def get_input_trans_func(bn_normalized_dense, is_training):
|
||||
gw_normalized_dense = tf.expand_dims(bn_normalized_dense, -1)
|
||||
group_num = bn_normalized_dense.shape.as_list()[1]
|
||||
|
||||
gw_normalized_dense = GroupWiseTrans(group_num, 1, 8, name="groupwise_1", activation=tf.tanh)(
|
||||
gw_normalized_dense
|
||||
)
|
||||
gw_normalized_dense = GroupWiseTrans(group_num, 8, 4, name="groupwise_2", activation=tf.tanh)(
|
||||
gw_normalized_dense
|
||||
)
|
||||
gw_normalized_dense = GroupWiseTrans(group_num, 4, 1, name="groupwise_3", activation=tf.tanh)(
|
||||
gw_normalized_dense
|
||||
)
|
||||
|
||||
gw_normalized_dense = tf.squeeze(gw_normalized_dense, [-1])
|
||||
|
||||
bn_gw_normalized_dense = tf.layers.batch_normalization(
|
||||
gw_normalized_dense,
|
||||
training=is_training,
|
||||
renorm_momentum=0.9999,
|
||||
momentum=0.9999,
|
||||
renorm=is_training,
|
||||
trainable=True,
|
||||
)
|
||||
|
||||
return bn_gw_normalized_dense
|
||||
|
||||
|
||||
def tensor_dropout(
|
||||
input_tensor,
|
||||
rate,
|
||||
is_training,
|
||||
sparse_tensor=None,
|
||||
):
|
||||
"""
|
||||
Implements dropout layer for both dense and sparse input_tensor
|
||||
|
||||
Arguments:
|
||||
input_tensor:
|
||||
B x D dense tensor, or a sparse tensor
|
||||
rate (float32):
|
||||
dropout rate
|
||||
is_training (bool):
|
||||
training stage or not.
|
||||
sparse_tensor (bool):
|
||||
whether the input_tensor is sparse tensor or not. Default to be None, this value has to be passed explicitly.
|
||||
rescale_sparse_dropout (bool):
|
||||
Do we need to do rescaling or not.
|
||||
Returns:
|
||||
tensor dropped out"""
|
||||
if sparse_tensor == True:
|
||||
if is_training:
|
||||
with tf.variable_scope("sparse_dropout"):
|
||||
values = input_tensor.values
|
||||
keep_mask = tf.keras.backend.random_binomial(
|
||||
tf.shape(values), p=1 - rate, dtype=tf.float32, seed=None
|
||||
)
|
||||
keep_mask.set_shape([None])
|
||||
keep_mask = tf.cast(keep_mask, tf.bool)
|
||||
|
||||
keep_indices = tf.boolean_mask(input_tensor.indices, keep_mask, axis=0)
|
||||
keep_values = tf.boolean_mask(values, keep_mask, axis=0)
|
||||
|
||||
dropped_tensor = tf.SparseTensor(keep_indices, keep_values, input_tensor.dense_shape)
|
||||
return dropped_tensor
|
||||
else:
|
||||
return input_tensor
|
||||
elif sparse_tensor == False:
|
||||
return tf.layers.dropout(input_tensor, rate=rate, training=is_training)
|
||||
|
||||
|
||||
def adaptive_transformation(bn_normalized_dense, is_training, func_type="default"):
|
||||
assert func_type in [
|
||||
"default",
|
||||
"tiny",
|
||||
], f"fun_type can only be one of default and tiny, but get {func_type}"
|
||||
|
||||
gw_normalized_dense = tf.expand_dims(bn_normalized_dense, -1)
|
||||
group_num = bn_normalized_dense.shape.as_list()[1]
|
||||
|
||||
if func_type == "default":
|
||||
gw_normalized_dense = FastGroupWiseTrans(
|
||||
group_num, 1, 8, name="groupwise_1", activation=tf.tanh, init_multiplier=8
|
||||
)(gw_normalized_dense)
|
||||
|
||||
gw_normalized_dense = FastGroupWiseTrans(
|
||||
group_num, 8, 4, name="groupwise_2", activation=tf.tanh, init_multiplier=8
|
||||
)(gw_normalized_dense)
|
||||
|
||||
gw_normalized_dense = FastGroupWiseTrans(
|
||||
group_num, 4, 1, name="groupwise_3", activation=tf.tanh, init_multiplier=8
|
||||
)(gw_normalized_dense)
|
||||
elif func_type == "tiny":
|
||||
gw_normalized_dense = FastGroupWiseTrans(
|
||||
group_num, 1, 2, name="groupwise_1", activation=tf.tanh, init_multiplier=8
|
||||
)(gw_normalized_dense)
|
||||
|
||||
gw_normalized_dense = FastGroupWiseTrans(
|
||||
group_num, 2, 1, name="groupwise_2", activation=tf.tanh, init_multiplier=8
|
||||
)(gw_normalized_dense)
|
||||
|
||||
gw_normalized_dense = FastGroupWiseTrans(
|
||||
group_num, 1, 1, name="groupwise_3", activation=tf.tanh, init_multiplier=8
|
||||
)(gw_normalized_dense)
|
||||
|
||||
gw_normalized_dense = tf.squeeze(gw_normalized_dense, [-1])
|
||||
bn_gw_normalized_dense = tf.layers.batch_normalization(
|
||||
gw_normalized_dense,
|
||||
training=is_training,
|
||||
renorm_momentum=0.9999,
|
||||
momentum=0.9999,
|
||||
renorm=is_training,
|
||||
trainable=True,
|
||||
)
|
||||
|
||||
return bn_gw_normalized_dense
|
||||
|
||||
|
||||
class FastGroupWiseTrans(object):
|
||||
"""
|
||||
used to apply group-wise fully connected layers to the input.
|
||||
it applies a tiny, unique MLP to each individual feature."""
|
||||
|
||||
def __init__(self, group_num, input_dim, out_dim, name, activation=None, init_multiplier=1):
|
||||
self.group_num = group_num
|
||||
self.input_dim = input_dim
|
||||
self.out_dim = out_dim
|
||||
self.activation = activation
|
||||
self.init_multiplier = init_multiplier
|
||||
|
||||
self.w = tf.get_variable(
|
||||
name + "_group_weight",
|
||||
[1, group_num, input_dim, out_dim],
|
||||
initializer=customized_glorot_uniform(
|
||||
fan_in=input_dim * init_multiplier, fan_out=out_dim * init_multiplier
|
||||
),
|
||||
trainable=True,
|
||||
)
|
||||
self.b = tf.get_variable(
|
||||
name + "_group_bias",
|
||||
[1, group_num, out_dim],
|
||||
initializer=tf.constant_initializer(0.0),
|
||||
trainable=True,
|
||||
)
|
||||
|
||||
def __call__(self, input_tensor):
|
||||
"""
|
||||
input_tensor: batch_size x group_num x input_dim
|
||||
output_tensor: batch_size x group_num x out_dim"""
|
||||
input_tensor_expand = tf.expand_dims(input_tensor, axis=-1)
|
||||
|
||||
output_tensor = tf.add(
|
||||
tf.reduce_sum(tf.multiply(input_tensor_expand, self.w), axis=-2, keepdims=False),
|
||||
self.b,
|
||||
)
|
||||
|
||||
if self.activation is not None:
|
||||
output_tensor = self.activation(output_tensor)
|
||||
return output_tensor
|
||||
|
||||
|
||||
class GroupWiseTrans(object):
|
||||
"""
|
||||
Used to apply group fully connected layers to the input.
|
||||
"""
|
||||
|
||||
def __init__(self, group_num, input_dim, out_dim, name, activation=None):
|
||||
self.group_num = group_num
|
||||
self.input_dim = input_dim
|
||||
self.out_dim = out_dim
|
||||
self.activation = activation
|
||||
|
||||
w_list, b_list = [], []
|
||||
for idx in range(out_dim):
|
||||
this_w = tf.get_variable(
|
||||
name + f"_group_weight_{idx}",
|
||||
[1, group_num, input_dim],
|
||||
initializer=tf.keras.initializers.glorot_uniform(),
|
||||
trainable=True,
|
||||
)
|
||||
this_b = tf.get_variable(
|
||||
name + f"_group_bias_{idx}",
|
||||
[1, group_num, 1],
|
||||
initializer=tf.constant_initializer(0.0),
|
||||
trainable=True,
|
||||
)
|
||||
w_list.append(this_w)
|
||||
b_list.append(this_b)
|
||||
self.w_list = w_list
|
||||
self.b_list = b_list
|
||||
|
||||
def __call__(self, input_tensor):
|
||||
"""
|
||||
input_tensor: batch_size x group_num x input_dim
|
||||
output_tensor: batch_size x group_num x out_dim
|
||||
"""
|
||||
out_tensor_list = []
|
||||
for idx in range(self.out_dim):
|
||||
this_res = (
|
||||
tf.reduce_sum(input_tensor * self.w_list[idx], axis=-1, keepdims=True) + self.b_list[idx]
|
||||
)
|
||||
out_tensor_list.append(this_res)
|
||||
output_tensor = tf.concat(out_tensor_list, axis=-1)
|
||||
|
||||
if self.activation is not None:
|
||||
output_tensor = self.activation(output_tensor)
|
||||
return output_tensor
|
||||
|
||||
|
||||
def add_scalar_summary(var, name, name_scope="hist_dense_feature/"):
|
||||
with tf.name_scope("summaries/"):
|
||||
with tf.name_scope(name_scope):
|
||||
tf.summary.scalar(name, var)
|
||||
|
||||
|
||||
def add_histogram_summary(var, name, name_scope="hist_dense_feature/"):
|
||||
with tf.name_scope("summaries/"):
|
||||
with tf.name_scope(name_scope):
|
||||
tf.summary.histogram(name, tf.reshape(var, [-1]))
|
||||
|
||||
|
||||
def sparse_clip_by_value(sparse_tf, min_val, max_val):
|
||||
new_vals = tf.clip_by_value(sparse_tf.values, min_val, max_val)
|
||||
return tf.SparseTensor(sparse_tf.indices, new_vals, sparse_tf.dense_shape)
|
||||
|
||||
|
||||
def check_numerics_with_msg(tensor, message="", sparse_tensor=False):
|
||||
if sparse_tensor:
|
||||
values = tf.debugging.check_numerics(tensor.values, message=message)
|
||||
return tf.SparseTensor(tensor.indices, values, tensor.dense_shape)
|
||||
else:
|
||||
return tf.debugging.check_numerics(tensor, message=message)
|
||||
|
||||
|
||||
def pad_empty_sparse_tensor(tensor):
|
||||
dummy_tensor = tf.SparseTensor(
|
||||
indices=[[0, 0]],
|
||||
values=[0.00001],
|
||||
dense_shape=tensor.dense_shape,
|
||||
)
|
||||
result = tf.cond(
|
||||
tf.equal(tf.size(tensor.values), 0),
|
||||
lambda: dummy_tensor,
|
||||
lambda: tensor,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def filter_nans_and_infs(tensor, sparse_tensor=False):
|
||||
if sparse_tensor:
|
||||
sparse_values = tensor.values
|
||||
filtered_val = tf.where(
|
||||
tf.logical_or(tf.is_nan(sparse_values), tf.is_inf(sparse_values)),
|
||||
tf.zeros_like(sparse_values),
|
||||
sparse_values,
|
||||
)
|
||||
return tf.SparseTensor(tensor.indices, filtered_val, tensor.dense_shape)
|
||||
else:
|
||||
return tf.where(
|
||||
tf.logical_or(tf.is_nan(tensor), tf.is_inf(tensor)), tf.zeros_like(tensor), tensor
|
||||
)
|
||||
|
||||
|
||||
def generate_disliked_mask(labels):
|
||||
"""Generate a disliked mask where only samples with dislike labels are set to 1 otherwise set to 0.
|
||||
Args:
|
||||
labels: labels of training samples, which is a 2D tensor of shape batch_size x 3: [OONCs, engagements, dislikes]
|
||||
Returns:
|
||||
1D tensor of shape batch_size x 1: [dislikes (booleans)]
|
||||
"""
|
||||
return tf.equal(tf.reshape(labels[:, 2], shape=[-1, 1]), 1)
|
309
pushservice/src/main/python/models/libs/warm_start_utils.py
Normal file
309
pushservice/src/main/python/models/libs/warm_start_utils.py
Normal file
@ -0,0 +1,309 @@
|
||||
from collections import OrderedDict
|
||||
import json
|
||||
import os
|
||||
from os.path import join
|
||||
|
||||
from twitter.magicpony.common import file_access
|
||||
import twml
|
||||
|
||||
from .model_utils import read_config
|
||||
|
||||
import numpy as np
|
||||
from scipy import stats
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
|
||||
# checkstyle: noqa
|
||||
|
||||
|
||||
def get_model_type_to_tensors_to_change_axis():
|
||||
model_type_to_tensors_to_change_axis = {
|
||||
"magic_recs/model/batch_normalization/beta": ([0], "continuous"),
|
||||
"magic_recs/model/batch_normalization/gamma": ([0], "continuous"),
|
||||
"magic_recs/model/batch_normalization/moving_mean": ([0], "continuous"),
|
||||
"magic_recs/model/batch_normalization/moving_stddev": ([0], "continuous"),
|
||||
"magic_recs/model/batch_normalization/moving_variance": ([0], "continuous"),
|
||||
"magic_recs/model/batch_normalization/renorm_mean": ([0], "continuous"),
|
||||
"magic_recs/model/batch_normalization/renorm_stddev": ([0], "continuous"),
|
||||
"magic_recs/model/logits/EngagementGivenOONC_logits/clem_net_1/block2_4/channel_wise_dense_4/kernel": (
|
||||
[1],
|
||||
"all",
|
||||
),
|
||||
"magic_recs/model/logits/OONC_logits/clem_net/block2/channel_wise_dense/kernel": ([1], "all"),
|
||||
}
|
||||
|
||||
return model_type_to_tensors_to_change_axis
|
||||
|
||||
|
||||
def mkdirp(dirname):
|
||||
if not tf.io.gfile.exists(dirname):
|
||||
tf.io.gfile.makedirs(dirname)
|
||||
|
||||
|
||||
def rename_dir(dirname, dst):
|
||||
file_access.hdfs.mv(dirname, dst)
|
||||
|
||||
|
||||
def rmdir(dirname):
|
||||
if tf.io.gfile.exists(dirname):
|
||||
if tf.io.gfile.isdir(dirname):
|
||||
tf.io.gfile.rmtree(dirname)
|
||||
else:
|
||||
tf.io.gfile.remove(dirname)
|
||||
|
||||
|
||||
def get_var_dict(checkpoint_path):
|
||||
checkpoint = tf.train.get_checkpoint_state(checkpoint_path)
|
||||
var_dict = OrderedDict()
|
||||
with tf.Session() as sess:
|
||||
all_var_list = tf.train.list_variables(checkpoint_path)
|
||||
for var_name, _ in all_var_list:
|
||||
# Load the variable
|
||||
var = tf.train.load_variable(checkpoint_path, var_name)
|
||||
var_dict[var_name] = var
|
||||
return var_dict
|
||||
|
||||
|
||||
def get_continunous_mapping_from_feat_list(old_feature_list, new_feature_list):
|
||||
"""
|
||||
get var_ind for old_feature and corresponding var_ind for new_feature
|
||||
"""
|
||||
new_var_ind, old_var_ind = [], []
|
||||
for this_new_id, this_new_name in enumerate(new_feature_list):
|
||||
if this_new_name in old_feature_list:
|
||||
this_old_id = old_feature_list.index(this_new_name)
|
||||
new_var_ind.append(this_new_id)
|
||||
old_var_ind.append(this_old_id)
|
||||
return np.asarray(old_var_ind), np.asarray(new_var_ind)
|
||||
|
||||
|
||||
def get_continuous_mapping_from_feat_dict(old_feature_dict, new_feature_dict):
|
||||
"""
|
||||
get var_ind for old_feature and corresponding var_ind for new_feature
|
||||
"""
|
||||
old_cont = old_feature_dict["continuous"]
|
||||
old_bin = old_feature_dict["binary"]
|
||||
|
||||
new_cont = new_feature_dict["continuous"]
|
||||
new_bin = new_feature_dict["binary"]
|
||||
|
||||
_dummy_sparse_feat = [f"sparse_feature_{_idx}" for _idx in range(100)]
|
||||
|
||||
cont_old_var_ind, cont_new_var_ind = get_continunous_mapping_from_feat_list(old_cont, new_cont)
|
||||
|
||||
all_old_var_ind, all_new_var_ind = get_continunous_mapping_from_feat_list(
|
||||
old_cont + old_bin + _dummy_sparse_feat, new_cont + new_bin + _dummy_sparse_feat
|
||||
)
|
||||
|
||||
_res = {
|
||||
"continuous": (cont_old_var_ind, cont_new_var_ind),
|
||||
"all": (all_old_var_ind, all_new_var_ind),
|
||||
}
|
||||
|
||||
return _res
|
||||
|
||||
|
||||
def warm_start_from_var_dict(
|
||||
old_ckpt_path,
|
||||
var_ind_dict,
|
||||
output_dir,
|
||||
new_len_var,
|
||||
var_to_change_dict_fn=get_model_type_to_tensors_to_change_axis,
|
||||
):
|
||||
"""
|
||||
Parameters:
|
||||
old_ckpt_path (str): path to the old checkpoint path
|
||||
new_var_ind (array of int): index to overlapping features in new var between old and new feature list.
|
||||
old_var_ind (array of int): index to overlapping features in old var between old and new feature list.
|
||||
|
||||
output_dir (str): dir that used to write modified checkpoint
|
||||
new_len_var ({str:int}): number of feature in the new feature list.
|
||||
var_to_change_dict_fn (dict): A function to get the dictionary of format {var_name: dim_to_change}
|
||||
"""
|
||||
old_var_dict = get_var_dict(old_ckpt_path)
|
||||
|
||||
ckpt_file_name = os.path.basename(old_ckpt_path)
|
||||
mkdirp(output_dir)
|
||||
output_path = join(output_dir, ckpt_file_name)
|
||||
|
||||
tensors_to_change = var_to_change_dict_fn()
|
||||
tf.compat.v1.reset_default_graph()
|
||||
|
||||
with tf.Session() as sess:
|
||||
var_name_shape_list = tf.train.list_variables(old_ckpt_path)
|
||||
count = 0
|
||||
|
||||
for var_name, var_shape in var_name_shape_list:
|
||||
old_var = old_var_dict[var_name]
|
||||
if var_name in tensors_to_change.keys():
|
||||
_info_tuple = tensors_to_change[var_name]
|
||||
dims_to_remove_from, var_type = _info_tuple
|
||||
|
||||
new_var_ind, old_var_ind = var_ind_dict[var_type]
|
||||
|
||||
this_shape = list(old_var.shape)
|
||||
for this_dim in dims_to_remove_from:
|
||||
this_shape[this_dim] = new_len_var[var_type]
|
||||
|
||||
stddev = np.std(old_var)
|
||||
truncated_norm_generator = stats.truncnorm(-0.5, 0.5, loc=0, scale=stddev)
|
||||
size = np.prod(this_shape)
|
||||
new_var = truncated_norm_generator.rvs(size).reshape(this_shape)
|
||||
new_var = new_var.astype(old_var.dtype)
|
||||
|
||||
new_var = copy_feat_based_on_mapping(
|
||||
new_var, old_var, dims_to_remove_from, new_var_ind, old_var_ind
|
||||
)
|
||||
count = count + 1
|
||||
else:
|
||||
new_var = old_var
|
||||
var = tf.Variable(new_var, name=var_name)
|
||||
assert count == len(tensors_to_change.keys()), "not all variables are exchanged.\n"
|
||||
saver = tf.train.Saver()
|
||||
sess.run(tf.global_variables_initializer())
|
||||
saver.save(sess, output_path)
|
||||
return output_path
|
||||
|
||||
|
||||
def copy_feat_based_on_mapping(new_array, old_array, dims_to_remove_from, new_var_ind, old_var_ind):
|
||||
if dims_to_remove_from == [0, 1]:
|
||||
for this_new_ind, this_old_ind in zip(new_var_ind, old_var_ind):
|
||||
new_array[this_new_ind, new_var_ind] = old_array[this_old_ind, old_var_ind]
|
||||
elif dims_to_remove_from == [0]:
|
||||
new_array[new_var_ind] = old_array[old_var_ind]
|
||||
elif dims_to_remove_from == [1]:
|
||||
new_array[:, new_var_ind] = old_array[:, old_var_ind]
|
||||
else:
|
||||
raise RuntimeError(f"undefined dims_to_remove_from pattern: ({dims_to_remove_from})")
|
||||
return new_array
|
||||
|
||||
|
||||
def read_file(filename, decode=False):
|
||||
"""
|
||||
Reads contents from a file and optionally decodes it.
|
||||
|
||||
Arguments:
|
||||
filename:
|
||||
path to file where the contents will be loaded from.
|
||||
Accepts HDFS and local paths.
|
||||
decode:
|
||||
False or 'json'. When decode='json', contents is decoded
|
||||
with json.loads. When False, contents is returned as is.
|
||||
"""
|
||||
graph = tf.Graph()
|
||||
with graph.as_default():
|
||||
read = tf.read_file(filename)
|
||||
|
||||
with tf.Session(graph=graph) as sess:
|
||||
contents = sess.run(read)
|
||||
if not isinstance(contents, str):
|
||||
contents = contents.decode()
|
||||
|
||||
if decode == "json":
|
||||
contents = json.loads(contents)
|
||||
|
||||
return contents
|
||||
|
||||
|
||||
def read_feat_list_from_disk(file_path):
|
||||
return read_file(file_path, decode="json")
|
||||
|
||||
|
||||
def get_feature_list_for_light_ranking(feature_list_path, data_spec_path):
|
||||
feature_list = read_config(feature_list_path).items()
|
||||
string_feat_list = [f[0] for f in feature_list if f[1] != "S"]
|
||||
|
||||
feature_config_builder = twml.contrib.feature_config.FeatureConfigBuilder(
|
||||
data_spec_path=data_spec_path
|
||||
)
|
||||
feature_config_builder = feature_config_builder.extract_feature_group(
|
||||
feature_regexes=string_feat_list,
|
||||
group_name="continuous",
|
||||
default_value=-1,
|
||||
type_filter=["CONTINUOUS"],
|
||||
)
|
||||
feature_config = feature_config_builder.build()
|
||||
feature_list = feature_config_builder._feature_group_extraction_configs[0].feature_map[
|
||||
"CONTINUOUS"
|
||||
]
|
||||
return feature_list
|
||||
|
||||
|
||||
def get_feature_list_for_heavy_ranking(feature_list_path, data_spec_path):
|
||||
feature_list = read_config(feature_list_path).items()
|
||||
string_feat_list = [f[0] for f in feature_list if f[1] != "S"]
|
||||
|
||||
feature_config_builder = twml.contrib.feature_config.FeatureConfigBuilder(
|
||||
data_spec_path=data_spec_path
|
||||
)
|
||||
feature_config_builder = feature_config_builder.extract_feature_group(
|
||||
feature_regexes=string_feat_list,
|
||||
group_name="continuous",
|
||||
default_value=-1,
|
||||
type_filter=["CONTINUOUS"],
|
||||
)
|
||||
|
||||
feature_config_builder = feature_config_builder.extract_feature_group(
|
||||
feature_regexes=string_feat_list,
|
||||
group_name="binary",
|
||||
default_value=False,
|
||||
type_filter=["BINARY"],
|
||||
)
|
||||
|
||||
feature_config_builder = feature_config_builder.build()
|
||||
|
||||
continuous_feature_list = feature_config_builder._feature_group_extraction_configs[0].feature_map[
|
||||
"CONTINUOUS"
|
||||
]
|
||||
|
||||
binary_feature_list = feature_config_builder._feature_group_extraction_configs[1].feature_map[
|
||||
"BINARY"
|
||||
]
|
||||
return {"continuous": continuous_feature_list, "binary": binary_feature_list}
|
||||
|
||||
|
||||
def warm_start_checkpoint(
|
||||
old_best_ckpt_folder,
|
||||
old_feature_list_path,
|
||||
feature_allow_list_path,
|
||||
data_spec_path,
|
||||
output_ckpt_folder,
|
||||
*args,
|
||||
):
|
||||
"""
|
||||
Reads old checkpoint and the old feature list, and create a new ckpt warm started from old ckpt using new features .
|
||||
|
||||
Arguments:
|
||||
old_best_ckpt_folder:
|
||||
path to the best_checkpoint_folder for old model
|
||||
old_feature_list_path:
|
||||
path to the json file that stores the list of continuous features used in old models.
|
||||
feature_allow_list_path:
|
||||
yaml file that contain the feature allow list.
|
||||
data_spec_path:
|
||||
path to the data_spec file
|
||||
output_ckpt_folder:
|
||||
folder that contains the modified ckpt.
|
||||
|
||||
Returns:
|
||||
path to the modified ckpt."""
|
||||
old_ckpt_path = tf.train.latest_checkpoint(old_best_ckpt_folder, latest_filename=None)
|
||||
|
||||
new_feature_dict = get_feature_list(feature_allow_list_path, data_spec_path)
|
||||
old_feature_dict = read_feat_list_from_disk(old_feature_list_path)
|
||||
|
||||
var_ind_dict = get_continuous_mapping_from_feat_dict(new_feature_dict, old_feature_dict)
|
||||
|
||||
new_len_var = {
|
||||
"continuous": len(new_feature_dict["continuous"]),
|
||||
"all": len(new_feature_dict["continuous"] + new_feature_dict["binary"]) + 100,
|
||||
}
|
||||
|
||||
warm_started_ckpt_path = warm_start_from_var_dict(
|
||||
old_ckpt_path,
|
||||
var_ind_dict,
|
||||
output_dir=output_ckpt_folder,
|
||||
new_len_var=new_len_var,
|
||||
)
|
||||
|
||||
return warm_started_ckpt_path
|
69
pushservice/src/main/python/models/light_ranking/BUILD
Normal file
69
pushservice/src/main/python/models/light_ranking/BUILD
Normal file
@ -0,0 +1,69 @@
|
||||
#":mlwf_libs",
|
||||
|
||||
python37_binary(
|
||||
name = "eval_model",
|
||||
source = "eval_model.py",
|
||||
dependencies = [
|
||||
":libs",
|
||||
"3rdparty/python/_closures/frigate/frigate-pushservice-opensource/src/main/python/models/light_ranking:eval_model",
|
||||
],
|
||||
)
|
||||
|
||||
python37_binary(
|
||||
name = "train_model",
|
||||
source = "deep_norm.py",
|
||||
dependencies = [
|
||||
":libs",
|
||||
"3rdparty/python/_closures/frigate/frigate-pushservice-opensource/src/main/python/models/light_ranking:train_model",
|
||||
],
|
||||
)
|
||||
|
||||
python37_binary(
|
||||
name = "train_model_local",
|
||||
source = "deep_norm.py",
|
||||
dependencies = [
|
||||
":libs",
|
||||
"3rdparty/python/_closures/frigate/frigate-pushservice-opensource/src/main/python/models/light_ranking:train_model_local",
|
||||
"twml",
|
||||
],
|
||||
)
|
||||
|
||||
python37_binary(
|
||||
name = "eval_model_local",
|
||||
source = "eval_model.py",
|
||||
dependencies = [
|
||||
":libs",
|
||||
"3rdparty/python/_closures/frigate/frigate-pushservice-opensource/src/main/python/models/light_ranking:eval_model_local",
|
||||
"twml",
|
||||
],
|
||||
)
|
||||
|
||||
python37_binary(
|
||||
name = "mlwf_model",
|
||||
source = "deep_norm.py",
|
||||
dependencies = [
|
||||
":mlwf_libs",
|
||||
"3rdparty/python/_closures/frigate/frigate-pushservice-opensource/src/main/python/models/light_ranking:mlwf_model",
|
||||
],
|
||||
)
|
||||
|
||||
python3_library(
|
||||
name = "libs",
|
||||
sources = ["**/*.py"],
|
||||
tags = ["no-mypy"],
|
||||
dependencies = [
|
||||
"src/python/twitter/deepbird/projects/magic_recs/libs",
|
||||
"src/python/twitter/deepbird/util/data",
|
||||
"twml:twml-nodeps",
|
||||
],
|
||||
)
|
||||
|
||||
python3_library(
|
||||
name = "mlwf_libs",
|
||||
sources = ["**/*.py"],
|
||||
tags = ["no-mypy"],
|
||||
dependencies = [
|
||||
"src/python/twitter/deepbird/projects/magic_recs/libs",
|
||||
"twml",
|
||||
],
|
||||
)
|
14
pushservice/src/main/python/models/light_ranking/README.md
Normal file
14
pushservice/src/main/python/models/light_ranking/README.md
Normal file
@ -0,0 +1,14 @@
|
||||
# Notification Light Ranker Model
|
||||
|
||||
## Model Context
|
||||
There are 4 major components of Twitter notifications recommendation system: 1) candidate generation 2) light ranking 3) heavy ranking & 4) quality control. This notification light ranker model bridges candidate generation and heavy ranking by pre-selecting highly-relevant candidates from the initial huge candidate pool. It’s a light-weight model to reduce system cost during heavy ranking without hurting user experience.
|
||||
|
||||
## Directory Structure
|
||||
- BUILD: this file defines python library dependencies
|
||||
- model_pools_mlp.py: this file defines tensorflow model architecture for the notification light ranker model
|
||||
- deep_norm.py: this file contains 1) how to build the tensorflow graph with specified model architecture, loss function and training configuration. 2) how to set up the overall model training & evaluation pipeline
|
||||
- eval_model.py: the main python entry file to set up the overall model evaluation pipeline
|
||||
|
||||
|
||||
|
||||
|
226
pushservice/src/main/python/models/light_ranking/deep_norm.py
Normal file
226
pushservice/src/main/python/models/light_ranking/deep_norm.py
Normal file
@ -0,0 +1,226 @@
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
import os
|
||||
|
||||
from twitter.cortex.ml.embeddings.common.helpers import decode_str_or_unicode
|
||||
import twml
|
||||
from twml.trainers import DataRecordTrainer
|
||||
|
||||
from ..libs.get_feat_config import get_feature_config_light_ranking, LABELS_LR
|
||||
from ..libs.graph_utils import get_trainable_variables
|
||||
from ..libs.group_metrics import (
|
||||
run_group_metrics_light_ranking,
|
||||
run_group_metrics_light_ranking_in_bq,
|
||||
)
|
||||
from ..libs.metric_fn_utils import get_metric_fn
|
||||
from ..libs.model_args import get_arg_parser_light_ranking
|
||||
from ..libs.model_utils import read_config
|
||||
from ..libs.warm_start_utils import get_feature_list_for_light_ranking
|
||||
from .model_pools_mlp import light_ranking_mlp_ngbdt
|
||||
|
||||
import tensorflow.compat.v1 as tf
|
||||
from tensorflow.compat.v1 import logging
|
||||
|
||||
|
||||
# checkstyle: noqa
|
||||
|
||||
|
||||
def build_graph(
|
||||
features, label, mode, params, config=None, run_light_ranking_group_metrics_in_bq=False
|
||||
):
|
||||
is_training = mode == tf.estimator.ModeKeys.TRAIN
|
||||
this_model_func = light_ranking_mlp_ngbdt
|
||||
model_output = this_model_func(features, is_training, params, label)
|
||||
|
||||
logits = model_output["output"]
|
||||
graph_output = {}
|
||||
# --------------------------------------------------------
|
||||
# define graph output dict
|
||||
# --------------------------------------------------------
|
||||
if mode == tf.estimator.ModeKeys.PREDICT:
|
||||
loss = None
|
||||
output_label = "prediction"
|
||||
if params.task_name in LABELS_LR:
|
||||
output = tf.nn.sigmoid(logits)
|
||||
output = tf.clip_by_value(output, 0, 1)
|
||||
|
||||
if run_light_ranking_group_metrics_in_bq:
|
||||
graph_output["trace_id"] = features["meta.trace_id"]
|
||||
graph_output["target"] = features["meta.ranking.weighted_oonc_model_score"]
|
||||
|
||||
else:
|
||||
raise ValueError("Invalid Task Name !")
|
||||
|
||||
else:
|
||||
output_label = "output"
|
||||
weights = tf.cast(features["weights"], dtype=tf.float32, name="RecordWeights")
|
||||
|
||||
if params.task_name in LABELS_LR:
|
||||
if params.use_record_weight:
|
||||
weights = tf.clip_by_value(
|
||||
1.0 / (1.0 + weights + params.smooth_weight), params.min_record_weight, 1.0
|
||||
)
|
||||
|
||||
loss = tf.reduce_sum(
|
||||
tf.nn.sigmoid_cross_entropy_with_logits(labels=label, logits=logits) * weights
|
||||
) / (tf.reduce_sum(weights))
|
||||
else:
|
||||
loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=label, logits=logits))
|
||||
output = tf.nn.sigmoid(logits)
|
||||
|
||||
else:
|
||||
raise ValueError("Invalid Task Name !")
|
||||
|
||||
train_op = None
|
||||
if mode == tf.estimator.ModeKeys.TRAIN:
|
||||
# --------------------------------------------------------
|
||||
# get train_op
|
||||
# --------------------------------------------------------
|
||||
optimizer = tf.train.GradientDescentOptimizer(learning_rate=params.learning_rate)
|
||||
update_ops = set(tf.get_collection(tf.GraphKeys.UPDATE_OPS))
|
||||
variables = get_trainable_variables(
|
||||
all_trainable_variables=tf.trainable_variables(), trainable_regexes=params.trainable_regexes
|
||||
)
|
||||
with tf.control_dependencies(update_ops):
|
||||
train_op = twml.optimizers.optimize_loss(
|
||||
loss=loss,
|
||||
variables=variables,
|
||||
global_step=tf.train.get_global_step(),
|
||||
optimizer=optimizer,
|
||||
learning_rate=params.learning_rate,
|
||||
learning_rate_decay_fn=twml.learning_rate_decay.get_learning_rate_decay_fn(params),
|
||||
)
|
||||
|
||||
graph_output[output_label] = output
|
||||
graph_output["loss"] = loss
|
||||
graph_output["train_op"] = train_op
|
||||
return graph_output
|
||||
|
||||
|
||||
def get_params(args=None):
|
||||
parser = get_arg_parser_light_ranking()
|
||||
if args is None:
|
||||
return parser.parse_args()
|
||||
else:
|
||||
return parser.parse_args(args)
|
||||
|
||||
|
||||
def _main():
|
||||
opt = get_params()
|
||||
logging.info("parse is: ")
|
||||
logging.info(opt)
|
||||
|
||||
feature_list = read_config(opt.feature_list).items()
|
||||
feature_config = get_feature_config_light_ranking(
|
||||
data_spec_path=opt.data_spec,
|
||||
feature_list_provided=feature_list,
|
||||
opt=opt,
|
||||
add_gbdt=opt.use_gbdt_features,
|
||||
run_light_ranking_group_metrics_in_bq=opt.run_light_ranking_group_metrics_in_bq,
|
||||
)
|
||||
feature_list_path = opt.feature_list
|
||||
|
||||
# --------------------------------------------------------
|
||||
# Create Trainer
|
||||
# --------------------------------------------------------
|
||||
trainer = DataRecordTrainer(
|
||||
name=opt.model_trainer_name,
|
||||
params=opt,
|
||||
build_graph_fn=build_graph,
|
||||
save_dir=opt.save_dir,
|
||||
run_config=None,
|
||||
feature_config=feature_config,
|
||||
metric_fn=get_metric_fn(opt.task_name, use_stratify_metrics=False),
|
||||
)
|
||||
if opt.directly_export_best:
|
||||
logging.info("Directly exporting the model without training")
|
||||
else:
|
||||
# ----------------------------------------------------
|
||||
# Model Training & Evaluation
|
||||
# ----------------------------------------------------
|
||||
eval_input_fn = trainer.get_eval_input_fn(repeat=False, shuffle=False)
|
||||
train_input_fn = trainer.get_train_input_fn(shuffle=True)
|
||||
|
||||
if opt.distributed or opt.num_workers is not None:
|
||||
learn = trainer.train_and_evaluate
|
||||
else:
|
||||
learn = trainer.learn
|
||||
logging.info("Training...")
|
||||
start = datetime.now()
|
||||
|
||||
early_stop_metric = "rce_unweighted_" + opt.task_name
|
||||
learn(
|
||||
early_stop_minimize=False,
|
||||
early_stop_metric=early_stop_metric,
|
||||
early_stop_patience=opt.early_stop_patience,
|
||||
early_stop_tolerance=opt.early_stop_tolerance,
|
||||
eval_input_fn=eval_input_fn,
|
||||
train_input_fn=train_input_fn,
|
||||
)
|
||||
|
||||
end = datetime.now()
|
||||
logging.info("Training time: " + str(end - start))
|
||||
|
||||
logging.info("Exporting the models...")
|
||||
|
||||
# --------------------------------------------------------
|
||||
# Do the model exporting
|
||||
# --------------------------------------------------------
|
||||
start = datetime.now()
|
||||
if not opt.export_dir:
|
||||
opt.export_dir = os.path.join(opt.save_dir, "exported_models")
|
||||
|
||||
raw_model_path = twml.contrib.export.export_fn.export_all_models(
|
||||
trainer=trainer,
|
||||
export_dir=opt.export_dir,
|
||||
parse_fn=feature_config.get_parse_fn(),
|
||||
serving_input_receiver_fn=feature_config.get_serving_input_receiver_fn(),
|
||||
export_output_fn=twml.export_output_fns.batch_prediction_continuous_output_fn,
|
||||
)
|
||||
export_model_dir = decode_str_or_unicode(raw_model_path)
|
||||
|
||||
logging.info("Model export time: " + str(datetime.now() - start))
|
||||
logging.info("The saved model directory is: " + opt.save_dir)
|
||||
|
||||
tf.logging.info("getting default continuous_feature_list")
|
||||
continuous_feature_list = get_feature_list_for_light_ranking(feature_list_path, opt.data_spec)
|
||||
continous_feature_list_save_path = os.path.join(opt.save_dir, "continuous_feature_list.json")
|
||||
twml.util.write_file(continous_feature_list_save_path, continuous_feature_list, encode="json")
|
||||
tf.logging.info(f"Finish writting files to {continous_feature_list_save_path}")
|
||||
|
||||
if opt.run_light_ranking_group_metrics:
|
||||
# --------------------------------------------
|
||||
# Run Light Ranking Group Metrics
|
||||
# --------------------------------------------
|
||||
run_group_metrics_light_ranking(
|
||||
trainer=trainer,
|
||||
data_dir=os.path.join(opt.eval_data_dir, opt.eval_start_datetime),
|
||||
model_path=export_model_dir,
|
||||
parse_fn=feature_config.get_parse_fn(),
|
||||
)
|
||||
|
||||
if opt.run_light_ranking_group_metrics_in_bq:
|
||||
# ----------------------------------------------------------------------------------------
|
||||
# Get Light/Heavy Ranker Predictions for Light Ranking Group Metrics in BigQuery
|
||||
# ----------------------------------------------------------------------------------------
|
||||
trainer_pred = DataRecordTrainer(
|
||||
name=opt.model_trainer_name,
|
||||
params=opt,
|
||||
build_graph_fn=partial(build_graph, run_light_ranking_group_metrics_in_bq=True),
|
||||
save_dir=opt.save_dir + "/tmp/",
|
||||
run_config=None,
|
||||
feature_config=feature_config,
|
||||
metric_fn=get_metric_fn(opt.task_name, use_stratify_metrics=False),
|
||||
)
|
||||
checkpoint_folder = os.path.join(opt.save_dir, "best_checkpoint")
|
||||
checkpoint = tf.train.latest_checkpoint(checkpoint_folder, latest_filename=None)
|
||||
tf.logging.info("\n\nPrediction from Checkpoint: {:}.\n\n".format(checkpoint))
|
||||
run_group_metrics_light_ranking_in_bq(
|
||||
trainer=trainer_pred, params=opt, checkpoint_path=checkpoint
|
||||
)
|
||||
|
||||
tf.logging.info("Done Training & Prediction.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
_main()
|
@ -0,0 +1,89 @@
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
import os
|
||||
|
||||
from ..libs.group_metrics import (
|
||||
run_group_metrics_light_ranking,
|
||||
run_group_metrics_light_ranking_in_bq,
|
||||
)
|
||||
from ..libs.metric_fn_utils import get_metric_fn
|
||||
from ..libs.model_args import get_arg_parser_light_ranking
|
||||
from ..libs.model_utils import read_config
|
||||
from .deep_norm import build_graph, DataRecordTrainer, get_config_func, logging
|
||||
|
||||
|
||||
# checkstyle: noqa
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = get_arg_parser_light_ranking()
|
||||
parser.add_argument(
|
||||
"--eval_checkpoint",
|
||||
default=None,
|
||||
type=str,
|
||||
help="Which checkpoint to use for evaluation",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--saved_model_path",
|
||||
default=None,
|
||||
type=str,
|
||||
help="Path to saved model for evaluation",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--run_binary_metrics",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Whether to compute the basic binary metrics for Light Ranking.",
|
||||
)
|
||||
|
||||
opt = parser.parse_args()
|
||||
logging.info("parse is: ")
|
||||
logging.info(opt)
|
||||
|
||||
feature_list = read_config(opt.feature_list).items()
|
||||
feature_config = get_config_func(opt.feat_config_type)(
|
||||
data_spec_path=opt.data_spec,
|
||||
feature_list_provided=feature_list,
|
||||
opt=opt,
|
||||
add_gbdt=opt.use_gbdt_features,
|
||||
run_light_ranking_group_metrics_in_bq=opt.run_light_ranking_group_metrics_in_bq,
|
||||
)
|
||||
|
||||
# -----------------------------------------------
|
||||
# Create Trainer
|
||||
# -----------------------------------------------
|
||||
trainer = DataRecordTrainer(
|
||||
name=opt.model_trainer_name,
|
||||
params=opt,
|
||||
build_graph_fn=partial(build_graph, run_light_ranking_group_metrics_in_bq=True),
|
||||
save_dir=opt.save_dir,
|
||||
run_config=None,
|
||||
feature_config=feature_config,
|
||||
metric_fn=get_metric_fn(opt.task_name, use_stratify_metrics=False),
|
||||
)
|
||||
|
||||
# -----------------------------------------------
|
||||
# Model Evaluation
|
||||
# -----------------------------------------------
|
||||
logging.info("Evaluating...")
|
||||
start = datetime.now()
|
||||
|
||||
if opt.run_binary_metrics:
|
||||
eval_input_fn = trainer.get_eval_input_fn(repeat=False, shuffle=False)
|
||||
eval_steps = None if (opt.eval_steps is not None and opt.eval_steps < 0) else opt.eval_steps
|
||||
trainer.estimator.evaluate(eval_input_fn, steps=eval_steps, checkpoint_path=opt.eval_checkpoint)
|
||||
|
||||
if opt.run_light_ranking_group_metrics_in_bq:
|
||||
run_group_metrics_light_ranking_in_bq(
|
||||
trainer=trainer, params=opt, checkpoint_path=opt.eval_checkpoint
|
||||
)
|
||||
|
||||
if opt.run_light_ranking_group_metrics:
|
||||
run_group_metrics_light_ranking(
|
||||
trainer=trainer,
|
||||
data_dir=os.path.join(opt.eval_data_dir, opt.eval_start_datetime),
|
||||
model_path=opt.saved_model_path,
|
||||
parse_fn=feature_config.get_parse_fn(),
|
||||
)
|
||||
|
||||
end = datetime.now()
|
||||
logging.info("Evaluating time: " + str(end - start))
|
@ -0,0 +1,187 @@
|
||||
import warnings
|
||||
|
||||
from twml.contrib.layers import ZscoreNormalization
|
||||
|
||||
from ...libs.customized_full_sparse import FullSparse
|
||||
from ...libs.get_feat_config import FEAT_CONFIG_DEFAULT_VAL as MISSING_VALUE_MARKER
|
||||
from ...libs.model_utils import (
|
||||
_sparse_feature_fixup,
|
||||
adaptive_transformation,
|
||||
filter_nans_and_infs,
|
||||
get_dense_out,
|
||||
tensor_dropout,
|
||||
)
|
||||
|
||||
import tensorflow.compat.v1 as tf
|
||||
# checkstyle: noqa
|
||||
|
||||
def light_ranking_mlp_ngbdt(features, is_training, params, label=None):
|
||||
return deepnorm_light_ranking(
|
||||
features,
|
||||
is_training,
|
||||
params,
|
||||
label=label,
|
||||
decay=params.momentum,
|
||||
dense_emb_size=params.dense_embedding_size,
|
||||
base_activation=tf.keras.layers.LeakyReLU(),
|
||||
input_dropout_rate=params.dropout,
|
||||
use_gbdt=False,
|
||||
)
|
||||
|
||||
|
||||
def deepnorm_light_ranking(
|
||||
features,
|
||||
is_training,
|
||||
params,
|
||||
label=None,
|
||||
decay=0.99999,
|
||||
dense_emb_size=128,
|
||||
base_activation=None,
|
||||
input_dropout_rate=None,
|
||||
input_dense_type="self_atten_dense",
|
||||
emb_dense_type="self_atten_dense",
|
||||
mlp_dense_type="self_atten_dense",
|
||||
use_gbdt=False,
|
||||
):
|
||||
# --------------------------------------------------------
|
||||
# Initial Parameter Checking
|
||||
# --------------------------------------------------------
|
||||
if base_activation is None:
|
||||
base_activation = tf.keras.layers.LeakyReLU()
|
||||
|
||||
if label is not None:
|
||||
warnings.warn(
|
||||
"Label is unused in deepnorm_gbdt. Stop using this argument.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
with tf.variable_scope("helper_layers"):
|
||||
full_sparse_layer = FullSparse(
|
||||
output_size=params.sparse_embedding_size,
|
||||
activation=base_activation,
|
||||
use_sparse_grads=is_training,
|
||||
use_binary_values=False,
|
||||
dtype=tf.float32,
|
||||
)
|
||||
input_normalizing_layer = ZscoreNormalization(decay=decay, name="input_normalizing_layer")
|
||||
|
||||
# --------------------------------------------------------
|
||||
# Feature Selection & Embedding
|
||||
# --------------------------------------------------------
|
||||
if use_gbdt:
|
||||
sparse_gbdt_features = _sparse_feature_fixup(features["gbdt_sparse"], params.input_size_bits)
|
||||
if input_dropout_rate is not None:
|
||||
sparse_gbdt_features = tensor_dropout(
|
||||
sparse_gbdt_features, input_dropout_rate, is_training, sparse_tensor=True
|
||||
)
|
||||
|
||||
total_embed = full_sparse_layer(sparse_gbdt_features, use_binary_values=True)
|
||||
|
||||
if (input_dropout_rate is not None) and is_training:
|
||||
total_embed = total_embed / (1 - input_dropout_rate)
|
||||
|
||||
else:
|
||||
with tf.variable_scope("dense_branch"):
|
||||
dense_continuous_features = filter_nans_and_infs(features["continuous"])
|
||||
|
||||
if params.use_missing_sub_branch:
|
||||
is_missing = tf.equal(dense_continuous_features, MISSING_VALUE_MARKER)
|
||||
continuous_features_filled = tf.where(
|
||||
is_missing,
|
||||
tf.zeros_like(dense_continuous_features),
|
||||
dense_continuous_features,
|
||||
)
|
||||
normalized_features = input_normalizing_layer(
|
||||
continuous_features_filled, is_training, tf.math.logical_not(is_missing)
|
||||
)
|
||||
|
||||
with tf.variable_scope("missing_sub_branch"):
|
||||
missing_feature_embed = get_dense_out(
|
||||
tf.cast(is_missing, tf.float32),
|
||||
dense_emb_size,
|
||||
activation=base_activation,
|
||||
dense_type=input_dense_type,
|
||||
)
|
||||
|
||||
else:
|
||||
continuous_features_filled = dense_continuous_features
|
||||
normalized_features = input_normalizing_layer(continuous_features_filled, is_training)
|
||||
|
||||
with tf.variable_scope("continuous_sub_branch"):
|
||||
normalized_features = adaptive_transformation(
|
||||
normalized_features, is_training, func_type="tiny"
|
||||
)
|
||||
|
||||
if input_dropout_rate is not None:
|
||||
normalized_features = tensor_dropout(
|
||||
normalized_features,
|
||||
input_dropout_rate,
|
||||
is_training,
|
||||
sparse_tensor=False,
|
||||
)
|
||||
filled_feature_embed = get_dense_out(
|
||||
normalized_features,
|
||||
dense_emb_size,
|
||||
activation=base_activation,
|
||||
dense_type=input_dense_type,
|
||||
)
|
||||
|
||||
if params.use_missing_sub_branch:
|
||||
dense_embed = tf.concat(
|
||||
[filled_feature_embed, missing_feature_embed], axis=1, name="merge_dense_emb"
|
||||
)
|
||||
else:
|
||||
dense_embed = filled_feature_embed
|
||||
|
||||
with tf.variable_scope("sparse_branch"):
|
||||
sparse_discrete_features = _sparse_feature_fixup(
|
||||
features["sparse_no_continuous"], params.input_size_bits
|
||||
)
|
||||
if input_dropout_rate is not None:
|
||||
sparse_discrete_features = tensor_dropout(
|
||||
sparse_discrete_features, input_dropout_rate, is_training, sparse_tensor=True
|
||||
)
|
||||
|
||||
discrete_features_embed = full_sparse_layer(sparse_discrete_features, use_binary_values=True)
|
||||
|
||||
if (input_dropout_rate is not None) and is_training:
|
||||
discrete_features_embed = discrete_features_embed / (1 - input_dropout_rate)
|
||||
|
||||
total_embed = tf.concat(
|
||||
[dense_embed, discrete_features_embed],
|
||||
axis=1,
|
||||
name="total_embed",
|
||||
)
|
||||
|
||||
total_embed = tf.layers.batch_normalization(
|
||||
total_embed,
|
||||
training=is_training,
|
||||
renorm_momentum=decay,
|
||||
momentum=decay,
|
||||
renorm=is_training,
|
||||
trainable=True,
|
||||
)
|
||||
|
||||
# --------------------------------------------------------
|
||||
# MLP Layers
|
||||
# --------------------------------------------------------
|
||||
with tf.variable_scope("MLP_branch"):
|
||||
|
||||
assert params.num_mlp_layers >= 0
|
||||
embed_list = [total_embed] + [None for _ in range(params.num_mlp_layers)]
|
||||
dense_types = [emb_dense_type] + [mlp_dense_type for _ in range(params.num_mlp_layers - 1)]
|
||||
|
||||
for xl in range(1, params.num_mlp_layers + 1):
|
||||
neurons = params.mlp_neuron_scale ** (params.num_mlp_layers + 1 - xl)
|
||||
embed_list[xl] = get_dense_out(
|
||||
embed_list[xl - 1], neurons, activation=base_activation, dense_type=dense_types[xl - 1]
|
||||
)
|
||||
|
||||
if params.task_name in ["Sent", "HeavyRankPosition", "HeavyRankProbability"]:
|
||||
logits = get_dense_out(embed_list[-1], 1, activation=None, dense_type=mlp_dense_type)
|
||||
|
||||
else:
|
||||
raise ValueError("Invalid Task Name !")
|
||||
|
||||
output_dict = {"output": logits}
|
||||
return output_dict
|
@ -0,0 +1,337 @@
|
||||
scala_library(
|
||||
sources = ["**/*.scala"],
|
||||
compiler_option_sets = ["fatal_warnings"],
|
||||
strict_deps = True,
|
||||
tags = [
|
||||
"bazel-compatible",
|
||||
],
|
||||
dependencies = [
|
||||
"3rdparty/jvm/com/twitter/bijection:scrooge",
|
||||
"3rdparty/jvm/com/twitter/storehaus:core",
|
||||
"abdecider",
|
||||
"abuse/detection/src/main/thrift/com/twitter/abuse/detection/scoring:thrift-scala",
|
||||
"ann/src/main/scala/com/twitter/ann/common",
|
||||
"ann/src/main/thrift/com/twitter/ann/common:ann-common-scala",
|
||||
"audience-rewards/thrift/src/main/thrift:thrift-scala",
|
||||
"communities/thrift/src/main/thrift/com/twitter/communities:thrift-scala",
|
||||
"configapi/configapi-core",
|
||||
"configapi/configapi-decider",
|
||||
"content-mixer/thrift/src/main/thrift:thrift-scala",
|
||||
"content-recommender/thrift/src/main/thrift:thrift-scala",
|
||||
"copyselectionservice/server/src/main/scala/com/twitter/copyselectionservice/algorithms",
|
||||
"copyselectionservice/thrift/src/main/thrift:copyselectionservice-scala",
|
||||
"cortex-deepbird/thrift/src/main/thrift:thrift-java",
|
||||
"cr-mixer/thrift/src/main/thrift:thrift-scala",
|
||||
"cuad/projects/hashspace/thrift:thrift-scala",
|
||||
"cuad/projects/tagspace/thrift/src/main/thrift:thrift-scala",
|
||||
"detopic/thrift/src/main/thrift:thrift-scala",
|
||||
"discovery-common/src/main/scala/com/twitter/discovery/common/configapi",
|
||||
"discovery-common/src/main/scala/com/twitter/discovery/common/ddg",
|
||||
"discovery-common/src/main/scala/com/twitter/discovery/common/environment",
|
||||
"discovery-common/src/main/scala/com/twitter/discovery/common/fatigue",
|
||||
"discovery-common/src/main/scala/com/twitter/discovery/common/nackwarmupfilter",
|
||||
"discovery-common/src/main/scala/com/twitter/discovery/common/server",
|
||||
"discovery-ds/src/main/thrift/com/twitter/dds/scio/searcher_aggregate_history_srp:searcher_aggregate_history_srp-scala",
|
||||
"escherbird/src/scala/com/twitter/escherbird/util/metadatastitch",
|
||||
"escherbird/src/scala/com/twitter/escherbird/util/uttclient",
|
||||
"escherbird/src/thrift/com/twitter/escherbird/utt:strato-columns-scala",
|
||||
"eventbus/client",
|
||||
"eventdetection/event_context/src/main/scala/com/twitter/eventdetection/event_context/util",
|
||||
"events-recos/events-recos-service/src/main/thrift:events-recos-thrift-scala",
|
||||
"explore/explore-ranker/thrift/src/main/thrift:thrift-scala",
|
||||
"featureswitches/featureswitches-core/src/main/scala",
|
||||
"featureswitches/featureswitches-core/src/main/scala:dynmap",
|
||||
"featureswitches/featureswitches-core/src/main/scala:recipient",
|
||||
"featureswitches/featureswitches-core/src/main/scala:useragent",
|
||||
"featureswitches/featureswitches-core/src/main/scala/com/twitter/featureswitches/v2/builder",
|
||||
"finagle-internal/mtls/src/main/scala/com/twitter/finagle/mtls/authentication",
|
||||
"finagle-internal/mtls/src/main/scala/com/twitter/finagle/mtls/server",
|
||||
"finagle-internal/ostrich-stats",
|
||||
"finagle/finagle-core/src/main",
|
||||
"finagle/finagle-http/src/main/scala",
|
||||
"finagle/finagle-memcached/src/main/scala",
|
||||
"finagle/finagle-stats",
|
||||
"finagle/finagle-thriftmux",
|
||||
"finagle/finagle-tunable/src/main/scala",
|
||||
"finagle/finagle-zipkin-scribe",
|
||||
"finatra-internal/abdecider",
|
||||
"finatra-internal/decider",
|
||||
"finatra-internal/mtls-http/src/main/scala",
|
||||
"finatra-internal/mtls-thriftmux/src/main/scala",
|
||||
"finatra/http-client/src/main/scala",
|
||||
"finatra/http-core/src/main/java/com/twitter/finatra/http",
|
||||
"finatra/http-core/src/main/scala/com/twitter/finatra/http/response",
|
||||
"finatra/http-server/src/main/scala/com/twitter/finatra/http",
|
||||
"finatra/http-server/src/main/scala/com/twitter/finatra/http/filters",
|
||||
"finatra/inject/inject-app/src/main/java/com/twitter/inject/annotations",
|
||||
"finatra/inject/inject-app/src/main/scala",
|
||||
"finatra/inject/inject-core/src/main/scala",
|
||||
"finatra/inject/inject-server/src/main/scala",
|
||||
"finatra/inject/inject-slf4j/src/main/scala/com/twitter/inject",
|
||||
"finatra/inject/inject-thrift-client/src/main/scala",
|
||||
"finatra/inject/inject-utils/src/main/scala",
|
||||
"finatra/utils/src/main/java/com/twitter/finatra/annotations",
|
||||
"fleets/fleets-proxy/thrift/src/main/thrift:fleet-scala",
|
||||
"fleets/fleets-proxy/thrift/src/main/thrift/service:baseservice-scala",
|
||||
"flock-client/src/main/scala",
|
||||
"flock-client/src/main/thrift:thrift-scala",
|
||||
"follow-recommendations-service/thrift/src/main/thrift:thrift-scala",
|
||||
"frigate/frigate-common:base",
|
||||
"frigate/frigate-common:config",
|
||||
"frigate/frigate-common:debug",
|
||||
"frigate/frigate-common:entity_graph_client",
|
||||
"frigate/frigate-common:history",
|
||||
"frigate/frigate-common:logger",
|
||||
"frigate/frigate-common:ml-base",
|
||||
"frigate/frigate-common:ml-feature",
|
||||
"frigate/frigate-common:ml-prediction",
|
||||
"frigate/frigate-common:ntab",
|
||||
"frigate/frigate-common:predicate",
|
||||
"frigate/frigate-common:rec_types",
|
||||
"frigate/frigate-common:score_summary",
|
||||
"frigate/frigate-common:util",
|
||||
"frigate/frigate-common/src/main/scala/com/twitter/frigate/common/candidate",
|
||||
"frigate/frigate-common/src/main/scala/com/twitter/frigate/common/experiments",
|
||||
"frigate/frigate-common/src/main/scala/com/twitter/frigate/common/filter",
|
||||
"frigate/frigate-common/src/main/scala/com/twitter/frigate/common/modules/store:semantic_core_stores",
|
||||
"frigate/frigate-common/src/main/scala/com/twitter/frigate/common/store",
|
||||
"frigate/frigate-common/src/main/scala/com/twitter/frigate/common/store/deviceinfo",
|
||||
"frigate/frigate-common/src/main/scala/com/twitter/frigate/common/store/interests",
|
||||
"frigate/frigate-common/src/main/scala/com/twitter/frigate/common/store/strato",
|
||||
"frigate/push-mixer/thrift/src/main/thrift:thrift-scala",
|
||||
"geo/geo-prediction/src/main/thrift:local-viral-tweets-thrift-scala",
|
||||
"geoduck/service/src/main/scala/com/twitter/geoduck/service/common/clientmodules",
|
||||
"geoduck/util/country",
|
||||
"gizmoduck/client/src/main/scala/com/twitter/gizmoduck/testusers/client",
|
||||
"hermit/hermit-core:model-user_state",
|
||||
"hermit/hermit-core:predicate",
|
||||
"hermit/hermit-core:predicate-gizmoduck",
|
||||
"hermit/hermit-core:predicate-scarecrow",
|
||||
"hermit/hermit-core:predicate-socialgraph",
|
||||
"hermit/hermit-core:predicate-tweetypie",
|
||||
"hermit/hermit-core:store-labeled_push_recs",
|
||||
"hermit/hermit-core:store-metastore",
|
||||
"hermit/hermit-core:store-timezone",
|
||||
"hermit/hermit-core:store-tweetypie",
|
||||
"hermit/hermit-core/src/main/scala/com/twitter/hermit/constants",
|
||||
"hermit/hermit-core/src/main/scala/com/twitter/hermit/model",
|
||||
"hermit/hermit-core/src/main/scala/com/twitter/hermit/store",
|
||||
"hermit/hermit-core/src/main/scala/com/twitter/hermit/store/common",
|
||||
"hermit/hermit-core/src/main/scala/com/twitter/hermit/store/gizmoduck",
|
||||
"hermit/hermit-core/src/main/scala/com/twitter/hermit/store/scarecrow",
|
||||
"hermit/hermit-core/src/main/scala/com/twitter/hermit/store/semantic_core",
|
||||
"hermit/hermit-core/src/main/scala/com/twitter/hermit/store/user_htl_session_store",
|
||||
"hermit/hermit-core/src/main/scala/com/twitter/hermit/store/user_interest",
|
||||
"hmli/hss/src/main/thrift/com/twitter/hss:thrift-scala",
|
||||
"ibis2/service/src/main/scala/com/twitter/ibis2/lib",
|
||||
"ibis2/service/src/main/thrift/com/twitter/ibis2/service:ibis2-service-scala",
|
||||
"interests-service/thrift/src/main/thrift:thrift-scala",
|
||||
"interests_discovery/thrift/src/main/thrift:batch-thrift-scala",
|
||||
"interests_discovery/thrift/src/main/thrift:service-thrift-scala",
|
||||
"kujaku/thrift/src/main/thrift:domain-scala",
|
||||
"live-video-timeline/client/src/main/scala/com/twitter/livevideo/timeline/client/v2",
|
||||
"live-video-timeline/domain/src/main/scala/com/twitter/livevideo/timeline/domain",
|
||||
"live-video-timeline/domain/src/main/scala/com/twitter/livevideo/timeline/domain/v2",
|
||||
"live-video-timeline/thrift/src/main/thrift/com/twitter/livevideo/timeline:thrift-scala",
|
||||
"live-video/common/src/main/scala/com/twitter/livevideo/common/domain/v2",
|
||||
"live-video/common/src/main/scala/com/twitter/livevideo/common/ids",
|
||||
"notifications-platform/inbound-notifications/src/main/thrift/com/twitter/inbound_notifications:exception-scala",
|
||||
"notifications-platform/inbound-notifications/src/main/thrift/com/twitter/inbound_notifications:thrift-scala",
|
||||
"notifications-platform/platform-lib/src/main/thrift/com/twitter/notifications/platform:custom-notification-actions-scala",
|
||||
"notifications-platform/platform-lib/src/main/thrift/com/twitter/notifications/platform:thrift-scala",
|
||||
"notifications-relevance/src/scala/com/twitter/nrel/heavyranker",
|
||||
"notifications-relevance/src/scala/com/twitter/nrel/hydration/base",
|
||||
"notifications-relevance/src/scala/com/twitter/nrel/hydration/frigate",
|
||||
"notifications-relevance/src/scala/com/twitter/nrel/hydration/push",
|
||||
"notifications-relevance/src/scala/com/twitter/nrel/lightranker",
|
||||
"notificationservice/common/src/main/scala/com/twitter/notificationservice/genericfeedbackstore",
|
||||
"notificationservice/common/src/main/scala/com/twitter/notificationservice/model:alias",
|
||||
"notificationservice/common/src/main/scala/com/twitter/notificationservice/model/service",
|
||||
"notificationservice/common/src/test/scala/com/twitter/notificationservice/mocks",
|
||||
"notificationservice/scribe/src/main/scala/com/twitter/notificationservice/scribe/manhattan:mh_wrapper",
|
||||
"notificationservice/thrift/src/main/thrift/com/twitter/notificationservice/api:thrift-scala",
|
||||
"notificationservice/thrift/src/main/thrift/com/twitter/notificationservice/badgecount-api:thrift-scala",
|
||||
"notificationservice/thrift/src/main/thrift/com/twitter/notificationservice/generic_notifications:thrift-scala",
|
||||
"notifinfra/ni-lib/src/main/scala/com/twitter/ni/lib/logged_out_transform",
|
||||
"observability/observability-manhattan-client/src/main/scala",
|
||||
"onboarding/service/src/main/scala/com/twitter/onboarding/task/service/models/external",
|
||||
"onboarding/service/thrift/src/main/thrift:thrift-scala",
|
||||
"people-discovery/api/thrift/src/main/thrift:thrift-scala",
|
||||
"periscope/api-proxy-thrift/thrift/src/main/thrift:thrift-scala",
|
||||
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/module",
|
||||
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/module/stringcenter",
|
||||
"product-mixer/core/src/main/thrift/com/twitter/product_mixer/core:thrift-scala",
|
||||
"qig-ranker/thrift/src/main/thrift:thrift-scala",
|
||||
"rux-ds/src/main/thrift/com/twitter/ruxds/jobs/user_past_aggregate:user_past_aggregate-scala",
|
||||
"rux/common/src/main/scala/com/twitter/rux/common/encode",
|
||||
"rux/common/thrift/src/main/thrift/rux-context:rux-context-scala",
|
||||
"rux/common/thrift/src/main/thrift/strato:strato-scala",
|
||||
"scribelib/marshallers/src/main/scala/com/twitter/scribelib/marshallers",
|
||||
"scrooge/scrooge-core",
|
||||
"scrooge/scrooge-serializer/src/main/scala",
|
||||
"sensitive-ds/src/main/thrift/com/twitter/scio/nsfw_user_segmentation:nsfw_user_segmentation-scala",
|
||||
"servo/decider/src/main/scala",
|
||||
"servo/request/src/main/scala",
|
||||
"servo/util/src/main/scala",
|
||||
"src/java/com/twitter/ml/api:api-base",
|
||||
"src/java/com/twitter/ml/prediction/core",
|
||||
"src/scala/com/twitter/frigate/data_pipeline/common",
|
||||
"src/scala/com/twitter/frigate/data_pipeline/embedding_cg:embedding_cg-test-user-ids",
|
||||
"src/scala/com/twitter/frigate/data_pipeline/features_common",
|
||||
"src/scala/com/twitter/frigate/news_article_recs/news_articles_metadata:thrift-scala",
|
||||
"src/scala/com/twitter/frontpage/stream/util",
|
||||
"src/scala/com/twitter/language/normalization",
|
||||
"src/scala/com/twitter/ml/api/embedding",
|
||||
"src/scala/com/twitter/ml/api/util:datarecord",
|
||||
"src/scala/com/twitter/ml/featurestore/catalog/entities/core",
|
||||
"src/scala/com/twitter/ml/featurestore/catalog/entities/magicrecs",
|
||||
"src/scala/com/twitter/ml/featurestore/catalog/features/core:aggregate",
|
||||
"src/scala/com/twitter/ml/featurestore/catalog/features/cuad:aggregate",
|
||||
"src/scala/com/twitter/ml/featurestore/catalog/features/embeddings",
|
||||
"src/scala/com/twitter/ml/featurestore/catalog/features/magicrecs:aggregate",
|
||||
"src/scala/com/twitter/ml/featurestore/catalog/features/topic_signals:aggregate",
|
||||
"src/scala/com/twitter/ml/featurestore/lib",
|
||||
"src/scala/com/twitter/ml/featurestore/lib/data",
|
||||
"src/scala/com/twitter/ml/featurestore/lib/dynamic",
|
||||
"src/scala/com/twitter/ml/featurestore/lib/entity",
|
||||
"src/scala/com/twitter/ml/featurestore/lib/online",
|
||||
"src/scala/com/twitter/recommendation/interests/discovery/core/config",
|
||||
"src/scala/com/twitter/recommendation/interests/discovery/core/deploy",
|
||||
"src/scala/com/twitter/recommendation/interests/discovery/core/model",
|
||||
"src/scala/com/twitter/recommendation/interests/discovery/popgeo/deploy",
|
||||
"src/scala/com/twitter/simclusters_v2/common",
|
||||
"src/scala/com/twitter/storehaus_internal/manhattan",
|
||||
"src/scala/com/twitter/storehaus_internal/manhattan/config",
|
||||
"src/scala/com/twitter/storehaus_internal/memcache",
|
||||
"src/scala/com/twitter/storehaus_internal/memcache/config",
|
||||
"src/scala/com/twitter/storehaus_internal/util",
|
||||
"src/scala/com/twitter/taxi/common",
|
||||
"src/scala/com/twitter/taxi/config",
|
||||
"src/scala/com/twitter/taxi/deploy",
|
||||
"src/scala/com/twitter/taxi/trending/common",
|
||||
"src/thrift/com/twitter/ads/adserver:adserver_rpc-scala",
|
||||
"src/thrift/com/twitter/clientapp/gen:clientapp-scala",
|
||||
"src/thrift/com/twitter/core_workflows/user_model:user_model-scala",
|
||||
"src/thrift/com/twitter/escherbird/common:constants-scala",
|
||||
"src/thrift/com/twitter/escherbird/metadata:megadata-scala",
|
||||
"src/thrift/com/twitter/escherbird/metadata:metadata-service-scala",
|
||||
"src/thrift/com/twitter/escherbird/search:search-service-scala",
|
||||
"src/thrift/com/twitter/expandodo:only-scala",
|
||||
"src/thrift/com/twitter/frigate:frigate-common-thrift-scala",
|
||||
"src/thrift/com/twitter/frigate:frigate-ml-thrift-scala",
|
||||
"src/thrift/com/twitter/frigate:frigate-notification-thrift-scala",
|
||||
"src/thrift/com/twitter/frigate:frigate-secondary-accounts-thrift-scala",
|
||||
"src/thrift/com/twitter/frigate:frigate-thrift-scala",
|
||||
"src/thrift/com/twitter/frigate:frigate-user-media-representation-thrift-scala",
|
||||
"src/thrift/com/twitter/frigate/data_pipeline:frigate-user-history-thrift-scala",
|
||||
"src/thrift/com/twitter/frigate/dau_model:frigate-dau-thrift-scala",
|
||||
"src/thrift/com/twitter/frigate/magic_events:frigate-magic-events-thrift-scala",
|
||||
"src/thrift/com/twitter/frigate/magic_events/scribe:thrift-scala",
|
||||
"src/thrift/com/twitter/frigate/pushcap:frigate-pushcap-thrift-scala",
|
||||
"src/thrift/com/twitter/frigate/pushservice:frigate-pushservice-thrift-scala",
|
||||
"src/thrift/com/twitter/frigate/scribe:frigate-scribe-thrift-scala",
|
||||
"src/thrift/com/twitter/frigate/subscribed_search:frigate-subscribed-search-thrift-scala",
|
||||
"src/thrift/com/twitter/frigate/user_states:frigate-userstates-thrift-scala",
|
||||
"src/thrift/com/twitter/geoduck:geoduck-scala",
|
||||
"src/thrift/com/twitter/gizmoduck:thrift-scala",
|
||||
"src/thrift/com/twitter/gizmoduck:user-thrift-scala",
|
||||
"src/thrift/com/twitter/hermit:hermit-scala",
|
||||
"src/thrift/com/twitter/hermit/pop_geo:hermit-pop-geo-scala",
|
||||
"src/thrift/com/twitter/hermit/stp:hermit-stp-scala",
|
||||
"src/thrift/com/twitter/ibis:service-scala",
|
||||
"src/thrift/com/twitter/manhattan:v1-scala",
|
||||
"src/thrift/com/twitter/manhattan:v2-scala",
|
||||
"src/thrift/com/twitter/ml/api:data-java",
|
||||
"src/thrift/com/twitter/ml/api:data-scala",
|
||||
"src/thrift/com/twitter/ml/featurestore/timelines:ml-features-timelines-scala",
|
||||
"src/thrift/com/twitter/ml/featurestore/timelines:ml-features-timelines-strato",
|
||||
"src/thrift/com/twitter/ml/prediction_service:prediction_service-java",
|
||||
"src/thrift/com/twitter/permissions_storage:thrift-scala",
|
||||
"src/thrift/com/twitter/pink-floyd/thrift:thrift-scala",
|
||||
"src/thrift/com/twitter/recos:recos-common-scala",
|
||||
"src/thrift/com/twitter/recos/user_tweet_entity_graph:user_tweet_entity_graph-scala",
|
||||
"src/thrift/com/twitter/recos/user_user_graph:user_user_graph-scala",
|
||||
"src/thrift/com/twitter/relevance/feature_store:feature_store-scala",
|
||||
"src/thrift/com/twitter/search:earlybird-scala",
|
||||
"src/thrift/com/twitter/search/common:features-scala",
|
||||
"src/thrift/com/twitter/search/query_interaction_graph:query_interaction_graph-scala",
|
||||
"src/thrift/com/twitter/search/query_interaction_graph/service:qig-service-scala",
|
||||
"src/thrift/com/twitter/service/metastore/gen:thrift-scala",
|
||||
"src/thrift/com/twitter/service/scarecrow/gen:scarecrow-scala",
|
||||
"src/thrift/com/twitter/service/scarecrow/gen:tiered-actions-scala",
|
||||
"src/thrift/com/twitter/simclusters_v2:simclusters_v2-thrift-scala",
|
||||
"src/thrift/com/twitter/socialgraph:thrift-scala",
|
||||
"src/thrift/com/twitter/spam/rtf:safety-level-scala",
|
||||
"src/thrift/com/twitter/timelinemixer:thrift-scala",
|
||||
"src/thrift/com/twitter/timelinemixer/server/internal:thrift-scala",
|
||||
"src/thrift/com/twitter/timelines/author_features/user_health:thrift-scala",
|
||||
"src/thrift/com/twitter/timelines/real_graph:real_graph-scala",
|
||||
"src/thrift/com/twitter/timelinescorer:thrift-scala",
|
||||
"src/thrift/com/twitter/timelinescorer/server/internal:thrift-scala",
|
||||
"src/thrift/com/twitter/timelineservice/server/internal:thrift-scala",
|
||||
"src/thrift/com/twitter/timelineservice/server/suggests/logging:thrift-scala",
|
||||
"src/thrift/com/twitter/trends/common:common-scala",
|
||||
"src/thrift/com/twitter/trends/trip_v1:trip-tweets-thrift-scala",
|
||||
"src/thrift/com/twitter/tweetypie:service-scala",
|
||||
"src/thrift/com/twitter/tweetypie:tweet-scala",
|
||||
"src/thrift/com/twitter/user_session_store:thrift-scala",
|
||||
"src/thrift/com/twitter/wtf/candidate:wtf-candidate-scala",
|
||||
"src/thrift/com/twitter/wtf/interest:interest-thrift-scala",
|
||||
"src/thrift/com/twitter/wtf/scalding/common:thrift-scala",
|
||||
"stitch/stitch-core",
|
||||
"stitch/stitch-gizmoduck",
|
||||
"stitch/stitch-socialgraph/src/main/scala",
|
||||
"stitch/stitch-storehaus/src/main/scala",
|
||||
"stitch/stitch-tweetypie/src/main/scala",
|
||||
"storage/clients/manhattan/client/src/main/scala",
|
||||
"strato/config/columns/clients:clients-strato-client",
|
||||
"strato/config/columns/geo/user:user-strato-client",
|
||||
"strato/config/columns/globe/curation:curation-strato-client",
|
||||
"strato/config/columns/interests:interests-strato-client",
|
||||
"strato/config/columns/ml/featureStore:featureStore-strato-client",
|
||||
"strato/config/columns/notifications:notifications-strato-client",
|
||||
"strato/config/columns/notifinfra:notifinfra-strato-client",
|
||||
"strato/config/columns/periscope:periscope-strato-client",
|
||||
"strato/config/columns/rux",
|
||||
"strato/config/columns/rux:rux-strato-client",
|
||||
"strato/config/columns/rux/open-app:open-app-strato-client",
|
||||
"strato/config/columns/socialgraph/graphs:graphs-strato-client",
|
||||
"strato/config/columns/socialgraph/service/soft_users:soft_users-strato-client",
|
||||
"strato/config/columns/translation/service:service-strato-client",
|
||||
"strato/config/columns/translation/service/platform:platform-strato-client",
|
||||
"strato/config/columns/trends/trip:trip-strato-client",
|
||||
"strato/config/src/thrift/com/twitter/strato/columns/frigate:logged-out-web-notifications-scala",
|
||||
"strato/config/src/thrift/com/twitter/strato/columns/notifications:thrift-scala",
|
||||
"strato/src/main/scala/com/twitter/strato/config",
|
||||
"strato/src/main/scala/com/twitter/strato/response",
|
||||
"thrift-web-forms",
|
||||
"timeline-training-service/service/thrift/src/main/thrift:thrift-scala",
|
||||
"timelines/src/main/scala/com/twitter/timelines/features/app",
|
||||
"topic-social-proof/server/src/main/thrift:thrift-scala",
|
||||
"topiclisting/topiclisting-core/src/main/scala/com/twitter/topiclisting",
|
||||
"topiclisting/topiclisting-utt/src/main/scala/com/twitter/topiclisting/utt",
|
||||
"trends/common/src/main/thrift/com/twitter/trends/common:thrift-scala",
|
||||
"tweetypie/src/scala/com/twitter/tweetypie/tweettext",
|
||||
"twitter-context/src/main/scala",
|
||||
"twitter-server-internal",
|
||||
"twitter-server/server/src/main/scala",
|
||||
"twitter-text/lib/java/src/main/java/com/twitter/twittertext",
|
||||
"twml/runtime/src/main/scala/com/twitter/deepbird/runtime/prediction_engine:prediction_engine_mkl",
|
||||
"ubs/common/src/main/thrift/com/twitter/ubs:broadcast-thrift-scala",
|
||||
"ubs/common/src/main/thrift/com/twitter/ubs:seller_application-thrift-scala",
|
||||
"user_session_store/src/main/scala/com/twitter/user_session_store/impl/manhattan/readwrite",
|
||||
"util-internal/scribe",
|
||||
"util-internal/tunable/src/main/scala/com/twitter/util/tunable",
|
||||
"util/util-app",
|
||||
"util/util-hashing/src/main/scala",
|
||||
"util/util-slf4j-api/src/main/scala",
|
||||
"util/util-stats/src/main/scala",
|
||||
"visibility/lib/src/main/scala/com/twitter/visibility/builder",
|
||||
"visibility/lib/src/main/scala/com/twitter/visibility/interfaces/push_service",
|
||||
"visibility/lib/src/main/scala/com/twitter/visibility/interfaces/spaces",
|
||||
"visibility/lib/src/main/scala/com/twitter/visibility/util",
|
||||
],
|
||||
exports = [
|
||||
"strato/config/src/thrift/com/twitter/strato/columns/frigate:logged-out-web-notifications-scala",
|
||||
],
|
||||
)
|
@ -0,0 +1,93 @@
|
||||
package com.twitter.frigate.pushservice
|
||||
|
||||
import com.google.inject.Inject
|
||||
import com.google.inject.Singleton
|
||||
import com.twitter.finagle.mtls.authentication.ServiceIdentifier
|
||||
import com.twitter.finagle.thrift.ClientId
|
||||
import com.twitter.finatra.thrift.routing.ThriftWarmup
|
||||
import com.twitter.util.logging.Logging
|
||||
import com.twitter.inject.utils.Handler
|
||||
import com.twitter.frigate.pushservice.{thriftscala => t}
|
||||
import com.twitter.frigate.thriftscala.NotificationDisplayLocation
|
||||
import com.twitter.util.Stopwatch
|
||||
import com.twitter.scrooge.Request
|
||||
import com.twitter.scrooge.Response
|
||||
import com.twitter.util.Return
|
||||
import com.twitter.util.Throw
|
||||
import com.twitter.util.Try
|
||||
|
||||
/**
|
||||
* Warms up the refresh request path.
|
||||
* If service is running as pushservice-send then the warmup does nothing.
|
||||
*
|
||||
* When making the warmup refresh requests we
|
||||
* - Set skipFilters to true to execute as much of the request path as possible
|
||||
* - Set darkWrite to true to prevent sending a push
|
||||
*/
|
||||
@Singleton
|
||||
class PushMixerThriftServerWarmupHandler @Inject() (
|
||||
warmup: ThriftWarmup,
|
||||
serviceIdentifier: ServiceIdentifier)
|
||||
extends Handler
|
||||
with Logging {
|
||||
|
||||
private val clientId = ClientId("thrift-warmup-client")
|
||||
|
||||
def handle(): Unit = {
|
||||
val refreshServices = Set(
|
||||
"frigate-pushservice",
|
||||
"frigate-pushservice-canary",
|
||||
"frigate-pushservice-canary-control",
|
||||
"frigate-pushservice-canary-treatment"
|
||||
)
|
||||
val isRefresh = refreshServices.contains(serviceIdentifier.service)
|
||||
if (isRefresh && !serviceIdentifier.isLocal) refreshWarmup()
|
||||
}
|
||||
|
||||
def refreshWarmup(): Unit = {
|
||||
val elapsed = Stopwatch.start()
|
||||
val testIds = Seq(
|
||||
1,
|
||||
2,
|
||||
3
|
||||
)
|
||||
try {
|
||||
clientId.asCurrent {
|
||||
testIds.foreach { id =>
|
||||
val warmupReq = warmupQuery(id)
|
||||
info(s"Sending warm-up request to service with query: $warmupReq")
|
||||
warmup.sendRequest(
|
||||
method = t.PushService.Refresh,
|
||||
req = Request(t.PushService.Refresh.Args(warmupReq)))(assertWarmupResponse)
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
case e: Throwable =>
|
||||
error(e.getMessage, e)
|
||||
}
|
||||
info(s"Warm up complete. Time taken: ${elapsed().toString}")
|
||||
}
|
||||
|
||||
private def warmupQuery(userId: Long): t.RefreshRequest = {
|
||||
t.RefreshRequest(
|
||||
userId = userId,
|
||||
notificationDisplayLocation = NotificationDisplayLocation.PushToMobileDevice,
|
||||
context = Some(
|
||||
t.PushContext(
|
||||
skipFilters = Some(true),
|
||||
darkWrite = Some(true)
|
||||
))
|
||||
)
|
||||
}
|
||||
|
||||
private def assertWarmupResponse(
|
||||
result: Try[Response[t.PushService.Refresh.SuccessType]]
|
||||
): Unit = {
|
||||
result match {
|
||||
case Return(_) => // ok
|
||||
case Throw(exception) =>
|
||||
warn("Error performing warm-up request.")
|
||||
error(exception.getMessage, exception)
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,193 @@
|
||||
package com.twitter.frigate.pushservice
|
||||
|
||||
import com.twitter.discovery.common.environment.modules.EnvironmentModule
|
||||
import com.twitter.finagle.Filter
|
||||
import com.twitter.finatra.annotations.DarkTrafficFilterType
|
||||
import com.twitter.finatra.decider.modules.DeciderModule
|
||||
import com.twitter.finatra.http.HttpServer
|
||||
import com.twitter.finatra.http.filters.CommonFilters
|
||||
import com.twitter.finatra.http.routing.HttpRouter
|
||||
import com.twitter.finatra.mtls.http.{Mtls => HttpMtls}
|
||||
import com.twitter.finatra.mtls.thriftmux.{Mtls => ThriftMtls}
|
||||
import com.twitter.finatra.mtls.thriftmux.filters.MtlsServerSessionTrackerFilter
|
||||
import com.twitter.finatra.thrift.ThriftServer
|
||||
import com.twitter.finatra.thrift.filters.ExceptionMappingFilter
|
||||
import com.twitter.finatra.thrift.filters.LoggingMDCFilter
|
||||
import com.twitter.finatra.thrift.filters.StatsFilter
|
||||
import com.twitter.finatra.thrift.filters.ThriftMDCFilter
|
||||
import com.twitter.finatra.thrift.filters.TraceIdMDCFilter
|
||||
import com.twitter.finatra.thrift.routing.ThriftRouter
|
||||
import com.twitter.frigate.common.logger.MRLoggerGlobalVariables
|
||||
import com.twitter.frigate.pushservice.controller.PushServiceController
|
||||
import com.twitter.frigate.pushservice.module._
|
||||
import com.twitter.inject.TwitterModule
|
||||
import com.twitter.inject.annotations.Flags
|
||||
import com.twitter.inject.thrift.modules.ThriftClientIdModule
|
||||
import com.twitter.logging.BareFormatter
|
||||
import com.twitter.logging.Level
|
||||
import com.twitter.logging.LoggerFactory
|
||||
import com.twitter.logging.{Logging => JLogging}
|
||||
import com.twitter.logging.QueueingHandler
|
||||
import com.twitter.logging.ScribeHandler
|
||||
import com.twitter.product_mixer.core.module.product_mixer_flags.ProductMixerFlagModule
|
||||
import com.twitter.product_mixer.core.module.ABDeciderModule
|
||||
import com.twitter.product_mixer.core.module.FeatureSwitchesModule
|
||||
import com.twitter.product_mixer.core.module.StratoClientModule
|
||||
|
||||
object PushServiceMain extends PushServiceFinatraServer
|
||||
|
||||
class PushServiceFinatraServer
|
||||
extends ThriftServer
|
||||
with ThriftMtls
|
||||
with HttpServer
|
||||
with HttpMtls
|
||||
with JLogging {
|
||||
|
||||
override val name = "PushService"
|
||||
|
||||
override val modules: Seq[TwitterModule] = {
|
||||
Seq(
|
||||
ABDeciderModule,
|
||||
DeciderModule,
|
||||
FeatureSwitchesModule,
|
||||
FilterModule,
|
||||
FlagModule,
|
||||
EnvironmentModule,
|
||||
ThriftClientIdModule,
|
||||
DeployConfigModule,
|
||||
ProductMixerFlagModule,
|
||||
StratoClientModule,
|
||||
PushHandlerModule,
|
||||
PushTargetUserBuilderModule,
|
||||
PushServiceDarkTrafficModule,
|
||||
LoggedOutPushTargetUserBuilderModule,
|
||||
new ThriftWebFormsModule(this),
|
||||
)
|
||||
}
|
||||
|
||||
override def configureThrift(router: ThriftRouter): Unit = {
|
||||
router
|
||||
.filter[ExceptionMappingFilter]
|
||||
.filter[LoggingMDCFilter]
|
||||
.filter[TraceIdMDCFilter]
|
||||
.filter[ThriftMDCFilter]
|
||||
.filter[MtlsServerSessionTrackerFilter]
|
||||
.filter[StatsFilter]
|
||||
.filter[Filter.TypeAgnostic, DarkTrafficFilterType]
|
||||
.add[PushServiceController]
|
||||
}
|
||||
|
||||
override def configureHttp(router: HttpRouter): Unit =
|
||||
router
|
||||
.filter[CommonFilters]
|
||||
|
||||
override protected def start(): Unit = {
|
||||
MRLoggerGlobalVariables.setRequiredFlags(
|
||||
traceLogFlag = injector.instance[Boolean](Flags.named(FlagModule.mrLoggerIsTraceAll.name)),
|
||||
nthLogFlag = injector.instance[Boolean](Flags.named(FlagModule.mrLoggerNthLog.name)),
|
||||
nthLogValFlag = injector.instance[Long](Flags.named(FlagModule.mrLoggerNthVal.name))
|
||||
)
|
||||
}
|
||||
|
||||
override protected def warmup(): Unit = {
|
||||
handle[PushMixerThriftServerWarmupHandler]()
|
||||
}
|
||||
|
||||
override protected def configureLoggerFactories(): Unit = {
|
||||
loggerFactories.foreach { _() }
|
||||
}
|
||||
|
||||
override def loggerFactories: List[LoggerFactory] = {
|
||||
val scribeScope = statsReceiver.scope("scribe")
|
||||
List(
|
||||
LoggerFactory(
|
||||
level = Some(levelFlag()),
|
||||
handlers = handlers
|
||||
),
|
||||
LoggerFactory(
|
||||
node = "request_scribe",
|
||||
level = Some(Level.INFO),
|
||||
useParents = false,
|
||||
handlers = QueueingHandler(
|
||||
maxQueueSize = 10000,
|
||||
handler = ScribeHandler(
|
||||
category = "frigate_pushservice_log",
|
||||
formatter = BareFormatter,
|
||||
statsReceiver = scribeScope.scope("frigate_pushservice_log")
|
||||
)
|
||||
) :: Nil
|
||||
),
|
||||
LoggerFactory(
|
||||
node = "notification_scribe",
|
||||
level = Some(Level.INFO),
|
||||
useParents = false,
|
||||
handlers = QueueingHandler(
|
||||
maxQueueSize = 10000,
|
||||
handler = ScribeHandler(
|
||||
category = "frigate_notifier",
|
||||
formatter = BareFormatter,
|
||||
statsReceiver = scribeScope.scope("frigate_notifier")
|
||||
)
|
||||
) :: Nil
|
||||
),
|
||||
LoggerFactory(
|
||||
node = "push_scribe",
|
||||
level = Some(Level.INFO),
|
||||
useParents = false,
|
||||
handlers = QueueingHandler(
|
||||
maxQueueSize = 10000,
|
||||
handler = ScribeHandler(
|
||||
category = "test_frigate_push",
|
||||
formatter = BareFormatter,
|
||||
statsReceiver = scribeScope.scope("test_frigate_push")
|
||||
)
|
||||
) :: Nil
|
||||
),
|
||||
LoggerFactory(
|
||||
node = "push_subsample_scribe",
|
||||
level = Some(Level.INFO),
|
||||
useParents = false,
|
||||
handlers = QueueingHandler(
|
||||
maxQueueSize = 2500,
|
||||
handler = ScribeHandler(
|
||||
category = "magicrecs_candidates_subsample_scribe",
|
||||
maxMessagesPerTransaction = 250,
|
||||
maxMessagesToBuffer = 2500,
|
||||
formatter = BareFormatter,
|
||||
statsReceiver = scribeScope.scope("magicrecs_candidates_subsample_scribe")
|
||||
)
|
||||
) :: Nil
|
||||
),
|
||||
LoggerFactory(
|
||||
node = "mr_request_scribe",
|
||||
level = Some(Level.INFO),
|
||||
useParents = false,
|
||||
handlers = QueueingHandler(
|
||||
maxQueueSize = 2500,
|
||||
handler = ScribeHandler(
|
||||
category = "mr_request_scribe",
|
||||
maxMessagesPerTransaction = 250,
|
||||
maxMessagesToBuffer = 2500,
|
||||
formatter = BareFormatter,
|
||||
statsReceiver = scribeScope.scope("mr_request_scribe")
|
||||
)
|
||||
) :: Nil
|
||||
),
|
||||
LoggerFactory(
|
||||
node = "high_quality_candidates_scribe",
|
||||
level = Some(Level.INFO),
|
||||
useParents = false,
|
||||
handlers = QueueingHandler(
|
||||
maxQueueSize = 2500,
|
||||
handler = ScribeHandler(
|
||||
category = "frigate_high_quality_candidates_log",
|
||||
maxMessagesPerTransaction = 250,
|
||||
maxMessagesToBuffer = 2500,
|
||||
formatter = BareFormatter,
|
||||
statsReceiver = scribeScope.scope("high_quality_candidates_scribe")
|
||||
)
|
||||
) :: Nil
|
||||
),
|
||||
)
|
||||
}
|
||||
}
|
@ -0,0 +1,323 @@
|
||||
package com.twitter.frigate.pushservice.adaptor
|
||||
|
||||
import com.twitter.contentrecommender.thriftscala.MetricTag
|
||||
import com.twitter.cr_mixer.thriftscala.CrMixerTweetRequest
|
||||
import com.twitter.cr_mixer.thriftscala.NotificationsContext
|
||||
import com.twitter.cr_mixer.thriftscala.Product
|
||||
import com.twitter.cr_mixer.thriftscala.ProductContext
|
||||
import com.twitter.cr_mixer.thriftscala.{MetricTag => CrMixerMetricTag}
|
||||
import com.twitter.finagle.stats.Stat
|
||||
import com.twitter.finagle.stats.StatsReceiver
|
||||
import com.twitter.frigate.common.base.AlgorithmScore
|
||||
import com.twitter.frigate.common.base.CandidateSource
|
||||
import com.twitter.frigate.common.base.CandidateSourceEligible
|
||||
import com.twitter.frigate.common.base.CrMixerCandidate
|
||||
import com.twitter.frigate.common.base.TopicCandidate
|
||||
import com.twitter.frigate.common.base.TopicProofTweetCandidate
|
||||
import com.twitter.frigate.common.base.TweetCandidate
|
||||
import com.twitter.frigate.common.predicate.CommonOutNetworkTweetCandidatesSourcePredicates.filterOutInNetworkTweets
|
||||
import com.twitter.frigate.common.predicate.CommonOutNetworkTweetCandidatesSourcePredicates.filterOutReplyTweet
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.RawCandidate
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.Target
|
||||
import com.twitter.frigate.pushservice.params.PushFeatureSwitchParams
|
||||
import com.twitter.frigate.pushservice.params.PushParams
|
||||
import com.twitter.frigate.pushservice.store.CrMixerTweetStore
|
||||
import com.twitter.frigate.pushservice.store.UttEntityHydrationStore
|
||||
import com.twitter.frigate.pushservice.util.AdaptorUtils
|
||||
import com.twitter.frigate.pushservice.util.PushDeviceUtil
|
||||
import com.twitter.frigate.pushservice.util.TopicsUtil
|
||||
import com.twitter.frigate.pushservice.util.TweetWithTopicProof
|
||||
import com.twitter.frigate.thriftscala.CommonRecommendationType
|
||||
import com.twitter.hermit.predicate.socialgraph.RelationEdge
|
||||
import com.twitter.product_mixer.core.thriftscala.ClientContext
|
||||
import com.twitter.stitch.tweetypie.TweetyPie.TweetyPieResult
|
||||
import com.twitter.storehaus.ReadableStore
|
||||
import com.twitter.topiclisting.utt.LocalizedEntity
|
||||
import com.twitter.tsp.thriftscala.TopicSocialProofRequest
|
||||
import com.twitter.tsp.thriftscala.TopicSocialProofResponse
|
||||
import com.twitter.util.Future
|
||||
import scala.collection.Map
|
||||
|
||||
case class ContentRecommenderMixerAdaptor(
|
||||
crMixerTweetStore: CrMixerTweetStore,
|
||||
tweetyPieStore: ReadableStore[Long, TweetyPieResult],
|
||||
edgeStore: ReadableStore[RelationEdge, Boolean],
|
||||
topicSocialProofServiceStore: ReadableStore[TopicSocialProofRequest, TopicSocialProofResponse],
|
||||
uttEntityHydrationStore: UttEntityHydrationStore,
|
||||
globalStats: StatsReceiver)
|
||||
extends CandidateSource[Target, RawCandidate]
|
||||
with CandidateSourceEligible[Target, RawCandidate] {
|
||||
|
||||
override val name: String = this.getClass.getSimpleName
|
||||
|
||||
private[this] val stats = globalStats.scope("ContentRecommenderMixerAdaptor")
|
||||
private[this] val numOfValidAuthors = stats.stat("num_of_valid_authors")
|
||||
private[this] val numOutOfMaximumDropped = stats.stat("dropped_due_out_of_maximum")
|
||||
private[this] val totalInputRecs = stats.counter("input_recs")
|
||||
private[this] val totalOutputRecs = stats.stat("output_recs")
|
||||
private[this] val totalRequests = stats.counter("total_requests")
|
||||
private[this] val nonReplyTweetsCounter = stats.counter("non_reply_tweets")
|
||||
private[this] val totalOutNetworkRecs = stats.counter("out_network_tweets")
|
||||
private[this] val totalInNetworkRecs = stats.counter("in_network_tweets")
|
||||
|
||||
/**
|
||||
* Builds OON raw candidates based on input OON Tweets
|
||||
*/
|
||||
def buildOONRawCandidates(
|
||||
inputTarget: Target,
|
||||
oonTweets: Seq[TweetyPieResult],
|
||||
tweetScoreMap: Map[Long, Double],
|
||||
tweetIdToTagsMap: Map[Long, Seq[CrMixerMetricTag]],
|
||||
maxNumOfCandidates: Int
|
||||
): Option[Seq[RawCandidate]] = {
|
||||
val cands = oonTweets.flatMap { tweetResult =>
|
||||
val tweetId = tweetResult.tweet.id
|
||||
generateOONRawCandidate(
|
||||
inputTarget,
|
||||
tweetId,
|
||||
Some(tweetResult),
|
||||
tweetScoreMap,
|
||||
tweetIdToTagsMap
|
||||
)
|
||||
}
|
||||
|
||||
val candidates = restrict(
|
||||
maxNumOfCandidates,
|
||||
cands,
|
||||
numOutOfMaximumDropped,
|
||||
totalOutputRecs
|
||||
)
|
||||
|
||||
Some(candidates)
|
||||
}
|
||||
|
||||
/**
|
||||
* Builds a single RawCandidate With TopicProofTweetCandidate
|
||||
*/
|
||||
def buildTopicTweetRawCandidate(
|
||||
inputTarget: Target,
|
||||
tweetWithTopicProof: TweetWithTopicProof,
|
||||
localizedEntity: LocalizedEntity,
|
||||
tags: Option[Seq[MetricTag]],
|
||||
): RawCandidate with TopicProofTweetCandidate = {
|
||||
new RawCandidate with TopicProofTweetCandidate {
|
||||
override def target: Target = inputTarget
|
||||
override def topicListingSetting: Option[String] = Some(
|
||||
tweetWithTopicProof.topicListingSetting)
|
||||
override def tweetId: Long = tweetWithTopicProof.tweetId
|
||||
override def tweetyPieResult: Option[TweetyPieResult] = Some(
|
||||
tweetWithTopicProof.tweetyPieResult)
|
||||
override def semanticCoreEntityId: Option[Long] = Some(tweetWithTopicProof.topicId)
|
||||
override def localizedUttEntity: Option[LocalizedEntity] = Some(localizedEntity)
|
||||
override def algorithmCR: Option[String] = tweetWithTopicProof.algorithmCR
|
||||
override def tagsCR: Option[Seq[MetricTag]] = tags
|
||||
override def isOutOfNetwork: Boolean = tweetWithTopicProof.isOON
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Takes a group of TopicTweets and transforms them into RawCandidates
|
||||
*/
|
||||
def buildTopicTweetRawCandidates(
|
||||
inputTarget: Target,
|
||||
topicProofCandidates: Seq[TweetWithTopicProof],
|
||||
tweetIdToTagsMap: Map[Long, Seq[CrMixerMetricTag]],
|
||||
maxNumberOfCands: Int
|
||||
): Future[Option[Seq[RawCandidate]]] = {
|
||||
val semanticCoreEntityIds = topicProofCandidates
|
||||
.map(_.topicId)
|
||||
.toSet
|
||||
|
||||
TopicsUtil
|
||||
.getLocalizedEntityMap(inputTarget, semanticCoreEntityIds, uttEntityHydrationStore)
|
||||
.map { localizedEntityMap =>
|
||||
val rawCandidates = topicProofCandidates.collect {
|
||||
case topicSocialProof: TweetWithTopicProof
|
||||
if localizedEntityMap.contains(topicSocialProof.topicId) =>
|
||||
// Once we deprecate CR calls, we should replace this code to use the CrMixerMetricTag
|
||||
val tags = tweetIdToTagsMap.get(topicSocialProof.tweetId).map {
|
||||
_.flatMap { tag => MetricTag.get(tag.value) }
|
||||
}
|
||||
buildTopicTweetRawCandidate(
|
||||
inputTarget,
|
||||
topicSocialProof,
|
||||
localizedEntityMap(topicSocialProof.topicId),
|
||||
tags
|
||||
)
|
||||
}
|
||||
|
||||
val candResult = restrict(
|
||||
maxNumberOfCands,
|
||||
rawCandidates,
|
||||
numOutOfMaximumDropped,
|
||||
totalOutputRecs
|
||||
)
|
||||
|
||||
Some(candResult)
|
||||
}
|
||||
}
|
||||
|
||||
private def generateOONRawCandidate(
|
||||
inputTarget: Target,
|
||||
id: Long,
|
||||
result: Option[TweetyPieResult],
|
||||
tweetScoreMap: Map[Long, Double],
|
||||
tweetIdToTagsMap: Map[Long, Seq[CrMixerMetricTag]]
|
||||
): Option[RawCandidate with TweetCandidate] = {
|
||||
val tagsFromCR = tweetIdToTagsMap.get(id).map { _.flatMap { tag => MetricTag.get(tag.value) } }
|
||||
val candidate = new RawCandidate with CrMixerCandidate with TopicCandidate with AlgorithmScore {
|
||||
override val tweetId = id
|
||||
override val target = inputTarget
|
||||
override val tweetyPieResult = result
|
||||
override val localizedUttEntity = None
|
||||
override val semanticCoreEntityId = None
|
||||
override def commonRecType =
|
||||
getMediaBasedCRT(
|
||||
CommonRecommendationType.TwistlyTweet,
|
||||
CommonRecommendationType.TwistlyPhoto,
|
||||
CommonRecommendationType.TwistlyVideo)
|
||||
override def tagsCR = tagsFromCR
|
||||
override def algorithmScore = tweetScoreMap.get(id)
|
||||
override def algorithmCR = None
|
||||
}
|
||||
Some(candidate)
|
||||
}
|
||||
|
||||
private def restrict(
|
||||
maxNumToReturn: Int,
|
||||
candidates: Seq[RawCandidate],
|
||||
numOutOfMaximumDropped: Stat,
|
||||
totalOutputRecs: Stat
|
||||
): Seq[RawCandidate] = {
|
||||
val newCandidates = candidates.take(maxNumToReturn)
|
||||
val numDropped = candidates.length - newCandidates.length
|
||||
numOutOfMaximumDropped.add(numDropped)
|
||||
totalOutputRecs.add(newCandidates.size)
|
||||
newCandidates
|
||||
}
|
||||
|
||||
private def buildCrMixerRequest(
|
||||
target: Target,
|
||||
countryCode: Option[String],
|
||||
language: Option[String],
|
||||
seenTweets: Seq[Long]
|
||||
): CrMixerTweetRequest = {
|
||||
CrMixerTweetRequest(
|
||||
clientContext = ClientContext(
|
||||
userId = Some(target.targetId),
|
||||
countryCode = countryCode,
|
||||
languageCode = language
|
||||
),
|
||||
product = Product.Notifications,
|
||||
productContext = Some(ProductContext.NotificationsContext(NotificationsContext())),
|
||||
excludedTweetIds = Some(seenTweets)
|
||||
)
|
||||
}
|
||||
|
||||
private def selectCandidatesToSendBasedOnSettings(
|
||||
isRecommendationsEligible: Boolean,
|
||||
isTopicsEligible: Boolean,
|
||||
oonRawCandidates: Option[Seq[RawCandidate]],
|
||||
topicTweetCandidates: Option[Seq[RawCandidate]]
|
||||
): Option[Seq[RawCandidate]] = {
|
||||
if (isRecommendationsEligible && isTopicsEligible) {
|
||||
Some(topicTweetCandidates.getOrElse(Seq.empty) ++ oonRawCandidates.getOrElse(Seq.empty))
|
||||
} else if (isRecommendationsEligible) {
|
||||
oonRawCandidates
|
||||
} else if (isTopicsEligible) {
|
||||
topicTweetCandidates
|
||||
} else None
|
||||
}
|
||||
|
||||
override def get(target: Target): Future[Option[Seq[RawCandidate]]] = {
|
||||
Future
|
||||
.join(
|
||||
target.seenTweetIds,
|
||||
target.countryCode,
|
||||
target.inferredUserDeviceLanguage,
|
||||
PushDeviceUtil.isTopicsEligible(target),
|
||||
PushDeviceUtil.isRecommendationsEligible(target)
|
||||
).flatMap {
|
||||
case (seenTweets, countryCode, language, isTopicsEligible, isRecommendationsEligible) =>
|
||||
val request = buildCrMixerRequest(target, countryCode, language, seenTweets)
|
||||
crMixerTweetStore.getTweetRecommendations(request).flatMap {
|
||||
case Some(response) =>
|
||||
totalInputRecs.incr(response.tweets.size)
|
||||
totalRequests.incr()
|
||||
AdaptorUtils
|
||||
.getTweetyPieResults(
|
||||
response.tweets.map(_.tweetId).toSet,
|
||||
tweetyPieStore).flatMap { tweetyPieResultMap =>
|
||||
filterOutInNetworkTweets(
|
||||
target,
|
||||
filterOutReplyTweet(tweetyPieResultMap.toMap, nonReplyTweetsCounter),
|
||||
edgeStore,
|
||||
numOfValidAuthors).flatMap {
|
||||
outNetworkTweetsWithId: Seq[(Long, TweetyPieResult)] =>
|
||||
totalOutNetworkRecs.incr(outNetworkTweetsWithId.size)
|
||||
totalInNetworkRecs.incr(response.tweets.size - outNetworkTweetsWithId.size)
|
||||
val outNetworkTweets: Seq[TweetyPieResult] = outNetworkTweetsWithId.map {
|
||||
case (_, tweetyPieResult) => tweetyPieResult
|
||||
}
|
||||
|
||||
val tweetIdToTagsMap = response.tweets.map { tweet =>
|
||||
tweet.tweetId -> tweet.metricTags.getOrElse(Seq.empty)
|
||||
}.toMap
|
||||
|
||||
val tweetScoreMap = response.tweets.map { tweet =>
|
||||
tweet.tweetId -> tweet.score
|
||||
}.toMap
|
||||
|
||||
val maxNumOfCandidates =
|
||||
target.params(PushFeatureSwitchParams.NumberOfMaxCrMixerCandidatesParam)
|
||||
|
||||
val oonRawCandidates =
|
||||
buildOONRawCandidates(
|
||||
target,
|
||||
outNetworkTweets,
|
||||
tweetScoreMap,
|
||||
tweetIdToTagsMap,
|
||||
maxNumOfCandidates)
|
||||
|
||||
TopicsUtil
|
||||
.getTopicSocialProofs(
|
||||
target,
|
||||
outNetworkTweets,
|
||||
topicSocialProofServiceStore,
|
||||
edgeStore,
|
||||
PushFeatureSwitchParams.TopicProofTweetCandidatesTopicScoreThreshold).flatMap {
|
||||
tweetsWithTopicProof =>
|
||||
buildTopicTweetRawCandidates(
|
||||
target,
|
||||
tweetsWithTopicProof,
|
||||
tweetIdToTagsMap,
|
||||
maxNumOfCandidates)
|
||||
}.map { topicTweetCandidates =>
|
||||
selectCandidatesToSendBasedOnSettings(
|
||||
isRecommendationsEligible,
|
||||
isTopicsEligible,
|
||||
oonRawCandidates,
|
||||
topicTweetCandidates)
|
||||
}
|
||||
}
|
||||
}
|
||||
case _ => Future.None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* For a user to be available the following news to happen
|
||||
*/
|
||||
override def isCandidateSourceAvailable(target: Target): Future[Boolean] = {
|
||||
Future
|
||||
.join(
|
||||
PushDeviceUtil.isRecommendationsEligible(target),
|
||||
PushDeviceUtil.isTopicsEligible(target)
|
||||
).map {
|
||||
case (isRecommendationsEligible, isTopicsEligible) =>
|
||||
(isRecommendationsEligible || isTopicsEligible) &&
|
||||
target.params(PushParams.ContentRecommenderMixerAdaptorDecider)
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,293 @@
|
||||
package com.twitter.frigate.pushservice.adaptor
|
||||
|
||||
import com.twitter.finagle.stats.Stat
|
||||
import com.twitter.finagle.stats.StatsReceiver
|
||||
import com.twitter.frigate.common.base._
|
||||
import com.twitter.frigate.common.candidate._
|
||||
import com.twitter.frigate.common.predicate.CommonOutNetworkTweetCandidatesSourcePredicates.filterOutReplyTweet
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.RawCandidate
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.Target
|
||||
import com.twitter.frigate.pushservice.params.PushFeatureSwitchParams
|
||||
import com.twitter.frigate.pushservice.params.PushParams
|
||||
import com.twitter.frigate.pushservice.util.PushDeviceUtil
|
||||
import com.twitter.hermit.store.tweetypie.UserTweet
|
||||
import com.twitter.recos.recos_common.thriftscala.SocialProofType
|
||||
import com.twitter.search.common.features.thriftscala.ThriftSearchResultFeatures
|
||||
import com.twitter.stitch.tweetypie.TweetyPie.TweetyPieResult
|
||||
import com.twitter.storehaus.ReadableStore
|
||||
import com.twitter.timelines.configapi.Param
|
||||
import com.twitter.util.Future
|
||||
import com.twitter.util.Time
|
||||
import scala.collection.Map
|
||||
|
||||
case class EarlyBirdFirstDegreeCandidateAdaptor(
|
||||
earlyBirdFirstDegreeCandidates: CandidateSource[
|
||||
EarlybirdCandidateSource.Query,
|
||||
EarlybirdCandidate
|
||||
],
|
||||
tweetyPieStore: ReadableStore[Long, TweetyPieResult],
|
||||
tweetyPieStoreNoVF: ReadableStore[Long, TweetyPieResult],
|
||||
userTweetTweetyPieStore: ReadableStore[UserTweet, TweetyPieResult],
|
||||
maxResultsParam: Param[Int],
|
||||
globalStats: StatsReceiver)
|
||||
extends CandidateSource[Target, RawCandidate]
|
||||
with CandidateSourceEligible[Target, RawCandidate] {
|
||||
|
||||
type EBCandidate = EarlybirdCandidate with TweetDetails
|
||||
private val stats = globalStats.scope("EarlyBirdFirstDegreeAdaptor")
|
||||
private val earlyBirdCandsStat: Stat = stats.stat("early_bird_cands_dist")
|
||||
private val emptyEarlyBirdCands = stats.counter("empty_early_bird_candidates")
|
||||
private val seedSetEmpty = stats.counter("empty_seedset")
|
||||
private val seenTweetsStat = stats.stat("filtered_by_seen_tweets")
|
||||
private val emptyTweetyPieResult = stats.stat("empty_tweetypie_result")
|
||||
private val nonReplyTweetsCounter = stats.counter("non_reply_tweets")
|
||||
private val enableRetweets = stats.counter("enable_retweets")
|
||||
private val f1withoutSocialContexts = stats.counter("f1_without_social_context")
|
||||
private val userTweetTweetyPieStoreCounter = stats.counter("user_tweet_tweetypie_store")
|
||||
|
||||
override val name: String = earlyBirdFirstDegreeCandidates.name
|
||||
|
||||
private def getAllSocialContextActions(
|
||||
socialProofTypes: Seq[(SocialProofType, Seq[Long])]
|
||||
): Seq[SocialContextAction] = {
|
||||
socialProofTypes.flatMap {
|
||||
case (SocialProofType.Favorite, scIds) =>
|
||||
scIds.map { scId =>
|
||||
SocialContextAction(
|
||||
scId,
|
||||
Time.now.inMilliseconds,
|
||||
socialContextActionType = Some(SocialContextActionType.Favorite)
|
||||
)
|
||||
}
|
||||
case (SocialProofType.Retweet, scIds) =>
|
||||
scIds.map { scId =>
|
||||
SocialContextAction(
|
||||
scId,
|
||||
Time.now.inMilliseconds,
|
||||
socialContextActionType = Some(SocialContextActionType.Retweet)
|
||||
)
|
||||
}
|
||||
case (SocialProofType.Reply, scIds) =>
|
||||
scIds.map { scId =>
|
||||
SocialContextAction(
|
||||
scId,
|
||||
Time.now.inMilliseconds,
|
||||
socialContextActionType = Some(SocialContextActionType.Reply)
|
||||
)
|
||||
}
|
||||
case (SocialProofType.Tweet, scIds) =>
|
||||
scIds.map { scId =>
|
||||
SocialContextAction(
|
||||
scId,
|
||||
Time.now.inMilliseconds,
|
||||
socialContextActionType = Some(SocialContextActionType.Tweet)
|
||||
)
|
||||
}
|
||||
case _ => Nil
|
||||
}
|
||||
}
|
||||
|
||||
private def generateRetweetCandidate(
|
||||
inputTarget: Target,
|
||||
candidate: EBCandidate,
|
||||
scIds: Seq[Long],
|
||||
socialProofTypes: Seq[(SocialProofType, Seq[Long])]
|
||||
): RawCandidate = {
|
||||
val scActions = scIds.map { scId => SocialContextAction(scId, Time.now.inMilliseconds) }
|
||||
new RawCandidate with TweetRetweetCandidate with EarlybirdTweetFeatures {
|
||||
override val socialContextActions = scActions
|
||||
override val socialContextAllTypeActions = getAllSocialContextActions(socialProofTypes)
|
||||
override val tweetId = candidate.tweetId
|
||||
override val target = inputTarget
|
||||
override val tweetyPieResult = candidate.tweetyPieResult
|
||||
override val features = candidate.features
|
||||
}
|
||||
}
|
||||
|
||||
private def generateF1CandidateWithoutSocialContext(
|
||||
inputTarget: Target,
|
||||
candidate: EBCandidate
|
||||
): RawCandidate = {
|
||||
f1withoutSocialContexts.incr()
|
||||
new RawCandidate with F1FirstDegree with EarlybirdTweetFeatures {
|
||||
override val tweetId = candidate.tweetId
|
||||
override val target = inputTarget
|
||||
override val tweetyPieResult = candidate.tweetyPieResult
|
||||
override val features = candidate.features
|
||||
}
|
||||
}
|
||||
|
||||
private def generateEarlyBirdCandidate(
|
||||
id: Long,
|
||||
result: Option[TweetyPieResult],
|
||||
ebFeatures: Option[ThriftSearchResultFeatures]
|
||||
): EBCandidate = {
|
||||
new EarlybirdCandidate with TweetDetails {
|
||||
override val tweetyPieResult: Option[TweetyPieResult] = result
|
||||
override val tweetId: Long = id
|
||||
override val features: Option[ThriftSearchResultFeatures] = ebFeatures
|
||||
}
|
||||
}
|
||||
|
||||
private def filterOutSeenTweets(seenTweetIds: Seq[Long], inputTweetIds: Seq[Long]): Seq[Long] = {
|
||||
inputTweetIds.filterNot(seenTweetIds.contains)
|
||||
}
|
||||
|
||||
private def filterInvalidTweets(
|
||||
tweetIds: Seq[Long],
|
||||
target: Target
|
||||
): Future[Seq[(Long, TweetyPieResult)]] = {
|
||||
|
||||
val resMap = {
|
||||
if (target.params(PushFeatureSwitchParams.EnableF1FromProtectedTweetAuthors)) {
|
||||
userTweetTweetyPieStoreCounter.incr()
|
||||
val keys = tweetIds.map { tweetId =>
|
||||
UserTweet(tweetId, Some(target.targetId))
|
||||
}
|
||||
|
||||
userTweetTweetyPieStore
|
||||
.multiGet(keys.toSet).map {
|
||||
case (userTweet, resultFut) =>
|
||||
userTweet.tweetId -> resultFut
|
||||
}.toMap
|
||||
} else {
|
||||
(target.params(PushFeatureSwitchParams.EnableVFInTweetypie) match {
|
||||
case true => tweetyPieStore
|
||||
case false => tweetyPieStoreNoVF
|
||||
}).multiGet(tweetIds.toSet)
|
||||
}
|
||||
}
|
||||
Future.collect(resMap).map { tweetyPieResultMap =>
|
||||
val cands = filterOutReplyTweet(tweetyPieResultMap, nonReplyTweetsCounter).collect {
|
||||
case (id: Long, Some(result)) =>
|
||||
id -> result
|
||||
}
|
||||
|
||||
emptyTweetyPieResult.add(tweetyPieResultMap.size - cands.size)
|
||||
cands.toSeq
|
||||
}
|
||||
}
|
||||
|
||||
private def getEBRetweetCandidates(
|
||||
inputTarget: Target,
|
||||
retweets: Seq[(Long, TweetyPieResult)]
|
||||
): Seq[RawCandidate] = {
|
||||
retweets.flatMap {
|
||||
case (_, tweetypieResult) =>
|
||||
tweetypieResult.tweet.coreData.flatMap { coreData =>
|
||||
tweetypieResult.sourceTweet.map { sourceTweet =>
|
||||
val tweetId = sourceTweet.id
|
||||
val scId = coreData.userId
|
||||
val socialProofTypes = Seq((SocialProofType.Retweet, Seq(scId)))
|
||||
val candidate = generateEarlyBirdCandidate(
|
||||
tweetId,
|
||||
Some(TweetyPieResult(sourceTweet, None, None)),
|
||||
None
|
||||
)
|
||||
generateRetweetCandidate(
|
||||
inputTarget,
|
||||
candidate,
|
||||
Seq(scId),
|
||||
socialProofTypes
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private def getEBFirstDegreeCands(
|
||||
tweets: Seq[(Long, TweetyPieResult)],
|
||||
ebTweetIdMap: Map[Long, Option[ThriftSearchResultFeatures]]
|
||||
): Seq[EBCandidate] = {
|
||||
tweets.map {
|
||||
case (id, tweetypieResult) =>
|
||||
val features = ebTweetIdMap.getOrElse(id, None)
|
||||
generateEarlyBirdCandidate(id, Some(tweetypieResult), features)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a combination of raw candidates made of: f1 recs, topic social proof recs, sc recs and retweet candidates
|
||||
*/
|
||||
def buildRawCandidates(
|
||||
inputTarget: Target,
|
||||
firstDegreeCandidates: Seq[EBCandidate],
|
||||
retweetCandidates: Seq[RawCandidate]
|
||||
): Seq[RawCandidate] = {
|
||||
val hydratedF1Recs =
|
||||
firstDegreeCandidates.map(generateF1CandidateWithoutSocialContext(inputTarget, _))
|
||||
hydratedF1Recs ++ retweetCandidates
|
||||
}
|
||||
|
||||
override def get(inputTarget: Target): Future[Option[Seq[RawCandidate]]] = {
|
||||
inputTarget.seedsWithWeight.flatMap { seedsetOpt =>
|
||||
val seedsetMap = seedsetOpt.getOrElse(Map.empty)
|
||||
|
||||
if (seedsetMap.isEmpty) {
|
||||
seedSetEmpty.incr()
|
||||
Future.None
|
||||
} else {
|
||||
val maxResultsToReturn = inputTarget.params(maxResultsParam)
|
||||
val maxTweetAge = inputTarget.params(PushFeatureSwitchParams.F1CandidateMaxTweetAgeParam)
|
||||
val earlybirdQuery = EarlybirdCandidateSource.Query(
|
||||
maxNumResultsToReturn = maxResultsToReturn,
|
||||
seedset = seedsetMap,
|
||||
maxConsecutiveResultsByTheSameUser = Some(1),
|
||||
maxTweetAge = maxTweetAge,
|
||||
disableTimelinesMLModel = false,
|
||||
searcherId = Some(inputTarget.targetId),
|
||||
isProtectTweetsEnabled =
|
||||
inputTarget.params(PushFeatureSwitchParams.EnableF1FromProtectedTweetAuthors),
|
||||
followedUserIds = Some(seedsetMap.keySet.toSeq)
|
||||
)
|
||||
|
||||
Future
|
||||
.join(inputTarget.seenTweetIds, earlyBirdFirstDegreeCandidates.get(earlybirdQuery))
|
||||
.flatMap {
|
||||
case (seenTweetIds, Some(candidates)) =>
|
||||
earlyBirdCandsStat.add(candidates.size)
|
||||
|
||||
val ebTweetIdMap = candidates.map { cand => cand.tweetId -> cand.features }.toMap
|
||||
|
||||
val ebTweetIds = ebTweetIdMap.keys.toSeq
|
||||
|
||||
val tweetIds = filterOutSeenTweets(seenTweetIds, ebTweetIds)
|
||||
seenTweetsStat.add(ebTweetIds.size - tweetIds.size)
|
||||
|
||||
filterInvalidTweets(tweetIds, inputTarget)
|
||||
.map { validTweets =>
|
||||
val (retweets, tweets) = validTweets.partition {
|
||||
case (_, tweetypieResult) =>
|
||||
tweetypieResult.sourceTweet.isDefined
|
||||
}
|
||||
|
||||
val firstDegreeCandidates = getEBFirstDegreeCands(tweets, ebTweetIdMap)
|
||||
|
||||
val retweetCandidates = {
|
||||
if (inputTarget.params(PushParams.EarlyBirdSCBasedCandidatesParam) &&
|
||||
inputTarget.params(PushParams.MRTweetRetweetRecsParam)) {
|
||||
enableRetweets.incr()
|
||||
getEBRetweetCandidates(inputTarget, retweets)
|
||||
} else Nil
|
||||
}
|
||||
|
||||
Some(
|
||||
buildRawCandidates(
|
||||
inputTarget,
|
||||
firstDegreeCandidates,
|
||||
retweetCandidates
|
||||
))
|
||||
}
|
||||
|
||||
case _ =>
|
||||
emptyEarlyBirdCands.incr()
|
||||
Future.None
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override def isCandidateSourceAvailable(target: Target): Future[Boolean] = {
|
||||
PushDeviceUtil.isRecommendationsEligible(target)
|
||||
}
|
||||
}
|
@ -0,0 +1,120 @@
|
||||
package com.twitter.frigate.pushservice.adaptor
|
||||
|
||||
import com.twitter.explore_ranker.thriftscala.ExploreRankerProductResponse
|
||||
import com.twitter.explore_ranker.thriftscala.ExploreRankerRequest
|
||||
import com.twitter.explore_ranker.thriftscala.ExploreRankerResponse
|
||||
import com.twitter.explore_ranker.thriftscala.ExploreRecommendation
|
||||
import com.twitter.explore_ranker.thriftscala.ImmersiveRecsResponse
|
||||
import com.twitter.explore_ranker.thriftscala.ImmersiveRecsResult
|
||||
import com.twitter.explore_ranker.thriftscala.NotificationsVideoRecs
|
||||
import com.twitter.explore_ranker.thriftscala.Product
|
||||
import com.twitter.explore_ranker.thriftscala.ProductContext
|
||||
import com.twitter.finagle.stats.StatsReceiver
|
||||
import com.twitter.frigate.common.base.CandidateSource
|
||||
import com.twitter.frigate.common.base.CandidateSourceEligible
|
||||
import com.twitter.frigate.common.base.OutOfNetworkTweetCandidate
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.RawCandidate
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.Target
|
||||
import com.twitter.frigate.pushservice.params.PushFeatureSwitchParams
|
||||
import com.twitter.frigate.pushservice.util.AdaptorUtils
|
||||
import com.twitter.frigate.pushservice.util.MediaCRT
|
||||
import com.twitter.frigate.pushservice.util.PushAdaptorUtil
|
||||
import com.twitter.frigate.pushservice.util.PushDeviceUtil
|
||||
import com.twitter.frigate.thriftscala.CommonRecommendationType
|
||||
import com.twitter.product_mixer.core.thriftscala.ClientContext
|
||||
import com.twitter.stitch.tweetypie.TweetyPie.TweetyPieResult
|
||||
import com.twitter.storehaus.ReadableStore
|
||||
import com.twitter.util.Future
|
||||
|
||||
case class ExploreVideoTweetCandidateAdaptor(
|
||||
exploreRankerStore: ReadableStore[ExploreRankerRequest, ExploreRankerResponse],
|
||||
tweetyPieStore: ReadableStore[Long, TweetyPieResult],
|
||||
globalStats: StatsReceiver)
|
||||
extends CandidateSource[Target, RawCandidate]
|
||||
with CandidateSourceEligible[Target, RawCandidate] {
|
||||
|
||||
override def name: String = this.getClass.getSimpleName
|
||||
private[this] val stats = globalStats.scope("ExploreVideoTweetCandidateAdaptor")
|
||||
private[this] val totalInputRecs = stats.stat("input_recs")
|
||||
private[this] val totalRequests = stats.counter("total_requests")
|
||||
private[this] val totalEmptyResponse = stats.counter("total_empty_response")
|
||||
|
||||
private def buildExploreRankerRequest(
|
||||
target: Target,
|
||||
countryCode: Option[String],
|
||||
language: Option[String],
|
||||
): ExploreRankerRequest = {
|
||||
ExploreRankerRequest(
|
||||
clientContext = ClientContext(
|
||||
userId = Some(target.targetId),
|
||||
countryCode = countryCode,
|
||||
languageCode = language,
|
||||
),
|
||||
product = Product.NotificationsVideoRecs,
|
||||
productContext = Some(ProductContext.NotificationsVideoRecs(NotificationsVideoRecs())),
|
||||
maxResults = Some(target.params(PushFeatureSwitchParams.MaxExploreVideoTweets))
|
||||
)
|
||||
}
|
||||
|
||||
override def get(target: Target): Future[Option[Seq[RawCandidate]]] = {
|
||||
Future
|
||||
.join(
|
||||
target.countryCode,
|
||||
target.inferredUserDeviceLanguage
|
||||
).flatMap {
|
||||
case (countryCode, language) =>
|
||||
val request = buildExploreRankerRequest(target, countryCode, language)
|
||||
exploreRankerStore.get(request).flatMap {
|
||||
case Some(response) =>
|
||||
val exploreResonseTweetIds = response match {
|
||||
case ExploreRankerResponse(ExploreRankerProductResponse
|
||||
.ImmersiveRecsResponse(ImmersiveRecsResponse(immersiveRecsResult))) =>
|
||||
immersiveRecsResult.collect {
|
||||
case ImmersiveRecsResult(ExploreRecommendation
|
||||
.ExploreTweetRecommendation(exploreTweetRecommendation)) =>
|
||||
exploreTweetRecommendation.tweetId
|
||||
}
|
||||
case _ =>
|
||||
Seq.empty
|
||||
}
|
||||
|
||||
totalInputRecs.add(exploreResonseTweetIds.size)
|
||||
totalRequests.incr()
|
||||
AdaptorUtils
|
||||
.getTweetyPieResults(exploreResonseTweetIds.toSet, tweetyPieStore).map {
|
||||
tweetyPieResultMap =>
|
||||
val candidates = tweetyPieResultMap.values.flatten
|
||||
.map(buildVideoRawCandidates(target, _))
|
||||
Some(candidates.toSeq)
|
||||
}
|
||||
case _ =>
|
||||
totalEmptyResponse.incr()
|
||||
Future.None
|
||||
}
|
||||
case _ =>
|
||||
Future.None
|
||||
}
|
||||
}
|
||||
|
||||
override def isCandidateSourceAvailable(target: Target): Future[Boolean] = {
|
||||
PushDeviceUtil.isRecommendationsEligible(target).map { userRecommendationsEligible =>
|
||||
userRecommendationsEligible && target.params(PushFeatureSwitchParams.EnableExploreVideoTweets)
|
||||
}
|
||||
}
|
||||
private def buildVideoRawCandidates(
|
||||
target: Target,
|
||||
tweetyPieResult: TweetyPieResult
|
||||
): RawCandidate with OutOfNetworkTweetCandidate = {
|
||||
PushAdaptorUtil.generateOutOfNetworkTweetCandidates(
|
||||
inputTarget = target,
|
||||
id = tweetyPieResult.tweet.id,
|
||||
mediaCRT = MediaCRT(
|
||||
CommonRecommendationType.ExploreVideoTweet,
|
||||
CommonRecommendationType.ExploreVideoTweet,
|
||||
CommonRecommendationType.ExploreVideoTweet
|
||||
),
|
||||
result = Some(tweetyPieResult),
|
||||
localizedEntity = None
|
||||
)
|
||||
}
|
||||
}
|
@ -0,0 +1,272 @@
|
||||
package com.twitter.frigate.pushservice.adaptor
|
||||
|
||||
import com.twitter.cr_mixer.thriftscala.FrsTweetRequest
|
||||
import com.twitter.cr_mixer.thriftscala.NotificationsContext
|
||||
import com.twitter.cr_mixer.thriftscala.Product
|
||||
import com.twitter.cr_mixer.thriftscala.ProductContext
|
||||
import com.twitter.finagle.stats.Counter
|
||||
import com.twitter.finagle.stats.StatsReceiver
|
||||
import com.twitter.frigate.common.base.CandidateSource
|
||||
import com.twitter.frigate.common.base.CandidateSourceEligible
|
||||
import com.twitter.frigate.common.base._
|
||||
import com.twitter.frigate.common.predicate.CommonOutNetworkTweetCandidatesSourcePredicates.filterOutReplyTweet
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.RawCandidate
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.Target
|
||||
import com.twitter.frigate.pushservice.params.PushFeatureSwitchParams
|
||||
import com.twitter.frigate.pushservice.store.CrMixerTweetStore
|
||||
import com.twitter.frigate.pushservice.store.UttEntityHydrationStore
|
||||
import com.twitter.frigate.pushservice.util.MediaCRT
|
||||
import com.twitter.frigate.pushservice.util.PushAdaptorUtil
|
||||
import com.twitter.frigate.pushservice.util.PushDeviceUtil
|
||||
import com.twitter.frigate.pushservice.util.TopicsUtil
|
||||
import com.twitter.frigate.thriftscala.CommonRecommendationType
|
||||
import com.twitter.hermit.constants.AlgorithmFeedbackTokens
|
||||
import com.twitter.hermit.model.Algorithm.Algorithm
|
||||
import com.twitter.hermit.model.Algorithm.CrowdSearchAccounts
|
||||
import com.twitter.hermit.model.Algorithm.ForwardEmailBook
|
||||
import com.twitter.hermit.model.Algorithm.ForwardPhoneBook
|
||||
import com.twitter.hermit.model.Algorithm.ReverseEmailBookIbis
|
||||
import com.twitter.hermit.model.Algorithm.ReversePhoneBook
|
||||
import com.twitter.hermit.store.tweetypie.UserTweet
|
||||
import com.twitter.product_mixer.core.thriftscala.ClientContext
|
||||
import com.twitter.stitch.tweetypie.TweetyPie.TweetyPieResult
|
||||
import com.twitter.storehaus.ReadableStore
|
||||
import com.twitter.tsp.thriftscala.TopicSocialProofRequest
|
||||
import com.twitter.tsp.thriftscala.TopicSocialProofResponse
|
||||
import com.twitter.util.Future
|
||||
|
||||
object FRSAlgorithmFeedbackTokenUtil {
|
||||
private val crtsByAlgoToken = Map(
|
||||
getAlgorithmToken(ReverseEmailBookIbis) -> CommonRecommendationType.ReverseAddressbookTweet,
|
||||
getAlgorithmToken(ReversePhoneBook) -> CommonRecommendationType.ReverseAddressbookTweet,
|
||||
getAlgorithmToken(ForwardEmailBook) -> CommonRecommendationType.ForwardAddressbookTweet,
|
||||
getAlgorithmToken(ForwardPhoneBook) -> CommonRecommendationType.ForwardAddressbookTweet,
|
||||
getAlgorithmToken(CrowdSearchAccounts) -> CommonRecommendationType.CrowdSearchTweet
|
||||
)
|
||||
|
||||
def getAlgorithmToken(algorithm: Algorithm): Int = {
|
||||
AlgorithmFeedbackTokens.AlgorithmToFeedbackTokenMap(algorithm)
|
||||
}
|
||||
|
||||
def getCRTForAlgoToken(algorithmToken: Int): Option[CommonRecommendationType] = {
|
||||
crtsByAlgoToken.get(algorithmToken)
|
||||
}
|
||||
}
|
||||
|
||||
case class FRSTweetCandidateAdaptor(
|
||||
crMixerTweetStore: CrMixerTweetStore,
|
||||
tweetyPieStore: ReadableStore[Long, TweetyPieResult],
|
||||
tweetyPieStoreNoVF: ReadableStore[Long, TweetyPieResult],
|
||||
userTweetTweetyPieStore: ReadableStore[UserTweet, TweetyPieResult],
|
||||
uttEntityHydrationStore: UttEntityHydrationStore,
|
||||
topicSocialProofServiceStore: ReadableStore[TopicSocialProofRequest, TopicSocialProofResponse],
|
||||
globalStats: StatsReceiver)
|
||||
extends CandidateSource[Target, RawCandidate]
|
||||
with CandidateSourceEligible[Target, RawCandidate] {
|
||||
|
||||
private val stats = globalStats.scope(this.getClass.getSimpleName)
|
||||
private val crtStats = stats.scope("CandidateDistribution")
|
||||
private val totalRequests = stats.counter("total_requests")
|
||||
|
||||
// Candidate Distribution stats
|
||||
private val reverseAddressbookCounter = crtStats.counter("reverse_addressbook")
|
||||
private val forwardAddressbookCounter = crtStats.counter("forward_addressbook")
|
||||
private val frsTweetCounter = crtStats.counter("frs_tweet")
|
||||
private val nonReplyTweetsCounter = stats.counter("non_reply_tweets")
|
||||
private val crtToCounterMapping: Map[CommonRecommendationType, Counter] = Map(
|
||||
CommonRecommendationType.ReverseAddressbookTweet -> reverseAddressbookCounter,
|
||||
CommonRecommendationType.ForwardAddressbookTweet -> forwardAddressbookCounter,
|
||||
CommonRecommendationType.FrsTweet -> frsTweetCounter
|
||||
)
|
||||
|
||||
private val emptyTweetyPieResult = stats.stat("empty_tweetypie_result")
|
||||
|
||||
private[this] val numberReturnedCandidates = stats.stat("returned_candidates_from_earlybird")
|
||||
private[this] val numberCandidateWithTopic: Counter = stats.counter("num_can_with_topic")
|
||||
private[this] val numberCandidateWithoutTopic: Counter = stats.counter("num_can_without_topic")
|
||||
|
||||
private val userTweetTweetyPieStoreCounter = stats.counter("user_tweet_tweetypie_store")
|
||||
|
||||
override val name: String = this.getClass.getSimpleName
|
||||
|
||||
private def filterInvalidTweets(
|
||||
tweetIds: Seq[Long],
|
||||
target: Target
|
||||
): Future[Map[Long, TweetyPieResult]] = {
|
||||
val resMap = {
|
||||
if (target.params(PushFeatureSwitchParams.EnableF1FromProtectedTweetAuthors)) {
|
||||
userTweetTweetyPieStoreCounter.incr()
|
||||
val keys = tweetIds.map { tweetId =>
|
||||
UserTweet(tweetId, Some(target.targetId))
|
||||
}
|
||||
userTweetTweetyPieStore
|
||||
.multiGet(keys.toSet).map {
|
||||
case (userTweet, resultFut) =>
|
||||
userTweet.tweetId -> resultFut
|
||||
}.toMap
|
||||
} else {
|
||||
(if (target.params(PushFeatureSwitchParams.EnableVFInTweetypie)) {
|
||||
tweetyPieStore
|
||||
} else {
|
||||
tweetyPieStoreNoVF
|
||||
}).multiGet(tweetIds.toSet)
|
||||
}
|
||||
}
|
||||
|
||||
Future.collect(resMap).map { tweetyPieResultMap =>
|
||||
// Filter out replies and generate earlybird candidates only for non-empty tweetypie result
|
||||
val cands = filterOutReplyTweet(tweetyPieResultMap, nonReplyTweetsCounter).collect {
|
||||
case (id: Long, Some(result)) =>
|
||||
id -> result
|
||||
}
|
||||
|
||||
emptyTweetyPieResult.add(tweetyPieResultMap.size - cands.size)
|
||||
cands
|
||||
}
|
||||
}
|
||||
|
||||
private def buildRawCandidates(
|
||||
target: Target,
|
||||
ebCandidates: Seq[FRSTweetCandidate]
|
||||
): Future[Option[Seq[RawCandidate with TweetCandidate]]] = {
|
||||
|
||||
val enableTopic = target.params(PushFeatureSwitchParams.EnableFrsTweetCandidatesTopicAnnotation)
|
||||
val topicScoreThre =
|
||||
target.params(PushFeatureSwitchParams.FrsTweetCandidatesTopicScoreThreshold)
|
||||
|
||||
val ebTweets = ebCandidates.map { ebCandidate =>
|
||||
ebCandidate.tweetId -> ebCandidate.tweetyPieResult
|
||||
}.toMap
|
||||
|
||||
val tweetIdLocalizedEntityMapFut = TopicsUtil.getTweetIdLocalizedEntityMap(
|
||||
target,
|
||||
ebTweets,
|
||||
uttEntityHydrationStore,
|
||||
topicSocialProofServiceStore,
|
||||
enableTopic,
|
||||
topicScoreThre
|
||||
)
|
||||
|
||||
Future.join(target.deviceInfo, tweetIdLocalizedEntityMapFut).map {
|
||||
case (Some(deviceInfo), tweetIdLocalizedEntityMap) =>
|
||||
val candidates = ebCandidates
|
||||
.map { ebCandidate =>
|
||||
val crt = ebCandidate.commonRecType
|
||||
crtToCounterMapping.get(crt).foreach(_.incr())
|
||||
|
||||
val tweetId = ebCandidate.tweetId
|
||||
val localizedEntityOpt = {
|
||||
if (tweetIdLocalizedEntityMap
|
||||
.contains(tweetId) && tweetIdLocalizedEntityMap.contains(
|
||||
tweetId) && deviceInfo.isTopicsEligible) {
|
||||
tweetIdLocalizedEntityMap(tweetId)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
PushAdaptorUtil.generateOutOfNetworkTweetCandidates(
|
||||
inputTarget = target,
|
||||
id = ebCandidate.tweetId,
|
||||
mediaCRT = MediaCRT(
|
||||
crt,
|
||||
crt,
|
||||
crt
|
||||
),
|
||||
result = ebCandidate.tweetyPieResult,
|
||||
localizedEntity = localizedEntityOpt)
|
||||
}.filter { candidate =>
|
||||
// If user only has the topic setting enabled, filter out all non-topic cands
|
||||
deviceInfo.isRecommendationsEligible || (deviceInfo.isTopicsEligible && candidate.semanticCoreEntityId.nonEmpty)
|
||||
}
|
||||
|
||||
candidates.map { candidate =>
|
||||
if (candidate.semanticCoreEntityId.nonEmpty) {
|
||||
numberCandidateWithTopic.incr()
|
||||
} else {
|
||||
numberCandidateWithoutTopic.incr()
|
||||
}
|
||||
}
|
||||
|
||||
numberReturnedCandidates.add(candidates.length)
|
||||
Some(candidates)
|
||||
case _ => Some(Seq.empty)
|
||||
}
|
||||
}
|
||||
|
||||
def getTweetCandidatesFromCrMixer(
|
||||
inputTarget: Target,
|
||||
showAllResultsFromFrs: Boolean,
|
||||
): Future[Option[Seq[RawCandidate with TweetCandidate]]] = {
|
||||
Future
|
||||
.join(
|
||||
inputTarget.seenTweetIds,
|
||||
inputTarget.pushRecItems,
|
||||
inputTarget.countryCode,
|
||||
inputTarget.targetLanguage).flatMap {
|
||||
case (seenTweetIds, pastRecItems, countryCode, language) =>
|
||||
val pastUserRecs = pastRecItems.userIds.toSeq
|
||||
val request = FrsTweetRequest(
|
||||
clientContext = ClientContext(
|
||||
userId = Some(inputTarget.targetId),
|
||||
countryCode = countryCode,
|
||||
languageCode = language
|
||||
),
|
||||
product = Product.Notifications,
|
||||
productContext = Some(ProductContext.NotificationsContext(NotificationsContext())),
|
||||
excludedUserIds = Some(pastUserRecs),
|
||||
excludedTweetIds = Some(seenTweetIds)
|
||||
)
|
||||
crMixerTweetStore.getFRSTweetCandidates(request).flatMap {
|
||||
case Some(response) =>
|
||||
val tweetIds = response.tweets.map(_.tweetId)
|
||||
val validTweets = filterInvalidTweets(tweetIds, inputTarget)
|
||||
validTweets.flatMap { tweetypieMap =>
|
||||
val ebCandidates = response.tweets
|
||||
.map { frsTweet =>
|
||||
val candidateTweetId = frsTweet.tweetId
|
||||
val resultFromTweetyPie = tweetypieMap.get(candidateTweetId)
|
||||
new FRSTweetCandidate {
|
||||
override val tweetId = candidateTweetId
|
||||
override val features = None
|
||||
override val tweetyPieResult = resultFromTweetyPie
|
||||
override val feedbackToken = frsTweet.frsPrimarySource
|
||||
override val commonRecType: CommonRecommendationType = feedbackToken
|
||||
.flatMap(token =>
|
||||
FRSAlgorithmFeedbackTokenUtil.getCRTForAlgoToken(token)).getOrElse(
|
||||
CommonRecommendationType.FrsTweet)
|
||||
}
|
||||
}.filter { ebCandidate =>
|
||||
showAllResultsFromFrs || ebCandidate.commonRecType == CommonRecommendationType.ReverseAddressbookTweet
|
||||
}
|
||||
|
||||
numberReturnedCandidates.add(ebCandidates.length)
|
||||
buildRawCandidates(
|
||||
inputTarget,
|
||||
ebCandidates
|
||||
)
|
||||
}
|
||||
case _ => Future.None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override def get(inputTarget: Target): Future[Option[Seq[RawCandidate with TweetCandidate]]] = {
|
||||
totalRequests.incr()
|
||||
val enableResultsFromFrs =
|
||||
inputTarget.params(PushFeatureSwitchParams.EnableResultFromFrsCandidates)
|
||||
getTweetCandidatesFromCrMixer(inputTarget, enableResultsFromFrs)
|
||||
}
|
||||
|
||||
override def isCandidateSourceAvailable(target: Target): Future[Boolean] = {
|
||||
lazy val enableFrsCandidates = target.params(PushFeatureSwitchParams.EnableFrsCandidates)
|
||||
PushDeviceUtil.isRecommendationsEligible(target).flatMap { isEnabledForRecosSetting =>
|
||||
PushDeviceUtil.isTopicsEligible(target).map { topicSettingEnabled =>
|
||||
val isEnabledForTopics =
|
||||
topicSettingEnabled && target.params(
|
||||
PushFeatureSwitchParams.EnableFrsTweetCandidatesTopicSetting)
|
||||
(isEnabledForRecosSetting || isEnabledForTopics) && enableFrsCandidates
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,107 @@
|
||||
package com.twitter.frigate.pushservice.adaptor
|
||||
|
||||
import com.twitter.finagle.stats.StatsReceiver
|
||||
import com.twitter.frigate.common.base._
|
||||
import com.twitter.frigate.common.candidate._
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.RawCandidate
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.Target
|
||||
import com.twitter.frigate.pushservice.params.PushParams
|
||||
import com.twitter.frigate.pushservice.util.PushDeviceUtil
|
||||
import com.twitter.stitch.tweetypie.TweetyPie.TweetyPieResult
|
||||
import com.twitter.storehaus.ReadableStore
|
||||
import com.twitter.util.Future
|
||||
|
||||
object GenericCandidates {
|
||||
type Target =
|
||||
TargetUser
|
||||
with UserDetails
|
||||
with TargetDecider
|
||||
with TargetABDecider
|
||||
with TweetImpressionHistory
|
||||
with HTLVisitHistory
|
||||
with MaxTweetAge
|
||||
with NewUserDetails
|
||||
with FrigateHistory
|
||||
with TargetWithSeedUsers
|
||||
}
|
||||
|
||||
case class GenericCandidateAdaptor(
|
||||
genericCandidates: CandidateSource[GenericCandidates.Target, Candidate],
|
||||
tweetyPieStore: ReadableStore[Long, TweetyPieResult],
|
||||
tweetyPieStoreNoVF: ReadableStore[Long, TweetyPieResult],
|
||||
stats: StatsReceiver)
|
||||
extends CandidateSource[Target, RawCandidate]
|
||||
with CandidateSourceEligible[Target, RawCandidate] {
|
||||
|
||||
override val name: String = genericCandidates.name
|
||||
|
||||
private def generateTweetFavCandidate(
|
||||
_target: Target,
|
||||
_tweetId: Long,
|
||||
_socialContextActions: Seq[SocialContextAction],
|
||||
socialContextActionsAllTypes: Seq[SocialContextAction],
|
||||
_tweetyPieResult: Option[TweetyPieResult]
|
||||
): RawCandidate = {
|
||||
new RawCandidate with TweetFavoriteCandidate {
|
||||
override val socialContextActions = _socialContextActions
|
||||
override val socialContextAllTypeActions =
|
||||
socialContextActionsAllTypes
|
||||
val tweetId = _tweetId
|
||||
val target = _target
|
||||
val tweetyPieResult = _tweetyPieResult
|
||||
}
|
||||
}
|
||||
|
||||
private def generateTweetRetweetCandidate(
|
||||
_target: Target,
|
||||
_tweetId: Long,
|
||||
_socialContextActions: Seq[SocialContextAction],
|
||||
socialContextActionsAllTypes: Seq[SocialContextAction],
|
||||
_tweetyPieResult: Option[TweetyPieResult]
|
||||
): RawCandidate = {
|
||||
new RawCandidate with TweetRetweetCandidate {
|
||||
override val socialContextActions = _socialContextActions
|
||||
override val socialContextAllTypeActions = socialContextActionsAllTypes
|
||||
val tweetId = _tweetId
|
||||
val target = _target
|
||||
val tweetyPieResult = _tweetyPieResult
|
||||
}
|
||||
}
|
||||
|
||||
override def get(inputTarget: Target): Future[Option[Seq[RawCandidate]]] = {
|
||||
genericCandidates.get(inputTarget).map { candidatesOpt =>
|
||||
candidatesOpt
|
||||
.map { candidates =>
|
||||
val candidatesSeq =
|
||||
candidates.collect {
|
||||
case tweetRetweet: TweetRetweetCandidate
|
||||
if inputTarget.params(PushParams.MRTweetRetweetRecsParam) =>
|
||||
generateTweetRetweetCandidate(
|
||||
inputTarget,
|
||||
tweetRetweet.tweetId,
|
||||
tweetRetweet.socialContextActions,
|
||||
tweetRetweet.socialContextAllTypeActions,
|
||||
tweetRetweet.tweetyPieResult)
|
||||
case tweetFavorite: TweetFavoriteCandidate
|
||||
if inputTarget.params(PushParams.MRTweetFavRecsParam) =>
|
||||
generateTweetFavCandidate(
|
||||
inputTarget,
|
||||
tweetFavorite.tweetId,
|
||||
tweetFavorite.socialContextActions,
|
||||
tweetFavorite.socialContextAllTypeActions,
|
||||
tweetFavorite.tweetyPieResult)
|
||||
}
|
||||
candidatesSeq.foreach { candidate =>
|
||||
stats.counter(s"${candidate.commonRecType}_count").incr()
|
||||
}
|
||||
candidatesSeq
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override def isCandidateSourceAvailable(target: Target): Future[Boolean] = {
|
||||
PushDeviceUtil.isRecommendationsEligible(target).map { isAvailable =>
|
||||
isAvailable && target.params(PushParams.GenericCandidateAdaptorDecider)
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,280 @@
|
||||
package com.twitter.frigate.pushservice.adaptor
|
||||
|
||||
import com.twitter.finagle.stats.Stat
|
||||
import com.twitter.finagle.stats.StatsReceiver
|
||||
import com.twitter.frigate.common.base.CandidateSource
|
||||
import com.twitter.frigate.common.base.CandidateSourceEligible
|
||||
import com.twitter.frigate.common.store.interests.InterestsLookupRequestWithContext
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.RawCandidate
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.Target
|
||||
import com.twitter.frigate.pushservice.params.HighQualityCandidateGroupEnum
|
||||
import com.twitter.frigate.pushservice.params.HighQualityCandidateGroupEnum._
|
||||
import com.twitter.frigate.pushservice.params.PushConstants.targetUserAgeFeatureName
|
||||
import com.twitter.frigate.pushservice.params.PushConstants.targetUserPreferredLanguage
|
||||
import com.twitter.frigate.pushservice.params.{PushFeatureSwitchParams => FS}
|
||||
import com.twitter.frigate.pushservice.predicate.TargetPredicates
|
||||
import com.twitter.frigate.pushservice.util.MediaCRT
|
||||
import com.twitter.frigate.pushservice.util.PushAdaptorUtil
|
||||
import com.twitter.frigate.pushservice.util.PushDeviceUtil
|
||||
import com.twitter.frigate.pushservice.util.TopicsUtil
|
||||
import com.twitter.frigate.thriftscala.CommonRecommendationType
|
||||
import com.twitter.interests.thriftscala.InterestId.SemanticCore
|
||||
import com.twitter.interests.thriftscala.UserInterests
|
||||
import com.twitter.language.normalization.UserDisplayLanguage
|
||||
import com.twitter.stitch.tweetypie.TweetyPie.TweetyPieResult
|
||||
import com.twitter.storehaus.ReadableStore
|
||||
import com.twitter.trends.trip_v1.trip_tweets.thriftscala.TripDomain
|
||||
import com.twitter.trends.trip_v1.trip_tweets.thriftscala.TripTweet
|
||||
import com.twitter.trends.trip_v1.trip_tweets.thriftscala.TripTweets
|
||||
import com.twitter.util.Future
|
||||
|
||||
object HighQualityTweetsHelper {
|
||||
def getFollowedTopics(
|
||||
target: Target,
|
||||
interestsWithLookupContextStore: ReadableStore[
|
||||
InterestsLookupRequestWithContext,
|
||||
UserInterests
|
||||
],
|
||||
followedTopicsStats: Stat
|
||||
): Future[Seq[Long]] = {
|
||||
TopicsUtil
|
||||
.getTopicsFollowedByUser(target, interestsWithLookupContextStore, followedTopicsStats).map {
|
||||
userInterestsOpt =>
|
||||
val userInterests = userInterestsOpt.getOrElse(Seq.empty)
|
||||
val extractedTopicIds = userInterests.flatMap {
|
||||
_.interestId match {
|
||||
case SemanticCore(semanticCore) => Some(semanticCore.id)
|
||||
case _ => None
|
||||
}
|
||||
}
|
||||
extractedTopicIds
|
||||
}
|
||||
}
|
||||
|
||||
def getTripQueries(
|
||||
target: Target,
|
||||
enabledGroups: Set[HighQualityCandidateGroupEnum.Value],
|
||||
interestsWithLookupContextStore: ReadableStore[
|
||||
InterestsLookupRequestWithContext,
|
||||
UserInterests
|
||||
],
|
||||
sourceIds: Seq[String],
|
||||
stat: Stat
|
||||
): Future[Set[TripDomain]] = {
|
||||
|
||||
val followedTopicIdsSetFut: Future[Set[Long]] = if (enabledGroups.contains(Topic)) {
|
||||
getFollowedTopics(target, interestsWithLookupContextStore, stat).map(topicIds =>
|
||||
topicIds.toSet)
|
||||
} else {
|
||||
Future.value(Set.empty)
|
||||
}
|
||||
|
||||
Future
|
||||
.join(target.featureMap, target.inferredUserDeviceLanguage, followedTopicIdsSetFut).map {
|
||||
case (
|
||||
featureMap,
|
||||
deviceLanguageOpt,
|
||||
followedTopicIds
|
||||
) =>
|
||||
val ageBucketOpt = if (enabledGroups.contains(AgeBucket)) {
|
||||
featureMap.categoricalFeatures.get(targetUserAgeFeatureName)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
|
||||
val languageOptions: Set[Option[String]] = if (enabledGroups.contains(Language)) {
|
||||
val userPreferredLanguages = featureMap.sparseBinaryFeatures
|
||||
.getOrElse(targetUserPreferredLanguage, Set.empty[String])
|
||||
if (userPreferredLanguages.nonEmpty) {
|
||||
userPreferredLanguages.map(lang => Some(UserDisplayLanguage.toTweetLanguage(lang)))
|
||||
} else {
|
||||
Set(deviceLanguageOpt.map(UserDisplayLanguage.toTweetLanguage))
|
||||
}
|
||||
} else Set(None)
|
||||
|
||||
val followedTopicOptions: Set[Option[Long]] = if (followedTopicIds.nonEmpty) {
|
||||
followedTopicIds.map(topic => Some(topic))
|
||||
} else Set(None)
|
||||
|
||||
val tripQueries = followedTopicOptions.flatMap { topicOption =>
|
||||
languageOptions.flatMap { languageOption =>
|
||||
sourceIds.map { sourceId =>
|
||||
TripDomain(
|
||||
sourceId = sourceId,
|
||||
language = languageOption,
|
||||
placeId = None,
|
||||
topicId = topicOption,
|
||||
gender = None,
|
||||
ageBucket = ageBucketOpt
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tripQueries
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
case class HighQualityTweetsAdaptor(
|
||||
tripTweetCandidateStore: ReadableStore[TripDomain, TripTweets],
|
||||
interestsWithLookupContextStore: ReadableStore[InterestsLookupRequestWithContext, UserInterests],
|
||||
tweetyPieStore: ReadableStore[Long, TweetyPieResult],
|
||||
tweetyPieStoreNoVF: ReadableStore[Long, TweetyPieResult],
|
||||
globalStats: StatsReceiver)
|
||||
extends CandidateSource[Target, RawCandidate]
|
||||
with CandidateSourceEligible[Target, RawCandidate] {
|
||||
|
||||
override def name: String = this.getClass.getSimpleName
|
||||
|
||||
private val stats = globalStats.scope("HighQualityCandidateAdaptor")
|
||||
private val followedTopicsStats = stats.stat("followed_topics")
|
||||
private val missingResponseCounter = stats.counter("missing_respond_counter")
|
||||
private val crtFatigueCounter = stats.counter("fatigue_by_crt")
|
||||
private val fallbackRequestsCounter = stats.counter("fallback_requests")
|
||||
|
||||
override def isCandidateSourceAvailable(target: Target): Future[Boolean] = {
|
||||
PushDeviceUtil.isRecommendationsEligible(target).map {
|
||||
_ && target.params(FS.HighQualityCandidatesEnableCandidateSource)
|
||||
}
|
||||
}
|
||||
|
||||
private val highQualityCandidateFrequencyPredicate = {
|
||||
TargetPredicates
|
||||
.pushRecTypeFatiguePredicate(
|
||||
CommonRecommendationType.TripHqTweet,
|
||||
FS.HighQualityTweetsPushInterval,
|
||||
FS.MaxHighQualityTweetsPushGivenInterval,
|
||||
stats
|
||||
)
|
||||
}
|
||||
|
||||
private def getTripCandidatesStrato(
|
||||
target: Target
|
||||
): Future[Map[Long, Set[TripDomain]]] = {
|
||||
val tripQueriesF: Future[Set[TripDomain]] = HighQualityTweetsHelper.getTripQueries(
|
||||
target = target,
|
||||
enabledGroups = target.params(FS.HighQualityCandidatesEnableGroups).toSet,
|
||||
interestsWithLookupContextStore = interestsWithLookupContextStore,
|
||||
sourceIds = target.params(FS.TripTweetCandidateSourceIds),
|
||||
stat = followedTopicsStats
|
||||
)
|
||||
|
||||
lazy val fallbackTripQueriesFut: Future[Set[TripDomain]] =
|
||||
if (target.params(FS.HighQualityCandidatesEnableFallback))
|
||||
HighQualityTweetsHelper.getTripQueries(
|
||||
target = target,
|
||||
enabledGroups = target.params(FS.HighQualityCandidatesFallbackEnabledGroups).toSet,
|
||||
interestsWithLookupContextStore = interestsWithLookupContextStore,
|
||||
sourceIds = target.params(FS.HighQualityCandidatesFallbackSourceIds),
|
||||
stat = followedTopicsStats
|
||||
)
|
||||
else Future.value(Set.empty)
|
||||
|
||||
val initialTweetsFut: Future[Map[TripDomain, Seq[TripTweet]]] = tripQueriesF.flatMap {
|
||||
tripQueries => getTripTweetsByDomains(tripQueries)
|
||||
}
|
||||
|
||||
val tweetsByDomainFut: Future[Map[TripDomain, Seq[TripTweet]]] =
|
||||
if (target.params(FS.HighQualityCandidatesEnableFallback)) {
|
||||
initialTweetsFut.flatMap { candidates =>
|
||||
val minCandidatesForFallback: Int =
|
||||
target.params(FS.HighQualityCandidatesMinNumOfCandidatesToFallback)
|
||||
val validCandidates = candidates.filter(_._2.size >= minCandidatesForFallback)
|
||||
|
||||
if (validCandidates.nonEmpty) {
|
||||
Future.value(validCandidates)
|
||||
} else {
|
||||
fallbackTripQueriesFut.flatMap { fallbackTripDomains =>
|
||||
fallbackRequestsCounter.incr(fallbackTripDomains.size)
|
||||
getTripTweetsByDomains(fallbackTripDomains)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
initialTweetsFut
|
||||
}
|
||||
|
||||
val numOfCandidates: Int = target.params(FS.HighQualityCandidatesNumberOfCandidates)
|
||||
tweetsByDomainFut.map(tweetsByDomain => reformatDomainTweetMap(tweetsByDomain, numOfCandidates))
|
||||
}
|
||||
|
||||
private def getTripTweetsByDomains(
|
||||
tripQueries: Set[TripDomain]
|
||||
): Future[Map[TripDomain, Seq[TripTweet]]] = {
|
||||
Future.collect(tripTweetCandidateStore.multiGet(tripQueries)).map { response =>
|
||||
response
|
||||
.filter(p => p._2.exists(_.tweets.nonEmpty))
|
||||
.mapValues(_.map(_.tweets).getOrElse(Seq.empty))
|
||||
}
|
||||
}
|
||||
|
||||
private def reformatDomainTweetMap(
|
||||
tweetsByDomain: Map[TripDomain, Seq[TripTweet]],
|
||||
numOfCandidates: Int
|
||||
): Map[Long, Set[TripDomain]] = tweetsByDomain
|
||||
.flatMap {
|
||||
case (tripDomain, tripTweets) =>
|
||||
tripTweets
|
||||
.sortBy(_.score)(Ordering[Double].reverse)
|
||||
.take(numOfCandidates)
|
||||
.map { tweet => (tweet.tweetId, tripDomain) }
|
||||
}.groupBy(_._1).mapValues(_.map(_._2).toSet)
|
||||
|
||||
private def buildRawCandidate(
|
||||
target: Target,
|
||||
tweetyPieResult: TweetyPieResult,
|
||||
tripDomain: Option[scala.collection.Set[TripDomain]]
|
||||
): RawCandidate = {
|
||||
PushAdaptorUtil.generateOutOfNetworkTweetCandidates(
|
||||
inputTarget = target,
|
||||
id = tweetyPieResult.tweet.id,
|
||||
mediaCRT = MediaCRT(
|
||||
CommonRecommendationType.TripHqTweet,
|
||||
CommonRecommendationType.TripHqTweet,
|
||||
CommonRecommendationType.TripHqTweet
|
||||
),
|
||||
result = Some(tweetyPieResult),
|
||||
tripTweetDomain = tripDomain
|
||||
)
|
||||
}
|
||||
|
||||
private def getTweetyPieResults(
|
||||
target: Target,
|
||||
tweetToTripDomain: Map[Long, Set[TripDomain]]
|
||||
): Future[Map[Long, Option[TweetyPieResult]]] = {
|
||||
Future.collect((if (target.params(FS.EnableVFInTweetypie)) {
|
||||
tweetyPieStore
|
||||
} else {
|
||||
tweetyPieStoreNoVF
|
||||
}).multiGet(tweetToTripDomain.keySet))
|
||||
}
|
||||
|
||||
override def get(target: Target): Future[Option[Seq[RawCandidate]]] = {
|
||||
for {
|
||||
tweetsToTripDomainMap <- getTripCandidatesStrato(target)
|
||||
tweetyPieResults <- getTweetyPieResults(target, tweetsToTripDomainMap)
|
||||
} yield {
|
||||
val candidates = tweetyPieResults.flatMap {
|
||||
case (tweetId, tweetyPieResultOpt) =>
|
||||
tweetyPieResultOpt.map(buildRawCandidate(target, _, tweetsToTripDomainMap.get(tweetId)))
|
||||
}
|
||||
if (candidates.nonEmpty) {
|
||||
highQualityCandidateFrequencyPredicate(Seq(target))
|
||||
.map(_.head)
|
||||
.map { isTargetFatigueEligible =>
|
||||
if (isTargetFatigueEligible) Some(candidates)
|
||||
else {
|
||||
crtFatigueCounter.incr()
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
Some(candidates.toSeq)
|
||||
} else {
|
||||
missingResponseCounter.incr()
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,152 @@
|
||||
package com.twitter.frigate.pushservice.adaptor
|
||||
|
||||
import com.twitter.finagle.stats.StatsReceiver
|
||||
import com.twitter.frigate.common.base.CandidateSource
|
||||
import com.twitter.frigate.common.base.CandidateSourceEligible
|
||||
import com.twitter.frigate.common.base.ListPushCandidate
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.RawCandidate
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.Target
|
||||
import com.twitter.frigate.pushservice.params.PushFeatureSwitchParams
|
||||
import com.twitter.frigate.pushservice.predicate.TargetPredicates
|
||||
import com.twitter.frigate.pushservice.util.PushDeviceUtil
|
||||
import com.twitter.frigate.thriftscala.CommonRecommendationType
|
||||
import com.twitter.geoduck.service.thriftscala.LocationResponse
|
||||
import com.twitter.interests_discovery.thriftscala.DisplayLocation
|
||||
import com.twitter.interests_discovery.thriftscala.NonPersonalizedRecommendedLists
|
||||
import com.twitter.interests_discovery.thriftscala.RecommendedListsRequest
|
||||
import com.twitter.interests_discovery.thriftscala.RecommendedListsResponse
|
||||
import com.twitter.storehaus.ReadableStore
|
||||
import com.twitter.util.Future
|
||||
|
||||
case class ListsToRecommendCandidateAdaptor(
|
||||
listRecommendationsStore: ReadableStore[String, NonPersonalizedRecommendedLists],
|
||||
geoDuckV2Store: ReadableStore[Long, LocationResponse],
|
||||
idsStore: ReadableStore[RecommendedListsRequest, RecommendedListsResponse],
|
||||
globalStats: StatsReceiver)
|
||||
extends CandidateSource[Target, RawCandidate]
|
||||
with CandidateSourceEligible[Target, RawCandidate] {
|
||||
|
||||
override val name: String = this.getClass.getSimpleName
|
||||
|
||||
private[this] val stats = globalStats.scope(name)
|
||||
private[this] val noLocationCodeCounter = stats.counter("no_location_code")
|
||||
private[this] val noCandidatesCounter = stats.counter("no_candidates_for_geo")
|
||||
private[this] val disablePopGeoListsCounter = stats.counter("disable_pop_geo_lists")
|
||||
private[this] val disableIDSListsCounter = stats.counter("disable_ids_lists")
|
||||
|
||||
private def getListCandidate(
|
||||
targetUser: Target,
|
||||
_listId: Long
|
||||
): RawCandidate with ListPushCandidate = {
|
||||
new RawCandidate with ListPushCandidate {
|
||||
override val listId: Long = _listId
|
||||
|
||||
override val commonRecType: CommonRecommendationType = CommonRecommendationType.List
|
||||
|
||||
override val target: Target = targetUser
|
||||
}
|
||||
}
|
||||
|
||||
private def getListsRecommendedFromHistory(
|
||||
target: Target
|
||||
): Future[Seq[Long]] = {
|
||||
target.history.map { history =>
|
||||
history.sortedHistory.flatMap {
|
||||
case (_, notif) if notif.commonRecommendationType == List =>
|
||||
notif.listNotification.map(_.listId)
|
||||
case _ => None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private def getIDSListRecs(
|
||||
target: Target,
|
||||
historicalListIds: Seq[Long]
|
||||
): Future[Seq[Long]] = {
|
||||
val request = RecommendedListsRequest(
|
||||
target.targetId,
|
||||
DisplayLocation.ListDiscoveryPage,
|
||||
Some(historicalListIds)
|
||||
)
|
||||
if (target.params(PushFeatureSwitchParams.EnableIDSListRecommendations)) {
|
||||
idsStore.get(request).map {
|
||||
case Some(response) =>
|
||||
response.channels.map(_.id)
|
||||
case _ => Nil
|
||||
}
|
||||
} else {
|
||||
disableIDSListsCounter.incr()
|
||||
Future.Nil
|
||||
}
|
||||
}
|
||||
|
||||
private def getPopGeoLists(
|
||||
target: Target,
|
||||
historicalListIds: Seq[Long]
|
||||
): Future[Seq[Long]] = {
|
||||
if (target.params(PushFeatureSwitchParams.EnablePopGeoListRecommendations)) {
|
||||
geoDuckV2Store.get(target.targetId).flatMap {
|
||||
case Some(locationResponse) if locationResponse.geohash.isDefined =>
|
||||
val geoHashLength =
|
||||
target.params(PushFeatureSwitchParams.ListRecommendationsGeoHashLength)
|
||||
val geoHash = locationResponse.geohash.get.take(geoHashLength)
|
||||
listRecommendationsStore
|
||||
.get(s"geohash_$geoHash")
|
||||
.map {
|
||||
case Some(recommendedLists) =>
|
||||
recommendedLists.recommendedListsByAlgo.flatMap { topLists =>
|
||||
topLists.lists.collect {
|
||||
case list if !historicalListIds.contains(list.listId) => list.listId
|
||||
}
|
||||
}
|
||||
case _ => Nil
|
||||
}
|
||||
case _ =>
|
||||
noLocationCodeCounter.incr()
|
||||
Future.Nil
|
||||
}
|
||||
} else {
|
||||
disablePopGeoListsCounter.incr()
|
||||
Future.Nil
|
||||
}
|
||||
}
|
||||
|
||||
override def get(target: Target): Future[Option[Seq[RawCandidate]]] = {
|
||||
getListsRecommendedFromHistory(target).flatMap { historicalListIds =>
|
||||
Future
|
||||
.join(
|
||||
getPopGeoLists(target, historicalListIds),
|
||||
getIDSListRecs(target, historicalListIds)
|
||||
)
|
||||
.map {
|
||||
case (popGeoListsIds, idsListIds) =>
|
||||
val candidates = (idsListIds ++ popGeoListsIds).map(getListCandidate(target, _))
|
||||
Some(candidates)
|
||||
case _ =>
|
||||
noCandidatesCounter.incr()
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private val pushCapFatiguePredicate = TargetPredicates.pushRecTypeFatiguePredicate(
|
||||
CommonRecommendationType.List,
|
||||
PushFeatureSwitchParams.ListRecommendationsPushInterval,
|
||||
PushFeatureSwitchParams.MaxListRecommendationsPushGivenInterval,
|
||||
stats,
|
||||
)
|
||||
override def isCandidateSourceAvailable(target: Target): Future[Boolean] = {
|
||||
|
||||
val isNotFatigued = pushCapFatiguePredicate.apply(Seq(target)).map(_.head)
|
||||
|
||||
Future
|
||||
.join(
|
||||
PushDeviceUtil.isRecommendationsEligible(target),
|
||||
isNotFatigued
|
||||
).map {
|
||||
case (userRecommendationsEligible, isUnderCAP) =>
|
||||
userRecommendationsEligible && isUnderCAP && target.params(
|
||||
PushFeatureSwitchParams.EnableListRecommendations)
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,54 @@
|
||||
package com.twitter.frigate.pushservice.adaptor
|
||||
|
||||
import com.twitter.finagle.stats.StatsReceiver
|
||||
import com.twitter.frigate.common.base.CandidateSource
|
||||
import com.twitter.frigate.common.base.CandidateSourceEligible
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.RawCandidate
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.Target
|
||||
import com.twitter.geoduck.service.thriftscala.LocationResponse
|
||||
import com.twitter.stitch.tweetypie.TweetyPie.TweetyPieResult
|
||||
import com.twitter.storehaus.ReadableStore
|
||||
import com.twitter.trends.trip_v1.trip_tweets.thriftscala.TripDomain
|
||||
import com.twitter.trends.trip_v1.trip_tweets.thriftscala.TripTweets
|
||||
import com.twitter.content_mixer.thriftscala.ContentMixerRequest
|
||||
import com.twitter.content_mixer.thriftscala.ContentMixerResponse
|
||||
import com.twitter.geoduck.common.thriftscala.Location
|
||||
import com.twitter.hermit.pop_geo.thriftscala.PopTweetsInPlace
|
||||
import com.twitter.recommendation.interests.discovery.core.model.InterestDomain
|
||||
|
||||
class LoggedOutPushCandidateSourceGenerator(
|
||||
tripTweetCandidateStore: ReadableStore[TripDomain, TripTweets],
|
||||
geoDuckV2Store: ReadableStore[Long, LocationResponse],
|
||||
safeCachedTweetyPieStoreV2: ReadableStore[Long, TweetyPieResult],
|
||||
cachedTweetyPieStoreV2NoVF: ReadableStore[Long, TweetyPieResult],
|
||||
cachedTweetyPieStoreV2: ReadableStore[Long, TweetyPieResult],
|
||||
contentMixerStore: ReadableStore[ContentMixerRequest, ContentMixerResponse],
|
||||
softUserLocationStore: ReadableStore[Long, Location],
|
||||
topTweetsByGeoStore: ReadableStore[InterestDomain[String], Map[String, List[(Long, Double)]]],
|
||||
topTweetsByGeoV2VersionedStore: ReadableStore[String, PopTweetsInPlace],
|
||||
)(
|
||||
implicit val globalStats: StatsReceiver) {
|
||||
val sources: Seq[CandidateSource[Target, RawCandidate] with CandidateSourceEligible[
|
||||
Target,
|
||||
RawCandidate
|
||||
]] = {
|
||||
Seq(
|
||||
TripGeoCandidatesAdaptor(
|
||||
tripTweetCandidateStore,
|
||||
contentMixerStore,
|
||||
safeCachedTweetyPieStoreV2,
|
||||
cachedTweetyPieStoreV2NoVF,
|
||||
globalStats
|
||||
),
|
||||
TopTweetsByGeoAdaptor(
|
||||
geoDuckV2Store,
|
||||
softUserLocationStore,
|
||||
topTweetsByGeoStore,
|
||||
topTweetsByGeoV2VersionedStore,
|
||||
cachedTweetyPieStoreV2,
|
||||
cachedTweetyPieStoreV2NoVF,
|
||||
globalStats
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
@ -0,0 +1,101 @@
|
||||
package com.twitter.frigate.pushservice.adaptor
|
||||
|
||||
import com.twitter.finagle.stats.StatsReceiver
|
||||
import com.twitter.frigate.common.base.CandidateSource
|
||||
import com.twitter.frigate.common.base.CandidateSourceEligible
|
||||
import com.twitter.frigate.common.base.DiscoverTwitterCandidate
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.RawCandidate
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.Target
|
||||
import com.twitter.frigate.pushservice.params.{PushFeatureSwitchParams => FS}
|
||||
import com.twitter.frigate.pushservice.predicate.DiscoverTwitterPredicate
|
||||
import com.twitter.frigate.pushservice.predicate.TargetPredicates
|
||||
import com.twitter.frigate.pushservice.util.PushAppPermissionUtil
|
||||
import com.twitter.frigate.pushservice.util.PushDeviceUtil
|
||||
import com.twitter.frigate.thriftscala.{CommonRecommendationType => CRT}
|
||||
import com.twitter.util.Future
|
||||
|
||||
class OnboardingPushCandidateAdaptor(
|
||||
globalStats: StatsReceiver)
|
||||
extends CandidateSource[Target, RawCandidate]
|
||||
with CandidateSourceEligible[Target, RawCandidate] {
|
||||
|
||||
override val name: String = this.getClass.getSimpleName
|
||||
|
||||
private[this] val stats = globalStats.scope(name)
|
||||
private[this] val requestNum = stats.counter("request_num")
|
||||
private[this] val addressBookCandNum = stats.counter("address_book_cand_num")
|
||||
private[this] val completeOnboardingCandNum = stats.counter("complete_onboarding_cand_num")
|
||||
|
||||
private def generateOnboardingPushRawCandidate(
|
||||
_target: Target,
|
||||
_commonRecType: CRT
|
||||
): RawCandidate = {
|
||||
new RawCandidate with DiscoverTwitterCandidate {
|
||||
override val target = _target
|
||||
override val commonRecType = _commonRecType
|
||||
}
|
||||
}
|
||||
|
||||
private def getEligibleCandsForTarget(
|
||||
target: Target
|
||||
): Future[Option[Seq[RawCandidate]]] = {
|
||||
val addressBookFatigue =
|
||||
TargetPredicates
|
||||
.pushRecTypeFatiguePredicate(
|
||||
CRT.AddressBookUploadPush,
|
||||
FS.FatigueForOnboardingPushes,
|
||||
FS.MaxOnboardingPushInInterval,
|
||||
stats)(Seq(target)).map(_.head)
|
||||
val completeOnboardingFatigue =
|
||||
TargetPredicates
|
||||
.pushRecTypeFatiguePredicate(
|
||||
CRT.CompleteOnboardingPush,
|
||||
FS.FatigueForOnboardingPushes,
|
||||
FS.MaxOnboardingPushInInterval,
|
||||
stats)(Seq(target)).map(_.head)
|
||||
|
||||
Future
|
||||
.join(
|
||||
target.appPermissions,
|
||||
addressBookFatigue,
|
||||
completeOnboardingFatigue
|
||||
).map {
|
||||
case (appPermissionOpt, addressBookPredicate, completeOnboardingPredicate) =>
|
||||
val addressBookUploaded =
|
||||
PushAppPermissionUtil.hasTargetUploadedAddressBook(appPermissionOpt)
|
||||
val abUploadCandidate =
|
||||
if (!addressBookUploaded && addressBookPredicate && target.params(
|
||||
FS.EnableAddressBookPush)) {
|
||||
addressBookCandNum.incr()
|
||||
Some(generateOnboardingPushRawCandidate(target, CRT.AddressBookUploadPush))
|
||||
} else if (!addressBookUploaded && (completeOnboardingPredicate ||
|
||||
target.params(FS.DisableOnboardingPushFatigue)) && target.params(
|
||||
FS.EnableCompleteOnboardingPush)) {
|
||||
completeOnboardingCandNum.incr()
|
||||
Some(generateOnboardingPushRawCandidate(target, CRT.CompleteOnboardingPush))
|
||||
} else None
|
||||
|
||||
val allCandidates =
|
||||
Seq(abUploadCandidate).filter(_.isDefined).flatten
|
||||
if (allCandidates.nonEmpty) Some(allCandidates) else None
|
||||
}
|
||||
}
|
||||
|
||||
override def get(inputTarget: Target): Future[Option[Seq[RawCandidate]]] = {
|
||||
requestNum.incr()
|
||||
val minDurationForMRElapsed =
|
||||
DiscoverTwitterPredicate
|
||||
.minDurationElapsedSinceLastMrPushPredicate(
|
||||
name,
|
||||
FS.MrMinDurationSincePushForOnboardingPushes,
|
||||
stats)(Seq(inputTarget)).map(_.head)
|
||||
minDurationForMRElapsed.flatMap { minDurationElapsed =>
|
||||
if (minDurationElapsed) getEligibleCandsForTarget(inputTarget) else Future.None
|
||||
}
|
||||
}
|
||||
|
||||
override def isCandidateSourceAvailable(target: Target): Future[Boolean] = {
|
||||
PushDeviceUtil
|
||||
.isRecommendationsEligible(target).map(_ && target.params(FS.EnableOnboardingPushes))
|
||||
}
|
||||
}
|
@ -0,0 +1,162 @@
|
||||
package com.twitter.frigate.pushservice.adaptor
|
||||
|
||||
import com.twitter.content_mixer.thriftscala.ContentMixerRequest
|
||||
import com.twitter.content_mixer.thriftscala.ContentMixerResponse
|
||||
import com.twitter.explore_ranker.thriftscala.ExploreRankerRequest
|
||||
import com.twitter.explore_ranker.thriftscala.ExploreRankerResponse
|
||||
import com.twitter.finagle.stats.StatsReceiver
|
||||
import com.twitter.frigate.common.base._
|
||||
import com.twitter.frigate.common.candidate._
|
||||
import com.twitter.frigate.common.store.RecentTweetsQuery
|
||||
import com.twitter.frigate.common.store.interests.InterestsLookupRequestWithContext
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.RawCandidate
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.Target
|
||||
import com.twitter.frigate.pushservice.params.PushFeatureSwitchParams
|
||||
import com.twitter.frigate.pushservice.store._
|
||||
import com.twitter.geoduck.common.thriftscala.Location
|
||||
import com.twitter.geoduck.service.thriftscala.LocationResponse
|
||||
import com.twitter.hermit.pop_geo.thriftscala.PopTweetsInPlace
|
||||
import com.twitter.hermit.predicate.socialgraph.RelationEdge
|
||||
import com.twitter.hermit.store.tweetypie.UserTweet
|
||||
import com.twitter.interests.thriftscala.UserInterests
|
||||
import com.twitter.interests_discovery.thriftscala.NonPersonalizedRecommendedLists
|
||||
import com.twitter.interests_discovery.thriftscala.RecommendedListsRequest
|
||||
import com.twitter.interests_discovery.thriftscala.RecommendedListsResponse
|
||||
import com.twitter.recommendation.interests.discovery.core.model.InterestDomain
|
||||
import com.twitter.stitch.tweetypie.TweetyPie.TweetyPieResult
|
||||
import com.twitter.storehaus.ReadableStore
|
||||
import com.twitter.trends.trip_v1.trip_tweets.thriftscala.TripDomain
|
||||
import com.twitter.trends.trip_v1.trip_tweets.thriftscala.TripTweets
|
||||
import com.twitter.tsp.thriftscala.TopicSocialProofRequest
|
||||
import com.twitter.tsp.thriftscala.TopicSocialProofResponse
|
||||
|
||||
/**
|
||||
* PushCandidateSourceGenerator generates candidate source list for a given Target user
|
||||
*/
|
||||
class PushCandidateSourceGenerator(
|
||||
earlybirdCandidates: CandidateSource[EarlybirdCandidateSource.Query, EarlybirdCandidate],
|
||||
userTweetEntityGraphCandidates: CandidateSource[UserTweetEntityGraphCandidates.Target, Candidate],
|
||||
cachedTweetyPieStoreV2: ReadableStore[Long, TweetyPieResult],
|
||||
safeCachedTweetyPieStoreV2: ReadableStore[Long, TweetyPieResult],
|
||||
userTweetTweetyPieStore: ReadableStore[UserTweet, TweetyPieResult],
|
||||
safeUserTweetTweetyPieStore: ReadableStore[UserTweet, TweetyPieResult],
|
||||
cachedTweetyPieStoreV2NoVF: ReadableStore[Long, TweetyPieResult],
|
||||
edgeStore: ReadableStore[RelationEdge, Boolean],
|
||||
interestsLookupStore: ReadableStore[InterestsLookupRequestWithContext, UserInterests],
|
||||
uttEntityHydrationStore: UttEntityHydrationStore,
|
||||
geoDuckV2Store: ReadableStore[Long, LocationResponse],
|
||||
topTweetsByGeoStore: ReadableStore[InterestDomain[String], Map[String, List[(Long, Double)]]],
|
||||
topTweetsByGeoV2VersionedStore: ReadableStore[String, PopTweetsInPlace],
|
||||
tweetImpressionsStore: TweetImpressionsStore,
|
||||
recommendedTrendsCandidateSource: RecommendedTrendsCandidateSource,
|
||||
recentTweetsByAuthorStore: ReadableStore[RecentTweetsQuery, Seq[Seq[Long]]],
|
||||
topicSocialProofServiceStore: ReadableStore[TopicSocialProofRequest, TopicSocialProofResponse],
|
||||
crMixerStore: CrMixerTweetStore,
|
||||
contentMixerStore: ReadableStore[ContentMixerRequest, ContentMixerResponse],
|
||||
exploreRankerStore: ReadableStore[ExploreRankerRequest, ExploreRankerResponse],
|
||||
softUserLocationStore: ReadableStore[Long, Location],
|
||||
tripTweetCandidateStore: ReadableStore[TripDomain, TripTweets],
|
||||
listRecsStore: ReadableStore[String, NonPersonalizedRecommendedLists],
|
||||
idsStore: ReadableStore[RecommendedListsRequest, RecommendedListsResponse]
|
||||
)(
|
||||
implicit val globalStats: StatsReceiver) {
|
||||
|
||||
private val earlyBirdFirstDegreeCandidateAdaptor = EarlyBirdFirstDegreeCandidateAdaptor(
|
||||
earlybirdCandidates,
|
||||
cachedTweetyPieStoreV2,
|
||||
cachedTweetyPieStoreV2NoVF,
|
||||
userTweetTweetyPieStore,
|
||||
PushFeatureSwitchParams.NumberOfMaxEarlybirdInNetworkCandidatesParam,
|
||||
globalStats
|
||||
)
|
||||
|
||||
private val frsTweetCandidateAdaptor = FRSTweetCandidateAdaptor(
|
||||
crMixerStore,
|
||||
cachedTweetyPieStoreV2,
|
||||
cachedTweetyPieStoreV2NoVF,
|
||||
userTweetTweetyPieStore,
|
||||
uttEntityHydrationStore,
|
||||
topicSocialProofServiceStore,
|
||||
globalStats
|
||||
)
|
||||
|
||||
private val contentRecommenderMixerAdaptor = ContentRecommenderMixerAdaptor(
|
||||
crMixerStore,
|
||||
safeCachedTweetyPieStoreV2,
|
||||
edgeStore,
|
||||
topicSocialProofServiceStore,
|
||||
uttEntityHydrationStore,
|
||||
globalStats
|
||||
)
|
||||
|
||||
private val tripGeoCandidatesAdaptor = TripGeoCandidatesAdaptor(
|
||||
tripTweetCandidateStore,
|
||||
contentMixerStore,
|
||||
safeCachedTweetyPieStoreV2,
|
||||
cachedTweetyPieStoreV2NoVF,
|
||||
globalStats
|
||||
)
|
||||
|
||||
val sources: Seq[
|
||||
CandidateSource[Target, RawCandidate] with CandidateSourceEligible[
|
||||
Target,
|
||||
RawCandidate
|
||||
]
|
||||
] = {
|
||||
Seq(
|
||||
earlyBirdFirstDegreeCandidateAdaptor,
|
||||
GenericCandidateAdaptor(
|
||||
userTweetEntityGraphCandidates,
|
||||
cachedTweetyPieStoreV2,
|
||||
cachedTweetyPieStoreV2NoVF,
|
||||
globalStats.scope("UserTweetEntityGraphCandidates")
|
||||
),
|
||||
new OnboardingPushCandidateAdaptor(globalStats),
|
||||
TopTweetsByGeoAdaptor(
|
||||
geoDuckV2Store,
|
||||
softUserLocationStore,
|
||||
topTweetsByGeoStore,
|
||||
topTweetsByGeoV2VersionedStore,
|
||||
cachedTweetyPieStoreV2,
|
||||
cachedTweetyPieStoreV2NoVF,
|
||||
globalStats
|
||||
),
|
||||
frsTweetCandidateAdaptor,
|
||||
TopTweetImpressionsCandidateAdaptor(
|
||||
recentTweetsByAuthorStore,
|
||||
cachedTweetyPieStoreV2,
|
||||
cachedTweetyPieStoreV2NoVF,
|
||||
tweetImpressionsStore,
|
||||
globalStats
|
||||
),
|
||||
TrendsCandidatesAdaptor(
|
||||
softUserLocationStore,
|
||||
recommendedTrendsCandidateSource,
|
||||
safeCachedTweetyPieStoreV2,
|
||||
cachedTweetyPieStoreV2NoVF,
|
||||
safeUserTweetTweetyPieStore,
|
||||
globalStats
|
||||
),
|
||||
contentRecommenderMixerAdaptor,
|
||||
tripGeoCandidatesAdaptor,
|
||||
HighQualityTweetsAdaptor(
|
||||
tripTweetCandidateStore,
|
||||
interestsLookupStore,
|
||||
cachedTweetyPieStoreV2,
|
||||
cachedTweetyPieStoreV2NoVF,
|
||||
globalStats
|
||||
),
|
||||
ExploreVideoTweetCandidateAdaptor(
|
||||
exploreRankerStore,
|
||||
cachedTweetyPieStoreV2,
|
||||
globalStats
|
||||
),
|
||||
ListsToRecommendCandidateAdaptor(
|
||||
listRecsStore,
|
||||
geoDuckV2Store,
|
||||
idsStore,
|
||||
globalStats
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
@ -0,0 +1,326 @@
|
||||
package com.twitter.frigate.pushservice.adaptor
|
||||
|
||||
import com.twitter.conversions.DurationOps._
|
||||
import com.twitter.finagle.stats.StatsReceiver
|
||||
import com.twitter.frigate.common.base.CandidateSource
|
||||
import com.twitter.frigate.common.base.CandidateSourceEligible
|
||||
import com.twitter.frigate.common.base.TopTweetImpressionsCandidate
|
||||
import com.twitter.frigate.common.store.RecentTweetsQuery
|
||||
import com.twitter.frigate.common.util.SnowflakeUtils
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.RawCandidate
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.Target
|
||||
import com.twitter.frigate.pushservice.params.{PushFeatureSwitchParams => FS}
|
||||
import com.twitter.frigate.pushservice.store.TweetImpressionsStore
|
||||
import com.twitter.frigate.pushservice.util.PushDeviceUtil
|
||||
import com.twitter.stitch.tweetypie.TweetyPie.TweetyPieResult
|
||||
import com.twitter.storehaus.FutureOps
|
||||
import com.twitter.storehaus.ReadableStore
|
||||
import com.twitter.util.Future
|
||||
|
||||
case class TweetImpressionsCandidate(
|
||||
tweetId: Long,
|
||||
tweetyPieResultOpt: Option[TweetyPieResult],
|
||||
impressionsCountOpt: Option[Long])
|
||||
|
||||
case class TopTweetImpressionsCandidateAdaptor(
|
||||
recentTweetsFromTflockStore: ReadableStore[RecentTweetsQuery, Seq[Seq[Long]]],
|
||||
tweetyPieStore: ReadableStore[Long, TweetyPieResult],
|
||||
tweetyPieStoreNoVF: ReadableStore[Long, TweetyPieResult],
|
||||
tweetImpressionsStore: TweetImpressionsStore,
|
||||
globalStats: StatsReceiver)
|
||||
extends CandidateSource[Target, RawCandidate]
|
||||
with CandidateSourceEligible[Target, RawCandidate] {
|
||||
|
||||
private val stats = globalStats.scope("TopTweetImpressionsAdaptor")
|
||||
private val tweetImpressionsCandsStat = stats.stat("top_tweet_impressions_cands_dist")
|
||||
|
||||
private val eligibleUsersCounter = stats.counter("eligible_users")
|
||||
private val noneligibleUsersCounter = stats.counter("noneligible_users")
|
||||
private val meetsMinTweetsRequiredCounter = stats.counter("meets_min_tweets_required")
|
||||
private val belowMinTweetsRequiredCounter = stats.counter("below_min_tweets_required")
|
||||
private val aboveMaxInboundFavoritesCounter = stats.counter("above_max_inbound_favorites")
|
||||
private val meetsImpressionsRequiredCounter = stats.counter("meets_impressions_required")
|
||||
private val belowImpressionsRequiredCounter = stats.counter("below_impressions_required")
|
||||
private val meetsFavoritesThresholdCounter = stats.counter("meets_favorites_threshold")
|
||||
private val aboveFavoritesThresholdCounter = stats.counter("above_favorites_threshold")
|
||||
private val emptyImpressionsMapCounter = stats.counter("empty_impressions_map")
|
||||
|
||||
private val tflockResultsStat = stats.stat("tflock", "results")
|
||||
private val emptyTflockResult = stats.counter("tflock", "empty_result")
|
||||
private val nonEmptyTflockResult = stats.counter("tflock", "non_empty_result")
|
||||
|
||||
private val originalTweetsStat = stats.stat("tweets", "original_tweets")
|
||||
private val retweetsStat = stats.stat("tweets", "retweets")
|
||||
private val allRetweetsOnlyCounter = stats.counter("tweets", "all_retweets_only")
|
||||
private val allOriginalTweetsOnlyCounter = stats.counter("tweets", "all_original_tweets_only")
|
||||
|
||||
private val emptyTweetypieMap = stats.counter("", "empty_tweetypie_map")
|
||||
private val emptyTweetyPieResult = stats.stat("", "empty_tweetypie_result")
|
||||
private val allEmptyTweetypieResults = stats.counter("", "all_empty_tweetypie_results")
|
||||
|
||||
private val eligibleUsersAfterImpressionsFilter =
|
||||
stats.counter("eligible_users_after_impressions_filter")
|
||||
private val eligibleUsersAfterFavoritesFilter =
|
||||
stats.counter("eligible_users_after_favorites_filter")
|
||||
private val eligibleUsersWithEligibleTweets =
|
||||
stats.counter("eligible_users_with_eligible_tweets")
|
||||
|
||||
private val eligibleTweetCands = stats.stat("eligible_tweet_cands")
|
||||
private val getCandsRequestCounter =
|
||||
stats.counter("top_tweet_impressions_get_request")
|
||||
|
||||
override val name: String = this.getClass.getSimpleName
|
||||
|
||||
override def get(inputTarget: Target): Future[Option[Seq[RawCandidate]]] = {
|
||||
getCandsRequestCounter.incr()
|
||||
val eligibleCandidatesFut = getTweetImpressionsCandidates(inputTarget)
|
||||
eligibleCandidatesFut.map { eligibleCandidates =>
|
||||
if (eligibleCandidates.nonEmpty) {
|
||||
eligibleUsersWithEligibleTweets.incr()
|
||||
eligibleTweetCands.add(eligibleCandidates.size)
|
||||
val candidate = getMostImpressionsTweet(eligibleCandidates)
|
||||
Some(
|
||||
Seq(
|
||||
generateTopTweetImpressionsCandidate(
|
||||
inputTarget,
|
||||
candidate.tweetId,
|
||||
candidate.tweetyPieResultOpt,
|
||||
candidate.impressionsCountOpt.getOrElse(0L))))
|
||||
} else None
|
||||
}
|
||||
}
|
||||
|
||||
private def getTweetImpressionsCandidates(
|
||||
inputTarget: Target
|
||||
): Future[Seq[TweetImpressionsCandidate]] = {
|
||||
val originalTweets = getRecentOriginalTweetsForUser(inputTarget)
|
||||
originalTweets.flatMap { tweetyPieResultsMap =>
|
||||
val numDaysSearchForOriginalTweets =
|
||||
inputTarget.params(FS.TopTweetImpressionsOriginalTweetsNumDaysSearch)
|
||||
val moreRecentTweetIds =
|
||||
getMoreRecentTweetIds(tweetyPieResultsMap.keySet.toSeq, numDaysSearchForOriginalTweets)
|
||||
val isEligible = isEligibleUser(inputTarget, tweetyPieResultsMap, moreRecentTweetIds)
|
||||
if (isEligible) filterByEligibility(inputTarget, tweetyPieResultsMap, moreRecentTweetIds)
|
||||
else Future.Nil
|
||||
}
|
||||
}
|
||||
|
||||
private def getRecentOriginalTweetsForUser(
|
||||
targetUser: Target
|
||||
): Future[Map[Long, TweetyPieResult]] = {
|
||||
val tweetyPieResultsMapFut = getTflockStoreResults(targetUser).flatMap { recentTweetIds =>
|
||||
FutureOps.mapCollect((targetUser.params(FS.EnableVFInTweetypie) match {
|
||||
case true => tweetyPieStore
|
||||
case false => tweetyPieStoreNoVF
|
||||
}).multiGet(recentTweetIds.toSet))
|
||||
}
|
||||
tweetyPieResultsMapFut.map { tweetyPieResultsMap =>
|
||||
if (tweetyPieResultsMap.isEmpty) {
|
||||
emptyTweetypieMap.incr()
|
||||
Map.empty
|
||||
} else removeRetweets(tweetyPieResultsMap)
|
||||
}
|
||||
}
|
||||
|
||||
private def getTflockStoreResults(targetUser: Target): Future[Seq[Long]] = {
|
||||
val maxResults = targetUser.params(FS.TopTweetImpressionsRecentTweetsByAuthorStoreMaxResults)
|
||||
val maxAge = targetUser.params(FS.TopTweetImpressionsTotalFavoritesLimitNumDaysSearch)
|
||||
val recentTweetsQuery =
|
||||
RecentTweetsQuery(
|
||||
userIds = Seq(targetUser.targetId),
|
||||
maxResults = maxResults,
|
||||
maxAge = maxAge.days
|
||||
)
|
||||
recentTweetsFromTflockStore
|
||||
.get(recentTweetsQuery).map {
|
||||
case Some(tweetIdsAll) =>
|
||||
val tweetIds = tweetIdsAll.headOption.getOrElse(Seq.empty)
|
||||
val numTweets = tweetIds.size
|
||||
if (numTweets > 0) {
|
||||
tflockResultsStat.add(numTweets)
|
||||
nonEmptyTflockResult.incr()
|
||||
} else emptyTflockResult.incr()
|
||||
tweetIds
|
||||
case _ => Nil
|
||||
}
|
||||
}
|
||||
|
||||
private def removeRetweets(
|
||||
tweetyPieResultsMap: Map[Long, Option[TweetyPieResult]]
|
||||
): Map[Long, TweetyPieResult] = {
|
||||
val nonEmptyTweetyPieResults: Map[Long, TweetyPieResult] = tweetyPieResultsMap.collect {
|
||||
case (key, Some(value)) => (key, value)
|
||||
}
|
||||
emptyTweetyPieResult.add(tweetyPieResultsMap.size - nonEmptyTweetyPieResults.size)
|
||||
|
||||
if (nonEmptyTweetyPieResults.nonEmpty) {
|
||||
val originalTweets = nonEmptyTweetyPieResults.filter {
|
||||
case (_, tweetyPieResult) =>
|
||||
tweetyPieResult.sourceTweet.isEmpty
|
||||
}
|
||||
val numOriginalTweets = originalTweets.size
|
||||
val numRetweets = nonEmptyTweetyPieResults.size - originalTweets.size
|
||||
originalTweetsStat.add(numOriginalTweets)
|
||||
retweetsStat.add(numRetweets)
|
||||
if (numRetweets == 0) allOriginalTweetsOnlyCounter.incr()
|
||||
if (numOriginalTweets == 0) allRetweetsOnlyCounter.incr()
|
||||
originalTweets
|
||||
} else {
|
||||
allEmptyTweetypieResults.incr()
|
||||
Map.empty
|
||||
}
|
||||
}
|
||||
|
||||
private def getMoreRecentTweetIds(
|
||||
tweetIds: Seq[Long],
|
||||
numDays: Int
|
||||
): Seq[Long] = {
|
||||
tweetIds.filter { tweetId =>
|
||||
SnowflakeUtils.isRecent(tweetId, numDays.days)
|
||||
}
|
||||
}
|
||||
|
||||
private def isEligibleUser(
|
||||
inputTarget: Target,
|
||||
tweetyPieResults: Map[Long, TweetyPieResult],
|
||||
recentTweetIds: Seq[Long]
|
||||
): Boolean = {
|
||||
val minNumTweets = inputTarget.params(FS.TopTweetImpressionsMinNumOriginalTweets)
|
||||
lazy val totalFavoritesLimit =
|
||||
inputTarget.params(FS.TopTweetImpressionsTotalInboundFavoritesLimit)
|
||||
if (recentTweetIds.size >= minNumTweets) {
|
||||
meetsMinTweetsRequiredCounter.incr()
|
||||
val isUnderLimit = isUnderTotalInboundFavoritesLimit(tweetyPieResults, totalFavoritesLimit)
|
||||
if (isUnderLimit) eligibleUsersCounter.incr()
|
||||
else {
|
||||
aboveMaxInboundFavoritesCounter.incr()
|
||||
noneligibleUsersCounter.incr()
|
||||
}
|
||||
isUnderLimit
|
||||
} else {
|
||||
belowMinTweetsRequiredCounter.incr()
|
||||
noneligibleUsersCounter.incr()
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
private def getFavoriteCounts(
|
||||
tweetyPieResult: TweetyPieResult
|
||||
): Long = tweetyPieResult.tweet.counts.flatMap(_.favoriteCount).getOrElse(0L)
|
||||
|
||||
private def isUnderTotalInboundFavoritesLimit(
|
||||
tweetyPieResults: Map[Long, TweetyPieResult],
|
||||
totalFavoritesLimit: Long
|
||||
): Boolean = {
|
||||
val favoritesIterator = tweetyPieResults.valuesIterator.map(getFavoriteCounts)
|
||||
val totalInboundFavorites = favoritesIterator.sum
|
||||
totalInboundFavorites <= totalFavoritesLimit
|
||||
}
|
||||
|
||||
def filterByEligibility(
|
||||
inputTarget: Target,
|
||||
tweetyPieResults: Map[Long, TweetyPieResult],
|
||||
tweetIds: Seq[Long]
|
||||
): Future[Seq[TweetImpressionsCandidate]] = {
|
||||
lazy val minNumImpressions: Long = inputTarget.params(FS.TopTweetImpressionsMinRequired)
|
||||
lazy val maxNumLikes: Long = inputTarget.params(FS.TopTweetImpressionsMaxFavoritesPerTweet)
|
||||
for {
|
||||
filteredImpressionsMap <- getFilteredImpressionsMap(tweetIds, minNumImpressions)
|
||||
tweetIdsFilteredByFavorites <-
|
||||
getTweetIdsFilteredByFavorites(filteredImpressionsMap.keySet, tweetyPieResults, maxNumLikes)
|
||||
} yield {
|
||||
if (filteredImpressionsMap.nonEmpty) eligibleUsersAfterImpressionsFilter.incr()
|
||||
if (tweetIdsFilteredByFavorites.nonEmpty) eligibleUsersAfterFavoritesFilter.incr()
|
||||
|
||||
val candidates = tweetIdsFilteredByFavorites.map { tweetId =>
|
||||
TweetImpressionsCandidate(
|
||||
tweetId,
|
||||
tweetyPieResults.get(tweetId),
|
||||
filteredImpressionsMap.get(tweetId))
|
||||
}
|
||||
tweetImpressionsCandsStat.add(candidates.length)
|
||||
candidates
|
||||
}
|
||||
}
|
||||
|
||||
private def getFilteredImpressionsMap(
|
||||
tweetIds: Seq[Long],
|
||||
minNumImpressions: Long
|
||||
): Future[Map[Long, Long]] = {
|
||||
getImpressionsCounts(tweetIds).map { impressionsMap =>
|
||||
if (impressionsMap.isEmpty) emptyImpressionsMapCounter.incr()
|
||||
impressionsMap.filter {
|
||||
case (_, numImpressions) =>
|
||||
val isValid = numImpressions >= minNumImpressions
|
||||
if (isValid) {
|
||||
meetsImpressionsRequiredCounter.incr()
|
||||
} else {
|
||||
belowImpressionsRequiredCounter.incr()
|
||||
}
|
||||
isValid
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private def getTweetIdsFilteredByFavorites(
|
||||
filteredTweetIds: Set[Long],
|
||||
tweetyPieResults: Map[Long, TweetyPieResult],
|
||||
maxNumLikes: Long
|
||||
): Future[Seq[Long]] = {
|
||||
val filteredByFavoritesTweetIds = filteredTweetIds.filter { tweetId =>
|
||||
val tweetyPieResultOpt = tweetyPieResults.get(tweetId)
|
||||
val isValid = tweetyPieResultOpt.exists { tweetyPieResult =>
|
||||
getFavoriteCounts(tweetyPieResult) <= maxNumLikes
|
||||
}
|
||||
if (isValid) meetsFavoritesThresholdCounter.incr()
|
||||
else aboveFavoritesThresholdCounter.incr()
|
||||
isValid
|
||||
}
|
||||
Future(filteredByFavoritesTweetIds.toSeq)
|
||||
}
|
||||
|
||||
private def getMostImpressionsTweet(
|
||||
filteredResults: Seq[TweetImpressionsCandidate]
|
||||
): TweetImpressionsCandidate = {
|
||||
val maxImpressions: Long = filteredResults.map {
|
||||
_.impressionsCountOpt.getOrElse(0L)
|
||||
}.max
|
||||
|
||||
val mostImpressionsCandidates: Seq[TweetImpressionsCandidate] =
|
||||
filteredResults.filter(_.impressionsCountOpt.getOrElse(0L) == maxImpressions)
|
||||
|
||||
mostImpressionsCandidates.maxBy(_.tweetId)
|
||||
}
|
||||
|
||||
private def getImpressionsCounts(
|
||||
tweetIds: Seq[Long]
|
||||
): Future[Map[Long, Long]] = {
|
||||
val impressionCountMap = tweetIds.map { tweetId =>
|
||||
tweetId -> tweetImpressionsStore
|
||||
.getCounts(tweetId).map(_.getOrElse(0L))
|
||||
}.toMap
|
||||
Future.collect(impressionCountMap)
|
||||
}
|
||||
|
||||
private def generateTopTweetImpressionsCandidate(
|
||||
inputTarget: Target,
|
||||
_tweetId: Long,
|
||||
result: Option[TweetyPieResult],
|
||||
_impressionsCount: Long
|
||||
): RawCandidate = {
|
||||
new RawCandidate with TopTweetImpressionsCandidate {
|
||||
override val target: Target = inputTarget
|
||||
override val tweetId: Long = _tweetId
|
||||
override val tweetyPieResult: Option[TweetyPieResult] = result
|
||||
override val impressionsCount: Long = _impressionsCount
|
||||
}
|
||||
}
|
||||
|
||||
override def isCandidateSourceAvailable(target: Target): Future[Boolean] = {
|
||||
val enabledTopTweetImpressionsNotification =
|
||||
target.params(FS.EnableTopTweetImpressionsNotification)
|
||||
|
||||
PushDeviceUtil
|
||||
.isRecommendationsEligible(target).map(_ && enabledTopTweetImpressionsNotification)
|
||||
}
|
||||
}
|
@ -0,0 +1,413 @@
|
||||
package com.twitter.frigate.pushservice.adaptor
|
||||
|
||||
import com.twitter.finagle.stats.Counter
|
||||
import com.twitter.finagle.stats.Stat
|
||||
import com.twitter.finagle.stats.StatsReceiver
|
||||
import com.twitter.frigate.common.base.CandidateSource
|
||||
import com.twitter.frigate.common.base.CandidateSourceEligible
|
||||
import com.twitter.frigate.common.base.TweetCandidate
|
||||
import com.twitter.frigate.common.predicate.CommonOutNetworkTweetCandidatesSourcePredicates.filterOutReplyTweet
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.RawCandidate
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.Target
|
||||
import com.twitter.frigate.pushservice.model.PushTypes
|
||||
import com.twitter.frigate.pushservice.params.PopGeoTweetVersion
|
||||
import com.twitter.frigate.pushservice.params.PushParams
|
||||
import com.twitter.frigate.pushservice.params.TopTweetsForGeoCombination
|
||||
import com.twitter.frigate.pushservice.params.TopTweetsForGeoRankingFunction
|
||||
import com.twitter.frigate.pushservice.params.{PushFeatureSwitchParams => FS}
|
||||
import com.twitter.frigate.pushservice.predicate.DiscoverTwitterPredicate
|
||||
import com.twitter.frigate.pushservice.predicate.TargetPredicates
|
||||
import com.twitter.frigate.pushservice.util.MediaCRT
|
||||
import com.twitter.frigate.pushservice.util.PushAdaptorUtil
|
||||
import com.twitter.frigate.pushservice.util.PushDeviceUtil
|
||||
import com.twitter.frigate.thriftscala.CommonRecommendationType
|
||||
import com.twitter.geoduck.common.thriftscala.{Location => GeoLocation}
|
||||
import com.twitter.geoduck.service.thriftscala.LocationResponse
|
||||
import com.twitter.gizmoduck.thriftscala.UserType
|
||||
import com.twitter.hermit.pop_geo.thriftscala.PopTweetsInPlace
|
||||
import com.twitter.recommendation.interests.discovery.core.model.InterestDomain
|
||||
import com.twitter.stitch.tweetypie.TweetyPie.TweetyPieResult
|
||||
import com.twitter.storehaus.FutureOps
|
||||
import com.twitter.storehaus.ReadableStore
|
||||
import com.twitter.util.Future
|
||||
import com.twitter.util.Time
|
||||
import scala.collection.Map
|
||||
|
||||
case class PlaceTweetScore(place: String, tweetId: Long, score: Double) {
|
||||
def toTweetScore: (Long, Double) = (tweetId, score)
|
||||
}
|
||||
case class TopTweetsByGeoAdaptor(
|
||||
geoduckStoreV2: ReadableStore[Long, LocationResponse],
|
||||
softUserGeoLocationStore: ReadableStore[Long, GeoLocation],
|
||||
topTweetsByGeoStore: ReadableStore[InterestDomain[String], Map[String, List[(Long, Double)]]],
|
||||
topTweetsByGeoStoreV2: ReadableStore[String, PopTweetsInPlace],
|
||||
tweetyPieStore: ReadableStore[Long, TweetyPieResult],
|
||||
tweetyPieStoreNoVF: ReadableStore[Long, TweetyPieResult],
|
||||
globalStats: StatsReceiver)
|
||||
extends CandidateSource[Target, RawCandidate]
|
||||
with CandidateSourceEligible[Target, RawCandidate] {
|
||||
|
||||
override def name: String = this.getClass.getSimpleName
|
||||
|
||||
private[this] val stats = globalStats.scope("TopTweetsByGeoAdaptor")
|
||||
private[this] val noGeohashUserCounter: Counter = stats.counter("users_with_no_geohash_counter")
|
||||
private[this] val incomingRequestCounter: Counter = stats.counter("incoming_request_counter")
|
||||
private[this] val incomingLoggedOutRequestCounter: Counter =
|
||||
stats.counter("incoming_logged_out_request_counter")
|
||||
private[this] val loggedOutRawCandidatesCounter =
|
||||
stats.counter("logged_out_raw_candidates_counter")
|
||||
private[this] val emptyLoggedOutRawCandidatesCounter =
|
||||
stats.counter("logged_out_empty_raw_candidates")
|
||||
private[this] val outputTopTweetsByGeoCounter: Stat =
|
||||
stats.stat("output_top_tweets_by_geo_counter")
|
||||
private[this] val loggedOutPopByGeoV2CandidatesCounter: Counter =
|
||||
stats.counter("logged_out_pop_by_geo_candidates")
|
||||
private[this] val dormantUsersSince14DaysCounter: Counter =
|
||||
stats.counter("dormant_user_since_14_days_counter")
|
||||
private[this] val dormantUsersSince30DaysCounter: Counter =
|
||||
stats.counter("dormant_user_since_30_days_counter")
|
||||
private[this] val nonDormantUsersSince14DaysCounter: Counter =
|
||||
stats.counter("non_dormant_user_since_14_days_counter")
|
||||
private[this] val topTweetsByGeoTake100Counter: Counter =
|
||||
stats.counter("top_tweets_by_geo_take_100_counter")
|
||||
private[this] val combinationRequestsCounter =
|
||||
stats.scope("combination_method_request_counter")
|
||||
private[this] val popGeoTweetVersionCounter =
|
||||
stats.scope("popgeo_tweet_version_counter")
|
||||
private[this] val nonReplyTweetsCounter = stats.counter("non_reply_tweets")
|
||||
|
||||
val MaxGeoHashSize = 4
|
||||
|
||||
private def constructKeys(
|
||||
geohash: Option[String],
|
||||
accountCountryCode: Option[String],
|
||||
keyLengths: Seq[Int],
|
||||
version: PopGeoTweetVersion.Value
|
||||
): Set[String] = {
|
||||
val geohashKeys = geohash match {
|
||||
case Some(hash) => keyLengths.map { version + "_geohash_" + hash.take(_) }
|
||||
case _ => Seq.empty
|
||||
}
|
||||
|
||||
val accountCountryCodeKeys =
|
||||
accountCountryCode.toSeq.map(version + "_country_" + _.toUpperCase)
|
||||
(geohashKeys ++ accountCountryCodeKeys).toSet
|
||||
}
|
||||
|
||||
def convertToPlaceTweetScore(
|
||||
popTweetsInPlace: Seq[PopTweetsInPlace]
|
||||
): Seq[PlaceTweetScore] = {
|
||||
popTweetsInPlace.flatMap {
|
||||
case p =>
|
||||
p.popTweets.map {
|
||||
case popTweet => PlaceTweetScore(p.place, popTweet.tweetId, popTweet.score)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def sortGeoHashTweets(
|
||||
placeTweetScores: Seq[PlaceTweetScore],
|
||||
rankingFunction: TopTweetsForGeoRankingFunction.Value
|
||||
): Seq[PlaceTweetScore] = {
|
||||
rankingFunction match {
|
||||
case TopTweetsForGeoRankingFunction.Score =>
|
||||
placeTweetScores.sortBy(_.score)(Ordering[Double].reverse)
|
||||
case TopTweetsForGeoRankingFunction.GeohashLengthAndThenScore =>
|
||||
placeTweetScores
|
||||
.sortBy(row => (row.place.length, row.score))(Ordering[(Int, Double)].reverse)
|
||||
}
|
||||
}
|
||||
|
||||
def getResultsForLambdaStore(
|
||||
inputTarget: Target,
|
||||
geohash: Option[String],
|
||||
store: ReadableStore[String, PopTweetsInPlace],
|
||||
topk: Int,
|
||||
version: PopGeoTweetVersion.Value
|
||||
): Future[Seq[(Long, Double)]] = {
|
||||
inputTarget.accountCountryCode.flatMap { countryCode =>
|
||||
val keys = {
|
||||
if (inputTarget.params(FS.EnableCountryCodeBackoffTopTweetsByGeo))
|
||||
constructKeys(geohash, countryCode, inputTarget.params(FS.GeoHashLengthList), version)
|
||||
else
|
||||
constructKeys(geohash, None, inputTarget.params(FS.GeoHashLengthList), version)
|
||||
}
|
||||
FutureOps
|
||||
.mapCollect(store.multiGet(keys)).map {
|
||||
case geohashTweetMap =>
|
||||
val popTweets =
|
||||
geohashTweetMap.values.flatten.toSeq
|
||||
val results = sortGeoHashTweets(
|
||||
convertToPlaceTweetScore(popTweets),
|
||||
inputTarget.params(FS.RankingFunctionForTopTweetsByGeo))
|
||||
.map(_.toTweetScore).take(topk)
|
||||
results
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def getPopGeoTweetsForLoggedOutUsers(
|
||||
inputTarget: Target,
|
||||
store: ReadableStore[String, PopTweetsInPlace]
|
||||
): Future[Seq[(Long, Double)]] = {
|
||||
inputTarget.countryCode.flatMap { countryCode =>
|
||||
val keys = constructKeys(None, countryCode, Seq(4), PopGeoTweetVersion.Prod)
|
||||
FutureOps.mapCollect(store.multiGet(keys)).map {
|
||||
case tweetMap =>
|
||||
val tweets = tweetMap.values.flatten.toSeq
|
||||
loggedOutPopByGeoV2CandidatesCounter.incr(tweets.size)
|
||||
val popTweets = sortGeoHashTweets(
|
||||
convertToPlaceTweetScore(tweets),
|
||||
TopTweetsForGeoRankingFunction.Score).map(_.toTweetScore)
|
||||
popTweets
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def getRankedTweets(
|
||||
inputTarget: Target,
|
||||
geohash: Option[String]
|
||||
): Future[Seq[(Long, Double)]] = {
|
||||
val MaxTopTweetsByGeoCandidatesToTake =
|
||||
inputTarget.params(FS.MaxTopTweetsByGeoCandidatesToTake)
|
||||
val scoringFn: String = inputTarget.params(FS.ScoringFuncForTopTweetsByGeo)
|
||||
val combinationMethod = inputTarget.params(FS.TopTweetsByGeoCombinationParam)
|
||||
val popGeoTweetVersion = inputTarget.params(FS.PopGeoTweetVersionParam)
|
||||
|
||||
inputTarget.isHeavyUserState.map { isHeavyUser =>
|
||||
stats
|
||||
.scope(combinationMethod.toString).scope(popGeoTweetVersion.toString).scope(
|
||||
"IsHeavyUser_" + isHeavyUser.toString).counter().incr()
|
||||
}
|
||||
combinationRequestsCounter.scope(combinationMethod.toString).counter().incr()
|
||||
popGeoTweetVersionCounter.scope(popGeoTweetVersion.toString).counter().incr()
|
||||
lazy val geoStoreResults = if (geohash.isDefined) {
|
||||
val hash = geohash.get.take(MaxGeoHashSize)
|
||||
topTweetsByGeoStore
|
||||
.get(
|
||||
InterestDomain[String](hash)
|
||||
)
|
||||
.map {
|
||||
case Some(scoringFnToTweetsMapOpt) =>
|
||||
val tweetsWithScore = scoringFnToTweetsMapOpt
|
||||
.getOrElse(scoringFn, List.empty)
|
||||
val sortedResults = sortGeoHashTweets(
|
||||
tweetsWithScore.map {
|
||||
case (tweetId, score) => PlaceTweetScore(hash, tweetId, score)
|
||||
},
|
||||
TopTweetsForGeoRankingFunction.Score
|
||||
).map(_.toTweetScore).take(
|
||||
MaxTopTweetsByGeoCandidatesToTake
|
||||
)
|
||||
sortedResults
|
||||
case _ => Seq.empty
|
||||
}
|
||||
} else Future.value(Seq.empty)
|
||||
lazy val versionPopGeoTweetResults =
|
||||
getResultsForLambdaStore(
|
||||
inputTarget,
|
||||
geohash,
|
||||
topTweetsByGeoStoreV2,
|
||||
MaxTopTweetsByGeoCandidatesToTake,
|
||||
popGeoTweetVersion
|
||||
)
|
||||
combinationMethod match {
|
||||
case TopTweetsForGeoCombination.Default => geoStoreResults
|
||||
case TopTweetsForGeoCombination.AccountsTweetFavAsBackfill =>
|
||||
Future.join(geoStoreResults, versionPopGeoTweetResults).map {
|
||||
case (geoStoreTweets, versionPopGeoTweets) =>
|
||||
(geoStoreTweets ++ versionPopGeoTweets).take(MaxTopTweetsByGeoCandidatesToTake)
|
||||
}
|
||||
case TopTweetsForGeoCombination.AccountsTweetFavIntermixed =>
|
||||
Future.join(geoStoreResults, versionPopGeoTweetResults).map {
|
||||
case (geoStoreTweets, versionPopGeoTweets) =>
|
||||
CandidateSource.interleaveSeqs(Seq(geoStoreTweets, versionPopGeoTweets))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override def get(inputTarget: Target): Future[Option[Seq[RawCandidate]]] = {
|
||||
if (inputTarget.isLoggedOutUser) {
|
||||
incomingLoggedOutRequestCounter.incr()
|
||||
val rankedTweets = getPopGeoTweetsForLoggedOutUsers(inputTarget, topTweetsByGeoStoreV2)
|
||||
val rawCandidates = {
|
||||
rankedTweets.map { rt =>
|
||||
FutureOps
|
||||
.mapCollect(
|
||||
tweetyPieStore
|
||||
.multiGet(rt.map { case (tweetId, _) => tweetId }.toSet))
|
||||
.map { tweetyPieResultMap =>
|
||||
val results = buildTopTweetsByGeoRawCandidates(
|
||||
inputTarget,
|
||||
None,
|
||||
tweetyPieResultMap
|
||||
)
|
||||
if (results.isEmpty) {
|
||||
emptyLoggedOutRawCandidatesCounter.incr()
|
||||
}
|
||||
loggedOutRawCandidatesCounter.incr(results.size)
|
||||
Some(results)
|
||||
}
|
||||
}.flatten
|
||||
}
|
||||
rawCandidates
|
||||
} else {
|
||||
incomingRequestCounter.incr()
|
||||
getGeoHashForUsers(inputTarget).flatMap { geohash =>
|
||||
if (geohash.isEmpty) noGeohashUserCounter.incr()
|
||||
getRankedTweets(inputTarget, geohash).map { rt =>
|
||||
if (rt.size == 100) {
|
||||
topTweetsByGeoTake100Counter.incr(1)
|
||||
}
|
||||
FutureOps
|
||||
.mapCollect((inputTarget.params(FS.EnableVFInTweetypie) match {
|
||||
case true => tweetyPieStore
|
||||
case false => tweetyPieStoreNoVF
|
||||
}).multiGet(rt.map { case (tweetId, _) => tweetId }.toSet))
|
||||
.map { tweetyPieResultMap =>
|
||||
Some(
|
||||
buildTopTweetsByGeoRawCandidates(
|
||||
inputTarget,
|
||||
None,
|
||||
filterOutReplyTweet(
|
||||
tweetyPieResultMap,
|
||||
nonReplyTweetsCounter
|
||||
)
|
||||
)
|
||||
)
|
||||
}
|
||||
}.flatten
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private def getGeoHashForUsers(
|
||||
inputTarget: Target
|
||||
): Future[Option[String]] = {
|
||||
|
||||
inputTarget.targetUser.flatMap {
|
||||
case Some(user) =>
|
||||
user.userType match {
|
||||
case UserType.Soft =>
|
||||
softUserGeoLocationStore
|
||||
.get(inputTarget.targetId)
|
||||
.map(_.flatMap(_.geohash.flatMap(_.stringGeohash)))
|
||||
|
||||
case _ =>
|
||||
geoduckStoreV2.get(inputTarget.targetId).map(_.flatMap(_.geohash))
|
||||
}
|
||||
|
||||
case None => Future.None
|
||||
}
|
||||
}
|
||||
|
||||
private def buildTopTweetsByGeoRawCandidates(
|
||||
target: PushTypes.Target,
|
||||
locationName: Option[String],
|
||||
topTweets: Map[Long, Option[TweetyPieResult]]
|
||||
): Seq[RawCandidate with TweetCandidate] = {
|
||||
val candidates = topTweets.map { tweetIdTweetyPieResultMap =>
|
||||
PushAdaptorUtil.generateOutOfNetworkTweetCandidates(
|
||||
inputTarget = target,
|
||||
id = tweetIdTweetyPieResultMap._1,
|
||||
mediaCRT = MediaCRT(
|
||||
CommonRecommendationType.GeoPopTweet,
|
||||
CommonRecommendationType.GeoPopTweet,
|
||||
CommonRecommendationType.GeoPopTweet
|
||||
),
|
||||
result = tweetIdTweetyPieResultMap._2,
|
||||
localizedEntity = None
|
||||
)
|
||||
}.toSeq
|
||||
outputTopTweetsByGeoCounter.add(candidates.length)
|
||||
candidates
|
||||
}
|
||||
|
||||
private val topTweetsByGeoFrequencyPredicate = {
|
||||
TargetPredicates
|
||||
.pushRecTypeFatiguePredicate(
|
||||
CommonRecommendationType.GeoPopTweet,
|
||||
FS.TopTweetsByGeoPushInterval,
|
||||
FS.MaxTopTweetsByGeoPushGivenInterval,
|
||||
stats
|
||||
)
|
||||
}
|
||||
|
||||
def getAvailabilityForDormantUser(target: Target): Future[Boolean] = {
|
||||
lazy val isDormantUserNotFatigued = topTweetsByGeoFrequencyPredicate(Seq(target)).map(_.head)
|
||||
lazy val enableTopTweetsByGeoForDormantUsers =
|
||||
target.params(FS.EnableTopTweetsByGeoCandidatesForDormantUsers)
|
||||
|
||||
target.lastHTLVisitTimestamp.flatMap {
|
||||
case Some(lastHTLTimestamp) =>
|
||||
val minTimeSinceLastLogin =
|
||||
target.params(FS.MinimumTimeSinceLastLoginForGeoPopTweetPush).ago
|
||||
val timeSinceInactive = target.params(FS.TimeSinceLastLoginForGeoPopTweetPush).ago
|
||||
val lastActiveTimestamp = Time.fromMilliseconds(lastHTLTimestamp)
|
||||
if (lastActiveTimestamp > minTimeSinceLastLogin) {
|
||||
nonDormantUsersSince14DaysCounter.incr()
|
||||
Future.False
|
||||
} else {
|
||||
dormantUsersSince14DaysCounter.incr()
|
||||
isDormantUserNotFatigued.map { isUserNotFatigued =>
|
||||
lastActiveTimestamp < timeSinceInactive &&
|
||||
enableTopTweetsByGeoForDormantUsers &&
|
||||
isUserNotFatigued
|
||||
}
|
||||
}
|
||||
case _ =>
|
||||
dormantUsersSince30DaysCounter.incr()
|
||||
isDormantUserNotFatigued.map { isUserNotFatigued =>
|
||||
enableTopTweetsByGeoForDormantUsers && isUserNotFatigued
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def getAvailabilityForPlaybookSetUp(target: Target): Future[Boolean] = {
|
||||
lazy val enableTopTweetsByGeoForNewUsers = target.params(FS.EnableTopTweetsByGeoCandidates)
|
||||
val isTargetEligibleForMrFatigueCheck = target.isAccountAtleastNDaysOld(
|
||||
target.params(FS.MrMinDurationSincePushForTopTweetsByGeoPushes))
|
||||
val isMrFatigueCheckEnabled =
|
||||
target.params(FS.EnableMrMinDurationSinceMrPushFatigue)
|
||||
val applyPredicateForTopTweetsByGeo =
|
||||
if (isMrFatigueCheckEnabled) {
|
||||
if (isTargetEligibleForMrFatigueCheck) {
|
||||
DiscoverTwitterPredicate
|
||||
.minDurationElapsedSinceLastMrPushPredicate(
|
||||
name,
|
||||
FS.MrMinDurationSincePushForTopTweetsByGeoPushes,
|
||||
stats
|
||||
).andThen(
|
||||
topTweetsByGeoFrequencyPredicate
|
||||
)(Seq(target)).map(_.head)
|
||||
} else {
|
||||
Future.False
|
||||
}
|
||||
} else {
|
||||
topTweetsByGeoFrequencyPredicate(Seq(target)).map(_.head)
|
||||
}
|
||||
applyPredicateForTopTweetsByGeo.map { predicateResult =>
|
||||
predicateResult && enableTopTweetsByGeoForNewUsers
|
||||
}
|
||||
}
|
||||
|
||||
override def isCandidateSourceAvailable(target: Target): Future[Boolean] = {
|
||||
if (target.isLoggedOutUser) {
|
||||
Future.True
|
||||
} else {
|
||||
PushDeviceUtil
|
||||
.isRecommendationsEligible(target).map(
|
||||
_ && target.params(PushParams.PopGeoCandidatesDecider)).flatMap { isAvailable =>
|
||||
if (isAvailable) {
|
||||
Future
|
||||
.join(getAvailabilityForDormantUser(target), getAvailabilityForPlaybookSetUp(target))
|
||||
.map {
|
||||
case (isAvailableForDormantUser, isAvailableForPlaybook) =>
|
||||
isAvailableForDormantUser || isAvailableForPlaybook
|
||||
case _ => false
|
||||
}
|
||||
} else Future.False
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,215 @@
|
||||
package com.twitter.frigate.pushservice.adaptor
|
||||
|
||||
import com.twitter.events.recos.thriftscala.DisplayLocation
|
||||
import com.twitter.events.recos.thriftscala.TrendsContext
|
||||
import com.twitter.finagle.stats.StatsReceiver
|
||||
import com.twitter.frigate.common.base.CandidateSource
|
||||
import com.twitter.frigate.common.base.CandidateSourceEligible
|
||||
import com.twitter.frigate.common.base.TrendTweetCandidate
|
||||
import com.twitter.frigate.common.base.TrendsCandidate
|
||||
import com.twitter.frigate.common.candidate.RecommendedTrendsCandidateSource
|
||||
import com.twitter.frigate.common.candidate.RecommendedTrendsCandidateSource.Query
|
||||
import com.twitter.frigate.common.predicate.CommonOutNetworkTweetCandidatesSourcePredicates.filterOutReplyTweet
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.RawCandidate
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.Target
|
||||
import com.twitter.frigate.pushservice.adaptor.TrendsCandidatesAdaptor._
|
||||
import com.twitter.frigate.pushservice.params.PushFeatureSwitchParams
|
||||
import com.twitter.frigate.pushservice.params.PushParams
|
||||
import com.twitter.frigate.pushservice.predicate.TargetPredicates
|
||||
import com.twitter.frigate.pushservice.util.PushDeviceUtil
|
||||
import com.twitter.frigate.thriftscala.CommonRecommendationType
|
||||
import com.twitter.geoduck.common.thriftscala.Location
|
||||
import com.twitter.gizmoduck.thriftscala.UserType
|
||||
import com.twitter.hermit.store.tweetypie.UserTweet
|
||||
import com.twitter.stitch.tweetypie.TweetyPie.TweetyPieResult
|
||||
import com.twitter.storehaus.ReadableStore
|
||||
import com.twitter.util.Future
|
||||
import scala.collection.Map
|
||||
|
||||
object TrendsCandidatesAdaptor {
|
||||
type TweetId = Long
|
||||
type EventId = Long
|
||||
}
|
||||
|
||||
case class TrendsCandidatesAdaptor(
|
||||
softUserGeoLocationStore: ReadableStore[Long, Location],
|
||||
recommendedTrendsCandidateSource: RecommendedTrendsCandidateSource,
|
||||
tweetyPieStore: ReadableStore[Long, TweetyPieResult],
|
||||
tweetyPieStoreNoVF: ReadableStore[Long, TweetyPieResult],
|
||||
safeUserTweetTweetyPieStore: ReadableStore[UserTweet, TweetyPieResult],
|
||||
statsReceiver: StatsReceiver)
|
||||
extends CandidateSource[Target, RawCandidate]
|
||||
with CandidateSourceEligible[Target, RawCandidate] {
|
||||
override val name = this.getClass.getSimpleName
|
||||
|
||||
private val trendAdaptorStats = statsReceiver.scope("TrendsCandidatesAdaptor")
|
||||
private val trendTweetCandidateNumber = trendAdaptorStats.counter("trend_tweet_candidate")
|
||||
private val nonReplyTweetsCounter = trendAdaptorStats.counter("non_reply_tweets")
|
||||
|
||||
private def getQuery(target: Target): Future[Query] = {
|
||||
def getUserCountryCode(target: Target): Future[Option[String]] = {
|
||||
target.targetUser.flatMap {
|
||||
case Some(user) if user.userType == UserType.Soft =>
|
||||
softUserGeoLocationStore
|
||||
.get(user.id)
|
||||
.map(_.flatMap(_.simpleRgcResult.flatMap(_.countryCodeAlpha2)))
|
||||
|
||||
case _ => target.accountCountryCode
|
||||
}
|
||||
}
|
||||
|
||||
for {
|
||||
countryCode <- getUserCountryCode(target)
|
||||
inferredLanguage <- target.inferredUserDeviceLanguage
|
||||
} yield {
|
||||
Query(
|
||||
userId = target.targetId,
|
||||
displayLocation = DisplayLocation.MagicRecs,
|
||||
languageCode = inferredLanguage,
|
||||
countryCode = countryCode,
|
||||
maxResults = target.params(PushFeatureSwitchParams.MaxRecommendedTrendsToQuery)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Query candidates only if sent at most [[PushFeatureSwitchParams.MaxTrendTweetNotificationsInDuration]]
|
||||
* trend tweet notifications in [[PushFeatureSwitchParams.TrendTweetNotificationsFatigueDuration]]
|
||||
*/
|
||||
val trendTweetFatiguePredicate = TargetPredicates.pushRecTypeFatiguePredicate(
|
||||
CommonRecommendationType.TrendTweet,
|
||||
PushFeatureSwitchParams.TrendTweetNotificationsFatigueDuration,
|
||||
PushFeatureSwitchParams.MaxTrendTweetNotificationsInDuration,
|
||||
trendAdaptorStats
|
||||
)
|
||||
|
||||
private val recommendedTrendsWithTweetsCandidateSource: CandidateSource[
|
||||
Target,
|
||||
RawCandidate with TrendsCandidate
|
||||
] = recommendedTrendsCandidateSource
|
||||
.convert[Target, TrendsCandidate](
|
||||
getQuery,
|
||||
recommendedTrendsCandidateSource.identityCandidateMapper
|
||||
)
|
||||
.batchMapValues[Target, RawCandidate with TrendsCandidate](
|
||||
trendsCandidatesToTweetCandidates(_, _, getTweetyPieResults))
|
||||
|
||||
private def getTweetyPieResults(
|
||||
tweetIds: Seq[TweetId],
|
||||
target: Target
|
||||
): Future[Map[TweetId, TweetyPieResult]] = {
|
||||
if (target.params(PushFeatureSwitchParams.EnableSafeUserTweetTweetypieStore)) {
|
||||
Future
|
||||
.collect(
|
||||
safeUserTweetTweetyPieStore.multiGet(
|
||||
tweetIds.toSet.map(UserTweet(_, Some(target.targetId))))).map {
|
||||
_.collect {
|
||||
case (userTweet, Some(tweetyPieResult)) => userTweet.tweetId -> tweetyPieResult
|
||||
}
|
||||
}
|
||||
} else {
|
||||
Future
|
||||
.collect((target.params(PushFeatureSwitchParams.EnableVFInTweetypie) match {
|
||||
case true => tweetyPieStore
|
||||
case false => tweetyPieStoreNoVF
|
||||
}).multiGet(tweetIds.toSet)).map { tweetyPieResultMap =>
|
||||
filterOutReplyTweet(tweetyPieResultMap, nonReplyTweetsCounter).collect {
|
||||
case (tweetId, Some(tweetyPieResult)) => tweetId -> tweetyPieResult
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param _target: [[Target]] object representing notificaion recipient user
|
||||
* @param trendsCandidates: Sequence of [[TrendsCandidate]] returned from ERS
|
||||
* @return: Seq of trends candidates expanded to associated tweets.
|
||||
*/
|
||||
private def trendsCandidatesToTweetCandidates(
|
||||
_target: Target,
|
||||
trendsCandidates: Seq[TrendsCandidate],
|
||||
getTweetyPieResults: (Seq[TweetId], Target) => Future[Map[TweetId, TweetyPieResult]]
|
||||
): Future[Seq[RawCandidate with TrendsCandidate]] = {
|
||||
|
||||
def generateTrendTweetCandidates(
|
||||
trendCandidate: TrendsCandidate,
|
||||
tweetyPieResults: Map[TweetId, TweetyPieResult]
|
||||
) = {
|
||||
val tweetIds = trendCandidate.context.curatedRepresentativeTweets.getOrElse(Seq.empty) ++
|
||||
trendCandidate.context.algoRepresentativeTweets.getOrElse(Seq.empty)
|
||||
|
||||
tweetIds.flatMap { tweetId =>
|
||||
tweetyPieResults.get(tweetId).map { _tweetyPieResult =>
|
||||
new RawCandidate with TrendTweetCandidate {
|
||||
override val trendId: String = trendCandidate.trendId
|
||||
override val trendName: String = trendCandidate.trendName
|
||||
override val landingUrl: String = trendCandidate.landingUrl
|
||||
override val timeBoundedLandingUrl: Option[String] =
|
||||
trendCandidate.timeBoundedLandingUrl
|
||||
override val context: TrendsContext = trendCandidate.context
|
||||
override val tweetyPieResult: Option[TweetyPieResult] = Some(_tweetyPieResult)
|
||||
override val tweetId: TweetId = _tweetyPieResult.tweet.id
|
||||
override val target: Target = _target
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// collect all tweet ids associated with all trends
|
||||
val allTweetIds = trendsCandidates.flatMap { trendsCandidate =>
|
||||
val context = trendsCandidate.context
|
||||
context.curatedRepresentativeTweets.getOrElse(Seq.empty) ++
|
||||
context.algoRepresentativeTweets.getOrElse(Seq.empty)
|
||||
}
|
||||
|
||||
getTweetyPieResults(allTweetIds, _target)
|
||||
.map { tweetIdToTweetyPieResult =>
|
||||
val trendTweetCandidates = trendsCandidates.flatMap { trendCandidate =>
|
||||
val allTrendTweetCandidates = generateTrendTweetCandidates(
|
||||
trendCandidate,
|
||||
tweetIdToTweetyPieResult
|
||||
)
|
||||
|
||||
val (tweetCandidatesFromCuratedTrends, tweetCandidatesFromNonCuratedTrends) =
|
||||
allTrendTweetCandidates.partition(_.isCuratedTrend)
|
||||
|
||||
tweetCandidatesFromCuratedTrends.filter(
|
||||
_.target.params(PushFeatureSwitchParams.EnableCuratedTrendTweets)) ++
|
||||
tweetCandidatesFromNonCuratedTrends.filter(
|
||||
_.target.params(PushFeatureSwitchParams.EnableNonCuratedTrendTweets))
|
||||
}
|
||||
|
||||
trendTweetCandidateNumber.incr(trendTweetCandidates.size)
|
||||
trendTweetCandidates
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param target: [[Target]] user
|
||||
* @return: true if customer is eligible to receive trend tweet notifications
|
||||
*
|
||||
*/
|
||||
override def isCandidateSourceAvailable(target: Target): Future[Boolean] = {
|
||||
PushDeviceUtil
|
||||
.isRecommendationsEligible(target)
|
||||
.map(target.params(PushParams.TrendsCandidateDecider) && _)
|
||||
}
|
||||
|
||||
override def get(target: Target): Future[Option[Seq[RawCandidate with TrendsCandidate]]] = {
|
||||
recommendedTrendsWithTweetsCandidateSource
|
||||
.get(target)
|
||||
.flatMap {
|
||||
case Some(candidates) if candidates.nonEmpty =>
|
||||
trendTweetFatiguePredicate(Seq(target))
|
||||
.map(_.head)
|
||||
.map { isTargetFatigueEligible =>
|
||||
if (isTargetFatigueEligible) Some(candidates)
|
||||
else None
|
||||
}
|
||||
|
||||
case _ => Future.None
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,188 @@
|
||||
package com.twitter.frigate.pushservice.adaptor
|
||||
|
||||
import com.twitter.content_mixer.thriftscala.ContentMixerProductResponse
|
||||
import com.twitter.content_mixer.thriftscala.ContentMixerRequest
|
||||
import com.twitter.content_mixer.thriftscala.ContentMixerResponse
|
||||
import com.twitter.content_mixer.thriftscala.NotificationsTripTweetsProductContext
|
||||
import com.twitter.content_mixer.thriftscala.Product
|
||||
import com.twitter.content_mixer.thriftscala.ProductContext
|
||||
import com.twitter.finagle.stats.StatsReceiver
|
||||
import com.twitter.frigate.common.base.CandidateSource
|
||||
import com.twitter.frigate.common.base.CandidateSourceEligible
|
||||
import com.twitter.frigate.common.predicate.CommonOutNetworkTweetCandidatesSourcePredicates.filterOutReplyTweet
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.RawCandidate
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.Target
|
||||
import com.twitter.frigate.pushservice.params.PushFeatureSwitchParams
|
||||
import com.twitter.frigate.pushservice.params.PushParams
|
||||
import com.twitter.frigate.pushservice.util.MediaCRT
|
||||
import com.twitter.frigate.pushservice.util.PushAdaptorUtil
|
||||
import com.twitter.frigate.pushservice.util.PushDeviceUtil
|
||||
import com.twitter.frigate.thriftscala.CommonRecommendationType
|
||||
import com.twitter.geoduck.util.country.CountryInfo
|
||||
import com.twitter.product_mixer.core.thriftscala.ClientContext
|
||||
import com.twitter.stitch.tweetypie.TweetyPie.TweetyPieResult
|
||||
import com.twitter.storehaus.ReadableStore
|
||||
import com.twitter.trends.trip_v1.trip_tweets.thriftscala.TripDomain
|
||||
import com.twitter.trends.trip_v1.trip_tweets.thriftscala.TripTweets
|
||||
import com.twitter.util.Future
|
||||
|
||||
case class TripGeoCandidatesAdaptor(
|
||||
tripTweetCandidateStore: ReadableStore[TripDomain, TripTweets],
|
||||
contentMixerStore: ReadableStore[ContentMixerRequest, ContentMixerResponse],
|
||||
tweetyPieStore: ReadableStore[Long, TweetyPieResult],
|
||||
tweetyPieStoreNoVF: ReadableStore[Long, TweetyPieResult],
|
||||
statsReceiver: StatsReceiver)
|
||||
extends CandidateSource[Target, RawCandidate]
|
||||
with CandidateSourceEligible[Target, RawCandidate] {
|
||||
|
||||
override def name: String = this.getClass.getSimpleName
|
||||
|
||||
private val stats = statsReceiver.scope(name.stripSuffix("$"))
|
||||
|
||||
private val contentMixerRequests = stats.counter("getTripCandidatesContentMixerRequests")
|
||||
private val loggedOutTripTweetIds = stats.counter("logged_out_trip_tweet_ids_count")
|
||||
private val loggedOutRawCandidates = stats.counter("logged_out_raw_candidates_count")
|
||||
private val rawCandidates = stats.counter("raw_candidates_count")
|
||||
private val loggedOutEmptyplaceId = stats.counter("logged_out_empty_place_id_count")
|
||||
private val loggedOutPlaceId = stats.counter("logged_out_place_id_count")
|
||||
private val nonReplyTweetsCounter = stats.counter("non_reply_tweets")
|
||||
|
||||
override def isCandidateSourceAvailable(target: Target): Future[Boolean] = {
|
||||
if (target.isLoggedOutUser) {
|
||||
Future.True
|
||||
} else {
|
||||
for {
|
||||
isRecommendationsSettingEnabled <- PushDeviceUtil.isRecommendationsEligible(target)
|
||||
inferredLanguage <- target.inferredUserDeviceLanguage
|
||||
} yield {
|
||||
isRecommendationsSettingEnabled &&
|
||||
inferredLanguage.nonEmpty &&
|
||||
target.params(PushParams.TripGeoTweetCandidatesDecider)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
private def buildRawCandidate(target: Target, tweetyPieResult: TweetyPieResult): RawCandidate = {
|
||||
PushAdaptorUtil.generateOutOfNetworkTweetCandidates(
|
||||
inputTarget = target,
|
||||
id = tweetyPieResult.tweet.id,
|
||||
mediaCRT = MediaCRT(
|
||||
CommonRecommendationType.TripGeoTweet,
|
||||
CommonRecommendationType.TripGeoTweet,
|
||||
CommonRecommendationType.TripGeoTweet
|
||||
),
|
||||
result = Some(tweetyPieResult),
|
||||
localizedEntity = None
|
||||
)
|
||||
}
|
||||
|
||||
override def get(target: Target): Future[Option[Seq[RawCandidate]]] = {
|
||||
if (target.isLoggedOutUser) {
|
||||
for {
|
||||
tripTweetIds <- getTripCandidatesForLoggedOutTarget(target)
|
||||
tweetyPieResults <- Future.collect(tweetyPieStoreNoVF.multiGet(tripTweetIds))
|
||||
} yield {
|
||||
val candidates = tweetyPieResults.values.flatten.map(buildRawCandidate(target, _))
|
||||
if (candidates.nonEmpty) {
|
||||
loggedOutRawCandidates.incr(candidates.size)
|
||||
Some(candidates.toSeq)
|
||||
} else None
|
||||
}
|
||||
} else {
|
||||
for {
|
||||
tripTweetIds <- getTripCandidatesContentMixer(target)
|
||||
tweetyPieResults <-
|
||||
Future.collect((target.params(PushFeatureSwitchParams.EnableVFInTweetypie) match {
|
||||
case true => tweetyPieStore
|
||||
case false => tweetyPieStoreNoVF
|
||||
}).multiGet(tripTweetIds))
|
||||
} yield {
|
||||
val nonReplyTweets = filterOutReplyTweet(tweetyPieResults, nonReplyTweetsCounter)
|
||||
val candidates = nonReplyTweets.values.flatten.map(buildRawCandidate(target, _))
|
||||
if (candidates.nonEmpty && target.params(
|
||||
PushFeatureSwitchParams.TripTweetCandidateReturnEnable)) {
|
||||
rawCandidates.incr(candidates.size)
|
||||
Some(candidates.toSeq)
|
||||
} else None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private def getTripCandidatesContentMixer(
|
||||
target: Target
|
||||
): Future[Set[Long]] = {
|
||||
contentMixerRequests.incr()
|
||||
Future
|
||||
.join(
|
||||
target.inferredUserDeviceLanguage,
|
||||
target.deviceInfo
|
||||
)
|
||||
.flatMap {
|
||||
case (languageOpt, deviceInfoOpt) =>
|
||||
contentMixerStore
|
||||
.get(
|
||||
ContentMixerRequest(
|
||||
clientContext = ClientContext(
|
||||
userId = Some(target.targetId),
|
||||
languageCode = languageOpt,
|
||||
userAgent = deviceInfoOpt.flatMap(_.guessedPrimaryDeviceUserAgent.map(_.toString))
|
||||
),
|
||||
product = Product.NotificationsTripTweets,
|
||||
productContext = Some(
|
||||
ProductContext.NotificationsTripTweetsProductContext(
|
||||
NotificationsTripTweetsProductContext()
|
||||
)),
|
||||
cursor = None,
|
||||
maxResults =
|
||||
Some(target.params(PushFeatureSwitchParams.TripTweetMaxTotalCandidates))
|
||||
)
|
||||
).map {
|
||||
_.map { rawResponse =>
|
||||
val tripResponse =
|
||||
rawResponse.contentMixerProductResponse
|
||||
.asInstanceOf[
|
||||
ContentMixerProductResponse.NotificationsTripTweetsProductResponse]
|
||||
.notificationsTripTweetsProductResponse
|
||||
|
||||
tripResponse.results.map(_.tweetResult.tweetId).toSet
|
||||
}.getOrElse(Set.empty)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private def getTripCandidatesForLoggedOutTarget(
|
||||
target: Target
|
||||
): Future[Set[Long]] = {
|
||||
Future.join(target.targetLanguage, target.countryCode).flatMap {
|
||||
case (Some(lang), Some(country)) =>
|
||||
val placeId = CountryInfo.lookupByCode(country).map(_.placeIdLong)
|
||||
if (placeId.nonEmpty) {
|
||||
loggedOutPlaceId.incr()
|
||||
} else {
|
||||
loggedOutEmptyplaceId.incr()
|
||||
}
|
||||
val tripSource = "TOP_GEO_V3_LR"
|
||||
val tripQuery = TripDomain(
|
||||
sourceId = tripSource,
|
||||
language = Some(lang),
|
||||
placeId = placeId,
|
||||
topicId = None
|
||||
)
|
||||
val response = tripTweetCandidateStore.get(tripQuery)
|
||||
val tripTweetIds =
|
||||
response.map { res =>
|
||||
if (res.isDefined) {
|
||||
res.get.tweets
|
||||
.sortBy(_.score)(Ordering[Double].reverse).map(_.tweetId).toSet
|
||||
} else {
|
||||
Set.empty[Long]
|
||||
}
|
||||
}
|
||||
tripTweetIds.map { ids => loggedOutTripTweetIds.incr(ids.size) }
|
||||
tripTweetIds
|
||||
|
||||
case (_, _) => Future.value(Set.empty)
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,461 @@
|
||||
package com.twitter.frigate.pushservice.config
|
||||
|
||||
import com.twitter.abdecider.LoggingABDecider
|
||||
import com.twitter.abuse.detection.scoring.thriftscala.TweetScoringRequest
|
||||
import com.twitter.abuse.detection.scoring.thriftscala.TweetScoringResponse
|
||||
import com.twitter.audience_rewards.thriftscala.HasSuperFollowingRelationshipRequest
|
||||
import com.twitter.channels.common.thriftscala.ApiList
|
||||
import com.twitter.datatools.entityservice.entities.sports.thriftscala._
|
||||
import com.twitter.decider.Decider
|
||||
import com.twitter.discovery.common.configapi.ConfigParamsBuilder
|
||||
import com.twitter.escherbird.common.thriftscala.QualifiedId
|
||||
import com.twitter.escherbird.metadata.thriftscala.EntityMegadata
|
||||
import com.twitter.eventbus.client.EventBusPublisher
|
||||
import com.twitter.finagle.stats.StatsReceiver
|
||||
import com.twitter.finagle.thrift.ClientId
|
||||
import com.twitter.frigate.common.base._
|
||||
import com.twitter.frigate.common.candidate._
|
||||
import com.twitter.frigate.common.history._
|
||||
import com.twitter.frigate.common.ml.base._
|
||||
import com.twitter.frigate.common.ml.feature._
|
||||
import com.twitter.frigate.common.store._
|
||||
import com.twitter.frigate.common.store.deviceinfo.DeviceInfo
|
||||
import com.twitter.frigate.common.store.interests.InterestsLookupRequestWithContext
|
||||
import com.twitter.frigate.common.store.interests.UserId
|
||||
import com.twitter.frigate.common.util._
|
||||
import com.twitter.frigate.data_pipeline.features_common._
|
||||
import com.twitter.frigate.data_pipeline.thriftscala.UserHistoryKey
|
||||
import com.twitter.frigate.data_pipeline.thriftscala.UserHistoryValue
|
||||
import com.twitter.frigate.dau_model.thriftscala.DauProbability
|
||||
import com.twitter.frigate.magic_events.thriftscala.FanoutEvent
|
||||
import com.twitter.frigate.pushcap.thriftscala.PushcapUserHistory
|
||||
import com.twitter.frigate.pushservice.ml._
|
||||
import com.twitter.frigate.pushservice.params.DeciderKey
|
||||
import com.twitter.frigate.pushservice.params.PushFeatureSwitchParams
|
||||
import com.twitter.frigate.pushservice.params.PushFeatureSwitches
|
||||
import com.twitter.frigate.pushservice.params.PushParams
|
||||
import com.twitter.frigate.pushservice.send_handler.SendHandlerPushCandidateHydrator
|
||||
import com.twitter.frigate.pushservice.refresh_handler.PushCandidateHydrator
|
||||
import com.twitter.frigate.pushservice.store._
|
||||
import com.twitter.frigate.pushservice.store.{Ibis2Store => PushIbis2Store}
|
||||
import com.twitter.frigate.pushservice.take.NotificationServiceRequest
|
||||
import com.twitter.frigate.pushservice.thriftscala.PushRequestScribe
|
||||
import com.twitter.frigate.scribe.thriftscala.NotificationScribe
|
||||
import com.twitter.frigate.thriftscala._
|
||||
import com.twitter.frigate.user_states.thriftscala.MRUserHmmState
|
||||
import com.twitter.geoduck.common.thriftscala.{Location => GeoLocation}
|
||||
import com.twitter.geoduck.service.thriftscala.LocationResponse
|
||||
import com.twitter.gizmoduck.thriftscala.User
|
||||
import com.twitter.hermit.pop_geo.thriftscala.PopTweetsInPlace
|
||||
import com.twitter.hermit.predicate.socialgraph.RelationEdge
|
||||
import com.twitter.hermit.predicate.tweetypie.Perspective
|
||||
import com.twitter.hermit.predicate.tweetypie.UserTweet
|
||||
import com.twitter.hermit.store.semantic_core.SemanticEntityForQuery
|
||||
import com.twitter.hermit.store.tweetypie.{UserTweet => TweetyPieUserTweet}
|
||||
import com.twitter.hermit.stp.thriftscala.STPResult
|
||||
import com.twitter.hss.api.thriftscala.UserHealthSignalResponse
|
||||
import com.twitter.interests.thriftscala.InterestId
|
||||
import com.twitter.interests.thriftscala.{UserInterests => Interests}
|
||||
import com.twitter.interests_discovery.thriftscala.NonPersonalizedRecommendedLists
|
||||
import com.twitter.interests_discovery.thriftscala.RecommendedListsRequest
|
||||
import com.twitter.interests_discovery.thriftscala.RecommendedListsResponse
|
||||
import com.twitter.livevideo.timeline.domain.v2.{Event => LiveEvent}
|
||||
import com.twitter.ml.api.thriftscala.{DataRecord => ThriftDataRecord}
|
||||
import com.twitter.ml.featurestore.lib.dynamic.DynamicFeatureStoreClient
|
||||
import com.twitter.notificationservice.genericfeedbackstore.FeedbackPromptValue
|
||||
import com.twitter.notificationservice.genericfeedbackstore.GenericFeedbackStore
|
||||
import com.twitter.notificationservice.scribe.manhattan.GenericNotificationsFeedbackRequest
|
||||
import com.twitter.notificationservice.thriftscala.CaretFeedbackDetails
|
||||
import com.twitter.notificationservice.thriftscala.CreateGenericNotificationResponse
|
||||
import com.twitter.nrel.heavyranker.CandidateFeatureHydrator
|
||||
import com.twitter.nrel.heavyranker.{FeatureHydrator => MRFeatureHydrator}
|
||||
import com.twitter.nrel.heavyranker.{TargetFeatureHydrator => RelevanceTargetFeatureHydrator}
|
||||
import com.twitter.onboarding.task.service.thriftscala.FatigueFlowEnrollment
|
||||
import com.twitter.permissions_storage.thriftscala.AppPermission
|
||||
import com.twitter.recommendation.interests.discovery.core.model.InterestDomain
|
||||
import com.twitter.recos.user_tweet_entity_graph.thriftscala.RecommendTweetEntityRequest
|
||||
import com.twitter.recos.user_tweet_entity_graph.thriftscala.RecommendTweetEntityResponse
|
||||
import com.twitter.recos.user_user_graph.thriftscala.RecommendUserRequest
|
||||
import com.twitter.recos.user_user_graph.thriftscala.RecommendUserResponse
|
||||
import com.twitter.rux.common.strato.thriftscala.UserTargetingProperty
|
||||
import com.twitter.scio.nsfw_user_segmentation.thriftscala.NSFWProducer
|
||||
import com.twitter.scio.nsfw_user_segmentation.thriftscala.NSFWUserSegmentation
|
||||
import com.twitter.search.common.features.thriftscala.ThriftSearchResultFeatures
|
||||
import com.twitter.search.earlybird.thriftscala.EarlybirdRequest
|
||||
import com.twitter.search.earlybird.thriftscala.ThriftSearchResult
|
||||
import com.twitter.service.gen.scarecrow.thriftscala.Event
|
||||
import com.twitter.service.gen.scarecrow.thriftscala.TieredActionResult
|
||||
import com.twitter.service.metastore.gen.thriftscala.Location
|
||||
import com.twitter.service.metastore.gen.thriftscala.UserLanguages
|
||||
import com.twitter.servo.decider.DeciderGateBuilder
|
||||
import com.twitter.simclusters_v2.thriftscala.SimClustersInferredEntities
|
||||
import com.twitter.stitch.tweetypie.TweetyPie.TweetyPieResult
|
||||
import com.twitter.storehaus.ReadableStore
|
||||
import com.twitter.strato.columns.frigate.logged_out_web_notifications.thriftscala.LOWebNotificationMetadata
|
||||
import com.twitter.strato.columns.notifications.thriftscala.SourceDestUserRequest
|
||||
import com.twitter.strato.client.{UserId => StratoUserId}
|
||||
import com.twitter.timelines.configapi
|
||||
import com.twitter.timelines.configapi.CompositeConfig
|
||||
import com.twitter.timelinescorer.thriftscala.v1.ScoredTweet
|
||||
import com.twitter.topiclisting.TopicListing
|
||||
import com.twitter.trends.trip_v1.trip_tweets.thriftscala.TripDomain
|
||||
import com.twitter.trends.trip_v1.trip_tweets.thriftscala.TripTweets
|
||||
import com.twitter.tsp.thriftscala.TopicSocialProofRequest
|
||||
import com.twitter.tsp.thriftscala.TopicSocialProofResponse
|
||||
import com.twitter.ubs.thriftscala.SellerTrack
|
||||
import com.twitter.ubs.thriftscala.AudioSpace
|
||||
import com.twitter.ubs.thriftscala.Participants
|
||||
import com.twitter.ubs.thriftscala.SellerApplicationState
|
||||
import com.twitter.user_session_store.thriftscala.UserSession
|
||||
import com.twitter.util.Duration
|
||||
import com.twitter.util.Future
|
||||
import com.twitter.wtf.scalding.common.thriftscala.UserFeatures
|
||||
|
||||
trait Config {
|
||||
self =>
|
||||
|
||||
def isServiceLocal: Boolean
|
||||
|
||||
def localConfigRepoPath: String
|
||||
|
||||
def inMemCacheOff: Boolean
|
||||
|
||||
def historyStore: PushServiceHistoryStore
|
||||
|
||||
def emailHistoryStore: PushServiceHistoryStore
|
||||
|
||||
def strongTiesStore: ReadableStore[Long, STPResult]
|
||||
|
||||
def safeUserStore: ReadableStore[Long, User]
|
||||
|
||||
def deviceInfoStore: ReadableStore[Long, DeviceInfo]
|
||||
|
||||
def edgeStore: ReadableStore[RelationEdge, Boolean]
|
||||
|
||||
def socialGraphServiceProcessStore: ReadableStore[RelationEdge, Boolean]
|
||||
|
||||
def userUtcOffsetStore: ReadableStore[Long, Duration]
|
||||
|
||||
def cachedTweetyPieStoreV2: ReadableStore[Long, TweetyPieResult]
|
||||
|
||||
def safeCachedTweetyPieStoreV2: ReadableStore[Long, TweetyPieResult]
|
||||
|
||||
def userTweetTweetyPieStore: ReadableStore[TweetyPieUserTweet, TweetyPieResult]
|
||||
|
||||
def safeUserTweetTweetyPieStore: ReadableStore[TweetyPieUserTweet, TweetyPieResult]
|
||||
|
||||
def cachedTweetyPieStoreV2NoVF: ReadableStore[Long, TweetyPieResult]
|
||||
|
||||
def tweetContentFeatureCacheStore: ReadableStore[Long, ThriftDataRecord]
|
||||
|
||||
def scarecrowCheckEventStore: ReadableStore[Event, TieredActionResult]
|
||||
|
||||
def userTweetPerspectiveStore: ReadableStore[UserTweet, Perspective]
|
||||
|
||||
def userCountryStore: ReadableStore[Long, Location]
|
||||
|
||||
def pushInfoStore: ReadableStore[Long, UserForPushTargeting]
|
||||
|
||||
def loggedOutPushInfoStore: ReadableStore[Long, LOWebNotificationMetadata]
|
||||
|
||||
def tweetImpressionStore: ReadableStore[Long, Seq[Long]]
|
||||
|
||||
def audioSpaceStore: ReadableStore[String, AudioSpace]
|
||||
|
||||
def basketballGameScoreStore: ReadableStore[QualifiedId, BasketballGameLiveUpdate]
|
||||
|
||||
def baseballGameScoreStore: ReadableStore[QualifiedId, BaseballGameLiveUpdate]
|
||||
|
||||
def cricketMatchScoreStore: ReadableStore[QualifiedId, CricketMatchLiveUpdate]
|
||||
|
||||
def soccerMatchScoreStore: ReadableStore[QualifiedId, SoccerMatchLiveUpdate]
|
||||
|
||||
def nflGameScoreStore: ReadableStore[QualifiedId, NflFootballGameLiveUpdate]
|
||||
|
||||
def topicSocialProofServiceStore: ReadableStore[TopicSocialProofRequest, TopicSocialProofResponse]
|
||||
|
||||
def spaceDeviceFollowStore: ReadableStore[SourceDestUserRequest, Boolean]
|
||||
|
||||
def audioSpaceParticipantsStore: ReadableStore[String, Participants]
|
||||
|
||||
def notificationServiceSender: ReadableStore[
|
||||
NotificationServiceRequest,
|
||||
CreateGenericNotificationResponse
|
||||
]
|
||||
|
||||
def ocfFatigueStore: ReadableStore[OCFHistoryStoreKey, FatigueFlowEnrollment]
|
||||
|
||||
def dauProbabilityStore: ReadableStore[Long, DauProbability]
|
||||
|
||||
def hydratedLabeledPushRecsStore: ReadableStore[UserHistoryKey, UserHistoryValue]
|
||||
|
||||
def userHTLLastVisitStore: ReadableStore[Long, Seq[Long]]
|
||||
|
||||
def userLanguagesStore: ReadableStore[Long, UserLanguages]
|
||||
|
||||
def topTweetsByGeoStore: ReadableStore[InterestDomain[String], Map[String, List[
|
||||
(Long, Double)
|
||||
]]]
|
||||
|
||||
def topTweetsByGeoV2VersionedStore: ReadableStore[String, PopTweetsInPlace]
|
||||
|
||||
lazy val pushRecItemStore: ReadableStore[PushRecItemsKey, RecItems] = PushRecItemStore(
|
||||
hydratedLabeledPushRecsStore
|
||||
)
|
||||
|
||||
lazy val labeledPushRecsVerifyingStore: ReadableStore[
|
||||
LabeledPushRecsVerifyingStoreKey,
|
||||
LabeledPushRecsVerifyingStoreResponse
|
||||
] =
|
||||
LabeledPushRecsVerifyingStore(
|
||||
hydratedLabeledPushRecsStore,
|
||||
historyStore
|
||||
)
|
||||
|
||||
lazy val labeledPushRecsDecideredStore: ReadableStore[LabeledPushRecsStoreKey, UserHistoryValue] =
|
||||
LabeledPushRecsDecideredStore(
|
||||
labeledPushRecsVerifyingStore,
|
||||
useHydratedLabeledSendsForFeaturesDeciderKey,
|
||||
verifyHydratedLabeledSendsForFeaturesDeciderKey
|
||||
)
|
||||
|
||||
def onlineUserHistoryStore: ReadableStore[OnlineUserHistoryKey, UserHistoryValue]
|
||||
|
||||
def nsfwConsumerStore: ReadableStore[Long, NSFWUserSegmentation]
|
||||
|
||||
def nsfwProducerStore: ReadableStore[Long, NSFWProducer]
|
||||
|
||||
def popGeoLists: ReadableStore[String, NonPersonalizedRecommendedLists]
|
||||
|
||||
def listAPIStore: ReadableStore[Long, ApiList]
|
||||
|
||||
def openedPushByHourAggregatedStore: ReadableStore[Long, Map[Int, Int]]
|
||||
|
||||
def userHealthSignalStore: ReadableStore[Long, UserHealthSignalResponse]
|
||||
|
||||
def reactivatedUserInfoStore: ReadableStore[Long, String]
|
||||
|
||||
def weightedOpenOrNtabClickModelScorer: PushMLModelScorer
|
||||
|
||||
def optoutModelScorer: PushMLModelScorer
|
||||
|
||||
def filteringModelScorer: PushMLModelScorer
|
||||
|
||||
def recentFollowsStore: ReadableStore[Long, Seq[Long]]
|
||||
|
||||
def geoDuckV2Store: ReadableStore[UserId, LocationResponse]
|
||||
|
||||
def realGraphScoresTop500InStore: ReadableStore[Long, Map[Long, Double]]
|
||||
|
||||
def tweetEntityGraphStore: ReadableStore[
|
||||
RecommendTweetEntityRequest,
|
||||
RecommendTweetEntityResponse
|
||||
]
|
||||
|
||||
def userUserGraphStore: ReadableStore[RecommendUserRequest, RecommendUserResponse]
|
||||
|
||||
def userFeaturesStore: ReadableStore[Long, UserFeatures]
|
||||
|
||||
def userTargetingPropertyStore: ReadableStore[Long, UserTargetingProperty]
|
||||
|
||||
def timelinesUserSessionStore: ReadableStore[Long, UserSession]
|
||||
|
||||
def optOutUserInterestsStore: ReadableStore[UserId, Seq[InterestId]]
|
||||
|
||||
def ntabCaretFeedbackStore: ReadableStore[GenericNotificationsFeedbackRequest, Seq[
|
||||
CaretFeedbackDetails
|
||||
]]
|
||||
|
||||
def genericFeedbackStore: ReadableStore[FeedbackRequest, Seq[
|
||||
FeedbackPromptValue
|
||||
]]
|
||||
|
||||
def genericNotificationFeedbackStore: GenericFeedbackStore
|
||||
|
||||
def semanticCoreMegadataStore: ReadableStore[
|
||||
SemanticEntityForQuery,
|
||||
EntityMegadata
|
||||
]
|
||||
|
||||
def tweetHealthScoreStore: ReadableStore[TweetScoringRequest, TweetScoringResponse]
|
||||
|
||||
def earlybirdFeatureStore: ReadableStore[Long, ThriftSearchResultFeatures]
|
||||
|
||||
def earlybirdFeatureBuilder: FeatureBuilder[Long]
|
||||
|
||||
// Feature builders
|
||||
|
||||
def tweetAuthorLocationFeatureBuilder: FeatureBuilder[Location]
|
||||
|
||||
def tweetAuthorLocationFeatureBuilderById: FeatureBuilder[Long]
|
||||
|
||||
def socialContextActionsFeatureBuilder: FeatureBuilder[SocialContextActions]
|
||||
|
||||
def tweetContentFeatureBuilder: FeatureBuilder[Long]
|
||||
|
||||
def tweetAuthorRecentRealGraphFeatureBuilder: FeatureBuilder[RealGraphEdge]
|
||||
|
||||
def socialContextRecentRealGraphFeatureBuilder: FeatureBuilder[Set[RealGraphEdge]]
|
||||
|
||||
def tweetSocialProofFeatureBuilder: FeatureBuilder[TweetSocialProofKey]
|
||||
|
||||
def targetUserFullRealGraphFeatureBuilder: FeatureBuilder[TargetFullRealGraphFeatureKey]
|
||||
|
||||
def postProcessingFeatureBuilder: PostProcessingFeatureBuilder
|
||||
|
||||
def mrOfflineUserCandidateSparseAggregatesFeatureBuilder: FeatureBuilder[
|
||||
OfflineSparseAggregateKey
|
||||
]
|
||||
|
||||
def mrOfflineUserAggregatesFeatureBuilder: FeatureBuilder[Long]
|
||||
|
||||
def mrOfflineUserCandidateAggregatesFeatureBuilder: FeatureBuilder[OfflineAggregateKey]
|
||||
|
||||
def tweetAnnotationsFeatureBuilder: FeatureBuilder[Long]
|
||||
|
||||
def targetUserMediaRepresentationFeatureBuilder: FeatureBuilder[Long]
|
||||
|
||||
def targetLevelFeatureBuilder: FeatureBuilder[MrRequestContextForFeatureStore]
|
||||
|
||||
def candidateLevelFeatureBuilder: FeatureBuilder[EntityRequestContextForFeatureStore]
|
||||
|
||||
def targetFeatureHydrator: RelevanceTargetFeatureHydrator
|
||||
|
||||
def useHydratedLabeledSendsForFeaturesDeciderKey: String =
|
||||
DeciderKey.useHydratedLabeledSendsForFeaturesDeciderKey.toString
|
||||
|
||||
def verifyHydratedLabeledSendsForFeaturesDeciderKey: String =
|
||||
DeciderKey.verifyHydratedLabeledSendsForFeaturesDeciderKey.toString
|
||||
|
||||
def lexServiceStore: ReadableStore[EventRequest, LiveEvent]
|
||||
|
||||
def userMediaRepresentationStore: ReadableStore[Long, UserMediaRepresentation]
|
||||
|
||||
def producerMediaRepresentationStore: ReadableStore[Long, UserMediaRepresentation]
|
||||
|
||||
def mrUserStatePredictionStore: ReadableStore[Long, MRUserHmmState]
|
||||
|
||||
def pushcapDynamicPredictionStore: ReadableStore[Long, PushcapUserHistory]
|
||||
|
||||
def earlybirdCandidateSource: EarlybirdCandidateSource
|
||||
|
||||
def earlybirdSearchStore: ReadableStore[EarlybirdRequest, Seq[ThriftSearchResult]]
|
||||
|
||||
def earlybirdSearchDest: String
|
||||
|
||||
def pushserviceThriftClientId: ClientId
|
||||
|
||||
def simClusterToEntityStore: ReadableStore[Int, SimClustersInferredEntities]
|
||||
|
||||
def fanoutMetadataStore: ReadableStore[(Long, Long), FanoutEvent]
|
||||
|
||||
/**
|
||||
* PostRanking Feature Store Client
|
||||
*/
|
||||
def postRankingFeatureStoreClient: DynamicFeatureStoreClient[MrRequestContextForFeatureStore]
|
||||
|
||||
/**
|
||||
* ReadableStore to fetch [[UserInterests]] from INTS service
|
||||
*/
|
||||
def interestsWithLookupContextStore: ReadableStore[InterestsLookupRequestWithContext, Interests]
|
||||
|
||||
/**
|
||||
*
|
||||
* @return: [[TopicListing]] object to fetch paused topics and scope from productId
|
||||
*/
|
||||
def topicListing: TopicListing
|
||||
|
||||
/**
|
||||
*
|
||||
* @return: [[UttEntityHydrationStore]] object
|
||||
*/
|
||||
def uttEntityHydrationStore: UttEntityHydrationStore
|
||||
|
||||
def appPermissionStore: ReadableStore[(Long, (String, String)), AppPermission]
|
||||
|
||||
lazy val userTweetEntityGraphCandidates: UserTweetEntityGraphCandidates =
|
||||
UserTweetEntityGraphCandidates(
|
||||
cachedTweetyPieStoreV2,
|
||||
tweetEntityGraphStore,
|
||||
PushParams.UTEGTweetCandidateSourceParam,
|
||||
PushFeatureSwitchParams.NumberOfMaxUTEGCandidatesQueriedParam,
|
||||
PushParams.AllowOneSocialProofForTweetInUTEGParam,
|
||||
PushParams.OutNetworkTweetsOnlyForUTEGParam,
|
||||
PushFeatureSwitchParams.MaxTweetAgeParam
|
||||
)(statsReceiver)
|
||||
|
||||
def pushSendEventBusPublisher: EventBusPublisher[NotificationScribe]
|
||||
|
||||
// miscs.
|
||||
|
||||
def isProd: Boolean
|
||||
|
||||
implicit def statsReceiver: StatsReceiver
|
||||
|
||||
def decider: Decider
|
||||
|
||||
def abDecider: LoggingABDecider
|
||||
|
||||
def casLock: CasLock
|
||||
|
||||
def pushIbisV2Store: PushIbis2Store
|
||||
|
||||
// scribe
|
||||
def notificationScribe(data: NotificationScribe): Unit
|
||||
|
||||
def requestScribe(data: PushRequestScribe): Unit
|
||||
|
||||
def init(): Future[Unit] = Future.Done
|
||||
|
||||
def configParamsBuilder: ConfigParamsBuilder
|
||||
|
||||
def candidateFeatureHydrator: CandidateFeatureHydrator
|
||||
|
||||
def featureHydrator: MRFeatureHydrator
|
||||
|
||||
def candidateHydrator: PushCandidateHydrator
|
||||
|
||||
def sendHandlerCandidateHydrator: SendHandlerPushCandidateHydrator
|
||||
|
||||
lazy val overridesConfig: configapi.Config = {
|
||||
val pushFeatureSwitchConfigs: configapi.Config = PushFeatureSwitches(
|
||||
deciderGateBuilder = new DeciderGateBuilder(decider),
|
||||
statsReceiver = statsReceiver
|
||||
).config
|
||||
|
||||
new CompositeConfig(Seq(pushFeatureSwitchConfigs))
|
||||
}
|
||||
|
||||
def realTimeClientEventStore: RealTimeClientEventStore
|
||||
|
||||
def inlineActionHistoryStore: ReadableStore[Long, Seq[(Long, String)]]
|
||||
|
||||
def softUserGeoLocationStore: ReadableStore[Long, GeoLocation]
|
||||
|
||||
def tweetTranslationStore: ReadableStore[TweetTranslationStore.Key, TweetTranslationStore.Value]
|
||||
|
||||
def tripTweetCandidateStore: ReadableStore[TripDomain, TripTweets]
|
||||
|
||||
def softUserFollowingStore: ReadableStore[User, Seq[Long]]
|
||||
|
||||
def superFollowEligibilityUserStore: ReadableStore[Long, Boolean]
|
||||
|
||||
def superFollowCreatorTweetCountStore: ReadableStore[StratoUserId, Int]
|
||||
|
||||
def hasSuperFollowingRelationshipStore: ReadableStore[
|
||||
HasSuperFollowingRelationshipRequest,
|
||||
Boolean
|
||||
]
|
||||
|
||||
def superFollowApplicationStatusStore: ReadableStore[(Long, SellerTrack), SellerApplicationState]
|
||||
|
||||
def recentHistoryCacheClient: RecentHistoryCacheClient
|
||||
|
||||
def openAppUserStore: ReadableStore[Long, Boolean]
|
||||
|
||||
def loggedOutHistoryStore: PushServiceHistoryStore
|
||||
|
||||
def idsStore: ReadableStore[RecommendedListsRequest, RecommendedListsResponse]
|
||||
|
||||
def htlScoreStore(userId: Long): ReadableStore[Long, ScoredTweet]
|
||||
}
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,16 @@
|
||||
package com.twitter.frigate.pushservice.config
|
||||
|
||||
import com.twitter.frigate.common.util.Experiments
|
||||
|
||||
object ExperimentsWithStats {
|
||||
|
||||
/**
|
||||
* Add an experiment here to collect detailed pushservice stats.
|
||||
*
|
||||
* ! Important !
|
||||
* Keep this set small and remove experiments when you don't need the stats anymore.
|
||||
*/
|
||||
final val PushExperiments: Set[String] = Set(
|
||||
Experiments.MRAndroidInlineActionHoldback.exptName,
|
||||
)
|
||||
}
|
@ -0,0 +1,230 @@
|
||||
package com.twitter.frigate.pushservice.config
|
||||
|
||||
import com.twitter.abdecider.LoggingABDecider
|
||||
import com.twitter.bijection.scrooge.BinaryScalaCodec
|
||||
import com.twitter.bijection.Base64String
|
||||
import com.twitter.bijection.Injection
|
||||
import com.twitter.conversions.DurationOps._
|
||||
import com.twitter.decider.Decider
|
||||
import com.twitter.featureswitches.v2.FeatureSwitches
|
||||
import com.twitter.finagle.mtls.authentication.ServiceIdentifier
|
||||
import com.twitter.finagle.stats.StatsReceiver
|
||||
import com.twitter.finagle.thrift.ClientId
|
||||
import com.twitter.finagle.thrift.RichClientParam
|
||||
import com.twitter.finagle.util.DefaultTimer
|
||||
import com.twitter.frigate.common.config.RateLimiterGenerator
|
||||
import com.twitter.frigate.common.filter.DynamicRequestMeterFilter
|
||||
import com.twitter.frigate.common.history.ManhattanHistoryStore
|
||||
import com.twitter.frigate.common.history.InvalidatingAfterWritesPushServiceHistoryStore
|
||||
import com.twitter.frigate.common.history.ManhattanKVHistoryStore
|
||||
import com.twitter.frigate.common.history.PushServiceHistoryStore
|
||||
import com.twitter.frigate.common.history.SimplePushServiceHistoryStore
|
||||
import com.twitter.frigate.common.util._
|
||||
import com.twitter.frigate.data_pipeline.features_common.FeatureStoreUtil
|
||||
import com.twitter.frigate.data_pipeline.features_common.TargetLevelFeaturesConfig
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.Target
|
||||
import com.twitter.frigate.pushservice.params.DeciderKey
|
||||
import com.twitter.frigate.pushservice.params.PushQPSLimitConstants
|
||||
import com.twitter.frigate.pushservice.params.PushServiceTunableKeys
|
||||
import com.twitter.frigate.pushservice.params.ShardParams
|
||||
import com.twitter.frigate.pushservice.store.PushIbis2Store
|
||||
import com.twitter.frigate.pushservice.thriftscala.PushRequestScribe
|
||||
import com.twitter.frigate.scribe.thriftscala.NotificationScribe
|
||||
import com.twitter.ibis2.service.thriftscala.Ibis2Service
|
||||
import com.twitter.logging.Logger
|
||||
import com.twitter.notificationservice.api.thriftscala.DeleteCurrentTimelineForUserRequest
|
||||
import com.twitter.notificationservice.api.thriftscala.NotificationApi
|
||||
import com.twitter.notificationservice.api.thriftscala.NotificationApi$FinagleClient
|
||||
import com.twitter.notificationservice.thriftscala.CreateGenericNotificationRequest
|
||||
import com.twitter.notificationservice.thriftscala.CreateGenericNotificationResponse
|
||||
import com.twitter.notificationservice.thriftscala.DeleteGenericNotificationRequest
|
||||
import com.twitter.notificationservice.thriftscala.NotificationService
|
||||
import com.twitter.notificationservice.thriftscala.NotificationService$FinagleClient
|
||||
import com.twitter.servo.decider.DeciderGateBuilder
|
||||
import com.twitter.util.tunable.TunableMap
|
||||
import com.twitter.util.Future
|
||||
import com.twitter.util.Timer
|
||||
|
||||
case class ProdConfig(
|
||||
override val isServiceLocal: Boolean,
|
||||
override val localConfigRepoPath: String,
|
||||
override val inMemCacheOff: Boolean,
|
||||
override val decider: Decider,
|
||||
override val abDecider: LoggingABDecider,
|
||||
override val featureSwitches: FeatureSwitches,
|
||||
override val shardParams: ShardParams,
|
||||
override val serviceIdentifier: ServiceIdentifier,
|
||||
override val tunableMap: TunableMap,
|
||||
)(
|
||||
implicit val statsReceiver: StatsReceiver)
|
||||
extends {
|
||||
// Due to trait initialization logic in Scala, any abstract members declared in Config or
|
||||
// DeployConfig should be declared in this block. Otherwise the abstract member might initialize to
|
||||
// null if invoked before object creation finishing.
|
||||
|
||||
val log = Logger("ProdConfig")
|
||||
|
||||
// Deciders
|
||||
val isPushserviceCanaryDeepbirdv2CanaryClusterEnabled = decider
|
||||
.feature(DeciderKey.enablePushserviceDeepbirdv2CanaryClusterDeciderKey.toString).isAvailable
|
||||
|
||||
// Client ids
|
||||
val notifierThriftClientId = ClientId("frigate-notifier.prod")
|
||||
val loggedOutNotifierThriftClientId = ClientId("frigate-logged-out-notifier.prod")
|
||||
val pushserviceThriftClientId: ClientId = ClientId("frigate-pushservice.prod")
|
||||
|
||||
// Dests
|
||||
val frigateHistoryCacheDest = "/s/cache/frigate_history"
|
||||
val memcacheCASDest = "/s/cache/magic_recs_cas:twemcaches"
|
||||
val historyStoreMemcacheDest =
|
||||
"/srv#/prod/local/cache/magic_recs_history:twemcaches"
|
||||
|
||||
val deepbirdv2PredictionServiceDest =
|
||||
if (serviceIdentifier.service.equals("frigate-pushservice-canary") &&
|
||||
isPushserviceCanaryDeepbirdv2CanaryClusterEnabled)
|
||||
"/s/frigate/deepbirdv2-magicrecs-canary"
|
||||
else "/s/frigate/deepbirdv2-magicrecs"
|
||||
|
||||
override val fanoutMetadataColumn = "frigate/magicfanout/prod/mh/fanoutMetadata"
|
||||
|
||||
override val timer: Timer = DefaultTimer
|
||||
override val featureStoreUtil = FeatureStoreUtil.withParams(Some(serviceIdentifier))
|
||||
override val targetLevelFeaturesConfig = TargetLevelFeaturesConfig()
|
||||
val pushServiceMHCacheDest = "/s/cache/pushservice_mh"
|
||||
|
||||
val pushServiceCoreSvcsCacheDest = "/srv#/prod/local/cache/pushservice_core_svcs"
|
||||
|
||||
val userTweetEntityGraphDest = "/s/cassowary/user_tweet_entity_graph"
|
||||
val userUserGraphDest = "/s/cassowary/user_user_graph"
|
||||
val lexServiceDest = "/s/live-video/timeline-thrift"
|
||||
val entityGraphCacheDest = "/s/cache/pushservice_entity_graph"
|
||||
|
||||
override val pushIbisV2Store = {
|
||||
val service = Finagle.readOnlyThriftService(
|
||||
"ibis-v2-service",
|
||||
"/s/ibis2/ibis2",
|
||||
statsReceiver,
|
||||
notifierThriftClientId,
|
||||
requestTimeout = 3.seconds,
|
||||
tries = 3,
|
||||
mTLSServiceIdentifier = Some(serviceIdentifier)
|
||||
)
|
||||
|
||||
// according to ibis team, it is safe to retry on timeout, write & channel closed exceptions.
|
||||
val pushIbisClient = new Ibis2Service.FinagledClient(
|
||||
new DynamicRequestMeterFilter(
|
||||
tunableMap(PushServiceTunableKeys.IbisQpsLimitTunableKey),
|
||||
RateLimiterGenerator.asTuple(_, shardParams.numShards, 20),
|
||||
PushQPSLimitConstants.IbisOrNTabQPSForRFPH
|
||||
)(timer).andThen(service),
|
||||
RichClientParam(serviceName = "ibis-v2-service")
|
||||
)
|
||||
|
||||
PushIbis2Store(pushIbisClient)
|
||||
}
|
||||
|
||||
val notificationServiceClient: NotificationService$FinagleClient = {
|
||||
val service = Finagle.readWriteThriftService(
|
||||
"notificationservice",
|
||||
"/s/notificationservice/notificationservice",
|
||||
statsReceiver,
|
||||
pushserviceThriftClientId,
|
||||
requestTimeout = 10.seconds,
|
||||
mTLSServiceIdentifier = Some(serviceIdentifier)
|
||||
)
|
||||
|
||||
new NotificationService.FinagledClient(
|
||||
new DynamicRequestMeterFilter(
|
||||
tunableMap(PushServiceTunableKeys.NtabQpsLimitTunableKey),
|
||||
RateLimiterGenerator.asTuple(_, shardParams.numShards, 20),
|
||||
PushQPSLimitConstants.IbisOrNTabQPSForRFPH)(timer).andThen(service),
|
||||
RichClientParam(serviceName = "notificationservice")
|
||||
)
|
||||
}
|
||||
|
||||
val notificationServiceApiClient: NotificationApi$FinagleClient = {
|
||||
val service = Finagle.readWriteThriftService(
|
||||
"notificationservice-api",
|
||||
"/s/notificationservice/notificationservice-api:thrift",
|
||||
statsReceiver,
|
||||
pushserviceThriftClientId,
|
||||
requestTimeout = 10.seconds,
|
||||
mTLSServiceIdentifier = Some(serviceIdentifier)
|
||||
)
|
||||
|
||||
new NotificationApi.FinagledClient(
|
||||
new DynamicRequestMeterFilter(
|
||||
tunableMap(PushServiceTunableKeys.NtabQpsLimitTunableKey),
|
||||
RateLimiterGenerator.asTuple(_, shardParams.numShards, 20),
|
||||
PushQPSLimitConstants.IbisOrNTabQPSForRFPH)(timer).andThen(service),
|
||||
RichClientParam(serviceName = "notificationservice-api")
|
||||
)
|
||||
}
|
||||
|
||||
val mrRequestScriberNode = "mr_request_scribe"
|
||||
val loggedOutMrRequestScriberNode = "lo_mr_request_scribe"
|
||||
|
||||
override val pushSendEventStreamName = "frigate_pushservice_send_event_prod"
|
||||
} with DeployConfig {
|
||||
// Scribe
|
||||
private val notificationScribeLog = Logger("notification_scribe")
|
||||
private val notificationScribeInjection: Injection[NotificationScribe, String] = BinaryScalaCodec(
|
||||
NotificationScribe
|
||||
) andThen Injection.connect[Array[Byte], Base64String, String]
|
||||
|
||||
override def notificationScribe(data: NotificationScribe): Unit = {
|
||||
val logEntry: String = notificationScribeInjection(data)
|
||||
notificationScribeLog.info(logEntry)
|
||||
}
|
||||
|
||||
// History Store - Invalidates cached history after writes
|
||||
override val historyStore = new InvalidatingAfterWritesPushServiceHistoryStore(
|
||||
ManhattanHistoryStore(notificationHistoryStore, statsReceiver),
|
||||
recentHistoryCacheClient,
|
||||
new DeciderGateBuilder(decider)
|
||||
.idGate(DeciderKey.enableInvalidatingCachedHistoryStoreAfterWrites)
|
||||
)
|
||||
|
||||
override val emailHistoryStore: PushServiceHistoryStore = {
|
||||
statsReceiver.scope("frigate_email_history").counter("request").incr()
|
||||
new SimplePushServiceHistoryStore(emailNotificationHistoryStore)
|
||||
}
|
||||
|
||||
override val loggedOutHistoryStore =
|
||||
new InvalidatingAfterWritesPushServiceHistoryStore(
|
||||
ManhattanKVHistoryStore(
|
||||
manhattanKVLoggedOutHistoryStoreEndpoint,
|
||||
"frigate_notification_logged_out_history"),
|
||||
recentHistoryCacheClient,
|
||||
new DeciderGateBuilder(decider)
|
||||
.idGate(DeciderKey.enableInvalidatingCachedLoggedOutHistoryStoreAfterWrites)
|
||||
)
|
||||
|
||||
private val requestScribeLog = Logger("request_scribe")
|
||||
private val requestScribeInjection: Injection[PushRequestScribe, String] = BinaryScalaCodec(
|
||||
PushRequestScribe
|
||||
) andThen Injection.connect[Array[Byte], Base64String, String]
|
||||
|
||||
override def requestScribe(data: PushRequestScribe): Unit = {
|
||||
val logEntry: String = requestScribeInjection(data)
|
||||
requestScribeLog.info(logEntry)
|
||||
}
|
||||
|
||||
// generic notification server
|
||||
override def notificationServiceSend(
|
||||
target: Target,
|
||||
request: CreateGenericNotificationRequest
|
||||
): Future[CreateGenericNotificationResponse] =
|
||||
notificationServiceClient.createGenericNotification(request)
|
||||
|
||||
// generic notification server
|
||||
override def notificationServiceDelete(
|
||||
request: DeleteGenericNotificationRequest
|
||||
): Future[Unit] = notificationServiceClient.deleteGenericNotification(request)
|
||||
|
||||
// NTab-api
|
||||
override def notificationServiceDeleteTimeline(
|
||||
request: DeleteCurrentTimelineForUserRequest
|
||||
): Future[Unit] = notificationServiceApiClient.deleteCurrentTimelineForUser(request)
|
||||
|
||||
}
|
@ -0,0 +1,193 @@
|
||||
package com.twitter.frigate.pushservice.config
|
||||
|
||||
import com.twitter.abdecider.LoggingABDecider
|
||||
import com.twitter.conversions.DurationOps._
|
||||
import com.twitter.decider.Decider
|
||||
import com.twitter.featureswitches.v2.FeatureSwitches
|
||||
import com.twitter.finagle.mtls.authentication.ServiceIdentifier
|
||||
import com.twitter.finagle.stats.StatsReceiver
|
||||
import com.twitter.finagle.thrift.ClientId
|
||||
import com.twitter.finagle.thrift.RichClientParam
|
||||
import com.twitter.finagle.util.DefaultTimer
|
||||
import com.twitter.frigate.common.config.RateLimiterGenerator
|
||||
import com.twitter.frigate.common.filter.DynamicRequestMeterFilter
|
||||
import com.twitter.frigate.common.history.InvalidatingAfterWritesPushServiceHistoryStore
|
||||
import com.twitter.frigate.common.history.ManhattanHistoryStore
|
||||
import com.twitter.frigate.common.history.ManhattanKVHistoryStore
|
||||
import com.twitter.frigate.common.history.ReadOnlyHistoryStore
|
||||
import com.twitter.frigate.common.history.PushServiceHistoryStore
|
||||
import com.twitter.frigate.common.history.SimplePushServiceHistoryStore
|
||||
import com.twitter.frigate.common.util.Finagle
|
||||
import com.twitter.frigate.data_pipeline.features_common.FeatureStoreUtil
|
||||
import com.twitter.frigate.data_pipeline.features_common.TargetLevelFeaturesConfig
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.Target
|
||||
import com.twitter.frigate.pushservice.params.DeciderKey
|
||||
import com.twitter.frigate.pushservice.params.PushQPSLimitConstants
|
||||
import com.twitter.frigate.pushservice.params.PushServiceTunableKeys
|
||||
import com.twitter.frigate.pushservice.params.ShardParams
|
||||
import com.twitter.frigate.pushservice.store._
|
||||
import com.twitter.frigate.pushservice.thriftscala.PushRequestScribe
|
||||
import com.twitter.frigate.scribe.thriftscala.NotificationScribe
|
||||
import com.twitter.ibis2.service.thriftscala.Ibis2Service
|
||||
import com.twitter.logging.Logger
|
||||
import com.twitter.notificationservice.api.thriftscala.DeleteCurrentTimelineForUserRequest
|
||||
import com.twitter.notificationservice.thriftscala.CreateGenericNotificationRequest
|
||||
import com.twitter.notificationservice.thriftscala.CreateGenericNotificationResponse
|
||||
import com.twitter.notificationservice.thriftscala.CreateGenericNotificationResponseType
|
||||
import com.twitter.notificationservice.thriftscala.DeleteGenericNotificationRequest
|
||||
import com.twitter.notificationservice.thriftscala.NotificationService
|
||||
import com.twitter.notificationservice.thriftscala.NotificationService$FinagleClient
|
||||
import com.twitter.servo.decider.DeciderGateBuilder
|
||||
import com.twitter.util.tunable.TunableMap
|
||||
import com.twitter.util.Future
|
||||
import com.twitter.util.Timer
|
||||
|
||||
case class StagingConfig(
|
||||
override val isServiceLocal: Boolean,
|
||||
override val localConfigRepoPath: String,
|
||||
override val inMemCacheOff: Boolean,
|
||||
override val decider: Decider,
|
||||
override val abDecider: LoggingABDecider,
|
||||
override val featureSwitches: FeatureSwitches,
|
||||
override val shardParams: ShardParams,
|
||||
override val serviceIdentifier: ServiceIdentifier,
|
||||
override val tunableMap: TunableMap,
|
||||
)(
|
||||
implicit val statsReceiver: StatsReceiver)
|
||||
extends {
|
||||
// Due to trait initialization logic in Scala, any abstract members declared in Config or
|
||||
// DeployConfig should be declared in this block. Otherwise the abstract member might initialize to
|
||||
// null if invoked before object creation finishing.
|
||||
|
||||
val log = Logger("StagingConfig")
|
||||
|
||||
// Client ids
|
||||
val notifierThriftClientId = ClientId("frigate-notifier.dev")
|
||||
val loggedOutNotifierThriftClientId = ClientId("frigate-logged-out-notifier.dev")
|
||||
val pushserviceThriftClientId: ClientId = ClientId("frigate-pushservice.staging")
|
||||
|
||||
override val fanoutMetadataColumn = "frigate/magicfanout/staging/mh/fanoutMetadata"
|
||||
|
||||
// dest
|
||||
val frigateHistoryCacheDest = "/srv#/test/local/cache/twemcache_frigate_history"
|
||||
val memcacheCASDest = "/srv#/test/local/cache/twemcache_magic_recs_cas_dev:twemcaches"
|
||||
val pushServiceMHCacheDest = "/srv#/test/local/cache/twemcache_pushservice_test"
|
||||
val entityGraphCacheDest = "/srv#/test/local/cache/twemcache_pushservice_test"
|
||||
val pushServiceCoreSvcsCacheDest = "/srv#/test/local/cache/twemcache_pushservice_core_svcs_test"
|
||||
val historyStoreMemcacheDest = "/srv#/test/local/cache/twemcache_eventstream_test:twemcaches"
|
||||
val userTweetEntityGraphDest = "/cluster/local/cassowary/staging/user_tweet_entity_graph"
|
||||
val userUserGraphDest = "/cluster/local/cassowary/staging/user_user_graph"
|
||||
val lexServiceDest = "/srv#/staging/local/live-video/timeline-thrift"
|
||||
val deepbirdv2PredictionServiceDest = "/cluster/local/frigate/staging/deepbirdv2-magicrecs"
|
||||
|
||||
override val featureStoreUtil = FeatureStoreUtil.withParams(Some(serviceIdentifier))
|
||||
override val targetLevelFeaturesConfig = TargetLevelFeaturesConfig()
|
||||
val mrRequestScriberNode = "validation_mr_request_scribe"
|
||||
val loggedOutMrRequestScriberNode = "lo_mr_request_scribe"
|
||||
|
||||
override val timer: Timer = DefaultTimer
|
||||
|
||||
override val pushSendEventStreamName = "frigate_pushservice_send_event_staging"
|
||||
|
||||
override val pushIbisV2Store = {
|
||||
val service = Finagle.readWriteThriftService(
|
||||
"ibis-v2-service",
|
||||
"/s/ibis2/ibis2",
|
||||
statsReceiver,
|
||||
notifierThriftClientId,
|
||||
requestTimeout = 6.seconds,
|
||||
mTLSServiceIdentifier = Some(serviceIdentifier)
|
||||
)
|
||||
|
||||
val pushIbisClient = new Ibis2Service.FinagledClient(
|
||||
new DynamicRequestMeterFilter(
|
||||
tunableMap(PushServiceTunableKeys.IbisQpsLimitTunableKey),
|
||||
RateLimiterGenerator.asTuple(_, shardParams.numShards, 20),
|
||||
PushQPSLimitConstants.IbisOrNTabQPSForRFPH
|
||||
)(timer).andThen(service),
|
||||
RichClientParam(serviceName = "ibis-v2-service")
|
||||
)
|
||||
|
||||
StagingIbis2Store(PushIbis2Store(pushIbisClient))
|
||||
}
|
||||
|
||||
val notificationServiceClient: NotificationService$FinagleClient = {
|
||||
val service = Finagle.readWriteThriftService(
|
||||
"notificationservice",
|
||||
"/s/notificationservice/notificationservice",
|
||||
statsReceiver,
|
||||
pushserviceThriftClientId,
|
||||
requestTimeout = 10.seconds,
|
||||
mTLSServiceIdentifier = Some(serviceIdentifier)
|
||||
)
|
||||
|
||||
new NotificationService.FinagledClient(
|
||||
new DynamicRequestMeterFilter(
|
||||
tunableMap(PushServiceTunableKeys.NtabQpsLimitTunableKey),
|
||||
RateLimiterGenerator.asTuple(_, shardParams.numShards, 20),
|
||||
PushQPSLimitConstants.IbisOrNTabQPSForRFPH)(timer).andThen(service),
|
||||
RichClientParam(serviceName = "notificationservice")
|
||||
)
|
||||
}
|
||||
} with DeployConfig {
|
||||
|
||||
// Scribe
|
||||
private val notificationScribeLog = Logger("StagingNotificationScribe")
|
||||
|
||||
override def notificationScribe(data: NotificationScribe): Unit = {
|
||||
notificationScribeLog.info(data.toString)
|
||||
}
|
||||
private val requestScribeLog = Logger("StagingRequestScribe")
|
||||
|
||||
override def requestScribe(data: PushRequestScribe): Unit = {
|
||||
requestScribeLog.info(data.toString)
|
||||
}
|
||||
|
||||
// history store
|
||||
override val historyStore = new InvalidatingAfterWritesPushServiceHistoryStore(
|
||||
ReadOnlyHistoryStore(
|
||||
ManhattanHistoryStore(notificationHistoryStore, statsReceiver)
|
||||
),
|
||||
recentHistoryCacheClient,
|
||||
new DeciderGateBuilder(decider)
|
||||
.idGate(DeciderKey.enableInvalidatingCachedHistoryStoreAfterWrites)
|
||||
)
|
||||
|
||||
override val emailHistoryStore: PushServiceHistoryStore = new SimplePushServiceHistoryStore(
|
||||
emailNotificationHistoryStore)
|
||||
|
||||
// history store
|
||||
override val loggedOutHistoryStore =
|
||||
new InvalidatingAfterWritesPushServiceHistoryStore(
|
||||
ReadOnlyHistoryStore(
|
||||
ManhattanKVHistoryStore(
|
||||
manhattanKVLoggedOutHistoryStoreEndpoint,
|
||||
"frigate_notification_logged_out_history")),
|
||||
recentHistoryCacheClient,
|
||||
new DeciderGateBuilder(decider)
|
||||
.idGate(DeciderKey.enableInvalidatingCachedLoggedOutHistoryStoreAfterWrites)
|
||||
)
|
||||
|
||||
override def notificationServiceSend(
|
||||
target: Target,
|
||||
request: CreateGenericNotificationRequest
|
||||
): Future[CreateGenericNotificationResponse] =
|
||||
target.isTeamMember.flatMap { isTeamMember =>
|
||||
if (isTeamMember) {
|
||||
notificationServiceClient.createGenericNotification(request)
|
||||
} else {
|
||||
log.info(s"Mock creating generic notification $request for user: ${target.targetId}")
|
||||
Future.value(
|
||||
CreateGenericNotificationResponse(CreateGenericNotificationResponseType.Success)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
override def notificationServiceDelete(
|
||||
request: DeleteGenericNotificationRequest
|
||||
): Future[Unit] = Future.Unit
|
||||
|
||||
override def notificationServiceDeleteTimeline(
|
||||
request: DeleteCurrentTimelineForUserRequest
|
||||
): Future[Unit] = Future.Unit
|
||||
}
|
@ -0,0 +1,23 @@
|
||||
package com.twitter.frigate.pushservice.config.mlconfig
|
||||
|
||||
import com.twitter.cortex.deepbird.thriftjava.DeepbirdPredictionService
|
||||
import com.twitter.finagle.stats.StatsReceiver
|
||||
import com.twitter.frigate.common.ml.prediction.DeepbirdPredictionEngineServiceStore
|
||||
import com.twitter.nrel.heavyranker.PushDBv2PredictionServiceStore
|
||||
|
||||
object DeepbirdV2ModelConfig {
|
||||
def buildPredictionServiceScoreStore(
|
||||
predictionServiceClient: DeepbirdPredictionService.ServiceToClient,
|
||||
serviceName: String
|
||||
)(
|
||||
implicit statsReceiver: StatsReceiver
|
||||
): PushDBv2PredictionServiceStore = {
|
||||
|
||||
val stats = statsReceiver.scope(serviceName)
|
||||
val serviceStats = statsReceiver.scope("dbv2PredictionServiceStore")
|
||||
|
||||
new PushDBv2PredictionServiceStore(
|
||||
DeepbirdPredictionEngineServiceStore(predictionServiceClient, batchSize = Some(32))(stats)
|
||||
)(serviceStats)
|
||||
}
|
||||
}
|
@ -0,0 +1,114 @@
|
||||
package com.twitter.frigate.pushservice.controller
|
||||
|
||||
import com.google.inject.Inject
|
||||
import com.twitter.finagle.stats.StatsReceiver
|
||||
import com.twitter.finagle.thrift.ClientId
|
||||
import com.twitter.finatra.thrift.Controller
|
||||
import com.twitter.frigate.pushservice.exception.DisplayLocationNotSupportedException
|
||||
import com.twitter.frigate.pushservice.refresh_handler.RefreshForPushHandler
|
||||
import com.twitter.frigate.pushservice.send_handler.SendHandler
|
||||
import com.twitter.frigate.pushservice.refresh_handler.LoggedOutRefreshForPushHandler
|
||||
import com.twitter.frigate.pushservice.thriftscala.PushService.Loggedout
|
||||
import com.twitter.frigate.pushservice.thriftscala.PushService.Refresh
|
||||
import com.twitter.frigate.pushservice.thriftscala.PushService.Send
|
||||
import com.twitter.frigate.pushservice.{thriftscala => t}
|
||||
import com.twitter.frigate.thriftscala.NotificationDisplayLocation
|
||||
import com.twitter.util.logging.Logging
|
||||
import com.twitter.util.Future
|
||||
|
||||
class PushServiceController @Inject() (
|
||||
sendHandler: SendHandler,
|
||||
refreshForPushHandler: RefreshForPushHandler,
|
||||
loggedOutRefreshForPushHandler: LoggedOutRefreshForPushHandler,
|
||||
statsReceiver: StatsReceiver)
|
||||
extends Controller(t.PushService)
|
||||
with Logging {
|
||||
|
||||
private val stats: StatsReceiver = statsReceiver.scope(s"${this.getClass.getSimpleName}")
|
||||
private val failureCount = stats.counter("failures")
|
||||
private val failureStatsScope = stats.scope("failures")
|
||||
private val uncaughtErrorCount = failureStatsScope.counter("uncaught")
|
||||
private val uncaughtErrorScope = failureStatsScope.scope("uncaught")
|
||||
private val clientIdScope = stats.scope("client_id")
|
||||
|
||||
handle(t.PushService.Send) { request: Send.Args =>
|
||||
send(request)
|
||||
}
|
||||
|
||||
handle(t.PushService.Refresh) { args: Refresh.Args =>
|
||||
refresh(args)
|
||||
}
|
||||
|
||||
handle(t.PushService.Loggedout) { request: Loggedout.Args =>
|
||||
loggedOutRefresh(request)
|
||||
}
|
||||
|
||||
private def loggedOutRefresh(
|
||||
request: t.PushService.Loggedout.Args
|
||||
): Future[t.PushService.Loggedout.SuccessType] = {
|
||||
val fut = request.request.notificationDisplayLocation match {
|
||||
case NotificationDisplayLocation.PushToMobileDevice =>
|
||||
loggedOutRefreshForPushHandler.refreshAndSend(request.request)
|
||||
case _ =>
|
||||
Future.exception(
|
||||
new DisplayLocationNotSupportedException(
|
||||
"Specified notification display location is not supported"))
|
||||
}
|
||||
fut.onFailure { ex =>
|
||||
logger.error(
|
||||
s"Failure in push service for logged out refresh request: $request - ${ex.getMessage} - ${ex.getStackTrace
|
||||
.mkString(", \n\t")}",
|
||||
ex)
|
||||
failureCount.incr()
|
||||
uncaughtErrorCount.incr()
|
||||
uncaughtErrorScope.counter(ex.getClass.getCanonicalName).incr()
|
||||
}
|
||||
}
|
||||
|
||||
private def refresh(
|
||||
request: t.PushService.Refresh.Args
|
||||
): Future[t.PushService.Refresh.SuccessType] = {
|
||||
|
||||
val fut = request.request.notificationDisplayLocation match {
|
||||
case NotificationDisplayLocation.PushToMobileDevice =>
|
||||
val clientId: String =
|
||||
ClientId.current
|
||||
.flatMap { cid => Option(cid.name) }
|
||||
.getOrElse("none")
|
||||
clientIdScope.counter(clientId).incr()
|
||||
refreshForPushHandler.refreshAndSend(request.request)
|
||||
case _ =>
|
||||
Future.exception(
|
||||
new DisplayLocationNotSupportedException(
|
||||
"Specified notification display location is not supported"))
|
||||
}
|
||||
fut.onFailure { ex =>
|
||||
logger.error(
|
||||
s"Failure in push service for refresh request: $request - ${ex.getMessage} - ${ex.getStackTrace
|
||||
.mkString(", \n\t")}",
|
||||
ex
|
||||
)
|
||||
|
||||
failureCount.incr()
|
||||
uncaughtErrorCount.incr()
|
||||
uncaughtErrorScope.counter(ex.getClass.getCanonicalName).incr()
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
private def send(
|
||||
request: t.PushService.Send.Args
|
||||
): Future[t.PushService.Send.SuccessType] = {
|
||||
sendHandler(request.request).onFailure { ex =>
|
||||
logger.error(
|
||||
s"Failure in push service for send request: $request - ${ex.getMessage} - ${ex.getStackTrace
|
||||
.mkString(", \n\t")}",
|
||||
ex
|
||||
)
|
||||
|
||||
failureCount.incr()
|
||||
uncaughtErrorCount.incr()
|
||||
uncaughtErrorScope.counter(ex.getClass.getCanonicalName).incr()
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,12 @@
|
||||
package com.twitter.frigate.pushservice.exception
|
||||
|
||||
import scala.util.control.NoStackTrace
|
||||
|
||||
/**
|
||||
* Throw exception if DisplayLocation is not supported
|
||||
*
|
||||
* @param message Exception message
|
||||
*/
|
||||
class DisplayLocationNotSupportedException(private val message: String)
|
||||
extends Exception(message)
|
||||
with NoStackTrace
|
@ -0,0 +1,12 @@
|
||||
package com.twitter.frigate.pushservice.exception
|
||||
|
||||
import scala.util.control.NoStackTrace
|
||||
|
||||
/**
|
||||
* Throw exception if the sport domain is not supported by MagicFanoutSports
|
||||
*
|
||||
* @param message Exception message
|
||||
*/
|
||||
class InvalidSportDomainException(private val message: String)
|
||||
extends Exception(message)
|
||||
with NoStackTrace
|
@ -0,0 +1,7 @@
|
||||
package com.twitter.frigate.pushservice.exception
|
||||
|
||||
import scala.util.control.NoStackTrace
|
||||
|
||||
class TweetNTabRequestHydratorException(private val message: String)
|
||||
extends Exception(message)
|
||||
with NoStackTrace
|
@ -0,0 +1,11 @@
|
||||
package com.twitter.frigate.pushservice.exception
|
||||
|
||||
import scala.util.control.NoStackTrace
|
||||
|
||||
/**
|
||||
* Exception for CRT not expected in the scope
|
||||
* @param message Exception message to log the UnsupportedCrt
|
||||
*/
|
||||
class UnsupportedCrtException(private val message: String)
|
||||
extends Exception(message)
|
||||
with NoStackTrace
|
@ -0,0 +1,12 @@
|
||||
package com.twitter.frigate.pushservice.exception
|
||||
|
||||
import scala.util.control.NoStackTrace
|
||||
|
||||
/**
|
||||
* Throw exception if UttEntity is not found where it might be a required data field
|
||||
*
|
||||
* @param message Exception message
|
||||
*/
|
||||
class UttEntityNotFoundException(private val message: String)
|
||||
extends Exception(message)
|
||||
with NoStackTrace
|
@ -0,0 +1,220 @@
|
||||
package com.twitter.frigate.pushservice.ml
|
||||
|
||||
import com.twitter.abuse.detection.scoring.thriftscala.{Model => TweetHealthModel}
|
||||
import com.twitter.abuse.detection.scoring.thriftscala.TweetScoringRequest
|
||||
import com.twitter.abuse.detection.scoring.thriftscala.TweetScoringResponse
|
||||
import com.twitter.frigate.common.base.FeatureMap
|
||||
import com.twitter.frigate.common.base.TweetAuthor
|
||||
import com.twitter.frigate.common.base.TweetAuthorDetails
|
||||
import com.twitter.frigate.common.base.TweetCandidate
|
||||
import com.twitter.frigate.common.rec_types.RecTypes
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.PushCandidate
|
||||
import com.twitter.frigate.pushservice.params.PushConstants
|
||||
import com.twitter.frigate.pushservice.params.PushFeatureSwitchParams
|
||||
import com.twitter.frigate.pushservice.predicate.HealthPredicates.userHealthSignalValueToDouble
|
||||
import com.twitter.frigate.pushservice.util.CandidateHydrationUtil
|
||||
import com.twitter.frigate.pushservice.util.CandidateUtil
|
||||
import com.twitter.frigate.pushservice.util.MediaAnnotationsUtil
|
||||
import com.twitter.frigate.thriftscala.UserMediaRepresentation
|
||||
import com.twitter.hss.api.thriftscala.SignalValue
|
||||
import com.twitter.hss.api.thriftscala.UserHealthSignal
|
||||
import com.twitter.hss.api.thriftscala.UserHealthSignal.AgathaCalibratedNsfwDouble
|
||||
import com.twitter.hss.api.thriftscala.UserHealthSignal.NsfwTextUserScoreDouble
|
||||
import com.twitter.hss.api.thriftscala.UserHealthSignalResponse
|
||||
import com.twitter.storehaus.ReadableStore
|
||||
import com.twitter.util.Future
|
||||
import com.twitter.util.Time
|
||||
|
||||
object HealthFeatureGetter {
|
||||
|
||||
def getFeatures(
|
||||
pushCandidate: PushCandidate,
|
||||
producerMediaRepresentationStore: ReadableStore[Long, UserMediaRepresentation],
|
||||
userHealthScoreStore: ReadableStore[Long, UserHealthSignalResponse],
|
||||
tweetHealthScoreStoreOpt: Option[ReadableStore[TweetScoringRequest, TweetScoringResponse]] =
|
||||
None
|
||||
): Future[FeatureMap] = {
|
||||
|
||||
pushCandidate match {
|
||||
case cand: PushCandidate with TweetCandidate with TweetAuthor with TweetAuthorDetails =>
|
||||
val pMediaNsfwRequest =
|
||||
TweetScoringRequest(cand.tweetId, TweetHealthModel.ExperimentalHealthModelScore4)
|
||||
val pTweetTextNsfwRequest =
|
||||
TweetScoringRequest(cand.tweetId, TweetHealthModel.ExperimentalHealthModelScore1)
|
||||
|
||||
cand.authorId match {
|
||||
case Some(authorId) =>
|
||||
Future
|
||||
.join(
|
||||
userHealthScoreStore.get(authorId),
|
||||
producerMediaRepresentationStore.get(authorId),
|
||||
tweetHealthScoreStoreOpt.map(_.get(pMediaNsfwRequest)).getOrElse(Future.None),
|
||||
tweetHealthScoreStoreOpt.map(_.get(pTweetTextNsfwRequest)).getOrElse(Future.None),
|
||||
cand.tweetAuthor
|
||||
).map {
|
||||
case (
|
||||
healthSignalsResponseOpt,
|
||||
producerMuOpt,
|
||||
pMediaNsfwOpt,
|
||||
pTweetTextNsfwOpt,
|
||||
tweetAuthorOpt) =>
|
||||
val healthSignalScoreMap = healthSignalsResponseOpt
|
||||
.map(_.signalValues).getOrElse(Map.empty[UserHealthSignal, SignalValue])
|
||||
val agathaNSFWScore = userHealthSignalValueToDouble(
|
||||
healthSignalScoreMap
|
||||
.getOrElse(AgathaCalibratedNsfwDouble, SignalValue.DoubleValue(0.5)))
|
||||
val userTextNSFWScore = userHealthSignalValueToDouble(
|
||||
healthSignalScoreMap
|
||||
.getOrElse(NsfwTextUserScoreDouble, SignalValue.DoubleValue(0.15)))
|
||||
val pMediaNsfwScore = pMediaNsfwOpt.map(_.score).getOrElse(0.0)
|
||||
val pTweetTextNsfwScore = pTweetTextNsfwOpt.map(_.score).getOrElse(0.0)
|
||||
|
||||
val mediaRepresentationMap =
|
||||
producerMuOpt.map(_.mediaRepresentation).getOrElse(Map.empty[String, Double])
|
||||
val sumScore: Double = mediaRepresentationMap.values.sum
|
||||
val nudityRate =
|
||||
if (sumScore > 0)
|
||||
mediaRepresentationMap.getOrElse(
|
||||
MediaAnnotationsUtil.nudityCategoryId,
|
||||
0.0) / sumScore
|
||||
else 0.0
|
||||
val beautyRate =
|
||||
if (sumScore > 0)
|
||||
mediaRepresentationMap.getOrElse(
|
||||
MediaAnnotationsUtil.beautyCategoryId,
|
||||
0.0) / sumScore
|
||||
else 0.0
|
||||
val singlePersonRate =
|
||||
if (sumScore > 0)
|
||||
mediaRepresentationMap.getOrElse(
|
||||
MediaAnnotationsUtil.singlePersonCategoryId,
|
||||
0.0) / sumScore
|
||||
else 0.0
|
||||
val dislikeCt = cand.numericFeatures.getOrElse(
|
||||
"tweet.magic_recs_tweet_real_time_aggregates_v2.pair.v2.magicrecs.realtime.is_ntab_disliked.any_feature.Duration.Top.count",
|
||||
0.0)
|
||||
val sentCt = cand.numericFeatures.getOrElse(
|
||||
"tweet.magic_recs_tweet_real_time_aggregates_v2.pair.v2.magicrecs.realtime.is_sent.any_feature.Duration.Top.count",
|
||||
0.0)
|
||||
val dislikeRate = if (sentCt > 0) dislikeCt / sentCt else 0.0
|
||||
|
||||
val authorDislikeCt = cand.numericFeatures.getOrElse(
|
||||
"tweet_author_aggregate.pair.label.ntab.isDisliked.any_feature.28.days.count",
|
||||
0.0)
|
||||
val authorReportCt = cand.numericFeatures.getOrElse(
|
||||
"tweet_author_aggregate.pair.label.reportTweetDone.any_feature.28.days.count",
|
||||
0.0)
|
||||
val authorSentCt = cand.numericFeatures
|
||||
.getOrElse(
|
||||
"tweet_author_aggregate.pair.any_label.any_feature.28.days.count",
|
||||
0.0)
|
||||
val authorDislikeRate =
|
||||
if (authorSentCt > 0) authorDislikeCt / authorSentCt else 0.0
|
||||
val authorReportRate =
|
||||
if (authorSentCt > 0) authorReportCt / authorSentCt else 0.0
|
||||
|
||||
val (isNsfwAccount, authorAccountAge) = tweetAuthorOpt match {
|
||||
case Some(tweetAuthor) =>
|
||||
(
|
||||
CandidateHydrationUtil.isNsfwAccount(
|
||||
tweetAuthor,
|
||||
cand.target.params(PushFeatureSwitchParams.NsfwTokensParam)),
|
||||
(Time.now - Time.fromMilliseconds(tweetAuthor.createdAtMsec)).inHours
|
||||
)
|
||||
case _ => (false, 0)
|
||||
}
|
||||
|
||||
val tweetSemanticCoreIds = cand.sparseBinaryFeatures
|
||||
.getOrElse(PushConstants.TweetSemanticCoreIdFeature, Set.empty[String])
|
||||
|
||||
val continuousFeatures = Map[String, Double](
|
||||
"agathaNsfwScore" -> agathaNSFWScore,
|
||||
"textNsfwScore" -> userTextNSFWScore,
|
||||
"pMediaNsfwScore" -> pMediaNsfwScore,
|
||||
"pTweetTextNsfwScore" -> pTweetTextNsfwScore,
|
||||
"nudityRate" -> nudityRate,
|
||||
"beautyRate" -> beautyRate,
|
||||
"singlePersonRate" -> singlePersonRate,
|
||||
"numSources" -> CandidateUtil.getTagsCRCount(cand),
|
||||
"favCount" -> cand.numericFeatures
|
||||
.getOrElse("tweet.core.tweet_counts.favorite_count", 0.0),
|
||||
"activeFollowers" -> cand.numericFeatures
|
||||
.getOrElse("RecTweetAuthor.User.ActiveFollowers", 0.0),
|
||||
"favorsRcvd28Days" -> cand.numericFeatures
|
||||
.getOrElse("RecTweetAuthor.User.FavorsRcvd28Days", 0.0),
|
||||
"tweets28Days" -> cand.numericFeatures
|
||||
.getOrElse("RecTweetAuthor.User.Tweets28Days", 0.0),
|
||||
"dislikeCount" -> dislikeCt,
|
||||
"dislikeRate" -> dislikeRate,
|
||||
"sentCount" -> sentCt,
|
||||
"authorDislikeCount" -> authorDislikeCt,
|
||||
"authorDislikeRate" -> authorDislikeRate,
|
||||
"authorReportCount" -> authorReportCt,
|
||||
"authorReportRate" -> authorReportRate,
|
||||
"authorSentCount" -> authorSentCt,
|
||||
"authorAgeInHour" -> authorAccountAge.toDouble
|
||||
)
|
||||
|
||||
val booleanFeatures = Map[String, Boolean](
|
||||
"isSimclusterBased" -> RecTypes.simclusterBasedTweets
|
||||
.contains(cand.commonRecType),
|
||||
"isTopicTweet" -> RecTypes.isTopicTweetType(cand.commonRecType),
|
||||
"isHashSpace" -> RecTypes.tagspaceTypes.contains(cand.commonRecType),
|
||||
"isFRS" -> RecTypes.frsTypes.contains(cand.commonRecType),
|
||||
"isModelingBased" -> RecTypes.mrModelingBasedTypes.contains(cand.commonRecType),
|
||||
"isGeoPop" -> RecTypes.GeoPopTweetTypes.contains(cand.commonRecType),
|
||||
"hasPhoto" -> cand.booleanFeatures
|
||||
.getOrElse("RecTweet.TweetyPieResult.HasPhoto", false),
|
||||
"hasVideo" -> cand.booleanFeatures
|
||||
.getOrElse("RecTweet.TweetyPieResult.HasVideo", false),
|
||||
"hasUrl" -> cand.booleanFeatures
|
||||
.getOrElse("RecTweet.TweetyPieResult.HasUrl", false),
|
||||
"isMrTwistly" -> CandidateUtil.isMrTwistlyCandidate(cand),
|
||||
"abuseStrikeTop2Percent" -> tweetSemanticCoreIds.contains(
|
||||
PushConstants.AbuseStrike_Top2Percent_Id),
|
||||
"abuseStrikeTop1Percent" -> tweetSemanticCoreIds.contains(
|
||||
PushConstants.AbuseStrike_Top1Percent_Id),
|
||||
"abuseStrikeTop05Percent" -> tweetSemanticCoreIds.contains(
|
||||
PushConstants.AbuseStrike_Top05Percent_Id),
|
||||
"abuseStrikeTop025Percent" -> tweetSemanticCoreIds.contains(
|
||||
PushConstants.AbuseStrike_Top025Percent_Id),
|
||||
"allSpamReportsPerFavTop1Percent" -> tweetSemanticCoreIds.contains(
|
||||
PushConstants.AllSpamReportsPerFav_Top1Percent_Id),
|
||||
"reportsPerFavTop1Percent" -> tweetSemanticCoreIds.contains(
|
||||
PushConstants.ReportsPerFav_Top1Percent_Id),
|
||||
"reportsPerFavTop2Percent" -> tweetSemanticCoreIds.contains(
|
||||
PushConstants.ReportsPerFav_Top2Percent_Id),
|
||||
"isNudity" -> tweetSemanticCoreIds.contains(
|
||||
PushConstants.MediaUnderstanding_Nudity_Id),
|
||||
"beautyStyleFashion" -> tweetSemanticCoreIds.contains(
|
||||
PushConstants.MediaUnderstanding_Beauty_Id),
|
||||
"singlePerson" -> tweetSemanticCoreIds.contains(
|
||||
PushConstants.MediaUnderstanding_SinglePerson_Id),
|
||||
"pornList" -> tweetSemanticCoreIds.contains(PushConstants.PornList_Id),
|
||||
"pornographyAndNsfwContent" -> tweetSemanticCoreIds.contains(
|
||||
PushConstants.PornographyAndNsfwContent_Id),
|
||||
"sexLife" -> tweetSemanticCoreIds.contains(PushConstants.SexLife_Id),
|
||||
"sexLifeOrSexualOrientation" -> tweetSemanticCoreIds.contains(
|
||||
PushConstants.SexLifeOrSexualOrientation_Id),
|
||||
"profanity" -> tweetSemanticCoreIds.contains(PushConstants.ProfanityFilter_Id),
|
||||
"isVerified" -> cand.booleanFeatures
|
||||
.getOrElse("RecTweetAuthor.User.IsVerified", false),
|
||||
"hasNsfwToken" -> isNsfwAccount
|
||||
)
|
||||
|
||||
val stringFeatures = Map[String, String](
|
||||
"tweetLanguage" -> cand.categoricalFeatures
|
||||
.getOrElse("tweet.core.tweet_text.language", "")
|
||||
)
|
||||
|
||||
FeatureMap(
|
||||
booleanFeatures = booleanFeatures,
|
||||
numericFeatures = continuousFeatures,
|
||||
categoricalFeatures = stringFeatures)
|
||||
}
|
||||
case _ => Future.value(FeatureMap())
|
||||
}
|
||||
case _ => Future.value(FeatureMap())
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,179 @@
|
||||
package com.twitter.frigate.pushservice.ml
|
||||
|
||||
import com.twitter.frigate.common.base._
|
||||
import com.twitter.frigate.common.ml.feature.TweetSocialProofKey
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.PushCandidate
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.Target
|
||||
import com.twitter.frigate.pushservice.params.PushFeatureSwitchParams
|
||||
import com.twitter.frigate.pushservice.predicate.quality_model_predicate.PDauCohortUtil
|
||||
import com.twitter.nrel.hydration.base.FeatureInput
|
||||
import com.twitter.nrel.hydration.push.HydrationContext
|
||||
import com.twitter.nrel.hydration.frigate.{FeatureInputs => FI}
|
||||
import com.twitter.util.Future
|
||||
|
||||
object HydrationContextBuilder {
|
||||
|
||||
private def getRecUserInputs(
|
||||
pushCandidate: PushCandidate
|
||||
): Set[FI.RecUser] = {
|
||||
pushCandidate match {
|
||||
case userCandidate: UserCandidate =>
|
||||
Set(FI.RecUser(userCandidate.userId))
|
||||
case _ => Set.empty
|
||||
}
|
||||
}
|
||||
|
||||
private def getRecTweetInputs(
|
||||
pushCandidate: PushCandidate
|
||||
): Set[FI.RecTweet] =
|
||||
pushCandidate match {
|
||||
case tweetCandidateWithAuthor: TweetCandidate with TweetAuthor with TweetAuthorDetails =>
|
||||
val authorIdOpt = tweetCandidateWithAuthor.authorId
|
||||
Set(FI.RecTweet(tweetCandidateWithAuthor.tweetId, authorIdOpt))
|
||||
case _ => Set.empty
|
||||
}
|
||||
|
||||
private def getMediaInputs(
|
||||
pushCandidate: PushCandidate
|
||||
): Set[FI.Media] =
|
||||
pushCandidate match {
|
||||
case tweetCandidateWithMedia: TweetCandidate with TweetDetails =>
|
||||
tweetCandidateWithMedia.mediaKeys
|
||||
.map { mk =>
|
||||
Set(FI.Media(mk))
|
||||
}.getOrElse(Set.empty)
|
||||
case _ => Set.empty
|
||||
}
|
||||
|
||||
private def getEventInputs(
|
||||
pushCandidate: PushCandidate
|
||||
): Set[FI.Event] = pushCandidate match {
|
||||
case mrEventCandidate: EventCandidate =>
|
||||
Set(FI.Event(mrEventCandidate.eventId))
|
||||
case mfEventCandidate: MagicFanoutEventCandidate =>
|
||||
Set(FI.Event(mfEventCandidate.eventId))
|
||||
case _ => Set.empty
|
||||
}
|
||||
|
||||
private def getTopicInputs(
|
||||
pushCandidate: PushCandidate
|
||||
): Set[FI.Topic] =
|
||||
pushCandidate match {
|
||||
case mrTopicCandidate: TopicCandidate =>
|
||||
mrTopicCandidate.semanticCoreEntityId match {
|
||||
case Some(topicId) => Set(FI.Topic(topicId))
|
||||
case _ => Set.empty
|
||||
}
|
||||
case _ => Set.empty
|
||||
}
|
||||
|
||||
private def getTweetSocialProofKey(
|
||||
pushCandidate: PushCandidate
|
||||
): Future[Set[FI.SocialProofKey]] = {
|
||||
pushCandidate match {
|
||||
case candidate: TweetCandidate with SocialContextActions =>
|
||||
val target = pushCandidate.target
|
||||
target.seedsWithWeight.map { seedsWithWeightOpt =>
|
||||
Set(
|
||||
FI.SocialProofKey(
|
||||
TweetSocialProofKey(
|
||||
seedsWithWeightOpt.getOrElse(Map.empty),
|
||||
candidate.socialContextAllTypeActions
|
||||
))
|
||||
)
|
||||
}
|
||||
case _ => Future.value(Set.empty)
|
||||
}
|
||||
}
|
||||
|
||||
private def getSocialContextInputs(
|
||||
pushCandidate: PushCandidate
|
||||
): Future[Set[FeatureInput]] =
|
||||
pushCandidate match {
|
||||
case candidateWithSC: Candidate with SocialContextActions =>
|
||||
val tweetSocialProofKeyFut = getTweetSocialProofKey(pushCandidate)
|
||||
tweetSocialProofKeyFut.map { tweetSocialProofKeyOpt =>
|
||||
val socialContextUsers = FI.SocialContextUsers(candidateWithSC.socialContextUserIds.toSet)
|
||||
val socialContextActions =
|
||||
FI.SocialContextActions(candidateWithSC.socialContextAllTypeActions)
|
||||
val socialProofKeyOpt = tweetSocialProofKeyOpt
|
||||
Set(Set(socialContextUsers), Set(socialContextActions), socialProofKeyOpt).flatten
|
||||
}
|
||||
case _ => Future.value(Set.empty)
|
||||
}
|
||||
|
||||
private def getPushStringGroupInputs(
|
||||
pushCandidate: PushCandidate
|
||||
): Set[FI.PushStringGroup] =
|
||||
Set(
|
||||
FI.PushStringGroup(
|
||||
pushCandidate.getPushCopy.flatMap(_.pushStringGroup).map(_.toString).getOrElse("")
|
||||
))
|
||||
|
||||
private def getCRTInputs(
|
||||
pushCandidate: PushCandidate
|
||||
): Set[FI.CommonRecommendationType] =
|
||||
Set(FI.CommonRecommendationType(pushCandidate.commonRecType))
|
||||
|
||||
private def getFrigateNotification(
|
||||
pushCandidate: PushCandidate
|
||||
): Set[FI.CandidateFrigateNotification] =
|
||||
Set(FI.CandidateFrigateNotification(pushCandidate.frigateNotification))
|
||||
|
||||
private def getCopyId(
|
||||
pushCandidate: PushCandidate
|
||||
): Set[FI.CopyId] =
|
||||
Set(FI.CopyId(pushCandidate.pushCopyId, pushCandidate.ntabCopyId))
|
||||
|
||||
def build(candidate: PushCandidate): Future[HydrationContext] = {
|
||||
val socialContextInputsFut = getSocialContextInputs(candidate)
|
||||
socialContextInputsFut.map { socialContextInputs =>
|
||||
val featureInputs: Set[FeatureInput] =
|
||||
socialContextInputs ++
|
||||
getRecUserInputs(candidate) ++
|
||||
getRecTweetInputs(candidate) ++
|
||||
getEventInputs(candidate) ++
|
||||
getTopicInputs(candidate) ++
|
||||
getCRTInputs(candidate) ++
|
||||
getPushStringGroupInputs(candidate) ++
|
||||
getMediaInputs(candidate) ++
|
||||
getFrigateNotification(candidate) ++
|
||||
getCopyId(candidate)
|
||||
|
||||
HydrationContext(
|
||||
candidate.target.targetId,
|
||||
featureInputs
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
def build(target: Target): Future[HydrationContext] = {
|
||||
val realGraphFeaturesFut = target.realGraphFeatures
|
||||
for {
|
||||
realGraphFeaturesOpt <- realGraphFeaturesFut
|
||||
dauProb <- PDauCohortUtil.getDauProb(target)
|
||||
mrUserStateOpt <- target.targetMrUserState
|
||||
historyInputOpt <-
|
||||
if (target.params(PushFeatureSwitchParams.EnableHydratingOnlineMRHistoryFeatures)) {
|
||||
target.onlineLabeledPushRecs.map { mrHistoryValueOpt =>
|
||||
mrHistoryValueOpt.map(FI.MrHistory)
|
||||
}
|
||||
} else Future.None
|
||||
} yield {
|
||||
val realGraphFeaturesInputOpt = realGraphFeaturesOpt.map { realGraphFeatures =>
|
||||
FI.TargetRealGraphFeatures(realGraphFeatures)
|
||||
}
|
||||
val dauProbInput = FI.DauProb(dauProb)
|
||||
val mrUserStateInput = FI.MrUserState(mrUserStateOpt.map(_.name).getOrElse("unknown"))
|
||||
HydrationContext(
|
||||
target.targetId,
|
||||
Seq(
|
||||
realGraphFeaturesInputOpt,
|
||||
historyInputOpt,
|
||||
Some(dauProbInput),
|
||||
Some(mrUserStateInput)
|
||||
).flatten.toSet
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,188 @@
|
||||
package com.twitter.frigate.pushservice.ml
|
||||
|
||||
import com.twitter.cortex.deepbird.thriftjava.ModelSelector
|
||||
import com.twitter.finagle.stats.Counter
|
||||
import com.twitter.finagle.stats.StatsReceiver
|
||||
import com.twitter.frigate.common.base.CandidateDetails
|
||||
import com.twitter.frigate.common.base.FeatureMap
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.PushCandidate
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.Target
|
||||
import com.twitter.frigate.pushservice.params.PushMLModel
|
||||
import com.twitter.frigate.pushservice.params.PushModelName
|
||||
import com.twitter.frigate.pushservice.params.WeightedOpenOrNtabClickModel
|
||||
import com.twitter.nrel.heavyranker.PushCandidateHydrationContextWithModel
|
||||
import com.twitter.nrel.heavyranker.PushPredictionServiceStore
|
||||
import com.twitter.nrel.heavyranker.TargetFeatureMapWithModel
|
||||
import com.twitter.timelines.configapi.FSParam
|
||||
import com.twitter.util.Future
|
||||
|
||||
/**
|
||||
* PushMLModelScorer scores the Candidates and populates their ML scores
|
||||
*
|
||||
* @param pushMLModel Enum to specify which model to use for scoring the Candidates
|
||||
* @param modelToPredictionServiceStoreMap Supports all other prediction services. Specifies model ID -> dbv2 ReadableStore
|
||||
* @param defaultDBv2PredictionServiceStore: Supports models that are not specified in the previous maps (which will be directly configured in the config repo)
|
||||
* @param scoringStats StatsReceiver for scoping stats
|
||||
*/
|
||||
class PushMLModelScorer(
|
||||
pushMLModel: PushMLModel.Value,
|
||||
modelToPredictionServiceStoreMap: Map[
|
||||
WeightedOpenOrNtabClickModel.ModelNameType,
|
||||
PushPredictionServiceStore
|
||||
],
|
||||
defaultDBv2PredictionServiceStore: PushPredictionServiceStore,
|
||||
scoringStats: StatsReceiver) {
|
||||
|
||||
val queriesOutsideTheModelMaps: StatsReceiver =
|
||||
scoringStats.scope("queries_outside_the_model_maps")
|
||||
val totalQueriesOutsideTheModelMaps: Counter =
|
||||
queriesOutsideTheModelMaps.counter("total")
|
||||
|
||||
private def scoreByBatchPredictionForModelFromMultiModelService(
|
||||
predictionServiceStore: PushPredictionServiceStore,
|
||||
modelVersion: WeightedOpenOrNtabClickModel.ModelNameType,
|
||||
candidatesDetails: Seq[CandidateDetails[PushCandidate]],
|
||||
useCommonFeatures: Boolean,
|
||||
overridePushMLModel: PushMLModel.Value
|
||||
): Seq[CandidateDetails[PushCandidate]] = {
|
||||
val modelName =
|
||||
PushModelName(overridePushMLModel, modelVersion).toString
|
||||
val modelSelector = new ModelSelector()
|
||||
modelSelector.setId(modelName)
|
||||
|
||||
val candidateHydrationWithFeaturesMap = candidatesDetails.map { candidatesDetail =>
|
||||
(
|
||||
candidatesDetail.candidate.candidateHydrationContext,
|
||||
candidatesDetail.candidate.candidateFeatureMap())
|
||||
}
|
||||
if (candidatesDetails.nonEmpty) {
|
||||
val candidatesWithScore = predictionServiceStore.getBatchPredictionsForModel(
|
||||
candidatesDetails.head.candidate.target.targetHydrationContext,
|
||||
candidatesDetails.head.candidate.target.featureMap,
|
||||
candidateHydrationWithFeaturesMap,
|
||||
Some(modelSelector),
|
||||
useCommonFeatures
|
||||
)
|
||||
candidatesDetails.zip(candidatesWithScore).foreach {
|
||||
case (candidateDetail, (_, scoreOptFut)) =>
|
||||
candidateDetail.candidate.populateQualityModelScore(
|
||||
overridePushMLModel,
|
||||
modelVersion,
|
||||
scoreOptFut
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
candidatesDetails
|
||||
}
|
||||
|
||||
private def scoreByBatchPrediction(
|
||||
modelVersion: WeightedOpenOrNtabClickModel.ModelNameType,
|
||||
candidatesDetails: Seq[CandidateDetails[PushCandidate]],
|
||||
useCommonFeaturesForDBv2Service: Boolean,
|
||||
overridePushMLModel: PushMLModel.Value
|
||||
): Seq[CandidateDetails[PushCandidate]] = {
|
||||
if (modelToPredictionServiceStoreMap.contains(modelVersion)) {
|
||||
scoreByBatchPredictionForModelFromMultiModelService(
|
||||
modelToPredictionServiceStoreMap(modelVersion),
|
||||
modelVersion,
|
||||
candidatesDetails,
|
||||
useCommonFeaturesForDBv2Service,
|
||||
overridePushMLModel
|
||||
)
|
||||
} else {
|
||||
totalQueriesOutsideTheModelMaps.incr()
|
||||
queriesOutsideTheModelMaps.counter(modelVersion).incr()
|
||||
scoreByBatchPredictionForModelFromMultiModelService(
|
||||
defaultDBv2PredictionServiceStore,
|
||||
modelVersion,
|
||||
candidatesDetails,
|
||||
useCommonFeaturesForDBv2Service,
|
||||
overridePushMLModel
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
def scoreByBatchPredictionForModelVersion(
|
||||
target: Target,
|
||||
candidatesDetails: Seq[CandidateDetails[PushCandidate]],
|
||||
modelVersionParam: FSParam[WeightedOpenOrNtabClickModel.ModelNameType],
|
||||
useCommonFeaturesForDBv2Service: Boolean = true,
|
||||
overridePushMLModelOpt: Option[PushMLModel.Value] = None
|
||||
): Seq[CandidateDetails[PushCandidate]] = {
|
||||
scoreByBatchPrediction(
|
||||
target.params(modelVersionParam),
|
||||
candidatesDetails,
|
||||
useCommonFeaturesForDBv2Service,
|
||||
overridePushMLModelOpt.getOrElse(pushMLModel)
|
||||
)
|
||||
}
|
||||
|
||||
def singlePredicationForModelVersion(
|
||||
modelVersion: String,
|
||||
candidate: PushCandidate,
|
||||
overridePushMLModelOpt: Option[PushMLModel.Value] = None
|
||||
): Future[Option[Double]] = {
|
||||
val modelSelector = new ModelSelector()
|
||||
modelSelector.setId(
|
||||
PushModelName(overridePushMLModelOpt.getOrElse(pushMLModel), modelVersion).toString
|
||||
)
|
||||
if (modelToPredictionServiceStoreMap.contains(modelVersion)) {
|
||||
modelToPredictionServiceStoreMap(modelVersion).get(
|
||||
PushCandidateHydrationContextWithModel(
|
||||
candidate.target.targetHydrationContext,
|
||||
candidate.target.featureMap,
|
||||
candidate.candidateHydrationContext,
|
||||
candidate.candidateFeatureMap(),
|
||||
Some(modelSelector)
|
||||
)
|
||||
)
|
||||
} else {
|
||||
totalQueriesOutsideTheModelMaps.incr()
|
||||
queriesOutsideTheModelMaps.counter(modelVersion).incr()
|
||||
defaultDBv2PredictionServiceStore.get(
|
||||
PushCandidateHydrationContextWithModel(
|
||||
candidate.target.targetHydrationContext,
|
||||
candidate.target.featureMap,
|
||||
candidate.candidateHydrationContext,
|
||||
candidate.candidateFeatureMap(),
|
||||
Some(modelSelector)
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
def singlePredictionForTargetLevel(
|
||||
modelVersion: String,
|
||||
targetId: Long,
|
||||
featureMap: Future[FeatureMap]
|
||||
): Future[Option[Double]] = {
|
||||
val modelSelector = new ModelSelector()
|
||||
modelSelector.setId(
|
||||
PushModelName(pushMLModel, modelVersion).toString
|
||||
)
|
||||
defaultDBv2PredictionServiceStore.getForTargetLevel(
|
||||
TargetFeatureMapWithModel(targetId, featureMap, Some(modelSelector))
|
||||
)
|
||||
}
|
||||
|
||||
def getScoreHistogramCounters(
|
||||
stats: StatsReceiver,
|
||||
scopeName: String,
|
||||
histogramBinSize: Double
|
||||
): IndexedSeq[Counter] = {
|
||||
val histogramScopedStatsReceiver = stats.scope(scopeName)
|
||||
val numBins = math.ceil(1.0 / histogramBinSize).toInt
|
||||
|
||||
(0 to numBins) map { k =>
|
||||
if (k == 0)
|
||||
histogramScopedStatsReceiver.counter("candidates_with_scores_zero")
|
||||
else {
|
||||
val counterName = "candidates_with_scores_from_%s_to_%s".format(
|
||||
"%.2f".format(histogramBinSize * (k - 1)).replace(".", ""),
|
||||
"%.2f".format(math.min(1.0, histogramBinSize * k)).replace(".", ""))
|
||||
histogramScopedStatsReceiver.counter(counterName)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,89 @@
|
||||
package com.twitter.frigate.pushservice.model
|
||||
|
||||
import com.twitter.finagle.stats.StatsReceiver
|
||||
import com.twitter.frigate.common.base.DiscoverTwitterCandidate
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.PushCandidate
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.RawCandidate
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.Target
|
||||
import com.twitter.frigate.pushservice.config.Config
|
||||
import com.twitter.frigate.pushservice.ml.PushMLModelScorer
|
||||
import com.twitter.frigate.pushservice.model.candidate.CopyIds
|
||||
import com.twitter.frigate.pushservice.model.ibis.DiscoverTwitterPushIbis2Hydrator
|
||||
import com.twitter.frigate.pushservice.model.ntab.DiscoverTwitterNtabRequestHydrator
|
||||
import com.twitter.frigate.pushservice.params.PushFeatureSwitchParams
|
||||
import com.twitter.frigate.pushservice.predicate.PredicatesForCandidate
|
||||
import com.twitter.frigate.pushservice.take.predicates.BasicRFPHPredicates
|
||||
import com.twitter.frigate.pushservice.take.predicates.OutOfNetworkTweetPredicates
|
||||
import com.twitter.frigate.thriftscala.CommonRecommendationType
|
||||
import com.twitter.hermit.predicate.NamedPredicate
|
||||
|
||||
class DiscoverTwitterPushCandidate(
|
||||
candidate: RawCandidate with DiscoverTwitterCandidate,
|
||||
copyIds: CopyIds,
|
||||
)(
|
||||
implicit val statsScoped: StatsReceiver,
|
||||
pushModelScorer: PushMLModelScorer)
|
||||
extends PushCandidate
|
||||
with DiscoverTwitterCandidate
|
||||
with DiscoverTwitterPushIbis2Hydrator
|
||||
with DiscoverTwitterNtabRequestHydrator {
|
||||
|
||||
override val pushCopyId: Option[Int] = copyIds.pushCopyId
|
||||
|
||||
override val ntabCopyId: Option[Int] = copyIds.ntabCopyId
|
||||
|
||||
override val copyAggregationId: Option[String] = copyIds.aggregationId
|
||||
|
||||
override val target: Target = candidate.target
|
||||
|
||||
override lazy val commonRecType: CommonRecommendationType = candidate.commonRecType
|
||||
|
||||
override val weightedOpenOrNtabClickModelScorer: PushMLModelScorer = pushModelScorer
|
||||
|
||||
override val statsReceiver: StatsReceiver =
|
||||
statsScoped.scope("DiscoverTwitterPushCandidate")
|
||||
}
|
||||
|
||||
case class AddressBookPushCandidatePredicates(config: Config)
|
||||
extends BasicRFPHPredicates[DiscoverTwitterPushCandidate] {
|
||||
|
||||
implicit val statsReceiver: StatsReceiver = config.statsReceiver.scope(getClass.getSimpleName)
|
||||
|
||||
override val predicates: List[
|
||||
NamedPredicate[DiscoverTwitterPushCandidate]
|
||||
] =
|
||||
List(
|
||||
PredicatesForCandidate.paramPredicate(
|
||||
PushFeatureSwitchParams.EnableAddressBookPush
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
case class CompleteOnboardingPushCandidatePredicates(config: Config)
|
||||
extends BasicRFPHPredicates[DiscoverTwitterPushCandidate] {
|
||||
|
||||
implicit val statsReceiver: StatsReceiver = config.statsReceiver.scope(getClass.getSimpleName)
|
||||
|
||||
override val predicates: List[
|
||||
NamedPredicate[DiscoverTwitterPushCandidate]
|
||||
] =
|
||||
List(
|
||||
PredicatesForCandidate.paramPredicate(
|
||||
PushFeatureSwitchParams.EnableCompleteOnboardingPush
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
case class PopGeoTweetCandidatePredicates(override val config: Config)
|
||||
extends OutOfNetworkTweetPredicates[OutOfNetworkTweetPushCandidate] {
|
||||
|
||||
implicit val statsReceiver: StatsReceiver = config.statsReceiver.scope(getClass.getSimpleName)
|
||||
|
||||
override def postCandidateSpecificPredicates: List[
|
||||
NamedPredicate[OutOfNetworkTweetPushCandidate]
|
||||
] = List(
|
||||
PredicatesForCandidate.htlFatiguePredicate(
|
||||
PushFeatureSwitchParams.NewUserPlaybookAllowedLastLoginHours
|
||||
)
|
||||
)
|
||||
}
|
@ -0,0 +1,60 @@
|
||||
package com.twitter.frigate.pushservice.model
|
||||
|
||||
import com.twitter.finagle.stats.StatsReceiver
|
||||
import com.twitter.frigate.common.base.F1FirstDegree
|
||||
import com.twitter.frigate.common.base.SocialContextAction
|
||||
import com.twitter.frigate.common.base.SocialGraphServiceRelationshipMap
|
||||
import com.twitter.frigate.common.base.TweetAuthorDetails
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.PushCandidate
|
||||
import com.twitter.frigate.pushservice.model.PushTypes._
|
||||
import com.twitter.frigate.pushservice.config.Config
|
||||
import com.twitter.frigate.pushservice.ml.PushMLModelScorer
|
||||
import com.twitter.frigate.pushservice.model.candidate.CopyIds
|
||||
import com.twitter.frigate.pushservice.model.ibis.F1FirstDegreeTweetIbis2HydratorForCandidate
|
||||
import com.twitter.frigate.pushservice.model.ntab.F1FirstDegreeTweetNTabRequestHydrator
|
||||
import com.twitter.frigate.pushservice.take.predicates.BasicTweetPredicatesForRFPHWithoutSGSPredicates
|
||||
import com.twitter.frigate.pushservice.util.CandidateHydrationUtil.TweetWithSocialContextTraits
|
||||
import com.twitter.frigate.thriftscala.CommonRecommendationType
|
||||
import com.twitter.gizmoduck.thriftscala.User
|
||||
import com.twitter.hermit.predicate.socialgraph.RelationEdge
|
||||
import com.twitter.stitch.tweetypie.TweetyPie
|
||||
import com.twitter.util.Future
|
||||
|
||||
class F1TweetPushCandidate(
|
||||
candidate: RawCandidate with TweetWithSocialContextTraits,
|
||||
author: Future[Option[User]],
|
||||
socialGraphServiceResultMap: Map[RelationEdge, Boolean],
|
||||
copyIds: CopyIds
|
||||
)(
|
||||
implicit stats: StatsReceiver,
|
||||
pushModelScorer: PushMLModelScorer)
|
||||
extends PushCandidate
|
||||
with F1FirstDegree
|
||||
with TweetAuthorDetails
|
||||
with SocialGraphServiceRelationshipMap
|
||||
with F1FirstDegreeTweetNTabRequestHydrator
|
||||
with F1FirstDegreeTweetIbis2HydratorForCandidate {
|
||||
override val socialContextActions: Seq[SocialContextAction] =
|
||||
candidate.socialContextActions
|
||||
override val socialContextAllTypeActions: Seq[SocialContextAction] =
|
||||
candidate.socialContextActions
|
||||
override val statsReceiver: StatsReceiver = stats
|
||||
override val weightedOpenOrNtabClickModelScorer: PushMLModelScorer = pushModelScorer
|
||||
override val tweetId: Long = candidate.tweetId
|
||||
override lazy val tweetyPieResult: Option[TweetyPie.TweetyPieResult] =
|
||||
candidate.tweetyPieResult
|
||||
override lazy val tweetAuthor: Future[Option[User]] = author
|
||||
override val target: PushTypes.Target = candidate.target
|
||||
override lazy val commonRecType: CommonRecommendationType =
|
||||
candidate.commonRecType
|
||||
override val pushCopyId: Option[Int] = copyIds.pushCopyId
|
||||
override val ntabCopyId: Option[Int] = copyIds.ntabCopyId
|
||||
override val copyAggregationId: Option[String] = copyIds.aggregationId
|
||||
|
||||
override val relationshipMap: Map[RelationEdge, Boolean] = socialGraphServiceResultMap
|
||||
}
|
||||
|
||||
case class F1TweetCandidatePredicates(override val config: Config)
|
||||
extends BasicTweetPredicatesForRFPHWithoutSGSPredicates[F1TweetPushCandidate] {
|
||||
implicit val statsReceiver: StatsReceiver = config.statsReceiver.scope(getClass.getSimpleName)
|
||||
}
|
@ -0,0 +1,72 @@
|
||||
package com.twitter.frigate.pushservice.model
|
||||
|
||||
import com.twitter.channels.common.thriftscala.ApiList
|
||||
import com.twitter.finagle.stats.StatsReceiver
|
||||
import com.twitter.frigate.common.base.ListPushCandidate
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.PushCandidate
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.RawCandidate
|
||||
import com.twitter.frigate.pushservice.config.Config
|
||||
import com.twitter.frigate.pushservice.ml.PushMLModelScorer
|
||||
import com.twitter.frigate.pushservice.model.candidate.CopyIds
|
||||
import com.twitter.frigate.pushservice.model.ibis.ListIbis2Hydrator
|
||||
import com.twitter.frigate.pushservice.model.ntab.ListCandidateNTabRequestHydrator
|
||||
import com.twitter.frigate.pushservice.predicate.ListPredicates
|
||||
import com.twitter.frigate.pushservice.take.predicates.BasicRFPHPredicates
|
||||
import com.twitter.frigate.thriftscala.CommonRecommendationType
|
||||
import com.twitter.hermit.predicate.NamedPredicate
|
||||
import com.twitter.storehaus.ReadableStore
|
||||
import com.twitter.util.Future
|
||||
|
||||
class ListRecommendationPushCandidate(
|
||||
val apiListStore: ReadableStore[Long, ApiList],
|
||||
candidate: RawCandidate with ListPushCandidate,
|
||||
copyIds: CopyIds
|
||||
)(
|
||||
implicit stats: StatsReceiver,
|
||||
pushModelScorer: PushMLModelScorer)
|
||||
extends PushCandidate
|
||||
with ListPushCandidate
|
||||
with ListIbis2Hydrator
|
||||
with ListCandidateNTabRequestHydrator {
|
||||
|
||||
override val commonRecType: CommonRecommendationType = candidate.commonRecType
|
||||
|
||||
override val pushCopyId: Option[Int] = copyIds.pushCopyId
|
||||
|
||||
override val ntabCopyId: Option[Int] = copyIds.ntabCopyId
|
||||
|
||||
override val copyAggregationId: Option[String] = copyIds.aggregationId
|
||||
|
||||
override val statsReceiver: StatsReceiver = stats
|
||||
|
||||
override val weightedOpenOrNtabClickModelScorer: PushMLModelScorer = pushModelScorer
|
||||
|
||||
override val target: PushTypes.Target = candidate.target
|
||||
|
||||
override val listId: Long = candidate.listId
|
||||
|
||||
lazy val apiList: Future[Option[ApiList]] = apiListStore.get(listId)
|
||||
|
||||
lazy val listName: Future[Option[String]] = apiList.map { apiListOpt =>
|
||||
apiListOpt.map(_.name)
|
||||
}
|
||||
|
||||
lazy val listOwnerId: Future[Option[Long]] = apiList.map { apiListOpt =>
|
||||
apiListOpt.map(_.ownerId)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
case class ListRecommendationPredicates(config: Config)
|
||||
extends BasicRFPHPredicates[ListRecommendationPushCandidate] {
|
||||
|
||||
implicit val statsReceiver: StatsReceiver = config.statsReceiver.scope(getClass.getSimpleName)
|
||||
|
||||
override val predicates: List[NamedPredicate[ListRecommendationPushCandidate]] = List(
|
||||
ListPredicates.listNameExistsPredicate(),
|
||||
ListPredicates.listAuthorExistsPredicate(),
|
||||
ListPredicates.listAuthorAcceptableToTargetUser(config.edgeStore),
|
||||
ListPredicates.listAcceptablePredicate(),
|
||||
ListPredicates.listSubscriberCountPredicate()
|
||||
)
|
||||
}
|
@ -0,0 +1,136 @@
|
||||
package com.twitter.frigate.pushservice.model
|
||||
|
||||
import com.twitter.finagle.stats.StatsReceiver
|
||||
import com.twitter.frigate.common.base.HydratedMagicFanoutCreatorEventCandidate
|
||||
import com.twitter.frigate.common.base.MagicFanoutCreatorEventCandidate
|
||||
import com.twitter.frigate.magic_events.thriftscala.CreatorFanoutType
|
||||
import com.twitter.frigate.magic_events.thriftscala.MagicEventsReason
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.PushCandidate
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.RawCandidate
|
||||
import com.twitter.frigate.pushservice.config.Config
|
||||
import com.twitter.frigate.pushservice.ml.PushMLModelScorer
|
||||
import com.twitter.frigate.pushservice.model.candidate.CopyIds
|
||||
import com.twitter.frigate.pushservice.model.ibis.MagicFanoutCreatorEventIbis2Hydrator
|
||||
import com.twitter.frigate.pushservice.model.ntab.MagicFanoutCreatorEventNtabRequestHydrator
|
||||
import com.twitter.frigate.pushservice.params.PushFeatureSwitchParams
|
||||
import com.twitter.frigate.pushservice.predicate.PredicatesForCandidate
|
||||
import com.twitter.frigate.pushservice.predicate.magic_fanout.MagicFanoutPredicatesForCandidate
|
||||
import com.twitter.frigate.pushservice.predicate.ntab_caret_fatigue.MagicFanoutNtabCaretFatiguePredicate
|
||||
import com.twitter.frigate.pushservice.take.predicates.BasicSendHandlerPredicates
|
||||
import com.twitter.frigate.thriftscala.CommonRecommendationType
|
||||
import com.twitter.frigate.thriftscala.FrigateNotification
|
||||
import com.twitter.gizmoduck.thriftscala.User
|
||||
import com.twitter.hermit.predicate.NamedPredicate
|
||||
import com.twitter.storehaus.ReadableStore
|
||||
import com.twitter.strato.client.UserId
|
||||
import com.twitter.util.Future
|
||||
import scala.util.control.NoStackTrace
|
||||
|
||||
class MagicFanoutCreatorEventPushCandidateHydratorException(private val message: String)
|
||||
extends Exception(message)
|
||||
with NoStackTrace
|
||||
|
||||
class MagicFanoutCreatorEventPushCandidate(
|
||||
candidate: RawCandidate with MagicFanoutCreatorEventCandidate,
|
||||
creatorUser: Option[User],
|
||||
copyIds: CopyIds,
|
||||
creatorTweetCountStore: ReadableStore[UserId, Int]
|
||||
)(
|
||||
implicit val statsScoped: StatsReceiver,
|
||||
pushModelScorer: PushMLModelScorer)
|
||||
extends PushCandidate
|
||||
with MagicFanoutCreatorEventIbis2Hydrator
|
||||
with MagicFanoutCreatorEventNtabRequestHydrator
|
||||
with MagicFanoutCreatorEventCandidate
|
||||
with HydratedMagicFanoutCreatorEventCandidate {
|
||||
override def creatorId: Long = candidate.creatorId
|
||||
|
||||
override def hydratedCreator: Option[User] = creatorUser
|
||||
|
||||
override lazy val numberOfTweetsFut: Future[Option[Int]] =
|
||||
creatorTweetCountStore.get(UserId(creatorId))
|
||||
|
||||
lazy val userProfile = hydratedCreator
|
||||
.flatMap(_.profile).getOrElse(
|
||||
throw new MagicFanoutCreatorEventPushCandidateHydratorException(
|
||||
s"Unable to get user profile to generate tapThrough for userId: $creatorId"))
|
||||
|
||||
override val frigateNotification: FrigateNotification = candidate.frigateNotification
|
||||
|
||||
override def subscriberId: Option[Long] = candidate.subscriberId
|
||||
|
||||
override def creatorFanoutType: CreatorFanoutType = candidate.creatorFanoutType
|
||||
|
||||
override def target: PushTypes.Target = candidate.target
|
||||
|
||||
override def pushId: Long = candidate.pushId
|
||||
|
||||
override def candidateMagicEventsReasons: Seq[MagicEventsReason] =
|
||||
candidate.candidateMagicEventsReasons
|
||||
|
||||
override def statsReceiver: StatsReceiver = statsScoped
|
||||
|
||||
override def pushCopyId: Option[Int] = copyIds.pushCopyId
|
||||
|
||||
override def ntabCopyId: Option[Int] = copyIds.ntabCopyId
|
||||
|
||||
override def copyAggregationId: Option[String] = copyIds.aggregationId
|
||||
|
||||
override def commonRecType: CommonRecommendationType = candidate.commonRecType
|
||||
|
||||
override def weightedOpenOrNtabClickModelScorer: PushMLModelScorer = pushModelScorer
|
||||
|
||||
}
|
||||
|
||||
case class MagicFanouCreatorSubscriptionEventPushPredicates(config: Config)
|
||||
extends BasicSendHandlerPredicates[MagicFanoutCreatorEventPushCandidate] {
|
||||
|
||||
implicit val statsReceiver: StatsReceiver = config.statsReceiver.scope(getClass.getSimpleName)
|
||||
|
||||
override val preCandidateSpecificPredicates: List[
|
||||
NamedPredicate[MagicFanoutCreatorEventPushCandidate]
|
||||
] =
|
||||
List(
|
||||
PredicatesForCandidate.paramPredicate(
|
||||
PushFeatureSwitchParams.EnableCreatorSubscriptionPush
|
||||
),
|
||||
PredicatesForCandidate.isDeviceEligibleForCreatorPush,
|
||||
MagicFanoutPredicatesForCandidate.creatorPushTargetIsNotCreator(),
|
||||
MagicFanoutPredicatesForCandidate.duplicateCreatorPredicate,
|
||||
MagicFanoutPredicatesForCandidate.magicFanoutCreatorPushFatiguePredicate(),
|
||||
)
|
||||
|
||||
override val postCandidateSpecificPredicates: List[
|
||||
NamedPredicate[MagicFanoutCreatorEventPushCandidate]
|
||||
] =
|
||||
List(
|
||||
MagicFanoutNtabCaretFatiguePredicate(),
|
||||
MagicFanoutPredicatesForCandidate.isSuperFollowingCreator()(config, statsReceiver).flip
|
||||
)
|
||||
}
|
||||
|
||||
case class MagicFanoutNewCreatorEventPushPredicates(config: Config)
|
||||
extends BasicSendHandlerPredicates[MagicFanoutCreatorEventPushCandidate] {
|
||||
|
||||
implicit val statsReceiver: StatsReceiver = config.statsReceiver.scope(getClass.getSimpleName)
|
||||
|
||||
override val preCandidateSpecificPredicates: List[
|
||||
NamedPredicate[MagicFanoutCreatorEventPushCandidate]
|
||||
] =
|
||||
List(
|
||||
PredicatesForCandidate.paramPredicate(
|
||||
PushFeatureSwitchParams.EnableNewCreatorPush
|
||||
),
|
||||
PredicatesForCandidate.isDeviceEligibleForCreatorPush,
|
||||
MagicFanoutPredicatesForCandidate.duplicateCreatorPredicate,
|
||||
MagicFanoutPredicatesForCandidate.magicFanoutCreatorPushFatiguePredicate,
|
||||
)
|
||||
|
||||
override val postCandidateSpecificPredicates: List[
|
||||
NamedPredicate[MagicFanoutCreatorEventPushCandidate]
|
||||
] =
|
||||
List(
|
||||
MagicFanoutNtabCaretFatiguePredicate(),
|
||||
MagicFanoutPredicatesForCandidate.isSuperFollowingCreator()(config, statsReceiver).flip
|
||||
)
|
||||
}
|
@ -0,0 +1,303 @@
|
||||
package com.twitter.frigate.pushservice.model
|
||||
|
||||
import com.twitter.escherbird.metadata.thriftscala.EntityMegadata
|
||||
import com.twitter.finagle.stats.StatsReceiver
|
||||
import com.twitter.frigate.common.base.MagicFanoutEventCandidate
|
||||
import com.twitter.frigate.common.base.RecommendationType
|
||||
import com.twitter.frigate.common.store.interests.InterestsLookupRequestWithContext
|
||||
import com.twitter.frigate.common.util.HighPriorityLocaleUtil
|
||||
import com.twitter.frigate.magic_events.thriftscala.FanoutEvent
|
||||
import com.twitter.frigate.magic_events.thriftscala.FanoutMetadata
|
||||
import com.twitter.frigate.magic_events.thriftscala.MagicEventsReason
|
||||
import com.twitter.frigate.magic_events.thriftscala.NewsForYouMetadata
|
||||
import com.twitter.frigate.magic_events.thriftscala.ReasonSource
|
||||
import com.twitter.frigate.magic_events.thriftscala.TargetID
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.PushCandidate
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.RawCandidate
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.Target
|
||||
import com.twitter.frigate.pushservice.ml.PushMLModelScorer
|
||||
import com.twitter.frigate.pushservice.model.candidate.CopyIds
|
||||
import com.twitter.frigate.pushservice.model.ibis.Ibis2HydratorForCandidate
|
||||
import com.twitter.frigate.pushservice.model.ntab.EventNTabRequestHydrator
|
||||
import com.twitter.frigate.pushservice.params.PushFeatureSwitchParams
|
||||
import com.twitter.frigate.pushservice.predicate.magic_fanout.MagicFanoutPredicatesUtil
|
||||
import com.twitter.frigate.pushservice.store.EventRequest
|
||||
import com.twitter.frigate.pushservice.store.UttEntityHydrationStore
|
||||
import com.twitter.frigate.pushservice.util.PushDeviceUtil
|
||||
import com.twitter.frigate.pushservice.util.TopicsUtil
|
||||
import com.twitter.frigate.thriftscala.FrigateNotification
|
||||
import com.twitter.frigate.thriftscala.MagicFanoutEventNotificationDetails
|
||||
import com.twitter.hermit.store.semantic_core.SemanticEntityForQuery
|
||||
import com.twitter.interests.thriftscala.InterestId.SemanticCore
|
||||
import com.twitter.interests.thriftscala.UserInterests
|
||||
import com.twitter.livevideo.common.ids.CountryId
|
||||
import com.twitter.livevideo.common.ids.UserId
|
||||
import com.twitter.livevideo.timeline.domain.v2.Event
|
||||
import com.twitter.livevideo.timeline.domain.v2.HydrationOptions
|
||||
import com.twitter.livevideo.timeline.domain.v2.LookupContext
|
||||
import com.twitter.simclusters_v2.thriftscala.SimClustersInferredEntities
|
||||
import com.twitter.storehaus.ReadableStore
|
||||
import com.twitter.topiclisting.utt.LocalizedEntity
|
||||
import com.twitter.util.Future
|
||||
|
||||
abstract class MagicFanoutEventPushCandidate(
|
||||
candidate: RawCandidate with MagicFanoutEventCandidate with RecommendationType,
|
||||
copyIds: CopyIds,
|
||||
override val fanoutEvent: Option[FanoutEvent],
|
||||
override val semanticEntityResults: Map[SemanticEntityForQuery, Option[EntityMegadata]],
|
||||
simClusterToEntities: Map[Int, Option[SimClustersInferredEntities]],
|
||||
lexServiceStore: ReadableStore[EventRequest, Event],
|
||||
interestsLookupStore: ReadableStore[InterestsLookupRequestWithContext, UserInterests],
|
||||
uttEntityHydrationStore: UttEntityHydrationStore
|
||||
)(
|
||||
implicit statsScoped: StatsReceiver,
|
||||
pushModelScorer: PushMLModelScorer)
|
||||
extends PushCandidate
|
||||
with MagicFanoutEventHydratedCandidate
|
||||
with MagicFanoutEventCandidate
|
||||
with EventNTabRequestHydrator
|
||||
with RecommendationType
|
||||
with Ibis2HydratorForCandidate {
|
||||
|
||||
override lazy val eventFut: Future[Option[Event]] = {
|
||||
eventRequestFut.flatMap {
|
||||
case Some(eventRequest) => lexServiceStore.get(eventRequest)
|
||||
case _ => Future.None
|
||||
}
|
||||
}
|
||||
|
||||
override val frigateNotification: FrigateNotification = candidate.frigateNotification
|
||||
|
||||
override val pushId: Long = candidate.pushId
|
||||
|
||||
override val candidateMagicEventsReasons: Seq[MagicEventsReason] =
|
||||
candidate.candidateMagicEventsReasons
|
||||
|
||||
override val eventId: Long = candidate.eventId
|
||||
|
||||
override val momentId: Option[Long] = candidate.momentId
|
||||
|
||||
override val target: Target = candidate.target
|
||||
|
||||
override val eventLanguage: Option[String] = candidate.eventLanguage
|
||||
|
||||
override val details: Option[MagicFanoutEventNotificationDetails] = candidate.details
|
||||
|
||||
override lazy val stats: StatsReceiver = statsScoped.scope("MagicFanoutEventPushCandidate")
|
||||
|
||||
override val weightedOpenOrNtabClickModelScorer: PushMLModelScorer = pushModelScorer
|
||||
|
||||
override val pushCopyId: Option[Int] = copyIds.pushCopyId
|
||||
|
||||
override val ntabCopyId: Option[Int] = copyIds.ntabCopyId
|
||||
|
||||
override val copyAggregationId: Option[String] = copyIds.aggregationId
|
||||
|
||||
override val statsReceiver: StatsReceiver = statsScoped.scope("MagicFanoutEventPushCandidate")
|
||||
|
||||
override val effectiveMagicEventsReasons: Option[Seq[MagicEventsReason]] = Some(
|
||||
candidateMagicEventsReasons)
|
||||
|
||||
lazy val newsForYouMetadata: Option[NewsForYouMetadata] =
|
||||
fanoutEvent.flatMap { event =>
|
||||
{
|
||||
event.fanoutMetadata.collect {
|
||||
case FanoutMetadata.NewsForYouMetadata(nfyMetadata) => nfyMetadata
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
val reverseIndexedTopicIds = candidate.candidateMagicEventsReasons
|
||||
.filter(_.source.contains(ReasonSource.UttTopicFollowGraph))
|
||||
.map(_.reason).collect {
|
||||
case TargetID.SemanticCoreID(semanticCoreID) => semanticCoreID.entityId
|
||||
}.toSet
|
||||
|
||||
val ergSemanticCoreIds = candidate.candidateMagicEventsReasons
|
||||
.filter(_.source.contains(ReasonSource.ErgShortTermInterestSemanticCore)).map(
|
||||
_.reason).collect {
|
||||
case TargetID.SemanticCoreID(semanticCoreID) => semanticCoreID.entityId
|
||||
}.toSet
|
||||
|
||||
override lazy val ergLocalizedEntities = TopicsUtil
|
||||
.getLocalizedEntityMap(target, ergSemanticCoreIds, uttEntityHydrationStore)
|
||||
.map { localizedEntityMap =>
|
||||
ergSemanticCoreIds.collect {
|
||||
case topicId if localizedEntityMap.contains(topicId) => localizedEntityMap(topicId)
|
||||
}
|
||||
}
|
||||
|
||||
val eventSemanticCoreEntityIds: Seq[Long] = {
|
||||
val entityIds = for {
|
||||
event <- fanoutEvent
|
||||
targets <- event.targets
|
||||
} yield {
|
||||
targets.flatMap {
|
||||
_.whitelist.map {
|
||||
_.collect {
|
||||
case TargetID.SemanticCoreID(semanticCoreID) => semanticCoreID.entityId
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
entityIds.map(_.flatten).getOrElse(Seq.empty)
|
||||
}
|
||||
|
||||
val eventSemanticCoreDomainIds: Seq[Long] = {
|
||||
val domainIds = for {
|
||||
event <- fanoutEvent
|
||||
targets <- event.targets
|
||||
} yield {
|
||||
targets.flatMap {
|
||||
_.whitelist.map {
|
||||
_.collect {
|
||||
case TargetID.SemanticCoreID(semanticCoreID) => semanticCoreID.domainId
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
domainIds.map(_.flatten).getOrElse(Seq.empty)
|
||||
}
|
||||
|
||||
override lazy val followedTopicLocalizedEntities: Future[Set[LocalizedEntity]] = {
|
||||
|
||||
val isNewSignupTargetingReason = candidateMagicEventsReasons.size == 1 &&
|
||||
candidateMagicEventsReasons.headOption.exists(_.source.contains(ReasonSource.NewSignup))
|
||||
|
||||
val shouldFetchTopicFollows = reverseIndexedTopicIds.nonEmpty || isNewSignupTargetingReason
|
||||
|
||||
val topicFollows = if (shouldFetchTopicFollows) {
|
||||
TopicsUtil
|
||||
.getTopicsFollowedByUser(
|
||||
candidate.target,
|
||||
interestsLookupStore,
|
||||
stats.stat("followed_topics")
|
||||
).map { _.getOrElse(Seq.empty) }.map {
|
||||
_.flatMap {
|
||||
_.interestId match {
|
||||
case SemanticCore(semanticCore) => Some(semanticCore.id)
|
||||
case _ => None
|
||||
}
|
||||
}
|
||||
}
|
||||
} else Future.Nil
|
||||
|
||||
topicFollows.flatMap { followedTopicIds =>
|
||||
val topicIds = if (isNewSignupTargetingReason) {
|
||||
// if new signup is the only targeting reason then we check the event targeting reason
|
||||
// against realtime topic follows.
|
||||
eventSemanticCoreEntityIds.toSet.intersect(followedTopicIds.toSet)
|
||||
} else {
|
||||
// check against the fanout reason of topics
|
||||
followedTopicIds.toSet.intersect(reverseIndexedTopicIds)
|
||||
}
|
||||
|
||||
TopicsUtil
|
||||
.getLocalizedEntityMap(target, topicIds, uttEntityHydrationStore)
|
||||
.map { localizedEntityMap =>
|
||||
topicIds.collect {
|
||||
case topicId if localizedEntityMap.contains(topicId) => localizedEntityMap(topicId)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
lazy val simClusterToEntityMapping: Map[Int, Seq[Long]] =
|
||||
simClusterToEntities.flatMap {
|
||||
case (clusterId, Some(inferredEntities)) =>
|
||||
statsReceiver.counter("with_cluster_to_entity_mapping").incr()
|
||||
Some(
|
||||
(
|
||||
clusterId,
|
||||
inferredEntities.entities
|
||||
.map(_.entityId)))
|
||||
case _ =>
|
||||
statsReceiver.counter("without_cluster_to_entity_mapping").incr()
|
||||
None
|
||||
}
|
||||
|
||||
lazy val annotatedAndInferredSemanticCoreEntities: Seq[Long] =
|
||||
(simClusterToEntityMapping, eventFanoutReasonEntities) match {
|
||||
case (entityMapping, eventFanoutReasons) =>
|
||||
entityMapping.values.flatten.toSeq ++
|
||||
eventFanoutReasons.semanticCoreIds.map(_.entityId)
|
||||
}
|
||||
|
||||
lazy val shouldHydrateSquareImage = target.deviceInfo.map { deviceInfo =>
|
||||
(PushDeviceUtil.isPrimaryDeviceIOS(deviceInfo) &&
|
||||
target.params(PushFeatureSwitchParams.EnableEventSquareMediaIosMagicFanoutNewsEvent)) ||
|
||||
(PushDeviceUtil.isPrimaryDeviceAndroid(deviceInfo) &&
|
||||
target.params(PushFeatureSwitchParams.EnableEventSquareMediaAndroid))
|
||||
}
|
||||
|
||||
lazy val shouldHydratePrimaryImage: Future[Boolean] = target.deviceInfo.map { deviceInfo =>
|
||||
(PushDeviceUtil.isPrimaryDeviceAndroid(deviceInfo) &&
|
||||
target.params(PushFeatureSwitchParams.EnableEventPrimaryMediaAndroid))
|
||||
}
|
||||
|
||||
lazy val eventRequestFut: Future[Option[EventRequest]] =
|
||||
Future
|
||||
.join(
|
||||
target.inferredUserDeviceLanguage,
|
||||
target.accountCountryCode,
|
||||
shouldHydrateSquareImage,
|
||||
shouldHydratePrimaryImage).map {
|
||||
case (
|
||||
inferredUserDeviceLanguage,
|
||||
accountCountryCode,
|
||||
shouldHydrateSquareImage,
|
||||
shouldHydratePrimaryImage) =>
|
||||
if (shouldHydrateSquareImage || shouldHydratePrimaryImage) {
|
||||
Some(
|
||||
EventRequest(
|
||||
eventId,
|
||||
lookupContext = LookupContext(
|
||||
hydrationOptions = HydrationOptions(
|
||||
includeSquareImage = shouldHydrateSquareImage,
|
||||
includePrimaryImage = shouldHydratePrimaryImage
|
||||
),
|
||||
language = inferredUserDeviceLanguage,
|
||||
countryCode = accountCountryCode,
|
||||
userId = Some(UserId(target.targetId))
|
||||
)
|
||||
))
|
||||
} else {
|
||||
Some(
|
||||
EventRequest(
|
||||
eventId,
|
||||
lookupContext = LookupContext(
|
||||
language = inferredUserDeviceLanguage,
|
||||
countryCode = accountCountryCode
|
||||
)
|
||||
))
|
||||
}
|
||||
case _ => None
|
||||
}
|
||||
|
||||
lazy val isHighPriorityEvent: Future[Boolean] = target.accountCountryCode.map { countryCodeOpt =>
|
||||
val isHighPriorityPushOpt = for {
|
||||
countryCode <- countryCodeOpt
|
||||
nfyMetadata <- newsForYouMetadata
|
||||
eventContext <- nfyMetadata.eventContextScribe
|
||||
} yield {
|
||||
val highPriorityLocales = HighPriorityLocaleUtil.getHighPriorityLocales(
|
||||
eventContext = eventContext,
|
||||
defaultLocalesOpt = nfyMetadata.locales)
|
||||
val highPriorityGeos = HighPriorityLocaleUtil.getHighPriorityGeos(
|
||||
eventContext = eventContext,
|
||||
defaultGeoPlaceIdsOpt = nfyMetadata.placeIds)
|
||||
val isHighPriorityLocalePush =
|
||||
highPriorityLocales.flatMap(_.country).map(CountryId(_)).contains(CountryId(countryCode))
|
||||
val isHighPriorityGeoPush = MagicFanoutPredicatesUtil
|
||||
.geoPlaceIdsFromReasons(candidateMagicEventsReasons)
|
||||
.intersect(highPriorityGeos.toSet)
|
||||
.nonEmpty
|
||||
stats.scope("is_high_priority_locale_push").counter(s"$isHighPriorityLocalePush").incr()
|
||||
stats.scope("is_high_priority_geo_push").counter(s"$isHighPriorityGeoPush").incr()
|
||||
isHighPriorityLocalePush || isHighPriorityGeoPush
|
||||
}
|
||||
isHighPriorityPushOpt.getOrElse(false)
|
||||
}
|
||||
}
|
@ -0,0 +1,147 @@
|
||||
package com.twitter.frigate.pushservice.model
|
||||
|
||||
import com.twitter.escherbird.common.thriftscala.QualifiedId
|
||||
import com.twitter.escherbird.metadata.thriftscala.BasicMetadata
|
||||
import com.twitter.escherbird.metadata.thriftscala.EntityIndexFields
|
||||
import com.twitter.escherbird.metadata.thriftscala.EntityMegadata
|
||||
import com.twitter.finagle.stats.StatsReceiver
|
||||
import com.twitter.frigate.common.base.MagicFanoutCandidate
|
||||
import com.twitter.frigate.common.base.MagicFanoutEventCandidate
|
||||
import com.twitter.frigate.common.base.RichEventFutCandidate
|
||||
import com.twitter.frigate.magic_events.thriftscala
|
||||
import com.twitter.frigate.magic_events.thriftscala.AnnotationAlg
|
||||
import com.twitter.frigate.magic_events.thriftscala.FanoutEvent
|
||||
import com.twitter.frigate.magic_events.thriftscala.MagicEventsReason
|
||||
import com.twitter.frigate.magic_events.thriftscala.SemanticCoreID
|
||||
import com.twitter.frigate.magic_events.thriftscala.SimClusterID
|
||||
import com.twitter.frigate.magic_events.thriftscala.TargetID
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.PushCandidate
|
||||
import com.twitter.hermit.store.semantic_core.SemanticEntityForQuery
|
||||
import com.twitter.livevideo.timeline.domain.v2.Event
|
||||
import com.twitter.topiclisting.utt.LocalizedEntity
|
||||
import com.twitter.util.Future
|
||||
|
||||
case class FanoutReasonEntities(
|
||||
userIds: Set[Long],
|
||||
placeIds: Set[Long],
|
||||
semanticCoreIds: Set[SemanticCoreID],
|
||||
simclusterIds: Set[SimClusterID]) {
|
||||
val qualifiedIds: Set[QualifiedId] =
|
||||
semanticCoreIds.map(e => QualifiedId(e.domainId, e.entityId))
|
||||
}
|
||||
|
||||
object FanoutReasonEntities {
|
||||
val empty = FanoutReasonEntities(
|
||||
userIds = Set.empty,
|
||||
placeIds = Set.empty,
|
||||
semanticCoreIds = Set.empty,
|
||||
simclusterIds = Set.empty
|
||||
)
|
||||
|
||||
def from(reasons: Seq[TargetID]): FanoutReasonEntities = {
|
||||
val userIds: Set[Long] = reasons.collect {
|
||||
case TargetID.UserID(userId) => userId.id
|
||||
}.toSet
|
||||
val placeIds: Set[Long] = reasons.collect {
|
||||
case TargetID.PlaceID(placeId) => placeId.id
|
||||
}.toSet
|
||||
val semanticCoreIds: Set[SemanticCoreID] = reasons.collect {
|
||||
case TargetID.SemanticCoreID(semanticCoreID) => semanticCoreID
|
||||
}.toSet
|
||||
val simclusterIds: Set[SimClusterID] = reasons.collect {
|
||||
case TargetID.SimClusterID(simClusterID) => simClusterID
|
||||
}.toSet
|
||||
|
||||
FanoutReasonEntities(
|
||||
userIds = userIds,
|
||||
placeIds,
|
||||
semanticCoreIds = semanticCoreIds,
|
||||
simclusterIds = simclusterIds
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
trait MagicFanoutHydratedCandidate extends PushCandidate with MagicFanoutCandidate {
|
||||
lazy val fanoutReasonEntities: FanoutReasonEntities =
|
||||
FanoutReasonEntities.from(candidateMagicEventsReasons.map(_.reason))
|
||||
}
|
||||
|
||||
trait MagicFanoutEventHydratedCandidate
|
||||
extends MagicFanoutHydratedCandidate
|
||||
with MagicFanoutEventCandidate
|
||||
with RichEventFutCandidate {
|
||||
|
||||
def target: PushTypes.Target
|
||||
|
||||
def stats: StatsReceiver
|
||||
|
||||
def fanoutEvent: Option[FanoutEvent]
|
||||
|
||||
def eventFut: Future[Option[Event]]
|
||||
|
||||
def semanticEntityResults: Map[SemanticEntityForQuery, Option[EntityMegadata]]
|
||||
|
||||
def effectiveMagicEventsReasons: Option[Seq[MagicEventsReason]]
|
||||
|
||||
def followedTopicLocalizedEntities: Future[Set[LocalizedEntity]]
|
||||
|
||||
def ergLocalizedEntities: Future[Set[LocalizedEntity]]
|
||||
|
||||
lazy val entityAnnotationAlg: Map[TargetID, Set[AnnotationAlg]] =
|
||||
fanoutEvent
|
||||
.flatMap { metadata =>
|
||||
metadata.eventAnnotationInfo.map { eventAnnotationInfo =>
|
||||
eventAnnotationInfo.map {
|
||||
case (target, annotationInfoSet) => target -> annotationInfoSet.map(_.alg).toSet
|
||||
}.toMap
|
||||
}
|
||||
}.getOrElse(Map.empty)
|
||||
|
||||
lazy val eventSource: Option[String] = fanoutEvent.map { metadata =>
|
||||
val source = metadata.eventSource.getOrElse("undefined")
|
||||
stats.scope("eventSource").counter(source).incr()
|
||||
source
|
||||
}
|
||||
|
||||
lazy val semanticCoreEntityTags: Map[(Long, Long), Set[String]] =
|
||||
semanticEntityResults.flatMap {
|
||||
case (semanticEntityForQuery, entityMegadataOpt: Option[EntityMegadata]) =>
|
||||
for {
|
||||
entityMegadata <- entityMegadataOpt
|
||||
basicMetadata: BasicMetadata <- entityMegadata.basicMetadata
|
||||
indexableFields: EntityIndexFields <- basicMetadata.indexableFields
|
||||
tags <- indexableFields.tags
|
||||
} yield {
|
||||
((semanticEntityForQuery.domainId, semanticEntityForQuery.entityId), tags.toSet)
|
||||
}
|
||||
}
|
||||
|
||||
lazy val owningTwitterUserIds: Seq[Long] = semanticEntityResults.values.flatten
|
||||
.flatMap {
|
||||
_.basicMetadata.flatMap(_.twitter.flatMap(_.owningTwitterUserIds))
|
||||
}.flatten
|
||||
.toSeq
|
||||
.distinct
|
||||
|
||||
lazy val eventFanoutReasonEntities: FanoutReasonEntities =
|
||||
fanoutEvent match {
|
||||
case Some(fanout) =>
|
||||
fanout.targets
|
||||
.map { targets: Seq[thriftscala.Target] =>
|
||||
FanoutReasonEntities.from(targets.flatMap(_.whitelist).flatten)
|
||||
}.getOrElse(FanoutReasonEntities.empty)
|
||||
case _ => FanoutReasonEntities.empty
|
||||
}
|
||||
|
||||
override lazy val eventResultFut: Future[Event] = eventFut.map {
|
||||
case Some(eventResult) => eventResult
|
||||
case _ =>
|
||||
throw new IllegalArgumentException("event is None for MagicFanoutEventHydratedCandidate")
|
||||
}
|
||||
override val rankScore: Option[Double] = None
|
||||
override val predictionScore: Option[Double] = None
|
||||
}
|
||||
|
||||
case class MagicFanoutEventHydratedInfo(
|
||||
fanoutEvent: Option[FanoutEvent],
|
||||
semanticEntityResults: Map[SemanticEntityForQuery, Option[EntityMegadata]])
|
@ -0,0 +1,99 @@
|
||||
package com.twitter.frigate.pushservice.model
|
||||
|
||||
import com.twitter.escherbird.metadata.thriftscala.EntityMegadata
|
||||
import com.twitter.finagle.stats.StatsReceiver
|
||||
import com.twitter.frigate.common.base.MagicFanoutNewsEventCandidate
|
||||
import com.twitter.frigate.common.store.interests.InterestsLookupRequestWithContext
|
||||
import com.twitter.frigate.magic_events.thriftscala.FanoutEvent
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.RawCandidate
|
||||
import com.twitter.frigate.pushservice.config.Config
|
||||
import com.twitter.frigate.pushservice.ml.PushMLModelScorer
|
||||
import com.twitter.frigate.pushservice.model.candidate.CopyIds
|
||||
import com.twitter.frigate.pushservice.model.ibis.MagicFanoutNewsEventIbis2Hydrator
|
||||
import com.twitter.frigate.pushservice.model.ntab.MagicFanoutNewsEventNTabRequestHydrator
|
||||
import com.twitter.frigate.pushservice.predicate.PredicatesForCandidate
|
||||
import com.twitter.frigate.pushservice.predicate.event.EventPredicatesForCandidate
|
||||
import com.twitter.frigate.pushservice.predicate.magic_fanout.MagicFanoutPredicatesForCandidate
|
||||
import com.twitter.frigate.pushservice.predicate.magic_fanout.MagicFanoutTargetingPredicateWrappersForCandidate
|
||||
import com.twitter.frigate.pushservice.predicate.ntab_caret_fatigue.MagicFanoutNtabCaretFatiguePredicate
|
||||
import com.twitter.frigate.pushservice.store.EventRequest
|
||||
import com.twitter.frigate.pushservice.store.UttEntityHydrationStore
|
||||
import com.twitter.frigate.pushservice.take.predicates.BasicSendHandlerPredicates
|
||||
import com.twitter.hermit.predicate.NamedPredicate
|
||||
import com.twitter.hermit.store.semantic_core.SemanticEntityForQuery
|
||||
import com.twitter.interests.thriftscala.UserInterests
|
||||
import com.twitter.livevideo.timeline.domain.v2.Event
|
||||
import com.twitter.simclusters_v2.thriftscala.SimClustersInferredEntities
|
||||
import com.twitter.storehaus.ReadableStore
|
||||
|
||||
class MagicFanoutNewsEventPushCandidate(
|
||||
candidate: RawCandidate with MagicFanoutNewsEventCandidate,
|
||||
copyIds: CopyIds,
|
||||
override val fanoutEvent: Option[FanoutEvent],
|
||||
override val semanticEntityResults: Map[SemanticEntityForQuery, Option[EntityMegadata]],
|
||||
simClusterToEntities: Map[Int, Option[SimClustersInferredEntities]],
|
||||
lexServiceStore: ReadableStore[EventRequest, Event],
|
||||
interestsLookupStore: ReadableStore[InterestsLookupRequestWithContext, UserInterests],
|
||||
uttEntityHydrationStore: UttEntityHydrationStore
|
||||
)(
|
||||
implicit statsScoped: StatsReceiver,
|
||||
pushModelScorer: PushMLModelScorer)
|
||||
extends MagicFanoutEventPushCandidate(
|
||||
candidate,
|
||||
copyIds,
|
||||
fanoutEvent,
|
||||
semanticEntityResults,
|
||||
simClusterToEntities,
|
||||
lexServiceStore,
|
||||
interestsLookupStore,
|
||||
uttEntityHydrationStore
|
||||
)(statsScoped, pushModelScorer)
|
||||
with MagicFanoutNewsEventCandidate
|
||||
with MagicFanoutNewsEventIbis2Hydrator
|
||||
with MagicFanoutNewsEventNTabRequestHydrator {
|
||||
|
||||
override lazy val stats: StatsReceiver = statsScoped.scope("MagicFanoutNewsEventPushCandidate")
|
||||
override val statsReceiver: StatsReceiver = statsScoped.scope("MagicFanoutNewsEventPushCandidate")
|
||||
}
|
||||
|
||||
case class MagicFanoutNewsEventCandidatePredicates(config: Config)
|
||||
extends BasicSendHandlerPredicates[MagicFanoutNewsEventPushCandidate] {
|
||||
|
||||
implicit val statsReceiver: StatsReceiver = config.statsReceiver.scope(getClass.getSimpleName)
|
||||
|
||||
override val preCandidateSpecificPredicates: List[
|
||||
NamedPredicate[MagicFanoutNewsEventPushCandidate]
|
||||
] =
|
||||
List(
|
||||
EventPredicatesForCandidate.accountCountryPredicateWithAllowlist,
|
||||
PredicatesForCandidate.isDeviceEligibleForNewsOrSports,
|
||||
MagicFanoutPredicatesForCandidate.inferredUserDeviceLanguagePredicate,
|
||||
PredicatesForCandidate.secondaryDormantAccountPredicate(statsReceiver),
|
||||
MagicFanoutPredicatesForCandidate.highPriorityNewsEventExceptedPredicate(
|
||||
MagicFanoutTargetingPredicateWrappersForCandidate
|
||||
.magicFanoutTargetingPredicate(statsReceiver, config)
|
||||
)(config),
|
||||
MagicFanoutPredicatesForCandidate.geoOptOutPredicate(config.safeUserStore),
|
||||
EventPredicatesForCandidate.isNotDuplicateWithEventId,
|
||||
MagicFanoutPredicatesForCandidate.highPriorityNewsEventExceptedPredicate(
|
||||
MagicFanoutPredicatesForCandidate.newsNotificationFatigue()
|
||||
)(config),
|
||||
MagicFanoutPredicatesForCandidate.highPriorityNewsEventExceptedPredicate(
|
||||
MagicFanoutNtabCaretFatiguePredicate()
|
||||
)(config),
|
||||
MagicFanoutPredicatesForCandidate.escherbirdMagicfanoutEventParam()(statsReceiver),
|
||||
MagicFanoutPredicatesForCandidate.hasCustomTargetingForNewsEventsParam(
|
||||
statsReceiver
|
||||
)
|
||||
)
|
||||
|
||||
override val postCandidateSpecificPredicates: List[
|
||||
NamedPredicate[MagicFanoutNewsEventPushCandidate]
|
||||
] =
|
||||
List(
|
||||
MagicFanoutPredicatesForCandidate.magicFanoutNoOptoutInterestPredicate,
|
||||
MagicFanoutPredicatesForCandidate.geoTargetingHoldback(),
|
||||
MagicFanoutPredicatesForCandidate.userGeneratedEventsPredicate,
|
||||
EventPredicatesForCandidate.hasTitle,
|
||||
)
|
||||
}
|
@ -0,0 +1,95 @@
|
||||
package com.twitter.frigate.pushservice.model
|
||||
|
||||
import com.twitter.finagle.stats.StatsReceiver
|
||||
import com.twitter.frigate.common.base.MagicFanoutProductLaunchCandidate
|
||||
import com.twitter.frigate.common.util.{FeatureSwitchParams => FS}
|
||||
import com.twitter.frigate.magic_events.thriftscala.MagicEventsReason
|
||||
import com.twitter.frigate.magic_events.thriftscala.ProductType
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.PushCandidate
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.RawCandidate
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.Target
|
||||
import com.twitter.frigate.pushservice.predicate.magic_fanout.MagicFanoutPredicatesUtil
|
||||
import com.twitter.frigate.pushservice.config.Config
|
||||
import com.twitter.frigate.pushservice.ml.PushMLModelScorer
|
||||
import com.twitter.frigate.pushservice.model.candidate.CopyIds
|
||||
import com.twitter.frigate.pushservice.model.ibis.MagicFanoutProductLaunchIbis2Hydrator
|
||||
import com.twitter.frigate.pushservice.model.ntab.MagicFanoutProductLaunchNtabRequestHydrator
|
||||
import com.twitter.frigate.pushservice.predicate.PredicatesForCandidate
|
||||
import com.twitter.frigate.pushservice.predicate.magic_fanout.MagicFanoutPredicatesForCandidate
|
||||
import com.twitter.frigate.pushservice.predicate.ntab_caret_fatigue.MagicFanoutNtabCaretFatiguePredicate
|
||||
import com.twitter.frigate.pushservice.take.predicates.BasicSendHandlerPredicates
|
||||
import com.twitter.frigate.thriftscala.FrigateNotification
|
||||
import com.twitter.hermit.predicate.NamedPredicate
|
||||
|
||||
class MagicFanoutProductLaunchPushCandidate(
|
||||
candidate: RawCandidate with MagicFanoutProductLaunchCandidate,
|
||||
copyIds: CopyIds
|
||||
)(
|
||||
implicit val statsScoped: StatsReceiver,
|
||||
pushModelScorer: PushMLModelScorer)
|
||||
extends PushCandidate
|
||||
with MagicFanoutProductLaunchCandidate
|
||||
with MagicFanoutProductLaunchIbis2Hydrator
|
||||
with MagicFanoutProductLaunchNtabRequestHydrator {
|
||||
|
||||
override val frigateNotification: FrigateNotification = candidate.frigateNotification
|
||||
|
||||
override val pushCopyId: Option[Int] = copyIds.pushCopyId
|
||||
|
||||
override val ntabCopyId: Option[Int] = copyIds.ntabCopyId
|
||||
|
||||
override val pushId: Long = candidate.pushId
|
||||
|
||||
override val productLaunchType: ProductType = candidate.productLaunchType
|
||||
|
||||
override val candidateMagicEventsReasons: Seq[MagicEventsReason] =
|
||||
candidate.candidateMagicEventsReasons
|
||||
|
||||
override val copyAggregationId: Option[String] = copyIds.aggregationId
|
||||
|
||||
override val target: Target = candidate.target
|
||||
|
||||
override val weightedOpenOrNtabClickModelScorer: PushMLModelScorer = pushModelScorer
|
||||
|
||||
override val statsReceiver: StatsReceiver =
|
||||
statsScoped.scope("MagicFanoutProductLaunchPushCandidate")
|
||||
}
|
||||
|
||||
case class MagicFanoutProductLaunchPushCandidatePredicates(config: Config)
|
||||
extends BasicSendHandlerPredicates[MagicFanoutProductLaunchPushCandidate] {
|
||||
|
||||
implicit val statsReceiver: StatsReceiver = config.statsReceiver.scope(getClass.getSimpleName)
|
||||
|
||||
override val preCandidateSpecificPredicates: List[
|
||||
NamedPredicate[MagicFanoutProductLaunchPushCandidate]
|
||||
] =
|
||||
List(
|
||||
PredicatesForCandidate.isDeviceEligibleForCreatorPush,
|
||||
PredicatesForCandidate.exceptedPredicate(
|
||||
"excepted_is_target_blue_verified",
|
||||
MagicFanoutPredicatesUtil.shouldSkipBlueVerifiedCheckForCandidate,
|
||||
PredicatesForCandidate.isTargetBlueVerified.flip
|
||||
), // no need to send if target is already Blue Verified
|
||||
PredicatesForCandidate.exceptedPredicate(
|
||||
"excepted_is_target_legacy_verified",
|
||||
MagicFanoutPredicatesUtil.shouldSkipLegacyVerifiedCheckForCandidate,
|
||||
PredicatesForCandidate.isTargetLegacyVerified.flip
|
||||
), // no need to send if target is already Legacy Verified
|
||||
PredicatesForCandidate.exceptedPredicate(
|
||||
"excepted_is_target_super_follow_creator",
|
||||
MagicFanoutPredicatesUtil.shouldSkipSuperFollowCreatorCheckForCandidate,
|
||||
PredicatesForCandidate.isTargetSuperFollowCreator.flip
|
||||
), // no need to send if target is already Super Follow Creator
|
||||
PredicatesForCandidate.paramPredicate(
|
||||
FS.EnableMagicFanoutProductLaunch
|
||||
),
|
||||
MagicFanoutPredicatesForCandidate.magicFanoutProductLaunchFatigue(),
|
||||
)
|
||||
|
||||
override val postCandidateSpecificPredicates: List[
|
||||
NamedPredicate[MagicFanoutProductLaunchPushCandidate]
|
||||
] =
|
||||
List(
|
||||
MagicFanoutNtabCaretFatiguePredicate(),
|
||||
)
|
||||
}
|
@ -0,0 +1,119 @@
|
||||
package com.twitter.frigate.pushservice.model
|
||||
|
||||
import com.twitter.escherbird.metadata.thriftscala.EntityMegadata
|
||||
import com.twitter.finagle.stats.StatsReceiver
|
||||
import com.twitter.frigate.common.base.BaseGameScore
|
||||
import com.twitter.frigate.common.base.MagicFanoutSportsEventCandidate
|
||||
import com.twitter.frigate.common.base.MagicFanoutSportsScoreInformation
|
||||
import com.twitter.frigate.common.base.TeamInfo
|
||||
import com.twitter.frigate.common.store.interests.InterestsLookupRequestWithContext
|
||||
import com.twitter.frigate.magic_events.thriftscala.FanoutEvent
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.RawCandidate
|
||||
import com.twitter.frigate.pushservice.config.Config
|
||||
import com.twitter.frigate.pushservice.ml.PushMLModelScorer
|
||||
import com.twitter.frigate.pushservice.model.candidate.CopyIds
|
||||
import com.twitter.frigate.pushservice.model.ibis.MagicFanoutSportsEventIbis2Hydrator
|
||||
import com.twitter.frigate.pushservice.model.ntab.MagicFanoutSportsEventNTabRequestHydrator
|
||||
import com.twitter.frigate.pushservice.params.PushFeatureSwitchParams
|
||||
import com.twitter.frigate.pushservice.predicate.PredicatesForCandidate
|
||||
import com.twitter.frigate.pushservice.predicate.magic_fanout.MagicFanoutPredicatesForCandidate
|
||||
import com.twitter.frigate.pushservice.predicate.magic_fanout.MagicFanoutTargetingPredicateWrappersForCandidate
|
||||
import com.twitter.frigate.pushservice.predicate.ntab_caret_fatigue.MagicFanoutNtabCaretFatiguePredicate
|
||||
import com.twitter.frigate.pushservice.store.EventRequest
|
||||
import com.twitter.frigate.pushservice.store.UttEntityHydrationStore
|
||||
import com.twitter.frigate.pushservice.take.predicates.BasicSendHandlerPredicates
|
||||
import com.twitter.hermit.predicate.NamedPredicate
|
||||
import com.twitter.hermit.store.semantic_core.SemanticEntityForQuery
|
||||
import com.twitter.interests.thriftscala.UserInterests
|
||||
import com.twitter.livevideo.timeline.domain.v2.Event
|
||||
import com.twitter.livevideo.timeline.domain.v2.HydrationOptions
|
||||
import com.twitter.livevideo.timeline.domain.v2.LookupContext
|
||||
import com.twitter.simclusters_v2.thriftscala.SimClustersInferredEntities
|
||||
import com.twitter.storehaus.ReadableStore
|
||||
import com.twitter.util.Future
|
||||
|
||||
class MagicFanoutSportsPushCandidate(
|
||||
candidate: RawCandidate
|
||||
with MagicFanoutSportsEventCandidate
|
||||
with MagicFanoutSportsScoreInformation,
|
||||
copyIds: CopyIds,
|
||||
override val fanoutEvent: Option[FanoutEvent],
|
||||
override val semanticEntityResults: Map[SemanticEntityForQuery, Option[EntityMegadata]],
|
||||
simClusterToEntities: Map[Int, Option[SimClustersInferredEntities]],
|
||||
lexServiceStore: ReadableStore[EventRequest, Event],
|
||||
interestsLookupStore: ReadableStore[InterestsLookupRequestWithContext, UserInterests],
|
||||
uttEntityHydrationStore: UttEntityHydrationStore
|
||||
)(
|
||||
implicit statsScoped: StatsReceiver,
|
||||
pushModelScorer: PushMLModelScorer)
|
||||
extends MagicFanoutEventPushCandidate(
|
||||
candidate,
|
||||
copyIds,
|
||||
fanoutEvent,
|
||||
semanticEntityResults,
|
||||
simClusterToEntities,
|
||||
lexServiceStore,
|
||||
interestsLookupStore,
|
||||
uttEntityHydrationStore)(statsScoped, pushModelScorer)
|
||||
with MagicFanoutSportsEventCandidate
|
||||
with MagicFanoutSportsScoreInformation
|
||||
with MagicFanoutSportsEventNTabRequestHydrator
|
||||
with MagicFanoutSportsEventIbis2Hydrator {
|
||||
|
||||
override val isScoreUpdate: Boolean = candidate.isScoreUpdate
|
||||
override val gameScores: Future[Option[BaseGameScore]] = candidate.gameScores
|
||||
override val homeTeamInfo: Future[Option[TeamInfo]] = candidate.homeTeamInfo
|
||||
override val awayTeamInfo: Future[Option[TeamInfo]] = candidate.awayTeamInfo
|
||||
|
||||
override lazy val stats: StatsReceiver = statsScoped.scope("MagicFanoutSportsPushCandidate")
|
||||
override val statsReceiver: StatsReceiver = statsScoped.scope("MagicFanoutSportsPushCandidate")
|
||||
|
||||
override lazy val eventRequestFut: Future[Option[EventRequest]] = {
|
||||
Future.join(target.inferredUserDeviceLanguage, target.accountCountryCode).map {
|
||||
case (inferredUserDeviceLanguage, accountCountryCode) =>
|
||||
Some(
|
||||
EventRequest(
|
||||
eventId,
|
||||
lookupContext = LookupContext(
|
||||
hydrationOptions = HydrationOptions(
|
||||
includeSquareImage = true,
|
||||
includePrimaryImage = true
|
||||
),
|
||||
language = inferredUserDeviceLanguage,
|
||||
countryCode = accountCountryCode
|
||||
)
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
case class MagicFanoutSportsEventCandidatePredicates(config: Config)
|
||||
extends BasicSendHandlerPredicates[MagicFanoutSportsPushCandidate] {
|
||||
|
||||
implicit val statsReceiver: StatsReceiver = config.statsReceiver.scope(getClass.getSimpleName)
|
||||
|
||||
override val preCandidateSpecificPredicates: List[
|
||||
NamedPredicate[MagicFanoutSportsPushCandidate]
|
||||
] =
|
||||
List(
|
||||
PredicatesForCandidate.paramPredicate(PushFeatureSwitchParams.EnableScoreFanoutNotification)
|
||||
)
|
||||
|
||||
override val postCandidateSpecificPredicates: List[
|
||||
NamedPredicate[MagicFanoutSportsPushCandidate]
|
||||
] =
|
||||
List(
|
||||
PredicatesForCandidate.isDeviceEligibleForNewsOrSports,
|
||||
MagicFanoutPredicatesForCandidate.inferredUserDeviceLanguagePredicate,
|
||||
MagicFanoutPredicatesForCandidate.highPriorityEventExceptedPredicate(
|
||||
MagicFanoutTargetingPredicateWrappersForCandidate
|
||||
.magicFanoutTargetingPredicate(statsReceiver, config)
|
||||
)(config),
|
||||
PredicatesForCandidate.secondaryDormantAccountPredicate(
|
||||
statsReceiver
|
||||
),
|
||||
MagicFanoutPredicatesForCandidate.highPriorityEventExceptedPredicate(
|
||||
MagicFanoutNtabCaretFatiguePredicate()
|
||||
)(config),
|
||||
)
|
||||
}
|
@ -0,0 +1,68 @@
|
||||
package com.twitter.frigate.pushservice.model
|
||||
|
||||
import com.twitter.contentrecommender.thriftscala.MetricTag
|
||||
import com.twitter.finagle.stats.StatsReceiver
|
||||
import com.twitter.frigate.common.base.OutOfNetworkTweetCandidate
|
||||
import com.twitter.frigate.common.base.TopicCandidate
|
||||
import com.twitter.frigate.common.base.TweetAuthorDetails
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.PushCandidate
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.RawCandidate
|
||||
import com.twitter.frigate.pushservice.config.Config
|
||||
import com.twitter.frigate.pushservice.ml.PushMLModelScorer
|
||||
import com.twitter.frigate.pushservice.model.candidate.CopyIds
|
||||
import com.twitter.frigate.pushservice.model.ibis.OutOfNetworkTweetIbis2HydratorForCandidate
|
||||
import com.twitter.frigate.pushservice.model.ntab.OutOfNetworkTweetNTabRequestHydrator
|
||||
import com.twitter.frigate.pushservice.predicate.HealthPredicates
|
||||
import com.twitter.frigate.pushservice.take.predicates.OutOfNetworkTweetPredicates
|
||||
import com.twitter.frigate.thriftscala.CommonRecommendationType
|
||||
import com.twitter.gizmoduck.thriftscala.User
|
||||
import com.twitter.hermit.predicate.NamedPredicate
|
||||
import com.twitter.stitch.tweetypie.TweetyPie
|
||||
import com.twitter.topiclisting.utt.LocalizedEntity
|
||||
import com.twitter.util.Future
|
||||
|
||||
class OutOfNetworkTweetPushCandidate(
|
||||
candidate: RawCandidate with OutOfNetworkTweetCandidate with TopicCandidate,
|
||||
author: Future[Option[User]],
|
||||
copyIds: CopyIds
|
||||
)(
|
||||
implicit stats: StatsReceiver,
|
||||
pushModelScorer: PushMLModelScorer)
|
||||
extends PushCandidate
|
||||
with OutOfNetworkTweetCandidate
|
||||
with TopicCandidate
|
||||
with TweetAuthorDetails
|
||||
with OutOfNetworkTweetNTabRequestHydrator
|
||||
with OutOfNetworkTweetIbis2HydratorForCandidate {
|
||||
override val statsReceiver: StatsReceiver = stats
|
||||
override val weightedOpenOrNtabClickModelScorer: PushMLModelScorer = pushModelScorer
|
||||
override val tweetId: Long = candidate.tweetId
|
||||
override lazy val tweetyPieResult: Option[TweetyPie.TweetyPieResult] =
|
||||
candidate.tweetyPieResult
|
||||
override lazy val tweetAuthor: Future[Option[User]] = author
|
||||
override val target: PushTypes.Target = candidate.target
|
||||
override lazy val commonRecType: CommonRecommendationType =
|
||||
candidate.commonRecType
|
||||
override val pushCopyId: Option[Int] = copyIds.pushCopyId
|
||||
override val ntabCopyId: Option[Int] = copyIds.ntabCopyId
|
||||
override val copyAggregationId: Option[String] = copyIds.aggregationId
|
||||
override lazy val semanticCoreEntityId: Option[Long] = candidate.semanticCoreEntityId
|
||||
override lazy val localizedUttEntity: Option[LocalizedEntity] = candidate.localizedUttEntity
|
||||
override lazy val algorithmCR: Option[String] = candidate.algorithmCR
|
||||
override lazy val isMrBackfillCR: Option[Boolean] = candidate.isMrBackfillCR
|
||||
override lazy val tagsCR: Option[Seq[MetricTag]] = candidate.tagsCR
|
||||
}
|
||||
|
||||
case class OutOfNetworkTweetCandidatePredicates(override val config: Config)
|
||||
extends OutOfNetworkTweetPredicates[OutOfNetworkTweetPushCandidate] {
|
||||
|
||||
implicit val statsReceiver: StatsReceiver = config.statsReceiver.scope(getClass.getSimpleName)
|
||||
|
||||
override def postCandidateSpecificPredicates: List[
|
||||
NamedPredicate[OutOfNetworkTweetPushCandidate]
|
||||
] =
|
||||
List(
|
||||
HealthPredicates.agathaAbusiveTweetAuthorPredicateMrTwistly(),
|
||||
)
|
||||
|
||||
}
|
@ -0,0 +1,61 @@
|
||||
package com.twitter.frigate.pushservice.model
|
||||
|
||||
import com.twitter.frigate.common.base._
|
||||
import com.twitter.frigate.common.candidate.UserLanguage
|
||||
import com.twitter.frigate.common.candidate._
|
||||
import com.twitter.frigate.data_pipeline.features_common.RequestContextForFeatureStore
|
||||
import com.twitter.frigate.pushservice.model.candidate.CopyInfo
|
||||
import com.twitter.frigate.pushservice.model.candidate.MLScores
|
||||
import com.twitter.frigate.pushservice.model.candidate.QualityScribing
|
||||
import com.twitter.frigate.pushservice.model.candidate.Scriber
|
||||
import com.twitter.frigate.pushservice.model.ibis.Ibis2HydratorForCandidate
|
||||
import com.twitter.frigate.pushservice.model.ntab.NTabRequest
|
||||
import com.twitter.frigate.pushservice.take.ChannelForCandidate
|
||||
import com.twitter.frigate.pushservice.target._
|
||||
import com.twitter.util.Time
|
||||
|
||||
object PushTypes {
|
||||
|
||||
trait Target
|
||||
extends TargetUser
|
||||
with UserDetails
|
||||
with TargetWithPushContext
|
||||
with TargetDecider
|
||||
with TargetABDecider
|
||||
with FrigateHistory
|
||||
with PushTargeting
|
||||
with TargetScoringDetails
|
||||
with TweetImpressionHistory
|
||||
with CustomConfigForExpt
|
||||
with CaretFeedbackHistory
|
||||
with NotificationFeedbackHistory
|
||||
with PromptFeedbackHistory
|
||||
with HTLVisitHistory
|
||||
with MaxTweetAge
|
||||
with NewUserDetails
|
||||
with ResurrectedUserDetails
|
||||
with TargetWithSeedUsers
|
||||
with MagicFanoutHistory
|
||||
with OptOutUserInterests
|
||||
with RequestContextForFeatureStore
|
||||
with TargetAppPermissions
|
||||
with UserLanguage
|
||||
with InlineActionHistory
|
||||
with TargetPlaces
|
||||
|
||||
trait RawCandidate extends Candidate with TargetInfo[PushTypes.Target] with RecommendationType {
|
||||
|
||||
val createdAt: Time = Time.now
|
||||
}
|
||||
|
||||
trait PushCandidate
|
||||
extends RawCandidate
|
||||
with CandidateScoringDetails
|
||||
with MLScores
|
||||
with QualityScribing
|
||||
with CopyInfo
|
||||
with Scriber
|
||||
with Ibis2HydratorForCandidate
|
||||
with NTabRequest
|
||||
with ChannelForCandidate
|
||||
}
|
@ -0,0 +1,85 @@
|
||||
package com.twitter.frigate.pushservice.model
|
||||
|
||||
import com.twitter.finagle.stats.StatsReceiver
|
||||
import com.twitter.frigate.common.base.ScheduledSpaceSpeakerCandidate
|
||||
import com.twitter.frigate.common.base.SpaceCandidateFanoutDetails
|
||||
import com.twitter.frigate.common.util.FeatureSwitchParams
|
||||
import com.twitter.frigate.magic_events.thriftscala.SpaceMetadata
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.PushCandidate
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.RawCandidate
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.Target
|
||||
import com.twitter.frigate.pushservice.config.Config
|
||||
import com.twitter.frigate.pushservice.ml.PushMLModelScorer
|
||||
import com.twitter.frigate.pushservice.model.candidate.CopyIds
|
||||
import com.twitter.frigate.pushservice.model.ibis.ScheduledSpaceSpeakerIbis2Hydrator
|
||||
import com.twitter.frigate.pushservice.model.ntab.ScheduledSpaceSpeakerNTabRequestHydrator
|
||||
import com.twitter.frigate.pushservice.predicate.PredicatesForCandidate
|
||||
import com.twitter.frigate.pushservice.predicate.SpacePredicate
|
||||
import com.twitter.frigate.pushservice.take.predicates.BasicSendHandlerPredicates
|
||||
import com.twitter.frigate.thriftscala.FrigateNotification
|
||||
import com.twitter.gizmoduck.thriftscala.User
|
||||
import com.twitter.hermit.predicate.NamedPredicate
|
||||
import com.twitter.storehaus.ReadableStore
|
||||
import com.twitter.ubs.thriftscala.AudioSpace
|
||||
import com.twitter.util.Future
|
||||
|
||||
class ScheduledSpaceSpeakerPushCandidate(
|
||||
candidate: RawCandidate with ScheduledSpaceSpeakerCandidate,
|
||||
hostUser: Option[User],
|
||||
copyIds: CopyIds,
|
||||
audioSpaceStore: ReadableStore[String, AudioSpace]
|
||||
)(
|
||||
implicit val statsScoped: StatsReceiver,
|
||||
pushModelScorer: PushMLModelScorer)
|
||||
extends PushCandidate
|
||||
with ScheduledSpaceSpeakerCandidate
|
||||
with ScheduledSpaceSpeakerIbis2Hydrator
|
||||
with SpaceCandidateFanoutDetails
|
||||
with ScheduledSpaceSpeakerNTabRequestHydrator {
|
||||
|
||||
override val startTime: Long = candidate.startTime
|
||||
|
||||
override val hydratedHost: Option[User] = hostUser
|
||||
|
||||
override val spaceId: String = candidate.spaceId
|
||||
|
||||
override val hostId: Option[Long] = candidate.hostId
|
||||
|
||||
override val speakerIds: Option[Seq[Long]] = candidate.speakerIds
|
||||
|
||||
override val listenerIds: Option[Seq[Long]] = candidate.listenerIds
|
||||
|
||||
override val frigateNotification: FrigateNotification = candidate.frigateNotification
|
||||
|
||||
override val pushCopyId: Option[Int] = copyIds.pushCopyId
|
||||
|
||||
override val ntabCopyId: Option[Int] = copyIds.ntabCopyId
|
||||
|
||||
override val copyAggregationId: Option[String] = copyIds.aggregationId
|
||||
|
||||
override val target: Target = candidate.target
|
||||
|
||||
override val weightedOpenOrNtabClickModelScorer: PushMLModelScorer = pushModelScorer
|
||||
|
||||
override lazy val audioSpaceFut: Future[Option[AudioSpace]] = audioSpaceStore.get(spaceId)
|
||||
|
||||
override val spaceFanoutMetadata: Option[SpaceMetadata] = None
|
||||
|
||||
override val statsReceiver: StatsReceiver =
|
||||
statsScoped.scope("ScheduledSpaceSpeakerCandidate")
|
||||
}
|
||||
|
||||
case class ScheduledSpaceSpeakerCandidatePredicates(config: Config)
|
||||
extends BasicSendHandlerPredicates[ScheduledSpaceSpeakerPushCandidate] {
|
||||
|
||||
implicit val statsReceiver: StatsReceiver = config.statsReceiver.scope(getClass.getSimpleName)
|
||||
|
||||
override val preCandidateSpecificPredicates: List[
|
||||
NamedPredicate[ScheduledSpaceSpeakerPushCandidate]
|
||||
] = List(
|
||||
SpacePredicate.scheduledSpaceStarted(
|
||||
config.audioSpaceStore
|
||||
),
|
||||
PredicatesForCandidate.paramPredicate(FeatureSwitchParams.EnableScheduledSpaceSpeakers)
|
||||
)
|
||||
}
|
@ -0,0 +1,86 @@
|
||||
package com.twitter.frigate.pushservice.model
|
||||
|
||||
import com.twitter.finagle.stats.StatsReceiver
|
||||
import com.twitter.frigate.common.base.ScheduledSpaceSubscriberCandidate
|
||||
import com.twitter.frigate.common.base.SpaceCandidateFanoutDetails
|
||||
import com.twitter.frigate.common.util.FeatureSwitchParams
|
||||
import com.twitter.frigate.magic_events.thriftscala.SpaceMetadata
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.PushCandidate
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.RawCandidate
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.Target
|
||||
import com.twitter.frigate.pushservice.config.Config
|
||||
import com.twitter.frigate.pushservice.ml.PushMLModelScorer
|
||||
import com.twitter.frigate.pushservice.model.candidate.CopyIds
|
||||
import com.twitter.frigate.pushservice.model.ibis.ScheduledSpaceSubscriberIbis2Hydrator
|
||||
import com.twitter.frigate.pushservice.model.ntab.ScheduledSpaceSubscriberNTabRequestHydrator
|
||||
import com.twitter.frigate.pushservice.predicate._
|
||||
import com.twitter.frigate.pushservice.take.predicates.BasicSendHandlerPredicates
|
||||
import com.twitter.frigate.thriftscala.FrigateNotification
|
||||
import com.twitter.gizmoduck.thriftscala.User
|
||||
import com.twitter.hermit.predicate.NamedPredicate
|
||||
import com.twitter.storehaus.ReadableStore
|
||||
import com.twitter.ubs.thriftscala.AudioSpace
|
||||
import com.twitter.util.Future
|
||||
|
||||
class ScheduledSpaceSubscriberPushCandidate(
|
||||
candidate: RawCandidate with ScheduledSpaceSubscriberCandidate,
|
||||
hostUser: Option[User],
|
||||
copyIds: CopyIds,
|
||||
audioSpaceStore: ReadableStore[String, AudioSpace]
|
||||
)(
|
||||
implicit val statsScoped: StatsReceiver,
|
||||
pushModelScorer: PushMLModelScorer)
|
||||
extends PushCandidate
|
||||
with ScheduledSpaceSubscriberCandidate
|
||||
with SpaceCandidateFanoutDetails
|
||||
with ScheduledSpaceSubscriberIbis2Hydrator
|
||||
with ScheduledSpaceSubscriberNTabRequestHydrator {
|
||||
|
||||
override val startTime: Long = candidate.startTime
|
||||
|
||||
override val hydratedHost: Option[User] = hostUser
|
||||
|
||||
override val spaceId: String = candidate.spaceId
|
||||
|
||||
override val hostId: Option[Long] = candidate.hostId
|
||||
|
||||
override val speakerIds: Option[Seq[Long]] = candidate.speakerIds
|
||||
|
||||
override val listenerIds: Option[Seq[Long]] = candidate.listenerIds
|
||||
|
||||
override val frigateNotification: FrigateNotification = candidate.frigateNotification
|
||||
|
||||
override val pushCopyId: Option[Int] = copyIds.pushCopyId
|
||||
|
||||
override val ntabCopyId: Option[Int] = copyIds.ntabCopyId
|
||||
|
||||
override val copyAggregationId: Option[String] = copyIds.aggregationId
|
||||
|
||||
override val target: Target = candidate.target
|
||||
|
||||
override lazy val audioSpaceFut: Future[Option[AudioSpace]] = audioSpaceStore.get(spaceId)
|
||||
|
||||
override val spaceFanoutMetadata: Option[SpaceMetadata] = None
|
||||
|
||||
override val weightedOpenOrNtabClickModelScorer: PushMLModelScorer = pushModelScorer
|
||||
|
||||
override val statsReceiver: StatsReceiver =
|
||||
statsScoped.scope("ScheduledSpaceSubscriberCandidate")
|
||||
}
|
||||
|
||||
case class ScheduledSpaceSubscriberCandidatePredicates(config: Config)
|
||||
extends BasicSendHandlerPredicates[ScheduledSpaceSubscriberPushCandidate] {
|
||||
|
||||
implicit val statsReceiver: StatsReceiver = config.statsReceiver.scope(getClass.getSimpleName)
|
||||
|
||||
override val preCandidateSpecificPredicates: List[
|
||||
NamedPredicate[ScheduledSpaceSubscriberPushCandidate]
|
||||
] =
|
||||
List(
|
||||
PredicatesForCandidate.paramPredicate(FeatureSwitchParams.EnableScheduledSpaceSubscribers),
|
||||
SpacePredicate.narrowCastSpace,
|
||||
SpacePredicate.targetInSpace(config.audioSpaceParticipantsStore),
|
||||
SpacePredicate.spaceHostTargetUserBlocking(config.edgeStore),
|
||||
PredicatesForCandidate.duplicateSpacesPredicate
|
||||
)
|
||||
}
|
@ -0,0 +1,56 @@
|
||||
package com.twitter.frigate.pushservice.model
|
||||
|
||||
import com.twitter.finagle.stats.StatsReceiver
|
||||
import com.twitter.frigate.common.base.SubscribedSearchTweetCandidate
|
||||
import com.twitter.frigate.common.base.TweetAuthorDetails
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.PushCandidate
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.RawCandidate
|
||||
import com.twitter.frigate.pushservice.config.Config
|
||||
import com.twitter.frigate.pushservice.ml.PushMLModelScorer
|
||||
import com.twitter.frigate.pushservice.model.candidate.CopyIds
|
||||
import com.twitter.frigate.pushservice.model.ibis.SubscribedSearchTweetIbis2Hydrator
|
||||
import com.twitter.frigate.pushservice.model.ntab.SubscribedSearchTweetNtabRequestHydrator
|
||||
import com.twitter.frigate.pushservice.take.predicates.BasicTweetPredicatesForRFPH
|
||||
import com.twitter.gizmoduck.thriftscala.User
|
||||
import com.twitter.stitch.tweetypie.TweetyPie
|
||||
import com.twitter.util.Future
|
||||
|
||||
class SubscribedSearchTweetPushCandidate(
|
||||
candidate: RawCandidate with SubscribedSearchTweetCandidate,
|
||||
author: Option[User],
|
||||
copyIds: CopyIds
|
||||
)(
|
||||
implicit stats: StatsReceiver,
|
||||
pushModelScorer: PushMLModelScorer)
|
||||
extends PushCandidate
|
||||
with SubscribedSearchTweetCandidate
|
||||
with TweetAuthorDetails
|
||||
with SubscribedSearchTweetIbis2Hydrator
|
||||
with SubscribedSearchTweetNtabRequestHydrator {
|
||||
override def tweetAuthor: Future[Option[User]] = Future.value(author)
|
||||
|
||||
override def weightedOpenOrNtabClickModelScorer: PushMLModelScorer = pushModelScorer
|
||||
|
||||
override def tweetId: Long = candidate.tweetId
|
||||
|
||||
override def pushCopyId: Option[Int] = copyIds.pushCopyId
|
||||
|
||||
override def ntabCopyId: Option[Int] = copyIds.ntabCopyId
|
||||
|
||||
override def copyAggregationId: Option[String] = copyIds.aggregationId
|
||||
|
||||
override def target: PushTypes.Target = candidate.target
|
||||
|
||||
override def searchTerm: String = candidate.searchTerm
|
||||
|
||||
override def timeBoundedLandingUrl: Option[String] = None
|
||||
|
||||
override def statsReceiver: StatsReceiver = stats
|
||||
|
||||
override def tweetyPieResult: Option[TweetyPie.TweetyPieResult] = candidate.tweetyPieResult
|
||||
}
|
||||
|
||||
case class SubscribedSearchTweetCandidatePredicates(override val config: Config)
|
||||
extends BasicTweetPredicatesForRFPH[SubscribedSearchTweetPushCandidate] {
|
||||
implicit val statsReceiver: StatsReceiver = config.statsReceiver.scope(getClass.getSimpleName)
|
||||
}
|
@ -0,0 +1,70 @@
|
||||
package com.twitter.frigate.pushservice.model
|
||||
|
||||
import com.twitter.finagle.stats.StatsReceiver
|
||||
import com.twitter.frigate.common.base.TopTweetImpressionsCandidate
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.PushCandidate
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.RawCandidate
|
||||
import com.twitter.frigate.pushservice.config.Config
|
||||
import com.twitter.frigate.pushservice.ml.PushMLModelScorer
|
||||
import com.twitter.frigate.pushservice.model.candidate.CopyIds
|
||||
import com.twitter.frigate.pushservice.model.ibis.TopTweetImpressionsCandidateIbis2Hydrator
|
||||
import com.twitter.frigate.pushservice.model.ntab.TopTweetImpressionsNTabRequestHydrator
|
||||
import com.twitter.frigate.pushservice.predicate.TopTweetImpressionsPredicates
|
||||
import com.twitter.frigate.pushservice.take.predicates.BasicTweetPredicatesForRFPH
|
||||
import com.twitter.frigate.thriftscala.CommonRecommendationType
|
||||
import com.twitter.hermit.predicate.NamedPredicate
|
||||
import com.twitter.notificationservice.thriftscala.StoryContext
|
||||
import com.twitter.notificationservice.thriftscala.StoryContextValue
|
||||
import com.twitter.stitch.tweetypie.TweetyPie
|
||||
|
||||
/**
|
||||
* This class defines a hydrated [[TopTweetImpressionsCandidate]]
|
||||
*
|
||||
* @param candidate: [[TopTweetImpressionsCandidate]] for the candidate representing the user's Tweet with the most impressions
|
||||
* @param copyIds: push and ntab notification copy
|
||||
* @param stats: finagle scoped states receiver
|
||||
* @param pushModelScorer: ML model score object for fetching prediction scores
|
||||
*/
|
||||
class TopTweetImpressionsPushCandidate(
|
||||
candidate: RawCandidate with TopTweetImpressionsCandidate,
|
||||
copyIds: CopyIds
|
||||
)(
|
||||
implicit stats: StatsReceiver,
|
||||
pushModelScorer: PushMLModelScorer)
|
||||
extends PushCandidate
|
||||
with TopTweetImpressionsCandidate
|
||||
with TopTweetImpressionsNTabRequestHydrator
|
||||
with TopTweetImpressionsCandidateIbis2Hydrator {
|
||||
override val target: PushTypes.Target = candidate.target
|
||||
override val commonRecType: CommonRecommendationType = candidate.commonRecType
|
||||
override val tweetId: Long = candidate.tweetId
|
||||
override lazy val tweetyPieResult: Option[TweetyPie.TweetyPieResult] =
|
||||
candidate.tweetyPieResult
|
||||
override val impressionsCount: Long = candidate.impressionsCount
|
||||
|
||||
override val statsReceiver: StatsReceiver = stats.scope(getClass.getSimpleName)
|
||||
override val pushCopyId: Option[Int] = copyIds.pushCopyId
|
||||
override val ntabCopyId: Option[Int] = copyIds.ntabCopyId
|
||||
override val copyAggregationId: Option[String] = copyIds.aggregationId
|
||||
override val weightedOpenOrNtabClickModelScorer: PushMLModelScorer = pushModelScorer
|
||||
override val storyContext: Option[StoryContext] =
|
||||
Some(StoryContext(altText = "", value = Some(StoryContextValue.Tweets(Seq(tweetId)))))
|
||||
}
|
||||
|
||||
case class TopTweetImpressionsPushCandidatePredicates(config: Config)
|
||||
extends BasicTweetPredicatesForRFPH[TopTweetImpressionsPushCandidate] {
|
||||
|
||||
implicit val statsReceiver: StatsReceiver = config.statsReceiver.scope(getClass.getSimpleName)
|
||||
|
||||
override val preCandidateSpecificPredicates: List[
|
||||
NamedPredicate[TopTweetImpressionsPushCandidate]
|
||||
] = List(
|
||||
TopTweetImpressionsPredicates.topTweetImpressionsFatiguePredicate
|
||||
)
|
||||
|
||||
override val postCandidateSpecificPredicates: List[
|
||||
NamedPredicate[TopTweetImpressionsPushCandidate]
|
||||
] = List(
|
||||
TopTweetImpressionsPredicates.topTweetImpressionsThreshold()
|
||||
)
|
||||
}
|
@ -0,0 +1,71 @@
|
||||
package com.twitter.frigate.pushservice.model
|
||||
|
||||
import com.twitter.finagle.stats.StatsReceiver
|
||||
import com.twitter.frigate.common.base.TopicProofTweetCandidate
|
||||
import com.twitter.frigate.common.base.TweetAuthorDetails
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.PushCandidate
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.RawCandidate
|
||||
import com.twitter.frigate.pushservice.config.Config
|
||||
import com.twitter.frigate.pushservice.ml.PushMLModelScorer
|
||||
import com.twitter.frigate.pushservice.model.candidate.CopyIds
|
||||
import com.twitter.frigate.pushservice.model.ibis.TopicProofTweetIbis2Hydrator
|
||||
import com.twitter.frigate.pushservice.model.ntab.TopicProofTweetNtabRequestHydrator
|
||||
import com.twitter.frigate.pushservice.params.PushFeatureSwitchParams
|
||||
import com.twitter.frigate.pushservice.predicate.PredicatesForCandidate
|
||||
import com.twitter.frigate.pushservice.take.predicates.BasicTweetPredicatesForRFPH
|
||||
import com.twitter.frigate.thriftscala.CommonRecommendationType
|
||||
import com.twitter.gizmoduck.thriftscala.User
|
||||
import com.twitter.hermit.predicate.NamedPredicate
|
||||
import com.twitter.stitch.tweetypie.TweetyPie
|
||||
import com.twitter.util.Future
|
||||
|
||||
/**
|
||||
* This class defines a hydrated [[TopicProofTweetCandidate]]
|
||||
*
|
||||
* @param candidate : [[TopicProofTweetCandidate]] for the candidate representint a Tweet recommendation for followed Topic
|
||||
* @param author : Tweet author representated as Gizmoduck user object
|
||||
* @param copyIds : push and ntab notification copy
|
||||
* @param stats : finagle scoped states receiver
|
||||
* @param pushModelScorer : ML model score object for fetching prediction scores
|
||||
*/
|
||||
class TopicProofTweetPushCandidate(
|
||||
candidate: RawCandidate with TopicProofTweetCandidate,
|
||||
author: Option[User],
|
||||
copyIds: CopyIds
|
||||
)(
|
||||
implicit stats: StatsReceiver,
|
||||
pushModelScorer: PushMLModelScorer)
|
||||
extends PushCandidate
|
||||
with TopicProofTweetCandidate
|
||||
with TweetAuthorDetails
|
||||
with TopicProofTweetNtabRequestHydrator
|
||||
with TopicProofTweetIbis2Hydrator {
|
||||
override val statsReceiver: StatsReceiver = stats
|
||||
override val target: PushTypes.Target = candidate.target
|
||||
override val tweetId: Long = candidate.tweetId
|
||||
override lazy val tweetyPieResult: Option[TweetyPie.TweetyPieResult] = candidate.tweetyPieResult
|
||||
override val weightedOpenOrNtabClickModelScorer: PushMLModelScorer = pushModelScorer
|
||||
override val pushCopyId: Option[Int] = copyIds.pushCopyId
|
||||
override val ntabCopyId: Option[Int] = copyIds.ntabCopyId
|
||||
override val copyAggregationId: Option[String] = copyIds.aggregationId
|
||||
override val semanticCoreEntityId = candidate.semanticCoreEntityId
|
||||
override val localizedUttEntity = candidate.localizedUttEntity
|
||||
override val tweetAuthor = Future.value(author)
|
||||
override val topicListingSetting = candidate.topicListingSetting
|
||||
override val algorithmCR = candidate.algorithmCR
|
||||
override val commonRecType: CommonRecommendationType = candidate.commonRecType
|
||||
override val tagsCR = candidate.tagsCR
|
||||
override val isOutOfNetwork = candidate.isOutOfNetwork
|
||||
}
|
||||
|
||||
case class TopicProofTweetCandidatePredicates(override val config: Config)
|
||||
extends BasicTweetPredicatesForRFPH[TopicProofTweetPushCandidate] {
|
||||
implicit val statsReceiver: StatsReceiver = config.statsReceiver.scope(getClass.getSimpleName)
|
||||
|
||||
override val preCandidateSpecificPredicates: List[NamedPredicate[TopicProofTweetPushCandidate]] =
|
||||
List(
|
||||
PredicatesForCandidate.paramPredicate(
|
||||
PushFeatureSwitchParams.EnableTopicProofTweetRecs
|
||||
),
|
||||
)
|
||||
}
|
@ -0,0 +1,50 @@
|
||||
package com.twitter.frigate.pushservice.model
|
||||
|
||||
import com.twitter.events.recos.thriftscala.TrendsContext
|
||||
import com.twitter.finagle.stats.StatsReceiver
|
||||
import com.twitter.frigate.common.base.TrendTweetCandidate
|
||||
import com.twitter.frigate.common.base.TweetAuthorDetails
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.PushCandidate
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.RawCandidate
|
||||
import com.twitter.frigate.pushservice.config.Config
|
||||
import com.twitter.frigate.pushservice.ml.PushMLModelScorer
|
||||
import com.twitter.frigate.pushservice.model.candidate.CopyIds
|
||||
import com.twitter.frigate.pushservice.model.ibis.TrendTweetIbis2Hydrator
|
||||
import com.twitter.frigate.pushservice.model.ntab.TrendTweetNtabHydrator
|
||||
import com.twitter.frigate.pushservice.take.predicates.BasicTweetPredicatesForRFPH
|
||||
import com.twitter.gizmoduck.thriftscala.User
|
||||
import com.twitter.stitch.tweetypie.TweetyPie
|
||||
import com.twitter.util.Future
|
||||
|
||||
class TrendTweetPushCandidate(
|
||||
candidate: RawCandidate with TrendTweetCandidate,
|
||||
author: Option[User],
|
||||
copyIds: CopyIds
|
||||
)(
|
||||
implicit stats: StatsReceiver,
|
||||
pushModelScorer: PushMLModelScorer)
|
||||
extends PushCandidate
|
||||
with TrendTweetCandidate
|
||||
with TweetAuthorDetails
|
||||
with TrendTweetIbis2Hydrator
|
||||
with TrendTweetNtabHydrator {
|
||||
override val statsReceiver: StatsReceiver = stats
|
||||
override val weightedOpenOrNtabClickModelScorer: PushMLModelScorer = pushModelScorer
|
||||
override val tweetId: Long = candidate.tweetId
|
||||
override lazy val tweetyPieResult: Option[TweetyPie.TweetyPieResult] = candidate.tweetyPieResult
|
||||
override lazy val tweetAuthor: Future[Option[User]] = Future.value(author)
|
||||
override val target: PushTypes.Target = candidate.target
|
||||
override val landingUrl: String = candidate.landingUrl
|
||||
override val timeBoundedLandingUrl: Option[String] = candidate.timeBoundedLandingUrl
|
||||
override val pushCopyId: Option[Int] = copyIds.pushCopyId
|
||||
override val ntabCopyId: Option[Int] = copyIds.ntabCopyId
|
||||
override val trendId: String = candidate.trendId
|
||||
override val trendName: String = candidate.trendName
|
||||
override val copyAggregationId: Option[String] = copyIds.aggregationId
|
||||
override val context: TrendsContext = candidate.context
|
||||
}
|
||||
|
||||
case class TrendTweetPredicates(override val config: Config)
|
||||
extends BasicTweetPredicatesForRFPH[TrendTweetPushCandidate] {
|
||||
implicit val statsReceiver: StatsReceiver = config.statsReceiver.scope(getClass.getSimpleName)
|
||||
}
|
@ -0,0 +1,60 @@
|
||||
package com.twitter.frigate.pushservice.model
|
||||
|
||||
import com.twitter.finagle.stats.StatsReceiver
|
||||
import com.twitter.frigate.common.base.OutOfNetworkTweetCandidate
|
||||
import com.twitter.frigate.common.base.TopicCandidate
|
||||
import com.twitter.frigate.common.base.TripCandidate
|
||||
import com.twitter.frigate.common.base.TweetAuthorDetails
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.PushCandidate
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.RawCandidate
|
||||
import com.twitter.frigate.pushservice.config.Config
|
||||
import com.twitter.frigate.pushservice.ml.PushMLModelScorer
|
||||
import com.twitter.frigate.pushservice.model.candidate.CopyIds
|
||||
import com.twitter.frigate.pushservice.model.ibis.OutOfNetworkTweetIbis2HydratorForCandidate
|
||||
import com.twitter.frigate.pushservice.model.ntab.OutOfNetworkTweetNTabRequestHydrator
|
||||
import com.twitter.frigate.pushservice.take.predicates.OutOfNetworkTweetPredicates
|
||||
import com.twitter.frigate.thriftscala.CommonRecommendationType
|
||||
import com.twitter.gizmoduck.thriftscala.User
|
||||
import com.twitter.stitch.tweetypie.TweetyPie
|
||||
import com.twitter.topiclisting.utt.LocalizedEntity
|
||||
import com.twitter.trends.trip_v1.trip_tweets.thriftscala.TripDomain
|
||||
import com.twitter.util.Future
|
||||
|
||||
class TripTweetPushCandidate(
|
||||
candidate: RawCandidate with OutOfNetworkTweetCandidate with TripCandidate,
|
||||
author: Future[Option[User]],
|
||||
copyIds: CopyIds
|
||||
)(
|
||||
implicit stats: StatsReceiver,
|
||||
pushModelScorer: PushMLModelScorer)
|
||||
extends PushCandidate
|
||||
with TripCandidate
|
||||
with TopicCandidate
|
||||
with OutOfNetworkTweetCandidate
|
||||
with TweetAuthorDetails
|
||||
with OutOfNetworkTweetNTabRequestHydrator
|
||||
with OutOfNetworkTweetIbis2HydratorForCandidate {
|
||||
override val statsReceiver: StatsReceiver = stats
|
||||
override val weightedOpenOrNtabClickModelScorer: PushMLModelScorer = pushModelScorer
|
||||
override val tweetId: Long = candidate.tweetId
|
||||
override lazy val tweetyPieResult: Option[TweetyPie.TweetyPieResult] =
|
||||
candidate.tweetyPieResult
|
||||
override lazy val tweetAuthor: Future[Option[User]] = author
|
||||
override val target: PushTypes.Target = candidate.target
|
||||
override lazy val commonRecType: CommonRecommendationType =
|
||||
candidate.commonRecType
|
||||
override val pushCopyId: Option[Int] = copyIds.pushCopyId
|
||||
override val ntabCopyId: Option[Int] = copyIds.ntabCopyId
|
||||
override val copyAggregationId: Option[String] = copyIds.aggregationId
|
||||
override lazy val semanticCoreEntityId: Option[Long] = None
|
||||
override lazy val localizedUttEntity: Option[LocalizedEntity] = None
|
||||
override lazy val algorithmCR: Option[String] = None
|
||||
override val tripDomain: Option[collection.Set[TripDomain]] = candidate.tripDomain
|
||||
}
|
||||
|
||||
case class TripTweetCandidatePredicates(override val config: Config)
|
||||
extends OutOfNetworkTweetPredicates[TripTweetPushCandidate] {
|
||||
|
||||
implicit val statsReceiver: StatsReceiver = config.statsReceiver.scope(getClass.getSimpleName)
|
||||
|
||||
}
|
@ -0,0 +1,26 @@
|
||||
package com.twitter.frigate.pushservice.model
|
||||
|
||||
import com.twitter.finagle.stats.StatsReceiver
|
||||
import com.twitter.frigate.common.base.SocialContextActions
|
||||
import com.twitter.frigate.common.base.TweetCandidate
|
||||
import com.twitter.frigate.common.base.TweetDetails
|
||||
import com.twitter.frigate.pushservice.model.PushTypes._
|
||||
import com.twitter.frigate.pushservice.config.Config
|
||||
import com.twitter.frigate.pushservice.predicate._
|
||||
import com.twitter.frigate.pushservice.take.predicates.BasicTweetPredicatesForRFPH
|
||||
|
||||
case class TweetActionCandidatePredicates(override val config: Config)
|
||||
extends BasicTweetPredicatesForRFPH[
|
||||
PushCandidate with TweetCandidate with TweetDetails with SocialContextActions
|
||||
] {
|
||||
|
||||
implicit val statsReceiver: StatsReceiver = config.statsReceiver.scope(getClass.getSimpleName)
|
||||
|
||||
override val preCandidateSpecificPredicates = List(PredicatesForCandidate.minSocialContext(1))
|
||||
|
||||
override val postCandidateSpecificPredicates = List(
|
||||
PredicatesForCandidate.socialContextBeingFollowed(config.edgeStore),
|
||||
PredicatesForCandidate.socialContextBlockingOrMuting(config.edgeStore),
|
||||
PredicatesForCandidate.socialContextNotRetweetFollowing(config.edgeStore)
|
||||
)
|
||||
}
|
@ -0,0 +1,53 @@
|
||||
package com.twitter.frigate.pushservice.model
|
||||
|
||||
import com.twitter.finagle.stats.StatsReceiver
|
||||
import com.twitter.frigate.common.base.SocialContextAction
|
||||
import com.twitter.frigate.common.base.SocialContextUserDetails
|
||||
import com.twitter.frigate.common.base.TweetAuthorDetails
|
||||
import com.twitter.frigate.common.base.TweetFavoriteCandidate
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.PushCandidate
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.RawCandidate
|
||||
import com.twitter.frigate.pushservice.ml.PushMLModelScorer
|
||||
import com.twitter.frigate.pushservice.model.candidate.CopyIds
|
||||
import com.twitter.frigate.pushservice.model.ibis.TweetFavoriteCandidateIbis2Hydrator
|
||||
import com.twitter.frigate.pushservice.model.ntab.TweetFavoriteNTabRequestHydrator
|
||||
import com.twitter.frigate.pushservice.util.CandidateHydrationUtil.TweetWithSocialContextTraits
|
||||
import com.twitter.frigate.thriftscala.CommonRecommendationType
|
||||
import com.twitter.gizmoduck.thriftscala.User
|
||||
import com.twitter.stitch.tweetypie.TweetyPie
|
||||
import com.twitter.util.Future
|
||||
|
||||
class TweetFavoritePushCandidate(
|
||||
candidate: RawCandidate with TweetWithSocialContextTraits,
|
||||
socialContextUserMap: Future[Map[Long, Option[User]]],
|
||||
author: Future[Option[User]],
|
||||
copyIds: CopyIds
|
||||
)(
|
||||
implicit stats: StatsReceiver,
|
||||
pushModelScorer: PushMLModelScorer)
|
||||
extends PushCandidate
|
||||
with TweetFavoriteCandidate
|
||||
with SocialContextUserDetails
|
||||
with TweetAuthorDetails
|
||||
with TweetFavoriteNTabRequestHydrator
|
||||
with TweetFavoriteCandidateIbis2Hydrator {
|
||||
override val statsReceiver: StatsReceiver = stats
|
||||
override val weightedOpenOrNtabClickModelScorer: PushMLModelScorer = pushModelScorer
|
||||
override val tweetId: Long = candidate.tweetId
|
||||
override val socialContextActions: Seq[SocialContextAction] =
|
||||
candidate.socialContextActions
|
||||
|
||||
override val socialContextAllTypeActions: Seq[SocialContextAction] =
|
||||
candidate.socialContextAllTypeActions
|
||||
|
||||
override lazy val scUserMap: Future[Map[Long, Option[User]]] = socialContextUserMap
|
||||
override lazy val tweetAuthor: Future[Option[User]] = author
|
||||
override lazy val commonRecType: CommonRecommendationType =
|
||||
candidate.commonRecType
|
||||
override val target: PushTypes.Target = candidate.target
|
||||
override lazy val tweetyPieResult: Option[TweetyPie.TweetyPieResult] =
|
||||
candidate.tweetyPieResult
|
||||
override val pushCopyId: Option[Int] = copyIds.pushCopyId
|
||||
override val ntabCopyId: Option[Int] = copyIds.ntabCopyId
|
||||
override val copyAggregationId: Option[String] = copyIds.aggregationId
|
||||
}
|
@ -0,0 +1,51 @@
|
||||
package com.twitter.frigate.pushservice.model
|
||||
|
||||
import com.twitter.finagle.stats.StatsReceiver
|
||||
import com.twitter.frigate.common.base.SocialContextAction
|
||||
import com.twitter.frigate.common.base.SocialContextUserDetails
|
||||
import com.twitter.frigate.common.base.TweetAuthorDetails
|
||||
import com.twitter.frigate.common.base.TweetRetweetCandidate
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.PushCandidate
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.RawCandidate
|
||||
import com.twitter.frigate.pushservice.ml.PushMLModelScorer
|
||||
import com.twitter.frigate.pushservice.model.candidate.CopyIds
|
||||
import com.twitter.frigate.pushservice.model.ibis.TweetRetweetCandidateIbis2Hydrator
|
||||
import com.twitter.frigate.pushservice.model.ntab.TweetRetweetNTabRequestHydrator
|
||||
import com.twitter.frigate.pushservice.util.CandidateHydrationUtil.TweetWithSocialContextTraits
|
||||
import com.twitter.frigate.thriftscala.CommonRecommendationType
|
||||
import com.twitter.gizmoduck.thriftscala.User
|
||||
import com.twitter.stitch.tweetypie.TweetyPie
|
||||
import com.twitter.util.Future
|
||||
|
||||
class TweetRetweetPushCandidate(
|
||||
candidate: RawCandidate with TweetWithSocialContextTraits,
|
||||
socialContextUserMap: Future[Map[Long, Option[User]]],
|
||||
author: Future[Option[User]],
|
||||
copyIds: CopyIds
|
||||
)(
|
||||
implicit stats: StatsReceiver,
|
||||
pushModelScorer: PushMLModelScorer)
|
||||
extends PushCandidate
|
||||
with TweetRetweetCandidate
|
||||
with SocialContextUserDetails
|
||||
with TweetAuthorDetails
|
||||
with TweetRetweetNTabRequestHydrator
|
||||
with TweetRetweetCandidateIbis2Hydrator {
|
||||
override val statsReceiver: StatsReceiver = stats
|
||||
override val weightedOpenOrNtabClickModelScorer: PushMLModelScorer = pushModelScorer
|
||||
override val tweetId: Long = candidate.tweetId
|
||||
override val socialContextActions: Seq[SocialContextAction] =
|
||||
candidate.socialContextActions
|
||||
|
||||
override val socialContextAllTypeActions: Seq[SocialContextAction] =
|
||||
candidate.socialContextAllTypeActions
|
||||
|
||||
override lazy val scUserMap: Future[Map[Long, Option[User]]] = socialContextUserMap
|
||||
override lazy val tweetAuthor: Future[Option[User]] = author
|
||||
override lazy val commonRecType: CommonRecommendationType = candidate.commonRecType
|
||||
override val target: PushTypes.Target = candidate.target
|
||||
override lazy val tweetyPieResult: Option[TweetyPie.TweetyPieResult] = candidate.tweetyPieResult
|
||||
override val pushCopyId: Option[Int] = copyIds.pushCopyId
|
||||
override val ntabCopyId: Option[Int] = copyIds.ntabCopyId
|
||||
override val copyAggregationId: Option[String] = copyIds.aggregationId
|
||||
}
|
@ -0,0 +1,33 @@
|
||||
package com.twitter.frigate.pushservice.model.candidate
|
||||
|
||||
import com.twitter.frigate.common.util.MRPushCopy
|
||||
import com.twitter.frigate.common.util.MrPushCopyObjects
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.PushCandidate
|
||||
import com.twitter.frigate.pushservice.util.CandidateUtil
|
||||
|
||||
case class CopyIds(
|
||||
pushCopyId: Option[Int] = None,
|
||||
ntabCopyId: Option[Int] = None,
|
||||
aggregationId: Option[String] = None)
|
||||
|
||||
trait CopyInfo {
|
||||
self: PushCandidate =>
|
||||
|
||||
import com.twitter.frigate.data_pipeline.common.FrigateNotificationUtil._
|
||||
|
||||
def getPushCopy: Option[MRPushCopy] =
|
||||
pushCopyId match {
|
||||
case Some(pushCopyId) => MrPushCopyObjects.getCopyFromId(pushCopyId)
|
||||
case _ =>
|
||||
crt2PushCopy(
|
||||
commonRecType,
|
||||
CandidateUtil.getSocialContextActionsFromCandidate(self).size
|
||||
)
|
||||
}
|
||||
|
||||
def pushCopyId: Option[Int]
|
||||
|
||||
def ntabCopyId: Option[Int]
|
||||
|
||||
def copyAggregationId: Option[String]
|
||||
}
|
@ -0,0 +1,307 @@
|
||||
package com.twitter.frigate.pushservice.model.candidate
|
||||
|
||||
import com.twitter.frigate.common.base.FeatureMap
|
||||
import com.twitter.frigate.common.rec_types.RecTypes
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.PushCandidate
|
||||
import com.twitter.frigate.pushservice.ml.HydrationContextBuilder
|
||||
import com.twitter.frigate.pushservice.ml.PushMLModelScorer
|
||||
import com.twitter.frigate.pushservice.params.PushFeatureSwitchParams
|
||||
import com.twitter.frigate.pushservice.params.PushMLModel
|
||||
import com.twitter.frigate.pushservice.params.WeightedOpenOrNtabClickModel
|
||||
import com.twitter.nrel.hydration.push.HydrationContext
|
||||
import com.twitter.timelines.configapi.FSParam
|
||||
import com.twitter.util.Future
|
||||
import java.util.concurrent.ConcurrentHashMap
|
||||
import scala.collection.concurrent.{Map => CMap}
|
||||
import scala.collection.convert.decorateAsScala._
|
||||
|
||||
trait MLScores {
|
||||
|
||||
self: PushCandidate =>
|
||||
|
||||
lazy val candidateHydrationContext: Future[HydrationContext] = HydrationContextBuilder.build(self)
|
||||
|
||||
def weightedOpenOrNtabClickModelScorer: PushMLModelScorer
|
||||
|
||||
// Used to store the scores and avoid duplicate prediction
|
||||
private val qualityModelScores: CMap[
|
||||
(PushMLModel.Value, WeightedOpenOrNtabClickModel.ModelNameType),
|
||||
Future[Option[Double]]
|
||||
] =
|
||||
new ConcurrentHashMap[(PushMLModel.Value, WeightedOpenOrNtabClickModel.ModelNameType), Future[
|
||||
Option[Double]
|
||||
]]().asScala
|
||||
|
||||
def populateQualityModelScore(
|
||||
pushMLModel: PushMLModel.Value,
|
||||
modelVersion: WeightedOpenOrNtabClickModel.ModelNameType,
|
||||
prob: Future[Option[Double]]
|
||||
) = {
|
||||
val modelAndVersion = (pushMLModel, modelVersion)
|
||||
if (!qualityModelScores.contains(modelAndVersion)) {
|
||||
qualityModelScores += modelAndVersion -> prob
|
||||
}
|
||||
}
|
||||
|
||||
// The ML scores that also depend on other candidates and are only available after all candidates are processed
|
||||
// For example, the likelihood info for Importance Sampling
|
||||
private lazy val crossCandidateMlScores: CMap[String, Double] =
|
||||
new ConcurrentHashMap[String, Double]().asScala
|
||||
|
||||
def populateCrossCandidateMlScores(scoreName: String, score: Double): Unit = {
|
||||
if (crossCandidateMlScores.contains(scoreName)) {
|
||||
throw new Exception(
|
||||
s"$scoreName has been populated in the CrossCandidateMlScores!\n" +
|
||||
s"Existing crossCandidateMlScores are ${crossCandidateMlScores}\n"
|
||||
)
|
||||
}
|
||||
crossCandidateMlScores += scoreName -> score
|
||||
}
|
||||
|
||||
def getMLModelScore(
|
||||
pushMLModel: PushMLModel.Value,
|
||||
modelVersion: WeightedOpenOrNtabClickModel.ModelNameType
|
||||
): Future[Option[Double]] = {
|
||||
qualityModelScores.getOrElseUpdate(
|
||||
(pushMLModel, modelVersion),
|
||||
weightedOpenOrNtabClickModelScorer
|
||||
.singlePredicationForModelVersion(modelVersion, self, Some(pushMLModel))
|
||||
)
|
||||
}
|
||||
|
||||
def getMLModelScoreWithoutUpdate(
|
||||
pushMLModel: PushMLModel.Value,
|
||||
modelVersion: WeightedOpenOrNtabClickModel.ModelNameType
|
||||
): Future[Option[Double]] = {
|
||||
qualityModelScores.getOrElse(
|
||||
(pushMLModel, modelVersion),
|
||||
Future.None
|
||||
)
|
||||
}
|
||||
|
||||
def getWeightedOpenOrNtabClickModelScore(
|
||||
weightedOONCModelParam: FSParam[WeightedOpenOrNtabClickModel.ModelNameType]
|
||||
): Future[Option[Double]] = {
|
||||
getMLModelScore(
|
||||
PushMLModel.WeightedOpenOrNtabClickProbability,
|
||||
target.params(weightedOONCModelParam)
|
||||
)
|
||||
}
|
||||
|
||||
/* After we unify the ranking and filtering models, we follow the iteration process below
|
||||
When improving the WeightedOONC model,
|
||||
1) Run experiment which only replace the ranking model
|
||||
2) Make decisions according to the experiment results
|
||||
3) Use the ranking model for filtering
|
||||
4) Adjust percentile thresholds if necessary
|
||||
*/
|
||||
lazy val mrWeightedOpenOrNtabClickRankingProbability: Future[Option[Double]] =
|
||||
target.rankingModelParam.flatMap { modelParam =>
|
||||
getWeightedOpenOrNtabClickModelScore(modelParam)
|
||||
}
|
||||
|
||||
def getBigFilteringScore(
|
||||
pushMLModel: PushMLModel.Value,
|
||||
modelVersion: WeightedOpenOrNtabClickModel.ModelNameType
|
||||
): Future[Option[Double]] = {
|
||||
mrWeightedOpenOrNtabClickRankingProbability.flatMap {
|
||||
case Some(rankingScore) =>
|
||||
// Adds ranking score to feature map (we must ensure the feature key is also in the feature context)
|
||||
mergeFeatures(
|
||||
FeatureMap(
|
||||
numericFeatures = Map("scribe.WeightedOpenOrNtabClickProbability" -> rankingScore)
|
||||
)
|
||||
)
|
||||
getMLModelScore(pushMLModel, modelVersion)
|
||||
case _ => Future.None
|
||||
}
|
||||
}
|
||||
|
||||
def getWeightedOpenOrNtabClickScoreForScribing(): Seq[Future[Map[String, Double]]] = {
|
||||
Seq(
|
||||
mrWeightedOpenOrNtabClickRankingProbability.map {
|
||||
case Some(score) => Map(PushMLModel.WeightedOpenOrNtabClickProbability.toString -> score)
|
||||
case _ => Map.empty[String, Double]
|
||||
},
|
||||
Future
|
||||
.join(
|
||||
target.rankingModelParam,
|
||||
mrWeightedOpenOrNtabClickRankingProbability
|
||||
).map {
|
||||
case (rankingModelParam, Some(score)) =>
|
||||
Map(target.params(rankingModelParam).toString -> score)
|
||||
case _ => Map.empty[String, Double]
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
def getNsfwScoreForScribing(): Seq[Future[Map[String, Double]]] = {
|
||||
val nsfwScoreFut = getMLModelScoreWithoutUpdate(
|
||||
PushMLModel.HealthNsfwProbability,
|
||||
target.params(PushFeatureSwitchParams.BqmlHealthModelTypeParam))
|
||||
Seq(nsfwScoreFut.map { nsfwScoreOpt =>
|
||||
nsfwScoreOpt
|
||||
.map(nsfwScore => Map(PushMLModel.HealthNsfwProbability.toString -> nsfwScore)).getOrElse(
|
||||
Map.empty[String, Double])
|
||||
})
|
||||
}
|
||||
|
||||
def getBigFilteringSupervisedScoresForScribing(): Seq[Future[Map[String, Double]]] = {
|
||||
if (target.params(
|
||||
PushFeatureSwitchParams.EnableMrRequestScribingBigFilteringSupervisedScores)) {
|
||||
Seq(
|
||||
mrBigFilteringSupervisedSendingScore.map {
|
||||
case Some(score) =>
|
||||
Map(PushMLModel.BigFilteringSupervisedSendingModel.toString -> score)
|
||||
case _ => Map.empty[String, Double]
|
||||
},
|
||||
mrBigFilteringSupervisedWithoutSendingScore.map {
|
||||
case Some(score) =>
|
||||
Map(PushMLModel.BigFilteringSupervisedWithoutSendingModel.toString -> score)
|
||||
case _ => Map.empty[String, Double]
|
||||
}
|
||||
)
|
||||
} else Seq.empty[Future[Map[String, Double]]]
|
||||
}
|
||||
|
||||
def getBigFilteringRLScoresForScribing(): Seq[Future[Map[String, Double]]] = {
|
||||
if (target.params(PushFeatureSwitchParams.EnableMrRequestScribingBigFilteringRLScores)) {
|
||||
Seq(
|
||||
mrBigFilteringRLSendingScore.map {
|
||||
case Some(score) => Map(PushMLModel.BigFilteringRLSendingModel.toString -> score)
|
||||
case _ => Map.empty[String, Double]
|
||||
},
|
||||
mrBigFilteringRLWithoutSendingScore.map {
|
||||
case Some(score) => Map(PushMLModel.BigFilteringRLWithoutSendingModel.toString -> score)
|
||||
case _ => Map.empty[String, Double]
|
||||
}
|
||||
)
|
||||
} else Seq.empty[Future[Map[String, Double]]]
|
||||
}
|
||||
|
||||
def buildModelScoresSeqForScribing(): Seq[Future[Map[String, Double]]] = {
|
||||
getWeightedOpenOrNtabClickScoreForScribing() ++
|
||||
getBigFilteringSupervisedScoresForScribing() ++
|
||||
getBigFilteringRLScoresForScribing() ++
|
||||
getNsfwScoreForScribing()
|
||||
}
|
||||
|
||||
lazy val mrBigFilteringSupervisedSendingScore: Future[Option[Double]] =
|
||||
getBigFilteringScore(
|
||||
PushMLModel.BigFilteringSupervisedSendingModel,
|
||||
target.params(PushFeatureSwitchParams.BigFilteringSupervisedSendingModelParam)
|
||||
)
|
||||
|
||||
lazy val mrBigFilteringSupervisedWithoutSendingScore: Future[Option[Double]] =
|
||||
getBigFilteringScore(
|
||||
PushMLModel.BigFilteringSupervisedWithoutSendingModel,
|
||||
target.params(PushFeatureSwitchParams.BigFilteringSupervisedWithoutSendingModelParam)
|
||||
)
|
||||
|
||||
lazy val mrBigFilteringRLSendingScore: Future[Option[Double]] =
|
||||
getBigFilteringScore(
|
||||
PushMLModel.BigFilteringRLSendingModel,
|
||||
target.params(PushFeatureSwitchParams.BigFilteringRLSendingModelParam)
|
||||
)
|
||||
|
||||
lazy val mrBigFilteringRLWithoutSendingScore: Future[Option[Double]] =
|
||||
getBigFilteringScore(
|
||||
PushMLModel.BigFilteringRLWithoutSendingModel,
|
||||
target.params(PushFeatureSwitchParams.BigFilteringRLWithoutSendingModelParam)
|
||||
)
|
||||
|
||||
lazy val mrWeightedOpenOrNtabClickFilteringProbability: Future[Option[Double]] =
|
||||
getWeightedOpenOrNtabClickModelScore(
|
||||
target.filteringModelParam
|
||||
)
|
||||
|
||||
lazy val mrQualityUprankingProbability: Future[Option[Double]] =
|
||||
getMLModelScore(
|
||||
PushMLModel.FilteringProbability,
|
||||
target.params(PushFeatureSwitchParams.QualityUprankingModelTypeParam)
|
||||
)
|
||||
|
||||
lazy val mrNsfwScore: Future[Option[Double]] =
|
||||
getMLModelScoreWithoutUpdate(
|
||||
PushMLModel.HealthNsfwProbability,
|
||||
target.params(PushFeatureSwitchParams.BqmlHealthModelTypeParam)
|
||||
)
|
||||
|
||||
// MR quality upranking param
|
||||
private val qualityUprankingBoost: String = "QualityUprankingBoost"
|
||||
private val producerQualityUprankingBoost: String = "ProducerQualityUprankingBoost"
|
||||
private val qualityUprankingInfo: CMap[String, Double] =
|
||||
new ConcurrentHashMap[String, Double]().asScala
|
||||
|
||||
lazy val mrQualityUprankingBoost: Option[Double] =
|
||||
qualityUprankingInfo.get(qualityUprankingBoost)
|
||||
lazy val mrProducerQualityUprankingBoost: Option[Double] =
|
||||
qualityUprankingInfo.get(producerQualityUprankingBoost)
|
||||
|
||||
def setQualityUprankingBoost(boost: Double) =
|
||||
if (qualityUprankingInfo.contains(qualityUprankingBoost)) {
|
||||
qualityUprankingInfo(qualityUprankingBoost) = boost
|
||||
} else {
|
||||
qualityUprankingInfo += qualityUprankingBoost -> boost
|
||||
}
|
||||
def setProducerQualityUprankingBoost(boost: Double) =
|
||||
if (qualityUprankingInfo.contains(producerQualityUprankingBoost)) {
|
||||
qualityUprankingInfo(producerQualityUprankingBoost) = boost
|
||||
} else {
|
||||
qualityUprankingInfo += producerQualityUprankingBoost -> boost
|
||||
}
|
||||
|
||||
private lazy val mrModelScoresFut: Future[Map[String, Double]] = {
|
||||
if (self.target.isLoggedOutUser) {
|
||||
Future.value(Map.empty[String, Double])
|
||||
} else {
|
||||
Future
|
||||
.collectToTry {
|
||||
buildModelScoresSeqForScribing()
|
||||
}.map { scoreTrySeq =>
|
||||
scoreTrySeq
|
||||
.collect {
|
||||
case result if result.isReturn => result.get()
|
||||
}.reduce(_ ++ _)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Internal model scores (scores that are independent of other candidates) for scribing
|
||||
lazy val modelScores: Future[Map[String, Double]] =
|
||||
target.dauProbability.flatMap { dauProbabilityOpt =>
|
||||
val dauProbScoreMap = dauProbabilityOpt
|
||||
.map(_.probability).map { dauProb =>
|
||||
PushMLModel.DauProbability.toString -> dauProb
|
||||
}.toMap
|
||||
|
||||
// Avoid unnecessary MR model scribing
|
||||
if (target.isDarkWrite) {
|
||||
mrModelScoresFut.map(dauProbScoreMap ++ _)
|
||||
} else if (RecTypes.isSendHandlerType(commonRecType) && !RecTypes
|
||||
.sendHandlerTypesUsingMrModel(commonRecType)) {
|
||||
Future.value(dauProbScoreMap)
|
||||
} else {
|
||||
mrModelScoresFut.map(dauProbScoreMap ++ _)
|
||||
}
|
||||
}
|
||||
|
||||
// We will scribe both internal ML scores and cross-Candidate scores
|
||||
def getModelScoresforScribing(): Future[Map[String, Double]] = {
|
||||
if (RecTypes.notEligibleForModelScoreTracking(commonRecType) || self.target.isLoggedOutUser) {
|
||||
Future.value(Map.empty[String, Double])
|
||||
} else {
|
||||
modelScores.map { internalScores =>
|
||||
if (internalScores.keySet.intersect(crossCandidateMlScores.keySet).nonEmpty) {
|
||||
throw new Exception(
|
||||
"crossCandidateMlScores overlap internalModelScores\n" +
|
||||
s"internalScores keySet: ${internalScores.keySet}\n" +
|
||||
s"crossCandidateScores keySet: ${crossCandidateMlScores.keySet}\n"
|
||||
)
|
||||
}
|
||||
|
||||
internalScores ++ crossCandidateMlScores
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,104 @@
|
||||
package com.twitter.frigate.pushservice.model.candidate
|
||||
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.PushCandidate
|
||||
import com.twitter.frigate.pushservice.params.HighQualityScribingScores
|
||||
import com.twitter.frigate.pushservice.params.PushFeatureSwitchParams
|
||||
import com.twitter.frigate.pushservice.params.PushMLModel
|
||||
import com.twitter.util.Future
|
||||
import java.util.concurrent.ConcurrentHashMap
|
||||
import scala.collection.concurrent.{Map => CMap}
|
||||
import scala.collection.convert.decorateAsScala._
|
||||
|
||||
trait QualityScribing {
|
||||
self: PushCandidate with MLScores =>
|
||||
|
||||
// Use to store other scores (to avoid duplicate queries to other services, e.g. HSS)
|
||||
private val externalCachedScores: CMap[String, Future[Option[Double]]] =
|
||||
new ConcurrentHashMap[String, Future[Option[Double]]]().asScala
|
||||
|
||||
/**
|
||||
* Retrieves the model version as specified by the corresponding FS param.
|
||||
* This model version will be used for getting the cached score or triggering
|
||||
* a prediction request.
|
||||
*
|
||||
* @param modelName The score we will like to scribe
|
||||
*/
|
||||
private def getModelVersion(
|
||||
modelName: HighQualityScribingScores.Name
|
||||
): String = {
|
||||
modelName match {
|
||||
case HighQualityScribingScores.HeavyRankingScore =>
|
||||
target.params(PushFeatureSwitchParams.HighQualityCandidatesHeavyRankingModel)
|
||||
case HighQualityScribingScores.NonPersonalizedQualityScoreUsingCnn =>
|
||||
target.params(PushFeatureSwitchParams.HighQualityCandidatesNonPersonalizedQualityCnnModel)
|
||||
case HighQualityScribingScores.BqmlNsfwScore =>
|
||||
target.params(PushFeatureSwitchParams.HighQualityCandidatesBqmlNsfwModel)
|
||||
case HighQualityScribingScores.BqmlReportScore =>
|
||||
target.params(PushFeatureSwitchParams.HighQualityCandidatesBqmlReportModel)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Retrieves the score for scribing either from a cached value or
|
||||
* by generating a prediction request. This will increase model QPS
|
||||
*
|
||||
* @param pushMLModel This represents the prefix of the model name (i.e. [pushMLModel]_[version])
|
||||
* @param scoreName The name to be use when scribing this score
|
||||
*/
|
||||
def getScribingScore(
|
||||
pushMLModel: PushMLModel.Value,
|
||||
scoreName: HighQualityScribingScores.Name
|
||||
): Future[(String, Option[Double])] = {
|
||||
getMLModelScore(
|
||||
pushMLModel,
|
||||
getModelVersion(scoreName)
|
||||
).map { scoreOpt =>
|
||||
scoreName.toString -> scoreOpt
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Retrieves the score for scribing if it has been computed/cached before otherwise
|
||||
* it will return Future.None
|
||||
*
|
||||
* @param pushMLModel This represents the prefix of the model name (i.e. [pushMLModel]_[version])
|
||||
* @param scoreName The name to be use when scribing this score
|
||||
*/
|
||||
def getScribingScoreWithoutUpdate(
|
||||
pushMLModel: PushMLModel.Value,
|
||||
scoreName: HighQualityScribingScores.Name
|
||||
): Future[(String, Option[Double])] = {
|
||||
getMLModelScoreWithoutUpdate(
|
||||
pushMLModel,
|
||||
getModelVersion(scoreName)
|
||||
).map { scoreOpt =>
|
||||
scoreName.toString -> scoreOpt
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Caches the given score future
|
||||
*
|
||||
* @param scoreName The name to be use when scribing this score
|
||||
* @param scoreFut Future mapping scoreName -> scoreOpt
|
||||
*/
|
||||
def cacheExternalScore(scoreName: String, scoreFut: Future[Option[Double]]) = {
|
||||
if (!externalCachedScores.contains(scoreName)) {
|
||||
externalCachedScores += scoreName -> scoreFut
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns all external scores future cached as a sequence
|
||||
*/
|
||||
def getExternalCachedScores: Seq[Future[(String, Option[Double])]] = {
|
||||
externalCachedScores.map {
|
||||
case (modelName, scoreFut) =>
|
||||
scoreFut.map { scoreOpt => modelName -> scoreOpt }
|
||||
}.toSeq
|
||||
}
|
||||
|
||||
def getExternalCachedScoreByName(name: String): Future[Option[Double]] = {
|
||||
externalCachedScores.getOrElse(name, Future.None)
|
||||
}
|
||||
}
|
@ -0,0 +1,277 @@
|
||||
package com.twitter.frigate.pushservice.model.candidate
|
||||
|
||||
import com.twitter.frigate.data_pipeline.features_common.PushQualityModelFeatureContext.featureContext
|
||||
import com.twitter.frigate.data_pipeline.features_common.PushQualityModelUtil
|
||||
import com.twitter.frigate.pushservice.params.PushFeatureSwitchParams
|
||||
import com.twitter.frigate.pushservice.params.PushParams
|
||||
import com.twitter.finagle.stats.StatsReceiver
|
||||
import com.twitter.frigate.common.base._
|
||||
import com.twitter.frigate.common.rec_types.RecTypes
|
||||
import com.twitter.frigate.common.util.NotificationScribeUtil
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.PushCandidate
|
||||
import com.twitter.frigate.pushservice.model.OutOfNetworkTweetPushCandidate
|
||||
import com.twitter.frigate.pushservice.model.TopicProofTweetPushCandidate
|
||||
import com.twitter.frigate.pushservice.ml.HydrationContextBuilder
|
||||
import com.twitter.frigate.pushservice.predicate.quality_model_predicate.PDauCohort
|
||||
import com.twitter.frigate.pushservice.predicate.quality_model_predicate.PDauCohortUtil
|
||||
import com.twitter.frigate.pushservice.util.Candidate2FrigateNotification
|
||||
import com.twitter.frigate.pushservice.util.MediaAnnotationsUtil.sensitiveMediaCategoryFeatureName
|
||||
import com.twitter.frigate.scribe.thriftscala.FrigateNotificationScribeType
|
||||
import com.twitter.frigate.scribe.thriftscala.NotificationScribe
|
||||
import com.twitter.frigate.scribe.thriftscala.PredicateDetailedInfo
|
||||
import com.twitter.frigate.scribe.thriftscala.PushCapInfo
|
||||
import com.twitter.frigate.thriftscala.ChannelName
|
||||
import com.twitter.frigate.thriftscala.FrigateNotification
|
||||
import com.twitter.frigate.thriftscala.OverrideInfo
|
||||
import com.twitter.gizmoduck.thriftscala.User
|
||||
import com.twitter.hermit.model.user_state.UserState.UserState
|
||||
import com.twitter.ibis2.service.thriftscala.Ibis2Response
|
||||
import com.twitter.ml.api.util.ScalaToJavaDataRecordConversions
|
||||
import com.twitter.nrel.heavyranker.FeatureHydrator
|
||||
import com.twitter.util.Future
|
||||
import java.util.UUID
|
||||
import java.util.concurrent.ConcurrentHashMap
|
||||
import scala.collection.concurrent.{Map => CMap}
|
||||
import scala.collection.Map
|
||||
import scala.collection.convert.decorateAsScala._
|
||||
|
||||
trait Scriber {
|
||||
self: PushCandidate =>
|
||||
|
||||
def statsReceiver: StatsReceiver
|
||||
|
||||
def frigateNotification: FrigateNotification = Candidate2FrigateNotification
|
||||
.getFrigateNotification(self)(statsReceiver)
|
||||
.copy(copyAggregationId = self.copyAggregationId)
|
||||
|
||||
lazy val impressionId: String = UUID.randomUUID.toString.replaceAll("-", "")
|
||||
|
||||
// Used to store the score and threshold for predicates
|
||||
// Map(predicate name, (score, threshold, filter?))
|
||||
private val predicateScoreAndThreshold: CMap[String, PredicateDetailedInfo] =
|
||||
new ConcurrentHashMap[String, PredicateDetailedInfo]().asScala
|
||||
|
||||
def cachePredicateInfo(
|
||||
predName: String,
|
||||
predScore: Double,
|
||||
predThreshold: Double,
|
||||
predResult: Boolean,
|
||||
additionalInformation: Option[Map[String, Double]] = None
|
||||
) = {
|
||||
if (!predicateScoreAndThreshold.contains(predName)) {
|
||||
predicateScoreAndThreshold += predName -> PredicateDetailedInfo(
|
||||
predName,
|
||||
predScore,
|
||||
predThreshold,
|
||||
predResult,
|
||||
additionalInformation)
|
||||
}
|
||||
}
|
||||
|
||||
def getCachedPredicateInfo(): Seq[PredicateDetailedInfo] = predicateScoreAndThreshold.values.toSeq
|
||||
|
||||
def frigateNotificationForPersistence(
|
||||
channels: Seq[ChannelName],
|
||||
isSilentPush: Boolean,
|
||||
overrideInfoOpt: Option[OverrideInfo] = None,
|
||||
copyFeaturesList: Set[String]
|
||||
): Future[FrigateNotification] = {
|
||||
|
||||
// record display location for frigate notification
|
||||
statsReceiver
|
||||
.scope("FrigateNotificationForPersistence")
|
||||
.scope("displayLocation")
|
||||
.counter(frigateNotification.notificationDisplayLocation.name)
|
||||
.incr()
|
||||
|
||||
val getModelScores = self.getModelScoresforScribing()
|
||||
|
||||
Future.join(getModelScores, self.target.targetMrUserState).map {
|
||||
case (mlScores, mrUserState) =>
|
||||
frigateNotification.copy(
|
||||
impressionId = Some(impressionId),
|
||||
isSilentPush = Some(isSilentPush),
|
||||
overrideInfo = overrideInfoOpt,
|
||||
mlModelScores = Some(mlScores),
|
||||
mrUserState = mrUserState.map(_.name),
|
||||
copyFeatures = Some(copyFeaturesList.toSeq)
|
||||
)
|
||||
}
|
||||
}
|
||||
// scribe data
|
||||
private def getNotificationScribe(
|
||||
notifForPersistence: FrigateNotification,
|
||||
userState: Option[UserState],
|
||||
dauCohort: PDauCohort.Value,
|
||||
ibis2Response: Option[Ibis2Response],
|
||||
tweetAuthorId: Option[Long],
|
||||
recUserId: Option[Long],
|
||||
modelScoresMap: Option[Map[String, Double]],
|
||||
primaryClient: Option[String],
|
||||
isMrBackfillCR: Option[Boolean] = None,
|
||||
tagsCR: Option[Seq[String]] = None,
|
||||
gizmoduckTargetUser: Option[User],
|
||||
predicateDetailedInfoList: Option[Seq[PredicateDetailedInfo]] = None,
|
||||
pushCapInfoList: Option[Seq[PushCapInfo]] = None
|
||||
): NotificationScribe = {
|
||||
NotificationScribe(
|
||||
FrigateNotificationScribeType.SendMessage,
|
||||
System.currentTimeMillis(),
|
||||
targetUserId = Some(self.target.targetId),
|
||||
timestampKeyForHistoryV2 = Some(createdAt.inSeconds),
|
||||
sendType = NotificationScribeUtil.convertToScribeDisplayLocation(
|
||||
self.frigateNotification.notificationDisplayLocation
|
||||
),
|
||||
recommendationType = NotificationScribeUtil.convertToScribeRecommendationType(
|
||||
self.frigateNotification.commonRecommendationType
|
||||
),
|
||||
commonRecommendationType = Some(self.frigateNotification.commonRecommendationType),
|
||||
fromPushService = Some(true),
|
||||
frigateNotification = Some(notifForPersistence),
|
||||
impressionId = Some(impressionId),
|
||||
skipModelInfo = target.skipModelInfo,
|
||||
ibis2Response = ibis2Response,
|
||||
tweetAuthorId = tweetAuthorId,
|
||||
scribeFeatures = Some(target.noSkipButScribeFeatures),
|
||||
userState = userState.map(_.toString),
|
||||
pDauCohort = Some(dauCohort.toString),
|
||||
recommendedUserId = recUserId,
|
||||
modelScores = modelScoresMap,
|
||||
primaryClient = primaryClient,
|
||||
isMrBackfillCR = isMrBackfillCR,
|
||||
tagsCR = tagsCR,
|
||||
targetUserType = gizmoduckTargetUser.map(_.userType),
|
||||
predicateDetailedInfoList = predicateDetailedInfoList,
|
||||
pushCapInfoList = pushCapInfoList
|
||||
)
|
||||
}
|
||||
|
||||
def scribeData(
|
||||
ibis2Response: Option[Ibis2Response] = None,
|
||||
isSilentPush: Boolean = false,
|
||||
overrideInfoOpt: Option[OverrideInfo] = None,
|
||||
copyFeaturesList: Set[String] = Set.empty,
|
||||
channels: Seq[ChannelName] = Seq.empty
|
||||
): Future[NotificationScribe] = {
|
||||
|
||||
val recTweetAuthorId = self match {
|
||||
case t: TweetCandidate with TweetAuthor => t.authorId
|
||||
case _ => None
|
||||
}
|
||||
|
||||
val recUserId = self match {
|
||||
case u: UserCandidate => Some(u.userId)
|
||||
case _ => None
|
||||
}
|
||||
|
||||
val isMrBackfillCR = self match {
|
||||
case t: OutOfNetworkTweetPushCandidate => t.isMrBackfillCR
|
||||
case _ => None
|
||||
}
|
||||
|
||||
val tagsCR = self match {
|
||||
case t: OutOfNetworkTweetPushCandidate =>
|
||||
t.tagsCR.map { tags =>
|
||||
tags.map(_.toString)
|
||||
}
|
||||
case t: TopicProofTweetPushCandidate =>
|
||||
t.tagsCR.map { tags =>
|
||||
tags.map(_.toString)
|
||||
}
|
||||
case _ => None
|
||||
}
|
||||
|
||||
Future
|
||||
.join(
|
||||
frigateNotificationForPersistence(
|
||||
channels = channels,
|
||||
isSilentPush = isSilentPush,
|
||||
overrideInfoOpt = overrideInfoOpt,
|
||||
copyFeaturesList = copyFeaturesList
|
||||
),
|
||||
target.targetUserState,
|
||||
PDauCohortUtil.getPDauCohort(target),
|
||||
target.deviceInfo,
|
||||
target.targetUser
|
||||
)
|
||||
.flatMap {
|
||||
case (notifForPersistence, userState, dauCohort, deviceInfo, gizmoduckTargetUserOpt) =>
|
||||
val primaryClient = deviceInfo.flatMap(_.guessedPrimaryClient).map(_.toString)
|
||||
val cachedPredicateInfo =
|
||||
if (self.target.params(PushParams.EnablePredicateDetailedInfoScribing)) {
|
||||
Some(getCachedPredicateInfo())
|
||||
} else None
|
||||
|
||||
val cachedPushCapInfo =
|
||||
if (self.target
|
||||
.params(PushParams.EnablePushCapInfoScribing)) {
|
||||
Some(target.finalPushcapAndFatigue.values.toSeq)
|
||||
} else None
|
||||
|
||||
val data = getNotificationScribe(
|
||||
notifForPersistence,
|
||||
userState,
|
||||
dauCohort,
|
||||
ibis2Response,
|
||||
recTweetAuthorId,
|
||||
recUserId,
|
||||
notifForPersistence.mlModelScores,
|
||||
primaryClient,
|
||||
isMrBackfillCR,
|
||||
tagsCR,
|
||||
gizmoduckTargetUserOpt,
|
||||
cachedPredicateInfo,
|
||||
cachedPushCapInfo
|
||||
)
|
||||
//Don't scribe features for CRTs not eligible for ML Layer
|
||||
if ((target.isModelTrainingData || target.scribeFeatureWithoutHydratingNewFeatures)
|
||||
&& !RecTypes.notEligibleForModelScoreTracking(self.commonRecType)) {
|
||||
// scribe all the features for the model training data
|
||||
self.getFeaturesForScribing.map { scribedFeatureMap =>
|
||||
if (target.params(PushParams.EnableScribingMLFeaturesAsDataRecord) && !target.params(
|
||||
PushFeatureSwitchParams.EnableMrScribingMLFeaturesAsFeatureMapForStaging)) {
|
||||
val scribedFeatureDataRecord =
|
||||
ScalaToJavaDataRecordConversions.javaDataRecord2ScalaDataRecord(
|
||||
PushQualityModelUtil.adaptToDataRecord(scribedFeatureMap, featureContext))
|
||||
data.copy(
|
||||
featureDataRecord = Some(scribedFeatureDataRecord)
|
||||
)
|
||||
} else {
|
||||
data.copy(features =
|
||||
Some(PushQualityModelUtil.convertFeatureMapToFeatures(scribedFeatureMap)))
|
||||
}
|
||||
}
|
||||
} else Future.value(data)
|
||||
}
|
||||
}
|
||||
|
||||
def getFeaturesForScribing: Future[FeatureMap] = {
|
||||
target.featureMap
|
||||
.flatMap { targetFeatureMap =>
|
||||
val onlineFeatureMap = targetFeatureMap ++ self
|
||||
.candidateFeatureMap() // targetFeatureMap includes target core user history features
|
||||
|
||||
val filteredFeatureMap = {
|
||||
onlineFeatureMap.copy(
|
||||
sparseContinuousFeatures = onlineFeatureMap.sparseContinuousFeatures.filterKeys(
|
||||
!_.equals(sensitiveMediaCategoryFeatureName))
|
||||
)
|
||||
}
|
||||
|
||||
val targetHydrationContext = HydrationContextBuilder.build(self.target)
|
||||
val candidateHydrationContext = HydrationContextBuilder.build(self)
|
||||
|
||||
val featureMapFut = targetHydrationContext.join(candidateHydrationContext).flatMap {
|
||||
case (targetContext, candidateContext) =>
|
||||
FeatureHydrator.getFeatures(
|
||||
candidateHydrationContext = candidateContext,
|
||||
targetHydrationContext = targetContext,
|
||||
onlineFeatures = filteredFeatureMap,
|
||||
statsReceiver = statsReceiver)
|
||||
}
|
||||
|
||||
featureMapFut
|
||||
}
|
||||
}
|
||||
|
||||
}
|
@ -0,0 +1,25 @@
|
||||
package com.twitter.frigate.pushservice.model.ibis
|
||||
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.PushCandidate
|
||||
import com.twitter.ibis2.lib.util.JsonMarshal
|
||||
import com.twitter.util.Future
|
||||
|
||||
trait CustomConfigurationMapForIbis {
|
||||
self: PushCandidate =>
|
||||
|
||||
lazy val customConfigMapsJsonFut: Future[String] = {
|
||||
customFieldsMapFut.map { customFields =>
|
||||
JsonMarshal.toJson(customFields)
|
||||
}
|
||||
}
|
||||
|
||||
lazy val customConfigMapsFut: Future[Map[String, String]] = {
|
||||
if (self.target.isLoggedOutUser) {
|
||||
Future.value(Map.empty[String, String])
|
||||
} else {
|
||||
customConfigMapsJsonFut.map { customConfigMapsJson =>
|
||||
Map("custom_config" -> customConfigMapsJson)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,17 @@
|
||||
package com.twitter.frigate.pushservice.model.ibis
|
||||
|
||||
import com.twitter.frigate.common.base.DiscoverTwitterCandidate
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.PushCandidate
|
||||
import com.twitter.frigate.pushservice.util.PushIbisUtil.mergeFutModelValues
|
||||
import com.twitter.util.Future
|
||||
|
||||
trait DiscoverTwitterPushIbis2Hydrator extends Ibis2HydratorForCandidate {
|
||||
self: PushCandidate with DiscoverTwitterCandidate =>
|
||||
|
||||
private lazy val targetModelValues: Map[String, String] = Map(
|
||||
"target_user" -> target.targetId.toString
|
||||
)
|
||||
|
||||
override lazy val modelValues: Future[Map[String, String]] =
|
||||
mergeFutModelValues(super.modelValues, Future.value(targetModelValues))
|
||||
}
|
@ -0,0 +1,24 @@
|
||||
package com.twitter.frigate.pushservice.model.ibis
|
||||
|
||||
import com.twitter.finagle.stats.StatsReceiver
|
||||
import com.twitter.frigate.common.base.F1FirstDegree
|
||||
import com.twitter.frigate.common.base.TweetAuthorDetails
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.PushCandidate
|
||||
import com.twitter.util.Future
|
||||
|
||||
trait F1FirstDegreeTweetIbis2HydratorForCandidate
|
||||
extends TweetCandidateIbis2Hydrator
|
||||
with RankedSocialContextIbis2Hydrator {
|
||||
self: PushCandidate with F1FirstDegree with TweetAuthorDetails =>
|
||||
|
||||
override lazy val scopedStats: StatsReceiver = statsReceiver.scope(getClass.getSimpleName)
|
||||
|
||||
override lazy val tweetModelValues: Future[Map[String, String]] = {
|
||||
for {
|
||||
superModelValues <- super.tweetModelValues
|
||||
tweetInlineModelValues <- tweetInlineActionModelValue
|
||||
} yield {
|
||||
superModelValues ++ otherModelValues ++ mediaModelValue ++ tweetInlineModelValues ++ inlineVideoMediaMap
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,127 @@
|
||||
package com.twitter.frigate.pushservice.model.ibis
|
||||
|
||||
import com.twitter.frigate.common.rec_types.RecTypes
|
||||
import com.twitter.frigate.common.util.MRPushCopy
|
||||
import com.twitter.frigate.common.util.MrPushCopyObjects
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.PushCandidate
|
||||
import com.twitter.frigate.pushservice.params.{PushFeatureSwitchParams => FS}
|
||||
import com.twitter.ibis2.service.thriftscala.Flags
|
||||
import com.twitter.ibis2.service.thriftscala.Ibis2Request
|
||||
import com.twitter.ibis2.service.thriftscala.RecipientSelector
|
||||
import com.twitter.ibis2.service.thriftscala.ResponseFlags
|
||||
import com.twitter.util.Future
|
||||
import scala.util.control.NoStackTrace
|
||||
import com.twitter.ni.lib.logged_out_transform.Ibis2RequestTransform
|
||||
|
||||
class PushCopyIdNotFoundException(private val message: String)
|
||||
extends Exception(message)
|
||||
with NoStackTrace
|
||||
|
||||
class InvalidPushCopyIdException(private val message: String)
|
||||
extends Exception(message)
|
||||
with NoStackTrace
|
||||
|
||||
trait Ibis2HydratorForCandidate
|
||||
extends CandidatePushCopy
|
||||
with OverrideForIbis2Request
|
||||
with CustomConfigurationMapForIbis {
|
||||
self: PushCandidate =>
|
||||
|
||||
lazy val silentPushModelValue: Map[String, String] =
|
||||
if (RecTypes.silentPushDefaultEnabledCrts.contains(commonRecType)) {
|
||||
Map.empty
|
||||
} else {
|
||||
Map("is_silent_push" -> "true")
|
||||
}
|
||||
|
||||
private def transformRelevanceScore(
|
||||
mlScore: Double,
|
||||
scoreRange: Seq[Double]
|
||||
): Double = {
|
||||
val (lowerBound, upperBound) = (scoreRange.head, scoreRange.last)
|
||||
(mlScore * (upperBound - lowerBound)) + lowerBound
|
||||
}
|
||||
|
||||
private def getBoundedMlScore(mlScore: Double): Double = {
|
||||
if (RecTypes.isMagicFanoutEventType(commonRecType)) {
|
||||
val mfScoreRange = target.params(FS.MagicFanoutRelevanceScoreRange)
|
||||
transformRelevanceScore(mlScore, mfScoreRange)
|
||||
} else {
|
||||
val mrScoreRange = target.params(FS.MagicRecsRelevanceScoreRange)
|
||||
transformRelevanceScore(mlScore, mrScoreRange)
|
||||
}
|
||||
}
|
||||
|
||||
lazy val relevanceScoreMapFut: Future[Map[String, String]] = {
|
||||
mrWeightedOpenOrNtabClickRankingProbability.map {
|
||||
case Some(mlScore) if target.params(FS.IncludeRelevanceScoreInIbis2Payload) =>
|
||||
val boundedMlScore = getBoundedMlScore(mlScore)
|
||||
Map("relevance_score" -> boundedMlScore.toString)
|
||||
case _ => Map.empty[String, String]
|
||||
}
|
||||
}
|
||||
|
||||
def customFieldsMapFut: Future[Map[String, String]] = relevanceScoreMapFut
|
||||
|
||||
//override is only enabled for RFPH CRT
|
||||
def modelValues: Future[Map[String, String]] = {
|
||||
Future.join(overrideModelValueFut, customConfigMapsFut).map {
|
||||
case (overrideModelValue, customConfig) =>
|
||||
overrideModelValue ++ silentPushModelValue ++ customConfig
|
||||
}
|
||||
}
|
||||
|
||||
def modelName: String = pushCopy.ibisPushModelName
|
||||
|
||||
def senderId: Option[Long] = None
|
||||
|
||||
def ibis2Request: Future[Option[Ibis2Request]] = {
|
||||
Future.join(self.target.loggedOutMetadata, modelValues).map {
|
||||
case (Some(metadata), modelVals) =>
|
||||
Some(
|
||||
Ibis2RequestTransform
|
||||
.apply(metadata, modelName, modelVals).copy(
|
||||
senderId = senderId,
|
||||
flags = Some(Flags(
|
||||
darkWrite = Some(target.isDarkWrite),
|
||||
skipDupcheck = target.pushContext.flatMap(_.useDebugHandler),
|
||||
responseFlags = Some(ResponseFlags(stringTelemetry = Some(true)))
|
||||
))
|
||||
))
|
||||
case (None, modelVals) =>
|
||||
Some(
|
||||
Ibis2Request(
|
||||
recipientSelector = RecipientSelector(Some(target.targetId)),
|
||||
modelName = modelName,
|
||||
modelValues = Some(modelVals),
|
||||
senderId = senderId,
|
||||
flags = Some(
|
||||
Flags(
|
||||
darkWrite = Some(target.isDarkWrite),
|
||||
skipDupcheck = target.pushContext.flatMap(_.useDebugHandler),
|
||||
responseFlags = Some(ResponseFlags(stringTelemetry = Some(true)))
|
||||
)
|
||||
)
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
trait CandidatePushCopy {
|
||||
self: PushCandidate =>
|
||||
|
||||
final lazy val pushCopy: MRPushCopy =
|
||||
pushCopyId match {
|
||||
case Some(pushCopyId) =>
|
||||
MrPushCopyObjects
|
||||
.getCopyFromId(pushCopyId)
|
||||
.getOrElse(
|
||||
throw new InvalidPushCopyIdException(
|
||||
s"Invalid push copy id: $pushCopyId for ${self.commonRecType}"))
|
||||
|
||||
case None =>
|
||||
throw new PushCopyIdNotFoundException(
|
||||
s"PushCopy not found in frigateNotification for ${self.commonRecType}"
|
||||
)
|
||||
}
|
||||
}
|
@ -0,0 +1,12 @@
|
||||
package com.twitter.frigate.pushservice.model.ibis
|
||||
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.PushCandidate
|
||||
import com.twitter.frigate.pushservice.util.InlineActionUtil
|
||||
import com.twitter.util.Future
|
||||
|
||||
trait InlineActionIbis2Hydrator {
|
||||
self: PushCandidate =>
|
||||
|
||||
lazy val tweetInlineActionModelValue: Future[Map[String, String]] =
|
||||
InlineActionUtil.getTweetInlineActionValue(target)
|
||||
}
|
@ -0,0 +1,21 @@
|
||||
package com.twitter.frigate.pushservice.model.ibis
|
||||
|
||||
import com.twitter.frigate.pushservice.model.ListRecommendationPushCandidate
|
||||
import com.twitter.util.Future
|
||||
|
||||
trait ListIbis2Hydrator extends Ibis2HydratorForCandidate {
|
||||
self: ListRecommendationPushCandidate =>
|
||||
|
||||
override lazy val senderId: Option[Long] = Some(0L)
|
||||
|
||||
override lazy val modelValues: Future[Map[String, String]] =
|
||||
Future.join(listName, listOwnerId).map {
|
||||
case (nameOpt, authorId) =>
|
||||
Map(
|
||||
"list" -> listId.toString,
|
||||
"list_name" -> nameOpt
|
||||
.getOrElse(""),
|
||||
"list_author" -> s"${authorId.getOrElse(0L)}"
|
||||
)
|
||||
}
|
||||
}
|
@ -0,0 +1,29 @@
|
||||
package com.twitter.frigate.pushservice.model.ibis
|
||||
|
||||
import com.twitter.frigate.magic_events.thriftscala.CreatorFanoutType
|
||||
import com.twitter.frigate.pushservice.model.PushTypes.PushCandidate
|
||||
import com.twitter.frigate.pushservice.model.MagicFanoutCreatorEventPushCandidate
|
||||
import com.twitter.frigate.pushservice.util.PushIbisUtil.mergeModelValues
|
||||
import com.twitter.util.Future
|
||||
|
||||
trait MagicFanoutCreatorEventIbis2Hydrator
|
||||
extends CustomConfigurationMapForIbis
|
||||
with Ibis2HydratorForCandidate {
|
||||
self: PushCandidate with MagicFanoutCreatorEventPushCandidate =>
|
||||
|
||||
val userMap = Map(
|
||||
"handle" -> userProfile.screenName,
|
||||
"display_name" -> userProfile.name
|
||||
)
|
||||
|
||||
override val senderId = hydratedCreator.map(_.id)
|
||||
|
||||
override lazy val modelValues: Future[Map[String, String]] =
|
||||
mergeModelValues(super.modelValues, userMap)
|
||||
|
||||
override val ibis2Request = creatorFanoutType match {
|
||||
case CreatorFanoutType.UserSubscription => Future.None
|
||||
case CreatorFanoutType.NewCreator => super.ibis2Request
|
||||
case _ => super.ibis2Request
|
||||
}
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user