Merge branch 'twitter:main' into main

This commit is contained in:
Coenraad Loubser 2023-05-31 22:54:36 +02:00 committed by GitHub
commit 227e3d140a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1177 changed files with 134033 additions and 884 deletions

View File

@ -1,6 +1,6 @@
# Twitter's Recommendation Algorithm # Twitter's Recommendation Algorithm
Twitter's Recommendation Algorithm is a set of services and jobs that are responsible for serving feeds of Tweets and other content across all Twitter product surfaces (e.g. For You Timeline, Search, Explore). For an introduction to how the algorithm works, please refer to our [engineering blog](https://blog.twitter.com/engineering/en_us/topics/open-source/2023/twitter-recommendation-algorithm). Twitter's Recommendation Algorithm is a set of services and jobs that are responsible for serving feeds of Tweets and other content across all Twitter product surfaces (e.g. For You Timeline, Search, Explore, Notifications). For an introduction to how the algorithm works, please refer to our [engineering blog](https://blog.twitter.com/engineering/en_us/topics/open-source/2023/twitter-recommendation-algorithm).
## Architecture ## Architecture
@ -8,7 +8,8 @@ Product surfaces at Twitter are built on a shared set of data, models, and softw
| Type | Component | Description | | Type | Component | Description |
|------------|------------|------------| |------------|------------|------------|
| Data | [unified-user-actions](unified_user_actions/README.md) | Real-time stream of user actions on Twitter. | | Data | [tweetypie](tweetypie/server/README.md) | Core Tweet service that handles the reading and writing of Tweet data. |
| | [unified-user-actions](unified_user_actions/README.md) | Real-time stream of user actions on Twitter. |
| | [user-signal-service](user-signal-service/README.md) | Centralized platform to retrieve explicit (e.g. likes, replies) and implicit (e.g. profile visits, tweet clicks) user signals. | | | [user-signal-service](user-signal-service/README.md) | Centralized platform to retrieve explicit (e.g. likes, replies) and implicit (e.g. profile visits, tweet clicks) user signals. |
| Model | [SimClusters](src/scala/com/twitter/simclusters_v2/README.md) | Community detection and sparse embeddings into those communities. | | Model | [SimClusters](src/scala/com/twitter/simclusters_v2/README.md) | Community detection and sparse embeddings into those communities. |
| | [TwHIN](https://github.com/twitter/the-algorithm-ml/blob/main/projects/twhin/README.md) | Dense knowledge graph embeddings for Users and Tweets. | | | [TwHIN](https://github.com/twitter/the-algorithm-ml/blob/main/projects/twhin/README.md) | Dense knowledge graph embeddings for Users and Tweets. |
@ -18,11 +19,14 @@ Product surfaces at Twitter are built on a shared set of data, models, and softw
| | [recos-injector](recos-injector/README.md) | Streaming event processor for building input streams for [GraphJet](https://github.com/twitter/GraphJet) based services. | | | [recos-injector](recos-injector/README.md) | Streaming event processor for building input streams for [GraphJet](https://github.com/twitter/GraphJet) based services. |
| | [graph-feature-service](graph-feature-service/README.md) | Serves graph features for a directed pair of Users (e.g. how many of User A's following liked Tweets from User B). | | | [graph-feature-service](graph-feature-service/README.md) | Serves graph features for a directed pair of Users (e.g. how many of User A's following liked Tweets from User B). |
| | [topic-social-proof](topic-social-proof/README.md) | Identifies topics related to individual Tweets. | | | [topic-social-proof](topic-social-proof/README.md) | Identifies topics related to individual Tweets. |
| | [representation-scorer](representation-scorer/README.md) | Compute scores between pairs of entities (Users, Tweets, etc.) using embedding similarity. |
| Software framework | [navi](navi/README.md) | High performance, machine learning model serving written in Rust. | | Software framework | [navi](navi/README.md) | High performance, machine learning model serving written in Rust. |
| | [product-mixer](product-mixer/README.md) | Software framework for building feeds of content. | | | [product-mixer](product-mixer/README.md) | Software framework for building feeds of content. |
| | [timelines-aggregation-framework](timelines/data_processing/ml_util/aggregation_framework/README.md) | Framework for generating aggregate features in batch or real time. |
| | [representation-manager](representation-manager/README.md) | Service to retrieve embeddings (i.e. SimClusers and TwHIN). |
| | [twml](twml/README.md) | Legacy machine learning framework built on TensorFlow v1. | | | [twml](twml/README.md) | Legacy machine learning framework built on TensorFlow v1. |
The product surface currently included in this repository is the For You Timeline. The product surfaces currently included in this repository are the For You Timeline and Recommended Notifications.
### For You Timeline ### For You Timeline
@ -44,6 +48,16 @@ The core components of the For You Timeline included in this repository are list
| | [visibility-filters](visibilitylib/README.md) | Responsible for filtering Twitter content to support legal compliance, improve product quality, increase user trust, protect revenue through the use of hard-filtering, visible product treatments, and coarse-grained downranking. | | | [visibility-filters](visibilitylib/README.md) | Responsible for filtering Twitter content to support legal compliance, improve product quality, increase user trust, protect revenue through the use of hard-filtering, visible product treatments, and coarse-grained downranking. |
| | [timelineranker](timelineranker/README.md) | Legacy service which provides relevance-scored tweets from the Earlybird Search Index and UTEG service. | | | [timelineranker](timelineranker/README.md) | Legacy service which provides relevance-scored tweets from the Earlybird Search Index and UTEG service. |
### Recommended Notifications
The core components of Recommended Notifications included in this repository are listed below:
| Type | Component | Description |
|------------|------------|------------|
| Service | [pushservice](pushservice/README.md) | Main recommendation service at Twitter used to surface recommendations to our users via notifications.
| Ranking | [pushservice-light-ranker](pushservice/src/main/python/models/light_ranking/README.md) | Light Ranker model used by pushservice to rank Tweets. Bridges candidate generation and heavy ranking by pre-selecting highly-relevant candidates from the initial huge candidate pool. |
| | [pushservice-heavy-ranker](pushservice/src/main/python/models/heavy_ranking/README.md) | Multi-task learning model to predict the probabilities that the target users will open and engage with the sent notifications. |
## Build and test code ## Build and test code
We include Bazel BUILD files for most components, but not a top-level BUILD or WORKSPACE file. We plan to add a more complete build and test system in the future. We include Bazel BUILD files for most components, but not a top-level BUILD or WORKSPACE file. We plan to add a more complete build and test system in the future.

51
RETREIVAL_SIGNALS.md Normal file
View File

@ -0,0 +1,51 @@
# Signals for Candidate Sources
## Overview
The candidate sourcing stage within the Twitter Recommendation algorithm serves to significantly narrow down the item size from approximately 1 billion to just a few thousand. This process utilizes Twitter user behavior as the primary input for the algorithm. This document comprehensively enumerates all the signals during the candidate sourcing phase.
| Signals | Description |
| :-------------------- | :-------------------------------------------------------------------- |
| Author Follow | The accounts which user explicit follows. |
| Author Unfollow | The accounts which user recently unfollows. |
| Author Mute | The accounts which user have muted. |
| Author Block | The accounts which user have blocked |
| Tweet Favorite | The tweets which user clicked the like botton. |
| Tweet Unfavorite | The tweets which user clicked the unlike botton. |
| Retweet | The tweets which user retweeted |
| Quote Tweet | The tweets which user retweeted with comments. |
| Tweet Reply | The tweets which user replied. |
| Tweet Share | The tweets which user clicked the share botton. |
| Tweet Bookmark | The tweets which user clicked the bookmark botton. |
| Tweet Click | The tweets which user clicked and viewed the tweet detail page. |
| Tweet Video Watch | The video tweets which user watched certain seconds or percentage. |
| Tweet Don't like | The tweets which user clicked "Not interested in this tweet" botton. |
| Tweet Report | The tweets which user clicked "Report Tweet" botton. |
| Notification Open | The push notification tweets which user opened. |
| Ntab click | The tweets which user click on the Notifications page. |
| User AddressBook | The author accounts identifiers of the user's addressbook. |
## Usage Details
Twitter uses these user signals as training labels and/or ML features in the each candidate sourcing algorithms. The following tables shows how they are used in the each components.
| Signals | USS | SimClusters | TwHin | UTEG | FRS | Light Ranking |
| :-------------------- | :----------------- | :----------------- | :----------------- | :----------------- | :----------------- | :----------------- |
| Author Follow | Features | Features / Labels | Features / Labels | Features | Features / Labels | N/A |
| Author Unfollow | Features | N/A | N/A | N/A | N/A | N/A |
| Author Mute | Features | N/A | N/A | N/A | Features | N/A |
| Author Block | Features | N/A | N/A | N/A | Features | N/A |
| Tweet Favorite | Features | Features | Features / Labels | Features | Features / Labels | Features / Labels |
| Tweet Unfavorite | Features | Features | N/A | N/A | N/A | N/A |
| Retweet | Features | N/A | Features / Labels | Features | Features / Labels | Features / Labels |
| Quote Tweet | Features | N/A | Features / Labels | Features | Features / Labels | Features / Labels |
| Tweet Reply | Features | N/A | Features | Features | Features / Labels | Features |
| Tweet Share | Features | N/A | N/A | N/A | Features | N/A |
| Tweet Bookmark | Features | N/A | N/A | N/A | N/A | N/A |
| Tweet Click | Features | N/A | N/A | N/A | Features | Labels |
| Tweet Video Watch | Features | Features | N/A | N/A | N/A | Labels |
| Tweet Don't like | Features | N/A | N/A | N/A | N/A | N/A |
| Tweet Report | Features | N/A | N/A | N/A | N/A | N/A |
| Notification Open | Features | Features | Features | N/A | Features | N/A |
| Ntab click | Features | Features | Features | N/A | Features | N/A |
| User AddressBook | N/A | N/A | N/A | N/A | Features | N/A |

View File

@ -31,6 +31,11 @@ In navi/navi, you can run the following commands:
- `scripts/run_onnx.sh` for [Onnx](https://onnx.ai/) - `scripts/run_onnx.sh` for [Onnx](https://onnx.ai/)
Do note that you need to create a models directory and create some versions, preferably using epoch time, e.g., `1679693908377`. Do note that you need to create a models directory and create some versions, preferably using epoch time, e.g., `1679693908377`.
so the models structure looks like:
models/
-web_click
- 1809000
- 1809010
## Build ## Build
You can adapt the above scripts to build using Cargo. You can adapt the above scripts to build using Cargo.

View File

@ -3,7 +3,6 @@ name = "dr_transform"
version = "0.1.0" version = "0.1.0"
edition = "2021" edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies] [dependencies]
serde = { version = "1.0", features = ["derive"] } serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0" serde_json = "1.0"
@ -12,7 +11,6 @@ bpr_thrift = { path = "../thrift_bpr_adapter/thrift/"}
segdense = { path = "../segdense/"} segdense = { path = "../segdense/"}
thrift = "0.17.0" thrift = "0.17.0"
ndarray = "0.15" ndarray = "0.15"
ort = {git ="https://github.com/pykeio/ort.git", tag="v1.14.2"}
base64 = "0.20.0" base64 = "0.20.0"
npyz = "0.7.2" npyz = "0.7.2"
log = "0.4.17" log = "0.4.17"
@ -21,6 +19,11 @@ prometheus = "0.13.1"
once_cell = "1.17.0" once_cell = "1.17.0"
rand = "0.8.5" rand = "0.8.5"
itertools = "0.10.5" itertools = "0.10.5"
anyhow = "1.0.70"
[target.'cfg(not(target_os="linux"))'.dependencies]
ort = {git ="https://github.com/pykeio/ort.git", features=["profiling"], tag="v1.14.6"}
[target.'cfg(target_os="linux")'.dependencies]
ort = {git ="https://github.com/pykeio/ort.git", features=["profiling", "tensorrt", "cuda", "copy-dylibs"], tag="v1.14.6"}
[dev-dependencies] [dev-dependencies]
criterion = "0.3.0" criterion = "0.3.0"

View File

@ -44,5 +44,6 @@ pub struct RenamedFeatures {
} }
pub fn parse(json_str: &str) -> Result<AllConfig, Error> { pub fn parse(json_str: &str) -> Result<AllConfig, Error> {
serde_json::from_str(json_str) let all_config: AllConfig = serde_json::from_str(json_str)?;
Ok(all_config)
} }

View File

@ -2,6 +2,9 @@ use std::collections::BTreeSet;
use std::fmt::{self, Debug, Display}; use std::fmt::{self, Debug, Display};
use std::fs; use std::fs;
use crate::all_config;
use crate::all_config::AllConfig;
use anyhow::{bail, Context};
use bpr_thrift::data::DataRecord; use bpr_thrift::data::DataRecord;
use bpr_thrift::prediction_service::BatchPredictionRequest; use bpr_thrift::prediction_service::BatchPredictionRequest;
use bpr_thrift::tensor::GeneralTensor; use bpr_thrift::tensor::GeneralTensor;
@ -16,8 +19,6 @@ use segdense::util;
use thrift::protocol::{TBinaryInputProtocol, TSerializable}; use thrift::protocol::{TBinaryInputProtocol, TSerializable};
use thrift::transport::TBufferChannel; use thrift::transport::TBufferChannel;
use crate::{all_config, all_config::AllConfig};
pub fn log_feature_match( pub fn log_feature_match(
dr: &DataRecord, dr: &DataRecord,
seg_dense_config: &DensificationTransformSpec, seg_dense_config: &DensificationTransformSpec,
@ -28,20 +29,24 @@ pub fn log_feature_match(
for (feature_id, feature_value) in dr.continuous_features.as_ref().unwrap() { for (feature_id, feature_value) in dr.continuous_features.as_ref().unwrap() {
debug!( debug!(
"{dr_type} - Continuous Datarecord => Feature ID: {feature_id}, Feature value: {feature_value}" "{} - Continous Datarecord => Feature ID: {}, Feature value: {}",
dr_type, feature_id, feature_value
); );
for input_feature in &seg_dense_config.cont.input_features { for input_feature in &seg_dense_config.cont.input_features {
if input_feature.feature_id == *feature_id { if input_feature.feature_id == *feature_id {
debug!("Matching input feature: {input_feature:?}") debug!("Matching input feature: {:?}", input_feature)
} }
} }
} }
for feature_id in dr.binary_features.as_ref().unwrap() { for feature_id in dr.binary_features.as_ref().unwrap() {
debug!("{dr_type} - Binary Datarecord => Feature ID: {feature_id}"); debug!(
"{} - Binary Datarecord => Feature ID: {}",
dr_type, feature_id
);
for input_feature in &seg_dense_config.binary.input_features { for input_feature in &seg_dense_config.binary.input_features {
if input_feature.feature_id == *feature_id { if input_feature.feature_id == *feature_id {
debug!("Found input feature: {input_feature:?}") debug!("Found input feature: {:?}", input_feature)
} }
} }
} }
@ -90,18 +95,19 @@ impl BatchPredictionRequestToTorchTensorConverter {
model_version: &str, model_version: &str,
reporting_feature_ids: Vec<(i64, &str)>, reporting_feature_ids: Vec<(i64, &str)>,
register_metric_fn: Option<impl Fn(&HistogramVec)>, register_metric_fn: Option<impl Fn(&HistogramVec)>,
) -> BatchPredictionRequestToTorchTensorConverter { ) -> anyhow::Result<BatchPredictionRequestToTorchTensorConverter> {
let all_config_path = format!("{model_dir}/{model_version}/all_config.json"); let all_config_path = format!("{}/{}/all_config.json", model_dir, model_version);
let seg_dense_config_path = let seg_dense_config_path = format!(
format!("{model_dir}/{model_version}/segdense_transform_spec_home_recap_2022.json"); "{}/{}/segdense_transform_spec_home_recap_2022.json",
let seg_dense_config = util::load_config(&seg_dense_config_path); model_dir, model_version
);
let seg_dense_config = util::load_config(&seg_dense_config_path)?;
let all_config = all_config::parse( let all_config = all_config::parse(
&fs::read_to_string(&all_config_path) &fs::read_to_string(&all_config_path)
.unwrap_or_else(|error| panic!("error loading all_config.json - {error}")), .with_context(|| "error loading all_config.json - ")?,
) )?;
.unwrap();
let feature_mapper = util::load_from_parsed_config_ref(&seg_dense_config); let feature_mapper = util::load_from_parsed_config(seg_dense_config.clone())?;
let user_embedding_feature_id = Self::get_feature_id( let user_embedding_feature_id = Self::get_feature_id(
&all_config &all_config
@ -131,11 +137,11 @@ impl BatchPredictionRequestToTorchTensorConverter {
let (discrete_feature_metrics, continuous_feature_metrics) = METRICS.get_or_init(|| { let (discrete_feature_metrics, continuous_feature_metrics) = METRICS.get_or_init(|| {
let discrete = HistogramVec::new( let discrete = HistogramVec::new(
HistogramOpts::new(":navi:feature_id:discrete", "Discrete Feature ID values") HistogramOpts::new(":navi:feature_id:discrete", "Discrete Feature ID values")
.buckets(Vec::from([ .buckets(Vec::from(&[
0.0f64, 10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0, 110.0, 0.0, 10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0, 110.0,
120.0, 130.0, 140.0, 150.0, 160.0, 170.0, 180.0, 190.0, 200.0, 250.0, 120.0, 130.0, 140.0, 150.0, 160.0, 170.0, 180.0, 190.0, 200.0, 250.0,
300.0, 500.0, 1000.0, 10000.0, 100000.0, 300.0, 500.0, 1000.0, 10000.0, 100000.0,
])), ] as &'static [f64])),
&["feature_id"], &["feature_id"],
) )
.expect("metric cannot be created"); .expect("metric cannot be created");
@ -144,18 +150,18 @@ impl BatchPredictionRequestToTorchTensorConverter {
":navi:feature_id:continuous", ":navi:feature_id:continuous",
"continuous Feature ID values", "continuous Feature ID values",
) )
.buckets(Vec::from([ .buckets(Vec::from(&[
0.0f64, 10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0, 110.0, 0.0, 10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0, 110.0, 120.0,
120.0, 130.0, 140.0, 150.0, 160.0, 170.0, 180.0, 190.0, 200.0, 250.0, 300.0, 130.0, 140.0, 150.0, 160.0, 170.0, 180.0, 190.0, 200.0, 250.0, 300.0, 500.0,
500.0, 1000.0, 10000.0, 100000.0, 1000.0, 10000.0, 100000.0,
])), ] as &'static [f64])),
&["feature_id"], &["feature_id"],
) )
.expect("metric cannot be created"); .expect("metric cannot be created");
if let Some(r) = register_metric_fn { register_metric_fn.map(|r| {
r(&discrete); r(&discrete);
r(&continuous); r(&continuous);
} });
(discrete, continuous) (discrete, continuous)
}); });
@ -164,13 +170,16 @@ impl BatchPredictionRequestToTorchTensorConverter {
for (feature_id, feature_type) in reporting_feature_ids.iter() { for (feature_id, feature_type) in reporting_feature_ids.iter() {
match *feature_type { match *feature_type {
"discrete" => discrete_features_to_report.insert(*feature_id), "discrete" => discrete_features_to_report.insert(feature_id.clone()),
"continuous" => continuous_features_to_report.insert(*feature_id), "continuous" => continuous_features_to_report.insert(feature_id.clone()),
_ => panic!("Invalid feature type {feature_type} for reporting metrics!"), _ => bail!(
"Invalid feature type {} for reporting metrics!",
feature_type
),
}; };
} }
BatchPredictionRequestToTorchTensorConverter { Ok(BatchPredictionRequestToTorchTensorConverter {
all_config, all_config,
seg_dense_config, seg_dense_config,
all_config_path, all_config_path,
@ -183,7 +192,7 @@ impl BatchPredictionRequestToTorchTensorConverter {
continuous_features_to_report, continuous_features_to_report,
discrete_feature_metrics, discrete_feature_metrics,
continuous_feature_metrics, continuous_feature_metrics,
} })
} }
fn get_feature_id(feature_name: &str, seg_dense_config: &Root) -> i64 { fn get_feature_id(feature_name: &str, seg_dense_config: &Root) -> i64 {
@ -218,9 +227,9 @@ impl BatchPredictionRequestToTorchTensorConverter {
let mut working_set = vec![0 as f32; total_size]; let mut working_set = vec![0 as f32; total_size];
let mut bpr_start = 0; let mut bpr_start = 0;
for (bpr, &bpr_end) in bprs.iter().zip(batch_size) { for (bpr, &bpr_end) in bprs.iter().zip(batch_size) {
if bpr.common_features.is_some() if bpr.common_features.is_some() {
&& bpr.common_features.as_ref().unwrap().tensors.is_some() if bpr.common_features.as_ref().unwrap().tensors.is_some() {
&& bpr if bpr
.common_features .common_features
.as_ref() .as_ref()
.unwrap() .unwrap()
@ -258,6 +267,8 @@ impl BatchPredictionRequestToTorchTensorConverter {
} }
} }
} }
}
}
// find the feature in individual feature list and add to corresponding batch. // find the feature in individual feature list and add to corresponding batch.
for (index, datarecord) in bpr.individual_features_list.iter().enumerate() { for (index, datarecord) in bpr.individual_features_list.iter().enumerate() {
if datarecord.tensors.is_some() if datarecord.tensors.is_some()
@ -298,9 +309,9 @@ impl BatchPredictionRequestToTorchTensorConverter {
// (INT64 --> INT64, DataRecord.discrete_feature) // (INT64 --> INT64, DataRecord.discrete_feature)
fn get_continuous(&self, bprs: &[BatchPredictionRequest], batch_ends: &[usize]) -> InputTensor { fn get_continuous(&self, bprs: &[BatchPredictionRequest], batch_ends: &[usize]) -> InputTensor {
// These need to be part of model schema // These need to be part of model schema
let rows = batch_ends[batch_ends.len() - 1]; let rows: usize = batch_ends[batch_ends.len() - 1];
let cols = 5293; let cols: usize = 5293;
let full_size = rows * cols; let full_size: usize = rows * cols;
let default_val = f32::NAN; let default_val = f32::NAN;
let mut tensor = vec![default_val; full_size]; let mut tensor = vec![default_val; full_size];
@ -325,16 +336,19 @@ impl BatchPredictionRequestToTorchTensorConverter {
.unwrap(); .unwrap();
for feature in common_features { for feature in common_features {
if let Some(f_info) = self.feature_mapper.get(feature.0) { match self.feature_mapper.get(feature.0) {
Some(f_info) => {
let idx = f_info.index_within_tensor as usize; let idx = f_info.index_within_tensor as usize;
if idx < cols { if idx < cols {
// Set value in each row // Set value in each row
for r in bpr_start..bpr_end { for r in bpr_start..bpr_end {
let flat_index = r * cols + idx; let flat_index: usize = r * cols + idx;
tensor[flat_index] = feature.1.into_inner() as f32; tensor[flat_index] = feature.1.into_inner() as f32;
} }
} }
} }
None => (),
}
if self.continuous_features_to_report.contains(feature.0) { if self.continuous_features_to_report.contains(feature.0) {
self.continuous_feature_metrics self.continuous_feature_metrics
.with_label_values(&[feature.0.to_string().as_str()]) .with_label_values(&[feature.0.to_string().as_str()])
@ -349,24 +363,28 @@ impl BatchPredictionRequestToTorchTensorConverter {
// Process the batch of datarecords // Process the batch of datarecords
for r in bpr_start..bpr_end { for r in bpr_start..bpr_end {
let dr: &DataRecord = &bpr.individual_features_list[r - bpr_start]; let dr: &DataRecord =
&bpr.individual_features_list[usize::try_from(r - bpr_start).unwrap()];
if dr.continuous_features.is_some() { if dr.continuous_features.is_some() {
for feature in dr.continuous_features.as_ref().unwrap() { for feature in dr.continuous_features.as_ref().unwrap() {
if let Some(f_info) = self.feature_mapper.get(feature.0) { match self.feature_mapper.get(&feature.0) {
Some(f_info) => {
let idx = f_info.index_within_tensor as usize; let idx = f_info.index_within_tensor as usize;
let flat_index = r * cols + idx; let flat_index: usize = r * cols + idx;
if flat_index < tensor.len() && idx < cols { if flat_index < tensor.len() && idx < cols {
tensor[flat_index] = feature.1.into_inner() as f32; tensor[flat_index] = feature.1.into_inner() as f32;
} }
} }
None => (),
}
if self.continuous_features_to_report.contains(feature.0) { if self.continuous_features_to_report.contains(feature.0) {
self.continuous_feature_metrics self.continuous_feature_metrics
.with_label_values(&[feature.0.to_string().as_str()]) .with_label_values(&[feature.0.to_string().as_str()])
.observe(feature.1.into_inner()) .observe(feature.1.into_inner() as f64)
} else if self.discrete_features_to_report.contains(feature.0) { } else if self.discrete_features_to_report.contains(feature.0) {
self.discrete_feature_metrics self.discrete_feature_metrics
.with_label_values(&[feature.0.to_string().as_str()]) .with_label_values(&[feature.0.to_string().as_str()])
.observe(feature.1.into_inner()) .observe(feature.1.into_inner() as f64)
} }
} }
} }
@ -383,10 +401,10 @@ impl BatchPredictionRequestToTorchTensorConverter {
fn get_binary(&self, bprs: &[BatchPredictionRequest], batch_ends: &[usize]) -> InputTensor { fn get_binary(&self, bprs: &[BatchPredictionRequest], batch_ends: &[usize]) -> InputTensor {
// These need to be part of model schema // These need to be part of model schema
let rows = batch_ends[batch_ends.len() - 1]; let rows: usize = batch_ends[batch_ends.len() - 1];
let cols = 149; let cols: usize = 149;
let full_size = rows * cols; let full_size: usize = rows * cols;
let default_val = 0; let default_val: i64 = 0;
let mut v = vec![default_val; full_size]; let mut v = vec![default_val; full_size];
@ -410,16 +428,19 @@ impl BatchPredictionRequestToTorchTensorConverter {
.unwrap(); .unwrap();
for feature in common_features { for feature in common_features {
if let Some(f_info) = self.feature_mapper.get(feature) { match self.feature_mapper.get(feature) {
Some(f_info) => {
let idx = f_info.index_within_tensor as usize; let idx = f_info.index_within_tensor as usize;
if idx < cols { if idx < cols {
// Set value in each row // Set value in each row
for r in bpr_start..bpr_end { for r in bpr_start..bpr_end {
let flat_index = r * cols + idx; let flat_index: usize = r * cols + idx;
v[flat_index] = 1; v[flat_index] = 1;
} }
} }
} }
None => (),
}
} }
} }
@ -428,11 +449,14 @@ impl BatchPredictionRequestToTorchTensorConverter {
let dr: &DataRecord = &bpr.individual_features_list[r - bpr_start]; let dr: &DataRecord = &bpr.individual_features_list[r - bpr_start];
if dr.binary_features.is_some() { if dr.binary_features.is_some() {
for feature in dr.binary_features.as_ref().unwrap() { for feature in dr.binary_features.as_ref().unwrap() {
if let Some(f_info) = self.feature_mapper.get(feature) { match self.feature_mapper.get(&feature) {
Some(f_info) => {
let idx = f_info.index_within_tensor as usize; let idx = f_info.index_within_tensor as usize;
let flat_index = r * cols + idx; let flat_index: usize = r * cols + idx;
v[flat_index] = 1; v[flat_index] = 1;
} }
None => (),
}
} }
} }
} }
@ -448,10 +472,10 @@ impl BatchPredictionRequestToTorchTensorConverter {
#[allow(dead_code)] #[allow(dead_code)]
fn get_discrete(&self, bprs: &[BatchPredictionRequest], batch_ends: &[usize]) -> InputTensor { fn get_discrete(&self, bprs: &[BatchPredictionRequest], batch_ends: &[usize]) -> InputTensor {
// These need to be part of model schema // These need to be part of model schema
let rows = batch_ends[batch_ends.len() - 1]; let rows: usize = batch_ends[batch_ends.len() - 1];
let cols = 320; let cols: usize = 320;
let full_size = rows * cols; let full_size: usize = rows * cols;
let default_val = 0; let default_val: i64 = 0;
let mut v = vec![default_val; full_size]; let mut v = vec![default_val; full_size];
@ -475,16 +499,19 @@ impl BatchPredictionRequestToTorchTensorConverter {
.unwrap(); .unwrap();
for feature in common_features { for feature in common_features {
if let Some(f_info) = self.feature_mapper.get(feature.0) { match self.feature_mapper.get(feature.0) {
Some(f_info) => {
let idx = f_info.index_within_tensor as usize; let idx = f_info.index_within_tensor as usize;
if idx < cols { if idx < cols {
// Set value in each row // Set value in each row
for r in bpr_start..bpr_end { for r in bpr_start..bpr_end {
let flat_index = r * cols + idx; let flat_index: usize = r * cols + idx;
v[flat_index] = *feature.1; v[flat_index] = *feature.1;
} }
} }
} }
None => (),
}
if self.discrete_features_to_report.contains(feature.0) { if self.discrete_features_to_report.contains(feature.0) {
self.discrete_feature_metrics self.discrete_feature_metrics
.with_label_values(&[feature.0.to_string().as_str()]) .with_label_values(&[feature.0.to_string().as_str()])
@ -495,16 +522,19 @@ impl BatchPredictionRequestToTorchTensorConverter {
// Process the batch of datarecords // Process the batch of datarecords
for r in bpr_start..bpr_end { for r in bpr_start..bpr_end {
let dr: &DataRecord = &bpr.individual_features_list[r]; let dr: &DataRecord = &bpr.individual_features_list[usize::try_from(r).unwrap()];
if dr.discrete_features.is_some() { if dr.discrete_features.is_some() {
for feature in dr.discrete_features.as_ref().unwrap() { for feature in dr.discrete_features.as_ref().unwrap() {
if let Some(f_info) = self.feature_mapper.get(feature.0) { match self.feature_mapper.get(&feature.0) {
Some(f_info) => {
let idx = f_info.index_within_tensor as usize; let idx = f_info.index_within_tensor as usize;
let flat_index = r * cols + idx; let flat_index: usize = r * cols + idx;
if flat_index < v.len() && idx < cols { if flat_index < v.len() && idx < cols {
v[flat_index] = *feature.1; v[flat_index] = *feature.1;
} }
} }
None => (),
}
if self.discrete_features_to_report.contains(feature.0) { if self.discrete_features_to_report.contains(feature.0) {
self.discrete_feature_metrics self.discrete_feature_metrics
.with_label_values(&[feature.0.to_string().as_str()]) .with_label_values(&[feature.0.to_string().as_str()])
@ -569,7 +599,7 @@ impl Converter for BatchPredictionRequestToTorchTensorConverter {
.map(|bpr| bpr.individual_features_list.len()) .map(|bpr| bpr.individual_features_list.len())
.scan(0usize, |acc, e| { .scan(0usize, |acc, e| {
//running total //running total
*acc += e; *acc = *acc + e;
Some(*acc) Some(*acc)
}) })
.collect::<Vec<_>>(); .collect::<Vec<_>>();

View File

@ -3,3 +3,4 @@ pub mod converter;
#[cfg(test)] #[cfg(test)]
mod test; mod test;
pub mod util; pub mod util;
pub extern crate ort;

View File

@ -1,8 +1,7 @@
[package] [package]
name = "navi" name = "navi"
version = "2.0.42" version = "2.0.45"
edition = "2021" edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[[bin]] [[bin]]
name = "navi" name = "navi"
@ -16,12 +15,19 @@ required-features=["torch"]
name = "navi_onnx" name = "navi_onnx"
path = "src/bin/navi_onnx.rs" path = "src/bin/navi_onnx.rs"
required-features=["onnx"] required-features=["onnx"]
[[bin]]
name = "navi_onnx_test"
path = "src/bin/bin_tests/navi_onnx_test.rs"
[[bin]]
name = "navi_torch_test"
path = "src/bin/bin_tests/navi_torch_test.rs"
required-features=["torch"]
[features] [features]
default=[] default=[]
navi_console=[] navi_console=[]
torch=["tch"] torch=["tch"]
onnx=["ort"] onnx=[]
tf=["tensorflow"] tf=["tensorflow"]
[dependencies] [dependencies]
itertools = "0.10.5" itertools = "0.10.5"
@ -47,6 +53,7 @@ parking_lot = "0.12.1"
rand = "0.8.5" rand = "0.8.5"
rand_pcg = "0.3.1" rand_pcg = "0.3.1"
random = "0.12.2" random = "0.12.2"
x509-parser = "0.15.0"
sha256 = "1.0.3" sha256 = "1.0.3"
tonic = { version = "0.6.2", features=['compression', 'tls'] } tonic = { version = "0.6.2", features=['compression', 'tls'] }
tokio = { version = "1.17.0", features = ["macros", "rt-multi-thread", "fs", "process"] } tokio = { version = "1.17.0", features = ["macros", "rt-multi-thread", "fs", "process"] }
@ -55,16 +62,12 @@ npyz = "0.7.3"
base64 = "0.21.0" base64 = "0.21.0"
histogram = "0.6.9" histogram = "0.6.9"
tch = {version = "0.10.3", optional = true} tch = {version = "0.10.3", optional = true}
tensorflow = { version = "0.20.0", optional = true } tensorflow = { version = "0.18.0", optional = true }
once_cell = {version = "1.17.1"} once_cell = {version = "1.17.1"}
ndarray = "0.15" ndarray = "0.15"
serde = "1.0.154" serde = "1.0.154"
serde_json = "1.0.94" serde_json = "1.0.94"
dr_transform = { path = "../dr_transform"} dr_transform = { path = "../dr_transform"}
[target.'cfg(not(target_os="linux"))'.dependencies]
ort = {git ="https://github.com/pykeio/ort.git", features=["profiling"], optional = true, tag="v1.14.2"}
[target.'cfg(target_os="linux")'.dependencies]
ort = {git ="https://github.com/pykeio/ort.git", features=["profiling", "tensorrt", "cuda", "copy-dylibs"], optional = true, tag="v1.14.2"}
[build-dependencies] [build-dependencies]
tonic-build = {version = "0.6.2", features=['prost', "compression"] } tonic-build = {version = "0.6.2", features=['prost', "compression"] }
[profile.release] [profile.release]
@ -74,3 +77,5 @@ ndarray-rand = "0.14.0"
tokio-test = "*" tokio-test = "*"
assert_cmd = "2.0" assert_cmd = "2.0"
criterion = "0.4.0" criterion = "0.4.0"

View File

@ -122,7 +122,7 @@ enum FullTypeId {
// TFT_TENSOR[TFT_INT32, TFT_UNKNOWN] // TFT_TENSOR[TFT_INT32, TFT_UNKNOWN]
// is a Tensor of int32 element type and unknown shape. // is a Tensor of int32 element type and unknown shape.
// //
// TODO: Define TFT_SHAPE and add more examples. // TODO(mdan): Define TFT_SHAPE and add more examples.
TFT_TENSOR = 1000; TFT_TENSOR = 1000;
// Array (or tensorflow::TensorList in the variant type registry). // Array (or tensorflow::TensorList in the variant type registry).
@ -178,7 +178,7 @@ enum FullTypeId {
// object (for now). // object (for now).
// The bool element type. // The bool element type.
// TODO // TODO(mdan): Quantized types, legacy representations (e.g. ref)
TFT_BOOL = 200; TFT_BOOL = 200;
// Integer element types. // Integer element types.
TFT_UINT8 = 201; TFT_UINT8 = 201;
@ -195,7 +195,7 @@ enum FullTypeId {
TFT_DOUBLE = 211; TFT_DOUBLE = 211;
TFT_BFLOAT16 = 215; TFT_BFLOAT16 = 215;
// Complex element types. // Complex element types.
// TODO: Represent as TFT_COMPLEX[TFT_DOUBLE] instead? // TODO(mdan): Represent as TFT_COMPLEX[TFT_DOUBLE] instead?
TFT_COMPLEX64 = 212; TFT_COMPLEX64 = 212;
TFT_COMPLEX128 = 213; TFT_COMPLEX128 = 213;
// The string element type. // The string element type.
@ -240,7 +240,7 @@ enum FullTypeId {
// ownership is in the true sense: "the op argument representing the lock is // ownership is in the true sense: "the op argument representing the lock is
// available". // available".
// Mutex locks are the dynamic counterpart of control dependencies. // Mutex locks are the dynamic counterpart of control dependencies.
// TODO: Properly document this thing. // TODO(mdan): Properly document this thing.
// //
// Parametrization: TFT_MUTEX_LOCK[]. // Parametrization: TFT_MUTEX_LOCK[].
TFT_MUTEX_LOCK = 10202; TFT_MUTEX_LOCK = 10202;
@ -271,6 +271,6 @@ message FullTypeDef {
oneof attr { oneof attr {
string s = 3; string s = 3;
int64 i = 4; int64 i = 4;
// TODO: list/tensor, map? Need to reconcile with TFT_RECORD, etc. // TODO(mdan): list/tensor, map? Need to reconcile with TFT_RECORD, etc.
} }
} }

View File

@ -23,7 +23,7 @@ message FunctionDefLibrary {
// with a value. When a GraphDef has a call to a function, it must // with a value. When a GraphDef has a call to a function, it must
// have binding for every attr defined in the signature. // have binding for every attr defined in the signature.
// //
// TODO: // TODO(zhifengc):
// * device spec, etc. // * device spec, etc.
message FunctionDef { message FunctionDef {
// The definition of the function's name, arguments, return values, // The definition of the function's name, arguments, return values,

View File

@ -61,7 +61,7 @@ message NodeDef {
// one of the names from the corresponding OpDef's attr field). // one of the names from the corresponding OpDef's attr field).
// The values must have a type matching the corresponding OpDef // The values must have a type matching the corresponding OpDef
// attr's type field. // attr's type field.
// TODO: Add some examples here showing best practices. // TODO(josh11b): Add some examples here showing best practices.
map<string, AttrValue> attr = 5; map<string, AttrValue> attr = 5;
message ExperimentalDebugInfo { message ExperimentalDebugInfo {

View File

@ -96,7 +96,7 @@ message OpDef {
// Human-readable description. // Human-readable description.
string description = 4; string description = 4;
// TODO: bool is_optional? // TODO(josh11b): bool is_optional?
// --- Constraints --- // --- Constraints ---
// These constraints are only in effect if specified. Default is no // These constraints are only in effect if specified. Default is no
@ -139,7 +139,7 @@ message OpDef {
// taking input from multiple devices with a tree of aggregate ops // taking input from multiple devices with a tree of aggregate ops
// that aggregate locally within each device (and possibly within // that aggregate locally within each device (and possibly within
// groups of nearby devices) before communicating. // groups of nearby devices) before communicating.
// TODO: Implement that optimization. // TODO(josh11b): Implement that optimization.
bool is_aggregate = 16; // for things like add bool is_aggregate = 16; // for things like add
// Other optimizations go here, like // Other optimizations go here, like

View File

@ -53,7 +53,7 @@ message MemoryStats {
// Time/size stats recorded for a single execution of a graph node. // Time/size stats recorded for a single execution of a graph node.
message NodeExecStats { message NodeExecStats {
// TODO: Use some more compact form of node identity than // TODO(tucker): Use some more compact form of node identity than
// the full string name. Either all processes should agree on a // the full string name. Either all processes should agree on a
// global id (cost_id?) for each node, or we should use a hash of // global id (cost_id?) for each node, or we should use a hash of
// the name. // the name.

View File

@ -16,7 +16,7 @@ option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framewo
message TensorProto { message TensorProto {
DataType dtype = 1; DataType dtype = 1;
// Shape of the tensor. TODO: sort out the 0-rank issues. // Shape of the tensor. TODO(touts): sort out the 0-rank issues.
TensorShapeProto tensor_shape = 2; TensorShapeProto tensor_shape = 2;
// Only one of the representations below is set, one of "tensor_contents" and // Only one of the representations below is set, one of "tensor_contents" and

View File

@ -532,7 +532,7 @@ message ConfigProto {
// We removed the flag client_handles_error_formatting. Marking the tag // We removed the flag client_handles_error_formatting. Marking the tag
// number as reserved. // number as reserved.
// TODO: Should we just remove this tag so that it can be // TODO(shikharagarwal): Should we just remove this tag so that it can be
// used in future for other purpose? // used in future for other purpose?
reserved 2; reserved 2;
@ -576,7 +576,7 @@ message ConfigProto {
// - If isolate_session_state is true, session states are isolated. // - If isolate_session_state is true, session states are isolated.
// - If isolate_session_state is false, session states are shared. // - If isolate_session_state is false, session states are shared.
// //
// TODO: Add a single API that consistently treats // TODO(b/129330037): Add a single API that consistently treats
// isolate_session_state and ClusterSpec propagation. // isolate_session_state and ClusterSpec propagation.
bool share_session_state_in_clusterspec_propagation = 8; bool share_session_state_in_clusterspec_propagation = 8;
@ -704,7 +704,7 @@ message ConfigProto {
// Options for a single Run() call. // Options for a single Run() call.
message RunOptions { message RunOptions {
// TODO Turn this into a TraceOptions proto which allows // TODO(pbar) Turn this into a TraceOptions proto which allows
// tracing to be controlled in a more orthogonal manner? // tracing to be controlled in a more orthogonal manner?
enum TraceLevel { enum TraceLevel {
NO_TRACE = 0; NO_TRACE = 0;
@ -781,7 +781,7 @@ message RunMetadata {
repeated GraphDef partition_graphs = 3; repeated GraphDef partition_graphs = 3;
message FunctionGraphs { message FunctionGraphs {
// TODO: Include some sort of function/cache-key identifier? // TODO(nareshmodi): Include some sort of function/cache-key identifier?
repeated GraphDef partition_graphs = 1; repeated GraphDef partition_graphs = 1;
GraphDef pre_optimization_graph = 2; GraphDef pre_optimization_graph = 2;

View File

@ -194,7 +194,7 @@ service CoordinationService {
// Report error to the task. RPC sets the receiving instance of coordination // Report error to the task. RPC sets the receiving instance of coordination
// service agent to error state permanently. // service agent to error state permanently.
// TODO: Consider splitting this into a different RPC service. // TODO(b/195990880): Consider splitting this into a different RPC service.
rpc ReportErrorToAgent(ReportErrorToAgentRequest) rpc ReportErrorToAgent(ReportErrorToAgentRequest)
returns (ReportErrorToAgentResponse); returns (ReportErrorToAgentResponse);

View File

@ -46,7 +46,7 @@ message DebugTensorWatch {
// are to be debugged, the callers of Session::Run() must use distinct // are to be debugged, the callers of Session::Run() must use distinct
// debug_urls to make sure that the streamed or dumped events do not overlap // debug_urls to make sure that the streamed or dumped events do not overlap
// among the invocations. // among the invocations.
// TODO: More visible documentation of this in g3docs. // TODO(cais): More visible documentation of this in g3docs.
repeated string debug_urls = 4; repeated string debug_urls = 4;
// Do not error out if debug op creation fails (e.g., due to dtype // Do not error out if debug op creation fails (e.g., due to dtype

View File

@ -12,7 +12,7 @@ option java_package = "org.tensorflow.util";
option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/protobuf/for_core_protos_go_proto"; option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/protobuf/for_core_protos_go_proto";
// Available modes for extracting debugging information from a Tensor. // Available modes for extracting debugging information from a Tensor.
// TODO: Document the detailed column names and semantics in a separate // TODO(cais): Document the detailed column names and semantics in a separate
// markdown file once the implementation settles. // markdown file once the implementation settles.
enum TensorDebugMode { enum TensorDebugMode {
UNSPECIFIED = 0; UNSPECIFIED = 0;
@ -223,7 +223,7 @@ message DebuggedDevice {
// A debugger-generated ID for the device. Guaranteed to be unique within // A debugger-generated ID for the device. Guaranteed to be unique within
// the scope of the debugged TensorFlow program, including single-host and // the scope of the debugged TensorFlow program, including single-host and
// multi-host settings. // multi-host settings.
// TODO: Test the uniqueness guarantee in multi-host settings. // TODO(cais): Test the uniqueness guarantee in multi-host settings.
int32 device_id = 2; int32 device_id = 2;
} }
@ -264,7 +264,7 @@ message Execution {
// field with the DebuggedDevice messages. // field with the DebuggedDevice messages.
repeated int32 output_tensor_device_ids = 9; repeated int32 output_tensor_device_ids = 9;
// TODO support, add more fields // TODO(cais): When backporting to V1 Session.run() support, add more fields
// such as fetches and feeds. // such as fetches and feeds.
} }

View File

@ -7,7 +7,7 @@ option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/protobu
// Used to serialize and transmit tensorflow::Status payloads through // Used to serialize and transmit tensorflow::Status payloads through
// grpc::Status `error_details` since grpc::Status lacks payload API. // grpc::Status `error_details` since grpc::Status lacks payload API.
// TODO: Use GRPC API once supported. // TODO(b/204231601): Use GRPC API once supported.
message GrpcPayloadContainer { message GrpcPayloadContainer {
map<string, bytes> payloads = 1; map<string, bytes> payloads = 1;
} }

View File

@ -172,7 +172,7 @@ message WaitQueueDoneRequest {
} }
message WaitQueueDoneResponse { message WaitQueueDoneResponse {
// TODO: Consider adding NodeExecStats here to be able to // TODO(nareshmodi): Consider adding NodeExecStats here to be able to
// propagate some stats. // propagate some stats.
} }

View File

@ -94,7 +94,7 @@ message ExtendSessionRequest {
} }
message ExtendSessionResponse { message ExtendSessionResponse {
// TODO: Return something about the operation? // TODO(mrry): Return something about the operation?
// The new version number for the extended graph, to be used in the next call // The new version number for the extended graph, to be used in the next call
// to ExtendSession. // to ExtendSession.

View File

@ -176,7 +176,7 @@ message SavedBareConcreteFunction {
// allows the ConcreteFunction to be called with nest structure inputs. This // allows the ConcreteFunction to be called with nest structure inputs. This
// field may not be populated. If this field is absent, the concrete function // field may not be populated. If this field is absent, the concrete function
// can only be called with flat inputs. // can only be called with flat inputs.
// TODO: support calling saved ConcreteFunction with structured // TODO(b/169361281): support calling saved ConcreteFunction with structured
// inputs in C++ SavedModel API. // inputs in C++ SavedModel API.
FunctionSpec function_spec = 4; FunctionSpec function_spec = 4;
} }

View File

@ -17,7 +17,7 @@ option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/protobu
// Special header that is associated with a bundle. // Special header that is associated with a bundle.
// //
// TODO: maybe in the future, we can add information about // TODO(zongheng,zhifengc): maybe in the future, we can add information about
// which binary produced this checkpoint, timestamp, etc. Sometime, these can be // which binary produced this checkpoint, timestamp, etc. Sometime, these can be
// valuable debugging information. And if needed, these can be used as defensive // valuable debugging information. And if needed, these can be used as defensive
// information ensuring reader (binary version) of the checkpoint and the writer // information ensuring reader (binary version) of the checkpoint and the writer

View File

@ -188,7 +188,7 @@ message DeregisterGraphRequest {
} }
message DeregisterGraphResponse { message DeregisterGraphResponse {
// TODO: Optionally add summary stats for the graph. // TODO(mrry): Optionally add summary stats for the graph.
} }
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
@ -294,7 +294,7 @@ message RunGraphResponse {
// If the request asked for execution stats, the cost graph, or the partition // If the request asked for execution stats, the cost graph, or the partition
// graphs, these are returned here. // graphs, these are returned here.
// TODO: Package these in a RunMetadata instead. // TODO(suharshs): Package these in a RunMetadata instead.
StepStats step_stats = 2; StepStats step_stats = 2;
CostGraphDef cost_graph = 3; CostGraphDef cost_graph = 3;
repeated GraphDef partition_graph = 4; repeated GraphDef partition_graph = 4;

View File

@ -13,5 +13,5 @@ message LogMetadata {
SamplingConfig sampling_config = 2; SamplingConfig sampling_config = 2;
// List of tags used to load the relevant MetaGraphDef from SavedModel. // List of tags used to load the relevant MetaGraphDef from SavedModel.
repeated string saved_model_tags = 3; repeated string saved_model_tags = 3;
// TODO: Add more metadata as mentioned in the bug. // TODO(b/33279154): Add more metadata as mentioned in the bug.
} }

View File

@ -58,7 +58,7 @@ message FileSystemStoragePathSourceConfig {
// A single servable name/base_path pair to monitor. // A single servable name/base_path pair to monitor.
// DEPRECATED: Use 'servables' instead. // DEPRECATED: Use 'servables' instead.
// TODO: Stop using these fields, and ultimately remove them here. // TODO(b/30898016): Stop using these fields, and ultimately remove them here.
string servable_name = 1 [deprecated = true]; string servable_name = 1 [deprecated = true];
string base_path = 2 [deprecated = true]; string base_path = 2 [deprecated = true];
@ -76,7 +76,7 @@ message FileSystemStoragePathSourceConfig {
// check for a version to appear later.) // check for a version to appear later.)
// DEPRECATED: Use 'servable_versions_always_present' instead, which includes // DEPRECATED: Use 'servable_versions_always_present' instead, which includes
// this behavior. // this behavior.
// TODO: Remove 2019-10-31 or later. // TODO(b/30898016): Remove 2019-10-31 or later.
bool fail_if_zero_versions_at_startup = 4 [deprecated = true]; bool fail_if_zero_versions_at_startup = 4 [deprecated = true];
// If true, the servable is always expected to exist on the underlying // If true, the servable is always expected to exist on the underlying

View File

@ -9,7 +9,7 @@ import "tensorflow_serving/config/logging_config.proto";
option cc_enable_arenas = true; option cc_enable_arenas = true;
// The type of model. // The type of model.
// TODO: DEPRECATED. // TODO(b/31336131): DEPRECATED.
enum ModelType { enum ModelType {
MODEL_TYPE_UNSPECIFIED = 0 [deprecated = true]; MODEL_TYPE_UNSPECIFIED = 0 [deprecated = true];
TENSORFLOW = 1 [deprecated = true]; TENSORFLOW = 1 [deprecated = true];
@ -31,7 +31,7 @@ message ModelConfig {
string base_path = 2; string base_path = 2;
// Type of model. // Type of model.
// TODO: DEPRECATED. Please use 'model_platform' instead. // TODO(b/31336131): DEPRECATED. Please use 'model_platform' instead.
ModelType model_type = 3 [deprecated = true]; ModelType model_type = 3 [deprecated = true];
// Type of model (e.g. "tensorflow"). // Type of model (e.g. "tensorflow").

View File

@ -1,10 +1,9 @@
#!/bin/sh #!/bin/sh
#RUST_LOG=debug LD_LIBRARY_PATH=so/onnx/lib target/release/navi_onnx --port 30 --num-worker-threads 8 --intra-op-parallelism 8 --inter-op-parallelism 8 \ #RUST_LOG=debug LD_LIBRARY_PATH=so/onnx/lib target/release/navi_onnx --port 30 --num-worker-threads 8 --intra-op-parallelism 8 --inter-op-parallelism 8 \
RUST_LOG=info LD_LIBRARY_PATH=so/onnx/lib cargo run --bin navi_onnx --features onnx -- \ RUST_LOG=info LD_LIBRARY_PATH=so/onnx/lib cargo run --bin navi_onnx --features onnx -- \
--port 30 --num-worker-threads 8 --intra-op-parallelism 8 --inter-op-parallelism 8 \ --port 8030 --num-worker-threads 8 \
--model-check-interval-secs 30 \ --model-check-interval-secs 30 \
--model-dir models/int8 \
--output caligrated_probabilities \
--input "" \
--modelsync-cli "echo" \ --modelsync-cli "echo" \
--onnx-ep-options use_arena=true --onnx-ep-options use_arena=true \
--model-dir models/prod_home --output caligrated_probabilities --input "" --intra-op-parallelism 8 --inter-op-parallelism 8 --max-batch-size 1 --batch-time-out-millis 1 \
--model-dir models/prod_home1 --output caligrated_probabilities --input "" --intra-op-parallelism 8 --inter-op-parallelism 8 --max-batch-size 1 --batch-time-out-millis 1 \

View File

@ -1,11 +1,24 @@
use anyhow::Result; use anyhow::Result;
use log::info;
use navi::cli_args::{ARGS, MODEL_SPECS}; use navi::cli_args::{ARGS, MODEL_SPECS};
use navi::onnx_model::onnx::OnnxModel; use navi::onnx_model::onnx::OnnxModel;
use navi::{bootstrap, metrics}; use navi::{bootstrap, metrics};
fn main() -> Result<()> { fn main() -> Result<()> {
env_logger::init(); env_logger::init();
assert_eq!(MODEL_SPECS.len(), ARGS.inter_op_parallelism.len()); info!("global: {:?}", ARGS.onnx_global_thread_pool_options);
let assert_session_params = if ARGS.onnx_global_thread_pool_options.is_empty() {
// std::env::set_var("OMP_NUM_THREADS", "1");
info!("now we use per session thread pool");
MODEL_SPECS.len()
}
else {
info!("now we use global thread pool");
0
};
assert_eq!(assert_session_params, ARGS.inter_op_parallelism.len());
assert_eq!(assert_session_params, ARGS.inter_op_parallelism.len());
metrics::register_custom_metrics(); metrics::register_custom_metrics();
bootstrap::bootstrap(OnnxModel::new) bootstrap::bootstrap(OnnxModel::new)
} }

View File

@ -1,5 +1,6 @@
use anyhow::Result; use anyhow::Result;
use log::{info, warn}; use log::{info, warn};
use x509_parser::{prelude::{parse_x509_pem}, parse_x509_certificate};
use std::collections::HashMap; use std::collections::HashMap;
use tokio::time::Instant; use tokio::time::Instant;
use tonic::{ use tonic::{
@ -27,6 +28,7 @@ use crate::cli_args::{ARGS, INPUTS, OUTPUTS};
use crate::metrics::{ use crate::metrics::{
NAVI_VERSION, NUM_PREDICTIONS, NUM_REQUESTS_FAILED, NUM_REQUESTS_FAILED_BY_MODEL, NAVI_VERSION, NUM_PREDICTIONS, NUM_REQUESTS_FAILED, NUM_REQUESTS_FAILED_BY_MODEL,
NUM_REQUESTS_RECEIVED, NUM_REQUESTS_RECEIVED_BY_MODEL, RESPONSE_TIME_COLLECTOR, NUM_REQUESTS_RECEIVED, NUM_REQUESTS_RECEIVED_BY_MODEL, RESPONSE_TIME_COLLECTOR,
CERT_EXPIRY_EPOCH
}; };
use crate::predict_service::{Model, PredictService}; use crate::predict_service::{Model, PredictService};
use crate::tf_proto::tensorflow_serving::model_spec::VersionChoice::Version; use crate::tf_proto::tensorflow_serving::model_spec::VersionChoice::Version;
@ -207,6 +209,9 @@ impl<T: Model> PredictionService for PredictService<T> {
PredictResult::DropDueToOverload => Err(Status::resource_exhausted("")), PredictResult::DropDueToOverload => Err(Status::resource_exhausted("")),
PredictResult::ModelNotFound(idx) => { PredictResult::ModelNotFound(idx) => {
Err(Status::not_found(format!("model index {}", idx))) Err(Status::not_found(format!("model index {}", idx)))
},
PredictResult::ModelNotReady(idx) => {
Err(Status::unavailable(format!("model index {}", idx)))
} }
PredictResult::ModelVersionNotFound(idx, version) => Err( PredictResult::ModelVersionNotFound(idx, version) => Err(
Status::not_found(format!("model index:{}, version {}", idx, version)), Status::not_found(format!("model index:{}, version {}", idx, version)),
@ -230,6 +235,12 @@ impl<T: Model> PredictionService for PredictService<T> {
} }
} }
// A function that takes a timestamp as input and returns a ticker stream
fn report_expiry(expiry_time: i64) {
info!("Certificate expires at epoch: {:?}", expiry_time);
CERT_EXPIRY_EPOCH.set(expiry_time as i64);
}
pub fn bootstrap<T: Model>(model_factory: ModelFactory<T>) -> Result<()> { pub fn bootstrap<T: Model>(model_factory: ModelFactory<T>) -> Result<()> {
info!("package: {}, version: {}, args: {:?}", NAME, VERSION, *ARGS); info!("package: {}, version: {}, args: {:?}", NAME, VERSION, *ARGS);
//we follow SemVer. So here we assume MAJOR.MINOR.PATCH //we follow SemVer. So here we assume MAJOR.MINOR.PATCH
@ -246,6 +257,7 @@ pub fn bootstrap<T: Model>(model_factory: ModelFactory<T>) -> Result<()> {
); );
} }
tokio::runtime::Builder::new_multi_thread() tokio::runtime::Builder::new_multi_thread()
.thread_name("async worker") .thread_name("async worker")
.worker_threads(ARGS.num_worker_threads) .worker_threads(ARGS.num_worker_threads)
@ -263,6 +275,21 @@ pub fn bootstrap<T: Model>(model_factory: ModelFactory<T>) -> Result<()> {
let mut builder = if ARGS.ssl_dir.is_empty() { let mut builder = if ARGS.ssl_dir.is_empty() {
Server::builder() Server::builder()
} else { } else {
// Read the pem file as a string
let pem_str = std::fs::read_to_string(format!("{}/server.crt", ARGS.ssl_dir)).unwrap();
let res = parse_x509_pem(&pem_str.as_bytes());
match res {
Ok((rem, pem_2)) => {
assert!(rem.is_empty());
assert_eq!(pem_2.label, String::from("CERTIFICATE"));
let res_x509 = parse_x509_certificate(&pem_2.contents);
info!("Certificate label: {}", pem_2.label);
assert!(res_x509.is_ok());
report_expiry(res_x509.unwrap().1.validity().not_after.timestamp());
},
_ => panic!("PEM parsing failed: {:?}", res),
}
let key = tokio::fs::read(format!("{}/server.key", ARGS.ssl_dir)) let key = tokio::fs::read(format!("{}/server.key", ARGS.ssl_dir))
.await .await
.expect("can't find key file"); .expect("can't find key file");

View File

@ -87,13 +87,11 @@ pub struct Args {
pub intra_op_parallelism: Vec<String>, pub intra_op_parallelism: Vec<String>,
#[clap( #[clap(
long, long,
default_value = "14",
help = "number of threads to parallelize computations of the graph" help = "number of threads to parallelize computations of the graph"
)] )]
pub inter_op_parallelism: Vec<String>, pub inter_op_parallelism: Vec<String>,
#[clap( #[clap(
long, long,
default_value = "serving_default",
help = "signature of a serving. only TF" help = "signature of a serving. only TF"
)] )]
pub serving_sig: Vec<String>, pub serving_sig: Vec<String>,
@ -107,6 +105,8 @@ pub struct Args {
help = "max warmup records to use. warmup only implemented for TF" help = "max warmup records to use. warmup only implemented for TF"
)] )]
pub max_warmup_records: usize, pub max_warmup_records: usize,
#[clap(long, value_parser = Args::parse_key_val::<String, String>, value_delimiter=',')]
pub onnx_global_thread_pool_options: Vec<(String, String)>,
#[clap( #[clap(
long, long,
default_value = "true", default_value = "true",

View File

@ -146,6 +146,7 @@ pub enum PredictResult {
Ok(Vec<TensorScores>, i64), Ok(Vec<TensorScores>, i64),
DropDueToOverload, DropDueToOverload,
ModelNotFound(usize), ModelNotFound(usize),
ModelNotReady(usize),
ModelVersionNotFound(usize, i64), ModelVersionNotFound(usize, i64),
} }

View File

@ -171,6 +171,9 @@ lazy_static! {
&["model_name"] &["model_name"]
) )
.expect("metric can be created"); .expect("metric can be created");
pub static ref CERT_EXPIRY_EPOCH: IntGauge =
IntGauge::new(":navi:cert_expiry_epoch", "Timestamp when the current cert expires")
.expect("metric can be created");
} }
pub fn register_custom_metrics() { pub fn register_custom_metrics() {
@ -249,6 +252,10 @@ pub fn register_custom_metrics() {
REGISTRY REGISTRY
.register(Box::new(CONVERTER_TIME_COLLECTOR.clone())) .register(Box::new(CONVERTER_TIME_COLLECTOR.clone()))
.expect("collector can be registered"); .expect("collector can be registered");
REGISTRY
.register(Box::new(CERT_EXPIRY_EPOCH.clone()))
.expect("collector can be registered");
} }
pub fn register_dynamic_metrics(c: &HistogramVec) { pub fn register_dynamic_metrics(c: &HistogramVec) {

View File

@ -13,21 +13,22 @@ pub mod onnx {
use dr_transform::converter::{BatchPredictionRequestToTorchTensorConverter, Converter}; use dr_transform::converter::{BatchPredictionRequestToTorchTensorConverter, Converter};
use itertools::Itertools; use itertools::Itertools;
use log::{debug, info}; use log::{debug, info};
use ort::environment::Environment; use dr_transform::ort::environment::Environment;
use ort::session::Session; use dr_transform::ort::session::Session;
use ort::tensor::InputTensor; use dr_transform::ort::tensor::InputTensor;
use ort::{ExecutionProvider, GraphOptimizationLevel, SessionBuilder}; use dr_transform::ort::{ExecutionProvider, GraphOptimizationLevel, SessionBuilder};
use dr_transform::ort::LoggingLevel;
use serde_json::Value; use serde_json::Value;
use std::fmt::{Debug, Display}; use std::fmt::{Debug, Display};
use std::sync::Arc; use std::sync::Arc;
use std::{fmt, fs}; use std::{fmt, fs};
use tokio::time::Instant; use tokio::time::Instant;
lazy_static! { lazy_static! {
pub static ref ENVIRONMENT: Arc<Environment> = Arc::new( pub static ref ENVIRONMENT: Arc<Environment> = Arc::new(
Environment::builder() Environment::builder()
.with_name("onnx home") .with_name("onnx home")
.with_log_level(ort::LoggingLevel::Error) .with_log_level(LoggingLevel::Error)
.with_global_thread_pool(ARGS.onnx_global_thread_pool_options.clone())
.build() .build()
.unwrap() .unwrap()
); );
@ -101,7 +102,9 @@ pub mod onnx {
let meta_info = format!("{}/{}/{}", ARGS.model_dir[idx], version, META_INFO); let meta_info = format!("{}/{}/{}", ARGS.model_dir[idx], version, META_INFO);
let mut builder = SessionBuilder::new(&ENVIRONMENT)? let mut builder = SessionBuilder::new(&ENVIRONMENT)?
.with_optimization_level(GraphOptimizationLevel::Level3)? .with_optimization_level(GraphOptimizationLevel::Level3)?
.with_parallel_execution(ARGS.onnx_use_parallel_mode == "true")? .with_parallel_execution(ARGS.onnx_use_parallel_mode == "true")?;
if ARGS.onnx_global_thread_pool_options.is_empty() {
builder = builder
.with_inter_threads( .with_inter_threads(
utils::get_config_or( utils::get_config_or(
model_config, model_config,
@ -117,7 +120,12 @@ pub mod onnx {
&ARGS.intra_op_parallelism[idx], &ARGS.intra_op_parallelism[idx],
) )
.parse()?, .parse()?,
)? )?;
}
else {
builder = builder.with_disable_per_session_threads()?;
}
builder = builder
.with_memory_pattern(ARGS.onnx_use_memory_pattern == "true")? .with_memory_pattern(ARGS.onnx_use_memory_pattern == "true")?
.with_execution_providers(&OnnxModel::ep_choices())?; .with_execution_providers(&OnnxModel::ep_choices())?;
match &ARGS.profiling { match &ARGS.profiling {
@ -181,7 +189,7 @@ pub mod onnx {
&version, &version,
reporting_feature_ids, reporting_feature_ids,
Some(metrics::register_dynamic_metrics), Some(metrics::register_dynamic_metrics),
)), )?),
}; };
onnx_model.warmup()?; onnx_model.warmup()?;
Ok(onnx_model) Ok(onnx_model)

View File

@ -1,7 +1,7 @@
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
use arrayvec::ArrayVec; use arrayvec::ArrayVec;
use itertools::Itertools; use itertools::Itertools;
use log::{error, info, warn}; use log::{error, info};
use std::fmt::{Debug, Display}; use std::fmt::{Debug, Display};
use std::string::String; use std::string::String;
use std::sync::Arc; use std::sync::Arc;
@ -24,7 +24,7 @@ use serde_json::{self, Value};
pub trait Model: Send + Sync + Display + Debug + 'static { pub trait Model: Send + Sync + Display + Debug + 'static {
fn warmup(&self) -> Result<()>; fn warmup(&self) -> Result<()>;
//TODO: refactor this to return Vec<Vec<TensorScores>>, i.e. //TODO: refactor this to return vec<vec<TensorScores>>, i.e.
//we have the underlying runtime impl to split the response to each client. //we have the underlying runtime impl to split the response to each client.
//It will eliminate some inefficient memory copy in onnx_model.rs as well as simplify code //It will eliminate some inefficient memory copy in onnx_model.rs as well as simplify code
fn do_predict( fn do_predict(
@ -179,17 +179,17 @@ impl<T: Model> PredictService<T> {
//initialize the latest version array //initialize the latest version array
let mut cur_versions = vec!["".to_owned(); MODEL_SPECS.len()]; let mut cur_versions = vec!["".to_owned(); MODEL_SPECS.len()];
loop { loop {
let config = utils::read_config(&meta_file).unwrap_or_else(|e| {
warn!("config file {} not found due to: {}", meta_file, e);
Value::Null
});
info!("***polling for models***"); //nice deliminter info!("***polling for models***"); //nice deliminter
info!("config:{}", config);
if let Some(ref cli) = ARGS.modelsync_cli { if let Some(ref cli) = ARGS.modelsync_cli {
if let Err(e) = call_external_modelsync(cli, &cur_versions).await { if let Err(e) = call_external_modelsync(cli, &cur_versions).await {
error!("model sync cli running error:{}", e) error!("model sync cli running error:{}", e)
} }
} }
let config = utils::read_config(&meta_file).unwrap_or_else(|e| {
info!("config file {} not found due to: {}", meta_file, e);
Value::Null
});
info!("config:{}", config);
for (idx, cur_version) in cur_versions.iter_mut().enumerate() { for (idx, cur_version) in cur_versions.iter_mut().enumerate() {
let model_dir = &ARGS.model_dir[idx]; let model_dir = &ARGS.model_dir[idx];
PredictService::scan_load_latest_model_from_model_dir( PredictService::scan_load_latest_model_from_model_dir(
@ -222,13 +222,18 @@ impl<T: Model> PredictService<T> {
.map(|b| b.parse().unwrap()) .map(|b| b.parse().unwrap())
.collect::<Vec<u64>>(); .collect::<Vec<u64>>();
let no_msg_wait_millis = *batch_time_out_millis.iter().min().unwrap(); let no_msg_wait_millis = *batch_time_out_millis.iter().min().unwrap();
let mut all_model_predictors = let mut all_model_predictors: ArrayVec::<ArrayVec<BatchPredictor<T>, MAX_VERSIONS_PER_MODEL>, MAX_NUM_MODELS> =
ArrayVec::<ArrayVec<BatchPredictor<T>, MAX_VERSIONS_PER_MODEL>, MAX_NUM_MODELS>::new(); (0 ..MAX_NUM_MODELS).map( |_| ArrayVec::<BatchPredictor<T>, MAX_VERSIONS_PER_MODEL>::new()).collect();
loop { loop {
let msg = rx.try_recv(); let msg = rx.try_recv();
let no_more_msg = match msg { let no_more_msg = match msg {
Ok(PredictMessage::Predict(model_spec_at, version, val, resp, ts)) => { Ok(PredictMessage::Predict(model_spec_at, version, val, resp, ts)) => {
if let Some(model_predictors) = all_model_predictors.get_mut(model_spec_at) { if let Some(model_predictors) = all_model_predictors.get_mut(model_spec_at) {
if model_predictors.is_empty() {
resp.send(PredictResult::ModelNotReady(model_spec_at))
.unwrap_or_else(|e| error!("cannot send back model not ready error: {:?}", e));
}
else {
match version { match version {
None => model_predictors[0].push(val, resp, ts), None => model_predictors[0].push(val, resp, ts),
Some(the_version) => match model_predictors Some(the_version) => match model_predictors
@ -246,9 +251,10 @@ impl<T: Model> PredictService<T> {
Some(predictor) => predictor.push(val, resp, ts), Some(predictor) => predictor.push(val, resp, ts),
}, },
} }
}
} else { } else {
resp.send(PredictResult::ModelNotFound(model_spec_at)) resp.send(PredictResult::ModelNotFound(model_spec_at))
.unwrap_or_else(|e| error!("cannot send back model error: {:?}", e)) .unwrap_or_else(|e| error!("cannot send back model not found error: {:?}", e))
} }
MPSC_CHANNEL_SIZE.dec(); MPSC_CHANNEL_SIZE.dec();
false false
@ -266,27 +272,23 @@ impl<T: Model> PredictService<T> {
queue_reset_ts: Instant::now(), queue_reset_ts: Instant::now(),
queue_earliest_rq_ts: Instant::now(), queue_earliest_rq_ts: Instant::now(),
}; };
if idx < all_model_predictors.len() { assert!(idx < all_model_predictors.len());
metrics::NEW_MODEL_SNAPSHOT metrics::NEW_MODEL_SNAPSHOT
.with_label_values(&[&MODEL_SPECS[idx]]) .with_label_values(&[&MODEL_SPECS[idx]])
.inc(); .inc();
info!("now we serve updated model: {}", predictor.model);
//we can do this since the vector is small //we can do this since the vector is small
let predictors = &mut all_model_predictors[idx]; let predictors = &mut all_model_predictors[idx];
if predictors.len() == 0 {
info!("now we serve new model: {}", predictor.model);
}
else {
info!("now we serve updated model: {}", predictor.model);
}
if predictors.len() == ARGS.versions_per_model { if predictors.len() == ARGS.versions_per_model {
predictors.remove(predictors.len() - 1); predictors.remove(predictors.len() - 1);
} }
predictors.insert(0, predictor); predictors.insert(0, predictor);
} else {
info!("now we serve new model: {:}", predictor.model);
let mut predictors =
ArrayVec::<BatchPredictor<T>, MAX_VERSIONS_PER_MODEL>::new();
predictors.push(predictor);
all_model_predictors.push(predictors);
//check the invariant that we always push the last model to the end
assert_eq!(all_model_predictors.len(), idx + 1)
}
false false
} }
Err(TryRecvError::Empty) => true, Err(TryRecvError::Empty) => true,

View File

@ -3,9 +3,9 @@ name = "segdense"
version = "0.1.0" version = "0.1.0"
edition = "2021" edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies] [dependencies]
env_logger = "0.10.0"
serde = { version = "1.0.104", features = ["derive"] } serde = { version = "1.0.104", features = ["derive"] }
serde_json = "1.0.48" serde_json = "1.0.48"
log = "0.4.17" log = "0.4.17"

View File

@ -19,11 +19,21 @@ impl Display for SegDenseError {
match self { match self {
SegDenseError::IoError(io_error) => write!(f, "{}", io_error), SegDenseError::IoError(io_error) => write!(f, "{}", io_error),
SegDenseError::Json(serde_json) => write!(f, "{}", serde_json), SegDenseError::Json(serde_json) => write!(f, "{}", serde_json),
SegDenseError::JsonMissingRoot => write!(f, "{}", "SegDense JSON: Root Node note found!"), SegDenseError::JsonMissingRoot => {
SegDenseError::JsonMissingObject => write!(f, "{}", "SegDense JSON: Object note found!"), write!(f, "{}", "SegDense JSON: Root Node note found!")
SegDenseError::JsonMissingArray => write!(f, "{}", "SegDense JSON: Array Node note found!"), }
SegDenseError::JsonArraySize => write!(f, "{}", "SegDense JSON: Array size not as expected!"), SegDenseError::JsonMissingObject => {
SegDenseError::JsonMissingInputFeature => write!(f, "{}", "SegDense JSON: Missing input feature!"), write!(f, "{}", "SegDense JSON: Object note found!")
}
SegDenseError::JsonMissingArray => {
write!(f, "{}", "SegDense JSON: Array Node note found!")
}
SegDenseError::JsonArraySize => {
write!(f, "{}", "SegDense JSON: Array size not as expected!")
}
SegDenseError::JsonMissingInputFeature => {
write!(f, "{}", "SegDense JSON: Missing input feature!")
}
} }
} }
} }

View File

@ -1,4 +1,4 @@
pub mod error; pub mod error;
pub mod segdense_transform_spec_home_recap_2022;
pub mod mapper; pub mod mapper;
pub mod segdense_transform_spec_home_recap_2022;
pub mod util; pub mod util;

View File

@ -20,4 +20,3 @@ fn main() -> Result<(), SegDenseError> {
Ok(()) Ok(())
} }

View File

@ -19,7 +19,7 @@ pub struct FeatureMapper {
impl FeatureMapper { impl FeatureMapper {
pub fn new() -> FeatureMapper { pub fn new() -> FeatureMapper {
FeatureMapper { FeatureMapper {
map: HashMap::new() map: HashMap::new(),
} }
} }
} }

View File

@ -164,7 +164,6 @@ pub struct ComplexFeatureTypeTransformSpec {
pub tensor_shape: Vec<i64>, pub tensor_shape: Vec<i64>,
} }
#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)] #[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub struct InputFeatureMapRecord { pub struct InputFeatureMapRecord {

View File

@ -1,23 +1,23 @@
use log::debug;
use std::fs; use std::fs;
use log::{debug};
use serde_json::{Value, Map}; use serde_json::{Map, Value};
use crate::error::SegDenseError; use crate::error::SegDenseError;
use crate::mapper::{FeatureMapper, FeatureInfo, MapWriter}; use crate::mapper::{FeatureInfo, FeatureMapper, MapWriter};
use crate::segdense_transform_spec_home_recap_2022::{self as seg_dense, InputFeature}; use crate::segdense_transform_spec_home_recap_2022::{self as seg_dense, InputFeature};
pub fn load_config(file_name: &str) -> seg_dense::Root { pub fn load_config(file_name: &str) -> Result<seg_dense::Root, SegDenseError> {
let json_str = fs::read_to_string(file_name).expect( let json_str = fs::read_to_string(file_name)?;
&format!("Unable to load segdense file {}", file_name)); // &format!("Unable to load segdense file {}", file_name));
let seg_dense_config = parse(&json_str).expect( let seg_dense_config = parse(&json_str)?;
&format!("Unable to parse segdense file {}", file_name)); // &format!("Unable to parse segdense file {}", file_name));
return seg_dense_config; Ok(seg_dense_config)
} }
pub fn parse(json_str: &str) -> Result<seg_dense::Root, SegDenseError> { pub fn parse(json_str: &str) -> Result<seg_dense::Root, SegDenseError> {
let root: seg_dense::Root = serde_json::from_str(json_str)?; let root: seg_dense::Root = serde_json::from_str(json_str)?;
return Ok(root); Ok(root)
} }
/** /**
@ -44,15 +44,8 @@ pub fn safe_load_config(json_str: &str) -> Result<FeatureMapper, SegDenseError>
load_from_parsed_config(root) load_from_parsed_config(root)
} }
pub fn load_from_parsed_config_ref(root: &seg_dense::Root) -> FeatureMapper {
load_from_parsed_config(root.clone()).unwrap_or_else(
|error| panic!("Error loading all_config.json - {}", error))
}
// Perf note : make 'root' un-owned // Perf note : make 'root' un-owned
pub fn load_from_parsed_config(root: seg_dense::Root) -> pub fn load_from_parsed_config(root: seg_dense::Root) -> Result<FeatureMapper, SegDenseError> {
Result<FeatureMapper, SegDenseError> {
let v = root.input_features_map; let v = root.input_features_map;
// Do error check // Do error check
@ -86,7 +79,7 @@ pub fn load_from_parsed_config(root: seg_dense::Root) ->
Some(info) => { Some(info) => {
debug!("{:?}", info); debug!("{:?}", info);
fm.set(feature_id, info) fm.set(feature_id, info)
}, }
None => (), None => (),
} }
} }
@ -94,7 +87,10 @@ pub fn load_from_parsed_config(root: seg_dense::Root) ->
Ok(fm) Ok(fm)
} }
#[allow(dead_code)] #[allow(dead_code)]
fn add_feature_info_to_mapper(feature_mapper: &mut FeatureMapper, input_features: &Vec<InputFeature>) { fn add_feature_info_to_mapper(
feature_mapper: &mut FeatureMapper,
input_features: &Vec<InputFeature>,
) {
for input_feature in input_features.iter() { for input_feature in input_features.iter() {
let feature_id = input_feature.feature_id; let feature_id = input_feature.feature_id;
let feature_info = to_feature_info(input_feature); let feature_info = to_feature_info(input_feature);
@ -103,7 +99,7 @@ fn add_feature_info_to_mapper(feature_mapper: &mut FeatureMapper, input_features
Some(info) => { Some(info) => {
debug!("{:?}", info); debug!("{:?}", info);
feature_mapper.set(feature_id, info) feature_mapper.set(feature_id, info)
}, }
None => (), None => (),
} }
} }
@ -139,7 +135,7 @@ pub fn to_feature_info(input_feature: &seg_dense::InputFeature) -> Option<Featur
2 => 0, 2 => 0,
3 => 2, 3 => 2,
_ => -1, _ => -1,
} },
}; };
if input_feature.index < 0 { if input_feature.index < 0 {
@ -156,4 +152,3 @@ pub fn to_feature_info(input_feature: &seg_dense::InputFeature) -> Option<Featur
index_within_tensor: input_feature.index, index_within_tensor: input_feature.index,
}) })
} }

48
pushservice/BUILD.bazel Normal file
View File

@ -0,0 +1,48 @@
alias(
name = "frigate-pushservice",
target = ":frigate-pushservice_lib",
)
target(
name = "frigate-pushservice_lib",
dependencies = [
"frigate/frigate-pushservice-opensource/src/main/scala/com/twitter/frigate/pushservice",
],
)
jvm_binary(
name = "bin",
basename = "frigate-pushservice",
main = "com.twitter.frigate.pushservice.PushServiceMain",
runtime_platform = "java11",
tags = ["bazel-compatible"],
dependencies = [
"3rdparty/jvm/ch/qos/logback:logback-classic",
"finatra/inject/inject-logback/src/main/scala",
"frigate/frigate-pushservice-opensource/src/main/scala/com/twitter/frigate/pushservice",
"loglens/loglens-logback/src/main/scala/com/twitter/loglens/logback",
"twitter-server/logback-classic/src/main/scala",
],
excludes = [
exclude("com.twitter.translations", "translations-twitter"),
exclude("org.apache.hadoop", "hadoop-aws"),
exclude("org.tensorflow"),
scala_exclude("com.twitter", "ckoia-scala"),
],
)
jvm_app(
name = "bundle",
basename = "frigate-pushservice-package-dist",
archive = "zip",
binary = ":bin",
tags = ["bazel-compatible"],
)
python3_library(
name = "mr_model_constants",
sources = [
"config/deepbird/constants.py",
],
tags = ["bazel-compatible"],
)

45
pushservice/README.md Normal file
View File

@ -0,0 +1,45 @@
# Pushservice
Pushservice is the main push recommendation service at Twitter used to generate recommendation-based notifications for users. It currently powers two functionalities:
- RefreshForPushHandler: This handler determines whether to send a recommendation push to a user based on their ID. It generates the best push recommendation item and coordinates with downstream services to deliver it
- SendHandler: This handler determines and manage whether send the push to users based on the given target user details and the provided push recommendation item
## Overview
### RefreshForPushHandler
RefreshForPushHandler follows these steps:
- Building Target and checking eligibility
- Builds a target user object based on the given user ID
- Performs target-level filterings to determine if the target is eligible for a recommendation push
- Fetch Candidates
- Retrieves a list of potential candidates for the push by querying various candidate sources using the target
- Candidate Hydration
- Hydrates the candidate details with batch calls to different downstream services
- Pre-rank Filtering, also called Light Filtering
- Filters the hydrated candidates with lightweight RPC calls
- Rank
- Perform feature hydration for candidates and target user
- Performs light ranking on candidates
- Performs heavy ranking on candidates
- Take Step, also called Heavy Filtering
- Takes the top-ranked candidates one by one and applies heavy filtering until one candidate passes all filter steps
- Send
- Calls the appropriate downstream service to deliver the eligible candidate as a push and in-app notification to the target user
### SendHandler
SendHandler follows these steps:
- Building Target
- Builds a target user object based on the given user ID
- Candidate Hydration
- Hydrates the candidate details with batch calls to different downstream services
- Feature Hydration
- Perform feature hydration for candidates and target user
- Take Step, also called Heavy Filtering
- Perform filterings and validation checking for the given candidate
- Send
- Calls the appropriate downstream service to deliver the given candidate as a push and/or in-app notification to the target user

View File

@ -0,0 +1,169 @@
python37_binary(
name = "update_warm_start_checkpoint",
source = "update_warm_start_checkpoint.py",
tags = ["no-mypy"],
dependencies = [
":deep_norm_lib",
"3rdparty/python/_closures/frigate/frigate-pushservice-opensource/src/main/python/models/heavy_ranking:update_warm_start_checkpoint",
],
)
python3_library(
name = "params_lib",
sources = ["params.py"],
tags = ["no-mypy"],
dependencies = [
"3rdparty/python/pydantic:default",
"src/python/twitter/deepbird/projects/magic_recs/v11/lib:params_lib",
],
)
python3_library(
name = "features_lib",
sources = ["features.py"],
tags = ["no-mypy"],
dependencies = [
":params_lib",
"src/python/twitter/deepbird/projects/magic_recs/libs",
"twml:twml-nodeps",
],
)
python3_library(
name = "model_pools_lib",
sources = ["model_pools.py"],
tags = ["no-mypy"],
dependencies = [
":features_lib",
":params_lib",
"src/python/twitter/deepbird/projects/magic_recs/v11/lib:model_lib",
],
)
python3_library(
name = "graph_lib",
sources = ["graph.py"],
tags = ["no-mypy"],
dependencies = [
":params_lib",
"src/python/twitter/deepbird/projects/magic_recs/libs",
],
)
python3_library(
name = "run_args_lib",
sources = ["run_args.py"],
tags = ["no-mypy"],
dependencies = [
":features_lib",
":params_lib",
"twml:twml-nodeps",
],
)
python3_library(
name = "deep_norm_lib",
sources = ["deep_norm.py"],
tags = ["no-mypy"],
dependencies = [
":features_lib",
":graph_lib",
":model_pools_lib",
":params_lib",
":run_args_lib",
"src/python/twitter/deepbird/projects/magic_recs/libs",
"src/python/twitter/deepbird/util/data",
"twml:twml-nodeps",
],
)
python3_library(
name = "eval_lib",
sources = ["eval.py"],
tags = ["no-mypy"],
dependencies = [
":features_lib",
":graph_lib",
":model_pools_lib",
":params_lib",
":run_args_lib",
"src/python/twitter/deepbird/projects/magic_recs/libs",
"twml:twml-nodeps",
],
)
python37_binary(
name = "deep_norm",
source = "deep_norm.py",
dependencies = [
":deep_norm_lib",
"3rdparty/python/_closures/frigate/frigate-pushservice-opensource/src/main/python/models/heavy_ranking:deep_norm",
"twml",
],
)
python37_binary(
name = "eval",
source = "eval.py",
dependencies = [
":eval_lib",
"3rdparty/python/_closures/frigate/frigate-pushservice-opensource/src/main/python/models/heavy_ranking:eval",
"twml",
],
)
python3_library(
name = "mlwf_libs",
tags = ["no-mypy"],
dependencies = [
":deep_norm_lib",
"twml",
],
)
python37_binary(
name = "train_model",
source = "deep_norm.py",
dependencies = [
":deep_norm_lib",
"3rdparty/python/_closures/frigate/frigate-pushservice-opensource/src/main/python/models/heavy_ranking:train_model",
],
)
python37_binary(
name = "train_model_local",
source = "deep_norm.py",
dependencies = [
":deep_norm_lib",
"3rdparty/python/_closures/frigate/frigate-pushservice-opensource/src/main/python/models/heavy_ranking:train_model_local",
"twml",
],
)
python37_binary(
name = "eval_model_local",
source = "eval.py",
dependencies = [
":eval_lib",
"3rdparty/python/_closures/frigate/frigate-pushservice-opensource/src/main/python/models/heavy_ranking:eval_model_local",
"twml",
],
)
python37_binary(
name = "eval_model",
source = "eval.py",
dependencies = [
":eval_lib",
"3rdparty/python/_closures/frigate/frigate-pushservice-opensource/src/main/python/models/heavy_ranking:eval_model",
],
)
python37_binary(
name = "mlwf_model",
source = "deep_norm.py",
dependencies = [
":mlwf_libs",
"3rdparty/python/_closures/frigate/frigate-pushservice-opensource/src/main/python/models/heavy_ranking:mlwf_model",
],
)

View File

@ -0,0 +1,20 @@
# Notification Heavy Ranker Model
## Model Context
There are 4 major components of Twitter notifications recommendation system: 1) candidate generation 2) light ranking 3) heavy ranking & 4) quality control. This notification heavy ranker model is the core ranking model for the personalised notifications recommendation. It's a multi-task learning model to predict the probabilities that the target users will open and engage with the sent notifications.
## Directory Structure
- BUILD: this file defines python library dependencies
- deep_norm.py: this file contains how to set up continuous training, model evaluation and model exporting for the notification heavy ranker model
- eval.py: the main python entry file to set up the overall model evaluation pipeline
- features.py: this file contains importing feature list and support functions for feature engineering
- graph.py: this file defines how to build the tensorflow graph with specified model architecture, loss function and training configuration
- model_pools.py: this file defines the available model types for the heavy ranker
- params.py: this file defines hyper-parameters used in the notification heavy ranker
- run_args.py: this file defines command line parameters to run model training & evaluation
- update_warm_start_checkpoint.py: this file contains the support to modify checkpoints of the given saved heavy ranker model
- lib/BUILD: this file defines python library dependencies for tensorflow model architecture
- lib/layers.py: this file defines different type of convolution layers to be used in the heavy ranker model
- lib/model.py: this file defines the module containing ClemNet, the heavy ranker model type
- lib/params.py: this file defines parameters used in the heavy ranker model

View File

@ -0,0 +1,136 @@
"""
Training job for the heavy ranker of the push notification service.
"""
from datetime import datetime
import json
import os
import twml
from ..libs.metric_fn_utils import flip_disliked_labels, get_metric_fn
from ..libs.model_utils import read_config
from ..libs.warm_start_utils import get_feature_list_for_heavy_ranking, warm_start_checkpoint
from .features import get_feature_config
from .model_pools import ALL_MODELS
from .params import load_graph_params
from .run_args import get_training_arg_parser
import tensorflow.compat.v1 as tf
from tensorflow.compat.v1 import logging
def main() -> None:
args, _ = get_training_arg_parser().parse_known_args()
logging.info(f"Parsed args: {args}")
params = load_graph_params(args)
logging.info(f"Loaded graph params: {params}")
param_file = os.path.join(args.save_dir, "params.json")
logging.info(f"Saving graph params to: {param_file}")
with tf.io.gfile.GFile(param_file, mode="w") as file:
json.dump(params.json(), file, ensure_ascii=False, indent=4)
logging.info(f"Get Feature Config: {args.feature_list}")
feature_list = read_config(args.feature_list).items()
feature_config = get_feature_config(
data_spec_path=args.data_spec,
params=params,
feature_list_provided=feature_list,
)
feature_list_path = args.feature_list
warm_start_from = args.warm_start_from
if args.warm_start_base_dir:
logging.info(f"Get warm started model from: {args.warm_start_base_dir}.")
continuous_binary_feat_list_save_path = os.path.join(
args.warm_start_base_dir, "continuous_binary_feat_list.json"
)
warm_start_folder = os.path.join(args.warm_start_base_dir, "best_checkpoint")
job_name = os.path.basename(args.save_dir)
ws_output_ckpt_folder = os.path.join(args.warm_start_base_dir, f"warm_start_for_{job_name}")
if tf.io.gfile.exists(ws_output_ckpt_folder):
tf.io.gfile.rmtree(ws_output_ckpt_folder)
tf.io.gfile.mkdir(ws_output_ckpt_folder)
warm_start_from = warm_start_checkpoint(
warm_start_folder,
continuous_binary_feat_list_save_path,
feature_list_path,
args.data_spec,
ws_output_ckpt_folder,
)
logging.info(f"Created warm_start_from_ckpt {warm_start_from}.")
logging.info("Build Trainer.")
metric_fn = get_metric_fn("OONC_Engagement" if len(params.tasks) == 2 else "OONC", False)
trainer = twml.trainers.DataRecordTrainer(
name="magic_recs",
params=args,
build_graph_fn=lambda *args: ALL_MODELS[params.model.name](params=params)(*args),
save_dir=args.save_dir,
run_config=None,
feature_config=feature_config,
metric_fn=flip_disliked_labels(metric_fn),
warm_start_from=warm_start_from,
)
logging.info("Build train and eval input functions.")
train_input_fn = trainer.get_train_input_fn(shuffle=True)
eval_input_fn = trainer.get_eval_input_fn(repeat=False, shuffle=False)
learn = trainer.learn
if args.distributed or args.num_workers is not None:
learn = trainer.train_and_evaluate
if not args.directly_export_best:
logging.info("Starting training")
start = datetime.now()
learn(
early_stop_minimize=False,
early_stop_metric="pr_auc_unweighted_OONC",
early_stop_patience=args.early_stop_patience,
early_stop_tolerance=args.early_stop_tolerance,
eval_input_fn=eval_input_fn,
train_input_fn=train_input_fn,
)
logging.info(f"Total training time: {datetime.now() - start}")
else:
logging.info("Directly exporting the model")
if not args.export_dir:
args.export_dir = os.path.join(args.save_dir, "exported_models")
logging.info(f"Exporting the model to {args.export_dir}.")
start = datetime.now()
twml.contrib.export.export_fn.export_all_models(
trainer=trainer,
export_dir=args.export_dir,
parse_fn=feature_config.get_parse_fn(),
serving_input_receiver_fn=feature_config.get_serving_input_receiver_fn(),
export_output_fn=twml.export_output_fns.batch_prediction_continuous_output_fn,
)
logging.info(f"Total model export time: {datetime.now() - start}")
logging.info(f"The MLP directory is: {args.save_dir}")
continuous_binary_feat_list_save_path = os.path.join(
args.save_dir, "continuous_binary_feat_list.json"
)
logging.info(
f"Saving the list of continuous and binary features to {continuous_binary_feat_list_save_path}."
)
continuous_binary_feat_list = get_feature_list_for_heavy_ranking(
feature_list_path, args.data_spec
)
twml.util.write_file(
continuous_binary_feat_list_save_path, continuous_binary_feat_list, encode="json"
)
if __name__ == "__main__":
main()
logging.info("Done.")

View File

@ -0,0 +1,59 @@
"""
Evaluation job for the heavy ranker of the push notification service.
"""
from datetime import datetime
import twml
from ..libs.metric_fn_utils import get_metric_fn
from ..libs.model_utils import read_config
from .features import get_feature_config
from .model_pools import ALL_MODELS
from .params import load_graph_params
from .run_args import get_eval_arg_parser
from tensorflow.compat.v1 import logging
def main():
args, _ = get_eval_arg_parser().parse_known_args()
logging.info(f"Parsed args: {args}")
params = load_graph_params(args)
logging.info(f"Loaded graph params: {params}")
logging.info(f"Get Feature Config: {args.feature_list}")
feature_list = read_config(args.feature_list).items()
feature_config = get_feature_config(
data_spec_path=args.data_spec,
params=params,
feature_list_provided=feature_list,
)
logging.info("Build DataRecordTrainer.")
metric_fn = get_metric_fn("OONC_Engagement" if len(params.tasks) == 2 else "OONC", False)
trainer = twml.trainers.DataRecordTrainer(
name="magic_recs",
params=args,
build_graph_fn=lambda *args: ALL_MODELS[params.model.name](params=params)(*args),
save_dir=args.save_dir,
run_config=None,
feature_config=feature_config,
metric_fn=metric_fn,
)
logging.info("Run the evaluation.")
start = datetime.now()
trainer._estimator.evaluate(
input_fn=trainer.get_eval_input_fn(repeat=False, shuffle=False),
steps=None if (args.eval_steps is not None and args.eval_steps < 0) else args.eval_steps,
checkpoint_path=args.eval_checkpoint,
)
logging.info(f"Evaluating time: {datetime.now() - start}.")
if __name__ == "__main__":
main()
logging.info("Job done.")

View File

@ -0,0 +1,138 @@
import os
from typing import Dict
from twitter.deepbird.projects.magic_recs.libs.model_utils import filter_nans_and_infs
import twml
from twml.layers import full_sparse, sparse_max_norm
from .params import FeaturesParams, GraphParams, SparseFeaturesParams
import tensorflow as tf
from tensorflow import Tensor
import tensorflow.compat.v1 as tf1
FEAT_CONFIG_DEFAULT_VAL = 0
DEFAULT_FEATURE_LIST_PATH = "./feature_list_default.yaml"
FEATURE_LIST_DEFAULT_PATH = os.path.join(
os.path.dirname(os.path.realpath(__file__)), DEFAULT_FEATURE_LIST_PATH
)
def get_feature_config(data_spec_path=None, feature_list_provided=[], params: GraphParams = None):
a_string_feat_list = [feat for feat, feat_type in feature_list_provided if feat_type != "S"]
builder = twml.contrib.feature_config.FeatureConfigBuilder(
data_spec_path=data_spec_path, debug=False
)
builder = builder.extract_feature_group(
feature_regexes=a_string_feat_list,
group_name="continuous_features",
default_value=FEAT_CONFIG_DEFAULT_VAL,
type_filter=["CONTINUOUS"],
)
builder = builder.extract_feature_group(
feature_regexes=a_string_feat_list,
group_name="binary_features",
type_filter=["BINARY"],
)
if params.model.features.sparse_features:
builder = builder.extract_features_as_hashed_sparse(
feature_regexes=a_string_feat_list,
hash_space_size_bits=params.model.features.sparse_features.bits,
type_filter=["DISCRETE", "STRING", "SPARSE_BINARY"],
output_tensor_name="sparse_not_continuous",
)
builder = builder.extract_features_as_hashed_sparse(
feature_regexes=[feat for feat, feat_type in feature_list_provided if feat_type == "S"],
hash_space_size_bits=params.model.features.sparse_features.bits,
type_filter=["SPARSE_CONTINUOUS"],
output_tensor_name="sparse_continuous",
)
builder = builder.add_labels([task.label for task in params.tasks] + ["label.ntabDislike"])
if params.weight:
builder = builder.define_weight(params.weight)
return builder.build()
def dense_features(features: Dict[str, Tensor], training: bool) -> Tensor:
"""
Performs feature transformations on the raw dense features (continuous and binary).
"""
with tf.name_scope("dense_features"):
x = filter_nans_and_infs(features["continuous_features"])
x = tf.sign(x) * tf.math.log(tf.abs(x) + 1)
x = tf1.layers.batch_normalization(
x, momentum=0.9999, training=training, renorm=training, axis=1
)
x = tf.clip_by_value(x, -5, 5)
transformed_continous_features = tf.where(tf.math.is_nan(x), tf.zeros_like(x), x)
binary_features = filter_nans_and_infs(features["binary_features"])
binary_features = tf.dtypes.cast(binary_features, tf.float32)
output = tf.concat([transformed_continous_features, binary_features], axis=1)
return output
def sparse_features(
features: Dict[str, Tensor], training: bool, params: SparseFeaturesParams
) -> Tensor:
"""
Performs feature transformations on the raw sparse features.
"""
with tf.name_scope("sparse_features"):
with tf.name_scope("sparse_not_continuous"):
sparse_not_continuous = full_sparse(
inputs=features["sparse_not_continuous"],
output_size=params.embedding_size,
use_sparse_grads=training,
use_binary_values=False,
)
with tf.name_scope("sparse_continuous"):
shape_enforced_input = twml.util.limit_sparse_tensor_size(
sparse_tf=features["sparse_continuous"], input_size_bits=params.bits, mask_indices=False
)
normalized_continuous_sparse = sparse_max_norm(
inputs=shape_enforced_input, is_training=training
)
sparse_continuous = full_sparse(
inputs=normalized_continuous_sparse,
output_size=params.embedding_size,
use_sparse_grads=training,
use_binary_values=False,
)
output = tf.concat([sparse_not_continuous, sparse_continuous], axis=1)
return output
def get_features(features: Dict[str, Tensor], training: bool, params: FeaturesParams) -> Tensor:
"""
Performs feature transformations on the dense and sparse features and combine the resulting
tensors into a single one.
"""
with tf.name_scope("features"):
x = dense_features(features, training)
tf1.logging.info(f"Dense features: {x.shape}")
if params.sparse_features:
x = tf.concat([x, sparse_features(features, training, params.sparse_features)], axis=1)
return x

View File

@ -0,0 +1,129 @@
"""
Graph class defining methods to obtain key quantities such as:
* the logits
* the probabilities
* the final score
* the loss function
* the training operator
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, Dict
from twitter.deepbird.hparam import HParams
import twml
from ..libs.model_utils import generate_disliked_mask
from .params import GraphParams
import tensorflow as tf
import tensorflow.compat.v1 as tf1
class Graph(ABC):
def __init__(self, params: GraphParams):
self.params = params
@abstractmethod
def get_logits(self, features: Dict[str, tf.Tensor], mode: tf.estimator.ModeKeys) -> tf.Tensor:
pass
def get_probabilities(self, logits: tf.Tensor) -> tf.Tensor:
return tf.math.cumprod(tf.nn.sigmoid(logits), axis=1, name="probabilities")
def get_task_weights(self, labels: tf.Tensor) -> tf.Tensor:
oonc_label = tf.reshape(labels[:, 0], shape=(-1, 1))
task_weights = tf.concat([tf.ones_like(oonc_label), oonc_label], axis=1)
n_labels = len(self.params.tasks)
task_weights = tf.reshape(task_weights[:, 0:n_labels], shape=(-1, n_labels))
return task_weights
def get_loss(self, labels: tf.Tensor, logits: tf.Tensor, **kwargs: Any) -> tf.Tensor:
with tf.name_scope("weights"):
disliked_mask = generate_disliked_mask(labels)
labels = tf.reshape(labels[:, 0:2], shape=[-1, 2])
labels = labels * tf.cast(tf.logical_not(disliked_mask), dtype=labels.dtype)
with tf.name_scope("task_weight"):
task_weights = self.get_task_weights(labels)
with tf.name_scope("batch_size"):
batch_size = tf.cast(tf.shape(labels)[0], dtype=tf.float32, name="batch_size")
weights = task_weights / batch_size
with tf.name_scope("loss"):
loss = tf.reduce_sum(
tf.nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits) * weights,
)
return loss
def get_score(self, probabilities: tf.Tensor) -> tf.Tensor:
with tf.name_scope("score_weight"):
score_weights = tf.constant([task.score_weight for task in self.params.tasks])
score_weights = score_weights / tf.reduce_sum(score_weights, axis=0)
with tf.name_scope("score"):
score = tf.reshape(tf.reduce_sum(probabilities * score_weights, axis=1), shape=[-1, 1])
return score
def get_train_op(self, loss: tf.Tensor, twml_params) -> Any:
with tf.name_scope("optimizer"):
learning_rate = twml_params.learning_rate
optimizer = tf1.train.GradientDescentOptimizer(learning_rate=learning_rate)
update_ops = set(tf1.get_collection(tf1.GraphKeys.UPDATE_OPS))
with tf.control_dependencies(update_ops):
train_op = twml.optimizers.optimize_loss(
loss=loss,
variables=tf1.trainable_variables(),
global_step=tf1.train.get_global_step(),
optimizer=optimizer,
learning_rate=None,
)
return train_op
def __call__(
self,
features: Dict[str, tf.Tensor],
labels: tf.Tensor,
mode: tf.estimator.ModeKeys,
params: HParams,
config=None,
) -> Dict[str, tf.Tensor]:
training = mode == tf.estimator.ModeKeys.TRAIN
logits = self.get_logits(features=features, training=training)
probabilities = self.get_probabilities(logits=logits)
score = None
loss = None
train_op = None
if mode == tf.estimator.ModeKeys.PREDICT:
score = self.get_score(probabilities=probabilities)
output = {"loss": loss, "train_op": train_op, "prediction": score}
elif mode in (tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL):
loss = self.get_loss(labels=labels, logits=logits)
if mode == tf.estimator.ModeKeys.TRAIN:
train_op = self.get_train_op(loss=loss, twml_params=params)
output = {"loss": loss, "train_op": train_op, "output": probabilities}
else:
raise ValueError(
f"""
Invalid mode. Possible values are: {tf.estimator.ModeKeys.PREDICT}, {tf.estimator.ModeKeys.TRAIN}, and {tf.estimator.ModeKeys.EVAL}
. Passed: {mode}
"""
)
return output

View File

@ -0,0 +1,42 @@
python3_library(
name = "params_lib",
sources = [
"params.py",
],
tags = [
"bazel-compatible",
"no-mypy",
],
dependencies = [
"3rdparty/python/pydantic:default",
],
)
python3_library(
name = "layers_lib",
sources = [
"layers.py",
],
tags = [
"bazel-compatible",
"no-mypy",
],
dependencies = [
],
)
python3_library(
name = "model_lib",
sources = [
"model.py",
],
tags = [
"bazel-compatible",
"no-mypy",
],
dependencies = [
":layers_lib",
":params_lib",
"3rdparty/python/absl-py:default",
],
)

View File

@ -0,0 +1,128 @@
"""
Different type of convolution layers to be used in the ClemNet.
"""
from typing import Any
import tensorflow as tf
class KerasConv1D(tf.keras.layers.Layer):
"""
Basic Conv1D layer in a wrapper to be compatible with ClemNet.
"""
def __init__(
self,
kernel_size: int,
filters: int,
strides: int,
padding: str,
use_bias: bool = True,
kernel_initializer: str = "glorot_uniform",
bias_initializer: str = "zeros",
**kwargs: Any,
):
super(KerasConv1D, self).__init__(**kwargs)
self.kernel_size = kernel_size
self.filters = filters
self.use_bias = use_bias
self.kernel_initializer = kernel_initializer
self.bias_initializer = bias_initializer
self.strides = strides
self.padding = padding
def build(self, input_shape: tf.TensorShape) -> None:
assert (
len(input_shape) == 3
), f"Tensor shape must be of length 3. Passed tensor of shape {input_shape}."
self.features = input_shape[1]
self.w = tf.keras.layers.Conv1D(
kernel_size=self.kernel_size,
filters=self.filters,
strides=self.strides,
padding=self.padding,
use_bias=self.use_bias,
kernel_initializer=self.kernel_initializer,
bias_initializer=self.bias_initializer,
name=self.name,
)
def call(self, inputs: tf.Tensor, **kwargs: Any) -> tf.Tensor:
return self.w(inputs)
class ChannelWiseDense(tf.keras.layers.Layer):
"""
Dense layer is applied to each channel separately. This is more memory and computationally
efficient than flattening the channels and performing single dense layers over it which is the
default behavior in tf1.
"""
def __init__(
self,
output_size: int,
use_bias: bool,
kernel_initializer: str = "uniform_glorot",
bias_initializer: str = "zeros",
**kwargs: Any,
):
super(ChannelWiseDense, self).__init__(**kwargs)
self.output_size = output_size
self.use_bias = use_bias
self.kernel_initializer = kernel_initializer
self.bias_initializer = bias_initializer
def build(self, input_shape: tf.TensorShape) -> None:
assert (
len(input_shape) == 3
), f"Tensor shape must be of length 3. Passed tensor of shape {input_shape}."
input_size = input_shape[1]
channels = input_shape[2]
self.kernel = self.add_weight(
name="kernel",
shape=(channels, input_size, self.output_size),
initializer=self.kernel_initializer,
trainable=True,
)
self.bias = self.add_weight(
name="bias",
shape=(channels, self.output_size),
initializer=self.bias_initializer,
trainable=self.use_bias,
)
def call(self, inputs: tf.Tensor, **kwargs: Any) -> tf.Tensor:
x = inputs
transposed_x = tf.transpose(x, perm=[2, 0, 1])
transposed_residual = (
tf.transpose(tf.matmul(transposed_x, self.kernel), perm=[1, 0, 2]) + self.bias
)
output = tf.transpose(transposed_residual, perm=[0, 2, 1])
return output
class ResidualLayer(tf.keras.layers.Layer):
"""
Layer implementing a 3D-residual connection.
"""
def build(self, input_shape: tf.TensorShape) -> None:
assert (
len(input_shape) == 3
), f"Tensor shape must be of length 3. Passed tensor of shape {input_shape}."
def call(self, inputs: tf.Tensor, residual: tf.Tensor, **kwargs: Any) -> tf.Tensor:
shortcut = tf.keras.layers.Conv1D(
filters=int(residual.shape[2]), strides=1, kernel_size=1, padding="SAME", use_bias=False
)(inputs)
output = tf.add(shortcut, residual)
return output

View File

@ -0,0 +1,76 @@
"""
Module containing ClemNet.
"""
from typing import Any
from .layers import ChannelWiseDense, KerasConv1D, ResidualLayer
from .params import BlockParams, ClemNetParams
import tensorflow as tf
import tensorflow.compat.v1 as tf1
class Block2(tf.keras.layers.Layer):
"""
Possible ClemNet block. Architecture is as follow:
Optional(DenseLayer + BN + Act)
Optional(ConvLayer + BN + Act)
Optional(Residual Layer)
"""
def __init__(self, params: BlockParams, **kwargs: Any):
super(Block2, self).__init__(**kwargs)
self.params = params
def build(self, input_shape: tf.TensorShape) -> None:
assert (
len(input_shape) == 3
), f"Tensor shape must be of length 3. Passed tensor of shape {input_shape}."
def call(self, inputs: tf.Tensor, training: bool) -> tf.Tensor:
x = inputs
if self.params.dense:
x = ChannelWiseDense(**self.params.dense.dict())(inputs=x, training=training)
x = tf1.layers.batch_normalization(x, momentum=0.9999, training=training, axis=1)
x = tf.keras.layers.Activation(self.params.activation)(x)
if self.params.conv:
x = KerasConv1D(**self.params.conv.dict())(inputs=x, training=training)
x = tf1.layers.batch_normalization(x, momentum=0.9999, training=training, axis=1)
x = tf.keras.layers.Activation(self.params.activation)(x)
if self.params.residual:
x = ResidualLayer()(inputs=inputs, residual=x)
return x
class ClemNet(tf.keras.layers.Layer):
"""
A residual network stacking residual blocks composed of dense layers and convolutions.
"""
def __init__(self, params: ClemNetParams, **kwargs: Any):
super(ClemNet, self).__init__(**kwargs)
self.params = params
def build(self, input_shape: tf.TensorShape) -> None:
assert len(input_shape) in (
2,
3,
), f"Tensor shape must be of length 3. Passed tensor of shape {input_shape}."
def call(self, inputs: tf.Tensor, training: bool) -> tf.Tensor:
if len(inputs.shape) < 3:
inputs = tf.expand_dims(inputs, axis=-1)
x = inputs
for block_params in self.params.blocks:
x = Block2(block_params)(inputs=x, training=training)
x = tf.keras.layers.Flatten(name="flattened")(x)
if self.params.top:
x = tf.keras.layers.Dense(units=self.params.top.n_labels, name="logits")(x)
return x

View File

@ -0,0 +1,49 @@
"""
Parameters used in ClemNet.
"""
from typing import List, Optional
from pydantic import BaseModel, Extra, Field, PositiveInt
# checkstyle: noqa
class ExtendedBaseModel(BaseModel):
class Config:
extra = Extra.forbid
class DenseParams(ExtendedBaseModel):
name: Optional[str]
bias_initializer: str = "zeros"
kernel_initializer: str = "glorot_uniform"
output_size: PositiveInt
use_bias: bool = Field(True)
class ConvParams(ExtendedBaseModel):
name: Optional[str]
bias_initializer: str = "zeros"
filters: PositiveInt
kernel_initializer: str = "glorot_uniform"
kernel_size: PositiveInt
padding: str = "SAME"
strides: PositiveInt = 1
use_bias: bool = Field(True)
class BlockParams(ExtendedBaseModel):
activation: Optional[str]
conv: Optional[ConvParams]
dense: Optional[DenseParams]
residual: Optional[bool]
class TopLayerParams(ExtendedBaseModel):
n_labels: PositiveInt
class ClemNetParams(ExtendedBaseModel):
blocks: List[BlockParams] = []
top: Optional[TopLayerParams]

View File

@ -0,0 +1,34 @@
"""
Candidate architectures for each task's.
"""
from __future__ import annotations
from typing import Dict
from .features import get_features
from .graph import Graph
from .lib.model import ClemNet
from .params import ModelTypeEnum
import tensorflow as tf
class MagicRecsClemNet(Graph):
def get_logits(self, features: Dict[str, tf.Tensor], training: bool) -> tf.Tensor:
with tf.name_scope("logits"):
inputs = get_features(features=features, training=training, params=self.params.model.features)
with tf.name_scope("OONC_logits"):
model = ClemNet(params=self.params.model.architecture)
oonc_logit = model(inputs=inputs, training=training)
with tf.name_scope("EngagementGivenOONC_logits"):
model = ClemNet(params=self.params.model.architecture)
eng_logits = model(inputs=inputs, training=training)
return tf.concat([oonc_logit, eng_logits], axis=1)
ALL_MODELS = {ModelTypeEnum.clemnet: MagicRecsClemNet}

View File

@ -0,0 +1,89 @@
import enum
import json
from typing import List, Optional
from .lib.params import BlockParams, ClemNetParams, ConvParams, DenseParams, TopLayerParams
from pydantic import BaseModel, Extra, NonNegativeFloat
import tensorflow.compat.v1 as tf
# checkstyle: noqa
class ExtendedBaseModel(BaseModel):
class Config:
extra = Extra.forbid
class SparseFeaturesParams(ExtendedBaseModel):
bits: int
embedding_size: int
class FeaturesParams(ExtendedBaseModel):
sparse_features: Optional[SparseFeaturesParams]
class ModelTypeEnum(str, enum.Enum):
clemnet: str = "clemnet"
class ModelParams(ExtendedBaseModel):
name: ModelTypeEnum
features: FeaturesParams
architecture: ClemNetParams
class TaskNameEnum(str, enum.Enum):
oonc: str = "OONC"
engagement: str = "Engagement"
class Task(ExtendedBaseModel):
name: TaskNameEnum
label: str
score_weight: NonNegativeFloat
DEFAULT_TASKS = [
Task(name=TaskNameEnum.oonc, label="label", score_weight=0.9),
Task(name=TaskNameEnum.engagement, label="label.engagement", score_weight=0.1),
]
class GraphParams(ExtendedBaseModel):
tasks: List[Task] = DEFAULT_TASKS
model: ModelParams
weight: Optional[str]
DEFAULT_ARCHITECTURE_PARAMS = ClemNetParams(
blocks=[
BlockParams(
activation="relu",
conv=ConvParams(kernel_size=3, filters=5),
dense=DenseParams(output_size=output_size),
residual=False,
)
for output_size in [1024, 512, 256, 128]
],
top=TopLayerParams(n_labels=1),
)
DEFAULT_GRAPH_PARAMS = GraphParams(
model=ModelParams(
name=ModelTypeEnum.clemnet,
architecture=DEFAULT_ARCHITECTURE_PARAMS,
features=FeaturesParams(sparse_features=SparseFeaturesParams(bits=18, embedding_size=50)),
),
)
def load_graph_params(args) -> GraphParams:
params = DEFAULT_GRAPH_PARAMS
if args.param_file:
with tf.io.gfile.GFile(args.param_file, mode="r+") as file:
params = GraphParams.parse_obj(json.load(file))
return params

View File

@ -0,0 +1,59 @@
from twml.trainers import DataRecordTrainer
from .features import FEATURE_LIST_DEFAULT_PATH
def get_training_arg_parser():
parser = DataRecordTrainer.add_parser_arguments()
parser.add_argument(
"--feature_list",
default=FEATURE_LIST_DEFAULT_PATH,
type=str,
help="Which features to use for training",
)
parser.add_argument(
"--param_file",
default=None,
type=str,
help="Path to JSON file containing the graph parameters. If None, model will load default parameters.",
)
parser.add_argument(
"--directly_export_best",
default=False,
action="store_true",
help="whether to directly_export best_checkpoint",
)
parser.add_argument(
"--warm_start_from", default=None, type=str, help="model dir to warm start from"
)
parser.add_argument(
"--warm_start_base_dir",
default=None,
type=str,
help="latest ckpt in this folder will be used to ",
)
parser.add_argument(
"--model_type",
default=None,
type=str,
help="Which type of model to train.",
)
return parser
def get_eval_arg_parser():
parser = get_training_arg_parser()
parser.add_argument(
"--eval_checkpoint",
default=None,
type=str,
help="Which checkpoint to use for evaluation",
)
return parser

View File

@ -0,0 +1,146 @@
"""
Model for modifying the checkpoints of the magic recs cnn Model with addition, deletion, and reordering
of continuous and binary features.
"""
import os
from twitter.deepbird.projects.magic_recs.libs.get_feat_config import FEATURE_LIST_DEFAULT_PATH
from twitter.deepbird.projects.magic_recs.libs.warm_start_utils_v11 import (
get_feature_list_for_heavy_ranking,
mkdirp,
rename_dir,
rmdir,
warm_start_checkpoint,
)
import twml
from twml.trainers import DataRecordTrainer
import tensorflow.compat.v1 as tf
from tensorflow.compat.v1 import logging
def get_arg_parser():
parser = DataRecordTrainer.add_parser_arguments()
parser.add_argument(
"--model_type",
default="deepnorm_gbdt_inputdrop2_rescale",
type=str,
help="specify the model type to use.",
)
parser.add_argument(
"--model_trainer_name",
default="None",
type=str,
help="deprecated, added here just for api compatibility.",
)
parser.add_argument(
"--warm_start_base_dir",
default="none",
type=str,
help="latest ckpt in this folder will be used.",
)
parser.add_argument(
"--output_checkpoint_dir",
default="none",
type=str,
help="Output folder for warm started ckpt. If none, it will move warm_start_base_dir to backup, and overwrite it",
)
parser.add_argument(
"--feature_list",
default="none",
type=str,
help="Which features to use for training",
)
parser.add_argument(
"--old_feature_list",
default="none",
type=str,
help="Which features to use for training",
)
return parser
def get_params(args=None):
parser = get_arg_parser()
if args is None:
return parser.parse_args()
else:
return parser.parse_args(args)
def _main():
opt = get_params()
logging.info("parse is: ")
logging.info(opt)
if opt.feature_list == "none":
feature_list_path = FEATURE_LIST_DEFAULT_PATH
else:
feature_list_path = opt.feature_list
if opt.warm_start_base_dir != "none" and tf.io.gfile.exists(opt.warm_start_base_dir):
if opt.output_checkpoint_dir == "none" or opt.output_checkpoint_dir == opt.warm_start_base_dir:
_warm_start_base_dir = os.path.normpath(opt.warm_start_base_dir) + "_backup_warm_start"
_output_folder_dir = opt.warm_start_base_dir
rename_dir(opt.warm_start_base_dir, _warm_start_base_dir)
tf.logging.info(f"moved {opt.warm_start_base_dir} to {_warm_start_base_dir}")
else:
_warm_start_base_dir = opt.warm_start_base_dir
_output_folder_dir = opt.output_checkpoint_dir
continuous_binary_feat_list_save_path = os.path.join(
_warm_start_base_dir, "continuous_binary_feat_list.json"
)
if opt.old_feature_list != "none":
tf.logging.info("getting old continuous_binary_feat_list")
continuous_binary_feat_list = get_feature_list_for_heavy_ranking(
opt.old_feature_list, opt.data_spec
)
rmdir(continuous_binary_feat_list_save_path)
twml.util.write_file(
continuous_binary_feat_list_save_path, continuous_binary_feat_list, encode="json"
)
tf.logging.info(f"Finish writting files to {continuous_binary_feat_list_save_path}")
warm_start_folder = os.path.join(_warm_start_base_dir, "best_checkpoint")
if not tf.io.gfile.exists(warm_start_folder):
warm_start_folder = _warm_start_base_dir
rmdir(_output_folder_dir)
mkdirp(_output_folder_dir)
new_ckpt = warm_start_checkpoint(
warm_start_folder,
continuous_binary_feat_list_save_path,
feature_list_path,
opt.data_spec,
_output_folder_dir,
opt.model_type,
)
logging.info(f"Created new ckpt {new_ckpt} from {warm_start_folder}")
tf.logging.info("getting new continuous_binary_feat_list")
new_continuous_binary_feat_list_save_path = os.path.join(
_output_folder_dir, "continuous_binary_feat_list.json"
)
continuous_binary_feat_list = get_feature_list_for_heavy_ranking(
feature_list_path, opt.data_spec
)
rmdir(new_continuous_binary_feat_list_save_path)
twml.util.write_file(
new_continuous_binary_feat_list_save_path, continuous_binary_feat_list, encode="json"
)
tf.logging.info(f"Finish writting files to {new_continuous_binary_feat_list_save_path}")
if __name__ == "__main__":
_main()

View File

@ -0,0 +1,16 @@
python3_library(
name = "libs",
sources = ["*.py"],
tags = [
"bazel-compatible",
"no-mypy",
],
dependencies = [
"cortex/recsys/src/python/twitter/cortex/recsys/utils",
"magicpony/common/file_access/src/python/twitter/magicpony/common/file_access",
"src/python/twitter/cortex/ml/embeddings/deepbird",
"src/python/twitter/cortex/ml/embeddings/deepbird/grouped_metrics",
"src/python/twitter/deepbird/util/data",
"twml:twml-nodeps",
],
)

View File

@ -0,0 +1,56 @@
# pylint: disable=no-member, arguments-differ, attribute-defined-outside-init, unused-argument
"""
Implementing Full Sparse Layer, allow specify use_binary_value in call() to
overide default action.
"""
from twml.layers import FullSparse as defaultFullSparse
from twml.layers.full_sparse import sparse_dense_matmul
import tensorflow.compat.v1 as tf
class FullSparse(defaultFullSparse):
def call(self, inputs, use_binary_values=None, **kwargs): # pylint: disable=unused-argument
"""The logic of the layer lives here.
Arguments:
inputs:
A SparseTensor or a list of SparseTensors.
If `inputs` is a list, all tensors must have same `dense_shape`.
Returns:
- If `inputs` is `SparseTensor`, then returns `bias + inputs * dense_b`.
- If `inputs` is a `list[SparseTensor`, then returns
`bias + add_n([sp_a * dense_b for sp_a in inputs])`.
"""
if use_binary_values is not None:
default_use_binary_values = use_binary_values
else:
default_use_binary_values = self.use_binary_values
if isinstance(default_use_binary_values, (list, tuple)):
raise ValueError(
"use_binary_values can not be %s when inputs is %s"
% (type(default_use_binary_values), type(inputs))
)
outputs = sparse_dense_matmul(
inputs,
self.weight,
self.use_sparse_grads,
default_use_binary_values,
name="sparse_mm",
partition_axis=self.partition_axis,
num_partitions=self.num_partitions,
compress_ids=self._use_compression,
cast_indices_dtype=self._cast_indices_dtype,
)
if self.bias is not None:
outputs = tf.nn.bias_add(outputs, self.bias)
if self.activation is not None:
return self.activation(outputs) # pylint: disable=not-callable
return outputs

View File

@ -0,0 +1,176 @@
import os
from twitter.deepbird.projects.magic_recs.libs.metric_fn_utils import USER_AGE_FEATURE_NAME
from twitter.deepbird.projects.magic_recs.libs.model_utils import read_config
from twml.contrib import feature_config as contrib_feature_config
# checkstyle: noqa
FEAT_CONFIG_DEFAULT_VAL = -1.23456789
DEFAULT_INPUT_SIZE_BITS = 18
DEFAULT_FEATURE_LIST_PATH = "./feature_list_default.yaml"
FEATURE_LIST_DEFAULT_PATH = os.path.join(
os.path.dirname(os.path.realpath(__file__)), DEFAULT_FEATURE_LIST_PATH
)
DEFAULT_FEATURE_LIST_LIGHT_RANKING_PATH = "./feature_list_light_ranking.yaml"
FEATURE_LIST_DEFAULT_LIGHT_RANKING_PATH = os.path.join(
os.path.dirname(os.path.realpath(__file__)), DEFAULT_FEATURE_LIST_LIGHT_RANKING_PATH
)
FEATURE_LIST_DEFAULT = read_config(FEATURE_LIST_DEFAULT_PATH).items()
FEATURE_LIST_LIGHT_RANKING_DEFAULT = read_config(FEATURE_LIST_DEFAULT_LIGHT_RANKING_PATH).items()
LABELS = ["label"]
LABELS_MTL = {"OONC": ["label"], "OONC_Engagement": ["label", "label.engagement"]}
LABELS_LR = {
"Sent": ["label.sent"],
"HeavyRankPosition": ["meta.ranking.is_top3"],
"HeavyRankProbability": ["meta.ranking.weighted_oonc_model_score"],
}
def _get_new_feature_config_base(
data_spec_path,
labels,
add_sparse_continous=True,
add_gbdt=True,
add_user_id=False,
add_timestamp=False,
add_user_age=False,
feature_list_provided=[],
opt=None,
run_light_ranking_group_metrics_in_bq=False,
):
"""
Getter of the feature config based on specification.
Args:
data_spec_path: A string indicating the path of the data_spec.json file, which could be
either a local path or a hdfs path.
labels: A list of strings indicating the name of the label in the data spec.
add_sparse_continous: A bool indicating if sparse_continuous feature needs to be included.
add_gbdt: A bool indicating if gbdt feature needs to be included.
add_user_id: A bool indicating if user_id feature needs to be included.
add_timestamp: A bool indicating if timestamp feature needs to be included. This will be useful
for sequential models and meta learning models.
add_user_age: A bool indicating if the user age feature needs to be included.
feature_list_provided: A list of features thats need to be included. If not specified, will use
FEATURE_LIST_DEFAULT by default.
opt: A namespace of arguments indicating the hyparameters.
run_light_ranking_group_metrics_in_bq: A bool indicating if heavy ranker score info needs to be included to compute group metrics in BigQuery.
Returns:
A twml feature config object.
"""
input_size_bits = DEFAULT_INPUT_SIZE_BITS if opt is None else opt.input_size_bits
feature_list = feature_list_provided if feature_list_provided != [] else FEATURE_LIST_DEFAULT
a_string_feat_list = [f[0] for f in feature_list if f[1] != "S"]
builder = contrib_feature_config.FeatureConfigBuilder(data_spec_path=data_spec_path)
builder = builder.extract_feature_group(
feature_regexes=a_string_feat_list,
group_name="continuous",
default_value=FEAT_CONFIG_DEFAULT_VAL,
type_filter=["CONTINUOUS"],
)
builder = builder.extract_features_as_hashed_sparse(
feature_regexes=a_string_feat_list,
output_tensor_name="sparse_no_continuous",
hash_space_size_bits=input_size_bits,
type_filter=["BINARY", "DISCRETE", "STRING", "SPARSE_BINARY"],
)
if add_gbdt:
builder = builder.extract_features_as_hashed_sparse(
feature_regexes=["ads\..*"],
output_tensor_name="gbdt_sparse",
hash_space_size_bits=input_size_bits,
)
if add_sparse_continous:
s_string_feat_list = [f[0] for f in feature_list if f[1] == "S"]
builder = builder.extract_features_as_hashed_sparse(
feature_regexes=s_string_feat_list,
output_tensor_name="sparse_continuous",
hash_space_size_bits=input_size_bits,
type_filter=["SPARSE_CONTINUOUS"],
)
if add_user_id:
builder = builder.extract_feature("meta.user_id")
if add_timestamp:
builder = builder.extract_feature("meta.timestamp")
if add_user_age:
builder = builder.extract_feature(USER_AGE_FEATURE_NAME)
if run_light_ranking_group_metrics_in_bq:
builder = builder.extract_feature("meta.trace_id")
builder = builder.extract_feature("meta.ranking.weighted_oonc_model_score")
builder = builder.add_labels(labels).define_weight("meta.weight")
return builder.build()
def get_feature_config_with_sparse_continuous(
data_spec_path,
feature_list_provided=[],
opt=None,
add_user_id=False,
add_timestamp=False,
add_user_age=False,
):
task_name = opt.task_name if getattr(opt, "task_name", None) is not None else "OONC"
if task_name not in LABELS_MTL:
raise ValueError("Invalid Task Name !")
return _get_new_feature_config_base(
data_spec_path=data_spec_path,
labels=LABELS_MTL[task_name],
add_sparse_continous=True,
add_user_id=add_user_id,
add_timestamp=add_timestamp,
add_user_age=add_user_age,
feature_list_provided=feature_list_provided,
opt=opt,
)
def get_feature_config_light_ranking(
data_spec_path,
feature_list_provided=[],
opt=None,
add_user_id=True,
add_timestamp=False,
add_user_age=False,
add_gbdt=False,
run_light_ranking_group_metrics_in_bq=False,
):
task_name = opt.task_name if getattr(opt, "task_name", None) is not None else "HeavyRankPosition"
if task_name not in LABELS_LR:
raise ValueError("Invalid Task Name !")
if not feature_list_provided:
feature_list_provided = FEATURE_LIST_LIGHT_RANKING_DEFAULT
return _get_new_feature_config_base(
data_spec_path=data_spec_path,
labels=LABELS_LR[task_name],
add_sparse_continous=False,
add_gbdt=add_gbdt,
add_user_id=add_user_id,
add_timestamp=add_timestamp,
add_user_age=add_user_age,
feature_list_provided=feature_list_provided,
opt=opt,
run_light_ranking_group_metrics_in_bq=run_light_ranking_group_metrics_in_bq,
)

View File

@ -0,0 +1,42 @@
"""
Utilties that aid in building the magic recs graph.
"""
import re
import tensorflow.compat.v1 as tf
def get_trainable_variables(all_trainable_variables, trainable_regexes):
"""Returns a subset of trainable variables for training.
Given a collection of trainable variables, this will return all those that match the given regexes.
Will also log those variables.
Args:
all_trainable_variables (a collection of trainable tf.Variable): The variables to search through.
trainable_regexes (a collection of regexes): Variables that match any regex will be included.
Returns a list of tf.Variable
"""
if trainable_regexes is None or len(trainable_regexes) == 0:
tf.logging.info("No trainable regexes found. Not using get_trainable_variables behavior.")
return None
assert any(
tf.is_tensor(var) for var in all_trainable_variables
), f"Non TF variable found: {all_trainable_variables}"
trainable_variables = list(
filter(
lambda var: any(re.match(regex, var.name, re.IGNORECASE) for regex in trainable_regexes),
all_trainable_variables,
)
)
tf.logging.info(f"Using filtered trainable variables: {trainable_variables}")
assert (
trainable_variables
), "Did not find trainable variables after filtering after filtering from {} number of vars originaly. All vars: {} and train regexes: {}".format(
len(all_trainable_variables), all_trainable_variables, trainable_regexes
)
return trainable_variables

View File

@ -0,0 +1,114 @@
import os
import time
from twitter.cortex.ml.embeddings.deepbird.grouped_metrics.computation import (
write_grouped_metrics_to_mldash,
)
from twitter.cortex.ml.embeddings.deepbird.grouped_metrics.configuration import (
ClassificationGroupedMetricsConfiguration,
NDCGGroupedMetricsConfiguration,
)
import twml
from .light_ranking_metrics import (
CGRGroupedMetricsConfiguration,
ExpectedLossGroupedMetricsConfiguration,
RecallGroupedMetricsConfiguration,
)
import numpy as np
import tensorflow.compat.v1 as tf
from tensorflow.compat.v1 import logging
# checkstyle: noqa
def run_group_metrics(trainer, data_dir, model_path, parse_fn, group_feature_name="meta.user_id"):
start_time = time.time()
logging.info("Evaluating with group metrics.")
metrics = write_grouped_metrics_to_mldash(
trainer=trainer,
data_dir=data_dir,
model_path=model_path,
group_fn=lambda datarecord: str(
datarecord.discreteFeatures[twml.feature_id(group_feature_name)[0]]
),
parse_fn=parse_fn,
metric_configurations=[
ClassificationGroupedMetricsConfiguration(),
NDCGGroupedMetricsConfiguration(k=[5, 10, 20]),
],
total_records_to_read=1000000000,
shuffle=False,
mldash_metrics_name="grouped_metrics",
)
end_time = time.time()
logging.info(f"Evaluated Group Metics: {metrics}.")
logging.info(f"Group metrics evaluation time {end_time - start_time}.")
def run_group_metrics_light_ranking(
trainer, data_dir, model_path, parse_fn, group_feature_name="meta.trace_id"
):
start_time = time.time()
logging.info("Evaluating with group metrics.")
metrics = write_grouped_metrics_to_mldash(
trainer=trainer,
data_dir=data_dir,
model_path=model_path,
group_fn=lambda datarecord: str(
datarecord.discreteFeatures[twml.feature_id(group_feature_name)[0]]
),
parse_fn=parse_fn,
metric_configurations=[
CGRGroupedMetricsConfiguration(lightNs=[50, 100, 200], heavyKs=[1, 3, 10, 20, 50]),
RecallGroupedMetricsConfiguration(n=[50, 100, 200], k=[1, 3, 10, 20, 50]),
ExpectedLossGroupedMetricsConfiguration(lightNs=[50, 100, 200]),
],
total_records_to_read=10000000,
num_batches_to_load=50,
batch_size=1024,
shuffle=False,
mldash_metrics_name="grouped_metrics_for_light_ranking",
)
end_time = time.time()
logging.info(f"Evaluated Group Metics for Light Ranking: {metrics}.")
logging.info(f"Group metrics evaluation time {end_time - start_time}.")
def run_group_metrics_light_ranking_in_bq(trainer, params, checkpoint_path):
logging.info("getting Test Predictions for Light Ranking Group Metrics in BigQuery !!!")
eval_input_fn = trainer.get_eval_input_fn(repeat=False, shuffle=False)
info_pool = []
for result in trainer.estimator.predict(
eval_input_fn, checkpoint_path=checkpoint_path, yield_single_examples=False
):
traceID = result["trace_id"]
pred = result["prediction"]
label = result["target"]
info = np.concatenate([traceID, pred, label], axis=1)
info_pool.append(info)
info_pool = np.concatenate(info_pool)
locname = "/tmp/000/"
if not os.path.exists(locname):
os.makedirs(locname)
locfile = locname + params.pred_file_name
columns = ["trace_id", "model_prediction", "meta__ranking__weighted_oonc_model_score"]
np.savetxt(locfile, info_pool, delimiter=",", header=",".join(columns))
tf.io.gfile.copy(locfile, params.pred_file_path + params.pred_file_name, overwrite=True)
if os.path.isfile(locfile):
os.remove(locfile)
logging.info("Done Prediction for Light Ranking Group Metrics in BigQuery.")

View File

@ -0,0 +1,118 @@
import numpy as np
from tensorflow.keras import backend as K
class VarianceScaling(object):
"""Initializer capable of adapting its scale to the shape of weights.
With `distribution="normal"`, samples are drawn from a truncated normal
distribution centered on zero, with `stddev = sqrt(scale / n)` where n is:
- number of input units in the weight tensor, if mode = "fan_in"
- number of output units, if mode = "fan_out"
- average of the numbers of input and output units, if mode = "fan_avg"
With `distribution="uniform"`,
samples are drawn from a uniform distribution
within [-limit, limit], with `limit = sqrt(3 * scale / n)`.
# Arguments
scale: Scaling factor (positive float).
mode: One of "fan_in", "fan_out", "fan_avg".
distribution: Random distribution to use. One of "normal", "uniform".
seed: A Python integer. Used to seed the random generator.
# Raises
ValueError: In case of an invalid value for the "scale", mode" or
"distribution" arguments."""
def __init__(
self,
scale=1.0,
mode="fan_in",
distribution="normal",
seed=None,
fan_in=None,
fan_out=None,
):
self.fan_in = fan_in
self.fan_out = fan_out
if scale <= 0.0:
raise ValueError("`scale` must be a positive float. Got:", scale)
mode = mode.lower()
if mode not in {"fan_in", "fan_out", "fan_avg"}:
raise ValueError(
"Invalid `mode` argument: " 'expected on of {"fan_in", "fan_out", "fan_avg"} ' "but got",
mode,
)
distribution = distribution.lower()
if distribution not in {"normal", "uniform"}:
raise ValueError(
"Invalid `distribution` argument: " 'expected one of {"normal", "uniform"} ' "but got",
distribution,
)
self.scale = scale
self.mode = mode
self.distribution = distribution
self.seed = seed
def __call__(self, shape, dtype=None, partition_info=None):
fan_in = shape[-2] if self.fan_in is None else self.fan_in
fan_out = shape[-1] if self.fan_out is None else self.fan_out
scale = self.scale
if self.mode == "fan_in":
scale /= max(1.0, fan_in)
elif self.mode == "fan_out":
scale /= max(1.0, fan_out)
else:
scale /= max(1.0, float(fan_in + fan_out) / 2)
if self.distribution == "normal":
stddev = np.sqrt(scale) / 0.87962566103423978
return K.truncated_normal(shape, 0.0, stddev, dtype=dtype, seed=self.seed)
else:
limit = np.sqrt(3.0 * scale)
return K.random_uniform(shape, -limit, limit, dtype=dtype, seed=self.seed)
def get_config(self):
return {
"scale": self.scale,
"mode": self.mode,
"distribution": self.distribution,
"seed": self.seed,
}
def customized_glorot_uniform(seed=None, fan_in=None, fan_out=None):
"""Glorot uniform initializer, also called Xavier uniform initializer.
It draws samples from a uniform distribution within [-limit, limit]
where `limit` is `sqrt(6 / (fan_in + fan_out))`
where `fan_in` is the number of input units in the weight tensor
and `fan_out` is the number of output units in the weight tensor.
# Arguments
seed: A Python integer. Used to seed the random generator.
# Returns
An initializer."""
return VarianceScaling(
scale=1.0,
mode="fan_avg",
distribution="uniform",
seed=seed,
fan_in=fan_in,
fan_out=fan_out,
)
def customized_glorot_norm(seed=None, fan_in=None, fan_out=None):
"""Glorot norm initializer, also called Xavier uniform initializer.
It draws samples from a uniform distribution within [-limit, limit]
where `limit` is `sqrt(6 / (fan_in + fan_out))`
where `fan_in` is the number of input units in the weight tensor
and `fan_out` is the number of output units in the weight tensor.
# Arguments
seed: A Python integer. Used to seed the random generator.
# Returns
An initializer."""
return VarianceScaling(
scale=1.0,
mode="fan_avg",
distribution="normal",
seed=seed,
fan_in=fan_in,
fan_out=fan_out,
)

View File

@ -0,0 +1,255 @@
from functools import partial
from twitter.cortex.ml.embeddings.deepbird.grouped_metrics.configuration import (
GroupedMetricsConfiguration,
)
from twitter.cortex.ml.embeddings.deepbird.grouped_metrics.helpers import (
extract_prediction_from_prediction_record,
)
# checkstyle: noqa
def score_loss_at_n(labels, predictions, lightN):
"""
Compute the absolute ScoreLoss ranking metric
Args:
labels (list) : A list of label values (HeavyRanking Reference)
predictions (list): A list of prediction values (LightRanking Predictions)
lightN (int): size of the list at which of Initial candidates to compute ScoreLoss. (LightRanking)
"""
assert len(labels) == len(predictions)
if lightN <= 0:
return None
labels_with_predictions = zip(labels, predictions)
labels_with_sorted_predictions = sorted(
labels_with_predictions, key=lambda x: x[1], reverse=True
)[:lightN]
labels_top1_light = max([label for label, _ in labels_with_sorted_predictions])
labels_top1_heavy = max(labels)
return labels_top1_heavy - labels_top1_light
def cgr_at_nk(labels, predictions, lightN, heavyK):
"""
Compute Cumulative Gain Ratio (CGR) ranking metric
Args:
labels (list) : A list of label values (HeavyRanking Reference)
predictions (list): A list of prediction values (LightRanking Predictions)
lightN (int): size of the list at which of Initial candidates to compute CGR. (LightRanking)
heavyK (int): size of the list at which of Refined candidates to compute CGR. (HeavyRanking)
"""
assert len(labels) == len(predictions)
if (not lightN) or (not heavyK):
out = None
elif lightN <= 0 or heavyK <= 0:
out = None
else:
labels_with_predictions = zip(labels, predictions)
labels_with_sorted_predictions = sorted(
labels_with_predictions, key=lambda x: x[1], reverse=True
)[:lightN]
labels_topN_light = [label for label, _ in labels_with_sorted_predictions]
if lightN <= heavyK:
cg_light = sum(labels_topN_light)
else:
labels_topK_heavy_from_light = sorted(labels_topN_light, reverse=True)[:heavyK]
cg_light = sum(labels_topK_heavy_from_light)
ideal_ordering = sorted(labels, reverse=True)
cg_heavy = sum(ideal_ordering[: min(lightN, heavyK)])
out = 0.0
if cg_heavy != 0:
out = max(cg_light / cg_heavy, 0)
return out
def _get_weight(w, atK):
if not w:
return 1.0
elif len(w) <= atK:
return 0.0
else:
return w[atK]
def recall_at_nk(labels, predictions, n=None, k=None, w=None):
"""
Recall at N-K ranking metric
Args:
labels (list): A list of label values
predictions (list): A list of prediction values
n (int): size of the list at which of predictions to compute recall. (Light Ranking Predictions)
The default is None in which case the length of the provided predictions is used as L
k (int): size of the list at which of labels to compute recall. (Heavy Ranking Predictions)
The default is None in which case the length of the provided labels is used as L
w (list): weight vector sorted by labels
"""
assert len(labels) == len(predictions)
if not any(labels):
out = None
else:
safe_n = len(predictions) if not n else min(len(predictions), n)
safe_k = len(labels) if not k else min(len(labels), k)
labels_with_predictions = zip(labels, predictions)
sorted_labels_with_predictions = sorted(
labels_with_predictions, key=lambda x: x[0], reverse=True
)
order_sorted_labels_predictions = zip(range(len(labels)), *zip(*sorted_labels_with_predictions))
order_with_predictions = [
(order, pred) for order, label, pred in order_sorted_labels_predictions
]
order_with_sorted_predictions = sorted(order_with_predictions, key=lambda x: x[1], reverse=True)
pred_sorted_order_at_n = [order for order, _ in order_with_sorted_predictions][:safe_n]
intersection_weight = [
_get_weight(w, order) if order < safe_k else 0 for order in pred_sorted_order_at_n
]
intersection_score = sum(intersection_weight)
full_score = sum(w) if w else float(safe_k)
out = 0.0
if full_score != 0:
out = intersection_score / full_score
return out
class ExpectedLossGroupedMetricsConfiguration(GroupedMetricsConfiguration):
"""
This is the Expected Loss Grouped metric computation configuration.
"""
def __init__(self, lightNs=[]):
"""
Args:
lightNs (list): size of the list at which of Initial candidates to compute Expected Loss. (LightRanking)
"""
self.lightNs = lightNs
@property
def name(self):
return "ExpectedLoss"
@property
def metrics_dict(self):
metrics_to_compute = {}
for lightN in self.lightNs:
metric_name = "ExpectedLoss_atLight_" + str(lightN)
metrics_to_compute[metric_name] = partial(score_loss_at_n, lightN=lightN)
return metrics_to_compute
def extract_label(self, prec, drec, drec_label):
return drec_label
def extract_prediction(self, prec, drec, drec_label):
return extract_prediction_from_prediction_record(prec)
class CGRGroupedMetricsConfiguration(GroupedMetricsConfiguration):
"""
This is the Cumulative Gain Ratio (CGR) Grouped metric computation configuration.
CGR at the max length of each session is the default.
CGR at additional positions can be computed by specifying a list of 'n's and 'k's
"""
def __init__(self, lightNs=[], heavyKs=[]):
"""
Args:
lightNs (list): size of the list at which of Initial candidates to compute CGR. (LightRanking)
heavyK (int): size of the list at which of Refined candidates to compute CGR. (HeavyRanking)
"""
self.lightNs = lightNs
self.heavyKs = heavyKs
@property
def name(self):
return "cgr"
@property
def metrics_dict(self):
metrics_to_compute = {}
for lightN in self.lightNs:
for heavyK in self.heavyKs:
metric_name = "cgr_atLight_" + str(lightN) + "_atHeavy_" + str(heavyK)
metrics_to_compute[metric_name] = partial(cgr_at_nk, lightN=lightN, heavyK=heavyK)
return metrics_to_compute
def extract_label(self, prec, drec, drec_label):
return drec_label
def extract_prediction(self, prec, drec, drec_label):
return extract_prediction_from_prediction_record(prec)
class RecallGroupedMetricsConfiguration(GroupedMetricsConfiguration):
"""
This is the Recall Grouped metric computation configuration.
Recall at the max length of each session is the default.
Recall at additional positions can be computed by specifying a list of 'n's and 'k's
"""
def __init__(self, n=[], k=[], w=[]):
"""
Args:
n (list): A list of ints. List of prediction rank thresholds (for light)
k (list): A list of ints. List of label rank thresholds (for heavy)
"""
self.predN = n
self.labelK = k
self.weight = w
@property
def name(self):
return "group_recall"
@property
def metrics_dict(self):
metrics_to_compute = {"group_recall_unweighted": recall_at_nk}
if not self.weight:
metrics_to_compute["group_recall_weighted"] = partial(recall_at_nk, w=self.weight)
if self.predN and self.labelK:
for n in self.predN:
for k in self.labelK:
if n >= k:
metrics_to_compute[
"group_recall_unweighted_at_L" + str(n) + "_at_H" + str(k)
] = partial(recall_at_nk, n=n, k=k)
if self.weight:
metrics_to_compute[
"group_recall_weighted_at_L" + str(n) + "_at_H" + str(k)
] = partial(recall_at_nk, n=n, k=k, w=self.weight)
if self.labelK and not self.predN:
for k in self.labelK:
metrics_to_compute["group_recall_unweighted_at_full_at_H" + str(k)] = partial(
recall_at_nk, k=k
)
if self.weight:
metrics_to_compute["group_recall_weighted_at_full_at_H" + str(k)] = partial(
recall_at_nk, k=k, w=self.weight
)
return metrics_to_compute
def extract_label(self, prec, drec, drec_label):
return drec_label
def extract_prediction(self, prec, drec, drec_label):
return extract_prediction_from_prediction_record(prec)

View File

@ -0,0 +1,294 @@
"""
Utilties for constructing a metric_fn for magic recs.
"""
from twml.contrib.metrics.metrics import (
get_dual_binary_tasks_metric_fn,
get_numeric_metric_fn,
get_partial_multi_binary_class_metric_fn,
get_single_binary_task_metric_fn,
)
from .model_utils import generate_disliked_mask
import tensorflow.compat.v1 as tf
METRIC_BOOK = {
"OONC": ["OONC"],
"OONC_Engagement": ["OONC", "Engagement"],
"Sent": ["Sent"],
"HeavyRankPosition": ["HeavyRankPosition"],
"HeavyRankProbability": ["HeavyRankProbability"],
}
USER_AGE_FEATURE_NAME = "accountAge"
NEW_USER_AGE_CUTOFF = 0
def remove_padding_and_flatten(tensor, valid_batch_size):
"""Remove the padding of the input padded tensor given the valid batch size tensor,
then flatten the output with respect to the first dimension.
Args:
tensor: A tensor of size [META_BATCH_SIZE, BATCH_SIZE, FEATURE_DIM].
valid_batch_size: A tensor of size [META_BATCH_SIZE], with each element indicating
the effective batch size of the BATCH_SIZE dimension.
Returns:
A tesnor of size [tf.reduce_sum(valid_batch_size), FEATURE_DIM].
"""
unpadded_ragged_tensor = tf.RaggedTensor.from_tensor(tensor=tensor, lengths=valid_batch_size)
return unpadded_ragged_tensor.flat_values
def safe_mask(values, mask):
"""Mask values if possible.
Boolean mask inputed values if and only if values is a tensor of the same dimension as mask (or can be broadcasted to that dimension).
Args:
values (Any or Tensor): Input tensor to mask. Dim 0 should be size N.
mask (boolean tensor): A boolean tensor of size N.
Returns Values or Values masked.
"""
if values is None:
return values
if not tf.is_tensor(values):
return values
values_shape = values.get_shape()
if not values_shape or len(values_shape) == 0:
return values
if not mask.get_shape().is_compatible_with(values_shape[0]):
return values
return tf.boolean_mask(values, mask)
def add_new_user_metrics(metric_fn):
"""Will stratify the metric_fn by adding new user metrics.
Given an input metric_fn, double every metric: One will be the orignal and the other will only include those for new users.
Args:
metric_fn (python function): Base twml metric_fn.
Returns a metric_fn with new user metrics included.
"""
def metric_fn_with_new_users(graph_output, labels, weights):
if USER_AGE_FEATURE_NAME not in graph_output:
raise ValueError(
"In order to get metrics stratified by user age, {name} feature should be added to model graph output. However, only the following output keys were found: {keys}.".format(
name=USER_AGE_FEATURE_NAME, keys=graph_output.keys()
)
)
metric_ops = metric_fn(graph_output, labels, weights)
is_new = tf.reshape(
tf.math.less_equal(
tf.cast(graph_output[USER_AGE_FEATURE_NAME], tf.int64),
tf.cast(NEW_USER_AGE_CUTOFF, tf.int64),
),
[-1],
)
labels = safe_mask(labels, is_new)
weights = safe_mask(weights, is_new)
graph_output = {key: safe_mask(values, is_new) for key, values in graph_output.items()}
new_user_metric_ops = metric_fn(graph_output, labels, weights)
new_user_metric_ops = {name + "_new_users": ops for name, ops in new_user_metric_ops.items()}
metric_ops.update(new_user_metric_ops)
return metric_ops
return metric_fn_with_new_users
def get_meta_learn_single_binary_task_metric_fn(
metrics, classnames, top_k=(5, 5, 5), use_top_k=False
):
"""Wrapper function to use the metric_fn with meta learning evaluation scheme.
Args:
metrics: A list of string representing metric names.
classnames: A list of string repsenting class names, In case of multiple binary class models,
the names for each class or label.
top_k: A tuple of int to specify top K metrics.
use_top_k: A boolean value indicating of top K of metrics is used.
Returns:
A customized metric_fn function.
"""
def get_eval_metric_ops(graph_output, labels, weights):
"""The op func of the eval_metrics. Comparing with normal version,
the difference is we flatten the output, label, and weights.
Args:
graph_output: A dict of tensors.
labels: A tensor of int32 be the value of either 0 or 1.
weights: A tensor of float32 to indicate the per record weight.
Returns:
A dict of metric names and values.
"""
metric_op_weighted = get_partial_multi_binary_class_metric_fn(
metrics, predcols=0, classes=classnames
)
classnames_unweighted = ["unweighted_" + classname for classname in classnames]
metric_op_unweighted = get_partial_multi_binary_class_metric_fn(
metrics, predcols=0, classes=classnames_unweighted
)
valid_batch_size = graph_output["valid_batch_size"]
graph_output["output"] = remove_padding_and_flatten(graph_output["output"], valid_batch_size)
labels = remove_padding_and_flatten(labels, valid_batch_size)
weights = remove_padding_and_flatten(weights, valid_batch_size)
tf.ensure_shape(graph_output["output"], [None, 1])
tf.ensure_shape(labels, [None, 1])
tf.ensure_shape(weights, [None, 1])
metrics_weighted = metric_op_weighted(graph_output, labels, weights)
metrics_unweighted = metric_op_unweighted(graph_output, labels, None)
metrics_weighted.update(metrics_unweighted)
if use_top_k:
metric_op_numeric = get_numeric_metric_fn(metrics=None, topK=top_k, predcol=0, labelcol=1)
metrics_numeric = metric_op_numeric(graph_output, labels, weights)
metrics_weighted.update(metrics_numeric)
return metrics_weighted
return get_eval_metric_ops
def get_meta_learn_dual_binary_tasks_metric_fn(
metrics, classnames, top_k=(5, 5, 5), use_top_k=False
):
"""Wrapper function to use the metric_fn with meta learning evaluation scheme.
Args:
metrics: A list of string representing metric names.
classnames: A list of string repsenting class names, In case of multiple binary class models,
the names for each class or label.
top_k: A tuple of int to specify top K metrics.
use_top_k: A boolean value indicating of top K of metrics is used.
Returns:
A customized metric_fn function.
"""
def get_eval_metric_ops(graph_output, labels, weights):
"""The op func of the eval_metrics. Comparing with normal version,
the difference is we flatten the output, label, and weights.
Args:
graph_output: A dict of tensors.
labels: A tensor of int32 be the value of either 0 or 1.
weights: A tensor of float32 to indicate the per record weight.
Returns:
A dict of metric names and values.
"""
metric_op_weighted = get_partial_multi_binary_class_metric_fn(
metrics, predcols=[0, 1], classes=classnames
)
classnames_unweighted = ["unweighted_" + classname for classname in classnames]
metric_op_unweighted = get_partial_multi_binary_class_metric_fn(
metrics, predcols=[0, 1], classes=classnames_unweighted
)
valid_batch_size = graph_output["valid_batch_size"]
graph_output["output"] = remove_padding_and_flatten(graph_output["output"], valid_batch_size)
labels = remove_padding_and_flatten(labels, valid_batch_size)
weights = remove_padding_and_flatten(weights, valid_batch_size)
tf.ensure_shape(graph_output["output"], [None, 2])
tf.ensure_shape(labels, [None, 2])
tf.ensure_shape(weights, [None, 1])
metrics_weighted = metric_op_weighted(graph_output, labels, weights)
metrics_unweighted = metric_op_unweighted(graph_output, labels, None)
metrics_weighted.update(metrics_unweighted)
if use_top_k:
metric_op_numeric = get_numeric_metric_fn(metrics=None, topK=top_k, predcol=2, labelcol=2)
metrics_numeric = metric_op_numeric(graph_output, labels, weights)
metrics_weighted.update(metrics_numeric)
return metrics_weighted
return get_eval_metric_ops
def get_metric_fn(task_name, use_stratify_metrics, use_meta_batch=False):
"""Will retrieve the metric_fn for magic recs.
Args:
task_name (string): Which task is being used for this model.
use_stratify_metrics (boolean): Should we add stratified metrics (new user metrics).
use_meta_batch (boolean): If the output/label/weights are passed in 3D shape instead of
2D shape.
Returns:
A metric_fn function to pass in twml Trainer.
"""
if task_name not in METRIC_BOOK:
raise ValueError(
"Task name of {task_name} not recognized. Unable to retrieve metrics.".format(
task_name=task_name
)
)
class_names = METRIC_BOOK[task_name]
if use_meta_batch:
get_n_binary_task_metric_fn = (
get_meta_learn_single_binary_task_metric_fn
if len(class_names) == 1
else get_meta_learn_dual_binary_tasks_metric_fn
)
else:
get_n_binary_task_metric_fn = (
get_single_binary_task_metric_fn if len(class_names) == 1 else get_dual_binary_tasks_metric_fn
)
metric_fn = get_n_binary_task_metric_fn(metrics=None, classnames=METRIC_BOOK[task_name])
if use_stratify_metrics:
metric_fn = add_new_user_metrics(metric_fn)
return metric_fn
def flip_disliked_labels(metric_fn):
"""This function returns an adapted metric_fn which flips the labels of the OONCed evaluation data to 0 if it is disliked.
Args:
metric_fn: A metric_fn function to pass in twml Trainer.
Returns:
_adapted_metric_fn: A customized metric_fn function with disliked OONC labels flipped.
"""
def _adapted_metric_fn(graph_output, labels, weights):
"""A customized metric_fn function with disliked OONC labels flipped.
Args:
graph_output: A dict of tensors.
labels: labels of training samples, which is a 2D tensor of shape batch_size x 3: [OONCs, engagements, dislikes]
weights: A tensor of float32 to indicate the per record weight.
Returns:
A dict of metric names and values.
"""
# We want to multiply the label of the observation by 0 only when it is disliked
disliked_mask = generate_disliked_mask(labels)
# Extract OONC and engagement labels only.
labels = tf.reshape(labels[:, 0:2], shape=[-1, 2])
# Labels will be set to 0 if it is disliked.
adapted_labels = labels * tf.cast(tf.logical_not(disliked_mask), dtype=labels.dtype)
return metric_fn(graph_output, adapted_labels, weights)
return _adapted_metric_fn

View File

@ -0,0 +1,231 @@
from twml.trainers import DataRecordTrainer
# checkstyle: noqa
def get_arg_parser():
parser = DataRecordTrainer.add_parser_arguments()
parser.add_argument(
"--input_size_bits",
type=int,
default=18,
help="number of bits allocated to the input size",
)
parser.add_argument(
"--model_trainer_name",
default="magic_recs_mlp_calibration_MTL_OONC_Engagement",
type=str,
help="specify the model trainer name.",
)
parser.add_argument(
"--model_type",
default="deepnorm_gbdt_inputdrop2_rescale",
type=str,
help="specify the model type to use.",
)
parser.add_argument(
"--feat_config_type",
default="get_feature_config_with_sparse_continuous",
type=str,
help="specify the feature configure function to use.",
)
parser.add_argument(
"--directly_export_best",
default=False,
action="store_true",
help="whether to directly_export best_checkpoint",
)
parser.add_argument(
"--warm_start_base_dir",
default="none",
type=str,
help="latest ckpt in this folder will be used to ",
)
parser.add_argument(
"--feature_list",
default="none",
type=str,
help="Which features to use for training",
)
parser.add_argument(
"--warm_start_from", default=None, type=str, help="model dir to warm start from"
)
parser.add_argument(
"--momentum", default=0.99999, type=float, help="Momentum term for batch normalization"
)
parser.add_argument(
"--dropout",
default=0.2,
type=float,
help="input_dropout_rate to rescale output by (1 - input_dropout_rate)",
)
parser.add_argument(
"--out_layer_1_size", default=256, type=int, help="Size of MLP_branch layer 1"
)
parser.add_argument(
"--out_layer_2_size", default=128, type=int, help="Size of MLP_branch layer 2"
)
parser.add_argument("--out_layer_3_size", default=64, type=int, help="Size of MLP_branch layer 3")
parser.add_argument(
"--sparse_embedding_size", default=50, type=int, help="Dimensionality of sparse embedding layer"
)
parser.add_argument(
"--dense_embedding_size", default=128, type=int, help="Dimensionality of dense embedding layer"
)
parser.add_argument(
"--use_uam_label",
default=False,
type=str,
help="Whether to use uam_label or not",
)
parser.add_argument(
"--task_name",
default="OONC_Engagement",
type=str,
help="specify the task name to use: OONC or OONC_Engagement.",
)
parser.add_argument(
"--init_weight",
default=0.9,
type=float,
help="Initial OONC Task Weight MTL: OONC+Engagement.",
)
parser.add_argument(
"--use_engagement_weight",
default=False,
action="store_true",
help="whether to use engagement weight for base model.",
)
parser.add_argument(
"--mtl_num_extra_layers",
type=int,
default=1,
help="Number of Hidden Layers for each TaskBranch.",
)
parser.add_argument(
"--mtl_neuron_scale", type=int, default=4, help="Scaling Factor of Neurons in MTL Extra Layers."
)
parser.add_argument(
"--use_oonc_score",
default=False,
action="store_true",
help="whether to use oonc score only or combined score.",
)
parser.add_argument(
"--use_stratified_metrics",
default=False,
action="store_true",
help="Use stratified metrics: Break out new-user metrics.",
)
parser.add_argument(
"--run_group_metrics",
default=False,
action="store_true",
help="Will run evaluation metrics grouped by user.",
)
parser.add_argument(
"--use_full_scope",
default=False,
action="store_true",
help="Will add extra scope and naming to graph.",
)
parser.add_argument(
"--trainable_regexes",
default=None,
nargs="*",
help="The union of variables specified by the list of regexes will be considered trainable.",
)
parser.add_argument(
"--fine_tuning.ckpt_to_initialize_from",
dest="fine_tuning_ckpt_to_initialize_from",
type=str,
default=None,
help="Checkpoint path from which to warm start. Indicates the pre-trained model.",
)
parser.add_argument(
"--fine_tuning.warm_start_scope_regex",
dest="fine_tuning_warm_start_scope_regex",
type=str,
default=None,
help="All variables matching this will be restored.",
)
return parser
def get_params(args=None):
parser = get_arg_parser()
if args is None:
return parser.parse_args()
else:
return parser.parse_args(args)
def get_arg_parser_light_ranking():
parser = get_arg_parser()
parser.add_argument(
"--use_record_weight",
default=False,
action="store_true",
help="whether to use record weight for base model.",
)
parser.add_argument(
"--min_record_weight", default=0.0, type=float, help="Minimum record weight to use."
)
parser.add_argument(
"--smooth_weight", default=0.0, type=float, help="Factor to smooth Rank Position Weight."
)
parser.add_argument(
"--num_mlp_layers", type=int, default=3, help="Number of Hidden Layers for MLP model."
)
parser.add_argument(
"--mlp_neuron_scale", type=int, default=4, help="Scaling Factor of Neurons in MLP Layers."
)
parser.add_argument(
"--run_light_ranking_group_metrics",
default=False,
action="store_true",
help="Will run evaluation metrics grouped by user for Light Ranking.",
)
parser.add_argument(
"--use_missing_sub_branch",
default=False,
action="store_true",
help="Whether to use missing value sub-branch for Light Ranking.",
)
parser.add_argument(
"--use_gbdt_features",
default=False,
action="store_true",
help="Whether to use GBDT features for Light Ranking.",
)
parser.add_argument(
"--run_light_ranking_group_metrics_in_bq",
default=False,
action="store_true",
help="Whether to get_predictions for Light Ranking to compute group metrics in BigQuery.",
)
parser.add_argument(
"--pred_file_path",
default=None,
type=str,
help="path",
)
parser.add_argument(
"--pred_file_name",
default=None,
type=str,
help="path",
)
return parser

View File

@ -0,0 +1,339 @@
import sys
import twml
from .initializer import customized_glorot_uniform
import tensorflow.compat.v1 as tf
import yaml
# checkstyle: noqa
def read_config(whitelist_yaml_file):
with tf.gfile.FastGFile(whitelist_yaml_file) as f:
try:
return yaml.safe_load(f)
except yaml.YAMLError as exc:
print(exc)
sys.exit(1)
def _sparse_feature_fixup(features, input_size_bits):
"""Rebuild a sparse tensor feature so that its dense shape attribute is present.
Arguments:
features (SparseTensor): Sparse feature tensor of shape ``(B, sparse_feature_dim)``.
input_size_bits (int): Number of columns in ``log2`` scale. Must be positive.
Returns:
SparseTensor: Rebuilt and non-faulty version of `features`."""
sparse_feature_dim = tf.constant(2**input_size_bits, dtype=tf.int64)
sparse_shape = tf.stack([features.dense_shape[0], sparse_feature_dim])
sparse_tf = tf.SparseTensor(features.indices, features.values, sparse_shape)
return sparse_tf
def self_atten_dense(input, out_dim, activation=None, use_bias=True, name=None):
def safe_concat(base, suffix):
"""Concats variables name components if base is given."""
if not base:
return base
return f"{base}:{suffix}"
input_dim = input.shape.as_list()[1]
sigmoid_out = twml.layers.FullDense(
input_dim, dtype=tf.float32, activation=tf.nn.sigmoid, name=safe_concat(name, "sigmoid_out")
)(input)
atten_input = sigmoid_out * input
mlp_out = twml.layers.FullDense(
out_dim,
dtype=tf.float32,
activation=activation,
use_bias=use_bias,
name=safe_concat(name, "mlp_out"),
)(atten_input)
return mlp_out
def get_dense_out(input, out_dim, activation, dense_type):
if dense_type == "full_dense":
out = twml.layers.FullDense(out_dim, dtype=tf.float32, activation=activation)(input)
elif dense_type == "self_atten_dense":
out = self_atten_dense(input, out_dim, activation=activation)
return out
def get_input_trans_func(bn_normalized_dense, is_training):
gw_normalized_dense = tf.expand_dims(bn_normalized_dense, -1)
group_num = bn_normalized_dense.shape.as_list()[1]
gw_normalized_dense = GroupWiseTrans(group_num, 1, 8, name="groupwise_1", activation=tf.tanh)(
gw_normalized_dense
)
gw_normalized_dense = GroupWiseTrans(group_num, 8, 4, name="groupwise_2", activation=tf.tanh)(
gw_normalized_dense
)
gw_normalized_dense = GroupWiseTrans(group_num, 4, 1, name="groupwise_3", activation=tf.tanh)(
gw_normalized_dense
)
gw_normalized_dense = tf.squeeze(gw_normalized_dense, [-1])
bn_gw_normalized_dense = tf.layers.batch_normalization(
gw_normalized_dense,
training=is_training,
renorm_momentum=0.9999,
momentum=0.9999,
renorm=is_training,
trainable=True,
)
return bn_gw_normalized_dense
def tensor_dropout(
input_tensor,
rate,
is_training,
sparse_tensor=None,
):
"""
Implements dropout layer for both dense and sparse input_tensor
Arguments:
input_tensor:
B x D dense tensor, or a sparse tensor
rate (float32):
dropout rate
is_training (bool):
training stage or not.
sparse_tensor (bool):
whether the input_tensor is sparse tensor or not. Default to be None, this value has to be passed explicitly.
rescale_sparse_dropout (bool):
Do we need to do rescaling or not.
Returns:
tensor dropped out"""
if sparse_tensor == True:
if is_training:
with tf.variable_scope("sparse_dropout"):
values = input_tensor.values
keep_mask = tf.keras.backend.random_binomial(
tf.shape(values), p=1 - rate, dtype=tf.float32, seed=None
)
keep_mask.set_shape([None])
keep_mask = tf.cast(keep_mask, tf.bool)
keep_indices = tf.boolean_mask(input_tensor.indices, keep_mask, axis=0)
keep_values = tf.boolean_mask(values, keep_mask, axis=0)
dropped_tensor = tf.SparseTensor(keep_indices, keep_values, input_tensor.dense_shape)
return dropped_tensor
else:
return input_tensor
elif sparse_tensor == False:
return tf.layers.dropout(input_tensor, rate=rate, training=is_training)
def adaptive_transformation(bn_normalized_dense, is_training, func_type="default"):
assert func_type in [
"default",
"tiny",
], f"fun_type can only be one of default and tiny, but get {func_type}"
gw_normalized_dense = tf.expand_dims(bn_normalized_dense, -1)
group_num = bn_normalized_dense.shape.as_list()[1]
if func_type == "default":
gw_normalized_dense = FastGroupWiseTrans(
group_num, 1, 8, name="groupwise_1", activation=tf.tanh, init_multiplier=8
)(gw_normalized_dense)
gw_normalized_dense = FastGroupWiseTrans(
group_num, 8, 4, name="groupwise_2", activation=tf.tanh, init_multiplier=8
)(gw_normalized_dense)
gw_normalized_dense = FastGroupWiseTrans(
group_num, 4, 1, name="groupwise_3", activation=tf.tanh, init_multiplier=8
)(gw_normalized_dense)
elif func_type == "tiny":
gw_normalized_dense = FastGroupWiseTrans(
group_num, 1, 2, name="groupwise_1", activation=tf.tanh, init_multiplier=8
)(gw_normalized_dense)
gw_normalized_dense = FastGroupWiseTrans(
group_num, 2, 1, name="groupwise_2", activation=tf.tanh, init_multiplier=8
)(gw_normalized_dense)
gw_normalized_dense = FastGroupWiseTrans(
group_num, 1, 1, name="groupwise_3", activation=tf.tanh, init_multiplier=8
)(gw_normalized_dense)
gw_normalized_dense = tf.squeeze(gw_normalized_dense, [-1])
bn_gw_normalized_dense = tf.layers.batch_normalization(
gw_normalized_dense,
training=is_training,
renorm_momentum=0.9999,
momentum=0.9999,
renorm=is_training,
trainable=True,
)
return bn_gw_normalized_dense
class FastGroupWiseTrans(object):
"""
used to apply group-wise fully connected layers to the input.
it applies a tiny, unique MLP to each individual feature."""
def __init__(self, group_num, input_dim, out_dim, name, activation=None, init_multiplier=1):
self.group_num = group_num
self.input_dim = input_dim
self.out_dim = out_dim
self.activation = activation
self.init_multiplier = init_multiplier
self.w = tf.get_variable(
name + "_group_weight",
[1, group_num, input_dim, out_dim],
initializer=customized_glorot_uniform(
fan_in=input_dim * init_multiplier, fan_out=out_dim * init_multiplier
),
trainable=True,
)
self.b = tf.get_variable(
name + "_group_bias",
[1, group_num, out_dim],
initializer=tf.constant_initializer(0.0),
trainable=True,
)
def __call__(self, input_tensor):
"""
input_tensor: batch_size x group_num x input_dim
output_tensor: batch_size x group_num x out_dim"""
input_tensor_expand = tf.expand_dims(input_tensor, axis=-1)
output_tensor = tf.add(
tf.reduce_sum(tf.multiply(input_tensor_expand, self.w), axis=-2, keepdims=False),
self.b,
)
if self.activation is not None:
output_tensor = self.activation(output_tensor)
return output_tensor
class GroupWiseTrans(object):
"""
Used to apply group fully connected layers to the input.
"""
def __init__(self, group_num, input_dim, out_dim, name, activation=None):
self.group_num = group_num
self.input_dim = input_dim
self.out_dim = out_dim
self.activation = activation
w_list, b_list = [], []
for idx in range(out_dim):
this_w = tf.get_variable(
name + f"_group_weight_{idx}",
[1, group_num, input_dim],
initializer=tf.keras.initializers.glorot_uniform(),
trainable=True,
)
this_b = tf.get_variable(
name + f"_group_bias_{idx}",
[1, group_num, 1],
initializer=tf.constant_initializer(0.0),
trainable=True,
)
w_list.append(this_w)
b_list.append(this_b)
self.w_list = w_list
self.b_list = b_list
def __call__(self, input_tensor):
"""
input_tensor: batch_size x group_num x input_dim
output_tensor: batch_size x group_num x out_dim
"""
out_tensor_list = []
for idx in range(self.out_dim):
this_res = (
tf.reduce_sum(input_tensor * self.w_list[idx], axis=-1, keepdims=True) + self.b_list[idx]
)
out_tensor_list.append(this_res)
output_tensor = tf.concat(out_tensor_list, axis=-1)
if self.activation is not None:
output_tensor = self.activation(output_tensor)
return output_tensor
def add_scalar_summary(var, name, name_scope="hist_dense_feature/"):
with tf.name_scope("summaries/"):
with tf.name_scope(name_scope):
tf.summary.scalar(name, var)
def add_histogram_summary(var, name, name_scope="hist_dense_feature/"):
with tf.name_scope("summaries/"):
with tf.name_scope(name_scope):
tf.summary.histogram(name, tf.reshape(var, [-1]))
def sparse_clip_by_value(sparse_tf, min_val, max_val):
new_vals = tf.clip_by_value(sparse_tf.values, min_val, max_val)
return tf.SparseTensor(sparse_tf.indices, new_vals, sparse_tf.dense_shape)
def check_numerics_with_msg(tensor, message="", sparse_tensor=False):
if sparse_tensor:
values = tf.debugging.check_numerics(tensor.values, message=message)
return tf.SparseTensor(tensor.indices, values, tensor.dense_shape)
else:
return tf.debugging.check_numerics(tensor, message=message)
def pad_empty_sparse_tensor(tensor):
dummy_tensor = tf.SparseTensor(
indices=[[0, 0]],
values=[0.00001],
dense_shape=tensor.dense_shape,
)
result = tf.cond(
tf.equal(tf.size(tensor.values), 0),
lambda: dummy_tensor,
lambda: tensor,
)
return result
def filter_nans_and_infs(tensor, sparse_tensor=False):
if sparse_tensor:
sparse_values = tensor.values
filtered_val = tf.where(
tf.logical_or(tf.is_nan(sparse_values), tf.is_inf(sparse_values)),
tf.zeros_like(sparse_values),
sparse_values,
)
return tf.SparseTensor(tensor.indices, filtered_val, tensor.dense_shape)
else:
return tf.where(
tf.logical_or(tf.is_nan(tensor), tf.is_inf(tensor)), tf.zeros_like(tensor), tensor
)
def generate_disliked_mask(labels):
"""Generate a disliked mask where only samples with dislike labels are set to 1 otherwise set to 0.
Args:
labels: labels of training samples, which is a 2D tensor of shape batch_size x 3: [OONCs, engagements, dislikes]
Returns:
1D tensor of shape batch_size x 1: [dislikes (booleans)]
"""
return tf.equal(tf.reshape(labels[:, 2], shape=[-1, 1]), 1)

View File

@ -0,0 +1,309 @@
from collections import OrderedDict
import json
import os
from os.path import join
from twitter.magicpony.common import file_access
import twml
from .model_utils import read_config
import numpy as np
from scipy import stats
import tensorflow.compat.v1 as tf
# checkstyle: noqa
def get_model_type_to_tensors_to_change_axis():
model_type_to_tensors_to_change_axis = {
"magic_recs/model/batch_normalization/beta": ([0], "continuous"),
"magic_recs/model/batch_normalization/gamma": ([0], "continuous"),
"magic_recs/model/batch_normalization/moving_mean": ([0], "continuous"),
"magic_recs/model/batch_normalization/moving_stddev": ([0], "continuous"),
"magic_recs/model/batch_normalization/moving_variance": ([0], "continuous"),
"magic_recs/model/batch_normalization/renorm_mean": ([0], "continuous"),
"magic_recs/model/batch_normalization/renorm_stddev": ([0], "continuous"),
"magic_recs/model/logits/EngagementGivenOONC_logits/clem_net_1/block2_4/channel_wise_dense_4/kernel": (
[1],
"all",
),
"magic_recs/model/logits/OONC_logits/clem_net/block2/channel_wise_dense/kernel": ([1], "all"),
}
return model_type_to_tensors_to_change_axis
def mkdirp(dirname):
if not tf.io.gfile.exists(dirname):
tf.io.gfile.makedirs(dirname)
def rename_dir(dirname, dst):
file_access.hdfs.mv(dirname, dst)
def rmdir(dirname):
if tf.io.gfile.exists(dirname):
if tf.io.gfile.isdir(dirname):
tf.io.gfile.rmtree(dirname)
else:
tf.io.gfile.remove(dirname)
def get_var_dict(checkpoint_path):
checkpoint = tf.train.get_checkpoint_state(checkpoint_path)
var_dict = OrderedDict()
with tf.Session() as sess:
all_var_list = tf.train.list_variables(checkpoint_path)
for var_name, _ in all_var_list:
# Load the variable
var = tf.train.load_variable(checkpoint_path, var_name)
var_dict[var_name] = var
return var_dict
def get_continunous_mapping_from_feat_list(old_feature_list, new_feature_list):
"""
get var_ind for old_feature and corresponding var_ind for new_feature
"""
new_var_ind, old_var_ind = [], []
for this_new_id, this_new_name in enumerate(new_feature_list):
if this_new_name in old_feature_list:
this_old_id = old_feature_list.index(this_new_name)
new_var_ind.append(this_new_id)
old_var_ind.append(this_old_id)
return np.asarray(old_var_ind), np.asarray(new_var_ind)
def get_continuous_mapping_from_feat_dict(old_feature_dict, new_feature_dict):
"""
get var_ind for old_feature and corresponding var_ind for new_feature
"""
old_cont = old_feature_dict["continuous"]
old_bin = old_feature_dict["binary"]
new_cont = new_feature_dict["continuous"]
new_bin = new_feature_dict["binary"]
_dummy_sparse_feat = [f"sparse_feature_{_idx}" for _idx in range(100)]
cont_old_var_ind, cont_new_var_ind = get_continunous_mapping_from_feat_list(old_cont, new_cont)
all_old_var_ind, all_new_var_ind = get_continunous_mapping_from_feat_list(
old_cont + old_bin + _dummy_sparse_feat, new_cont + new_bin + _dummy_sparse_feat
)
_res = {
"continuous": (cont_old_var_ind, cont_new_var_ind),
"all": (all_old_var_ind, all_new_var_ind),
}
return _res
def warm_start_from_var_dict(
old_ckpt_path,
var_ind_dict,
output_dir,
new_len_var,
var_to_change_dict_fn=get_model_type_to_tensors_to_change_axis,
):
"""
Parameters:
old_ckpt_path (str): path to the old checkpoint path
new_var_ind (array of int): index to overlapping features in new var between old and new feature list.
old_var_ind (array of int): index to overlapping features in old var between old and new feature list.
output_dir (str): dir that used to write modified checkpoint
new_len_var ({str:int}): number of feature in the new feature list.
var_to_change_dict_fn (dict): A function to get the dictionary of format {var_name: dim_to_change}
"""
old_var_dict = get_var_dict(old_ckpt_path)
ckpt_file_name = os.path.basename(old_ckpt_path)
mkdirp(output_dir)
output_path = join(output_dir, ckpt_file_name)
tensors_to_change = var_to_change_dict_fn()
tf.compat.v1.reset_default_graph()
with tf.Session() as sess:
var_name_shape_list = tf.train.list_variables(old_ckpt_path)
count = 0
for var_name, var_shape in var_name_shape_list:
old_var = old_var_dict[var_name]
if var_name in tensors_to_change.keys():
_info_tuple = tensors_to_change[var_name]
dims_to_remove_from, var_type = _info_tuple
new_var_ind, old_var_ind = var_ind_dict[var_type]
this_shape = list(old_var.shape)
for this_dim in dims_to_remove_from:
this_shape[this_dim] = new_len_var[var_type]
stddev = np.std(old_var)
truncated_norm_generator = stats.truncnorm(-0.5, 0.5, loc=0, scale=stddev)
size = np.prod(this_shape)
new_var = truncated_norm_generator.rvs(size).reshape(this_shape)
new_var = new_var.astype(old_var.dtype)
new_var = copy_feat_based_on_mapping(
new_var, old_var, dims_to_remove_from, new_var_ind, old_var_ind
)
count = count + 1
else:
new_var = old_var
var = tf.Variable(new_var, name=var_name)
assert count == len(tensors_to_change.keys()), "not all variables are exchanged.\n"
saver = tf.train.Saver()
sess.run(tf.global_variables_initializer())
saver.save(sess, output_path)
return output_path
def copy_feat_based_on_mapping(new_array, old_array, dims_to_remove_from, new_var_ind, old_var_ind):
if dims_to_remove_from == [0, 1]:
for this_new_ind, this_old_ind in zip(new_var_ind, old_var_ind):
new_array[this_new_ind, new_var_ind] = old_array[this_old_ind, old_var_ind]
elif dims_to_remove_from == [0]:
new_array[new_var_ind] = old_array[old_var_ind]
elif dims_to_remove_from == [1]:
new_array[:, new_var_ind] = old_array[:, old_var_ind]
else:
raise RuntimeError(f"undefined dims_to_remove_from pattern: ({dims_to_remove_from})")
return new_array
def read_file(filename, decode=False):
"""
Reads contents from a file and optionally decodes it.
Arguments:
filename:
path to file where the contents will be loaded from.
Accepts HDFS and local paths.
decode:
False or 'json'. When decode='json', contents is decoded
with json.loads. When False, contents is returned as is.
"""
graph = tf.Graph()
with graph.as_default():
read = tf.read_file(filename)
with tf.Session(graph=graph) as sess:
contents = sess.run(read)
if not isinstance(contents, str):
contents = contents.decode()
if decode == "json":
contents = json.loads(contents)
return contents
def read_feat_list_from_disk(file_path):
return read_file(file_path, decode="json")
def get_feature_list_for_light_ranking(feature_list_path, data_spec_path):
feature_list = read_config(feature_list_path).items()
string_feat_list = [f[0] for f in feature_list if f[1] != "S"]
feature_config_builder = twml.contrib.feature_config.FeatureConfigBuilder(
data_spec_path=data_spec_path
)
feature_config_builder = feature_config_builder.extract_feature_group(
feature_regexes=string_feat_list,
group_name="continuous",
default_value=-1,
type_filter=["CONTINUOUS"],
)
feature_config = feature_config_builder.build()
feature_list = feature_config_builder._feature_group_extraction_configs[0].feature_map[
"CONTINUOUS"
]
return feature_list
def get_feature_list_for_heavy_ranking(feature_list_path, data_spec_path):
feature_list = read_config(feature_list_path).items()
string_feat_list = [f[0] for f in feature_list if f[1] != "S"]
feature_config_builder = twml.contrib.feature_config.FeatureConfigBuilder(
data_spec_path=data_spec_path
)
feature_config_builder = feature_config_builder.extract_feature_group(
feature_regexes=string_feat_list,
group_name="continuous",
default_value=-1,
type_filter=["CONTINUOUS"],
)
feature_config_builder = feature_config_builder.extract_feature_group(
feature_regexes=string_feat_list,
group_name="binary",
default_value=False,
type_filter=["BINARY"],
)
feature_config_builder = feature_config_builder.build()
continuous_feature_list = feature_config_builder._feature_group_extraction_configs[0].feature_map[
"CONTINUOUS"
]
binary_feature_list = feature_config_builder._feature_group_extraction_configs[1].feature_map[
"BINARY"
]
return {"continuous": continuous_feature_list, "binary": binary_feature_list}
def warm_start_checkpoint(
old_best_ckpt_folder,
old_feature_list_path,
feature_allow_list_path,
data_spec_path,
output_ckpt_folder,
*args,
):
"""
Reads old checkpoint and the old feature list, and create a new ckpt warm started from old ckpt using new features .
Arguments:
old_best_ckpt_folder:
path to the best_checkpoint_folder for old model
old_feature_list_path:
path to the json file that stores the list of continuous features used in old models.
feature_allow_list_path:
yaml file that contain the feature allow list.
data_spec_path:
path to the data_spec file
output_ckpt_folder:
folder that contains the modified ckpt.
Returns:
path to the modified ckpt."""
old_ckpt_path = tf.train.latest_checkpoint(old_best_ckpt_folder, latest_filename=None)
new_feature_dict = get_feature_list(feature_allow_list_path, data_spec_path)
old_feature_dict = read_feat_list_from_disk(old_feature_list_path)
var_ind_dict = get_continuous_mapping_from_feat_dict(new_feature_dict, old_feature_dict)
new_len_var = {
"continuous": len(new_feature_dict["continuous"]),
"all": len(new_feature_dict["continuous"] + new_feature_dict["binary"]) + 100,
}
warm_started_ckpt_path = warm_start_from_var_dict(
old_ckpt_path,
var_ind_dict,
output_dir=output_ckpt_folder,
new_len_var=new_len_var,
)
return warm_started_ckpt_path

View File

@ -0,0 +1,69 @@
#":mlwf_libs",
python37_binary(
name = "eval_model",
source = "eval_model.py",
dependencies = [
":libs",
"3rdparty/python/_closures/frigate/frigate-pushservice-opensource/src/main/python/models/light_ranking:eval_model",
],
)
python37_binary(
name = "train_model",
source = "deep_norm.py",
dependencies = [
":libs",
"3rdparty/python/_closures/frigate/frigate-pushservice-opensource/src/main/python/models/light_ranking:train_model",
],
)
python37_binary(
name = "train_model_local",
source = "deep_norm.py",
dependencies = [
":libs",
"3rdparty/python/_closures/frigate/frigate-pushservice-opensource/src/main/python/models/light_ranking:train_model_local",
"twml",
],
)
python37_binary(
name = "eval_model_local",
source = "eval_model.py",
dependencies = [
":libs",
"3rdparty/python/_closures/frigate/frigate-pushservice-opensource/src/main/python/models/light_ranking:eval_model_local",
"twml",
],
)
python37_binary(
name = "mlwf_model",
source = "deep_norm.py",
dependencies = [
":mlwf_libs",
"3rdparty/python/_closures/frigate/frigate-pushservice-opensource/src/main/python/models/light_ranking:mlwf_model",
],
)
python3_library(
name = "libs",
sources = ["**/*.py"],
tags = ["no-mypy"],
dependencies = [
"src/python/twitter/deepbird/projects/magic_recs/libs",
"src/python/twitter/deepbird/util/data",
"twml:twml-nodeps",
],
)
python3_library(
name = "mlwf_libs",
sources = ["**/*.py"],
tags = ["no-mypy"],
dependencies = [
"src/python/twitter/deepbird/projects/magic_recs/libs",
"twml",
],
)

View File

@ -0,0 +1,14 @@
# Notification Light Ranker Model
## Model Context
There are 4 major components of Twitter notifications recommendation system: 1) candidate generation 2) light ranking 3) heavy ranking & 4) quality control. This notification light ranker model bridges candidate generation and heavy ranking by pre-selecting highly-relevant candidates from the initial huge candidate pool. Its a light-weight model to reduce system cost during heavy ranking without hurting user experience.
## Directory Structure
- BUILD: this file defines python library dependencies
- model_pools_mlp.py: this file defines tensorflow model architecture for the notification light ranker model
- deep_norm.py: this file contains 1) how to build the tensorflow graph with specified model architecture, loss function and training configuration. 2) how to set up the overall model training & evaluation pipeline
- eval_model.py: the main python entry file to set up the overall model evaluation pipeline

View File

@ -0,0 +1,226 @@
from datetime import datetime
from functools import partial
import os
from twitter.cortex.ml.embeddings.common.helpers import decode_str_or_unicode
import twml
from twml.trainers import DataRecordTrainer
from ..libs.get_feat_config import get_feature_config_light_ranking, LABELS_LR
from ..libs.graph_utils import get_trainable_variables
from ..libs.group_metrics import (
run_group_metrics_light_ranking,
run_group_metrics_light_ranking_in_bq,
)
from ..libs.metric_fn_utils import get_metric_fn
from ..libs.model_args import get_arg_parser_light_ranking
from ..libs.model_utils import read_config
from ..libs.warm_start_utils import get_feature_list_for_light_ranking
from .model_pools_mlp import light_ranking_mlp_ngbdt
import tensorflow.compat.v1 as tf
from tensorflow.compat.v1 import logging
# checkstyle: noqa
def build_graph(
features, label, mode, params, config=None, run_light_ranking_group_metrics_in_bq=False
):
is_training = mode == tf.estimator.ModeKeys.TRAIN
this_model_func = light_ranking_mlp_ngbdt
model_output = this_model_func(features, is_training, params, label)
logits = model_output["output"]
graph_output = {}
# --------------------------------------------------------
# define graph output dict
# --------------------------------------------------------
if mode == tf.estimator.ModeKeys.PREDICT:
loss = None
output_label = "prediction"
if params.task_name in LABELS_LR:
output = tf.nn.sigmoid(logits)
output = tf.clip_by_value(output, 0, 1)
if run_light_ranking_group_metrics_in_bq:
graph_output["trace_id"] = features["meta.trace_id"]
graph_output["target"] = features["meta.ranking.weighted_oonc_model_score"]
else:
raise ValueError("Invalid Task Name !")
else:
output_label = "output"
weights = tf.cast(features["weights"], dtype=tf.float32, name="RecordWeights")
if params.task_name in LABELS_LR:
if params.use_record_weight:
weights = tf.clip_by_value(
1.0 / (1.0 + weights + params.smooth_weight), params.min_record_weight, 1.0
)
loss = tf.reduce_sum(
tf.nn.sigmoid_cross_entropy_with_logits(labels=label, logits=logits) * weights
) / (tf.reduce_sum(weights))
else:
loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=label, logits=logits))
output = tf.nn.sigmoid(logits)
else:
raise ValueError("Invalid Task Name !")
train_op = None
if mode == tf.estimator.ModeKeys.TRAIN:
# --------------------------------------------------------
# get train_op
# --------------------------------------------------------
optimizer = tf.train.GradientDescentOptimizer(learning_rate=params.learning_rate)
update_ops = set(tf.get_collection(tf.GraphKeys.UPDATE_OPS))
variables = get_trainable_variables(
all_trainable_variables=tf.trainable_variables(), trainable_regexes=params.trainable_regexes
)
with tf.control_dependencies(update_ops):
train_op = twml.optimizers.optimize_loss(
loss=loss,
variables=variables,
global_step=tf.train.get_global_step(),
optimizer=optimizer,
learning_rate=params.learning_rate,
learning_rate_decay_fn=twml.learning_rate_decay.get_learning_rate_decay_fn(params),
)
graph_output[output_label] = output
graph_output["loss"] = loss
graph_output["train_op"] = train_op
return graph_output
def get_params(args=None):
parser = get_arg_parser_light_ranking()
if args is None:
return parser.parse_args()
else:
return parser.parse_args(args)
def _main():
opt = get_params()
logging.info("parse is: ")
logging.info(opt)
feature_list = read_config(opt.feature_list).items()
feature_config = get_feature_config_light_ranking(
data_spec_path=opt.data_spec,
feature_list_provided=feature_list,
opt=opt,
add_gbdt=opt.use_gbdt_features,
run_light_ranking_group_metrics_in_bq=opt.run_light_ranking_group_metrics_in_bq,
)
feature_list_path = opt.feature_list
# --------------------------------------------------------
# Create Trainer
# --------------------------------------------------------
trainer = DataRecordTrainer(
name=opt.model_trainer_name,
params=opt,
build_graph_fn=build_graph,
save_dir=opt.save_dir,
run_config=None,
feature_config=feature_config,
metric_fn=get_metric_fn(opt.task_name, use_stratify_metrics=False),
)
if opt.directly_export_best:
logging.info("Directly exporting the model without training")
else:
# ----------------------------------------------------
# Model Training & Evaluation
# ----------------------------------------------------
eval_input_fn = trainer.get_eval_input_fn(repeat=False, shuffle=False)
train_input_fn = trainer.get_train_input_fn(shuffle=True)
if opt.distributed or opt.num_workers is not None:
learn = trainer.train_and_evaluate
else:
learn = trainer.learn
logging.info("Training...")
start = datetime.now()
early_stop_metric = "rce_unweighted_" + opt.task_name
learn(
early_stop_minimize=False,
early_stop_metric=early_stop_metric,
early_stop_patience=opt.early_stop_patience,
early_stop_tolerance=opt.early_stop_tolerance,
eval_input_fn=eval_input_fn,
train_input_fn=train_input_fn,
)
end = datetime.now()
logging.info("Training time: " + str(end - start))
logging.info("Exporting the models...")
# --------------------------------------------------------
# Do the model exporting
# --------------------------------------------------------
start = datetime.now()
if not opt.export_dir:
opt.export_dir = os.path.join(opt.save_dir, "exported_models")
raw_model_path = twml.contrib.export.export_fn.export_all_models(
trainer=trainer,
export_dir=opt.export_dir,
parse_fn=feature_config.get_parse_fn(),
serving_input_receiver_fn=feature_config.get_serving_input_receiver_fn(),
export_output_fn=twml.export_output_fns.batch_prediction_continuous_output_fn,
)
export_model_dir = decode_str_or_unicode(raw_model_path)
logging.info("Model export time: " + str(datetime.now() - start))
logging.info("The saved model directory is: " + opt.save_dir)
tf.logging.info("getting default continuous_feature_list")
continuous_feature_list = get_feature_list_for_light_ranking(feature_list_path, opt.data_spec)
continous_feature_list_save_path = os.path.join(opt.save_dir, "continuous_feature_list.json")
twml.util.write_file(continous_feature_list_save_path, continuous_feature_list, encode="json")
tf.logging.info(f"Finish writting files to {continous_feature_list_save_path}")
if opt.run_light_ranking_group_metrics:
# --------------------------------------------
# Run Light Ranking Group Metrics
# --------------------------------------------
run_group_metrics_light_ranking(
trainer=trainer,
data_dir=os.path.join(opt.eval_data_dir, opt.eval_start_datetime),
model_path=export_model_dir,
parse_fn=feature_config.get_parse_fn(),
)
if opt.run_light_ranking_group_metrics_in_bq:
# ----------------------------------------------------------------------------------------
# Get Light/Heavy Ranker Predictions for Light Ranking Group Metrics in BigQuery
# ----------------------------------------------------------------------------------------
trainer_pred = DataRecordTrainer(
name=opt.model_trainer_name,
params=opt,
build_graph_fn=partial(build_graph, run_light_ranking_group_metrics_in_bq=True),
save_dir=opt.save_dir + "/tmp/",
run_config=None,
feature_config=feature_config,
metric_fn=get_metric_fn(opt.task_name, use_stratify_metrics=False),
)
checkpoint_folder = os.path.join(opt.save_dir, "best_checkpoint")
checkpoint = tf.train.latest_checkpoint(checkpoint_folder, latest_filename=None)
tf.logging.info("\n\nPrediction from Checkpoint: {:}.\n\n".format(checkpoint))
run_group_metrics_light_ranking_in_bq(
trainer=trainer_pred, params=opt, checkpoint_path=checkpoint
)
tf.logging.info("Done Training & Prediction.")
if __name__ == "__main__":
_main()

View File

@ -0,0 +1,89 @@
from datetime import datetime
from functools import partial
import os
from ..libs.group_metrics import (
run_group_metrics_light_ranking,
run_group_metrics_light_ranking_in_bq,
)
from ..libs.metric_fn_utils import get_metric_fn
from ..libs.model_args import get_arg_parser_light_ranking
from ..libs.model_utils import read_config
from .deep_norm import build_graph, DataRecordTrainer, get_config_func, logging
# checkstyle: noqa
if __name__ == "__main__":
parser = get_arg_parser_light_ranking()
parser.add_argument(
"--eval_checkpoint",
default=None,
type=str,
help="Which checkpoint to use for evaluation",
)
parser.add_argument(
"--saved_model_path",
default=None,
type=str,
help="Path to saved model for evaluation",
)
parser.add_argument(
"--run_binary_metrics",
default=False,
action="store_true",
help="Whether to compute the basic binary metrics for Light Ranking.",
)
opt = parser.parse_args()
logging.info("parse is: ")
logging.info(opt)
feature_list = read_config(opt.feature_list).items()
feature_config = get_config_func(opt.feat_config_type)(
data_spec_path=opt.data_spec,
feature_list_provided=feature_list,
opt=opt,
add_gbdt=opt.use_gbdt_features,
run_light_ranking_group_metrics_in_bq=opt.run_light_ranking_group_metrics_in_bq,
)
# -----------------------------------------------
# Create Trainer
# -----------------------------------------------
trainer = DataRecordTrainer(
name=opt.model_trainer_name,
params=opt,
build_graph_fn=partial(build_graph, run_light_ranking_group_metrics_in_bq=True),
save_dir=opt.save_dir,
run_config=None,
feature_config=feature_config,
metric_fn=get_metric_fn(opt.task_name, use_stratify_metrics=False),
)
# -----------------------------------------------
# Model Evaluation
# -----------------------------------------------
logging.info("Evaluating...")
start = datetime.now()
if opt.run_binary_metrics:
eval_input_fn = trainer.get_eval_input_fn(repeat=False, shuffle=False)
eval_steps = None if (opt.eval_steps is not None and opt.eval_steps < 0) else opt.eval_steps
trainer.estimator.evaluate(eval_input_fn, steps=eval_steps, checkpoint_path=opt.eval_checkpoint)
if opt.run_light_ranking_group_metrics_in_bq:
run_group_metrics_light_ranking_in_bq(
trainer=trainer, params=opt, checkpoint_path=opt.eval_checkpoint
)
if opt.run_light_ranking_group_metrics:
run_group_metrics_light_ranking(
trainer=trainer,
data_dir=os.path.join(opt.eval_data_dir, opt.eval_start_datetime),
model_path=opt.saved_model_path,
parse_fn=feature_config.get_parse_fn(),
)
end = datetime.now()
logging.info("Evaluating time: " + str(end - start))

View File

@ -0,0 +1,187 @@
import warnings
from twml.contrib.layers import ZscoreNormalization
from ...libs.customized_full_sparse import FullSparse
from ...libs.get_feat_config import FEAT_CONFIG_DEFAULT_VAL as MISSING_VALUE_MARKER
from ...libs.model_utils import (
_sparse_feature_fixup,
adaptive_transformation,
filter_nans_and_infs,
get_dense_out,
tensor_dropout,
)
import tensorflow.compat.v1 as tf
# checkstyle: noqa
def light_ranking_mlp_ngbdt(features, is_training, params, label=None):
return deepnorm_light_ranking(
features,
is_training,
params,
label=label,
decay=params.momentum,
dense_emb_size=params.dense_embedding_size,
base_activation=tf.keras.layers.LeakyReLU(),
input_dropout_rate=params.dropout,
use_gbdt=False,
)
def deepnorm_light_ranking(
features,
is_training,
params,
label=None,
decay=0.99999,
dense_emb_size=128,
base_activation=None,
input_dropout_rate=None,
input_dense_type="self_atten_dense",
emb_dense_type="self_atten_dense",
mlp_dense_type="self_atten_dense",
use_gbdt=False,
):
# --------------------------------------------------------
# Initial Parameter Checking
# --------------------------------------------------------
if base_activation is None:
base_activation = tf.keras.layers.LeakyReLU()
if label is not None:
warnings.warn(
"Label is unused in deepnorm_gbdt. Stop using this argument.",
DeprecationWarning,
)
with tf.variable_scope("helper_layers"):
full_sparse_layer = FullSparse(
output_size=params.sparse_embedding_size,
activation=base_activation,
use_sparse_grads=is_training,
use_binary_values=False,
dtype=tf.float32,
)
input_normalizing_layer = ZscoreNormalization(decay=decay, name="input_normalizing_layer")
# --------------------------------------------------------
# Feature Selection & Embedding
# --------------------------------------------------------
if use_gbdt:
sparse_gbdt_features = _sparse_feature_fixup(features["gbdt_sparse"], params.input_size_bits)
if input_dropout_rate is not None:
sparse_gbdt_features = tensor_dropout(
sparse_gbdt_features, input_dropout_rate, is_training, sparse_tensor=True
)
total_embed = full_sparse_layer(sparse_gbdt_features, use_binary_values=True)
if (input_dropout_rate is not None) and is_training:
total_embed = total_embed / (1 - input_dropout_rate)
else:
with tf.variable_scope("dense_branch"):
dense_continuous_features = filter_nans_and_infs(features["continuous"])
if params.use_missing_sub_branch:
is_missing = tf.equal(dense_continuous_features, MISSING_VALUE_MARKER)
continuous_features_filled = tf.where(
is_missing,
tf.zeros_like(dense_continuous_features),
dense_continuous_features,
)
normalized_features = input_normalizing_layer(
continuous_features_filled, is_training, tf.math.logical_not(is_missing)
)
with tf.variable_scope("missing_sub_branch"):
missing_feature_embed = get_dense_out(
tf.cast(is_missing, tf.float32),
dense_emb_size,
activation=base_activation,
dense_type=input_dense_type,
)
else:
continuous_features_filled = dense_continuous_features
normalized_features = input_normalizing_layer(continuous_features_filled, is_training)
with tf.variable_scope("continuous_sub_branch"):
normalized_features = adaptive_transformation(
normalized_features, is_training, func_type="tiny"
)
if input_dropout_rate is not None:
normalized_features = tensor_dropout(
normalized_features,
input_dropout_rate,
is_training,
sparse_tensor=False,
)
filled_feature_embed = get_dense_out(
normalized_features,
dense_emb_size,
activation=base_activation,
dense_type=input_dense_type,
)
if params.use_missing_sub_branch:
dense_embed = tf.concat(
[filled_feature_embed, missing_feature_embed], axis=1, name="merge_dense_emb"
)
else:
dense_embed = filled_feature_embed
with tf.variable_scope("sparse_branch"):
sparse_discrete_features = _sparse_feature_fixup(
features["sparse_no_continuous"], params.input_size_bits
)
if input_dropout_rate is not None:
sparse_discrete_features = tensor_dropout(
sparse_discrete_features, input_dropout_rate, is_training, sparse_tensor=True
)
discrete_features_embed = full_sparse_layer(sparse_discrete_features, use_binary_values=True)
if (input_dropout_rate is not None) and is_training:
discrete_features_embed = discrete_features_embed / (1 - input_dropout_rate)
total_embed = tf.concat(
[dense_embed, discrete_features_embed],
axis=1,
name="total_embed",
)
total_embed = tf.layers.batch_normalization(
total_embed,
training=is_training,
renorm_momentum=decay,
momentum=decay,
renorm=is_training,
trainable=True,
)
# --------------------------------------------------------
# MLP Layers
# --------------------------------------------------------
with tf.variable_scope("MLP_branch"):
assert params.num_mlp_layers >= 0
embed_list = [total_embed] + [None for _ in range(params.num_mlp_layers)]
dense_types = [emb_dense_type] + [mlp_dense_type for _ in range(params.num_mlp_layers - 1)]
for xl in range(1, params.num_mlp_layers + 1):
neurons = params.mlp_neuron_scale ** (params.num_mlp_layers + 1 - xl)
embed_list[xl] = get_dense_out(
embed_list[xl - 1], neurons, activation=base_activation, dense_type=dense_types[xl - 1]
)
if params.task_name in ["Sent", "HeavyRankPosition", "HeavyRankProbability"]:
logits = get_dense_out(embed_list[-1], 1, activation=None, dense_type=mlp_dense_type)
else:
raise ValueError("Invalid Task Name !")
output_dict = {"output": logits}
return output_dict

View File

@ -0,0 +1,337 @@
scala_library(
sources = ["**/*.scala"],
compiler_option_sets = ["fatal_warnings"],
strict_deps = True,
tags = [
"bazel-compatible",
],
dependencies = [
"3rdparty/jvm/com/twitter/bijection:scrooge",
"3rdparty/jvm/com/twitter/storehaus:core",
"abdecider",
"abuse/detection/src/main/thrift/com/twitter/abuse/detection/scoring:thrift-scala",
"ann/src/main/scala/com/twitter/ann/common",
"ann/src/main/thrift/com/twitter/ann/common:ann-common-scala",
"audience-rewards/thrift/src/main/thrift:thrift-scala",
"communities/thrift/src/main/thrift/com/twitter/communities:thrift-scala",
"configapi/configapi-core",
"configapi/configapi-decider",
"content-mixer/thrift/src/main/thrift:thrift-scala",
"content-recommender/thrift/src/main/thrift:thrift-scala",
"copyselectionservice/server/src/main/scala/com/twitter/copyselectionservice/algorithms",
"copyselectionservice/thrift/src/main/thrift:copyselectionservice-scala",
"cortex-deepbird/thrift/src/main/thrift:thrift-java",
"cr-mixer/thrift/src/main/thrift:thrift-scala",
"cuad/projects/hashspace/thrift:thrift-scala",
"cuad/projects/tagspace/thrift/src/main/thrift:thrift-scala",
"detopic/thrift/src/main/thrift:thrift-scala",
"discovery-common/src/main/scala/com/twitter/discovery/common/configapi",
"discovery-common/src/main/scala/com/twitter/discovery/common/ddg",
"discovery-common/src/main/scala/com/twitter/discovery/common/environment",
"discovery-common/src/main/scala/com/twitter/discovery/common/fatigue",
"discovery-common/src/main/scala/com/twitter/discovery/common/nackwarmupfilter",
"discovery-common/src/main/scala/com/twitter/discovery/common/server",
"discovery-ds/src/main/thrift/com/twitter/dds/scio/searcher_aggregate_history_srp:searcher_aggregate_history_srp-scala",
"escherbird/src/scala/com/twitter/escherbird/util/metadatastitch",
"escherbird/src/scala/com/twitter/escherbird/util/uttclient",
"escherbird/src/thrift/com/twitter/escherbird/utt:strato-columns-scala",
"eventbus/client",
"eventdetection/event_context/src/main/scala/com/twitter/eventdetection/event_context/util",
"events-recos/events-recos-service/src/main/thrift:events-recos-thrift-scala",
"explore/explore-ranker/thrift/src/main/thrift:thrift-scala",
"featureswitches/featureswitches-core/src/main/scala",
"featureswitches/featureswitches-core/src/main/scala:dynmap",
"featureswitches/featureswitches-core/src/main/scala:recipient",
"featureswitches/featureswitches-core/src/main/scala:useragent",
"featureswitches/featureswitches-core/src/main/scala/com/twitter/featureswitches/v2/builder",
"finagle-internal/mtls/src/main/scala/com/twitter/finagle/mtls/authentication",
"finagle-internal/mtls/src/main/scala/com/twitter/finagle/mtls/server",
"finagle-internal/ostrich-stats",
"finagle/finagle-core/src/main",
"finagle/finagle-http/src/main/scala",
"finagle/finagle-memcached/src/main/scala",
"finagle/finagle-stats",
"finagle/finagle-thriftmux",
"finagle/finagle-tunable/src/main/scala",
"finagle/finagle-zipkin-scribe",
"finatra-internal/abdecider",
"finatra-internal/decider",
"finatra-internal/mtls-http/src/main/scala",
"finatra-internal/mtls-thriftmux/src/main/scala",
"finatra/http-client/src/main/scala",
"finatra/http-core/src/main/java/com/twitter/finatra/http",
"finatra/http-core/src/main/scala/com/twitter/finatra/http/response",
"finatra/http-server/src/main/scala/com/twitter/finatra/http",
"finatra/http-server/src/main/scala/com/twitter/finatra/http/filters",
"finatra/inject/inject-app/src/main/java/com/twitter/inject/annotations",
"finatra/inject/inject-app/src/main/scala",
"finatra/inject/inject-core/src/main/scala",
"finatra/inject/inject-server/src/main/scala",
"finatra/inject/inject-slf4j/src/main/scala/com/twitter/inject",
"finatra/inject/inject-thrift-client/src/main/scala",
"finatra/inject/inject-utils/src/main/scala",
"finatra/utils/src/main/java/com/twitter/finatra/annotations",
"fleets/fleets-proxy/thrift/src/main/thrift:fleet-scala",
"fleets/fleets-proxy/thrift/src/main/thrift/service:baseservice-scala",
"flock-client/src/main/scala",
"flock-client/src/main/thrift:thrift-scala",
"follow-recommendations-service/thrift/src/main/thrift:thrift-scala",
"frigate/frigate-common:base",
"frigate/frigate-common:config",
"frigate/frigate-common:debug",
"frigate/frigate-common:entity_graph_client",
"frigate/frigate-common:history",
"frigate/frigate-common:logger",
"frigate/frigate-common:ml-base",
"frigate/frigate-common:ml-feature",
"frigate/frigate-common:ml-prediction",
"frigate/frigate-common:ntab",
"frigate/frigate-common:predicate",
"frigate/frigate-common:rec_types",
"frigate/frigate-common:score_summary",
"frigate/frigate-common:util",
"frigate/frigate-common/src/main/scala/com/twitter/frigate/common/candidate",
"frigate/frigate-common/src/main/scala/com/twitter/frigate/common/experiments",
"frigate/frigate-common/src/main/scala/com/twitter/frigate/common/filter",
"frigate/frigate-common/src/main/scala/com/twitter/frigate/common/modules/store:semantic_core_stores",
"frigate/frigate-common/src/main/scala/com/twitter/frigate/common/store",
"frigate/frigate-common/src/main/scala/com/twitter/frigate/common/store/deviceinfo",
"frigate/frigate-common/src/main/scala/com/twitter/frigate/common/store/interests",
"frigate/frigate-common/src/main/scala/com/twitter/frigate/common/store/strato",
"frigate/push-mixer/thrift/src/main/thrift:thrift-scala",
"geo/geo-prediction/src/main/thrift:local-viral-tweets-thrift-scala",
"geoduck/service/src/main/scala/com/twitter/geoduck/service/common/clientmodules",
"geoduck/util/country",
"gizmoduck/client/src/main/scala/com/twitter/gizmoduck/testusers/client",
"hermit/hermit-core:model-user_state",
"hermit/hermit-core:predicate",
"hermit/hermit-core:predicate-gizmoduck",
"hermit/hermit-core:predicate-scarecrow",
"hermit/hermit-core:predicate-socialgraph",
"hermit/hermit-core:predicate-tweetypie",
"hermit/hermit-core:store-labeled_push_recs",
"hermit/hermit-core:store-metastore",
"hermit/hermit-core:store-timezone",
"hermit/hermit-core:store-tweetypie",
"hermit/hermit-core/src/main/scala/com/twitter/hermit/constants",
"hermit/hermit-core/src/main/scala/com/twitter/hermit/model",
"hermit/hermit-core/src/main/scala/com/twitter/hermit/store",
"hermit/hermit-core/src/main/scala/com/twitter/hermit/store/common",
"hermit/hermit-core/src/main/scala/com/twitter/hermit/store/gizmoduck",
"hermit/hermit-core/src/main/scala/com/twitter/hermit/store/scarecrow",
"hermit/hermit-core/src/main/scala/com/twitter/hermit/store/semantic_core",
"hermit/hermit-core/src/main/scala/com/twitter/hermit/store/user_htl_session_store",
"hermit/hermit-core/src/main/scala/com/twitter/hermit/store/user_interest",
"hmli/hss/src/main/thrift/com/twitter/hss:thrift-scala",
"ibis2/service/src/main/scala/com/twitter/ibis2/lib",
"ibis2/service/src/main/thrift/com/twitter/ibis2/service:ibis2-service-scala",
"interests-service/thrift/src/main/thrift:thrift-scala",
"interests_discovery/thrift/src/main/thrift:batch-thrift-scala",
"interests_discovery/thrift/src/main/thrift:service-thrift-scala",
"kujaku/thrift/src/main/thrift:domain-scala",
"live-video-timeline/client/src/main/scala/com/twitter/livevideo/timeline/client/v2",
"live-video-timeline/domain/src/main/scala/com/twitter/livevideo/timeline/domain",
"live-video-timeline/domain/src/main/scala/com/twitter/livevideo/timeline/domain/v2",
"live-video-timeline/thrift/src/main/thrift/com/twitter/livevideo/timeline:thrift-scala",
"live-video/common/src/main/scala/com/twitter/livevideo/common/domain/v2",
"live-video/common/src/main/scala/com/twitter/livevideo/common/ids",
"notifications-platform/inbound-notifications/src/main/thrift/com/twitter/inbound_notifications:exception-scala",
"notifications-platform/inbound-notifications/src/main/thrift/com/twitter/inbound_notifications:thrift-scala",
"notifications-platform/platform-lib/src/main/thrift/com/twitter/notifications/platform:custom-notification-actions-scala",
"notifications-platform/platform-lib/src/main/thrift/com/twitter/notifications/platform:thrift-scala",
"notifications-relevance/src/scala/com/twitter/nrel/heavyranker",
"notifications-relevance/src/scala/com/twitter/nrel/hydration/base",
"notifications-relevance/src/scala/com/twitter/nrel/hydration/frigate",
"notifications-relevance/src/scala/com/twitter/nrel/hydration/push",
"notifications-relevance/src/scala/com/twitter/nrel/lightranker",
"notificationservice/common/src/main/scala/com/twitter/notificationservice/genericfeedbackstore",
"notificationservice/common/src/main/scala/com/twitter/notificationservice/model:alias",
"notificationservice/common/src/main/scala/com/twitter/notificationservice/model/service",
"notificationservice/common/src/test/scala/com/twitter/notificationservice/mocks",
"notificationservice/scribe/src/main/scala/com/twitter/notificationservice/scribe/manhattan:mh_wrapper",
"notificationservice/thrift/src/main/thrift/com/twitter/notificationservice/api:thrift-scala",
"notificationservice/thrift/src/main/thrift/com/twitter/notificationservice/badgecount-api:thrift-scala",
"notificationservice/thrift/src/main/thrift/com/twitter/notificationservice/generic_notifications:thrift-scala",
"notifinfra/ni-lib/src/main/scala/com/twitter/ni/lib/logged_out_transform",
"observability/observability-manhattan-client/src/main/scala",
"onboarding/service/src/main/scala/com/twitter/onboarding/task/service/models/external",
"onboarding/service/thrift/src/main/thrift:thrift-scala",
"people-discovery/api/thrift/src/main/thrift:thrift-scala",
"periscope/api-proxy-thrift/thrift/src/main/thrift:thrift-scala",
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/module",
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/module/stringcenter",
"product-mixer/core/src/main/thrift/com/twitter/product_mixer/core:thrift-scala",
"qig-ranker/thrift/src/main/thrift:thrift-scala",
"rux-ds/src/main/thrift/com/twitter/ruxds/jobs/user_past_aggregate:user_past_aggregate-scala",
"rux/common/src/main/scala/com/twitter/rux/common/encode",
"rux/common/thrift/src/main/thrift/rux-context:rux-context-scala",
"rux/common/thrift/src/main/thrift/strato:strato-scala",
"scribelib/marshallers/src/main/scala/com/twitter/scribelib/marshallers",
"scrooge/scrooge-core",
"scrooge/scrooge-serializer/src/main/scala",
"sensitive-ds/src/main/thrift/com/twitter/scio/nsfw_user_segmentation:nsfw_user_segmentation-scala",
"servo/decider/src/main/scala",
"servo/request/src/main/scala",
"servo/util/src/main/scala",
"src/java/com/twitter/ml/api:api-base",
"src/java/com/twitter/ml/prediction/core",
"src/scala/com/twitter/frigate/data_pipeline/common",
"src/scala/com/twitter/frigate/data_pipeline/embedding_cg:embedding_cg-test-user-ids",
"src/scala/com/twitter/frigate/data_pipeline/features_common",
"src/scala/com/twitter/frigate/news_article_recs/news_articles_metadata:thrift-scala",
"src/scala/com/twitter/frontpage/stream/util",
"src/scala/com/twitter/language/normalization",
"src/scala/com/twitter/ml/api/embedding",
"src/scala/com/twitter/ml/api/util:datarecord",
"src/scala/com/twitter/ml/featurestore/catalog/entities/core",
"src/scala/com/twitter/ml/featurestore/catalog/entities/magicrecs",
"src/scala/com/twitter/ml/featurestore/catalog/features/core:aggregate",
"src/scala/com/twitter/ml/featurestore/catalog/features/cuad:aggregate",
"src/scala/com/twitter/ml/featurestore/catalog/features/embeddings",
"src/scala/com/twitter/ml/featurestore/catalog/features/magicrecs:aggregate",
"src/scala/com/twitter/ml/featurestore/catalog/features/topic_signals:aggregate",
"src/scala/com/twitter/ml/featurestore/lib",
"src/scala/com/twitter/ml/featurestore/lib/data",
"src/scala/com/twitter/ml/featurestore/lib/dynamic",
"src/scala/com/twitter/ml/featurestore/lib/entity",
"src/scala/com/twitter/ml/featurestore/lib/online",
"src/scala/com/twitter/recommendation/interests/discovery/core/config",
"src/scala/com/twitter/recommendation/interests/discovery/core/deploy",
"src/scala/com/twitter/recommendation/interests/discovery/core/model",
"src/scala/com/twitter/recommendation/interests/discovery/popgeo/deploy",
"src/scala/com/twitter/simclusters_v2/common",
"src/scala/com/twitter/storehaus_internal/manhattan",
"src/scala/com/twitter/storehaus_internal/manhattan/config",
"src/scala/com/twitter/storehaus_internal/memcache",
"src/scala/com/twitter/storehaus_internal/memcache/config",
"src/scala/com/twitter/storehaus_internal/util",
"src/scala/com/twitter/taxi/common",
"src/scala/com/twitter/taxi/config",
"src/scala/com/twitter/taxi/deploy",
"src/scala/com/twitter/taxi/trending/common",
"src/thrift/com/twitter/ads/adserver:adserver_rpc-scala",
"src/thrift/com/twitter/clientapp/gen:clientapp-scala",
"src/thrift/com/twitter/core_workflows/user_model:user_model-scala",
"src/thrift/com/twitter/escherbird/common:constants-scala",
"src/thrift/com/twitter/escherbird/metadata:megadata-scala",
"src/thrift/com/twitter/escherbird/metadata:metadata-service-scala",
"src/thrift/com/twitter/escherbird/search:search-service-scala",
"src/thrift/com/twitter/expandodo:only-scala",
"src/thrift/com/twitter/frigate:frigate-common-thrift-scala",
"src/thrift/com/twitter/frigate:frigate-ml-thrift-scala",
"src/thrift/com/twitter/frigate:frigate-notification-thrift-scala",
"src/thrift/com/twitter/frigate:frigate-secondary-accounts-thrift-scala",
"src/thrift/com/twitter/frigate:frigate-thrift-scala",
"src/thrift/com/twitter/frigate:frigate-user-media-representation-thrift-scala",
"src/thrift/com/twitter/frigate/data_pipeline:frigate-user-history-thrift-scala",
"src/thrift/com/twitter/frigate/dau_model:frigate-dau-thrift-scala",
"src/thrift/com/twitter/frigate/magic_events:frigate-magic-events-thrift-scala",
"src/thrift/com/twitter/frigate/magic_events/scribe:thrift-scala",
"src/thrift/com/twitter/frigate/pushcap:frigate-pushcap-thrift-scala",
"src/thrift/com/twitter/frigate/pushservice:frigate-pushservice-thrift-scala",
"src/thrift/com/twitter/frigate/scribe:frigate-scribe-thrift-scala",
"src/thrift/com/twitter/frigate/subscribed_search:frigate-subscribed-search-thrift-scala",
"src/thrift/com/twitter/frigate/user_states:frigate-userstates-thrift-scala",
"src/thrift/com/twitter/geoduck:geoduck-scala",
"src/thrift/com/twitter/gizmoduck:thrift-scala",
"src/thrift/com/twitter/gizmoduck:user-thrift-scala",
"src/thrift/com/twitter/hermit:hermit-scala",
"src/thrift/com/twitter/hermit/pop_geo:hermit-pop-geo-scala",
"src/thrift/com/twitter/hermit/stp:hermit-stp-scala",
"src/thrift/com/twitter/ibis:service-scala",
"src/thrift/com/twitter/manhattan:v1-scala",
"src/thrift/com/twitter/manhattan:v2-scala",
"src/thrift/com/twitter/ml/api:data-java",
"src/thrift/com/twitter/ml/api:data-scala",
"src/thrift/com/twitter/ml/featurestore/timelines:ml-features-timelines-scala",
"src/thrift/com/twitter/ml/featurestore/timelines:ml-features-timelines-strato",
"src/thrift/com/twitter/ml/prediction_service:prediction_service-java",
"src/thrift/com/twitter/permissions_storage:thrift-scala",
"src/thrift/com/twitter/pink-floyd/thrift:thrift-scala",
"src/thrift/com/twitter/recos:recos-common-scala",
"src/thrift/com/twitter/recos/user_tweet_entity_graph:user_tweet_entity_graph-scala",
"src/thrift/com/twitter/recos/user_user_graph:user_user_graph-scala",
"src/thrift/com/twitter/relevance/feature_store:feature_store-scala",
"src/thrift/com/twitter/search:earlybird-scala",
"src/thrift/com/twitter/search/common:features-scala",
"src/thrift/com/twitter/search/query_interaction_graph:query_interaction_graph-scala",
"src/thrift/com/twitter/search/query_interaction_graph/service:qig-service-scala",
"src/thrift/com/twitter/service/metastore/gen:thrift-scala",
"src/thrift/com/twitter/service/scarecrow/gen:scarecrow-scala",
"src/thrift/com/twitter/service/scarecrow/gen:tiered-actions-scala",
"src/thrift/com/twitter/simclusters_v2:simclusters_v2-thrift-scala",
"src/thrift/com/twitter/socialgraph:thrift-scala",
"src/thrift/com/twitter/spam/rtf:safety-level-scala",
"src/thrift/com/twitter/timelinemixer:thrift-scala",
"src/thrift/com/twitter/timelinemixer/server/internal:thrift-scala",
"src/thrift/com/twitter/timelines/author_features/user_health:thrift-scala",
"src/thrift/com/twitter/timelines/real_graph:real_graph-scala",
"src/thrift/com/twitter/timelinescorer:thrift-scala",
"src/thrift/com/twitter/timelinescorer/server/internal:thrift-scala",
"src/thrift/com/twitter/timelineservice/server/internal:thrift-scala",
"src/thrift/com/twitter/timelineservice/server/suggests/logging:thrift-scala",
"src/thrift/com/twitter/trends/common:common-scala",
"src/thrift/com/twitter/trends/trip_v1:trip-tweets-thrift-scala",
"src/thrift/com/twitter/tweetypie:service-scala",
"src/thrift/com/twitter/tweetypie:tweet-scala",
"src/thrift/com/twitter/user_session_store:thrift-scala",
"src/thrift/com/twitter/wtf/candidate:wtf-candidate-scala",
"src/thrift/com/twitter/wtf/interest:interest-thrift-scala",
"src/thrift/com/twitter/wtf/scalding/common:thrift-scala",
"stitch/stitch-core",
"stitch/stitch-gizmoduck",
"stitch/stitch-socialgraph/src/main/scala",
"stitch/stitch-storehaus/src/main/scala",
"stitch/stitch-tweetypie/src/main/scala",
"storage/clients/manhattan/client/src/main/scala",
"strato/config/columns/clients:clients-strato-client",
"strato/config/columns/geo/user:user-strato-client",
"strato/config/columns/globe/curation:curation-strato-client",
"strato/config/columns/interests:interests-strato-client",
"strato/config/columns/ml/featureStore:featureStore-strato-client",
"strato/config/columns/notifications:notifications-strato-client",
"strato/config/columns/notifinfra:notifinfra-strato-client",
"strato/config/columns/periscope:periscope-strato-client",
"strato/config/columns/rux",
"strato/config/columns/rux:rux-strato-client",
"strato/config/columns/rux/open-app:open-app-strato-client",
"strato/config/columns/socialgraph/graphs:graphs-strato-client",
"strato/config/columns/socialgraph/service/soft_users:soft_users-strato-client",
"strato/config/columns/translation/service:service-strato-client",
"strato/config/columns/translation/service/platform:platform-strato-client",
"strato/config/columns/trends/trip:trip-strato-client",
"strato/config/src/thrift/com/twitter/strato/columns/frigate:logged-out-web-notifications-scala",
"strato/config/src/thrift/com/twitter/strato/columns/notifications:thrift-scala",
"strato/src/main/scala/com/twitter/strato/config",
"strato/src/main/scala/com/twitter/strato/response",
"thrift-web-forms",
"timeline-training-service/service/thrift/src/main/thrift:thrift-scala",
"timelines/src/main/scala/com/twitter/timelines/features/app",
"topic-social-proof/server/src/main/thrift:thrift-scala",
"topiclisting/topiclisting-core/src/main/scala/com/twitter/topiclisting",
"topiclisting/topiclisting-utt/src/main/scala/com/twitter/topiclisting/utt",
"trends/common/src/main/thrift/com/twitter/trends/common:thrift-scala",
"tweetypie/src/scala/com/twitter/tweetypie/tweettext",
"twitter-context/src/main/scala",
"twitter-server-internal",
"twitter-server/server/src/main/scala",
"twitter-text/lib/java/src/main/java/com/twitter/twittertext",
"twml/runtime/src/main/scala/com/twitter/deepbird/runtime/prediction_engine:prediction_engine_mkl",
"ubs/common/src/main/thrift/com/twitter/ubs:broadcast-thrift-scala",
"ubs/common/src/main/thrift/com/twitter/ubs:seller_application-thrift-scala",
"user_session_store/src/main/scala/com/twitter/user_session_store/impl/manhattan/readwrite",
"util-internal/scribe",
"util-internal/tunable/src/main/scala/com/twitter/util/tunable",
"util/util-app",
"util/util-hashing/src/main/scala",
"util/util-slf4j-api/src/main/scala",
"util/util-stats/src/main/scala",
"visibility/lib/src/main/scala/com/twitter/visibility/builder",
"visibility/lib/src/main/scala/com/twitter/visibility/interfaces/push_service",
"visibility/lib/src/main/scala/com/twitter/visibility/interfaces/spaces",
"visibility/lib/src/main/scala/com/twitter/visibility/util",
],
exports = [
"strato/config/src/thrift/com/twitter/strato/columns/frigate:logged-out-web-notifications-scala",
],
)

View File

@ -0,0 +1,93 @@
package com.twitter.frigate.pushservice
import com.google.inject.Inject
import com.google.inject.Singleton
import com.twitter.finagle.mtls.authentication.ServiceIdentifier
import com.twitter.finagle.thrift.ClientId
import com.twitter.finatra.thrift.routing.ThriftWarmup
import com.twitter.util.logging.Logging
import com.twitter.inject.utils.Handler
import com.twitter.frigate.pushservice.{thriftscala => t}
import com.twitter.frigate.thriftscala.NotificationDisplayLocation
import com.twitter.util.Stopwatch
import com.twitter.scrooge.Request
import com.twitter.scrooge.Response
import com.twitter.util.Return
import com.twitter.util.Throw
import com.twitter.util.Try
/**
* Warms up the refresh request path.
* If service is running as pushservice-send then the warmup does nothing.
*
* When making the warmup refresh requests we
* - Set skipFilters to true to execute as much of the request path as possible
* - Set darkWrite to true to prevent sending a push
*/
@Singleton
class PushMixerThriftServerWarmupHandler @Inject() (
warmup: ThriftWarmup,
serviceIdentifier: ServiceIdentifier)
extends Handler
with Logging {
private val clientId = ClientId("thrift-warmup-client")
def handle(): Unit = {
val refreshServices = Set(
"frigate-pushservice",
"frigate-pushservice-canary",
"frigate-pushservice-canary-control",
"frigate-pushservice-canary-treatment"
)
val isRefresh = refreshServices.contains(serviceIdentifier.service)
if (isRefresh && !serviceIdentifier.isLocal) refreshWarmup()
}
def refreshWarmup(): Unit = {
val elapsed = Stopwatch.start()
val testIds = Seq(
1,
2,
3
)
try {
clientId.asCurrent {
testIds.foreach { id =>
val warmupReq = warmupQuery(id)
info(s"Sending warm-up request to service with query: $warmupReq")
warmup.sendRequest(
method = t.PushService.Refresh,
req = Request(t.PushService.Refresh.Args(warmupReq)))(assertWarmupResponse)
}
}
} catch {
case e: Throwable =>
error(e.getMessage, e)
}
info(s"Warm up complete. Time taken: ${elapsed().toString}")
}
private def warmupQuery(userId: Long): t.RefreshRequest = {
t.RefreshRequest(
userId = userId,
notificationDisplayLocation = NotificationDisplayLocation.PushToMobileDevice,
context = Some(
t.PushContext(
skipFilters = Some(true),
darkWrite = Some(true)
))
)
}
private def assertWarmupResponse(
result: Try[Response[t.PushService.Refresh.SuccessType]]
): Unit = {
result match {
case Return(_) => // ok
case Throw(exception) =>
warn("Error performing warm-up request.")
error(exception.getMessage, exception)
}
}
}

View File

@ -0,0 +1,193 @@
package com.twitter.frigate.pushservice
import com.twitter.discovery.common.environment.modules.EnvironmentModule
import com.twitter.finagle.Filter
import com.twitter.finatra.annotations.DarkTrafficFilterType
import com.twitter.finatra.decider.modules.DeciderModule
import com.twitter.finatra.http.HttpServer
import com.twitter.finatra.http.filters.CommonFilters
import com.twitter.finatra.http.routing.HttpRouter
import com.twitter.finatra.mtls.http.{Mtls => HttpMtls}
import com.twitter.finatra.mtls.thriftmux.{Mtls => ThriftMtls}
import com.twitter.finatra.mtls.thriftmux.filters.MtlsServerSessionTrackerFilter
import com.twitter.finatra.thrift.ThriftServer
import com.twitter.finatra.thrift.filters.ExceptionMappingFilter
import com.twitter.finatra.thrift.filters.LoggingMDCFilter
import com.twitter.finatra.thrift.filters.StatsFilter
import com.twitter.finatra.thrift.filters.ThriftMDCFilter
import com.twitter.finatra.thrift.filters.TraceIdMDCFilter
import com.twitter.finatra.thrift.routing.ThriftRouter
import com.twitter.frigate.common.logger.MRLoggerGlobalVariables
import com.twitter.frigate.pushservice.controller.PushServiceController
import com.twitter.frigate.pushservice.module._
import com.twitter.inject.TwitterModule
import com.twitter.inject.annotations.Flags
import com.twitter.inject.thrift.modules.ThriftClientIdModule
import com.twitter.logging.BareFormatter
import com.twitter.logging.Level
import com.twitter.logging.LoggerFactory
import com.twitter.logging.{Logging => JLogging}
import com.twitter.logging.QueueingHandler
import com.twitter.logging.ScribeHandler
import com.twitter.product_mixer.core.module.product_mixer_flags.ProductMixerFlagModule
import com.twitter.product_mixer.core.module.ABDeciderModule
import com.twitter.product_mixer.core.module.FeatureSwitchesModule
import com.twitter.product_mixer.core.module.StratoClientModule
object PushServiceMain extends PushServiceFinatraServer
class PushServiceFinatraServer
extends ThriftServer
with ThriftMtls
with HttpServer
with HttpMtls
with JLogging {
override val name = "PushService"
override val modules: Seq[TwitterModule] = {
Seq(
ABDeciderModule,
DeciderModule,
FeatureSwitchesModule,
FilterModule,
FlagModule,
EnvironmentModule,
ThriftClientIdModule,
DeployConfigModule,
ProductMixerFlagModule,
StratoClientModule,
PushHandlerModule,
PushTargetUserBuilderModule,
PushServiceDarkTrafficModule,
LoggedOutPushTargetUserBuilderModule,
new ThriftWebFormsModule(this),
)
}
override def configureThrift(router: ThriftRouter): Unit = {
router
.filter[ExceptionMappingFilter]
.filter[LoggingMDCFilter]
.filter[TraceIdMDCFilter]
.filter[ThriftMDCFilter]
.filter[MtlsServerSessionTrackerFilter]
.filter[StatsFilter]
.filter[Filter.TypeAgnostic, DarkTrafficFilterType]
.add[PushServiceController]
}
override def configureHttp(router: HttpRouter): Unit =
router
.filter[CommonFilters]
override protected def start(): Unit = {
MRLoggerGlobalVariables.setRequiredFlags(
traceLogFlag = injector.instance[Boolean](Flags.named(FlagModule.mrLoggerIsTraceAll.name)),
nthLogFlag = injector.instance[Boolean](Flags.named(FlagModule.mrLoggerNthLog.name)),
nthLogValFlag = injector.instance[Long](Flags.named(FlagModule.mrLoggerNthVal.name))
)
}
override protected def warmup(): Unit = {
handle[PushMixerThriftServerWarmupHandler]()
}
override protected def configureLoggerFactories(): Unit = {
loggerFactories.foreach { _() }
}
override def loggerFactories: List[LoggerFactory] = {
val scribeScope = statsReceiver.scope("scribe")
List(
LoggerFactory(
level = Some(levelFlag()),
handlers = handlers
),
LoggerFactory(
node = "request_scribe",
level = Some(Level.INFO),
useParents = false,
handlers = QueueingHandler(
maxQueueSize = 10000,
handler = ScribeHandler(
category = "frigate_pushservice_log",
formatter = BareFormatter,
statsReceiver = scribeScope.scope("frigate_pushservice_log")
)
) :: Nil
),
LoggerFactory(
node = "notification_scribe",
level = Some(Level.INFO),
useParents = false,
handlers = QueueingHandler(
maxQueueSize = 10000,
handler = ScribeHandler(
category = "frigate_notifier",
formatter = BareFormatter,
statsReceiver = scribeScope.scope("frigate_notifier")
)
) :: Nil
),
LoggerFactory(
node = "push_scribe",
level = Some(Level.INFO),
useParents = false,
handlers = QueueingHandler(
maxQueueSize = 10000,
handler = ScribeHandler(
category = "test_frigate_push",
formatter = BareFormatter,
statsReceiver = scribeScope.scope("test_frigate_push")
)
) :: Nil
),
LoggerFactory(
node = "push_subsample_scribe",
level = Some(Level.INFO),
useParents = false,
handlers = QueueingHandler(
maxQueueSize = 2500,
handler = ScribeHandler(
category = "magicrecs_candidates_subsample_scribe",
maxMessagesPerTransaction = 250,
maxMessagesToBuffer = 2500,
formatter = BareFormatter,
statsReceiver = scribeScope.scope("magicrecs_candidates_subsample_scribe")
)
) :: Nil
),
LoggerFactory(
node = "mr_request_scribe",
level = Some(Level.INFO),
useParents = false,
handlers = QueueingHandler(
maxQueueSize = 2500,
handler = ScribeHandler(
category = "mr_request_scribe",
maxMessagesPerTransaction = 250,
maxMessagesToBuffer = 2500,
formatter = BareFormatter,
statsReceiver = scribeScope.scope("mr_request_scribe")
)
) :: Nil
),
LoggerFactory(
node = "high_quality_candidates_scribe",
level = Some(Level.INFO),
useParents = false,
handlers = QueueingHandler(
maxQueueSize = 2500,
handler = ScribeHandler(
category = "frigate_high_quality_candidates_log",
maxMessagesPerTransaction = 250,
maxMessagesToBuffer = 2500,
formatter = BareFormatter,
statsReceiver = scribeScope.scope("high_quality_candidates_scribe")
)
) :: Nil
),
)
}
}

View File

@ -0,0 +1,323 @@
package com.twitter.frigate.pushservice.adaptor
import com.twitter.contentrecommender.thriftscala.MetricTag
import com.twitter.cr_mixer.thriftscala.CrMixerTweetRequest
import com.twitter.cr_mixer.thriftscala.NotificationsContext
import com.twitter.cr_mixer.thriftscala.Product
import com.twitter.cr_mixer.thriftscala.ProductContext
import com.twitter.cr_mixer.thriftscala.{MetricTag => CrMixerMetricTag}
import com.twitter.finagle.stats.Stat
import com.twitter.finagle.stats.StatsReceiver
import com.twitter.frigate.common.base.AlgorithmScore
import com.twitter.frigate.common.base.CandidateSource
import com.twitter.frigate.common.base.CandidateSourceEligible
import com.twitter.frigate.common.base.CrMixerCandidate
import com.twitter.frigate.common.base.TopicCandidate
import com.twitter.frigate.common.base.TopicProofTweetCandidate
import com.twitter.frigate.common.base.TweetCandidate
import com.twitter.frigate.common.predicate.CommonOutNetworkTweetCandidatesSourcePredicates.filterOutInNetworkTweets
import com.twitter.frigate.common.predicate.CommonOutNetworkTweetCandidatesSourcePredicates.filterOutReplyTweet
import com.twitter.frigate.pushservice.model.PushTypes.RawCandidate
import com.twitter.frigate.pushservice.model.PushTypes.Target
import com.twitter.frigate.pushservice.params.PushFeatureSwitchParams
import com.twitter.frigate.pushservice.params.PushParams
import com.twitter.frigate.pushservice.store.CrMixerTweetStore
import com.twitter.frigate.pushservice.store.UttEntityHydrationStore
import com.twitter.frigate.pushservice.util.AdaptorUtils
import com.twitter.frigate.pushservice.util.PushDeviceUtil
import com.twitter.frigate.pushservice.util.TopicsUtil
import com.twitter.frigate.pushservice.util.TweetWithTopicProof
import com.twitter.frigate.thriftscala.CommonRecommendationType
import com.twitter.hermit.predicate.socialgraph.RelationEdge
import com.twitter.product_mixer.core.thriftscala.ClientContext
import com.twitter.stitch.tweetypie.TweetyPie.TweetyPieResult
import com.twitter.storehaus.ReadableStore
import com.twitter.topiclisting.utt.LocalizedEntity
import com.twitter.tsp.thriftscala.TopicSocialProofRequest
import com.twitter.tsp.thriftscala.TopicSocialProofResponse
import com.twitter.util.Future
import scala.collection.Map
case class ContentRecommenderMixerAdaptor(
crMixerTweetStore: CrMixerTweetStore,
tweetyPieStore: ReadableStore[Long, TweetyPieResult],
edgeStore: ReadableStore[RelationEdge, Boolean],
topicSocialProofServiceStore: ReadableStore[TopicSocialProofRequest, TopicSocialProofResponse],
uttEntityHydrationStore: UttEntityHydrationStore,
globalStats: StatsReceiver)
extends CandidateSource[Target, RawCandidate]
with CandidateSourceEligible[Target, RawCandidate] {
override val name: String = this.getClass.getSimpleName
private[this] val stats = globalStats.scope("ContentRecommenderMixerAdaptor")
private[this] val numOfValidAuthors = stats.stat("num_of_valid_authors")
private[this] val numOutOfMaximumDropped = stats.stat("dropped_due_out_of_maximum")
private[this] val totalInputRecs = stats.counter("input_recs")
private[this] val totalOutputRecs = stats.stat("output_recs")
private[this] val totalRequests = stats.counter("total_requests")
private[this] val nonReplyTweetsCounter = stats.counter("non_reply_tweets")
private[this] val totalOutNetworkRecs = stats.counter("out_network_tweets")
private[this] val totalInNetworkRecs = stats.counter("in_network_tweets")
/**
* Builds OON raw candidates based on input OON Tweets
*/
def buildOONRawCandidates(
inputTarget: Target,
oonTweets: Seq[TweetyPieResult],
tweetScoreMap: Map[Long, Double],
tweetIdToTagsMap: Map[Long, Seq[CrMixerMetricTag]],
maxNumOfCandidates: Int
): Option[Seq[RawCandidate]] = {
val cands = oonTweets.flatMap { tweetResult =>
val tweetId = tweetResult.tweet.id
generateOONRawCandidate(
inputTarget,
tweetId,
Some(tweetResult),
tweetScoreMap,
tweetIdToTagsMap
)
}
val candidates = restrict(
maxNumOfCandidates,
cands,
numOutOfMaximumDropped,
totalOutputRecs
)
Some(candidates)
}
/**
* Builds a single RawCandidate With TopicProofTweetCandidate
*/
def buildTopicTweetRawCandidate(
inputTarget: Target,
tweetWithTopicProof: TweetWithTopicProof,
localizedEntity: LocalizedEntity,
tags: Option[Seq[MetricTag]],
): RawCandidate with TopicProofTweetCandidate = {
new RawCandidate with TopicProofTweetCandidate {
override def target: Target = inputTarget
override def topicListingSetting: Option[String] = Some(
tweetWithTopicProof.topicListingSetting)
override def tweetId: Long = tweetWithTopicProof.tweetId
override def tweetyPieResult: Option[TweetyPieResult] = Some(
tweetWithTopicProof.tweetyPieResult)
override def semanticCoreEntityId: Option[Long] = Some(tweetWithTopicProof.topicId)
override def localizedUttEntity: Option[LocalizedEntity] = Some(localizedEntity)
override def algorithmCR: Option[String] = tweetWithTopicProof.algorithmCR
override def tagsCR: Option[Seq[MetricTag]] = tags
override def isOutOfNetwork: Boolean = tweetWithTopicProof.isOON
}
}
/**
* Takes a group of TopicTweets and transforms them into RawCandidates
*/
def buildTopicTweetRawCandidates(
inputTarget: Target,
topicProofCandidates: Seq[TweetWithTopicProof],
tweetIdToTagsMap: Map[Long, Seq[CrMixerMetricTag]],
maxNumberOfCands: Int
): Future[Option[Seq[RawCandidate]]] = {
val semanticCoreEntityIds = topicProofCandidates
.map(_.topicId)
.toSet
TopicsUtil
.getLocalizedEntityMap(inputTarget, semanticCoreEntityIds, uttEntityHydrationStore)
.map { localizedEntityMap =>
val rawCandidates = topicProofCandidates.collect {
case topicSocialProof: TweetWithTopicProof
if localizedEntityMap.contains(topicSocialProof.topicId) =>
// Once we deprecate CR calls, we should replace this code to use the CrMixerMetricTag
val tags = tweetIdToTagsMap.get(topicSocialProof.tweetId).map {
_.flatMap { tag => MetricTag.get(tag.value) }
}
buildTopicTweetRawCandidate(
inputTarget,
topicSocialProof,
localizedEntityMap(topicSocialProof.topicId),
tags
)
}
val candResult = restrict(
maxNumberOfCands,
rawCandidates,
numOutOfMaximumDropped,
totalOutputRecs
)
Some(candResult)
}
}
private def generateOONRawCandidate(
inputTarget: Target,
id: Long,
result: Option[TweetyPieResult],
tweetScoreMap: Map[Long, Double],
tweetIdToTagsMap: Map[Long, Seq[CrMixerMetricTag]]
): Option[RawCandidate with TweetCandidate] = {
val tagsFromCR = tweetIdToTagsMap.get(id).map { _.flatMap { tag => MetricTag.get(tag.value) } }
val candidate = new RawCandidate with CrMixerCandidate with TopicCandidate with AlgorithmScore {
override val tweetId = id
override val target = inputTarget
override val tweetyPieResult = result
override val localizedUttEntity = None
override val semanticCoreEntityId = None
override def commonRecType =
getMediaBasedCRT(
CommonRecommendationType.TwistlyTweet,
CommonRecommendationType.TwistlyPhoto,
CommonRecommendationType.TwistlyVideo)
override def tagsCR = tagsFromCR
override def algorithmScore = tweetScoreMap.get(id)
override def algorithmCR = None
}
Some(candidate)
}
private def restrict(
maxNumToReturn: Int,
candidates: Seq[RawCandidate],
numOutOfMaximumDropped: Stat,
totalOutputRecs: Stat
): Seq[RawCandidate] = {
val newCandidates = candidates.take(maxNumToReturn)
val numDropped = candidates.length - newCandidates.length
numOutOfMaximumDropped.add(numDropped)
totalOutputRecs.add(newCandidates.size)
newCandidates
}
private def buildCrMixerRequest(
target: Target,
countryCode: Option[String],
language: Option[String],
seenTweets: Seq[Long]
): CrMixerTweetRequest = {
CrMixerTweetRequest(
clientContext = ClientContext(
userId = Some(target.targetId),
countryCode = countryCode,
languageCode = language
),
product = Product.Notifications,
productContext = Some(ProductContext.NotificationsContext(NotificationsContext())),
excludedTweetIds = Some(seenTweets)
)
}
private def selectCandidatesToSendBasedOnSettings(
isRecommendationsEligible: Boolean,
isTopicsEligible: Boolean,
oonRawCandidates: Option[Seq[RawCandidate]],
topicTweetCandidates: Option[Seq[RawCandidate]]
): Option[Seq[RawCandidate]] = {
if (isRecommendationsEligible && isTopicsEligible) {
Some(topicTweetCandidates.getOrElse(Seq.empty) ++ oonRawCandidates.getOrElse(Seq.empty))
} else if (isRecommendationsEligible) {
oonRawCandidates
} else if (isTopicsEligible) {
topicTweetCandidates
} else None
}
override def get(target: Target): Future[Option[Seq[RawCandidate]]] = {
Future
.join(
target.seenTweetIds,
target.countryCode,
target.inferredUserDeviceLanguage,
PushDeviceUtil.isTopicsEligible(target),
PushDeviceUtil.isRecommendationsEligible(target)
).flatMap {
case (seenTweets, countryCode, language, isTopicsEligible, isRecommendationsEligible) =>
val request = buildCrMixerRequest(target, countryCode, language, seenTweets)
crMixerTweetStore.getTweetRecommendations(request).flatMap {
case Some(response) =>
totalInputRecs.incr(response.tweets.size)
totalRequests.incr()
AdaptorUtils
.getTweetyPieResults(
response.tweets.map(_.tweetId).toSet,
tweetyPieStore).flatMap { tweetyPieResultMap =>
filterOutInNetworkTweets(
target,
filterOutReplyTweet(tweetyPieResultMap.toMap, nonReplyTweetsCounter),
edgeStore,
numOfValidAuthors).flatMap {
outNetworkTweetsWithId: Seq[(Long, TweetyPieResult)] =>
totalOutNetworkRecs.incr(outNetworkTweetsWithId.size)
totalInNetworkRecs.incr(response.tweets.size - outNetworkTweetsWithId.size)
val outNetworkTweets: Seq[TweetyPieResult] = outNetworkTweetsWithId.map {
case (_, tweetyPieResult) => tweetyPieResult
}
val tweetIdToTagsMap = response.tweets.map { tweet =>
tweet.tweetId -> tweet.metricTags.getOrElse(Seq.empty)
}.toMap
val tweetScoreMap = response.tweets.map { tweet =>
tweet.tweetId -> tweet.score
}.toMap
val maxNumOfCandidates =
target.params(PushFeatureSwitchParams.NumberOfMaxCrMixerCandidatesParam)
val oonRawCandidates =
buildOONRawCandidates(
target,
outNetworkTweets,
tweetScoreMap,
tweetIdToTagsMap,
maxNumOfCandidates)
TopicsUtil
.getTopicSocialProofs(
target,
outNetworkTweets,
topicSocialProofServiceStore,
edgeStore,
PushFeatureSwitchParams.TopicProofTweetCandidatesTopicScoreThreshold).flatMap {
tweetsWithTopicProof =>
buildTopicTweetRawCandidates(
target,
tweetsWithTopicProof,
tweetIdToTagsMap,
maxNumOfCandidates)
}.map { topicTweetCandidates =>
selectCandidatesToSendBasedOnSettings(
isRecommendationsEligible,
isTopicsEligible,
oonRawCandidates,
topicTweetCandidates)
}
}
}
case _ => Future.None
}
}
}
/**
* For a user to be available the following news to happen
*/
override def isCandidateSourceAvailable(target: Target): Future[Boolean] = {
Future
.join(
PushDeviceUtil.isRecommendationsEligible(target),
PushDeviceUtil.isTopicsEligible(target)
).map {
case (isRecommendationsEligible, isTopicsEligible) =>
(isRecommendationsEligible || isTopicsEligible) &&
target.params(PushParams.ContentRecommenderMixerAdaptorDecider)
}
}
}

View File

@ -0,0 +1,293 @@
package com.twitter.frigate.pushservice.adaptor
import com.twitter.finagle.stats.Stat
import com.twitter.finagle.stats.StatsReceiver
import com.twitter.frigate.common.base._
import com.twitter.frigate.common.candidate._
import com.twitter.frigate.common.predicate.CommonOutNetworkTweetCandidatesSourcePredicates.filterOutReplyTweet
import com.twitter.frigate.pushservice.model.PushTypes.RawCandidate
import com.twitter.frigate.pushservice.model.PushTypes.Target
import com.twitter.frigate.pushservice.params.PushFeatureSwitchParams
import com.twitter.frigate.pushservice.params.PushParams
import com.twitter.frigate.pushservice.util.PushDeviceUtil
import com.twitter.hermit.store.tweetypie.UserTweet
import com.twitter.recos.recos_common.thriftscala.SocialProofType
import com.twitter.search.common.features.thriftscala.ThriftSearchResultFeatures
import com.twitter.stitch.tweetypie.TweetyPie.TweetyPieResult
import com.twitter.storehaus.ReadableStore
import com.twitter.timelines.configapi.Param
import com.twitter.util.Future
import com.twitter.util.Time
import scala.collection.Map
case class EarlyBirdFirstDegreeCandidateAdaptor(
earlyBirdFirstDegreeCandidates: CandidateSource[
EarlybirdCandidateSource.Query,
EarlybirdCandidate
],
tweetyPieStore: ReadableStore[Long, TweetyPieResult],
tweetyPieStoreNoVF: ReadableStore[Long, TweetyPieResult],
userTweetTweetyPieStore: ReadableStore[UserTweet, TweetyPieResult],
maxResultsParam: Param[Int],
globalStats: StatsReceiver)
extends CandidateSource[Target, RawCandidate]
with CandidateSourceEligible[Target, RawCandidate] {
type EBCandidate = EarlybirdCandidate with TweetDetails
private val stats = globalStats.scope("EarlyBirdFirstDegreeAdaptor")
private val earlyBirdCandsStat: Stat = stats.stat("early_bird_cands_dist")
private val emptyEarlyBirdCands = stats.counter("empty_early_bird_candidates")
private val seedSetEmpty = stats.counter("empty_seedset")
private val seenTweetsStat = stats.stat("filtered_by_seen_tweets")
private val emptyTweetyPieResult = stats.stat("empty_tweetypie_result")
private val nonReplyTweetsCounter = stats.counter("non_reply_tweets")
private val enableRetweets = stats.counter("enable_retweets")
private val f1withoutSocialContexts = stats.counter("f1_without_social_context")
private val userTweetTweetyPieStoreCounter = stats.counter("user_tweet_tweetypie_store")
override val name: String = earlyBirdFirstDegreeCandidates.name
private def getAllSocialContextActions(
socialProofTypes: Seq[(SocialProofType, Seq[Long])]
): Seq[SocialContextAction] = {
socialProofTypes.flatMap {
case (SocialProofType.Favorite, scIds) =>
scIds.map { scId =>
SocialContextAction(
scId,
Time.now.inMilliseconds,
socialContextActionType = Some(SocialContextActionType.Favorite)
)
}
case (SocialProofType.Retweet, scIds) =>
scIds.map { scId =>
SocialContextAction(
scId,
Time.now.inMilliseconds,
socialContextActionType = Some(SocialContextActionType.Retweet)
)
}
case (SocialProofType.Reply, scIds) =>
scIds.map { scId =>
SocialContextAction(
scId,
Time.now.inMilliseconds,
socialContextActionType = Some(SocialContextActionType.Reply)
)
}
case (SocialProofType.Tweet, scIds) =>
scIds.map { scId =>
SocialContextAction(
scId,
Time.now.inMilliseconds,
socialContextActionType = Some(SocialContextActionType.Tweet)
)
}
case _ => Nil
}
}
private def generateRetweetCandidate(
inputTarget: Target,
candidate: EBCandidate,
scIds: Seq[Long],
socialProofTypes: Seq[(SocialProofType, Seq[Long])]
): RawCandidate = {
val scActions = scIds.map { scId => SocialContextAction(scId, Time.now.inMilliseconds) }
new RawCandidate with TweetRetweetCandidate with EarlybirdTweetFeatures {
override val socialContextActions = scActions
override val socialContextAllTypeActions = getAllSocialContextActions(socialProofTypes)
override val tweetId = candidate.tweetId
override val target = inputTarget
override val tweetyPieResult = candidate.tweetyPieResult
override val features = candidate.features
}
}
private def generateF1CandidateWithoutSocialContext(
inputTarget: Target,
candidate: EBCandidate
): RawCandidate = {
f1withoutSocialContexts.incr()
new RawCandidate with F1FirstDegree with EarlybirdTweetFeatures {
override val tweetId = candidate.tweetId
override val target = inputTarget
override val tweetyPieResult = candidate.tweetyPieResult
override val features = candidate.features
}
}
private def generateEarlyBirdCandidate(
id: Long,
result: Option[TweetyPieResult],
ebFeatures: Option[ThriftSearchResultFeatures]
): EBCandidate = {
new EarlybirdCandidate with TweetDetails {
override val tweetyPieResult: Option[TweetyPieResult] = result
override val tweetId: Long = id
override val features: Option[ThriftSearchResultFeatures] = ebFeatures
}
}
private def filterOutSeenTweets(seenTweetIds: Seq[Long], inputTweetIds: Seq[Long]): Seq[Long] = {
inputTweetIds.filterNot(seenTweetIds.contains)
}
private def filterInvalidTweets(
tweetIds: Seq[Long],
target: Target
): Future[Seq[(Long, TweetyPieResult)]] = {
val resMap = {
if (target.params(PushFeatureSwitchParams.EnableF1FromProtectedTweetAuthors)) {
userTweetTweetyPieStoreCounter.incr()
val keys = tweetIds.map { tweetId =>
UserTweet(tweetId, Some(target.targetId))
}
userTweetTweetyPieStore
.multiGet(keys.toSet).map {
case (userTweet, resultFut) =>
userTweet.tweetId -> resultFut
}.toMap
} else {
(target.params(PushFeatureSwitchParams.EnableVFInTweetypie) match {
case true => tweetyPieStore
case false => tweetyPieStoreNoVF
}).multiGet(tweetIds.toSet)
}
}
Future.collect(resMap).map { tweetyPieResultMap =>
val cands = filterOutReplyTweet(tweetyPieResultMap, nonReplyTweetsCounter).collect {
case (id: Long, Some(result)) =>
id -> result
}
emptyTweetyPieResult.add(tweetyPieResultMap.size - cands.size)
cands.toSeq
}
}
private def getEBRetweetCandidates(
inputTarget: Target,
retweets: Seq[(Long, TweetyPieResult)]
): Seq[RawCandidate] = {
retweets.flatMap {
case (_, tweetypieResult) =>
tweetypieResult.tweet.coreData.flatMap { coreData =>
tweetypieResult.sourceTweet.map { sourceTweet =>
val tweetId = sourceTweet.id
val scId = coreData.userId
val socialProofTypes = Seq((SocialProofType.Retweet, Seq(scId)))
val candidate = generateEarlyBirdCandidate(
tweetId,
Some(TweetyPieResult(sourceTweet, None, None)),
None
)
generateRetweetCandidate(
inputTarget,
candidate,
Seq(scId),
socialProofTypes
)
}
}
}
}
private def getEBFirstDegreeCands(
tweets: Seq[(Long, TweetyPieResult)],
ebTweetIdMap: Map[Long, Option[ThriftSearchResultFeatures]]
): Seq[EBCandidate] = {
tweets.map {
case (id, tweetypieResult) =>
val features = ebTweetIdMap.getOrElse(id, None)
generateEarlyBirdCandidate(id, Some(tweetypieResult), features)
}
}
/**
* Returns a combination of raw candidates made of: f1 recs, topic social proof recs, sc recs and retweet candidates
*/
def buildRawCandidates(
inputTarget: Target,
firstDegreeCandidates: Seq[EBCandidate],
retweetCandidates: Seq[RawCandidate]
): Seq[RawCandidate] = {
val hydratedF1Recs =
firstDegreeCandidates.map(generateF1CandidateWithoutSocialContext(inputTarget, _))
hydratedF1Recs ++ retweetCandidates
}
override def get(inputTarget: Target): Future[Option[Seq[RawCandidate]]] = {
inputTarget.seedsWithWeight.flatMap { seedsetOpt =>
val seedsetMap = seedsetOpt.getOrElse(Map.empty)
if (seedsetMap.isEmpty) {
seedSetEmpty.incr()
Future.None
} else {
val maxResultsToReturn = inputTarget.params(maxResultsParam)
val maxTweetAge = inputTarget.params(PushFeatureSwitchParams.F1CandidateMaxTweetAgeParam)
val earlybirdQuery = EarlybirdCandidateSource.Query(
maxNumResultsToReturn = maxResultsToReturn,
seedset = seedsetMap,
maxConsecutiveResultsByTheSameUser = Some(1),
maxTweetAge = maxTweetAge,
disableTimelinesMLModel = false,
searcherId = Some(inputTarget.targetId),
isProtectTweetsEnabled =
inputTarget.params(PushFeatureSwitchParams.EnableF1FromProtectedTweetAuthors),
followedUserIds = Some(seedsetMap.keySet.toSeq)
)
Future
.join(inputTarget.seenTweetIds, earlyBirdFirstDegreeCandidates.get(earlybirdQuery))
.flatMap {
case (seenTweetIds, Some(candidates)) =>
earlyBirdCandsStat.add(candidates.size)
val ebTweetIdMap = candidates.map { cand => cand.tweetId -> cand.features }.toMap
val ebTweetIds = ebTweetIdMap.keys.toSeq
val tweetIds = filterOutSeenTweets(seenTweetIds, ebTweetIds)
seenTweetsStat.add(ebTweetIds.size - tweetIds.size)
filterInvalidTweets(tweetIds, inputTarget)
.map { validTweets =>
val (retweets, tweets) = validTweets.partition {
case (_, tweetypieResult) =>
tweetypieResult.sourceTweet.isDefined
}
val firstDegreeCandidates = getEBFirstDegreeCands(tweets, ebTweetIdMap)
val retweetCandidates = {
if (inputTarget.params(PushParams.EarlyBirdSCBasedCandidatesParam) &&
inputTarget.params(PushParams.MRTweetRetweetRecsParam)) {
enableRetweets.incr()
getEBRetweetCandidates(inputTarget, retweets)
} else Nil
}
Some(
buildRawCandidates(
inputTarget,
firstDegreeCandidates,
retweetCandidates
))
}
case _ =>
emptyEarlyBirdCands.incr()
Future.None
}
}
}
}
override def isCandidateSourceAvailable(target: Target): Future[Boolean] = {
PushDeviceUtil.isRecommendationsEligible(target)
}
}

View File

@ -0,0 +1,120 @@
package com.twitter.frigate.pushservice.adaptor
import com.twitter.explore_ranker.thriftscala.ExploreRankerProductResponse
import com.twitter.explore_ranker.thriftscala.ExploreRankerRequest
import com.twitter.explore_ranker.thriftscala.ExploreRankerResponse
import com.twitter.explore_ranker.thriftscala.ExploreRecommendation
import com.twitter.explore_ranker.thriftscala.ImmersiveRecsResponse
import com.twitter.explore_ranker.thriftscala.ImmersiveRecsResult
import com.twitter.explore_ranker.thriftscala.NotificationsVideoRecs
import com.twitter.explore_ranker.thriftscala.Product
import com.twitter.explore_ranker.thriftscala.ProductContext
import com.twitter.finagle.stats.StatsReceiver
import com.twitter.frigate.common.base.CandidateSource
import com.twitter.frigate.common.base.CandidateSourceEligible
import com.twitter.frigate.common.base.OutOfNetworkTweetCandidate
import com.twitter.frigate.pushservice.model.PushTypes.RawCandidate
import com.twitter.frigate.pushservice.model.PushTypes.Target
import com.twitter.frigate.pushservice.params.PushFeatureSwitchParams
import com.twitter.frigate.pushservice.util.AdaptorUtils
import com.twitter.frigate.pushservice.util.MediaCRT
import com.twitter.frigate.pushservice.util.PushAdaptorUtil
import com.twitter.frigate.pushservice.util.PushDeviceUtil
import com.twitter.frigate.thriftscala.CommonRecommendationType
import com.twitter.product_mixer.core.thriftscala.ClientContext
import com.twitter.stitch.tweetypie.TweetyPie.TweetyPieResult
import com.twitter.storehaus.ReadableStore
import com.twitter.util.Future
case class ExploreVideoTweetCandidateAdaptor(
exploreRankerStore: ReadableStore[ExploreRankerRequest, ExploreRankerResponse],
tweetyPieStore: ReadableStore[Long, TweetyPieResult],
globalStats: StatsReceiver)
extends CandidateSource[Target, RawCandidate]
with CandidateSourceEligible[Target, RawCandidate] {
override def name: String = this.getClass.getSimpleName
private[this] val stats = globalStats.scope("ExploreVideoTweetCandidateAdaptor")
private[this] val totalInputRecs = stats.stat("input_recs")
private[this] val totalRequests = stats.counter("total_requests")
private[this] val totalEmptyResponse = stats.counter("total_empty_response")
private def buildExploreRankerRequest(
target: Target,
countryCode: Option[String],
language: Option[String],
): ExploreRankerRequest = {
ExploreRankerRequest(
clientContext = ClientContext(
userId = Some(target.targetId),
countryCode = countryCode,
languageCode = language,
),
product = Product.NotificationsVideoRecs,
productContext = Some(ProductContext.NotificationsVideoRecs(NotificationsVideoRecs())),
maxResults = Some(target.params(PushFeatureSwitchParams.MaxExploreVideoTweets))
)
}
override def get(target: Target): Future[Option[Seq[RawCandidate]]] = {
Future
.join(
target.countryCode,
target.inferredUserDeviceLanguage
).flatMap {
case (countryCode, language) =>
val request = buildExploreRankerRequest(target, countryCode, language)
exploreRankerStore.get(request).flatMap {
case Some(response) =>
val exploreResonseTweetIds = response match {
case ExploreRankerResponse(ExploreRankerProductResponse
.ImmersiveRecsResponse(ImmersiveRecsResponse(immersiveRecsResult))) =>
immersiveRecsResult.collect {
case ImmersiveRecsResult(ExploreRecommendation
.ExploreTweetRecommendation(exploreTweetRecommendation)) =>
exploreTweetRecommendation.tweetId
}
case _ =>
Seq.empty
}
totalInputRecs.add(exploreResonseTweetIds.size)
totalRequests.incr()
AdaptorUtils
.getTweetyPieResults(exploreResonseTweetIds.toSet, tweetyPieStore).map {
tweetyPieResultMap =>
val candidates = tweetyPieResultMap.values.flatten
.map(buildVideoRawCandidates(target, _))
Some(candidates.toSeq)
}
case _ =>
totalEmptyResponse.incr()
Future.None
}
case _ =>
Future.None
}
}
override def isCandidateSourceAvailable(target: Target): Future[Boolean] = {
PushDeviceUtil.isRecommendationsEligible(target).map { userRecommendationsEligible =>
userRecommendationsEligible && target.params(PushFeatureSwitchParams.EnableExploreVideoTweets)
}
}
private def buildVideoRawCandidates(
target: Target,
tweetyPieResult: TweetyPieResult
): RawCandidate with OutOfNetworkTweetCandidate = {
PushAdaptorUtil.generateOutOfNetworkTweetCandidates(
inputTarget = target,
id = tweetyPieResult.tweet.id,
mediaCRT = MediaCRT(
CommonRecommendationType.ExploreVideoTweet,
CommonRecommendationType.ExploreVideoTweet,
CommonRecommendationType.ExploreVideoTweet
),
result = Some(tweetyPieResult),
localizedEntity = None
)
}
}

View File

@ -0,0 +1,272 @@
package com.twitter.frigate.pushservice.adaptor
import com.twitter.cr_mixer.thriftscala.FrsTweetRequest
import com.twitter.cr_mixer.thriftscala.NotificationsContext
import com.twitter.cr_mixer.thriftscala.Product
import com.twitter.cr_mixer.thriftscala.ProductContext
import com.twitter.finagle.stats.Counter
import com.twitter.finagle.stats.StatsReceiver
import com.twitter.frigate.common.base.CandidateSource
import com.twitter.frigate.common.base.CandidateSourceEligible
import com.twitter.frigate.common.base._
import com.twitter.frigate.common.predicate.CommonOutNetworkTweetCandidatesSourcePredicates.filterOutReplyTweet
import com.twitter.frigate.pushservice.model.PushTypes.RawCandidate
import com.twitter.frigate.pushservice.model.PushTypes.Target
import com.twitter.frigate.pushservice.params.PushFeatureSwitchParams
import com.twitter.frigate.pushservice.store.CrMixerTweetStore
import com.twitter.frigate.pushservice.store.UttEntityHydrationStore
import com.twitter.frigate.pushservice.util.MediaCRT
import com.twitter.frigate.pushservice.util.PushAdaptorUtil
import com.twitter.frigate.pushservice.util.PushDeviceUtil
import com.twitter.frigate.pushservice.util.TopicsUtil
import com.twitter.frigate.thriftscala.CommonRecommendationType
import com.twitter.hermit.constants.AlgorithmFeedbackTokens
import com.twitter.hermit.model.Algorithm.Algorithm
import com.twitter.hermit.model.Algorithm.CrowdSearchAccounts
import com.twitter.hermit.model.Algorithm.ForwardEmailBook
import com.twitter.hermit.model.Algorithm.ForwardPhoneBook
import com.twitter.hermit.model.Algorithm.ReverseEmailBookIbis
import com.twitter.hermit.model.Algorithm.ReversePhoneBook
import com.twitter.hermit.store.tweetypie.UserTweet
import com.twitter.product_mixer.core.thriftscala.ClientContext
import com.twitter.stitch.tweetypie.TweetyPie.TweetyPieResult
import com.twitter.storehaus.ReadableStore
import com.twitter.tsp.thriftscala.TopicSocialProofRequest
import com.twitter.tsp.thriftscala.TopicSocialProofResponse
import com.twitter.util.Future
object FRSAlgorithmFeedbackTokenUtil {
private val crtsByAlgoToken = Map(
getAlgorithmToken(ReverseEmailBookIbis) -> CommonRecommendationType.ReverseAddressbookTweet,
getAlgorithmToken(ReversePhoneBook) -> CommonRecommendationType.ReverseAddressbookTweet,
getAlgorithmToken(ForwardEmailBook) -> CommonRecommendationType.ForwardAddressbookTweet,
getAlgorithmToken(ForwardPhoneBook) -> CommonRecommendationType.ForwardAddressbookTweet,
getAlgorithmToken(CrowdSearchAccounts) -> CommonRecommendationType.CrowdSearchTweet
)
def getAlgorithmToken(algorithm: Algorithm): Int = {
AlgorithmFeedbackTokens.AlgorithmToFeedbackTokenMap(algorithm)
}
def getCRTForAlgoToken(algorithmToken: Int): Option[CommonRecommendationType] = {
crtsByAlgoToken.get(algorithmToken)
}
}
case class FRSTweetCandidateAdaptor(
crMixerTweetStore: CrMixerTweetStore,
tweetyPieStore: ReadableStore[Long, TweetyPieResult],
tweetyPieStoreNoVF: ReadableStore[Long, TweetyPieResult],
userTweetTweetyPieStore: ReadableStore[UserTweet, TweetyPieResult],
uttEntityHydrationStore: UttEntityHydrationStore,
topicSocialProofServiceStore: ReadableStore[TopicSocialProofRequest, TopicSocialProofResponse],
globalStats: StatsReceiver)
extends CandidateSource[Target, RawCandidate]
with CandidateSourceEligible[Target, RawCandidate] {
private val stats = globalStats.scope(this.getClass.getSimpleName)
private val crtStats = stats.scope("CandidateDistribution")
private val totalRequests = stats.counter("total_requests")
// Candidate Distribution stats
private val reverseAddressbookCounter = crtStats.counter("reverse_addressbook")
private val forwardAddressbookCounter = crtStats.counter("forward_addressbook")
private val frsTweetCounter = crtStats.counter("frs_tweet")
private val nonReplyTweetsCounter = stats.counter("non_reply_tweets")
private val crtToCounterMapping: Map[CommonRecommendationType, Counter] = Map(
CommonRecommendationType.ReverseAddressbookTweet -> reverseAddressbookCounter,
CommonRecommendationType.ForwardAddressbookTweet -> forwardAddressbookCounter,
CommonRecommendationType.FrsTweet -> frsTweetCounter
)
private val emptyTweetyPieResult = stats.stat("empty_tweetypie_result")
private[this] val numberReturnedCandidates = stats.stat("returned_candidates_from_earlybird")
private[this] val numberCandidateWithTopic: Counter = stats.counter("num_can_with_topic")
private[this] val numberCandidateWithoutTopic: Counter = stats.counter("num_can_without_topic")
private val userTweetTweetyPieStoreCounter = stats.counter("user_tweet_tweetypie_store")
override val name: String = this.getClass.getSimpleName
private def filterInvalidTweets(
tweetIds: Seq[Long],
target: Target
): Future[Map[Long, TweetyPieResult]] = {
val resMap = {
if (target.params(PushFeatureSwitchParams.EnableF1FromProtectedTweetAuthors)) {
userTweetTweetyPieStoreCounter.incr()
val keys = tweetIds.map { tweetId =>
UserTweet(tweetId, Some(target.targetId))
}
userTweetTweetyPieStore
.multiGet(keys.toSet).map {
case (userTweet, resultFut) =>
userTweet.tweetId -> resultFut
}.toMap
} else {
(if (target.params(PushFeatureSwitchParams.EnableVFInTweetypie)) {
tweetyPieStore
} else {
tweetyPieStoreNoVF
}).multiGet(tweetIds.toSet)
}
}
Future.collect(resMap).map { tweetyPieResultMap =>
// Filter out replies and generate earlybird candidates only for non-empty tweetypie result
val cands = filterOutReplyTweet(tweetyPieResultMap, nonReplyTweetsCounter).collect {
case (id: Long, Some(result)) =>
id -> result
}
emptyTweetyPieResult.add(tweetyPieResultMap.size - cands.size)
cands
}
}
private def buildRawCandidates(
target: Target,
ebCandidates: Seq[FRSTweetCandidate]
): Future[Option[Seq[RawCandidate with TweetCandidate]]] = {
val enableTopic = target.params(PushFeatureSwitchParams.EnableFrsTweetCandidatesTopicAnnotation)
val topicScoreThre =
target.params(PushFeatureSwitchParams.FrsTweetCandidatesTopicScoreThreshold)
val ebTweets = ebCandidates.map { ebCandidate =>
ebCandidate.tweetId -> ebCandidate.tweetyPieResult
}.toMap
val tweetIdLocalizedEntityMapFut = TopicsUtil.getTweetIdLocalizedEntityMap(
target,
ebTweets,
uttEntityHydrationStore,
topicSocialProofServiceStore,
enableTopic,
topicScoreThre
)
Future.join(target.deviceInfo, tweetIdLocalizedEntityMapFut).map {
case (Some(deviceInfo), tweetIdLocalizedEntityMap) =>
val candidates = ebCandidates
.map { ebCandidate =>
val crt = ebCandidate.commonRecType
crtToCounterMapping.get(crt).foreach(_.incr())
val tweetId = ebCandidate.tweetId
val localizedEntityOpt = {
if (tweetIdLocalizedEntityMap
.contains(tweetId) && tweetIdLocalizedEntityMap.contains(
tweetId) && deviceInfo.isTopicsEligible) {
tweetIdLocalizedEntityMap(tweetId)
} else {
None
}
}
PushAdaptorUtil.generateOutOfNetworkTweetCandidates(
inputTarget = target,
id = ebCandidate.tweetId,
mediaCRT = MediaCRT(
crt,
crt,
crt
),
result = ebCandidate.tweetyPieResult,
localizedEntity = localizedEntityOpt)
}.filter { candidate =>
// If user only has the topic setting enabled, filter out all non-topic cands
deviceInfo.isRecommendationsEligible || (deviceInfo.isTopicsEligible && candidate.semanticCoreEntityId.nonEmpty)
}
candidates.map { candidate =>
if (candidate.semanticCoreEntityId.nonEmpty) {
numberCandidateWithTopic.incr()
} else {
numberCandidateWithoutTopic.incr()
}
}
numberReturnedCandidates.add(candidates.length)
Some(candidates)
case _ => Some(Seq.empty)
}
}
def getTweetCandidatesFromCrMixer(
inputTarget: Target,
showAllResultsFromFrs: Boolean,
): Future[Option[Seq[RawCandidate with TweetCandidate]]] = {
Future
.join(
inputTarget.seenTweetIds,
inputTarget.pushRecItems,
inputTarget.countryCode,
inputTarget.targetLanguage).flatMap {
case (seenTweetIds, pastRecItems, countryCode, language) =>
val pastUserRecs = pastRecItems.userIds.toSeq
val request = FrsTweetRequest(
clientContext = ClientContext(
userId = Some(inputTarget.targetId),
countryCode = countryCode,
languageCode = language
),
product = Product.Notifications,
productContext = Some(ProductContext.NotificationsContext(NotificationsContext())),
excludedUserIds = Some(pastUserRecs),
excludedTweetIds = Some(seenTweetIds)
)
crMixerTweetStore.getFRSTweetCandidates(request).flatMap {
case Some(response) =>
val tweetIds = response.tweets.map(_.tweetId)
val validTweets = filterInvalidTweets(tweetIds, inputTarget)
validTweets.flatMap { tweetypieMap =>
val ebCandidates = response.tweets
.map { frsTweet =>
val candidateTweetId = frsTweet.tweetId
val resultFromTweetyPie = tweetypieMap.get(candidateTweetId)
new FRSTweetCandidate {
override val tweetId = candidateTweetId
override val features = None
override val tweetyPieResult = resultFromTweetyPie
override val feedbackToken = frsTweet.frsPrimarySource
override val commonRecType: CommonRecommendationType = feedbackToken
.flatMap(token =>
FRSAlgorithmFeedbackTokenUtil.getCRTForAlgoToken(token)).getOrElse(
CommonRecommendationType.FrsTweet)
}
}.filter { ebCandidate =>
showAllResultsFromFrs || ebCandidate.commonRecType == CommonRecommendationType.ReverseAddressbookTweet
}
numberReturnedCandidates.add(ebCandidates.length)
buildRawCandidates(
inputTarget,
ebCandidates
)
}
case _ => Future.None
}
}
}
override def get(inputTarget: Target): Future[Option[Seq[RawCandidate with TweetCandidate]]] = {
totalRequests.incr()
val enableResultsFromFrs =
inputTarget.params(PushFeatureSwitchParams.EnableResultFromFrsCandidates)
getTweetCandidatesFromCrMixer(inputTarget, enableResultsFromFrs)
}
override def isCandidateSourceAvailable(target: Target): Future[Boolean] = {
lazy val enableFrsCandidates = target.params(PushFeatureSwitchParams.EnableFrsCandidates)
PushDeviceUtil.isRecommendationsEligible(target).flatMap { isEnabledForRecosSetting =>
PushDeviceUtil.isTopicsEligible(target).map { topicSettingEnabled =>
val isEnabledForTopics =
topicSettingEnabled && target.params(
PushFeatureSwitchParams.EnableFrsTweetCandidatesTopicSetting)
(isEnabledForRecosSetting || isEnabledForTopics) && enableFrsCandidates
}
}
}
}

View File

@ -0,0 +1,107 @@
package com.twitter.frigate.pushservice.adaptor
import com.twitter.finagle.stats.StatsReceiver
import com.twitter.frigate.common.base._
import com.twitter.frigate.common.candidate._
import com.twitter.frigate.pushservice.model.PushTypes.RawCandidate
import com.twitter.frigate.pushservice.model.PushTypes.Target
import com.twitter.frigate.pushservice.params.PushParams
import com.twitter.frigate.pushservice.util.PushDeviceUtil
import com.twitter.stitch.tweetypie.TweetyPie.TweetyPieResult
import com.twitter.storehaus.ReadableStore
import com.twitter.util.Future
object GenericCandidates {
type Target =
TargetUser
with UserDetails
with TargetDecider
with TargetABDecider
with TweetImpressionHistory
with HTLVisitHistory
with MaxTweetAge
with NewUserDetails
with FrigateHistory
with TargetWithSeedUsers
}
case class GenericCandidateAdaptor(
genericCandidates: CandidateSource[GenericCandidates.Target, Candidate],
tweetyPieStore: ReadableStore[Long, TweetyPieResult],
tweetyPieStoreNoVF: ReadableStore[Long, TweetyPieResult],
stats: StatsReceiver)
extends CandidateSource[Target, RawCandidate]
with CandidateSourceEligible[Target, RawCandidate] {
override val name: String = genericCandidates.name
private def generateTweetFavCandidate(
_target: Target,
_tweetId: Long,
_socialContextActions: Seq[SocialContextAction],
socialContextActionsAllTypes: Seq[SocialContextAction],
_tweetyPieResult: Option[TweetyPieResult]
): RawCandidate = {
new RawCandidate with TweetFavoriteCandidate {
override val socialContextActions = _socialContextActions
override val socialContextAllTypeActions =
socialContextActionsAllTypes
val tweetId = _tweetId
val target = _target
val tweetyPieResult = _tweetyPieResult
}
}
private def generateTweetRetweetCandidate(
_target: Target,
_tweetId: Long,
_socialContextActions: Seq[SocialContextAction],
socialContextActionsAllTypes: Seq[SocialContextAction],
_tweetyPieResult: Option[TweetyPieResult]
): RawCandidate = {
new RawCandidate with TweetRetweetCandidate {
override val socialContextActions = _socialContextActions
override val socialContextAllTypeActions = socialContextActionsAllTypes
val tweetId = _tweetId
val target = _target
val tweetyPieResult = _tweetyPieResult
}
}
override def get(inputTarget: Target): Future[Option[Seq[RawCandidate]]] = {
genericCandidates.get(inputTarget).map { candidatesOpt =>
candidatesOpt
.map { candidates =>
val candidatesSeq =
candidates.collect {
case tweetRetweet: TweetRetweetCandidate
if inputTarget.params(PushParams.MRTweetRetweetRecsParam) =>
generateTweetRetweetCandidate(
inputTarget,
tweetRetweet.tweetId,
tweetRetweet.socialContextActions,
tweetRetweet.socialContextAllTypeActions,
tweetRetweet.tweetyPieResult)
case tweetFavorite: TweetFavoriteCandidate
if inputTarget.params(PushParams.MRTweetFavRecsParam) =>
generateTweetFavCandidate(
inputTarget,
tweetFavorite.tweetId,
tweetFavorite.socialContextActions,
tweetFavorite.socialContextAllTypeActions,
tweetFavorite.tweetyPieResult)
}
candidatesSeq.foreach { candidate =>
stats.counter(s"${candidate.commonRecType}_count").incr()
}
candidatesSeq
}
}
}
override def isCandidateSourceAvailable(target: Target): Future[Boolean] = {
PushDeviceUtil.isRecommendationsEligible(target).map { isAvailable =>
isAvailable && target.params(PushParams.GenericCandidateAdaptorDecider)
}
}
}

View File

@ -0,0 +1,280 @@
package com.twitter.frigate.pushservice.adaptor
import com.twitter.finagle.stats.Stat
import com.twitter.finagle.stats.StatsReceiver
import com.twitter.frigate.common.base.CandidateSource
import com.twitter.frigate.common.base.CandidateSourceEligible
import com.twitter.frigate.common.store.interests.InterestsLookupRequestWithContext
import com.twitter.frigate.pushservice.model.PushTypes.RawCandidate
import com.twitter.frigate.pushservice.model.PushTypes.Target
import com.twitter.frigate.pushservice.params.HighQualityCandidateGroupEnum
import com.twitter.frigate.pushservice.params.HighQualityCandidateGroupEnum._
import com.twitter.frigate.pushservice.params.PushConstants.targetUserAgeFeatureName
import com.twitter.frigate.pushservice.params.PushConstants.targetUserPreferredLanguage
import com.twitter.frigate.pushservice.params.{PushFeatureSwitchParams => FS}
import com.twitter.frigate.pushservice.predicate.TargetPredicates
import com.twitter.frigate.pushservice.util.MediaCRT
import com.twitter.frigate.pushservice.util.PushAdaptorUtil
import com.twitter.frigate.pushservice.util.PushDeviceUtil
import com.twitter.frigate.pushservice.util.TopicsUtil
import com.twitter.frigate.thriftscala.CommonRecommendationType
import com.twitter.interests.thriftscala.InterestId.SemanticCore
import com.twitter.interests.thriftscala.UserInterests
import com.twitter.language.normalization.UserDisplayLanguage
import com.twitter.stitch.tweetypie.TweetyPie.TweetyPieResult
import com.twitter.storehaus.ReadableStore
import com.twitter.trends.trip_v1.trip_tweets.thriftscala.TripDomain
import com.twitter.trends.trip_v1.trip_tweets.thriftscala.TripTweet
import com.twitter.trends.trip_v1.trip_tweets.thriftscala.TripTweets
import com.twitter.util.Future
object HighQualityTweetsHelper {
def getFollowedTopics(
target: Target,
interestsWithLookupContextStore: ReadableStore[
InterestsLookupRequestWithContext,
UserInterests
],
followedTopicsStats: Stat
): Future[Seq[Long]] = {
TopicsUtil
.getTopicsFollowedByUser(target, interestsWithLookupContextStore, followedTopicsStats).map {
userInterestsOpt =>
val userInterests = userInterestsOpt.getOrElse(Seq.empty)
val extractedTopicIds = userInterests.flatMap {
_.interestId match {
case SemanticCore(semanticCore) => Some(semanticCore.id)
case _ => None
}
}
extractedTopicIds
}
}
def getTripQueries(
target: Target,
enabledGroups: Set[HighQualityCandidateGroupEnum.Value],
interestsWithLookupContextStore: ReadableStore[
InterestsLookupRequestWithContext,
UserInterests
],
sourceIds: Seq[String],
stat: Stat
): Future[Set[TripDomain]] = {
val followedTopicIdsSetFut: Future[Set[Long]] = if (enabledGroups.contains(Topic)) {
getFollowedTopics(target, interestsWithLookupContextStore, stat).map(topicIds =>
topicIds.toSet)
} else {
Future.value(Set.empty)
}
Future
.join(target.featureMap, target.inferredUserDeviceLanguage, followedTopicIdsSetFut).map {
case (
featureMap,
deviceLanguageOpt,
followedTopicIds
) =>
val ageBucketOpt = if (enabledGroups.contains(AgeBucket)) {
featureMap.categoricalFeatures.get(targetUserAgeFeatureName)
} else {
None
}
val languageOptions: Set[Option[String]] = if (enabledGroups.contains(Language)) {
val userPreferredLanguages = featureMap.sparseBinaryFeatures
.getOrElse(targetUserPreferredLanguage, Set.empty[String])
if (userPreferredLanguages.nonEmpty) {
userPreferredLanguages.map(lang => Some(UserDisplayLanguage.toTweetLanguage(lang)))
} else {
Set(deviceLanguageOpt.map(UserDisplayLanguage.toTweetLanguage))
}
} else Set(None)
val followedTopicOptions: Set[Option[Long]] = if (followedTopicIds.nonEmpty) {
followedTopicIds.map(topic => Some(topic))
} else Set(None)
val tripQueries = followedTopicOptions.flatMap { topicOption =>
languageOptions.flatMap { languageOption =>
sourceIds.map { sourceId =>
TripDomain(
sourceId = sourceId,
language = languageOption,
placeId = None,
topicId = topicOption,
gender = None,
ageBucket = ageBucketOpt
)
}
}
}
tripQueries
}
}
}
case class HighQualityTweetsAdaptor(
tripTweetCandidateStore: ReadableStore[TripDomain, TripTweets],
interestsWithLookupContextStore: ReadableStore[InterestsLookupRequestWithContext, UserInterests],
tweetyPieStore: ReadableStore[Long, TweetyPieResult],
tweetyPieStoreNoVF: ReadableStore[Long, TweetyPieResult],
globalStats: StatsReceiver)
extends CandidateSource[Target, RawCandidate]
with CandidateSourceEligible[Target, RawCandidate] {
override def name: String = this.getClass.getSimpleName
private val stats = globalStats.scope("HighQualityCandidateAdaptor")
private val followedTopicsStats = stats.stat("followed_topics")
private val missingResponseCounter = stats.counter("missing_respond_counter")
private val crtFatigueCounter = stats.counter("fatigue_by_crt")
private val fallbackRequestsCounter = stats.counter("fallback_requests")
override def isCandidateSourceAvailable(target: Target): Future[Boolean] = {
PushDeviceUtil.isRecommendationsEligible(target).map {
_ && target.params(FS.HighQualityCandidatesEnableCandidateSource)
}
}
private val highQualityCandidateFrequencyPredicate = {
TargetPredicates
.pushRecTypeFatiguePredicate(
CommonRecommendationType.TripHqTweet,
FS.HighQualityTweetsPushInterval,
FS.MaxHighQualityTweetsPushGivenInterval,
stats
)
}
private def getTripCandidatesStrato(
target: Target
): Future[Map[Long, Set[TripDomain]]] = {
val tripQueriesF: Future[Set[TripDomain]] = HighQualityTweetsHelper.getTripQueries(
target = target,
enabledGroups = target.params(FS.HighQualityCandidatesEnableGroups).toSet,
interestsWithLookupContextStore = interestsWithLookupContextStore,
sourceIds = target.params(FS.TripTweetCandidateSourceIds),
stat = followedTopicsStats
)
lazy val fallbackTripQueriesFut: Future[Set[TripDomain]] =
if (target.params(FS.HighQualityCandidatesEnableFallback))
HighQualityTweetsHelper.getTripQueries(
target = target,
enabledGroups = target.params(FS.HighQualityCandidatesFallbackEnabledGroups).toSet,
interestsWithLookupContextStore = interestsWithLookupContextStore,
sourceIds = target.params(FS.HighQualityCandidatesFallbackSourceIds),
stat = followedTopicsStats
)
else Future.value(Set.empty)
val initialTweetsFut: Future[Map[TripDomain, Seq[TripTweet]]] = tripQueriesF.flatMap {
tripQueries => getTripTweetsByDomains(tripQueries)
}
val tweetsByDomainFut: Future[Map[TripDomain, Seq[TripTweet]]] =
if (target.params(FS.HighQualityCandidatesEnableFallback)) {
initialTweetsFut.flatMap { candidates =>
val minCandidatesForFallback: Int =
target.params(FS.HighQualityCandidatesMinNumOfCandidatesToFallback)
val validCandidates = candidates.filter(_._2.size >= minCandidatesForFallback)
if (validCandidates.nonEmpty) {
Future.value(validCandidates)
} else {
fallbackTripQueriesFut.flatMap { fallbackTripDomains =>
fallbackRequestsCounter.incr(fallbackTripDomains.size)
getTripTweetsByDomains(fallbackTripDomains)
}
}
}
} else {
initialTweetsFut
}
val numOfCandidates: Int = target.params(FS.HighQualityCandidatesNumberOfCandidates)
tweetsByDomainFut.map(tweetsByDomain => reformatDomainTweetMap(tweetsByDomain, numOfCandidates))
}
private def getTripTweetsByDomains(
tripQueries: Set[TripDomain]
): Future[Map[TripDomain, Seq[TripTweet]]] = {
Future.collect(tripTweetCandidateStore.multiGet(tripQueries)).map { response =>
response
.filter(p => p._2.exists(_.tweets.nonEmpty))
.mapValues(_.map(_.tweets).getOrElse(Seq.empty))
}
}
private def reformatDomainTweetMap(
tweetsByDomain: Map[TripDomain, Seq[TripTweet]],
numOfCandidates: Int
): Map[Long, Set[TripDomain]] = tweetsByDomain
.flatMap {
case (tripDomain, tripTweets) =>
tripTweets
.sortBy(_.score)(Ordering[Double].reverse)
.take(numOfCandidates)
.map { tweet => (tweet.tweetId, tripDomain) }
}.groupBy(_._1).mapValues(_.map(_._2).toSet)
private def buildRawCandidate(
target: Target,
tweetyPieResult: TweetyPieResult,
tripDomain: Option[scala.collection.Set[TripDomain]]
): RawCandidate = {
PushAdaptorUtil.generateOutOfNetworkTweetCandidates(
inputTarget = target,
id = tweetyPieResult.tweet.id,
mediaCRT = MediaCRT(
CommonRecommendationType.TripHqTweet,
CommonRecommendationType.TripHqTweet,
CommonRecommendationType.TripHqTweet
),
result = Some(tweetyPieResult),
tripTweetDomain = tripDomain
)
}
private def getTweetyPieResults(
target: Target,
tweetToTripDomain: Map[Long, Set[TripDomain]]
): Future[Map[Long, Option[TweetyPieResult]]] = {
Future.collect((if (target.params(FS.EnableVFInTweetypie)) {
tweetyPieStore
} else {
tweetyPieStoreNoVF
}).multiGet(tweetToTripDomain.keySet))
}
override def get(target: Target): Future[Option[Seq[RawCandidate]]] = {
for {
tweetsToTripDomainMap <- getTripCandidatesStrato(target)
tweetyPieResults <- getTweetyPieResults(target, tweetsToTripDomainMap)
} yield {
val candidates = tweetyPieResults.flatMap {
case (tweetId, tweetyPieResultOpt) =>
tweetyPieResultOpt.map(buildRawCandidate(target, _, tweetsToTripDomainMap.get(tweetId)))
}
if (candidates.nonEmpty) {
highQualityCandidateFrequencyPredicate(Seq(target))
.map(_.head)
.map { isTargetFatigueEligible =>
if (isTargetFatigueEligible) Some(candidates)
else {
crtFatigueCounter.incr()
None
}
}
Some(candidates.toSeq)
} else {
missingResponseCounter.incr()
None
}
}
}
}

View File

@ -0,0 +1,152 @@
package com.twitter.frigate.pushservice.adaptor
import com.twitter.finagle.stats.StatsReceiver
import com.twitter.frigate.common.base.CandidateSource
import com.twitter.frigate.common.base.CandidateSourceEligible
import com.twitter.frigate.common.base.ListPushCandidate
import com.twitter.frigate.pushservice.model.PushTypes.RawCandidate
import com.twitter.frigate.pushservice.model.PushTypes.Target
import com.twitter.frigate.pushservice.params.PushFeatureSwitchParams
import com.twitter.frigate.pushservice.predicate.TargetPredicates
import com.twitter.frigate.pushservice.util.PushDeviceUtil
import com.twitter.frigate.thriftscala.CommonRecommendationType
import com.twitter.geoduck.service.thriftscala.LocationResponse
import com.twitter.interests_discovery.thriftscala.DisplayLocation
import com.twitter.interests_discovery.thriftscala.NonPersonalizedRecommendedLists
import com.twitter.interests_discovery.thriftscala.RecommendedListsRequest
import com.twitter.interests_discovery.thriftscala.RecommendedListsResponse
import com.twitter.storehaus.ReadableStore
import com.twitter.util.Future
case class ListsToRecommendCandidateAdaptor(
listRecommendationsStore: ReadableStore[String, NonPersonalizedRecommendedLists],
geoDuckV2Store: ReadableStore[Long, LocationResponse],
idsStore: ReadableStore[RecommendedListsRequest, RecommendedListsResponse],
globalStats: StatsReceiver)
extends CandidateSource[Target, RawCandidate]
with CandidateSourceEligible[Target, RawCandidate] {
override val name: String = this.getClass.getSimpleName
private[this] val stats = globalStats.scope(name)
private[this] val noLocationCodeCounter = stats.counter("no_location_code")
private[this] val noCandidatesCounter = stats.counter("no_candidates_for_geo")
private[this] val disablePopGeoListsCounter = stats.counter("disable_pop_geo_lists")
private[this] val disableIDSListsCounter = stats.counter("disable_ids_lists")
private def getListCandidate(
targetUser: Target,
_listId: Long
): RawCandidate with ListPushCandidate = {
new RawCandidate with ListPushCandidate {
override val listId: Long = _listId
override val commonRecType: CommonRecommendationType = CommonRecommendationType.List
override val target: Target = targetUser
}
}
private def getListsRecommendedFromHistory(
target: Target
): Future[Seq[Long]] = {
target.history.map { history =>
history.sortedHistory.flatMap {
case (_, notif) if notif.commonRecommendationType == List =>
notif.listNotification.map(_.listId)
case _ => None
}
}
}
private def getIDSListRecs(
target: Target,
historicalListIds: Seq[Long]
): Future[Seq[Long]] = {
val request = RecommendedListsRequest(
target.targetId,
DisplayLocation.ListDiscoveryPage,
Some(historicalListIds)
)
if (target.params(PushFeatureSwitchParams.EnableIDSListRecommendations)) {
idsStore.get(request).map {
case Some(response) =>
response.channels.map(_.id)
case _ => Nil
}
} else {
disableIDSListsCounter.incr()
Future.Nil
}
}
private def getPopGeoLists(
target: Target,
historicalListIds: Seq[Long]
): Future[Seq[Long]] = {
if (target.params(PushFeatureSwitchParams.EnablePopGeoListRecommendations)) {
geoDuckV2Store.get(target.targetId).flatMap {
case Some(locationResponse) if locationResponse.geohash.isDefined =>
val geoHashLength =
target.params(PushFeatureSwitchParams.ListRecommendationsGeoHashLength)
val geoHash = locationResponse.geohash.get.take(geoHashLength)
listRecommendationsStore
.get(s"geohash_$geoHash")
.map {
case Some(recommendedLists) =>
recommendedLists.recommendedListsByAlgo.flatMap { topLists =>
topLists.lists.collect {
case list if !historicalListIds.contains(list.listId) => list.listId
}
}
case _ => Nil
}
case _ =>
noLocationCodeCounter.incr()
Future.Nil
}
} else {
disablePopGeoListsCounter.incr()
Future.Nil
}
}
override def get(target: Target): Future[Option[Seq[RawCandidate]]] = {
getListsRecommendedFromHistory(target).flatMap { historicalListIds =>
Future
.join(
getPopGeoLists(target, historicalListIds),
getIDSListRecs(target, historicalListIds)
)
.map {
case (popGeoListsIds, idsListIds) =>
val candidates = (idsListIds ++ popGeoListsIds).map(getListCandidate(target, _))
Some(candidates)
case _ =>
noCandidatesCounter.incr()
None
}
}
}
private val pushCapFatiguePredicate = TargetPredicates.pushRecTypeFatiguePredicate(
CommonRecommendationType.List,
PushFeatureSwitchParams.ListRecommendationsPushInterval,
PushFeatureSwitchParams.MaxListRecommendationsPushGivenInterval,
stats,
)
override def isCandidateSourceAvailable(target: Target): Future[Boolean] = {
val isNotFatigued = pushCapFatiguePredicate.apply(Seq(target)).map(_.head)
Future
.join(
PushDeviceUtil.isRecommendationsEligible(target),
isNotFatigued
).map {
case (userRecommendationsEligible, isUnderCAP) =>
userRecommendationsEligible && isUnderCAP && target.params(
PushFeatureSwitchParams.EnableListRecommendations)
}
}
}

View File

@ -0,0 +1,54 @@
package com.twitter.frigate.pushservice.adaptor
import com.twitter.finagle.stats.StatsReceiver
import com.twitter.frigate.common.base.CandidateSource
import com.twitter.frigate.common.base.CandidateSourceEligible
import com.twitter.frigate.pushservice.model.PushTypes.RawCandidate
import com.twitter.frigate.pushservice.model.PushTypes.Target
import com.twitter.geoduck.service.thriftscala.LocationResponse
import com.twitter.stitch.tweetypie.TweetyPie.TweetyPieResult
import com.twitter.storehaus.ReadableStore
import com.twitter.trends.trip_v1.trip_tweets.thriftscala.TripDomain
import com.twitter.trends.trip_v1.trip_tweets.thriftscala.TripTweets
import com.twitter.content_mixer.thriftscala.ContentMixerRequest
import com.twitter.content_mixer.thriftscala.ContentMixerResponse
import com.twitter.geoduck.common.thriftscala.Location
import com.twitter.hermit.pop_geo.thriftscala.PopTweetsInPlace
import com.twitter.recommendation.interests.discovery.core.model.InterestDomain
class LoggedOutPushCandidateSourceGenerator(
tripTweetCandidateStore: ReadableStore[TripDomain, TripTweets],
geoDuckV2Store: ReadableStore[Long, LocationResponse],
safeCachedTweetyPieStoreV2: ReadableStore[Long, TweetyPieResult],
cachedTweetyPieStoreV2NoVF: ReadableStore[Long, TweetyPieResult],
cachedTweetyPieStoreV2: ReadableStore[Long, TweetyPieResult],
contentMixerStore: ReadableStore[ContentMixerRequest, ContentMixerResponse],
softUserLocationStore: ReadableStore[Long, Location],
topTweetsByGeoStore: ReadableStore[InterestDomain[String], Map[String, List[(Long, Double)]]],
topTweetsByGeoV2VersionedStore: ReadableStore[String, PopTweetsInPlace],
)(
implicit val globalStats: StatsReceiver) {
val sources: Seq[CandidateSource[Target, RawCandidate] with CandidateSourceEligible[
Target,
RawCandidate
]] = {
Seq(
TripGeoCandidatesAdaptor(
tripTweetCandidateStore,
contentMixerStore,
safeCachedTweetyPieStoreV2,
cachedTweetyPieStoreV2NoVF,
globalStats
),
TopTweetsByGeoAdaptor(
geoDuckV2Store,
softUserLocationStore,
topTweetsByGeoStore,
topTweetsByGeoV2VersionedStore,
cachedTweetyPieStoreV2,
cachedTweetyPieStoreV2NoVF,
globalStats
)
)
}
}

View File

@ -0,0 +1,101 @@
package com.twitter.frigate.pushservice.adaptor
import com.twitter.finagle.stats.StatsReceiver
import com.twitter.frigate.common.base.CandidateSource
import com.twitter.frigate.common.base.CandidateSourceEligible
import com.twitter.frigate.common.base.DiscoverTwitterCandidate
import com.twitter.frigate.pushservice.model.PushTypes.RawCandidate
import com.twitter.frigate.pushservice.model.PushTypes.Target
import com.twitter.frigate.pushservice.params.{PushFeatureSwitchParams => FS}
import com.twitter.frigate.pushservice.predicate.DiscoverTwitterPredicate
import com.twitter.frigate.pushservice.predicate.TargetPredicates
import com.twitter.frigate.pushservice.util.PushAppPermissionUtil
import com.twitter.frigate.pushservice.util.PushDeviceUtil
import com.twitter.frigate.thriftscala.{CommonRecommendationType => CRT}
import com.twitter.util.Future
class OnboardingPushCandidateAdaptor(
globalStats: StatsReceiver)
extends CandidateSource[Target, RawCandidate]
with CandidateSourceEligible[Target, RawCandidate] {
override val name: String = this.getClass.getSimpleName
private[this] val stats = globalStats.scope(name)
private[this] val requestNum = stats.counter("request_num")
private[this] val addressBookCandNum = stats.counter("address_book_cand_num")
private[this] val completeOnboardingCandNum = stats.counter("complete_onboarding_cand_num")
private def generateOnboardingPushRawCandidate(
_target: Target,
_commonRecType: CRT
): RawCandidate = {
new RawCandidate with DiscoverTwitterCandidate {
override val target = _target
override val commonRecType = _commonRecType
}
}
private def getEligibleCandsForTarget(
target: Target
): Future[Option[Seq[RawCandidate]]] = {
val addressBookFatigue =
TargetPredicates
.pushRecTypeFatiguePredicate(
CRT.AddressBookUploadPush,
FS.FatigueForOnboardingPushes,
FS.MaxOnboardingPushInInterval,
stats)(Seq(target)).map(_.head)
val completeOnboardingFatigue =
TargetPredicates
.pushRecTypeFatiguePredicate(
CRT.CompleteOnboardingPush,
FS.FatigueForOnboardingPushes,
FS.MaxOnboardingPushInInterval,
stats)(Seq(target)).map(_.head)
Future
.join(
target.appPermissions,
addressBookFatigue,
completeOnboardingFatigue
).map {
case (appPermissionOpt, addressBookPredicate, completeOnboardingPredicate) =>
val addressBookUploaded =
PushAppPermissionUtil.hasTargetUploadedAddressBook(appPermissionOpt)
val abUploadCandidate =
if (!addressBookUploaded && addressBookPredicate && target.params(
FS.EnableAddressBookPush)) {
addressBookCandNum.incr()
Some(generateOnboardingPushRawCandidate(target, CRT.AddressBookUploadPush))
} else if (!addressBookUploaded && (completeOnboardingPredicate ||
target.params(FS.DisableOnboardingPushFatigue)) && target.params(
FS.EnableCompleteOnboardingPush)) {
completeOnboardingCandNum.incr()
Some(generateOnboardingPushRawCandidate(target, CRT.CompleteOnboardingPush))
} else None
val allCandidates =
Seq(abUploadCandidate).filter(_.isDefined).flatten
if (allCandidates.nonEmpty) Some(allCandidates) else None
}
}
override def get(inputTarget: Target): Future[Option[Seq[RawCandidate]]] = {
requestNum.incr()
val minDurationForMRElapsed =
DiscoverTwitterPredicate
.minDurationElapsedSinceLastMrPushPredicate(
name,
FS.MrMinDurationSincePushForOnboardingPushes,
stats)(Seq(inputTarget)).map(_.head)
minDurationForMRElapsed.flatMap { minDurationElapsed =>
if (minDurationElapsed) getEligibleCandsForTarget(inputTarget) else Future.None
}
}
override def isCandidateSourceAvailable(target: Target): Future[Boolean] = {
PushDeviceUtil
.isRecommendationsEligible(target).map(_ && target.params(FS.EnableOnboardingPushes))
}
}

View File

@ -0,0 +1,162 @@
package com.twitter.frigate.pushservice.adaptor
import com.twitter.content_mixer.thriftscala.ContentMixerRequest
import com.twitter.content_mixer.thriftscala.ContentMixerResponse
import com.twitter.explore_ranker.thriftscala.ExploreRankerRequest
import com.twitter.explore_ranker.thriftscala.ExploreRankerResponse
import com.twitter.finagle.stats.StatsReceiver
import com.twitter.frigate.common.base._
import com.twitter.frigate.common.candidate._
import com.twitter.frigate.common.store.RecentTweetsQuery
import com.twitter.frigate.common.store.interests.InterestsLookupRequestWithContext
import com.twitter.frigate.pushservice.model.PushTypes.RawCandidate
import com.twitter.frigate.pushservice.model.PushTypes.Target
import com.twitter.frigate.pushservice.params.PushFeatureSwitchParams
import com.twitter.frigate.pushservice.store._
import com.twitter.geoduck.common.thriftscala.Location
import com.twitter.geoduck.service.thriftscala.LocationResponse
import com.twitter.hermit.pop_geo.thriftscala.PopTweetsInPlace
import com.twitter.hermit.predicate.socialgraph.RelationEdge
import com.twitter.hermit.store.tweetypie.UserTweet
import com.twitter.interests.thriftscala.UserInterests
import com.twitter.interests_discovery.thriftscala.NonPersonalizedRecommendedLists
import com.twitter.interests_discovery.thriftscala.RecommendedListsRequest
import com.twitter.interests_discovery.thriftscala.RecommendedListsResponse
import com.twitter.recommendation.interests.discovery.core.model.InterestDomain
import com.twitter.stitch.tweetypie.TweetyPie.TweetyPieResult
import com.twitter.storehaus.ReadableStore
import com.twitter.trends.trip_v1.trip_tweets.thriftscala.TripDomain
import com.twitter.trends.trip_v1.trip_tweets.thriftscala.TripTweets
import com.twitter.tsp.thriftscala.TopicSocialProofRequest
import com.twitter.tsp.thriftscala.TopicSocialProofResponse
/**
* PushCandidateSourceGenerator generates candidate source list for a given Target user
*/
class PushCandidateSourceGenerator(
earlybirdCandidates: CandidateSource[EarlybirdCandidateSource.Query, EarlybirdCandidate],
userTweetEntityGraphCandidates: CandidateSource[UserTweetEntityGraphCandidates.Target, Candidate],
cachedTweetyPieStoreV2: ReadableStore[Long, TweetyPieResult],
safeCachedTweetyPieStoreV2: ReadableStore[Long, TweetyPieResult],
userTweetTweetyPieStore: ReadableStore[UserTweet, TweetyPieResult],
safeUserTweetTweetyPieStore: ReadableStore[UserTweet, TweetyPieResult],
cachedTweetyPieStoreV2NoVF: ReadableStore[Long, TweetyPieResult],
edgeStore: ReadableStore[RelationEdge, Boolean],
interestsLookupStore: ReadableStore[InterestsLookupRequestWithContext, UserInterests],
uttEntityHydrationStore: UttEntityHydrationStore,
geoDuckV2Store: ReadableStore[Long, LocationResponse],
topTweetsByGeoStore: ReadableStore[InterestDomain[String], Map[String, List[(Long, Double)]]],
topTweetsByGeoV2VersionedStore: ReadableStore[String, PopTweetsInPlace],
tweetImpressionsStore: TweetImpressionsStore,
recommendedTrendsCandidateSource: RecommendedTrendsCandidateSource,
recentTweetsByAuthorStore: ReadableStore[RecentTweetsQuery, Seq[Seq[Long]]],
topicSocialProofServiceStore: ReadableStore[TopicSocialProofRequest, TopicSocialProofResponse],
crMixerStore: CrMixerTweetStore,
contentMixerStore: ReadableStore[ContentMixerRequest, ContentMixerResponse],
exploreRankerStore: ReadableStore[ExploreRankerRequest, ExploreRankerResponse],
softUserLocationStore: ReadableStore[Long, Location],
tripTweetCandidateStore: ReadableStore[TripDomain, TripTweets],
listRecsStore: ReadableStore[String, NonPersonalizedRecommendedLists],
idsStore: ReadableStore[RecommendedListsRequest, RecommendedListsResponse]
)(
implicit val globalStats: StatsReceiver) {
private val earlyBirdFirstDegreeCandidateAdaptor = EarlyBirdFirstDegreeCandidateAdaptor(
earlybirdCandidates,
cachedTweetyPieStoreV2,
cachedTweetyPieStoreV2NoVF,
userTweetTweetyPieStore,
PushFeatureSwitchParams.NumberOfMaxEarlybirdInNetworkCandidatesParam,
globalStats
)
private val frsTweetCandidateAdaptor = FRSTweetCandidateAdaptor(
crMixerStore,
cachedTweetyPieStoreV2,
cachedTweetyPieStoreV2NoVF,
userTweetTweetyPieStore,
uttEntityHydrationStore,
topicSocialProofServiceStore,
globalStats
)
private val contentRecommenderMixerAdaptor = ContentRecommenderMixerAdaptor(
crMixerStore,
safeCachedTweetyPieStoreV2,
edgeStore,
topicSocialProofServiceStore,
uttEntityHydrationStore,
globalStats
)
private val tripGeoCandidatesAdaptor = TripGeoCandidatesAdaptor(
tripTweetCandidateStore,
contentMixerStore,
safeCachedTweetyPieStoreV2,
cachedTweetyPieStoreV2NoVF,
globalStats
)
val sources: Seq[
CandidateSource[Target, RawCandidate] with CandidateSourceEligible[
Target,
RawCandidate
]
] = {
Seq(
earlyBirdFirstDegreeCandidateAdaptor,
GenericCandidateAdaptor(
userTweetEntityGraphCandidates,
cachedTweetyPieStoreV2,
cachedTweetyPieStoreV2NoVF,
globalStats.scope("UserTweetEntityGraphCandidates")
),
new OnboardingPushCandidateAdaptor(globalStats),
TopTweetsByGeoAdaptor(
geoDuckV2Store,
softUserLocationStore,
topTweetsByGeoStore,
topTweetsByGeoV2VersionedStore,
cachedTweetyPieStoreV2,
cachedTweetyPieStoreV2NoVF,
globalStats
),
frsTweetCandidateAdaptor,
TopTweetImpressionsCandidateAdaptor(
recentTweetsByAuthorStore,
cachedTweetyPieStoreV2,
cachedTweetyPieStoreV2NoVF,
tweetImpressionsStore,
globalStats
),
TrendsCandidatesAdaptor(
softUserLocationStore,
recommendedTrendsCandidateSource,
safeCachedTweetyPieStoreV2,
cachedTweetyPieStoreV2NoVF,
safeUserTweetTweetyPieStore,
globalStats
),
contentRecommenderMixerAdaptor,
tripGeoCandidatesAdaptor,
HighQualityTweetsAdaptor(
tripTweetCandidateStore,
interestsLookupStore,
cachedTweetyPieStoreV2,
cachedTweetyPieStoreV2NoVF,
globalStats
),
ExploreVideoTweetCandidateAdaptor(
exploreRankerStore,
cachedTweetyPieStoreV2,
globalStats
),
ListsToRecommendCandidateAdaptor(
listRecsStore,
geoDuckV2Store,
idsStore,
globalStats
)
)
}
}

View File

@ -0,0 +1,326 @@
package com.twitter.frigate.pushservice.adaptor
import com.twitter.conversions.DurationOps._
import com.twitter.finagle.stats.StatsReceiver
import com.twitter.frigate.common.base.CandidateSource
import com.twitter.frigate.common.base.CandidateSourceEligible
import com.twitter.frigate.common.base.TopTweetImpressionsCandidate
import com.twitter.frigate.common.store.RecentTweetsQuery
import com.twitter.frigate.common.util.SnowflakeUtils
import com.twitter.frigate.pushservice.model.PushTypes.RawCandidate
import com.twitter.frigate.pushservice.model.PushTypes.Target
import com.twitter.frigate.pushservice.params.{PushFeatureSwitchParams => FS}
import com.twitter.frigate.pushservice.store.TweetImpressionsStore
import com.twitter.frigate.pushservice.util.PushDeviceUtil
import com.twitter.stitch.tweetypie.TweetyPie.TweetyPieResult
import com.twitter.storehaus.FutureOps
import com.twitter.storehaus.ReadableStore
import com.twitter.util.Future
case class TweetImpressionsCandidate(
tweetId: Long,
tweetyPieResultOpt: Option[TweetyPieResult],
impressionsCountOpt: Option[Long])
case class TopTweetImpressionsCandidateAdaptor(
recentTweetsFromTflockStore: ReadableStore[RecentTweetsQuery, Seq[Seq[Long]]],
tweetyPieStore: ReadableStore[Long, TweetyPieResult],
tweetyPieStoreNoVF: ReadableStore[Long, TweetyPieResult],
tweetImpressionsStore: TweetImpressionsStore,
globalStats: StatsReceiver)
extends CandidateSource[Target, RawCandidate]
with CandidateSourceEligible[Target, RawCandidate] {
private val stats = globalStats.scope("TopTweetImpressionsAdaptor")
private val tweetImpressionsCandsStat = stats.stat("top_tweet_impressions_cands_dist")
private val eligibleUsersCounter = stats.counter("eligible_users")
private val noneligibleUsersCounter = stats.counter("noneligible_users")
private val meetsMinTweetsRequiredCounter = stats.counter("meets_min_tweets_required")
private val belowMinTweetsRequiredCounter = stats.counter("below_min_tweets_required")
private val aboveMaxInboundFavoritesCounter = stats.counter("above_max_inbound_favorites")
private val meetsImpressionsRequiredCounter = stats.counter("meets_impressions_required")
private val belowImpressionsRequiredCounter = stats.counter("below_impressions_required")
private val meetsFavoritesThresholdCounter = stats.counter("meets_favorites_threshold")
private val aboveFavoritesThresholdCounter = stats.counter("above_favorites_threshold")
private val emptyImpressionsMapCounter = stats.counter("empty_impressions_map")
private val tflockResultsStat = stats.stat("tflock", "results")
private val emptyTflockResult = stats.counter("tflock", "empty_result")
private val nonEmptyTflockResult = stats.counter("tflock", "non_empty_result")
private val originalTweetsStat = stats.stat("tweets", "original_tweets")
private val retweetsStat = stats.stat("tweets", "retweets")
private val allRetweetsOnlyCounter = stats.counter("tweets", "all_retweets_only")
private val allOriginalTweetsOnlyCounter = stats.counter("tweets", "all_original_tweets_only")
private val emptyTweetypieMap = stats.counter("", "empty_tweetypie_map")
private val emptyTweetyPieResult = stats.stat("", "empty_tweetypie_result")
private val allEmptyTweetypieResults = stats.counter("", "all_empty_tweetypie_results")
private val eligibleUsersAfterImpressionsFilter =
stats.counter("eligible_users_after_impressions_filter")
private val eligibleUsersAfterFavoritesFilter =
stats.counter("eligible_users_after_favorites_filter")
private val eligibleUsersWithEligibleTweets =
stats.counter("eligible_users_with_eligible_tweets")
private val eligibleTweetCands = stats.stat("eligible_tweet_cands")
private val getCandsRequestCounter =
stats.counter("top_tweet_impressions_get_request")
override val name: String = this.getClass.getSimpleName
override def get(inputTarget: Target): Future[Option[Seq[RawCandidate]]] = {
getCandsRequestCounter.incr()
val eligibleCandidatesFut = getTweetImpressionsCandidates(inputTarget)
eligibleCandidatesFut.map { eligibleCandidates =>
if (eligibleCandidates.nonEmpty) {
eligibleUsersWithEligibleTweets.incr()
eligibleTweetCands.add(eligibleCandidates.size)
val candidate = getMostImpressionsTweet(eligibleCandidates)
Some(
Seq(
generateTopTweetImpressionsCandidate(
inputTarget,
candidate.tweetId,
candidate.tweetyPieResultOpt,
candidate.impressionsCountOpt.getOrElse(0L))))
} else None
}
}
private def getTweetImpressionsCandidates(
inputTarget: Target
): Future[Seq[TweetImpressionsCandidate]] = {
val originalTweets = getRecentOriginalTweetsForUser(inputTarget)
originalTweets.flatMap { tweetyPieResultsMap =>
val numDaysSearchForOriginalTweets =
inputTarget.params(FS.TopTweetImpressionsOriginalTweetsNumDaysSearch)
val moreRecentTweetIds =
getMoreRecentTweetIds(tweetyPieResultsMap.keySet.toSeq, numDaysSearchForOriginalTweets)
val isEligible = isEligibleUser(inputTarget, tweetyPieResultsMap, moreRecentTweetIds)
if (isEligible) filterByEligibility(inputTarget, tweetyPieResultsMap, moreRecentTweetIds)
else Future.Nil
}
}
private def getRecentOriginalTweetsForUser(
targetUser: Target
): Future[Map[Long, TweetyPieResult]] = {
val tweetyPieResultsMapFut = getTflockStoreResults(targetUser).flatMap { recentTweetIds =>
FutureOps.mapCollect((targetUser.params(FS.EnableVFInTweetypie) match {
case true => tweetyPieStore
case false => tweetyPieStoreNoVF
}).multiGet(recentTweetIds.toSet))
}
tweetyPieResultsMapFut.map { tweetyPieResultsMap =>
if (tweetyPieResultsMap.isEmpty) {
emptyTweetypieMap.incr()
Map.empty
} else removeRetweets(tweetyPieResultsMap)
}
}
private def getTflockStoreResults(targetUser: Target): Future[Seq[Long]] = {
val maxResults = targetUser.params(FS.TopTweetImpressionsRecentTweetsByAuthorStoreMaxResults)
val maxAge = targetUser.params(FS.TopTweetImpressionsTotalFavoritesLimitNumDaysSearch)
val recentTweetsQuery =
RecentTweetsQuery(
userIds = Seq(targetUser.targetId),
maxResults = maxResults,
maxAge = maxAge.days
)
recentTweetsFromTflockStore
.get(recentTweetsQuery).map {
case Some(tweetIdsAll) =>
val tweetIds = tweetIdsAll.headOption.getOrElse(Seq.empty)
val numTweets = tweetIds.size
if (numTweets > 0) {
tflockResultsStat.add(numTweets)
nonEmptyTflockResult.incr()
} else emptyTflockResult.incr()
tweetIds
case _ => Nil
}
}
private def removeRetweets(
tweetyPieResultsMap: Map[Long, Option[TweetyPieResult]]
): Map[Long, TweetyPieResult] = {
val nonEmptyTweetyPieResults: Map[Long, TweetyPieResult] = tweetyPieResultsMap.collect {
case (key, Some(value)) => (key, value)
}
emptyTweetyPieResult.add(tweetyPieResultsMap.size - nonEmptyTweetyPieResults.size)
if (nonEmptyTweetyPieResults.nonEmpty) {
val originalTweets = nonEmptyTweetyPieResults.filter {
case (_, tweetyPieResult) =>
tweetyPieResult.sourceTweet.isEmpty
}
val numOriginalTweets = originalTweets.size
val numRetweets = nonEmptyTweetyPieResults.size - originalTweets.size
originalTweetsStat.add(numOriginalTweets)
retweetsStat.add(numRetweets)
if (numRetweets == 0) allOriginalTweetsOnlyCounter.incr()
if (numOriginalTweets == 0) allRetweetsOnlyCounter.incr()
originalTweets
} else {
allEmptyTweetypieResults.incr()
Map.empty
}
}
private def getMoreRecentTweetIds(
tweetIds: Seq[Long],
numDays: Int
): Seq[Long] = {
tweetIds.filter { tweetId =>
SnowflakeUtils.isRecent(tweetId, numDays.days)
}
}
private def isEligibleUser(
inputTarget: Target,
tweetyPieResults: Map[Long, TweetyPieResult],
recentTweetIds: Seq[Long]
): Boolean = {
val minNumTweets = inputTarget.params(FS.TopTweetImpressionsMinNumOriginalTweets)
lazy val totalFavoritesLimit =
inputTarget.params(FS.TopTweetImpressionsTotalInboundFavoritesLimit)
if (recentTweetIds.size >= minNumTweets) {
meetsMinTweetsRequiredCounter.incr()
val isUnderLimit = isUnderTotalInboundFavoritesLimit(tweetyPieResults, totalFavoritesLimit)
if (isUnderLimit) eligibleUsersCounter.incr()
else {
aboveMaxInboundFavoritesCounter.incr()
noneligibleUsersCounter.incr()
}
isUnderLimit
} else {
belowMinTweetsRequiredCounter.incr()
noneligibleUsersCounter.incr()
false
}
}
private def getFavoriteCounts(
tweetyPieResult: TweetyPieResult
): Long = tweetyPieResult.tweet.counts.flatMap(_.favoriteCount).getOrElse(0L)
private def isUnderTotalInboundFavoritesLimit(
tweetyPieResults: Map[Long, TweetyPieResult],
totalFavoritesLimit: Long
): Boolean = {
val favoritesIterator = tweetyPieResults.valuesIterator.map(getFavoriteCounts)
val totalInboundFavorites = favoritesIterator.sum
totalInboundFavorites <= totalFavoritesLimit
}
def filterByEligibility(
inputTarget: Target,
tweetyPieResults: Map[Long, TweetyPieResult],
tweetIds: Seq[Long]
): Future[Seq[TweetImpressionsCandidate]] = {
lazy val minNumImpressions: Long = inputTarget.params(FS.TopTweetImpressionsMinRequired)
lazy val maxNumLikes: Long = inputTarget.params(FS.TopTweetImpressionsMaxFavoritesPerTweet)
for {
filteredImpressionsMap <- getFilteredImpressionsMap(tweetIds, minNumImpressions)
tweetIdsFilteredByFavorites <-
getTweetIdsFilteredByFavorites(filteredImpressionsMap.keySet, tweetyPieResults, maxNumLikes)
} yield {
if (filteredImpressionsMap.nonEmpty) eligibleUsersAfterImpressionsFilter.incr()
if (tweetIdsFilteredByFavorites.nonEmpty) eligibleUsersAfterFavoritesFilter.incr()
val candidates = tweetIdsFilteredByFavorites.map { tweetId =>
TweetImpressionsCandidate(
tweetId,
tweetyPieResults.get(tweetId),
filteredImpressionsMap.get(tweetId))
}
tweetImpressionsCandsStat.add(candidates.length)
candidates
}
}
private def getFilteredImpressionsMap(
tweetIds: Seq[Long],
minNumImpressions: Long
): Future[Map[Long, Long]] = {
getImpressionsCounts(tweetIds).map { impressionsMap =>
if (impressionsMap.isEmpty) emptyImpressionsMapCounter.incr()
impressionsMap.filter {
case (_, numImpressions) =>
val isValid = numImpressions >= minNumImpressions
if (isValid) {
meetsImpressionsRequiredCounter.incr()
} else {
belowImpressionsRequiredCounter.incr()
}
isValid
}
}
}
private def getTweetIdsFilteredByFavorites(
filteredTweetIds: Set[Long],
tweetyPieResults: Map[Long, TweetyPieResult],
maxNumLikes: Long
): Future[Seq[Long]] = {
val filteredByFavoritesTweetIds = filteredTweetIds.filter { tweetId =>
val tweetyPieResultOpt = tweetyPieResults.get(tweetId)
val isValid = tweetyPieResultOpt.exists { tweetyPieResult =>
getFavoriteCounts(tweetyPieResult) <= maxNumLikes
}
if (isValid) meetsFavoritesThresholdCounter.incr()
else aboveFavoritesThresholdCounter.incr()
isValid
}
Future(filteredByFavoritesTweetIds.toSeq)
}
private def getMostImpressionsTweet(
filteredResults: Seq[TweetImpressionsCandidate]
): TweetImpressionsCandidate = {
val maxImpressions: Long = filteredResults.map {
_.impressionsCountOpt.getOrElse(0L)
}.max
val mostImpressionsCandidates: Seq[TweetImpressionsCandidate] =
filteredResults.filter(_.impressionsCountOpt.getOrElse(0L) == maxImpressions)
mostImpressionsCandidates.maxBy(_.tweetId)
}
private def getImpressionsCounts(
tweetIds: Seq[Long]
): Future[Map[Long, Long]] = {
val impressionCountMap = tweetIds.map { tweetId =>
tweetId -> tweetImpressionsStore
.getCounts(tweetId).map(_.getOrElse(0L))
}.toMap
Future.collect(impressionCountMap)
}
private def generateTopTweetImpressionsCandidate(
inputTarget: Target,
_tweetId: Long,
result: Option[TweetyPieResult],
_impressionsCount: Long
): RawCandidate = {
new RawCandidate with TopTweetImpressionsCandidate {
override val target: Target = inputTarget
override val tweetId: Long = _tweetId
override val tweetyPieResult: Option[TweetyPieResult] = result
override val impressionsCount: Long = _impressionsCount
}
}
override def isCandidateSourceAvailable(target: Target): Future[Boolean] = {
val enabledTopTweetImpressionsNotification =
target.params(FS.EnableTopTweetImpressionsNotification)
PushDeviceUtil
.isRecommendationsEligible(target).map(_ && enabledTopTweetImpressionsNotification)
}
}

View File

@ -0,0 +1,413 @@
package com.twitter.frigate.pushservice.adaptor
import com.twitter.finagle.stats.Counter
import com.twitter.finagle.stats.Stat
import com.twitter.finagle.stats.StatsReceiver
import com.twitter.frigate.common.base.CandidateSource
import com.twitter.frigate.common.base.CandidateSourceEligible
import com.twitter.frigate.common.base.TweetCandidate
import com.twitter.frigate.common.predicate.CommonOutNetworkTweetCandidatesSourcePredicates.filterOutReplyTweet
import com.twitter.frigate.pushservice.model.PushTypes.RawCandidate
import com.twitter.frigate.pushservice.model.PushTypes.Target
import com.twitter.frigate.pushservice.model.PushTypes
import com.twitter.frigate.pushservice.params.PopGeoTweetVersion
import com.twitter.frigate.pushservice.params.PushParams
import com.twitter.frigate.pushservice.params.TopTweetsForGeoCombination
import com.twitter.frigate.pushservice.params.TopTweetsForGeoRankingFunction
import com.twitter.frigate.pushservice.params.{PushFeatureSwitchParams => FS}
import com.twitter.frigate.pushservice.predicate.DiscoverTwitterPredicate
import com.twitter.frigate.pushservice.predicate.TargetPredicates
import com.twitter.frigate.pushservice.util.MediaCRT
import com.twitter.frigate.pushservice.util.PushAdaptorUtil
import com.twitter.frigate.pushservice.util.PushDeviceUtil
import com.twitter.frigate.thriftscala.CommonRecommendationType
import com.twitter.geoduck.common.thriftscala.{Location => GeoLocation}
import com.twitter.geoduck.service.thriftscala.LocationResponse
import com.twitter.gizmoduck.thriftscala.UserType
import com.twitter.hermit.pop_geo.thriftscala.PopTweetsInPlace
import com.twitter.recommendation.interests.discovery.core.model.InterestDomain
import com.twitter.stitch.tweetypie.TweetyPie.TweetyPieResult
import com.twitter.storehaus.FutureOps
import com.twitter.storehaus.ReadableStore
import com.twitter.util.Future
import com.twitter.util.Time
import scala.collection.Map
case class PlaceTweetScore(place: String, tweetId: Long, score: Double) {
def toTweetScore: (Long, Double) = (tweetId, score)
}
case class TopTweetsByGeoAdaptor(
geoduckStoreV2: ReadableStore[Long, LocationResponse],
softUserGeoLocationStore: ReadableStore[Long, GeoLocation],
topTweetsByGeoStore: ReadableStore[InterestDomain[String], Map[String, List[(Long, Double)]]],
topTweetsByGeoStoreV2: ReadableStore[String, PopTweetsInPlace],
tweetyPieStore: ReadableStore[Long, TweetyPieResult],
tweetyPieStoreNoVF: ReadableStore[Long, TweetyPieResult],
globalStats: StatsReceiver)
extends CandidateSource[Target, RawCandidate]
with CandidateSourceEligible[Target, RawCandidate] {
override def name: String = this.getClass.getSimpleName
private[this] val stats = globalStats.scope("TopTweetsByGeoAdaptor")
private[this] val noGeohashUserCounter: Counter = stats.counter("users_with_no_geohash_counter")
private[this] val incomingRequestCounter: Counter = stats.counter("incoming_request_counter")
private[this] val incomingLoggedOutRequestCounter: Counter =
stats.counter("incoming_logged_out_request_counter")
private[this] val loggedOutRawCandidatesCounter =
stats.counter("logged_out_raw_candidates_counter")
private[this] val emptyLoggedOutRawCandidatesCounter =
stats.counter("logged_out_empty_raw_candidates")
private[this] val outputTopTweetsByGeoCounter: Stat =
stats.stat("output_top_tweets_by_geo_counter")
private[this] val loggedOutPopByGeoV2CandidatesCounter: Counter =
stats.counter("logged_out_pop_by_geo_candidates")
private[this] val dormantUsersSince14DaysCounter: Counter =
stats.counter("dormant_user_since_14_days_counter")
private[this] val dormantUsersSince30DaysCounter: Counter =
stats.counter("dormant_user_since_30_days_counter")
private[this] val nonDormantUsersSince14DaysCounter: Counter =
stats.counter("non_dormant_user_since_14_days_counter")
private[this] val topTweetsByGeoTake100Counter: Counter =
stats.counter("top_tweets_by_geo_take_100_counter")
private[this] val combinationRequestsCounter =
stats.scope("combination_method_request_counter")
private[this] val popGeoTweetVersionCounter =
stats.scope("popgeo_tweet_version_counter")
private[this] val nonReplyTweetsCounter = stats.counter("non_reply_tweets")
val MaxGeoHashSize = 4
private def constructKeys(
geohash: Option[String],
accountCountryCode: Option[String],
keyLengths: Seq[Int],
version: PopGeoTweetVersion.Value
): Set[String] = {
val geohashKeys = geohash match {
case Some(hash) => keyLengths.map { version + "_geohash_" + hash.take(_) }
case _ => Seq.empty
}
val accountCountryCodeKeys =
accountCountryCode.toSeq.map(version + "_country_" + _.toUpperCase)
(geohashKeys ++ accountCountryCodeKeys).toSet
}
def convertToPlaceTweetScore(
popTweetsInPlace: Seq[PopTweetsInPlace]
): Seq[PlaceTweetScore] = {
popTweetsInPlace.flatMap {
case p =>
p.popTweets.map {
case popTweet => PlaceTweetScore(p.place, popTweet.tweetId, popTweet.score)
}
}
}
def sortGeoHashTweets(
placeTweetScores: Seq[PlaceTweetScore],
rankingFunction: TopTweetsForGeoRankingFunction.Value
): Seq[PlaceTweetScore] = {
rankingFunction match {
case TopTweetsForGeoRankingFunction.Score =>
placeTweetScores.sortBy(_.score)(Ordering[Double].reverse)
case TopTweetsForGeoRankingFunction.GeohashLengthAndThenScore =>
placeTweetScores
.sortBy(row => (row.place.length, row.score))(Ordering[(Int, Double)].reverse)
}
}
def getResultsForLambdaStore(
inputTarget: Target,
geohash: Option[String],
store: ReadableStore[String, PopTweetsInPlace],
topk: Int,
version: PopGeoTweetVersion.Value
): Future[Seq[(Long, Double)]] = {
inputTarget.accountCountryCode.flatMap { countryCode =>
val keys = {
if (inputTarget.params(FS.EnableCountryCodeBackoffTopTweetsByGeo))
constructKeys(geohash, countryCode, inputTarget.params(FS.GeoHashLengthList), version)
else
constructKeys(geohash, None, inputTarget.params(FS.GeoHashLengthList), version)
}
FutureOps
.mapCollect(store.multiGet(keys)).map {
case geohashTweetMap =>
val popTweets =
geohashTweetMap.values.flatten.toSeq
val results = sortGeoHashTweets(
convertToPlaceTweetScore(popTweets),
inputTarget.params(FS.RankingFunctionForTopTweetsByGeo))
.map(_.toTweetScore).take(topk)
results
}
}
}
def getPopGeoTweetsForLoggedOutUsers(
inputTarget: Target,
store: ReadableStore[String, PopTweetsInPlace]
): Future[Seq[(Long, Double)]] = {
inputTarget.countryCode.flatMap { countryCode =>
val keys = constructKeys(None, countryCode, Seq(4), PopGeoTweetVersion.Prod)
FutureOps.mapCollect(store.multiGet(keys)).map {
case tweetMap =>
val tweets = tweetMap.values.flatten.toSeq
loggedOutPopByGeoV2CandidatesCounter.incr(tweets.size)
val popTweets = sortGeoHashTweets(
convertToPlaceTweetScore(tweets),
TopTweetsForGeoRankingFunction.Score).map(_.toTweetScore)
popTweets
}
}
}
def getRankedTweets(
inputTarget: Target,
geohash: Option[String]
): Future[Seq[(Long, Double)]] = {
val MaxTopTweetsByGeoCandidatesToTake =
inputTarget.params(FS.MaxTopTweetsByGeoCandidatesToTake)
val scoringFn: String = inputTarget.params(FS.ScoringFuncForTopTweetsByGeo)
val combinationMethod = inputTarget.params(FS.TopTweetsByGeoCombinationParam)
val popGeoTweetVersion = inputTarget.params(FS.PopGeoTweetVersionParam)
inputTarget.isHeavyUserState.map { isHeavyUser =>
stats
.scope(combinationMethod.toString).scope(popGeoTweetVersion.toString).scope(
"IsHeavyUser_" + isHeavyUser.toString).counter().incr()
}
combinationRequestsCounter.scope(combinationMethod.toString).counter().incr()
popGeoTweetVersionCounter.scope(popGeoTweetVersion.toString).counter().incr()
lazy val geoStoreResults = if (geohash.isDefined) {
val hash = geohash.get.take(MaxGeoHashSize)
topTweetsByGeoStore
.get(
InterestDomain[String](hash)
)
.map {
case Some(scoringFnToTweetsMapOpt) =>
val tweetsWithScore = scoringFnToTweetsMapOpt
.getOrElse(scoringFn, List.empty)
val sortedResults = sortGeoHashTweets(
tweetsWithScore.map {
case (tweetId, score) => PlaceTweetScore(hash, tweetId, score)
},
TopTweetsForGeoRankingFunction.Score
).map(_.toTweetScore).take(
MaxTopTweetsByGeoCandidatesToTake
)
sortedResults
case _ => Seq.empty
}
} else Future.value(Seq.empty)
lazy val versionPopGeoTweetResults =
getResultsForLambdaStore(
inputTarget,
geohash,
topTweetsByGeoStoreV2,
MaxTopTweetsByGeoCandidatesToTake,
popGeoTweetVersion
)
combinationMethod match {
case TopTweetsForGeoCombination.Default => geoStoreResults
case TopTweetsForGeoCombination.AccountsTweetFavAsBackfill =>
Future.join(geoStoreResults, versionPopGeoTweetResults).map {
case (geoStoreTweets, versionPopGeoTweets) =>
(geoStoreTweets ++ versionPopGeoTweets).take(MaxTopTweetsByGeoCandidatesToTake)
}
case TopTweetsForGeoCombination.AccountsTweetFavIntermixed =>
Future.join(geoStoreResults, versionPopGeoTweetResults).map {
case (geoStoreTweets, versionPopGeoTweets) =>
CandidateSource.interleaveSeqs(Seq(geoStoreTweets, versionPopGeoTweets))
}
}
}
override def get(inputTarget: Target): Future[Option[Seq[RawCandidate]]] = {
if (inputTarget.isLoggedOutUser) {
incomingLoggedOutRequestCounter.incr()
val rankedTweets = getPopGeoTweetsForLoggedOutUsers(inputTarget, topTweetsByGeoStoreV2)
val rawCandidates = {
rankedTweets.map { rt =>
FutureOps
.mapCollect(
tweetyPieStore
.multiGet(rt.map { case (tweetId, _) => tweetId }.toSet))
.map { tweetyPieResultMap =>
val results = buildTopTweetsByGeoRawCandidates(
inputTarget,
None,
tweetyPieResultMap
)
if (results.isEmpty) {
emptyLoggedOutRawCandidatesCounter.incr()
}
loggedOutRawCandidatesCounter.incr(results.size)
Some(results)
}
}.flatten
}
rawCandidates
} else {
incomingRequestCounter.incr()
getGeoHashForUsers(inputTarget).flatMap { geohash =>
if (geohash.isEmpty) noGeohashUserCounter.incr()
getRankedTweets(inputTarget, geohash).map { rt =>
if (rt.size == 100) {
topTweetsByGeoTake100Counter.incr(1)
}
FutureOps
.mapCollect((inputTarget.params(FS.EnableVFInTweetypie) match {
case true => tweetyPieStore
case false => tweetyPieStoreNoVF
}).multiGet(rt.map { case (tweetId, _) => tweetId }.toSet))
.map { tweetyPieResultMap =>
Some(
buildTopTweetsByGeoRawCandidates(
inputTarget,
None,
filterOutReplyTweet(
tweetyPieResultMap,
nonReplyTweetsCounter
)
)
)
}
}.flatten
}
}
}
private def getGeoHashForUsers(
inputTarget: Target
): Future[Option[String]] = {
inputTarget.targetUser.flatMap {
case Some(user) =>
user.userType match {
case UserType.Soft =>
softUserGeoLocationStore
.get(inputTarget.targetId)
.map(_.flatMap(_.geohash.flatMap(_.stringGeohash)))
case _ =>
geoduckStoreV2.get(inputTarget.targetId).map(_.flatMap(_.geohash))
}
case None => Future.None
}
}
private def buildTopTweetsByGeoRawCandidates(
target: PushTypes.Target,
locationName: Option[String],
topTweets: Map[Long, Option[TweetyPieResult]]
): Seq[RawCandidate with TweetCandidate] = {
val candidates = topTweets.map { tweetIdTweetyPieResultMap =>
PushAdaptorUtil.generateOutOfNetworkTweetCandidates(
inputTarget = target,
id = tweetIdTweetyPieResultMap._1,
mediaCRT = MediaCRT(
CommonRecommendationType.GeoPopTweet,
CommonRecommendationType.GeoPopTweet,
CommonRecommendationType.GeoPopTweet
),
result = tweetIdTweetyPieResultMap._2,
localizedEntity = None
)
}.toSeq
outputTopTweetsByGeoCounter.add(candidates.length)
candidates
}
private val topTweetsByGeoFrequencyPredicate = {
TargetPredicates
.pushRecTypeFatiguePredicate(
CommonRecommendationType.GeoPopTweet,
FS.TopTweetsByGeoPushInterval,
FS.MaxTopTweetsByGeoPushGivenInterval,
stats
)
}
def getAvailabilityForDormantUser(target: Target): Future[Boolean] = {
lazy val isDormantUserNotFatigued = topTweetsByGeoFrequencyPredicate(Seq(target)).map(_.head)
lazy val enableTopTweetsByGeoForDormantUsers =
target.params(FS.EnableTopTweetsByGeoCandidatesForDormantUsers)
target.lastHTLVisitTimestamp.flatMap {
case Some(lastHTLTimestamp) =>
val minTimeSinceLastLogin =
target.params(FS.MinimumTimeSinceLastLoginForGeoPopTweetPush).ago
val timeSinceInactive = target.params(FS.TimeSinceLastLoginForGeoPopTweetPush).ago
val lastActiveTimestamp = Time.fromMilliseconds(lastHTLTimestamp)
if (lastActiveTimestamp > minTimeSinceLastLogin) {
nonDormantUsersSince14DaysCounter.incr()
Future.False
} else {
dormantUsersSince14DaysCounter.incr()
isDormantUserNotFatigued.map { isUserNotFatigued =>
lastActiveTimestamp < timeSinceInactive &&
enableTopTweetsByGeoForDormantUsers &&
isUserNotFatigued
}
}
case _ =>
dormantUsersSince30DaysCounter.incr()
isDormantUserNotFatigued.map { isUserNotFatigued =>
enableTopTweetsByGeoForDormantUsers && isUserNotFatigued
}
}
}
def getAvailabilityForPlaybookSetUp(target: Target): Future[Boolean] = {
lazy val enableTopTweetsByGeoForNewUsers = target.params(FS.EnableTopTweetsByGeoCandidates)
val isTargetEligibleForMrFatigueCheck = target.isAccountAtleastNDaysOld(
target.params(FS.MrMinDurationSincePushForTopTweetsByGeoPushes))
val isMrFatigueCheckEnabled =
target.params(FS.EnableMrMinDurationSinceMrPushFatigue)
val applyPredicateForTopTweetsByGeo =
if (isMrFatigueCheckEnabled) {
if (isTargetEligibleForMrFatigueCheck) {
DiscoverTwitterPredicate
.minDurationElapsedSinceLastMrPushPredicate(
name,
FS.MrMinDurationSincePushForTopTweetsByGeoPushes,
stats
).andThen(
topTweetsByGeoFrequencyPredicate
)(Seq(target)).map(_.head)
} else {
Future.False
}
} else {
topTweetsByGeoFrequencyPredicate(Seq(target)).map(_.head)
}
applyPredicateForTopTweetsByGeo.map { predicateResult =>
predicateResult && enableTopTweetsByGeoForNewUsers
}
}
override def isCandidateSourceAvailable(target: Target): Future[Boolean] = {
if (target.isLoggedOutUser) {
Future.True
} else {
PushDeviceUtil
.isRecommendationsEligible(target).map(
_ && target.params(PushParams.PopGeoCandidatesDecider)).flatMap { isAvailable =>
if (isAvailable) {
Future
.join(getAvailabilityForDormantUser(target), getAvailabilityForPlaybookSetUp(target))
.map {
case (isAvailableForDormantUser, isAvailableForPlaybook) =>
isAvailableForDormantUser || isAvailableForPlaybook
case _ => false
}
} else Future.False
}
}
}
}

View File

@ -0,0 +1,215 @@
package com.twitter.frigate.pushservice.adaptor
import com.twitter.events.recos.thriftscala.DisplayLocation
import com.twitter.events.recos.thriftscala.TrendsContext
import com.twitter.finagle.stats.StatsReceiver
import com.twitter.frigate.common.base.CandidateSource
import com.twitter.frigate.common.base.CandidateSourceEligible
import com.twitter.frigate.common.base.TrendTweetCandidate
import com.twitter.frigate.common.base.TrendsCandidate
import com.twitter.frigate.common.candidate.RecommendedTrendsCandidateSource
import com.twitter.frigate.common.candidate.RecommendedTrendsCandidateSource.Query
import com.twitter.frigate.common.predicate.CommonOutNetworkTweetCandidatesSourcePredicates.filterOutReplyTweet
import com.twitter.frigate.pushservice.model.PushTypes.RawCandidate
import com.twitter.frigate.pushservice.model.PushTypes.Target
import com.twitter.frigate.pushservice.adaptor.TrendsCandidatesAdaptor._
import com.twitter.frigate.pushservice.params.PushFeatureSwitchParams
import com.twitter.frigate.pushservice.params.PushParams
import com.twitter.frigate.pushservice.predicate.TargetPredicates
import com.twitter.frigate.pushservice.util.PushDeviceUtil
import com.twitter.frigate.thriftscala.CommonRecommendationType
import com.twitter.geoduck.common.thriftscala.Location
import com.twitter.gizmoduck.thriftscala.UserType
import com.twitter.hermit.store.tweetypie.UserTweet
import com.twitter.stitch.tweetypie.TweetyPie.TweetyPieResult
import com.twitter.storehaus.ReadableStore
import com.twitter.util.Future
import scala.collection.Map
object TrendsCandidatesAdaptor {
type TweetId = Long
type EventId = Long
}
case class TrendsCandidatesAdaptor(
softUserGeoLocationStore: ReadableStore[Long, Location],
recommendedTrendsCandidateSource: RecommendedTrendsCandidateSource,
tweetyPieStore: ReadableStore[Long, TweetyPieResult],
tweetyPieStoreNoVF: ReadableStore[Long, TweetyPieResult],
safeUserTweetTweetyPieStore: ReadableStore[UserTweet, TweetyPieResult],
statsReceiver: StatsReceiver)
extends CandidateSource[Target, RawCandidate]
with CandidateSourceEligible[Target, RawCandidate] {
override val name = this.getClass.getSimpleName
private val trendAdaptorStats = statsReceiver.scope("TrendsCandidatesAdaptor")
private val trendTweetCandidateNumber = trendAdaptorStats.counter("trend_tweet_candidate")
private val nonReplyTweetsCounter = trendAdaptorStats.counter("non_reply_tweets")
private def getQuery(target: Target): Future[Query] = {
def getUserCountryCode(target: Target): Future[Option[String]] = {
target.targetUser.flatMap {
case Some(user) if user.userType == UserType.Soft =>
softUserGeoLocationStore
.get(user.id)
.map(_.flatMap(_.simpleRgcResult.flatMap(_.countryCodeAlpha2)))
case _ => target.accountCountryCode
}
}
for {
countryCode <- getUserCountryCode(target)
inferredLanguage <- target.inferredUserDeviceLanguage
} yield {
Query(
userId = target.targetId,
displayLocation = DisplayLocation.MagicRecs,
languageCode = inferredLanguage,
countryCode = countryCode,
maxResults = target.params(PushFeatureSwitchParams.MaxRecommendedTrendsToQuery)
)
}
}
/**
* Query candidates only if sent at most [[PushFeatureSwitchParams.MaxTrendTweetNotificationsInDuration]]
* trend tweet notifications in [[PushFeatureSwitchParams.TrendTweetNotificationsFatigueDuration]]
*/
val trendTweetFatiguePredicate = TargetPredicates.pushRecTypeFatiguePredicate(
CommonRecommendationType.TrendTweet,
PushFeatureSwitchParams.TrendTweetNotificationsFatigueDuration,
PushFeatureSwitchParams.MaxTrendTweetNotificationsInDuration,
trendAdaptorStats
)
private val recommendedTrendsWithTweetsCandidateSource: CandidateSource[
Target,
RawCandidate with TrendsCandidate
] = recommendedTrendsCandidateSource
.convert[Target, TrendsCandidate](
getQuery,
recommendedTrendsCandidateSource.identityCandidateMapper
)
.batchMapValues[Target, RawCandidate with TrendsCandidate](
trendsCandidatesToTweetCandidates(_, _, getTweetyPieResults))
private def getTweetyPieResults(
tweetIds: Seq[TweetId],
target: Target
): Future[Map[TweetId, TweetyPieResult]] = {
if (target.params(PushFeatureSwitchParams.EnableSafeUserTweetTweetypieStore)) {
Future
.collect(
safeUserTweetTweetyPieStore.multiGet(
tweetIds.toSet.map(UserTweet(_, Some(target.targetId))))).map {
_.collect {
case (userTweet, Some(tweetyPieResult)) => userTweet.tweetId -> tweetyPieResult
}
}
} else {
Future
.collect((target.params(PushFeatureSwitchParams.EnableVFInTweetypie) match {
case true => tweetyPieStore
case false => tweetyPieStoreNoVF
}).multiGet(tweetIds.toSet)).map { tweetyPieResultMap =>
filterOutReplyTweet(tweetyPieResultMap, nonReplyTweetsCounter).collect {
case (tweetId, Some(tweetyPieResult)) => tweetId -> tweetyPieResult
}
}
}
}
/**
*
* @param _target: [[Target]] object representing notificaion recipient user
* @param trendsCandidates: Sequence of [[TrendsCandidate]] returned from ERS
* @return: Seq of trends candidates expanded to associated tweets.
*/
private def trendsCandidatesToTweetCandidates(
_target: Target,
trendsCandidates: Seq[TrendsCandidate],
getTweetyPieResults: (Seq[TweetId], Target) => Future[Map[TweetId, TweetyPieResult]]
): Future[Seq[RawCandidate with TrendsCandidate]] = {
def generateTrendTweetCandidates(
trendCandidate: TrendsCandidate,
tweetyPieResults: Map[TweetId, TweetyPieResult]
) = {
val tweetIds = trendCandidate.context.curatedRepresentativeTweets.getOrElse(Seq.empty) ++
trendCandidate.context.algoRepresentativeTweets.getOrElse(Seq.empty)
tweetIds.flatMap { tweetId =>
tweetyPieResults.get(tweetId).map { _tweetyPieResult =>
new RawCandidate with TrendTweetCandidate {
override val trendId: String = trendCandidate.trendId
override val trendName: String = trendCandidate.trendName
override val landingUrl: String = trendCandidate.landingUrl
override val timeBoundedLandingUrl: Option[String] =
trendCandidate.timeBoundedLandingUrl
override val context: TrendsContext = trendCandidate.context
override val tweetyPieResult: Option[TweetyPieResult] = Some(_tweetyPieResult)
override val tweetId: TweetId = _tweetyPieResult.tweet.id
override val target: Target = _target
}
}
}
}
// collect all tweet ids associated with all trends
val allTweetIds = trendsCandidates.flatMap { trendsCandidate =>
val context = trendsCandidate.context
context.curatedRepresentativeTweets.getOrElse(Seq.empty) ++
context.algoRepresentativeTweets.getOrElse(Seq.empty)
}
getTweetyPieResults(allTweetIds, _target)
.map { tweetIdToTweetyPieResult =>
val trendTweetCandidates = trendsCandidates.flatMap { trendCandidate =>
val allTrendTweetCandidates = generateTrendTweetCandidates(
trendCandidate,
tweetIdToTweetyPieResult
)
val (tweetCandidatesFromCuratedTrends, tweetCandidatesFromNonCuratedTrends) =
allTrendTweetCandidates.partition(_.isCuratedTrend)
tweetCandidatesFromCuratedTrends.filter(
_.target.params(PushFeatureSwitchParams.EnableCuratedTrendTweets)) ++
tweetCandidatesFromNonCuratedTrends.filter(
_.target.params(PushFeatureSwitchParams.EnableNonCuratedTrendTweets))
}
trendTweetCandidateNumber.incr(trendTweetCandidates.size)
trendTweetCandidates
}
}
/**
*
* @param target: [[Target]] user
* @return: true if customer is eligible to receive trend tweet notifications
*
*/
override def isCandidateSourceAvailable(target: Target): Future[Boolean] = {
PushDeviceUtil
.isRecommendationsEligible(target)
.map(target.params(PushParams.TrendsCandidateDecider) && _)
}
override def get(target: Target): Future[Option[Seq[RawCandidate with TrendsCandidate]]] = {
recommendedTrendsWithTweetsCandidateSource
.get(target)
.flatMap {
case Some(candidates) if candidates.nonEmpty =>
trendTweetFatiguePredicate(Seq(target))
.map(_.head)
.map { isTargetFatigueEligible =>
if (isTargetFatigueEligible) Some(candidates)
else None
}
case _ => Future.None
}
}
}

View File

@ -0,0 +1,188 @@
package com.twitter.frigate.pushservice.adaptor
import com.twitter.content_mixer.thriftscala.ContentMixerProductResponse
import com.twitter.content_mixer.thriftscala.ContentMixerRequest
import com.twitter.content_mixer.thriftscala.ContentMixerResponse
import com.twitter.content_mixer.thriftscala.NotificationsTripTweetsProductContext
import com.twitter.content_mixer.thriftscala.Product
import com.twitter.content_mixer.thriftscala.ProductContext
import com.twitter.finagle.stats.StatsReceiver
import com.twitter.frigate.common.base.CandidateSource
import com.twitter.frigate.common.base.CandidateSourceEligible
import com.twitter.frigate.common.predicate.CommonOutNetworkTweetCandidatesSourcePredicates.filterOutReplyTweet
import com.twitter.frigate.pushservice.model.PushTypes.RawCandidate
import com.twitter.frigate.pushservice.model.PushTypes.Target
import com.twitter.frigate.pushservice.params.PushFeatureSwitchParams
import com.twitter.frigate.pushservice.params.PushParams
import com.twitter.frigate.pushservice.util.MediaCRT
import com.twitter.frigate.pushservice.util.PushAdaptorUtil
import com.twitter.frigate.pushservice.util.PushDeviceUtil
import com.twitter.frigate.thriftscala.CommonRecommendationType
import com.twitter.geoduck.util.country.CountryInfo
import com.twitter.product_mixer.core.thriftscala.ClientContext
import com.twitter.stitch.tweetypie.TweetyPie.TweetyPieResult
import com.twitter.storehaus.ReadableStore
import com.twitter.trends.trip_v1.trip_tweets.thriftscala.TripDomain
import com.twitter.trends.trip_v1.trip_tweets.thriftscala.TripTweets
import com.twitter.util.Future
case class TripGeoCandidatesAdaptor(
tripTweetCandidateStore: ReadableStore[TripDomain, TripTweets],
contentMixerStore: ReadableStore[ContentMixerRequest, ContentMixerResponse],
tweetyPieStore: ReadableStore[Long, TweetyPieResult],
tweetyPieStoreNoVF: ReadableStore[Long, TweetyPieResult],
statsReceiver: StatsReceiver)
extends CandidateSource[Target, RawCandidate]
with CandidateSourceEligible[Target, RawCandidate] {
override def name: String = this.getClass.getSimpleName
private val stats = statsReceiver.scope(name.stripSuffix("$"))
private val contentMixerRequests = stats.counter("getTripCandidatesContentMixerRequests")
private val loggedOutTripTweetIds = stats.counter("logged_out_trip_tweet_ids_count")
private val loggedOutRawCandidates = stats.counter("logged_out_raw_candidates_count")
private val rawCandidates = stats.counter("raw_candidates_count")
private val loggedOutEmptyplaceId = stats.counter("logged_out_empty_place_id_count")
private val loggedOutPlaceId = stats.counter("logged_out_place_id_count")
private val nonReplyTweetsCounter = stats.counter("non_reply_tweets")
override def isCandidateSourceAvailable(target: Target): Future[Boolean] = {
if (target.isLoggedOutUser) {
Future.True
} else {
for {
isRecommendationsSettingEnabled <- PushDeviceUtil.isRecommendationsEligible(target)
inferredLanguage <- target.inferredUserDeviceLanguage
} yield {
isRecommendationsSettingEnabled &&
inferredLanguage.nonEmpty &&
target.params(PushParams.TripGeoTweetCandidatesDecider)
}
}
}
private def buildRawCandidate(target: Target, tweetyPieResult: TweetyPieResult): RawCandidate = {
PushAdaptorUtil.generateOutOfNetworkTweetCandidates(
inputTarget = target,
id = tweetyPieResult.tweet.id,
mediaCRT = MediaCRT(
CommonRecommendationType.TripGeoTweet,
CommonRecommendationType.TripGeoTweet,
CommonRecommendationType.TripGeoTweet
),
result = Some(tweetyPieResult),
localizedEntity = None
)
}
override def get(target: Target): Future[Option[Seq[RawCandidate]]] = {
if (target.isLoggedOutUser) {
for {
tripTweetIds <- getTripCandidatesForLoggedOutTarget(target)
tweetyPieResults <- Future.collect(tweetyPieStoreNoVF.multiGet(tripTweetIds))
} yield {
val candidates = tweetyPieResults.values.flatten.map(buildRawCandidate(target, _))
if (candidates.nonEmpty) {
loggedOutRawCandidates.incr(candidates.size)
Some(candidates.toSeq)
} else None
}
} else {
for {
tripTweetIds <- getTripCandidatesContentMixer(target)
tweetyPieResults <-
Future.collect((target.params(PushFeatureSwitchParams.EnableVFInTweetypie) match {
case true => tweetyPieStore
case false => tweetyPieStoreNoVF
}).multiGet(tripTweetIds))
} yield {
val nonReplyTweets = filterOutReplyTweet(tweetyPieResults, nonReplyTweetsCounter)
val candidates = nonReplyTweets.values.flatten.map(buildRawCandidate(target, _))
if (candidates.nonEmpty && target.params(
PushFeatureSwitchParams.TripTweetCandidateReturnEnable)) {
rawCandidates.incr(candidates.size)
Some(candidates.toSeq)
} else None
}
}
}
private def getTripCandidatesContentMixer(
target: Target
): Future[Set[Long]] = {
contentMixerRequests.incr()
Future
.join(
target.inferredUserDeviceLanguage,
target.deviceInfo
)
.flatMap {
case (languageOpt, deviceInfoOpt) =>
contentMixerStore
.get(
ContentMixerRequest(
clientContext = ClientContext(
userId = Some(target.targetId),
languageCode = languageOpt,
userAgent = deviceInfoOpt.flatMap(_.guessedPrimaryDeviceUserAgent.map(_.toString))
),
product = Product.NotificationsTripTweets,
productContext = Some(
ProductContext.NotificationsTripTweetsProductContext(
NotificationsTripTweetsProductContext()
)),
cursor = None,
maxResults =
Some(target.params(PushFeatureSwitchParams.TripTweetMaxTotalCandidates))
)
).map {
_.map { rawResponse =>
val tripResponse =
rawResponse.contentMixerProductResponse
.asInstanceOf[
ContentMixerProductResponse.NotificationsTripTweetsProductResponse]
.notificationsTripTweetsProductResponse
tripResponse.results.map(_.tweetResult.tweetId).toSet
}.getOrElse(Set.empty)
}
}
}
private def getTripCandidatesForLoggedOutTarget(
target: Target
): Future[Set[Long]] = {
Future.join(target.targetLanguage, target.countryCode).flatMap {
case (Some(lang), Some(country)) =>
val placeId = CountryInfo.lookupByCode(country).map(_.placeIdLong)
if (placeId.nonEmpty) {
loggedOutPlaceId.incr()
} else {
loggedOutEmptyplaceId.incr()
}
val tripSource = "TOP_GEO_V3_LR"
val tripQuery = TripDomain(
sourceId = tripSource,
language = Some(lang),
placeId = placeId,
topicId = None
)
val response = tripTweetCandidateStore.get(tripQuery)
val tripTweetIds =
response.map { res =>
if (res.isDefined) {
res.get.tweets
.sortBy(_.score)(Ordering[Double].reverse).map(_.tweetId).toSet
} else {
Set.empty[Long]
}
}
tripTweetIds.map { ids => loggedOutTripTweetIds.incr(ids.size) }
tripTweetIds
case (_, _) => Future.value(Set.empty)
}
}
}

View File

@ -0,0 +1,461 @@
package com.twitter.frigate.pushservice.config
import com.twitter.abdecider.LoggingABDecider
import com.twitter.abuse.detection.scoring.thriftscala.TweetScoringRequest
import com.twitter.abuse.detection.scoring.thriftscala.TweetScoringResponse
import com.twitter.audience_rewards.thriftscala.HasSuperFollowingRelationshipRequest
import com.twitter.channels.common.thriftscala.ApiList
import com.twitter.datatools.entityservice.entities.sports.thriftscala._
import com.twitter.decider.Decider
import com.twitter.discovery.common.configapi.ConfigParamsBuilder
import com.twitter.escherbird.common.thriftscala.QualifiedId
import com.twitter.escherbird.metadata.thriftscala.EntityMegadata
import com.twitter.eventbus.client.EventBusPublisher
import com.twitter.finagle.stats.StatsReceiver
import com.twitter.finagle.thrift.ClientId
import com.twitter.frigate.common.base._
import com.twitter.frigate.common.candidate._
import com.twitter.frigate.common.history._
import com.twitter.frigate.common.ml.base._
import com.twitter.frigate.common.ml.feature._
import com.twitter.frigate.common.store._
import com.twitter.frigate.common.store.deviceinfo.DeviceInfo
import com.twitter.frigate.common.store.interests.InterestsLookupRequestWithContext
import com.twitter.frigate.common.store.interests.UserId
import com.twitter.frigate.common.util._
import com.twitter.frigate.data_pipeline.features_common._
import com.twitter.frigate.data_pipeline.thriftscala.UserHistoryKey
import com.twitter.frigate.data_pipeline.thriftscala.UserHistoryValue
import com.twitter.frigate.dau_model.thriftscala.DauProbability
import com.twitter.frigate.magic_events.thriftscala.FanoutEvent
import com.twitter.frigate.pushcap.thriftscala.PushcapUserHistory
import com.twitter.frigate.pushservice.ml._
import com.twitter.frigate.pushservice.params.DeciderKey
import com.twitter.frigate.pushservice.params.PushFeatureSwitchParams
import com.twitter.frigate.pushservice.params.PushFeatureSwitches
import com.twitter.frigate.pushservice.params.PushParams
import com.twitter.frigate.pushservice.send_handler.SendHandlerPushCandidateHydrator
import com.twitter.frigate.pushservice.refresh_handler.PushCandidateHydrator
import com.twitter.frigate.pushservice.store._
import com.twitter.frigate.pushservice.store.{Ibis2Store => PushIbis2Store}
import com.twitter.frigate.pushservice.take.NotificationServiceRequest
import com.twitter.frigate.pushservice.thriftscala.PushRequestScribe
import com.twitter.frigate.scribe.thriftscala.NotificationScribe
import com.twitter.frigate.thriftscala._
import com.twitter.frigate.user_states.thriftscala.MRUserHmmState
import com.twitter.geoduck.common.thriftscala.{Location => GeoLocation}
import com.twitter.geoduck.service.thriftscala.LocationResponse
import com.twitter.gizmoduck.thriftscala.User
import com.twitter.hermit.pop_geo.thriftscala.PopTweetsInPlace
import com.twitter.hermit.predicate.socialgraph.RelationEdge
import com.twitter.hermit.predicate.tweetypie.Perspective
import com.twitter.hermit.predicate.tweetypie.UserTweet
import com.twitter.hermit.store.semantic_core.SemanticEntityForQuery
import com.twitter.hermit.store.tweetypie.{UserTweet => TweetyPieUserTweet}
import com.twitter.hermit.stp.thriftscala.STPResult
import com.twitter.hss.api.thriftscala.UserHealthSignalResponse
import com.twitter.interests.thriftscala.InterestId
import com.twitter.interests.thriftscala.{UserInterests => Interests}
import com.twitter.interests_discovery.thriftscala.NonPersonalizedRecommendedLists
import com.twitter.interests_discovery.thriftscala.RecommendedListsRequest
import com.twitter.interests_discovery.thriftscala.RecommendedListsResponse
import com.twitter.livevideo.timeline.domain.v2.{Event => LiveEvent}
import com.twitter.ml.api.thriftscala.{DataRecord => ThriftDataRecord}
import com.twitter.ml.featurestore.lib.dynamic.DynamicFeatureStoreClient
import com.twitter.notificationservice.genericfeedbackstore.FeedbackPromptValue
import com.twitter.notificationservice.genericfeedbackstore.GenericFeedbackStore
import com.twitter.notificationservice.scribe.manhattan.GenericNotificationsFeedbackRequest
import com.twitter.notificationservice.thriftscala.CaretFeedbackDetails
import com.twitter.notificationservice.thriftscala.CreateGenericNotificationResponse
import com.twitter.nrel.heavyranker.CandidateFeatureHydrator
import com.twitter.nrel.heavyranker.{FeatureHydrator => MRFeatureHydrator}
import com.twitter.nrel.heavyranker.{TargetFeatureHydrator => RelevanceTargetFeatureHydrator}
import com.twitter.onboarding.task.service.thriftscala.FatigueFlowEnrollment
import com.twitter.permissions_storage.thriftscala.AppPermission
import com.twitter.recommendation.interests.discovery.core.model.InterestDomain
import com.twitter.recos.user_tweet_entity_graph.thriftscala.RecommendTweetEntityRequest
import com.twitter.recos.user_tweet_entity_graph.thriftscala.RecommendTweetEntityResponse
import com.twitter.recos.user_user_graph.thriftscala.RecommendUserRequest
import com.twitter.recos.user_user_graph.thriftscala.RecommendUserResponse
import com.twitter.rux.common.strato.thriftscala.UserTargetingProperty
import com.twitter.scio.nsfw_user_segmentation.thriftscala.NSFWProducer
import com.twitter.scio.nsfw_user_segmentation.thriftscala.NSFWUserSegmentation
import com.twitter.search.common.features.thriftscala.ThriftSearchResultFeatures
import com.twitter.search.earlybird.thriftscala.EarlybirdRequest
import com.twitter.search.earlybird.thriftscala.ThriftSearchResult
import com.twitter.service.gen.scarecrow.thriftscala.Event
import com.twitter.service.gen.scarecrow.thriftscala.TieredActionResult
import com.twitter.service.metastore.gen.thriftscala.Location
import com.twitter.service.metastore.gen.thriftscala.UserLanguages
import com.twitter.servo.decider.DeciderGateBuilder
import com.twitter.simclusters_v2.thriftscala.SimClustersInferredEntities
import com.twitter.stitch.tweetypie.TweetyPie.TweetyPieResult
import com.twitter.storehaus.ReadableStore
import com.twitter.strato.columns.frigate.logged_out_web_notifications.thriftscala.LOWebNotificationMetadata
import com.twitter.strato.columns.notifications.thriftscala.SourceDestUserRequest
import com.twitter.strato.client.{UserId => StratoUserId}
import com.twitter.timelines.configapi
import com.twitter.timelines.configapi.CompositeConfig
import com.twitter.timelinescorer.thriftscala.v1.ScoredTweet
import com.twitter.topiclisting.TopicListing
import com.twitter.trends.trip_v1.trip_tweets.thriftscala.TripDomain
import com.twitter.trends.trip_v1.trip_tweets.thriftscala.TripTweets
import com.twitter.tsp.thriftscala.TopicSocialProofRequest
import com.twitter.tsp.thriftscala.TopicSocialProofResponse
import com.twitter.ubs.thriftscala.SellerTrack
import com.twitter.ubs.thriftscala.AudioSpace
import com.twitter.ubs.thriftscala.Participants
import com.twitter.ubs.thriftscala.SellerApplicationState
import com.twitter.user_session_store.thriftscala.UserSession
import com.twitter.util.Duration
import com.twitter.util.Future
import com.twitter.wtf.scalding.common.thriftscala.UserFeatures
trait Config {
self =>
def isServiceLocal: Boolean
def localConfigRepoPath: String
def inMemCacheOff: Boolean
def historyStore: PushServiceHistoryStore
def emailHistoryStore: PushServiceHistoryStore
def strongTiesStore: ReadableStore[Long, STPResult]
def safeUserStore: ReadableStore[Long, User]
def deviceInfoStore: ReadableStore[Long, DeviceInfo]
def edgeStore: ReadableStore[RelationEdge, Boolean]
def socialGraphServiceProcessStore: ReadableStore[RelationEdge, Boolean]
def userUtcOffsetStore: ReadableStore[Long, Duration]
def cachedTweetyPieStoreV2: ReadableStore[Long, TweetyPieResult]
def safeCachedTweetyPieStoreV2: ReadableStore[Long, TweetyPieResult]
def userTweetTweetyPieStore: ReadableStore[TweetyPieUserTweet, TweetyPieResult]
def safeUserTweetTweetyPieStore: ReadableStore[TweetyPieUserTweet, TweetyPieResult]
def cachedTweetyPieStoreV2NoVF: ReadableStore[Long, TweetyPieResult]
def tweetContentFeatureCacheStore: ReadableStore[Long, ThriftDataRecord]
def scarecrowCheckEventStore: ReadableStore[Event, TieredActionResult]
def userTweetPerspectiveStore: ReadableStore[UserTweet, Perspective]
def userCountryStore: ReadableStore[Long, Location]
def pushInfoStore: ReadableStore[Long, UserForPushTargeting]
def loggedOutPushInfoStore: ReadableStore[Long, LOWebNotificationMetadata]
def tweetImpressionStore: ReadableStore[Long, Seq[Long]]
def audioSpaceStore: ReadableStore[String, AudioSpace]
def basketballGameScoreStore: ReadableStore[QualifiedId, BasketballGameLiveUpdate]
def baseballGameScoreStore: ReadableStore[QualifiedId, BaseballGameLiveUpdate]
def cricketMatchScoreStore: ReadableStore[QualifiedId, CricketMatchLiveUpdate]
def soccerMatchScoreStore: ReadableStore[QualifiedId, SoccerMatchLiveUpdate]
def nflGameScoreStore: ReadableStore[QualifiedId, NflFootballGameLiveUpdate]
def topicSocialProofServiceStore: ReadableStore[TopicSocialProofRequest, TopicSocialProofResponse]
def spaceDeviceFollowStore: ReadableStore[SourceDestUserRequest, Boolean]
def audioSpaceParticipantsStore: ReadableStore[String, Participants]
def notificationServiceSender: ReadableStore[
NotificationServiceRequest,
CreateGenericNotificationResponse
]
def ocfFatigueStore: ReadableStore[OCFHistoryStoreKey, FatigueFlowEnrollment]
def dauProbabilityStore: ReadableStore[Long, DauProbability]
def hydratedLabeledPushRecsStore: ReadableStore[UserHistoryKey, UserHistoryValue]
def userHTLLastVisitStore: ReadableStore[Long, Seq[Long]]
def userLanguagesStore: ReadableStore[Long, UserLanguages]
def topTweetsByGeoStore: ReadableStore[InterestDomain[String], Map[String, List[
(Long, Double)
]]]
def topTweetsByGeoV2VersionedStore: ReadableStore[String, PopTweetsInPlace]
lazy val pushRecItemStore: ReadableStore[PushRecItemsKey, RecItems] = PushRecItemStore(
hydratedLabeledPushRecsStore
)
lazy val labeledPushRecsVerifyingStore: ReadableStore[
LabeledPushRecsVerifyingStoreKey,
LabeledPushRecsVerifyingStoreResponse
] =
LabeledPushRecsVerifyingStore(
hydratedLabeledPushRecsStore,
historyStore
)
lazy val labeledPushRecsDecideredStore: ReadableStore[LabeledPushRecsStoreKey, UserHistoryValue] =
LabeledPushRecsDecideredStore(
labeledPushRecsVerifyingStore,
useHydratedLabeledSendsForFeaturesDeciderKey,
verifyHydratedLabeledSendsForFeaturesDeciderKey
)
def onlineUserHistoryStore: ReadableStore[OnlineUserHistoryKey, UserHistoryValue]
def nsfwConsumerStore: ReadableStore[Long, NSFWUserSegmentation]
def nsfwProducerStore: ReadableStore[Long, NSFWProducer]
def popGeoLists: ReadableStore[String, NonPersonalizedRecommendedLists]
def listAPIStore: ReadableStore[Long, ApiList]
def openedPushByHourAggregatedStore: ReadableStore[Long, Map[Int, Int]]
def userHealthSignalStore: ReadableStore[Long, UserHealthSignalResponse]
def reactivatedUserInfoStore: ReadableStore[Long, String]
def weightedOpenOrNtabClickModelScorer: PushMLModelScorer
def optoutModelScorer: PushMLModelScorer
def filteringModelScorer: PushMLModelScorer
def recentFollowsStore: ReadableStore[Long, Seq[Long]]
def geoDuckV2Store: ReadableStore[UserId, LocationResponse]
def realGraphScoresTop500InStore: ReadableStore[Long, Map[Long, Double]]
def tweetEntityGraphStore: ReadableStore[
RecommendTweetEntityRequest,
RecommendTweetEntityResponse
]
def userUserGraphStore: ReadableStore[RecommendUserRequest, RecommendUserResponse]
def userFeaturesStore: ReadableStore[Long, UserFeatures]
def userTargetingPropertyStore: ReadableStore[Long, UserTargetingProperty]
def timelinesUserSessionStore: ReadableStore[Long, UserSession]
def optOutUserInterestsStore: ReadableStore[UserId, Seq[InterestId]]
def ntabCaretFeedbackStore: ReadableStore[GenericNotificationsFeedbackRequest, Seq[
CaretFeedbackDetails
]]
def genericFeedbackStore: ReadableStore[FeedbackRequest, Seq[
FeedbackPromptValue
]]
def genericNotificationFeedbackStore: GenericFeedbackStore
def semanticCoreMegadataStore: ReadableStore[
SemanticEntityForQuery,
EntityMegadata
]
def tweetHealthScoreStore: ReadableStore[TweetScoringRequest, TweetScoringResponse]
def earlybirdFeatureStore: ReadableStore[Long, ThriftSearchResultFeatures]
def earlybirdFeatureBuilder: FeatureBuilder[Long]
// Feature builders
def tweetAuthorLocationFeatureBuilder: FeatureBuilder[Location]
def tweetAuthorLocationFeatureBuilderById: FeatureBuilder[Long]
def socialContextActionsFeatureBuilder: FeatureBuilder[SocialContextActions]
def tweetContentFeatureBuilder: FeatureBuilder[Long]
def tweetAuthorRecentRealGraphFeatureBuilder: FeatureBuilder[RealGraphEdge]
def socialContextRecentRealGraphFeatureBuilder: FeatureBuilder[Set[RealGraphEdge]]
def tweetSocialProofFeatureBuilder: FeatureBuilder[TweetSocialProofKey]
def targetUserFullRealGraphFeatureBuilder: FeatureBuilder[TargetFullRealGraphFeatureKey]
def postProcessingFeatureBuilder: PostProcessingFeatureBuilder
def mrOfflineUserCandidateSparseAggregatesFeatureBuilder: FeatureBuilder[
OfflineSparseAggregateKey
]
def mrOfflineUserAggregatesFeatureBuilder: FeatureBuilder[Long]
def mrOfflineUserCandidateAggregatesFeatureBuilder: FeatureBuilder[OfflineAggregateKey]
def tweetAnnotationsFeatureBuilder: FeatureBuilder[Long]
def targetUserMediaRepresentationFeatureBuilder: FeatureBuilder[Long]
def targetLevelFeatureBuilder: FeatureBuilder[MrRequestContextForFeatureStore]
def candidateLevelFeatureBuilder: FeatureBuilder[EntityRequestContextForFeatureStore]
def targetFeatureHydrator: RelevanceTargetFeatureHydrator
def useHydratedLabeledSendsForFeaturesDeciderKey: String =
DeciderKey.useHydratedLabeledSendsForFeaturesDeciderKey.toString
def verifyHydratedLabeledSendsForFeaturesDeciderKey: String =
DeciderKey.verifyHydratedLabeledSendsForFeaturesDeciderKey.toString
def lexServiceStore: ReadableStore[EventRequest, LiveEvent]
def userMediaRepresentationStore: ReadableStore[Long, UserMediaRepresentation]
def producerMediaRepresentationStore: ReadableStore[Long, UserMediaRepresentation]
def mrUserStatePredictionStore: ReadableStore[Long, MRUserHmmState]
def pushcapDynamicPredictionStore: ReadableStore[Long, PushcapUserHistory]
def earlybirdCandidateSource: EarlybirdCandidateSource
def earlybirdSearchStore: ReadableStore[EarlybirdRequest, Seq[ThriftSearchResult]]
def earlybirdSearchDest: String
def pushserviceThriftClientId: ClientId
def simClusterToEntityStore: ReadableStore[Int, SimClustersInferredEntities]
def fanoutMetadataStore: ReadableStore[(Long, Long), FanoutEvent]
/**
* PostRanking Feature Store Client
*/
def postRankingFeatureStoreClient: DynamicFeatureStoreClient[MrRequestContextForFeatureStore]
/**
* ReadableStore to fetch [[UserInterests]] from INTS service
*/
def interestsWithLookupContextStore: ReadableStore[InterestsLookupRequestWithContext, Interests]
/**
*
* @return: [[TopicListing]] object to fetch paused topics and scope from productId
*/
def topicListing: TopicListing
/**
*
* @return: [[UttEntityHydrationStore]] object
*/
def uttEntityHydrationStore: UttEntityHydrationStore
def appPermissionStore: ReadableStore[(Long, (String, String)), AppPermission]
lazy val userTweetEntityGraphCandidates: UserTweetEntityGraphCandidates =
UserTweetEntityGraphCandidates(
cachedTweetyPieStoreV2,
tweetEntityGraphStore,
PushParams.UTEGTweetCandidateSourceParam,
PushFeatureSwitchParams.NumberOfMaxUTEGCandidatesQueriedParam,
PushParams.AllowOneSocialProofForTweetInUTEGParam,
PushParams.OutNetworkTweetsOnlyForUTEGParam,
PushFeatureSwitchParams.MaxTweetAgeParam
)(statsReceiver)
def pushSendEventBusPublisher: EventBusPublisher[NotificationScribe]
// miscs.
def isProd: Boolean
implicit def statsReceiver: StatsReceiver
def decider: Decider
def abDecider: LoggingABDecider
def casLock: CasLock
def pushIbisV2Store: PushIbis2Store
// scribe
def notificationScribe(data: NotificationScribe): Unit
def requestScribe(data: PushRequestScribe): Unit
def init(): Future[Unit] = Future.Done
def configParamsBuilder: ConfigParamsBuilder
def candidateFeatureHydrator: CandidateFeatureHydrator
def featureHydrator: MRFeatureHydrator
def candidateHydrator: PushCandidateHydrator
def sendHandlerCandidateHydrator: SendHandlerPushCandidateHydrator
lazy val overridesConfig: configapi.Config = {
val pushFeatureSwitchConfigs: configapi.Config = PushFeatureSwitches(
deciderGateBuilder = new DeciderGateBuilder(decider),
statsReceiver = statsReceiver
).config
new CompositeConfig(Seq(pushFeatureSwitchConfigs))
}
def realTimeClientEventStore: RealTimeClientEventStore
def inlineActionHistoryStore: ReadableStore[Long, Seq[(Long, String)]]
def softUserGeoLocationStore: ReadableStore[Long, GeoLocation]
def tweetTranslationStore: ReadableStore[TweetTranslationStore.Key, TweetTranslationStore.Value]
def tripTweetCandidateStore: ReadableStore[TripDomain, TripTweets]
def softUserFollowingStore: ReadableStore[User, Seq[Long]]
def superFollowEligibilityUserStore: ReadableStore[Long, Boolean]
def superFollowCreatorTweetCountStore: ReadableStore[StratoUserId, Int]
def hasSuperFollowingRelationshipStore: ReadableStore[
HasSuperFollowingRelationshipRequest,
Boolean
]
def superFollowApplicationStatusStore: ReadableStore[(Long, SellerTrack), SellerApplicationState]
def recentHistoryCacheClient: RecentHistoryCacheClient
def openAppUserStore: ReadableStore[Long, Boolean]
def loggedOutHistoryStore: PushServiceHistoryStore
def idsStore: ReadableStore[RecommendedListsRequest, RecommendedListsResponse]
def htlScoreStore(userId: Long): ReadableStore[Long, ScoredTweet]
}

View File

@ -0,0 +1,16 @@
package com.twitter.frigate.pushservice.config
import com.twitter.frigate.common.util.Experiments
object ExperimentsWithStats {
/**
* Add an experiment here to collect detailed pushservice stats.
*
* ! Important !
* Keep this set small and remove experiments when you don't need the stats anymore.
*/
final val PushExperiments: Set[String] = Set(
Experiments.MRAndroidInlineActionHoldback.exptName,
)
}

View File

@ -0,0 +1,230 @@
package com.twitter.frigate.pushservice.config
import com.twitter.abdecider.LoggingABDecider
import com.twitter.bijection.scrooge.BinaryScalaCodec
import com.twitter.bijection.Base64String
import com.twitter.bijection.Injection
import com.twitter.conversions.DurationOps._
import com.twitter.decider.Decider
import com.twitter.featureswitches.v2.FeatureSwitches
import com.twitter.finagle.mtls.authentication.ServiceIdentifier
import com.twitter.finagle.stats.StatsReceiver
import com.twitter.finagle.thrift.ClientId
import com.twitter.finagle.thrift.RichClientParam
import com.twitter.finagle.util.DefaultTimer
import com.twitter.frigate.common.config.RateLimiterGenerator
import com.twitter.frigate.common.filter.DynamicRequestMeterFilter
import com.twitter.frigate.common.history.ManhattanHistoryStore
import com.twitter.frigate.common.history.InvalidatingAfterWritesPushServiceHistoryStore
import com.twitter.frigate.common.history.ManhattanKVHistoryStore
import com.twitter.frigate.common.history.PushServiceHistoryStore
import com.twitter.frigate.common.history.SimplePushServiceHistoryStore
import com.twitter.frigate.common.util._
import com.twitter.frigate.data_pipeline.features_common.FeatureStoreUtil
import com.twitter.frigate.data_pipeline.features_common.TargetLevelFeaturesConfig
import com.twitter.frigate.pushservice.model.PushTypes.Target
import com.twitter.frigate.pushservice.params.DeciderKey
import com.twitter.frigate.pushservice.params.PushQPSLimitConstants
import com.twitter.frigate.pushservice.params.PushServiceTunableKeys
import com.twitter.frigate.pushservice.params.ShardParams
import com.twitter.frigate.pushservice.store.PushIbis2Store
import com.twitter.frigate.pushservice.thriftscala.PushRequestScribe
import com.twitter.frigate.scribe.thriftscala.NotificationScribe
import com.twitter.ibis2.service.thriftscala.Ibis2Service
import com.twitter.logging.Logger
import com.twitter.notificationservice.api.thriftscala.DeleteCurrentTimelineForUserRequest
import com.twitter.notificationservice.api.thriftscala.NotificationApi
import com.twitter.notificationservice.api.thriftscala.NotificationApi$FinagleClient
import com.twitter.notificationservice.thriftscala.CreateGenericNotificationRequest
import com.twitter.notificationservice.thriftscala.CreateGenericNotificationResponse
import com.twitter.notificationservice.thriftscala.DeleteGenericNotificationRequest
import com.twitter.notificationservice.thriftscala.NotificationService
import com.twitter.notificationservice.thriftscala.NotificationService$FinagleClient
import com.twitter.servo.decider.DeciderGateBuilder
import com.twitter.util.tunable.TunableMap
import com.twitter.util.Future
import com.twitter.util.Timer
case class ProdConfig(
override val isServiceLocal: Boolean,
override val localConfigRepoPath: String,
override val inMemCacheOff: Boolean,
override val decider: Decider,
override val abDecider: LoggingABDecider,
override val featureSwitches: FeatureSwitches,
override val shardParams: ShardParams,
override val serviceIdentifier: ServiceIdentifier,
override val tunableMap: TunableMap,
)(
implicit val statsReceiver: StatsReceiver)
extends {
// Due to trait initialization logic in Scala, any abstract members declared in Config or
// DeployConfig should be declared in this block. Otherwise the abstract member might initialize to
// null if invoked before object creation finishing.
val log = Logger("ProdConfig")
// Deciders
val isPushserviceCanaryDeepbirdv2CanaryClusterEnabled = decider
.feature(DeciderKey.enablePushserviceDeepbirdv2CanaryClusterDeciderKey.toString).isAvailable
// Client ids
val notifierThriftClientId = ClientId("frigate-notifier.prod")
val loggedOutNotifierThriftClientId = ClientId("frigate-logged-out-notifier.prod")
val pushserviceThriftClientId: ClientId = ClientId("frigate-pushservice.prod")
// Dests
val frigateHistoryCacheDest = "/s/cache/frigate_history"
val memcacheCASDest = "/s/cache/magic_recs_cas:twemcaches"
val historyStoreMemcacheDest =
"/srv#/prod/local/cache/magic_recs_history:twemcaches"
val deepbirdv2PredictionServiceDest =
if (serviceIdentifier.service.equals("frigate-pushservice-canary") &&
isPushserviceCanaryDeepbirdv2CanaryClusterEnabled)
"/s/frigate/deepbirdv2-magicrecs-canary"
else "/s/frigate/deepbirdv2-magicrecs"
override val fanoutMetadataColumn = "frigate/magicfanout/prod/mh/fanoutMetadata"
override val timer: Timer = DefaultTimer
override val featureStoreUtil = FeatureStoreUtil.withParams(Some(serviceIdentifier))
override val targetLevelFeaturesConfig = TargetLevelFeaturesConfig()
val pushServiceMHCacheDest = "/s/cache/pushservice_mh"
val pushServiceCoreSvcsCacheDest = "/srv#/prod/local/cache/pushservice_core_svcs"
val userTweetEntityGraphDest = "/s/cassowary/user_tweet_entity_graph"
val userUserGraphDest = "/s/cassowary/user_user_graph"
val lexServiceDest = "/s/live-video/timeline-thrift"
val entityGraphCacheDest = "/s/cache/pushservice_entity_graph"
override val pushIbisV2Store = {
val service = Finagle.readOnlyThriftService(
"ibis-v2-service",
"/s/ibis2/ibis2",
statsReceiver,
notifierThriftClientId,
requestTimeout = 3.seconds,
tries = 3,
mTLSServiceIdentifier = Some(serviceIdentifier)
)
// according to ibis team, it is safe to retry on timeout, write & channel closed exceptions.
val pushIbisClient = new Ibis2Service.FinagledClient(
new DynamicRequestMeterFilter(
tunableMap(PushServiceTunableKeys.IbisQpsLimitTunableKey),
RateLimiterGenerator.asTuple(_, shardParams.numShards, 20),
PushQPSLimitConstants.IbisOrNTabQPSForRFPH
)(timer).andThen(service),
RichClientParam(serviceName = "ibis-v2-service")
)
PushIbis2Store(pushIbisClient)
}
val notificationServiceClient: NotificationService$FinagleClient = {
val service = Finagle.readWriteThriftService(
"notificationservice",
"/s/notificationservice/notificationservice",
statsReceiver,
pushserviceThriftClientId,
requestTimeout = 10.seconds,
mTLSServiceIdentifier = Some(serviceIdentifier)
)
new NotificationService.FinagledClient(
new DynamicRequestMeterFilter(
tunableMap(PushServiceTunableKeys.NtabQpsLimitTunableKey),
RateLimiterGenerator.asTuple(_, shardParams.numShards, 20),
PushQPSLimitConstants.IbisOrNTabQPSForRFPH)(timer).andThen(service),
RichClientParam(serviceName = "notificationservice")
)
}
val notificationServiceApiClient: NotificationApi$FinagleClient = {
val service = Finagle.readWriteThriftService(
"notificationservice-api",
"/s/notificationservice/notificationservice-api:thrift",
statsReceiver,
pushserviceThriftClientId,
requestTimeout = 10.seconds,
mTLSServiceIdentifier = Some(serviceIdentifier)
)
new NotificationApi.FinagledClient(
new DynamicRequestMeterFilter(
tunableMap(PushServiceTunableKeys.NtabQpsLimitTunableKey),
RateLimiterGenerator.asTuple(_, shardParams.numShards, 20),
PushQPSLimitConstants.IbisOrNTabQPSForRFPH)(timer).andThen(service),
RichClientParam(serviceName = "notificationservice-api")
)
}
val mrRequestScriberNode = "mr_request_scribe"
val loggedOutMrRequestScriberNode = "lo_mr_request_scribe"
override val pushSendEventStreamName = "frigate_pushservice_send_event_prod"
} with DeployConfig {
// Scribe
private val notificationScribeLog = Logger("notification_scribe")
private val notificationScribeInjection: Injection[NotificationScribe, String] = BinaryScalaCodec(
NotificationScribe
) andThen Injection.connect[Array[Byte], Base64String, String]
override def notificationScribe(data: NotificationScribe): Unit = {
val logEntry: String = notificationScribeInjection(data)
notificationScribeLog.info(logEntry)
}
// History Store - Invalidates cached history after writes
override val historyStore = new InvalidatingAfterWritesPushServiceHistoryStore(
ManhattanHistoryStore(notificationHistoryStore, statsReceiver),
recentHistoryCacheClient,
new DeciderGateBuilder(decider)
.idGate(DeciderKey.enableInvalidatingCachedHistoryStoreAfterWrites)
)
override val emailHistoryStore: PushServiceHistoryStore = {
statsReceiver.scope("frigate_email_history").counter("request").incr()
new SimplePushServiceHistoryStore(emailNotificationHistoryStore)
}
override val loggedOutHistoryStore =
new InvalidatingAfterWritesPushServiceHistoryStore(
ManhattanKVHistoryStore(
manhattanKVLoggedOutHistoryStoreEndpoint,
"frigate_notification_logged_out_history"),
recentHistoryCacheClient,
new DeciderGateBuilder(decider)
.idGate(DeciderKey.enableInvalidatingCachedLoggedOutHistoryStoreAfterWrites)
)
private val requestScribeLog = Logger("request_scribe")
private val requestScribeInjection: Injection[PushRequestScribe, String] = BinaryScalaCodec(
PushRequestScribe
) andThen Injection.connect[Array[Byte], Base64String, String]
override def requestScribe(data: PushRequestScribe): Unit = {
val logEntry: String = requestScribeInjection(data)
requestScribeLog.info(logEntry)
}
// generic notification server
override def notificationServiceSend(
target: Target,
request: CreateGenericNotificationRequest
): Future[CreateGenericNotificationResponse] =
notificationServiceClient.createGenericNotification(request)
// generic notification server
override def notificationServiceDelete(
request: DeleteGenericNotificationRequest
): Future[Unit] = notificationServiceClient.deleteGenericNotification(request)
// NTab-api
override def notificationServiceDeleteTimeline(
request: DeleteCurrentTimelineForUserRequest
): Future[Unit] = notificationServiceApiClient.deleteCurrentTimelineForUser(request)
}

View File

@ -0,0 +1,193 @@
package com.twitter.frigate.pushservice.config
import com.twitter.abdecider.LoggingABDecider
import com.twitter.conversions.DurationOps._
import com.twitter.decider.Decider
import com.twitter.featureswitches.v2.FeatureSwitches
import com.twitter.finagle.mtls.authentication.ServiceIdentifier
import com.twitter.finagle.stats.StatsReceiver
import com.twitter.finagle.thrift.ClientId
import com.twitter.finagle.thrift.RichClientParam
import com.twitter.finagle.util.DefaultTimer
import com.twitter.frigate.common.config.RateLimiterGenerator
import com.twitter.frigate.common.filter.DynamicRequestMeterFilter
import com.twitter.frigate.common.history.InvalidatingAfterWritesPushServiceHistoryStore
import com.twitter.frigate.common.history.ManhattanHistoryStore
import com.twitter.frigate.common.history.ManhattanKVHistoryStore
import com.twitter.frigate.common.history.ReadOnlyHistoryStore
import com.twitter.frigate.common.history.PushServiceHistoryStore
import com.twitter.frigate.common.history.SimplePushServiceHistoryStore
import com.twitter.frigate.common.util.Finagle
import com.twitter.frigate.data_pipeline.features_common.FeatureStoreUtil
import com.twitter.frigate.data_pipeline.features_common.TargetLevelFeaturesConfig
import com.twitter.frigate.pushservice.model.PushTypes.Target
import com.twitter.frigate.pushservice.params.DeciderKey
import com.twitter.frigate.pushservice.params.PushQPSLimitConstants
import com.twitter.frigate.pushservice.params.PushServiceTunableKeys
import com.twitter.frigate.pushservice.params.ShardParams
import com.twitter.frigate.pushservice.store._
import com.twitter.frigate.pushservice.thriftscala.PushRequestScribe
import com.twitter.frigate.scribe.thriftscala.NotificationScribe
import com.twitter.ibis2.service.thriftscala.Ibis2Service
import com.twitter.logging.Logger
import com.twitter.notificationservice.api.thriftscala.DeleteCurrentTimelineForUserRequest
import com.twitter.notificationservice.thriftscala.CreateGenericNotificationRequest
import com.twitter.notificationservice.thriftscala.CreateGenericNotificationResponse
import com.twitter.notificationservice.thriftscala.CreateGenericNotificationResponseType
import com.twitter.notificationservice.thriftscala.DeleteGenericNotificationRequest
import com.twitter.notificationservice.thriftscala.NotificationService
import com.twitter.notificationservice.thriftscala.NotificationService$FinagleClient
import com.twitter.servo.decider.DeciderGateBuilder
import com.twitter.util.tunable.TunableMap
import com.twitter.util.Future
import com.twitter.util.Timer
case class StagingConfig(
override val isServiceLocal: Boolean,
override val localConfigRepoPath: String,
override val inMemCacheOff: Boolean,
override val decider: Decider,
override val abDecider: LoggingABDecider,
override val featureSwitches: FeatureSwitches,
override val shardParams: ShardParams,
override val serviceIdentifier: ServiceIdentifier,
override val tunableMap: TunableMap,
)(
implicit val statsReceiver: StatsReceiver)
extends {
// Due to trait initialization logic in Scala, any abstract members declared in Config or
// DeployConfig should be declared in this block. Otherwise the abstract member might initialize to
// null if invoked before object creation finishing.
val log = Logger("StagingConfig")
// Client ids
val notifierThriftClientId = ClientId("frigate-notifier.dev")
val loggedOutNotifierThriftClientId = ClientId("frigate-logged-out-notifier.dev")
val pushserviceThriftClientId: ClientId = ClientId("frigate-pushservice.staging")
override val fanoutMetadataColumn = "frigate/magicfanout/staging/mh/fanoutMetadata"
// dest
val frigateHistoryCacheDest = "/srv#/test/local/cache/twemcache_frigate_history"
val memcacheCASDest = "/srv#/test/local/cache/twemcache_magic_recs_cas_dev:twemcaches"
val pushServiceMHCacheDest = "/srv#/test/local/cache/twemcache_pushservice_test"
val entityGraphCacheDest = "/srv#/test/local/cache/twemcache_pushservice_test"
val pushServiceCoreSvcsCacheDest = "/srv#/test/local/cache/twemcache_pushservice_core_svcs_test"
val historyStoreMemcacheDest = "/srv#/test/local/cache/twemcache_eventstream_test:twemcaches"
val userTweetEntityGraphDest = "/cluster/local/cassowary/staging/user_tweet_entity_graph"
val userUserGraphDest = "/cluster/local/cassowary/staging/user_user_graph"
val lexServiceDest = "/srv#/staging/local/live-video/timeline-thrift"
val deepbirdv2PredictionServiceDest = "/cluster/local/frigate/staging/deepbirdv2-magicrecs"
override val featureStoreUtil = FeatureStoreUtil.withParams(Some(serviceIdentifier))
override val targetLevelFeaturesConfig = TargetLevelFeaturesConfig()
val mrRequestScriberNode = "validation_mr_request_scribe"
val loggedOutMrRequestScriberNode = "lo_mr_request_scribe"
override val timer: Timer = DefaultTimer
override val pushSendEventStreamName = "frigate_pushservice_send_event_staging"
override val pushIbisV2Store = {
val service = Finagle.readWriteThriftService(
"ibis-v2-service",
"/s/ibis2/ibis2",
statsReceiver,
notifierThriftClientId,
requestTimeout = 6.seconds,
mTLSServiceIdentifier = Some(serviceIdentifier)
)
val pushIbisClient = new Ibis2Service.FinagledClient(
new DynamicRequestMeterFilter(
tunableMap(PushServiceTunableKeys.IbisQpsLimitTunableKey),
RateLimiterGenerator.asTuple(_, shardParams.numShards, 20),
PushQPSLimitConstants.IbisOrNTabQPSForRFPH
)(timer).andThen(service),
RichClientParam(serviceName = "ibis-v2-service")
)
StagingIbis2Store(PushIbis2Store(pushIbisClient))
}
val notificationServiceClient: NotificationService$FinagleClient = {
val service = Finagle.readWriteThriftService(
"notificationservice",
"/s/notificationservice/notificationservice",
statsReceiver,
pushserviceThriftClientId,
requestTimeout = 10.seconds,
mTLSServiceIdentifier = Some(serviceIdentifier)
)
new NotificationService.FinagledClient(
new DynamicRequestMeterFilter(
tunableMap(PushServiceTunableKeys.NtabQpsLimitTunableKey),
RateLimiterGenerator.asTuple(_, shardParams.numShards, 20),
PushQPSLimitConstants.IbisOrNTabQPSForRFPH)(timer).andThen(service),
RichClientParam(serviceName = "notificationservice")
)
}
} with DeployConfig {
// Scribe
private val notificationScribeLog = Logger("StagingNotificationScribe")
override def notificationScribe(data: NotificationScribe): Unit = {
notificationScribeLog.info(data.toString)
}
private val requestScribeLog = Logger("StagingRequestScribe")
override def requestScribe(data: PushRequestScribe): Unit = {
requestScribeLog.info(data.toString)
}
// history store
override val historyStore = new InvalidatingAfterWritesPushServiceHistoryStore(
ReadOnlyHistoryStore(
ManhattanHistoryStore(notificationHistoryStore, statsReceiver)
),
recentHistoryCacheClient,
new DeciderGateBuilder(decider)
.idGate(DeciderKey.enableInvalidatingCachedHistoryStoreAfterWrites)
)
override val emailHistoryStore: PushServiceHistoryStore = new SimplePushServiceHistoryStore(
emailNotificationHistoryStore)
// history store
override val loggedOutHistoryStore =
new InvalidatingAfterWritesPushServiceHistoryStore(
ReadOnlyHistoryStore(
ManhattanKVHistoryStore(
manhattanKVLoggedOutHistoryStoreEndpoint,
"frigate_notification_logged_out_history")),
recentHistoryCacheClient,
new DeciderGateBuilder(decider)
.idGate(DeciderKey.enableInvalidatingCachedLoggedOutHistoryStoreAfterWrites)
)
override def notificationServiceSend(
target: Target,
request: CreateGenericNotificationRequest
): Future[CreateGenericNotificationResponse] =
target.isTeamMember.flatMap { isTeamMember =>
if (isTeamMember) {
notificationServiceClient.createGenericNotification(request)
} else {
log.info(s"Mock creating generic notification $request for user: ${target.targetId}")
Future.value(
CreateGenericNotificationResponse(CreateGenericNotificationResponseType.Success)
)
}
}
override def notificationServiceDelete(
request: DeleteGenericNotificationRequest
): Future[Unit] = Future.Unit
override def notificationServiceDeleteTimeline(
request: DeleteCurrentTimelineForUserRequest
): Future[Unit] = Future.Unit
}

View File

@ -0,0 +1,23 @@
package com.twitter.frigate.pushservice.config.mlconfig
import com.twitter.cortex.deepbird.thriftjava.DeepbirdPredictionService
import com.twitter.finagle.stats.StatsReceiver
import com.twitter.frigate.common.ml.prediction.DeepbirdPredictionEngineServiceStore
import com.twitter.nrel.heavyranker.PushDBv2PredictionServiceStore
object DeepbirdV2ModelConfig {
def buildPredictionServiceScoreStore(
predictionServiceClient: DeepbirdPredictionService.ServiceToClient,
serviceName: String
)(
implicit statsReceiver: StatsReceiver
): PushDBv2PredictionServiceStore = {
val stats = statsReceiver.scope(serviceName)
val serviceStats = statsReceiver.scope("dbv2PredictionServiceStore")
new PushDBv2PredictionServiceStore(
DeepbirdPredictionEngineServiceStore(predictionServiceClient, batchSize = Some(32))(stats)
)(serviceStats)
}
}

Some files were not shown because too many files have changed in this diff Show More