mirror of
https://github.com/twitter/the-algorithm.git
synced 2025-01-09 18:59:25 +01:00
Compare commits
No commits in common. "31e82d6474cf47b3695bf919c44c94d146192a03" and "6e5c875a69b5dc400302e42a3d0b2cfe509c71b6" have entirely different histories.
31e82d6474
...
6e5c875a69
navi
README.md
dr_transform
navi
segdense
unified_user_actions
adapter/src
main/scala/com/twitter/unified_user_actions/adapter/behavioral_client_event
BUILDBaseBCEAdapter.scalaBehavioralClientEventAdapter.scalaImpressionBCEAdapter.scalaProfileImpressionBCEAdapter.scalaTweetImpressionBCEAdapter.scala
test/scala/com/twitter/unified_user_actions/adapter
service/src/main/scala/com/twitter/unified_user_actions/service
@ -31,11 +31,6 @@ In navi/navi, you can run the following commands:
|
|||||||
- `scripts/run_onnx.sh` for [Onnx](https://onnx.ai/)
|
- `scripts/run_onnx.sh` for [Onnx](https://onnx.ai/)
|
||||||
|
|
||||||
Do note that you need to create a models directory and create some versions, preferably using epoch time, e.g., `1679693908377`.
|
Do note that you need to create a models directory and create some versions, preferably using epoch time, e.g., `1679693908377`.
|
||||||
so the models structure looks like:
|
|
||||||
models/
|
|
||||||
-web_click
|
|
||||||
- 1809000
|
|
||||||
- 1809010
|
|
||||||
|
|
||||||
## Build
|
## Build
|
||||||
You can adapt the above scripts to build using Cargo.
|
You can adapt the above scripts to build using Cargo.
|
||||||
|
@ -3,6 +3,7 @@ name = "dr_transform"
|
|||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
|
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||||
[dependencies]
|
[dependencies]
|
||||||
serde = { version = "1.0", features = ["derive"] }
|
serde = { version = "1.0", features = ["derive"] }
|
||||||
serde_json = "1.0"
|
serde_json = "1.0"
|
||||||
@ -11,6 +12,7 @@ bpr_thrift = { path = "../thrift_bpr_adapter/thrift/"}
|
|||||||
segdense = { path = "../segdense/"}
|
segdense = { path = "../segdense/"}
|
||||||
thrift = "0.17.0"
|
thrift = "0.17.0"
|
||||||
ndarray = "0.15"
|
ndarray = "0.15"
|
||||||
|
ort = {git ="https://github.com/pykeio/ort.git", tag="v1.14.2"}
|
||||||
base64 = "0.20.0"
|
base64 = "0.20.0"
|
||||||
npyz = "0.7.2"
|
npyz = "0.7.2"
|
||||||
log = "0.4.17"
|
log = "0.4.17"
|
||||||
@ -19,11 +21,6 @@ prometheus = "0.13.1"
|
|||||||
once_cell = "1.17.0"
|
once_cell = "1.17.0"
|
||||||
rand = "0.8.5"
|
rand = "0.8.5"
|
||||||
itertools = "0.10.5"
|
itertools = "0.10.5"
|
||||||
anyhow = "1.0.70"
|
|
||||||
[target.'cfg(not(target_os="linux"))'.dependencies]
|
|
||||||
ort = {git ="https://github.com/pykeio/ort.git", features=["profiling"], tag="v1.14.6"}
|
|
||||||
[target.'cfg(target_os="linux")'.dependencies]
|
|
||||||
ort = {git ="https://github.com/pykeio/ort.git", features=["profiling", "tensorrt", "cuda", "copy-dylibs"], tag="v1.14.6"}
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
criterion = "0.3.0"
|
criterion = "0.3.0"
|
||||||
|
|
||||||
|
@ -44,6 +44,5 @@ pub struct RenamedFeatures {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn parse(json_str: &str) -> Result<AllConfig, Error> {
|
pub fn parse(json_str: &str) -> Result<AllConfig, Error> {
|
||||||
let all_config: AllConfig = serde_json::from_str(json_str)?;
|
serde_json::from_str(json_str)
|
||||||
Ok(all_config)
|
|
||||||
}
|
}
|
||||||
|
@ -2,9 +2,6 @@ use std::collections::BTreeSet;
|
|||||||
use std::fmt::{self, Debug, Display};
|
use std::fmt::{self, Debug, Display};
|
||||||
use std::fs;
|
use std::fs;
|
||||||
|
|
||||||
use crate::all_config;
|
|
||||||
use crate::all_config::AllConfig;
|
|
||||||
use anyhow::{bail, Context};
|
|
||||||
use bpr_thrift::data::DataRecord;
|
use bpr_thrift::data::DataRecord;
|
||||||
use bpr_thrift::prediction_service::BatchPredictionRequest;
|
use bpr_thrift::prediction_service::BatchPredictionRequest;
|
||||||
use bpr_thrift::tensor::GeneralTensor;
|
use bpr_thrift::tensor::GeneralTensor;
|
||||||
@ -19,6 +16,8 @@ use segdense::util;
|
|||||||
use thrift::protocol::{TBinaryInputProtocol, TSerializable};
|
use thrift::protocol::{TBinaryInputProtocol, TSerializable};
|
||||||
use thrift::transport::TBufferChannel;
|
use thrift::transport::TBufferChannel;
|
||||||
|
|
||||||
|
use crate::{all_config, all_config::AllConfig};
|
||||||
|
|
||||||
pub fn log_feature_match(
|
pub fn log_feature_match(
|
||||||
dr: &DataRecord,
|
dr: &DataRecord,
|
||||||
seg_dense_config: &DensificationTransformSpec,
|
seg_dense_config: &DensificationTransformSpec,
|
||||||
@ -29,24 +28,20 @@ pub fn log_feature_match(
|
|||||||
|
|
||||||
for (feature_id, feature_value) in dr.continuous_features.as_ref().unwrap() {
|
for (feature_id, feature_value) in dr.continuous_features.as_ref().unwrap() {
|
||||||
debug!(
|
debug!(
|
||||||
"{} - Continous Datarecord => Feature ID: {}, Feature value: {}",
|
"{dr_type} - Continuous Datarecord => Feature ID: {feature_id}, Feature value: {feature_value}"
|
||||||
dr_type, feature_id, feature_value
|
|
||||||
);
|
);
|
||||||
for input_feature in &seg_dense_config.cont.input_features {
|
for input_feature in &seg_dense_config.cont.input_features {
|
||||||
if input_feature.feature_id == *feature_id {
|
if input_feature.feature_id == *feature_id {
|
||||||
debug!("Matching input feature: {:?}", input_feature)
|
debug!("Matching input feature: {input_feature:?}")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for feature_id in dr.binary_features.as_ref().unwrap() {
|
for feature_id in dr.binary_features.as_ref().unwrap() {
|
||||||
debug!(
|
debug!("{dr_type} - Binary Datarecord => Feature ID: {feature_id}");
|
||||||
"{} - Binary Datarecord => Feature ID: {}",
|
|
||||||
dr_type, feature_id
|
|
||||||
);
|
|
||||||
for input_feature in &seg_dense_config.binary.input_features {
|
for input_feature in &seg_dense_config.binary.input_features {
|
||||||
if input_feature.feature_id == *feature_id {
|
if input_feature.feature_id == *feature_id {
|
||||||
debug!("Found input feature: {:?}", input_feature)
|
debug!("Found input feature: {input_feature:?}")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -95,19 +90,18 @@ impl BatchPredictionRequestToTorchTensorConverter {
|
|||||||
model_version: &str,
|
model_version: &str,
|
||||||
reporting_feature_ids: Vec<(i64, &str)>,
|
reporting_feature_ids: Vec<(i64, &str)>,
|
||||||
register_metric_fn: Option<impl Fn(&HistogramVec)>,
|
register_metric_fn: Option<impl Fn(&HistogramVec)>,
|
||||||
) -> anyhow::Result<BatchPredictionRequestToTorchTensorConverter> {
|
) -> BatchPredictionRequestToTorchTensorConverter {
|
||||||
let all_config_path = format!("{}/{}/all_config.json", model_dir, model_version);
|
let all_config_path = format!("{model_dir}/{model_version}/all_config.json");
|
||||||
let seg_dense_config_path = format!(
|
let seg_dense_config_path =
|
||||||
"{}/{}/segdense_transform_spec_home_recap_2022.json",
|
format!("{model_dir}/{model_version}/segdense_transform_spec_home_recap_2022.json");
|
||||||
model_dir, model_version
|
let seg_dense_config = util::load_config(&seg_dense_config_path);
|
||||||
);
|
|
||||||
let seg_dense_config = util::load_config(&seg_dense_config_path)?;
|
|
||||||
let all_config = all_config::parse(
|
let all_config = all_config::parse(
|
||||||
&fs::read_to_string(&all_config_path)
|
&fs::read_to_string(&all_config_path)
|
||||||
.with_context(|| "error loading all_config.json - ")?,
|
.unwrap_or_else(|error| panic!("error loading all_config.json - {error}")),
|
||||||
)?;
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let feature_mapper = util::load_from_parsed_config(seg_dense_config.clone())?;
|
let feature_mapper = util::load_from_parsed_config_ref(&seg_dense_config);
|
||||||
|
|
||||||
let user_embedding_feature_id = Self::get_feature_id(
|
let user_embedding_feature_id = Self::get_feature_id(
|
||||||
&all_config
|
&all_config
|
||||||
@ -137,11 +131,11 @@ impl BatchPredictionRequestToTorchTensorConverter {
|
|||||||
let (discrete_feature_metrics, continuous_feature_metrics) = METRICS.get_or_init(|| {
|
let (discrete_feature_metrics, continuous_feature_metrics) = METRICS.get_or_init(|| {
|
||||||
let discrete = HistogramVec::new(
|
let discrete = HistogramVec::new(
|
||||||
HistogramOpts::new(":navi:feature_id:discrete", "Discrete Feature ID values")
|
HistogramOpts::new(":navi:feature_id:discrete", "Discrete Feature ID values")
|
||||||
.buckets(Vec::from(&[
|
.buckets(Vec::from([
|
||||||
0.0, 10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0, 110.0,
|
0.0f64, 10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0, 110.0,
|
||||||
120.0, 130.0, 140.0, 150.0, 160.0, 170.0, 180.0, 190.0, 200.0, 250.0,
|
120.0, 130.0, 140.0, 150.0, 160.0, 170.0, 180.0, 190.0, 200.0, 250.0,
|
||||||
300.0, 500.0, 1000.0, 10000.0, 100000.0,
|
300.0, 500.0, 1000.0, 10000.0, 100000.0,
|
||||||
] as &'static [f64])),
|
])),
|
||||||
&["feature_id"],
|
&["feature_id"],
|
||||||
)
|
)
|
||||||
.expect("metric cannot be created");
|
.expect("metric cannot be created");
|
||||||
@ -150,18 +144,18 @@ impl BatchPredictionRequestToTorchTensorConverter {
|
|||||||
":navi:feature_id:continuous",
|
":navi:feature_id:continuous",
|
||||||
"continuous Feature ID values",
|
"continuous Feature ID values",
|
||||||
)
|
)
|
||||||
.buckets(Vec::from(&[
|
.buckets(Vec::from([
|
||||||
0.0, 10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0, 110.0, 120.0,
|
0.0f64, 10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0, 110.0,
|
||||||
130.0, 140.0, 150.0, 160.0, 170.0, 180.0, 190.0, 200.0, 250.0, 300.0, 500.0,
|
120.0, 130.0, 140.0, 150.0, 160.0, 170.0, 180.0, 190.0, 200.0, 250.0, 300.0,
|
||||||
1000.0, 10000.0, 100000.0,
|
500.0, 1000.0, 10000.0, 100000.0,
|
||||||
] as &'static [f64])),
|
])),
|
||||||
&["feature_id"],
|
&["feature_id"],
|
||||||
)
|
)
|
||||||
.expect("metric cannot be created");
|
.expect("metric cannot be created");
|
||||||
register_metric_fn.map(|r| {
|
if let Some(r) = register_metric_fn {
|
||||||
r(&discrete);
|
r(&discrete);
|
||||||
r(&continuous);
|
r(&continuous);
|
||||||
});
|
}
|
||||||
(discrete, continuous)
|
(discrete, continuous)
|
||||||
});
|
});
|
||||||
|
|
||||||
@ -170,16 +164,13 @@ impl BatchPredictionRequestToTorchTensorConverter {
|
|||||||
|
|
||||||
for (feature_id, feature_type) in reporting_feature_ids.iter() {
|
for (feature_id, feature_type) in reporting_feature_ids.iter() {
|
||||||
match *feature_type {
|
match *feature_type {
|
||||||
"discrete" => discrete_features_to_report.insert(feature_id.clone()),
|
"discrete" => discrete_features_to_report.insert(*feature_id),
|
||||||
"continuous" => continuous_features_to_report.insert(feature_id.clone()),
|
"continuous" => continuous_features_to_report.insert(*feature_id),
|
||||||
_ => bail!(
|
_ => panic!("Invalid feature type {feature_type} for reporting metrics!"),
|
||||||
"Invalid feature type {} for reporting metrics!",
|
|
||||||
feature_type
|
|
||||||
),
|
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(BatchPredictionRequestToTorchTensorConverter {
|
BatchPredictionRequestToTorchTensorConverter {
|
||||||
all_config,
|
all_config,
|
||||||
seg_dense_config,
|
seg_dense_config,
|
||||||
all_config_path,
|
all_config_path,
|
||||||
@ -192,7 +183,7 @@ impl BatchPredictionRequestToTorchTensorConverter {
|
|||||||
continuous_features_to_report,
|
continuous_features_to_report,
|
||||||
discrete_feature_metrics,
|
discrete_feature_metrics,
|
||||||
continuous_feature_metrics,
|
continuous_feature_metrics,
|
||||||
})
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_feature_id(feature_name: &str, seg_dense_config: &Root) -> i64 {
|
fn get_feature_id(feature_name: &str, seg_dense_config: &Root) -> i64 {
|
||||||
@ -227,45 +218,43 @@ impl BatchPredictionRequestToTorchTensorConverter {
|
|||||||
let mut working_set = vec![0 as f32; total_size];
|
let mut working_set = vec![0 as f32; total_size];
|
||||||
let mut bpr_start = 0;
|
let mut bpr_start = 0;
|
||||||
for (bpr, &bpr_end) in bprs.iter().zip(batch_size) {
|
for (bpr, &bpr_end) in bprs.iter().zip(batch_size) {
|
||||||
if bpr.common_features.is_some() {
|
if bpr.common_features.is_some()
|
||||||
if bpr.common_features.as_ref().unwrap().tensors.is_some() {
|
&& bpr.common_features.as_ref().unwrap().tensors.is_some()
|
||||||
if bpr
|
&& bpr
|
||||||
.common_features
|
.common_features
|
||||||
.as_ref()
|
.as_ref()
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.tensors
|
.tensors
|
||||||
.as_ref()
|
.as_ref()
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.contains_key(&feature_id)
|
.contains_key(&feature_id)
|
||||||
|
{
|
||||||
|
let source_tensor = bpr
|
||||||
|
.common_features
|
||||||
|
.as_ref()
|
||||||
|
.unwrap()
|
||||||
|
.tensors
|
||||||
|
.as_ref()
|
||||||
|
.unwrap()
|
||||||
|
.get(&feature_id)
|
||||||
|
.unwrap();
|
||||||
|
let tensor = match source_tensor {
|
||||||
|
GeneralTensor::FloatTensor(float_tensor) =>
|
||||||
|
//Tensor::of_slice(
|
||||||
{
|
{
|
||||||
let source_tensor = bpr
|
float_tensor
|
||||||
.common_features
|
.floats
|
||||||
.as_ref()
|
.iter()
|
||||||
.unwrap()
|
.map(|x| x.into_inner() as f32)
|
||||||
.tensors
|
.collect::<Vec<_>>()
|
||||||
.as_ref()
|
}
|
||||||
.unwrap()
|
_ => vec![0 as f32; cols],
|
||||||
.get(&feature_id)
|
};
|
||||||
.unwrap();
|
|
||||||
let tensor = match source_tensor {
|
|
||||||
GeneralTensor::FloatTensor(float_tensor) =>
|
|
||||||
//Tensor::of_slice(
|
|
||||||
{
|
|
||||||
float_tensor
|
|
||||||
.floats
|
|
||||||
.iter()
|
|
||||||
.map(|x| x.into_inner() as f32)
|
|
||||||
.collect::<Vec<_>>()
|
|
||||||
}
|
|
||||||
_ => vec![0 as f32; cols],
|
|
||||||
};
|
|
||||||
|
|
||||||
// since the tensor is found in common feature, add it in all batches
|
// since the tensor is found in common feature, add it in all batches
|
||||||
for row in bpr_start..bpr_end {
|
for row in bpr_start..bpr_end {
|
||||||
for col in 0..cols {
|
for col in 0..cols {
|
||||||
working_set[row * cols + col] = tensor[col];
|
working_set[row * cols + col] = tensor[col];
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -309,9 +298,9 @@ impl BatchPredictionRequestToTorchTensorConverter {
|
|||||||
// (INT64 --> INT64, DataRecord.discrete_feature)
|
// (INT64 --> INT64, DataRecord.discrete_feature)
|
||||||
fn get_continuous(&self, bprs: &[BatchPredictionRequest], batch_ends: &[usize]) -> InputTensor {
|
fn get_continuous(&self, bprs: &[BatchPredictionRequest], batch_ends: &[usize]) -> InputTensor {
|
||||||
// These need to be part of model schema
|
// These need to be part of model schema
|
||||||
let rows: usize = batch_ends[batch_ends.len() - 1];
|
let rows = batch_ends[batch_ends.len() - 1];
|
||||||
let cols: usize = 5293;
|
let cols = 5293;
|
||||||
let full_size: usize = rows * cols;
|
let full_size = rows * cols;
|
||||||
let default_val = f32::NAN;
|
let default_val = f32::NAN;
|
||||||
|
|
||||||
let mut tensor = vec![default_val; full_size];
|
let mut tensor = vec![default_val; full_size];
|
||||||
@ -336,18 +325,15 @@ impl BatchPredictionRequestToTorchTensorConverter {
|
|||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
for feature in common_features {
|
for feature in common_features {
|
||||||
match self.feature_mapper.get(feature.0) {
|
if let Some(f_info) = self.feature_mapper.get(feature.0) {
|
||||||
Some(f_info) => {
|
let idx = f_info.index_within_tensor as usize;
|
||||||
let idx = f_info.index_within_tensor as usize;
|
if idx < cols {
|
||||||
if idx < cols {
|
// Set value in each row
|
||||||
// Set value in each row
|
for r in bpr_start..bpr_end {
|
||||||
for r in bpr_start..bpr_end {
|
let flat_index = r * cols + idx;
|
||||||
let flat_index: usize = r * cols + idx;
|
tensor[flat_index] = feature.1.into_inner() as f32;
|
||||||
tensor[flat_index] = feature.1.into_inner() as f32;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
None => (),
|
|
||||||
}
|
}
|
||||||
if self.continuous_features_to_report.contains(feature.0) {
|
if self.continuous_features_to_report.contains(feature.0) {
|
||||||
self.continuous_feature_metrics
|
self.continuous_feature_metrics
|
||||||
@ -363,28 +349,24 @@ impl BatchPredictionRequestToTorchTensorConverter {
|
|||||||
|
|
||||||
// Process the batch of datarecords
|
// Process the batch of datarecords
|
||||||
for r in bpr_start..bpr_end {
|
for r in bpr_start..bpr_end {
|
||||||
let dr: &DataRecord =
|
let dr: &DataRecord = &bpr.individual_features_list[r - bpr_start];
|
||||||
&bpr.individual_features_list[usize::try_from(r - bpr_start).unwrap()];
|
|
||||||
if dr.continuous_features.is_some() {
|
if dr.continuous_features.is_some() {
|
||||||
for feature in dr.continuous_features.as_ref().unwrap() {
|
for feature in dr.continuous_features.as_ref().unwrap() {
|
||||||
match self.feature_mapper.get(&feature.0) {
|
if let Some(f_info) = self.feature_mapper.get(feature.0) {
|
||||||
Some(f_info) => {
|
let idx = f_info.index_within_tensor as usize;
|
||||||
let idx = f_info.index_within_tensor as usize;
|
let flat_index = r * cols + idx;
|
||||||
let flat_index: usize = r * cols + idx;
|
if flat_index < tensor.len() && idx < cols {
|
||||||
if flat_index < tensor.len() && idx < cols {
|
tensor[flat_index] = feature.1.into_inner() as f32;
|
||||||
tensor[flat_index] = feature.1.into_inner() as f32;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
None => (),
|
|
||||||
}
|
}
|
||||||
if self.continuous_features_to_report.contains(feature.0) {
|
if self.continuous_features_to_report.contains(feature.0) {
|
||||||
self.continuous_feature_metrics
|
self.continuous_feature_metrics
|
||||||
.with_label_values(&[feature.0.to_string().as_str()])
|
.with_label_values(&[feature.0.to_string().as_str()])
|
||||||
.observe(feature.1.into_inner() as f64)
|
.observe(feature.1.into_inner())
|
||||||
} else if self.discrete_features_to_report.contains(feature.0) {
|
} else if self.discrete_features_to_report.contains(feature.0) {
|
||||||
self.discrete_feature_metrics
|
self.discrete_feature_metrics
|
||||||
.with_label_values(&[feature.0.to_string().as_str()])
|
.with_label_values(&[feature.0.to_string().as_str()])
|
||||||
.observe(feature.1.into_inner() as f64)
|
.observe(feature.1.into_inner())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -401,10 +383,10 @@ impl BatchPredictionRequestToTorchTensorConverter {
|
|||||||
|
|
||||||
fn get_binary(&self, bprs: &[BatchPredictionRequest], batch_ends: &[usize]) -> InputTensor {
|
fn get_binary(&self, bprs: &[BatchPredictionRequest], batch_ends: &[usize]) -> InputTensor {
|
||||||
// These need to be part of model schema
|
// These need to be part of model schema
|
||||||
let rows: usize = batch_ends[batch_ends.len() - 1];
|
let rows = batch_ends[batch_ends.len() - 1];
|
||||||
let cols: usize = 149;
|
let cols = 149;
|
||||||
let full_size: usize = rows * cols;
|
let full_size = rows * cols;
|
||||||
let default_val: i64 = 0;
|
let default_val = 0;
|
||||||
|
|
||||||
let mut v = vec![default_val; full_size];
|
let mut v = vec![default_val; full_size];
|
||||||
|
|
||||||
@ -428,18 +410,15 @@ impl BatchPredictionRequestToTorchTensorConverter {
|
|||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
for feature in common_features {
|
for feature in common_features {
|
||||||
match self.feature_mapper.get(feature) {
|
if let Some(f_info) = self.feature_mapper.get(feature) {
|
||||||
Some(f_info) => {
|
let idx = f_info.index_within_tensor as usize;
|
||||||
let idx = f_info.index_within_tensor as usize;
|
if idx < cols {
|
||||||
if idx < cols {
|
// Set value in each row
|
||||||
// Set value in each row
|
for r in bpr_start..bpr_end {
|
||||||
for r in bpr_start..bpr_end {
|
let flat_index = r * cols + idx;
|
||||||
let flat_index: usize = r * cols + idx;
|
v[flat_index] = 1;
|
||||||
v[flat_index] = 1;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
None => (),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -449,13 +428,10 @@ impl BatchPredictionRequestToTorchTensorConverter {
|
|||||||
let dr: &DataRecord = &bpr.individual_features_list[r - bpr_start];
|
let dr: &DataRecord = &bpr.individual_features_list[r - bpr_start];
|
||||||
if dr.binary_features.is_some() {
|
if dr.binary_features.is_some() {
|
||||||
for feature in dr.binary_features.as_ref().unwrap() {
|
for feature in dr.binary_features.as_ref().unwrap() {
|
||||||
match self.feature_mapper.get(&feature) {
|
if let Some(f_info) = self.feature_mapper.get(feature) {
|
||||||
Some(f_info) => {
|
let idx = f_info.index_within_tensor as usize;
|
||||||
let idx = f_info.index_within_tensor as usize;
|
let flat_index = r * cols + idx;
|
||||||
let flat_index: usize = r * cols + idx;
|
v[flat_index] = 1;
|
||||||
v[flat_index] = 1;
|
|
||||||
}
|
|
||||||
None => (),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -472,10 +448,10 @@ impl BatchPredictionRequestToTorchTensorConverter {
|
|||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
fn get_discrete(&self, bprs: &[BatchPredictionRequest], batch_ends: &[usize]) -> InputTensor {
|
fn get_discrete(&self, bprs: &[BatchPredictionRequest], batch_ends: &[usize]) -> InputTensor {
|
||||||
// These need to be part of model schema
|
// These need to be part of model schema
|
||||||
let rows: usize = batch_ends[batch_ends.len() - 1];
|
let rows = batch_ends[batch_ends.len() - 1];
|
||||||
let cols: usize = 320;
|
let cols = 320;
|
||||||
let full_size: usize = rows * cols;
|
let full_size = rows * cols;
|
||||||
let default_val: i64 = 0;
|
let default_val = 0;
|
||||||
|
|
||||||
let mut v = vec![default_val; full_size];
|
let mut v = vec![default_val; full_size];
|
||||||
|
|
||||||
@ -499,18 +475,15 @@ impl BatchPredictionRequestToTorchTensorConverter {
|
|||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
for feature in common_features {
|
for feature in common_features {
|
||||||
match self.feature_mapper.get(feature.0) {
|
if let Some(f_info) = self.feature_mapper.get(feature.0) {
|
||||||
Some(f_info) => {
|
let idx = f_info.index_within_tensor as usize;
|
||||||
let idx = f_info.index_within_tensor as usize;
|
if idx < cols {
|
||||||
if idx < cols {
|
// Set value in each row
|
||||||
// Set value in each row
|
for r in bpr_start..bpr_end {
|
||||||
for r in bpr_start..bpr_end {
|
let flat_index = r * cols + idx;
|
||||||
let flat_index: usize = r * cols + idx;
|
v[flat_index] = *feature.1;
|
||||||
v[flat_index] = *feature.1;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
None => (),
|
|
||||||
}
|
}
|
||||||
if self.discrete_features_to_report.contains(feature.0) {
|
if self.discrete_features_to_report.contains(feature.0) {
|
||||||
self.discrete_feature_metrics
|
self.discrete_feature_metrics
|
||||||
@ -522,18 +495,15 @@ impl BatchPredictionRequestToTorchTensorConverter {
|
|||||||
|
|
||||||
// Process the batch of datarecords
|
// Process the batch of datarecords
|
||||||
for r in bpr_start..bpr_end {
|
for r in bpr_start..bpr_end {
|
||||||
let dr: &DataRecord = &bpr.individual_features_list[usize::try_from(r).unwrap()];
|
let dr: &DataRecord = &bpr.individual_features_list[r];
|
||||||
if dr.discrete_features.is_some() {
|
if dr.discrete_features.is_some() {
|
||||||
for feature in dr.discrete_features.as_ref().unwrap() {
|
for feature in dr.discrete_features.as_ref().unwrap() {
|
||||||
match self.feature_mapper.get(&feature.0) {
|
if let Some(f_info) = self.feature_mapper.get(feature.0) {
|
||||||
Some(f_info) => {
|
let idx = f_info.index_within_tensor as usize;
|
||||||
let idx = f_info.index_within_tensor as usize;
|
let flat_index = r * cols + idx;
|
||||||
let flat_index: usize = r * cols + idx;
|
if flat_index < v.len() && idx < cols {
|
||||||
if flat_index < v.len() && idx < cols {
|
v[flat_index] = *feature.1;
|
||||||
v[flat_index] = *feature.1;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
None => (),
|
|
||||||
}
|
}
|
||||||
if self.discrete_features_to_report.contains(feature.0) {
|
if self.discrete_features_to_report.contains(feature.0) {
|
||||||
self.discrete_feature_metrics
|
self.discrete_feature_metrics
|
||||||
@ -599,7 +569,7 @@ impl Converter for BatchPredictionRequestToTorchTensorConverter {
|
|||||||
.map(|bpr| bpr.individual_features_list.len())
|
.map(|bpr| bpr.individual_features_list.len())
|
||||||
.scan(0usize, |acc, e| {
|
.scan(0usize, |acc, e| {
|
||||||
//running total
|
//running total
|
||||||
*acc = *acc + e;
|
*acc += e;
|
||||||
Some(*acc)
|
Some(*acc)
|
||||||
})
|
})
|
||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
|
@ -3,4 +3,3 @@ pub mod converter;
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod test;
|
mod test;
|
||||||
pub mod util;
|
pub mod util;
|
||||||
pub extern crate ort;
|
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "navi"
|
name = "navi"
|
||||||
version = "2.0.45"
|
version = "2.0.42"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||||
|
|
||||||
[[bin]]
|
[[bin]]
|
||||||
name = "navi"
|
name = "navi"
|
||||||
@ -15,19 +16,12 @@ required-features=["torch"]
|
|||||||
name = "navi_onnx"
|
name = "navi_onnx"
|
||||||
path = "src/bin/navi_onnx.rs"
|
path = "src/bin/navi_onnx.rs"
|
||||||
required-features=["onnx"]
|
required-features=["onnx"]
|
||||||
[[bin]]
|
|
||||||
name = "navi_onnx_test"
|
|
||||||
path = "src/bin/bin_tests/navi_onnx_test.rs"
|
|
||||||
[[bin]]
|
|
||||||
name = "navi_torch_test"
|
|
||||||
path = "src/bin/bin_tests/navi_torch_test.rs"
|
|
||||||
required-features=["torch"]
|
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
default=[]
|
default=[]
|
||||||
navi_console=[]
|
navi_console=[]
|
||||||
torch=["tch"]
|
torch=["tch"]
|
||||||
onnx=[]
|
onnx=["ort"]
|
||||||
tf=["tensorflow"]
|
tf=["tensorflow"]
|
||||||
[dependencies]
|
[dependencies]
|
||||||
itertools = "0.10.5"
|
itertools = "0.10.5"
|
||||||
@ -53,7 +47,6 @@ parking_lot = "0.12.1"
|
|||||||
rand = "0.8.5"
|
rand = "0.8.5"
|
||||||
rand_pcg = "0.3.1"
|
rand_pcg = "0.3.1"
|
||||||
random = "0.12.2"
|
random = "0.12.2"
|
||||||
x509-parser = "0.15.0"
|
|
||||||
sha256 = "1.0.3"
|
sha256 = "1.0.3"
|
||||||
tonic = { version = "0.6.2", features=['compression', 'tls'] }
|
tonic = { version = "0.6.2", features=['compression', 'tls'] }
|
||||||
tokio = { version = "1.17.0", features = ["macros", "rt-multi-thread", "fs", "process"] }
|
tokio = { version = "1.17.0", features = ["macros", "rt-multi-thread", "fs", "process"] }
|
||||||
@ -62,12 +55,16 @@ npyz = "0.7.3"
|
|||||||
base64 = "0.21.0"
|
base64 = "0.21.0"
|
||||||
histogram = "0.6.9"
|
histogram = "0.6.9"
|
||||||
tch = {version = "0.10.3", optional = true}
|
tch = {version = "0.10.3", optional = true}
|
||||||
tensorflow = { version = "0.18.0", optional = true }
|
tensorflow = { version = "0.20.0", optional = true }
|
||||||
once_cell = {version = "1.17.1"}
|
once_cell = {version = "1.17.1"}
|
||||||
ndarray = "0.15"
|
ndarray = "0.15"
|
||||||
serde = "1.0.154"
|
serde = "1.0.154"
|
||||||
serde_json = "1.0.94"
|
serde_json = "1.0.94"
|
||||||
dr_transform = { path = "../dr_transform"}
|
dr_transform = { path = "../dr_transform"}
|
||||||
|
[target.'cfg(not(target_os="linux"))'.dependencies]
|
||||||
|
ort = {git ="https://github.com/pykeio/ort.git", features=["profiling"], optional = true, tag="v1.14.2"}
|
||||||
|
[target.'cfg(target_os="linux")'.dependencies]
|
||||||
|
ort = {git ="https://github.com/pykeio/ort.git", features=["profiling", "tensorrt", "cuda", "copy-dylibs"], optional = true, tag="v1.14.2"}
|
||||||
[build-dependencies]
|
[build-dependencies]
|
||||||
tonic-build = {version = "0.6.2", features=['prost', "compression"] }
|
tonic-build = {version = "0.6.2", features=['prost', "compression"] }
|
||||||
[profile.release]
|
[profile.release]
|
||||||
@ -77,5 +74,3 @@ ndarray-rand = "0.14.0"
|
|||||||
tokio-test = "*"
|
tokio-test = "*"
|
||||||
assert_cmd = "2.0"
|
assert_cmd = "2.0"
|
||||||
criterion = "0.4.0"
|
criterion = "0.4.0"
|
||||||
|
|
||||||
|
|
||||||
|
@ -122,7 +122,7 @@ enum FullTypeId {
|
|||||||
// TFT_TENSOR[TFT_INT32, TFT_UNKNOWN]
|
// TFT_TENSOR[TFT_INT32, TFT_UNKNOWN]
|
||||||
// is a Tensor of int32 element type and unknown shape.
|
// is a Tensor of int32 element type and unknown shape.
|
||||||
//
|
//
|
||||||
// TODO(mdan): Define TFT_SHAPE and add more examples.
|
// TODO: Define TFT_SHAPE and add more examples.
|
||||||
TFT_TENSOR = 1000;
|
TFT_TENSOR = 1000;
|
||||||
|
|
||||||
// Array (or tensorflow::TensorList in the variant type registry).
|
// Array (or tensorflow::TensorList in the variant type registry).
|
||||||
@ -178,7 +178,7 @@ enum FullTypeId {
|
|||||||
// object (for now).
|
// object (for now).
|
||||||
|
|
||||||
// The bool element type.
|
// The bool element type.
|
||||||
// TODO(mdan): Quantized types, legacy representations (e.g. ref)
|
// TODO
|
||||||
TFT_BOOL = 200;
|
TFT_BOOL = 200;
|
||||||
// Integer element types.
|
// Integer element types.
|
||||||
TFT_UINT8 = 201;
|
TFT_UINT8 = 201;
|
||||||
@ -195,7 +195,7 @@ enum FullTypeId {
|
|||||||
TFT_DOUBLE = 211;
|
TFT_DOUBLE = 211;
|
||||||
TFT_BFLOAT16 = 215;
|
TFT_BFLOAT16 = 215;
|
||||||
// Complex element types.
|
// Complex element types.
|
||||||
// TODO(mdan): Represent as TFT_COMPLEX[TFT_DOUBLE] instead?
|
// TODO: Represent as TFT_COMPLEX[TFT_DOUBLE] instead?
|
||||||
TFT_COMPLEX64 = 212;
|
TFT_COMPLEX64 = 212;
|
||||||
TFT_COMPLEX128 = 213;
|
TFT_COMPLEX128 = 213;
|
||||||
// The string element type.
|
// The string element type.
|
||||||
@ -240,7 +240,7 @@ enum FullTypeId {
|
|||||||
// ownership is in the true sense: "the op argument representing the lock is
|
// ownership is in the true sense: "the op argument representing the lock is
|
||||||
// available".
|
// available".
|
||||||
// Mutex locks are the dynamic counterpart of control dependencies.
|
// Mutex locks are the dynamic counterpart of control dependencies.
|
||||||
// TODO(mdan): Properly document this thing.
|
// TODO: Properly document this thing.
|
||||||
//
|
//
|
||||||
// Parametrization: TFT_MUTEX_LOCK[].
|
// Parametrization: TFT_MUTEX_LOCK[].
|
||||||
TFT_MUTEX_LOCK = 10202;
|
TFT_MUTEX_LOCK = 10202;
|
||||||
@ -271,6 +271,6 @@ message FullTypeDef {
|
|||||||
oneof attr {
|
oneof attr {
|
||||||
string s = 3;
|
string s = 3;
|
||||||
int64 i = 4;
|
int64 i = 4;
|
||||||
// TODO(mdan): list/tensor, map? Need to reconcile with TFT_RECORD, etc.
|
// TODO: list/tensor, map? Need to reconcile with TFT_RECORD, etc.
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -23,7 +23,7 @@ message FunctionDefLibrary {
|
|||||||
// with a value. When a GraphDef has a call to a function, it must
|
// with a value. When a GraphDef has a call to a function, it must
|
||||||
// have binding for every attr defined in the signature.
|
// have binding for every attr defined in the signature.
|
||||||
//
|
//
|
||||||
// TODO(zhifengc):
|
// TODO:
|
||||||
// * device spec, etc.
|
// * device spec, etc.
|
||||||
message FunctionDef {
|
message FunctionDef {
|
||||||
// The definition of the function's name, arguments, return values,
|
// The definition of the function's name, arguments, return values,
|
||||||
|
@ -61,7 +61,7 @@ message NodeDef {
|
|||||||
// one of the names from the corresponding OpDef's attr field).
|
// one of the names from the corresponding OpDef's attr field).
|
||||||
// The values must have a type matching the corresponding OpDef
|
// The values must have a type matching the corresponding OpDef
|
||||||
// attr's type field.
|
// attr's type field.
|
||||||
// TODO(josh11b): Add some examples here showing best practices.
|
// TODO: Add some examples here showing best practices.
|
||||||
map<string, AttrValue> attr = 5;
|
map<string, AttrValue> attr = 5;
|
||||||
|
|
||||||
message ExperimentalDebugInfo {
|
message ExperimentalDebugInfo {
|
||||||
|
@ -96,7 +96,7 @@ message OpDef {
|
|||||||
// Human-readable description.
|
// Human-readable description.
|
||||||
string description = 4;
|
string description = 4;
|
||||||
|
|
||||||
// TODO(josh11b): bool is_optional?
|
// TODO: bool is_optional?
|
||||||
|
|
||||||
// --- Constraints ---
|
// --- Constraints ---
|
||||||
// These constraints are only in effect if specified. Default is no
|
// These constraints are only in effect if specified. Default is no
|
||||||
@ -139,7 +139,7 @@ message OpDef {
|
|||||||
// taking input from multiple devices with a tree of aggregate ops
|
// taking input from multiple devices with a tree of aggregate ops
|
||||||
// that aggregate locally within each device (and possibly within
|
// that aggregate locally within each device (and possibly within
|
||||||
// groups of nearby devices) before communicating.
|
// groups of nearby devices) before communicating.
|
||||||
// TODO(josh11b): Implement that optimization.
|
// TODO: Implement that optimization.
|
||||||
bool is_aggregate = 16; // for things like add
|
bool is_aggregate = 16; // for things like add
|
||||||
|
|
||||||
// Other optimizations go here, like
|
// Other optimizations go here, like
|
||||||
|
@ -53,7 +53,7 @@ message MemoryStats {
|
|||||||
|
|
||||||
// Time/size stats recorded for a single execution of a graph node.
|
// Time/size stats recorded for a single execution of a graph node.
|
||||||
message NodeExecStats {
|
message NodeExecStats {
|
||||||
// TODO(tucker): Use some more compact form of node identity than
|
// TODO: Use some more compact form of node identity than
|
||||||
// the full string name. Either all processes should agree on a
|
// the full string name. Either all processes should agree on a
|
||||||
// global id (cost_id?) for each node, or we should use a hash of
|
// global id (cost_id?) for each node, or we should use a hash of
|
||||||
// the name.
|
// the name.
|
||||||
|
@ -16,7 +16,7 @@ option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framewo
|
|||||||
message TensorProto {
|
message TensorProto {
|
||||||
DataType dtype = 1;
|
DataType dtype = 1;
|
||||||
|
|
||||||
// Shape of the tensor. TODO(touts): sort out the 0-rank issues.
|
// Shape of the tensor. TODO: sort out the 0-rank issues.
|
||||||
TensorShapeProto tensor_shape = 2;
|
TensorShapeProto tensor_shape = 2;
|
||||||
|
|
||||||
// Only one of the representations below is set, one of "tensor_contents" and
|
// Only one of the representations below is set, one of "tensor_contents" and
|
||||||
|
@ -532,7 +532,7 @@ message ConfigProto {
|
|||||||
|
|
||||||
// We removed the flag client_handles_error_formatting. Marking the tag
|
// We removed the flag client_handles_error_formatting. Marking the tag
|
||||||
// number as reserved.
|
// number as reserved.
|
||||||
// TODO(shikharagarwal): Should we just remove this tag so that it can be
|
// TODO: Should we just remove this tag so that it can be
|
||||||
// used in future for other purpose?
|
// used in future for other purpose?
|
||||||
reserved 2;
|
reserved 2;
|
||||||
|
|
||||||
@ -576,7 +576,7 @@ message ConfigProto {
|
|||||||
// - If isolate_session_state is true, session states are isolated.
|
// - If isolate_session_state is true, session states are isolated.
|
||||||
// - If isolate_session_state is false, session states are shared.
|
// - If isolate_session_state is false, session states are shared.
|
||||||
//
|
//
|
||||||
// TODO(b/129330037): Add a single API that consistently treats
|
// TODO: Add a single API that consistently treats
|
||||||
// isolate_session_state and ClusterSpec propagation.
|
// isolate_session_state and ClusterSpec propagation.
|
||||||
bool share_session_state_in_clusterspec_propagation = 8;
|
bool share_session_state_in_clusterspec_propagation = 8;
|
||||||
|
|
||||||
@ -704,7 +704,7 @@ message ConfigProto {
|
|||||||
|
|
||||||
// Options for a single Run() call.
|
// Options for a single Run() call.
|
||||||
message RunOptions {
|
message RunOptions {
|
||||||
// TODO(pbar) Turn this into a TraceOptions proto which allows
|
// TODO Turn this into a TraceOptions proto which allows
|
||||||
// tracing to be controlled in a more orthogonal manner?
|
// tracing to be controlled in a more orthogonal manner?
|
||||||
enum TraceLevel {
|
enum TraceLevel {
|
||||||
NO_TRACE = 0;
|
NO_TRACE = 0;
|
||||||
@ -781,7 +781,7 @@ message RunMetadata {
|
|||||||
repeated GraphDef partition_graphs = 3;
|
repeated GraphDef partition_graphs = 3;
|
||||||
|
|
||||||
message FunctionGraphs {
|
message FunctionGraphs {
|
||||||
// TODO(nareshmodi): Include some sort of function/cache-key identifier?
|
// TODO: Include some sort of function/cache-key identifier?
|
||||||
repeated GraphDef partition_graphs = 1;
|
repeated GraphDef partition_graphs = 1;
|
||||||
|
|
||||||
GraphDef pre_optimization_graph = 2;
|
GraphDef pre_optimization_graph = 2;
|
||||||
|
@ -194,7 +194,7 @@ service CoordinationService {
|
|||||||
|
|
||||||
// Report error to the task. RPC sets the receiving instance of coordination
|
// Report error to the task. RPC sets the receiving instance of coordination
|
||||||
// service agent to error state permanently.
|
// service agent to error state permanently.
|
||||||
// TODO(b/195990880): Consider splitting this into a different RPC service.
|
// TODO: Consider splitting this into a different RPC service.
|
||||||
rpc ReportErrorToAgent(ReportErrorToAgentRequest)
|
rpc ReportErrorToAgent(ReportErrorToAgentRequest)
|
||||||
returns (ReportErrorToAgentResponse);
|
returns (ReportErrorToAgentResponse);
|
||||||
|
|
||||||
|
@ -46,7 +46,7 @@ message DebugTensorWatch {
|
|||||||
// are to be debugged, the callers of Session::Run() must use distinct
|
// are to be debugged, the callers of Session::Run() must use distinct
|
||||||
// debug_urls to make sure that the streamed or dumped events do not overlap
|
// debug_urls to make sure that the streamed or dumped events do not overlap
|
||||||
// among the invocations.
|
// among the invocations.
|
||||||
// TODO(cais): More visible documentation of this in g3docs.
|
// TODO: More visible documentation of this in g3docs.
|
||||||
repeated string debug_urls = 4;
|
repeated string debug_urls = 4;
|
||||||
|
|
||||||
// Do not error out if debug op creation fails (e.g., due to dtype
|
// Do not error out if debug op creation fails (e.g., due to dtype
|
||||||
|
@ -12,7 +12,7 @@ option java_package = "org.tensorflow.util";
|
|||||||
option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/protobuf/for_core_protos_go_proto";
|
option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/protobuf/for_core_protos_go_proto";
|
||||||
|
|
||||||
// Available modes for extracting debugging information from a Tensor.
|
// Available modes for extracting debugging information from a Tensor.
|
||||||
// TODO(cais): Document the detailed column names and semantics in a separate
|
// TODO: Document the detailed column names and semantics in a separate
|
||||||
// markdown file once the implementation settles.
|
// markdown file once the implementation settles.
|
||||||
enum TensorDebugMode {
|
enum TensorDebugMode {
|
||||||
UNSPECIFIED = 0;
|
UNSPECIFIED = 0;
|
||||||
@ -223,7 +223,7 @@ message DebuggedDevice {
|
|||||||
// A debugger-generated ID for the device. Guaranteed to be unique within
|
// A debugger-generated ID for the device. Guaranteed to be unique within
|
||||||
// the scope of the debugged TensorFlow program, including single-host and
|
// the scope of the debugged TensorFlow program, including single-host and
|
||||||
// multi-host settings.
|
// multi-host settings.
|
||||||
// TODO(cais): Test the uniqueness guarantee in multi-host settings.
|
// TODO: Test the uniqueness guarantee in multi-host settings.
|
||||||
int32 device_id = 2;
|
int32 device_id = 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -264,7 +264,7 @@ message Execution {
|
|||||||
// field with the DebuggedDevice messages.
|
// field with the DebuggedDevice messages.
|
||||||
repeated int32 output_tensor_device_ids = 9;
|
repeated int32 output_tensor_device_ids = 9;
|
||||||
|
|
||||||
// TODO(cais): When backporting to V1 Session.run() support, add more fields
|
// TODO support, add more fields
|
||||||
// such as fetches and feeds.
|
// such as fetches and feeds.
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -7,7 +7,7 @@ option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/protobu
|
|||||||
|
|
||||||
// Used to serialize and transmit tensorflow::Status payloads through
|
// Used to serialize and transmit tensorflow::Status payloads through
|
||||||
// grpc::Status `error_details` since grpc::Status lacks payload API.
|
// grpc::Status `error_details` since grpc::Status lacks payload API.
|
||||||
// TODO(b/204231601): Use GRPC API once supported.
|
// TODO: Use GRPC API once supported.
|
||||||
message GrpcPayloadContainer {
|
message GrpcPayloadContainer {
|
||||||
map<string, bytes> payloads = 1;
|
map<string, bytes> payloads = 1;
|
||||||
}
|
}
|
||||||
|
@ -172,7 +172,7 @@ message WaitQueueDoneRequest {
|
|||||||
}
|
}
|
||||||
|
|
||||||
message WaitQueueDoneResponse {
|
message WaitQueueDoneResponse {
|
||||||
// TODO(nareshmodi): Consider adding NodeExecStats here to be able to
|
// TODO: Consider adding NodeExecStats here to be able to
|
||||||
// propagate some stats.
|
// propagate some stats.
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -94,7 +94,7 @@ message ExtendSessionRequest {
|
|||||||
}
|
}
|
||||||
|
|
||||||
message ExtendSessionResponse {
|
message ExtendSessionResponse {
|
||||||
// TODO(mrry): Return something about the operation?
|
// TODO: Return something about the operation?
|
||||||
|
|
||||||
// The new version number for the extended graph, to be used in the next call
|
// The new version number for the extended graph, to be used in the next call
|
||||||
// to ExtendSession.
|
// to ExtendSession.
|
||||||
|
@ -176,7 +176,7 @@ message SavedBareConcreteFunction {
|
|||||||
// allows the ConcreteFunction to be called with nest structure inputs. This
|
// allows the ConcreteFunction to be called with nest structure inputs. This
|
||||||
// field may not be populated. If this field is absent, the concrete function
|
// field may not be populated. If this field is absent, the concrete function
|
||||||
// can only be called with flat inputs.
|
// can only be called with flat inputs.
|
||||||
// TODO(b/169361281): support calling saved ConcreteFunction with structured
|
// TODO: support calling saved ConcreteFunction with structured
|
||||||
// inputs in C++ SavedModel API.
|
// inputs in C++ SavedModel API.
|
||||||
FunctionSpec function_spec = 4;
|
FunctionSpec function_spec = 4;
|
||||||
}
|
}
|
||||||
|
@ -17,7 +17,7 @@ option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/protobu
|
|||||||
|
|
||||||
// Special header that is associated with a bundle.
|
// Special header that is associated with a bundle.
|
||||||
//
|
//
|
||||||
// TODO(zongheng,zhifengc): maybe in the future, we can add information about
|
// TODO: maybe in the future, we can add information about
|
||||||
// which binary produced this checkpoint, timestamp, etc. Sometime, these can be
|
// which binary produced this checkpoint, timestamp, etc. Sometime, these can be
|
||||||
// valuable debugging information. And if needed, these can be used as defensive
|
// valuable debugging information. And if needed, these can be used as defensive
|
||||||
// information ensuring reader (binary version) of the checkpoint and the writer
|
// information ensuring reader (binary version) of the checkpoint and the writer
|
||||||
|
@ -188,7 +188,7 @@ message DeregisterGraphRequest {
|
|||||||
}
|
}
|
||||||
|
|
||||||
message DeregisterGraphResponse {
|
message DeregisterGraphResponse {
|
||||||
// TODO(mrry): Optionally add summary stats for the graph.
|
// TODO: Optionally add summary stats for the graph.
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
@ -294,7 +294,7 @@ message RunGraphResponse {
|
|||||||
|
|
||||||
// If the request asked for execution stats, the cost graph, or the partition
|
// If the request asked for execution stats, the cost graph, or the partition
|
||||||
// graphs, these are returned here.
|
// graphs, these are returned here.
|
||||||
// TODO(suharshs): Package these in a RunMetadata instead.
|
// TODO: Package these in a RunMetadata instead.
|
||||||
StepStats step_stats = 2;
|
StepStats step_stats = 2;
|
||||||
CostGraphDef cost_graph = 3;
|
CostGraphDef cost_graph = 3;
|
||||||
repeated GraphDef partition_graph = 4;
|
repeated GraphDef partition_graph = 4;
|
||||||
|
@ -13,5 +13,5 @@ message LogMetadata {
|
|||||||
SamplingConfig sampling_config = 2;
|
SamplingConfig sampling_config = 2;
|
||||||
// List of tags used to load the relevant MetaGraphDef from SavedModel.
|
// List of tags used to load the relevant MetaGraphDef from SavedModel.
|
||||||
repeated string saved_model_tags = 3;
|
repeated string saved_model_tags = 3;
|
||||||
// TODO(b/33279154): Add more metadata as mentioned in the bug.
|
// TODO: Add more metadata as mentioned in the bug.
|
||||||
}
|
}
|
||||||
|
@ -58,7 +58,7 @@ message FileSystemStoragePathSourceConfig {
|
|||||||
|
|
||||||
// A single servable name/base_path pair to monitor.
|
// A single servable name/base_path pair to monitor.
|
||||||
// DEPRECATED: Use 'servables' instead.
|
// DEPRECATED: Use 'servables' instead.
|
||||||
// TODO(b/30898016): Stop using these fields, and ultimately remove them here.
|
// TODO: Stop using these fields, and ultimately remove them here.
|
||||||
string servable_name = 1 [deprecated = true];
|
string servable_name = 1 [deprecated = true];
|
||||||
string base_path = 2 [deprecated = true];
|
string base_path = 2 [deprecated = true];
|
||||||
|
|
||||||
@ -76,7 +76,7 @@ message FileSystemStoragePathSourceConfig {
|
|||||||
// check for a version to appear later.)
|
// check for a version to appear later.)
|
||||||
// DEPRECATED: Use 'servable_versions_always_present' instead, which includes
|
// DEPRECATED: Use 'servable_versions_always_present' instead, which includes
|
||||||
// this behavior.
|
// this behavior.
|
||||||
// TODO(b/30898016): Remove 2019-10-31 or later.
|
// TODO: Remove 2019-10-31 or later.
|
||||||
bool fail_if_zero_versions_at_startup = 4 [deprecated = true];
|
bool fail_if_zero_versions_at_startup = 4 [deprecated = true];
|
||||||
|
|
||||||
// If true, the servable is always expected to exist on the underlying
|
// If true, the servable is always expected to exist on the underlying
|
||||||
|
@ -9,7 +9,7 @@ import "tensorflow_serving/config/logging_config.proto";
|
|||||||
option cc_enable_arenas = true;
|
option cc_enable_arenas = true;
|
||||||
|
|
||||||
// The type of model.
|
// The type of model.
|
||||||
// TODO(b/31336131): DEPRECATED.
|
// TODO: DEPRECATED.
|
||||||
enum ModelType {
|
enum ModelType {
|
||||||
MODEL_TYPE_UNSPECIFIED = 0 [deprecated = true];
|
MODEL_TYPE_UNSPECIFIED = 0 [deprecated = true];
|
||||||
TENSORFLOW = 1 [deprecated = true];
|
TENSORFLOW = 1 [deprecated = true];
|
||||||
@ -31,7 +31,7 @@ message ModelConfig {
|
|||||||
string base_path = 2;
|
string base_path = 2;
|
||||||
|
|
||||||
// Type of model.
|
// Type of model.
|
||||||
// TODO(b/31336131): DEPRECATED. Please use 'model_platform' instead.
|
// TODO: DEPRECATED. Please use 'model_platform' instead.
|
||||||
ModelType model_type = 3 [deprecated = true];
|
ModelType model_type = 3 [deprecated = true];
|
||||||
|
|
||||||
// Type of model (e.g. "tensorflow").
|
// Type of model (e.g. "tensorflow").
|
||||||
|
@ -1,9 +1,10 @@
|
|||||||
#!/bin/sh
|
#!/bin/sh
|
||||||
#RUST_LOG=debug LD_LIBRARY_PATH=so/onnx/lib target/release/navi_onnx --port 30 --num-worker-threads 8 --intra-op-parallelism 8 --inter-op-parallelism 8 \
|
#RUST_LOG=debug LD_LIBRARY_PATH=so/onnx/lib target/release/navi_onnx --port 30 --num-worker-threads 8 --intra-op-parallelism 8 --inter-op-parallelism 8 \
|
||||||
RUST_LOG=info LD_LIBRARY_PATH=so/onnx/lib cargo run --bin navi_onnx --features onnx -- \
|
RUST_LOG=info LD_LIBRARY_PATH=so/onnx/lib cargo run --bin navi_onnx --features onnx -- \
|
||||||
--port 8030 --num-worker-threads 8 \
|
--port 30 --num-worker-threads 8 --intra-op-parallelism 8 --inter-op-parallelism 8 \
|
||||||
--model-check-interval-secs 30 \
|
--model-check-interval-secs 30 \
|
||||||
|
--model-dir models/int8 \
|
||||||
|
--output caligrated_probabilities \
|
||||||
|
--input "" \
|
||||||
--modelsync-cli "echo" \
|
--modelsync-cli "echo" \
|
||||||
--onnx-ep-options use_arena=true \
|
--onnx-ep-options use_arena=true
|
||||||
--model-dir models/prod_home --output caligrated_probabilities --input "" --intra-op-parallelism 8 --inter-op-parallelism 8 --max-batch-size 1 --batch-time-out-millis 1 \
|
|
||||||
--model-dir models/prod_home1 --output caligrated_probabilities --input "" --intra-op-parallelism 8 --inter-op-parallelism 8 --max-batch-size 1 --batch-time-out-millis 1 \
|
|
||||||
|
@ -1,24 +1,11 @@
|
|||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use log::info;
|
|
||||||
use navi::cli_args::{ARGS, MODEL_SPECS};
|
use navi::cli_args::{ARGS, MODEL_SPECS};
|
||||||
use navi::onnx_model::onnx::OnnxModel;
|
use navi::onnx_model::onnx::OnnxModel;
|
||||||
use navi::{bootstrap, metrics};
|
use navi::{bootstrap, metrics};
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
fn main() -> Result<()> {
|
||||||
env_logger::init();
|
env_logger::init();
|
||||||
info!("global: {:?}", ARGS.onnx_global_thread_pool_options);
|
assert_eq!(MODEL_SPECS.len(), ARGS.inter_op_parallelism.len());
|
||||||
let assert_session_params = if ARGS.onnx_global_thread_pool_options.is_empty() {
|
|
||||||
// std::env::set_var("OMP_NUM_THREADS", "1");
|
|
||||||
info!("now we use per session thread pool");
|
|
||||||
MODEL_SPECS.len()
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
info!("now we use global thread pool");
|
|
||||||
0
|
|
||||||
};
|
|
||||||
assert_eq!(assert_session_params, ARGS.inter_op_parallelism.len());
|
|
||||||
assert_eq!(assert_session_params, ARGS.inter_op_parallelism.len());
|
|
||||||
|
|
||||||
metrics::register_custom_metrics();
|
metrics::register_custom_metrics();
|
||||||
bootstrap::bootstrap(OnnxModel::new)
|
bootstrap::bootstrap(OnnxModel::new)
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use log::{info, warn};
|
use log::{info, warn};
|
||||||
use x509_parser::{prelude::{parse_x509_pem}, parse_x509_certificate};
|
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
use tonic::{
|
use tonic::{
|
||||||
@ -28,7 +27,6 @@ use crate::cli_args::{ARGS, INPUTS, OUTPUTS};
|
|||||||
use crate::metrics::{
|
use crate::metrics::{
|
||||||
NAVI_VERSION, NUM_PREDICTIONS, NUM_REQUESTS_FAILED, NUM_REQUESTS_FAILED_BY_MODEL,
|
NAVI_VERSION, NUM_PREDICTIONS, NUM_REQUESTS_FAILED, NUM_REQUESTS_FAILED_BY_MODEL,
|
||||||
NUM_REQUESTS_RECEIVED, NUM_REQUESTS_RECEIVED_BY_MODEL, RESPONSE_TIME_COLLECTOR,
|
NUM_REQUESTS_RECEIVED, NUM_REQUESTS_RECEIVED_BY_MODEL, RESPONSE_TIME_COLLECTOR,
|
||||||
CERT_EXPIRY_EPOCH
|
|
||||||
};
|
};
|
||||||
use crate::predict_service::{Model, PredictService};
|
use crate::predict_service::{Model, PredictService};
|
||||||
use crate::tf_proto::tensorflow_serving::model_spec::VersionChoice::Version;
|
use crate::tf_proto::tensorflow_serving::model_spec::VersionChoice::Version;
|
||||||
@ -209,9 +207,6 @@ impl<T: Model> PredictionService for PredictService<T> {
|
|||||||
PredictResult::DropDueToOverload => Err(Status::resource_exhausted("")),
|
PredictResult::DropDueToOverload => Err(Status::resource_exhausted("")),
|
||||||
PredictResult::ModelNotFound(idx) => {
|
PredictResult::ModelNotFound(idx) => {
|
||||||
Err(Status::not_found(format!("model index {}", idx)))
|
Err(Status::not_found(format!("model index {}", idx)))
|
||||||
},
|
|
||||||
PredictResult::ModelNotReady(idx) => {
|
|
||||||
Err(Status::unavailable(format!("model index {}", idx)))
|
|
||||||
}
|
}
|
||||||
PredictResult::ModelVersionNotFound(idx, version) => Err(
|
PredictResult::ModelVersionNotFound(idx, version) => Err(
|
||||||
Status::not_found(format!("model index:{}, version {}", idx, version)),
|
Status::not_found(format!("model index:{}, version {}", idx, version)),
|
||||||
@ -235,12 +230,6 @@ impl<T: Model> PredictionService for PredictService<T> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// A function that takes a timestamp as input and returns a ticker stream
|
|
||||||
fn report_expiry(expiry_time: i64) {
|
|
||||||
info!("Certificate expires at epoch: {:?}", expiry_time);
|
|
||||||
CERT_EXPIRY_EPOCH.set(expiry_time as i64);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn bootstrap<T: Model>(model_factory: ModelFactory<T>) -> Result<()> {
|
pub fn bootstrap<T: Model>(model_factory: ModelFactory<T>) -> Result<()> {
|
||||||
info!("package: {}, version: {}, args: {:?}", NAME, VERSION, *ARGS);
|
info!("package: {}, version: {}, args: {:?}", NAME, VERSION, *ARGS);
|
||||||
//we follow SemVer. So here we assume MAJOR.MINOR.PATCH
|
//we follow SemVer. So here we assume MAJOR.MINOR.PATCH
|
||||||
@ -257,7 +246,6 @@ pub fn bootstrap<T: Model>(model_factory: ModelFactory<T>) -> Result<()> {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
tokio::runtime::Builder::new_multi_thread()
|
tokio::runtime::Builder::new_multi_thread()
|
||||||
.thread_name("async worker")
|
.thread_name("async worker")
|
||||||
.worker_threads(ARGS.num_worker_threads)
|
.worker_threads(ARGS.num_worker_threads)
|
||||||
@ -275,21 +263,6 @@ pub fn bootstrap<T: Model>(model_factory: ModelFactory<T>) -> Result<()> {
|
|||||||
let mut builder = if ARGS.ssl_dir.is_empty() {
|
let mut builder = if ARGS.ssl_dir.is_empty() {
|
||||||
Server::builder()
|
Server::builder()
|
||||||
} else {
|
} else {
|
||||||
// Read the pem file as a string
|
|
||||||
let pem_str = std::fs::read_to_string(format!("{}/server.crt", ARGS.ssl_dir)).unwrap();
|
|
||||||
let res = parse_x509_pem(&pem_str.as_bytes());
|
|
||||||
match res {
|
|
||||||
Ok((rem, pem_2)) => {
|
|
||||||
assert!(rem.is_empty());
|
|
||||||
assert_eq!(pem_2.label, String::from("CERTIFICATE"));
|
|
||||||
let res_x509 = parse_x509_certificate(&pem_2.contents);
|
|
||||||
info!("Certificate label: {}", pem_2.label);
|
|
||||||
assert!(res_x509.is_ok());
|
|
||||||
report_expiry(res_x509.unwrap().1.validity().not_after.timestamp());
|
|
||||||
},
|
|
||||||
_ => panic!("PEM parsing failed: {:?}", res),
|
|
||||||
}
|
|
||||||
|
|
||||||
let key = tokio::fs::read(format!("{}/server.key", ARGS.ssl_dir))
|
let key = tokio::fs::read(format!("{}/server.key", ARGS.ssl_dir))
|
||||||
.await
|
.await
|
||||||
.expect("can't find key file");
|
.expect("can't find key file");
|
||||||
@ -305,7 +278,7 @@ pub fn bootstrap<T: Model>(model_factory: ModelFactory<T>) -> Result<()> {
|
|||||||
let identity = Identity::from_pem(pem.clone(), key);
|
let identity = Identity::from_pem(pem.clone(), key);
|
||||||
let client_ca_cert = Certificate::from_pem(pem.clone());
|
let client_ca_cert = Certificate::from_pem(pem.clone());
|
||||||
let tls = ServerTlsConfig::new()
|
let tls = ServerTlsConfig::new()
|
||||||
.identity(identity)
|
.identity(identity)
|
||||||
.client_ca_root(client_ca_cert);
|
.client_ca_root(client_ca_cert);
|
||||||
Server::builder()
|
Server::builder()
|
||||||
.tls_config(tls)
|
.tls_config(tls)
|
||||||
|
@ -87,11 +87,13 @@ pub struct Args {
|
|||||||
pub intra_op_parallelism: Vec<String>,
|
pub intra_op_parallelism: Vec<String>,
|
||||||
#[clap(
|
#[clap(
|
||||||
long,
|
long,
|
||||||
|
default_value = "14",
|
||||||
help = "number of threads to parallelize computations of the graph"
|
help = "number of threads to parallelize computations of the graph"
|
||||||
)]
|
)]
|
||||||
pub inter_op_parallelism: Vec<String>,
|
pub inter_op_parallelism: Vec<String>,
|
||||||
#[clap(
|
#[clap(
|
||||||
long,
|
long,
|
||||||
|
default_value = "serving_default",
|
||||||
help = "signature of a serving. only TF"
|
help = "signature of a serving. only TF"
|
||||||
)]
|
)]
|
||||||
pub serving_sig: Vec<String>,
|
pub serving_sig: Vec<String>,
|
||||||
@ -105,12 +107,10 @@ pub struct Args {
|
|||||||
help = "max warmup records to use. warmup only implemented for TF"
|
help = "max warmup records to use. warmup only implemented for TF"
|
||||||
)]
|
)]
|
||||||
pub max_warmup_records: usize,
|
pub max_warmup_records: usize,
|
||||||
#[clap(long, value_parser = Args::parse_key_val::<String, String>, value_delimiter=',')]
|
|
||||||
pub onnx_global_thread_pool_options: Vec<(String, String)>,
|
|
||||||
#[clap(
|
#[clap(
|
||||||
long,
|
long,
|
||||||
default_value = "true",
|
default_value = "true",
|
||||||
help = "when to use graph parallelization. only for ONNX"
|
help = "when to use graph parallelization. only for ONNX"
|
||||||
)]
|
)]
|
||||||
pub onnx_use_parallel_mode: String,
|
pub onnx_use_parallel_mode: String,
|
||||||
// #[clap(long, default_value = "false")]
|
// #[clap(long, default_value = "false")]
|
||||||
|
@ -146,7 +146,6 @@ pub enum PredictResult {
|
|||||||
Ok(Vec<TensorScores>, i64),
|
Ok(Vec<TensorScores>, i64),
|
||||||
DropDueToOverload,
|
DropDueToOverload,
|
||||||
ModelNotFound(usize),
|
ModelNotFound(usize),
|
||||||
ModelNotReady(usize),
|
|
||||||
ModelVersionNotFound(usize, i64),
|
ModelVersionNotFound(usize, i64),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -171,9 +171,6 @@ lazy_static! {
|
|||||||
&["model_name"]
|
&["model_name"]
|
||||||
)
|
)
|
||||||
.expect("metric can be created");
|
.expect("metric can be created");
|
||||||
pub static ref CERT_EXPIRY_EPOCH: IntGauge =
|
|
||||||
IntGauge::new(":navi:cert_expiry_epoch", "Timestamp when the current cert expires")
|
|
||||||
.expect("metric can be created");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn register_custom_metrics() {
|
pub fn register_custom_metrics() {
|
||||||
@ -252,10 +249,6 @@ pub fn register_custom_metrics() {
|
|||||||
REGISTRY
|
REGISTRY
|
||||||
.register(Box::new(CONVERTER_TIME_COLLECTOR.clone()))
|
.register(Box::new(CONVERTER_TIME_COLLECTOR.clone()))
|
||||||
.expect("collector can be registered");
|
.expect("collector can be registered");
|
||||||
REGISTRY
|
|
||||||
.register(Box::new(CERT_EXPIRY_EPOCH.clone()))
|
|
||||||
.expect("collector can be registered");
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn register_dynamic_metrics(c: &HistogramVec) {
|
pub fn register_dynamic_metrics(c: &HistogramVec) {
|
||||||
|
@ -13,22 +13,21 @@ pub mod onnx {
|
|||||||
use dr_transform::converter::{BatchPredictionRequestToTorchTensorConverter, Converter};
|
use dr_transform::converter::{BatchPredictionRequestToTorchTensorConverter, Converter};
|
||||||
use itertools::Itertools;
|
use itertools::Itertools;
|
||||||
use log::{debug, info};
|
use log::{debug, info};
|
||||||
use dr_transform::ort::environment::Environment;
|
use ort::environment::Environment;
|
||||||
use dr_transform::ort::session::Session;
|
use ort::session::Session;
|
||||||
use dr_transform::ort::tensor::InputTensor;
|
use ort::tensor::InputTensor;
|
||||||
use dr_transform::ort::{ExecutionProvider, GraphOptimizationLevel, SessionBuilder};
|
use ort::{ExecutionProvider, GraphOptimizationLevel, SessionBuilder};
|
||||||
use dr_transform::ort::LoggingLevel;
|
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
use std::fmt::{Debug, Display};
|
use std::fmt::{Debug, Display};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::{fmt, fs};
|
use std::{fmt, fs};
|
||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
|
|
||||||
lazy_static! {
|
lazy_static! {
|
||||||
pub static ref ENVIRONMENT: Arc<Environment> = Arc::new(
|
pub static ref ENVIRONMENT: Arc<Environment> = Arc::new(
|
||||||
Environment::builder()
|
Environment::builder()
|
||||||
.with_name("onnx home")
|
.with_name("onnx home")
|
||||||
.with_log_level(LoggingLevel::Error)
|
.with_log_level(ort::LoggingLevel::Error)
|
||||||
.with_global_thread_pool(ARGS.onnx_global_thread_pool_options.clone())
|
|
||||||
.build()
|
.build()
|
||||||
.unwrap()
|
.unwrap()
|
||||||
);
|
);
|
||||||
@ -102,30 +101,23 @@ pub mod onnx {
|
|||||||
let meta_info = format!("{}/{}/{}", ARGS.model_dir[idx], version, META_INFO);
|
let meta_info = format!("{}/{}/{}", ARGS.model_dir[idx], version, META_INFO);
|
||||||
let mut builder = SessionBuilder::new(&ENVIRONMENT)?
|
let mut builder = SessionBuilder::new(&ENVIRONMENT)?
|
||||||
.with_optimization_level(GraphOptimizationLevel::Level3)?
|
.with_optimization_level(GraphOptimizationLevel::Level3)?
|
||||||
.with_parallel_execution(ARGS.onnx_use_parallel_mode == "true")?;
|
.with_parallel_execution(ARGS.onnx_use_parallel_mode == "true")?
|
||||||
if ARGS.onnx_global_thread_pool_options.is_empty() {
|
.with_inter_threads(
|
||||||
builder = builder
|
utils::get_config_or(
|
||||||
.with_inter_threads(
|
model_config,
|
||||||
utils::get_config_or(
|
"inter_op_parallelism",
|
||||||
model_config,
|
&ARGS.inter_op_parallelism[idx],
|
||||||
"inter_op_parallelism",
|
)
|
||||||
&ARGS.inter_op_parallelism[idx],
|
.parse()?,
|
||||||
)
|
)?
|
||||||
.parse()?,
|
.with_intra_threads(
|
||||||
)?
|
utils::get_config_or(
|
||||||
.with_intra_threads(
|
model_config,
|
||||||
utils::get_config_or(
|
"intra_op_parallelism",
|
||||||
model_config,
|
&ARGS.intra_op_parallelism[idx],
|
||||||
"intra_op_parallelism",
|
)
|
||||||
&ARGS.intra_op_parallelism[idx],
|
.parse()?,
|
||||||
)
|
)?
|
||||||
.parse()?,
|
|
||||||
)?;
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
builder = builder.with_disable_per_session_threads()?;
|
|
||||||
}
|
|
||||||
builder = builder
|
|
||||||
.with_memory_pattern(ARGS.onnx_use_memory_pattern == "true")?
|
.with_memory_pattern(ARGS.onnx_use_memory_pattern == "true")?
|
||||||
.with_execution_providers(&OnnxModel::ep_choices())?;
|
.with_execution_providers(&OnnxModel::ep_choices())?;
|
||||||
match &ARGS.profiling {
|
match &ARGS.profiling {
|
||||||
@ -189,7 +181,7 @@ pub mod onnx {
|
|||||||
&version,
|
&version,
|
||||||
reporting_feature_ids,
|
reporting_feature_ids,
|
||||||
Some(metrics::register_dynamic_metrics),
|
Some(metrics::register_dynamic_metrics),
|
||||||
)?),
|
)),
|
||||||
};
|
};
|
||||||
onnx_model.warmup()?;
|
onnx_model.warmup()?;
|
||||||
Ok(onnx_model)
|
Ok(onnx_model)
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
use anyhow::{anyhow, Result};
|
use anyhow::{anyhow, Result};
|
||||||
use arrayvec::ArrayVec;
|
use arrayvec::ArrayVec;
|
||||||
use itertools::Itertools;
|
use itertools::Itertools;
|
||||||
use log::{error, info};
|
use log::{error, info, warn};
|
||||||
use std::fmt::{Debug, Display};
|
use std::fmt::{Debug, Display};
|
||||||
use std::string::String;
|
use std::string::String;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
@ -24,7 +24,7 @@ use serde_json::{self, Value};
|
|||||||
|
|
||||||
pub trait Model: Send + Sync + Display + Debug + 'static {
|
pub trait Model: Send + Sync + Display + Debug + 'static {
|
||||||
fn warmup(&self) -> Result<()>;
|
fn warmup(&self) -> Result<()>;
|
||||||
//TODO: refactor this to return vec<vec<TensorScores>>, i.e.
|
//TODO: refactor this to return Vec<Vec<TensorScores>>, i.e.
|
||||||
//we have the underlying runtime impl to split the response to each client.
|
//we have the underlying runtime impl to split the response to each client.
|
||||||
//It will eliminate some inefficient memory copy in onnx_model.rs as well as simplify code
|
//It will eliminate some inefficient memory copy in onnx_model.rs as well as simplify code
|
||||||
fn do_predict(
|
fn do_predict(
|
||||||
@ -179,17 +179,17 @@ impl<T: Model> PredictService<T> {
|
|||||||
//initialize the latest version array
|
//initialize the latest version array
|
||||||
let mut cur_versions = vec!["".to_owned(); MODEL_SPECS.len()];
|
let mut cur_versions = vec!["".to_owned(); MODEL_SPECS.len()];
|
||||||
loop {
|
loop {
|
||||||
|
let config = utils::read_config(&meta_file).unwrap_or_else(|e| {
|
||||||
|
warn!("config file {} not found due to: {}", meta_file, e);
|
||||||
|
Value::Null
|
||||||
|
});
|
||||||
info!("***polling for models***"); //nice deliminter
|
info!("***polling for models***"); //nice deliminter
|
||||||
|
info!("config:{}", config);
|
||||||
if let Some(ref cli) = ARGS.modelsync_cli {
|
if let Some(ref cli) = ARGS.modelsync_cli {
|
||||||
if let Err(e) = call_external_modelsync(cli, &cur_versions).await {
|
if let Err(e) = call_external_modelsync(cli, &cur_versions).await {
|
||||||
error!("model sync cli running error:{}", e)
|
error!("model sync cli running error:{}", e)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
let config = utils::read_config(&meta_file).unwrap_or_else(|e| {
|
|
||||||
info!("config file {} not found due to: {}", meta_file, e);
|
|
||||||
Value::Null
|
|
||||||
});
|
|
||||||
info!("config:{}", config);
|
|
||||||
for (idx, cur_version) in cur_versions.iter_mut().enumerate() {
|
for (idx, cur_version) in cur_versions.iter_mut().enumerate() {
|
||||||
let model_dir = &ARGS.model_dir[idx];
|
let model_dir = &ARGS.model_dir[idx];
|
||||||
PredictService::scan_load_latest_model_from_model_dir(
|
PredictService::scan_load_latest_model_from_model_dir(
|
||||||
@ -222,39 +222,33 @@ impl<T: Model> PredictService<T> {
|
|||||||
.map(|b| b.parse().unwrap())
|
.map(|b| b.parse().unwrap())
|
||||||
.collect::<Vec<u64>>();
|
.collect::<Vec<u64>>();
|
||||||
let no_msg_wait_millis = *batch_time_out_millis.iter().min().unwrap();
|
let no_msg_wait_millis = *batch_time_out_millis.iter().min().unwrap();
|
||||||
let mut all_model_predictors: ArrayVec::<ArrayVec<BatchPredictor<T>, MAX_VERSIONS_PER_MODEL>, MAX_NUM_MODELS> =
|
let mut all_model_predictors =
|
||||||
(0 ..MAX_NUM_MODELS).map( |_| ArrayVec::<BatchPredictor<T>, MAX_VERSIONS_PER_MODEL>::new()).collect();
|
ArrayVec::<ArrayVec<BatchPredictor<T>, MAX_VERSIONS_PER_MODEL>, MAX_NUM_MODELS>::new();
|
||||||
loop {
|
loop {
|
||||||
let msg = rx.try_recv();
|
let msg = rx.try_recv();
|
||||||
let no_more_msg = match msg {
|
let no_more_msg = match msg {
|
||||||
Ok(PredictMessage::Predict(model_spec_at, version, val, resp, ts)) => {
|
Ok(PredictMessage::Predict(model_spec_at, version, val, resp, ts)) => {
|
||||||
if let Some(model_predictors) = all_model_predictors.get_mut(model_spec_at) {
|
if let Some(model_predictors) = all_model_predictors.get_mut(model_spec_at) {
|
||||||
if model_predictors.is_empty() {
|
match version {
|
||||||
resp.send(PredictResult::ModelNotReady(model_spec_at))
|
None => model_predictors[0].push(val, resp, ts),
|
||||||
.unwrap_or_else(|e| error!("cannot send back model not ready error: {:?}", e));
|
Some(the_version) => match model_predictors
|
||||||
}
|
.iter_mut()
|
||||||
else {
|
.find(|x| x.model.version() == the_version)
|
||||||
match version {
|
{
|
||||||
None => model_predictors[0].push(val, resp, ts),
|
None => resp
|
||||||
Some(the_version) => match model_predictors
|
.send(PredictResult::ModelVersionNotFound(
|
||||||
.iter_mut()
|
model_spec_at,
|
||||||
.find(|x| x.model.version() == the_version)
|
the_version,
|
||||||
{
|
))
|
||||||
None => resp
|
.unwrap_or_else(|e| {
|
||||||
.send(PredictResult::ModelVersionNotFound(
|
error!("cannot send back version error: {:?}", e)
|
||||||
model_spec_at,
|
}),
|
||||||
the_version,
|
Some(predictor) => predictor.push(val, resp, ts),
|
||||||
))
|
},
|
||||||
.unwrap_or_else(|e| {
|
|
||||||
error!("cannot send back version error: {:?}", e)
|
|
||||||
}),
|
|
||||||
Some(predictor) => predictor.push(val, resp, ts),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
resp.send(PredictResult::ModelNotFound(model_spec_at))
|
resp.send(PredictResult::ModelNotFound(model_spec_at))
|
||||||
.unwrap_or_else(|e| error!("cannot send back model not found error: {:?}", e))
|
.unwrap_or_else(|e| error!("cannot send back model error: {:?}", e))
|
||||||
}
|
}
|
||||||
MPSC_CHANNEL_SIZE.dec();
|
MPSC_CHANNEL_SIZE.dec();
|
||||||
false
|
false
|
||||||
@ -272,23 +266,27 @@ impl<T: Model> PredictService<T> {
|
|||||||
queue_reset_ts: Instant::now(),
|
queue_reset_ts: Instant::now(),
|
||||||
queue_earliest_rq_ts: Instant::now(),
|
queue_earliest_rq_ts: Instant::now(),
|
||||||
};
|
};
|
||||||
assert!(idx < all_model_predictors.len());
|
if idx < all_model_predictors.len() {
|
||||||
metrics::NEW_MODEL_SNAPSHOT
|
metrics::NEW_MODEL_SNAPSHOT
|
||||||
.with_label_values(&[&MODEL_SPECS[idx]])
|
.with_label_values(&[&MODEL_SPECS[idx]])
|
||||||
.inc();
|
.inc();
|
||||||
|
|
||||||
//we can do this since the vector is small
|
|
||||||
let predictors = &mut all_model_predictors[idx];
|
|
||||||
if predictors.len() == 0 {
|
|
||||||
info!("now we serve new model: {}", predictor.model);
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
info!("now we serve updated model: {}", predictor.model);
|
info!("now we serve updated model: {}", predictor.model);
|
||||||
|
//we can do this since the vector is small
|
||||||
|
let predictors = &mut all_model_predictors[idx];
|
||||||
|
if predictors.len() == ARGS.versions_per_model {
|
||||||
|
predictors.remove(predictors.len() - 1);
|
||||||
|
}
|
||||||
|
predictors.insert(0, predictor);
|
||||||
|
} else {
|
||||||
|
info!("now we serve new model: {:}", predictor.model);
|
||||||
|
let mut predictors =
|
||||||
|
ArrayVec::<BatchPredictor<T>, MAX_VERSIONS_PER_MODEL>::new();
|
||||||
|
predictors.push(predictor);
|
||||||
|
all_model_predictors.push(predictors);
|
||||||
|
//check the invariant that we always push the last model to the end
|
||||||
|
assert_eq!(all_model_predictors.len(), idx + 1)
|
||||||
}
|
}
|
||||||
if predictors.len() == ARGS.versions_per_model {
|
|
||||||
predictors.remove(predictors.len() - 1);
|
|
||||||
}
|
|
||||||
predictors.insert(0, predictor);
|
|
||||||
false
|
false
|
||||||
}
|
}
|
||||||
Err(TryRecvError::Empty) => true,
|
Err(TryRecvError::Empty) => true,
|
||||||
|
@ -3,9 +3,9 @@ name = "segdense"
|
|||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
|
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
env_logger = "0.10.0"
|
|
||||||
serde = { version = "1.0.104", features = ["derive"] }
|
serde = { version = "1.0.104", features = ["derive"] }
|
||||||
serde_json = "1.0.48"
|
serde_json = "1.0.48"
|
||||||
log = "0.4.17"
|
log = "0.4.17"
|
||||||
|
@ -5,49 +5,39 @@ use std::fmt::Display;
|
|||||||
*/
|
*/
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub enum SegDenseError {
|
pub enum SegDenseError {
|
||||||
IoError(std::io::Error),
|
IoError(std::io::Error),
|
||||||
Json(serde_json::Error),
|
Json(serde_json::Error),
|
||||||
JsonMissingRoot,
|
JsonMissingRoot,
|
||||||
JsonMissingObject,
|
JsonMissingObject,
|
||||||
JsonMissingArray,
|
JsonMissingArray,
|
||||||
JsonArraySize,
|
JsonArraySize,
|
||||||
JsonMissingInputFeature,
|
JsonMissingInputFeature,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Display for SegDenseError {
|
impl Display for SegDenseError {
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
match self {
|
match self {
|
||||||
SegDenseError::IoError(io_error) => write!(f, "{}", io_error),
|
SegDenseError::IoError(io_error) => write!(f, "{}", io_error),
|
||||||
SegDenseError::Json(serde_json) => write!(f, "{}", serde_json),
|
SegDenseError::Json(serde_json) => write!(f, "{}", serde_json),
|
||||||
SegDenseError::JsonMissingRoot => {
|
SegDenseError::JsonMissingRoot => write!(f, "{}", "SegDense JSON: Root Node note found!"),
|
||||||
write!(f, "{}", "SegDense JSON: Root Node note found!")
|
SegDenseError::JsonMissingObject => write!(f, "{}", "SegDense JSON: Object note found!"),
|
||||||
}
|
SegDenseError::JsonMissingArray => write!(f, "{}", "SegDense JSON: Array Node note found!"),
|
||||||
SegDenseError::JsonMissingObject => {
|
SegDenseError::JsonArraySize => write!(f, "{}", "SegDense JSON: Array size not as expected!"),
|
||||||
write!(f, "{}", "SegDense JSON: Object note found!")
|
SegDenseError::JsonMissingInputFeature => write!(f, "{}", "SegDense JSON: Missing input feature!"),
|
||||||
}
|
|
||||||
SegDenseError::JsonMissingArray => {
|
|
||||||
write!(f, "{}", "SegDense JSON: Array Node note found!")
|
|
||||||
}
|
|
||||||
SegDenseError::JsonArraySize => {
|
|
||||||
write!(f, "{}", "SegDense JSON: Array size not as expected!")
|
|
||||||
}
|
|
||||||
SegDenseError::JsonMissingInputFeature => {
|
|
||||||
write!(f, "{}", "SegDense JSON: Missing input feature!")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl std::error::Error for SegDenseError {}
|
impl std::error::Error for SegDenseError {}
|
||||||
|
|
||||||
impl From<std::io::Error> for SegDenseError {
|
impl From<std::io::Error> for SegDenseError {
|
||||||
fn from(err: std::io::Error) -> Self {
|
fn from(err: std::io::Error) -> Self {
|
||||||
SegDenseError::IoError(err)
|
SegDenseError::IoError(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<serde_json::Error> for SegDenseError {
|
impl From<serde_json::Error> for SegDenseError {
|
||||||
fn from(err: serde_json::Error) -> Self {
|
fn from(err: serde_json::Error) -> Self {
|
||||||
SegDenseError::Json(err)
|
SegDenseError::Json(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
pub mod error;
|
pub mod error;
|
||||||
|
pub mod segdense_transform_spec_home_recap_2022;
|
||||||
pub mod mapper;
|
pub mod mapper;
|
||||||
pub mod segdense_transform_spec_home_recap_2022;
|
pub mod util;
|
||||||
pub mod util;
|
|
@ -5,18 +5,19 @@ use segdense::error::SegDenseError;
|
|||||||
use segdense::util;
|
use segdense::util;
|
||||||
|
|
||||||
fn main() -> Result<(), SegDenseError> {
|
fn main() -> Result<(), SegDenseError> {
|
||||||
env_logger::init();
|
env_logger::init();
|
||||||
let args: Vec<String> = env::args().collect();
|
let args: Vec<String> = env::args().collect();
|
||||||
|
|
||||||
|
let schema_file_name: &str = if args.len() == 1 {
|
||||||
|
"json/compact.json"
|
||||||
|
} else {
|
||||||
|
&args[1]
|
||||||
|
};
|
||||||
|
|
||||||
let schema_file_name: &str = if args.len() == 1 {
|
let json_str = fs::read_to_string(schema_file_name)?;
|
||||||
"json/compact.json"
|
|
||||||
} else {
|
|
||||||
&args[1]
|
|
||||||
};
|
|
||||||
|
|
||||||
let json_str = fs::read_to_string(schema_file_name)?;
|
util::safe_load_config(&json_str)?;
|
||||||
|
|
||||||
util::safe_load_config(&json_str)?;
|
Ok(())
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -19,13 +19,13 @@ pub struct FeatureMapper {
|
|||||||
impl FeatureMapper {
|
impl FeatureMapper {
|
||||||
pub fn new() -> FeatureMapper {
|
pub fn new() -> FeatureMapper {
|
||||||
FeatureMapper {
|
FeatureMapper {
|
||||||
map: HashMap::new(),
|
map: HashMap::new()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait MapWriter {
|
pub trait MapWriter {
|
||||||
fn set(&mut self, feature_id: i64, info: FeatureInfo);
|
fn set(&mut self, feature_id: i64, info: FeatureInfo);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait MapReader {
|
pub trait MapReader {
|
||||||
|
@ -164,6 +164,7 @@ pub struct ComplexFeatureTypeTransformSpec {
|
|||||||
pub tensor_shape: Vec<i64>,
|
pub tensor_shape: Vec<i64>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
|
#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||||
#[serde(rename_all = "camelCase")]
|
#[serde(rename_all = "camelCase")]
|
||||||
pub struct InputFeatureMapRecord {
|
pub struct InputFeatureMapRecord {
|
||||||
|
@ -1,23 +1,23 @@
|
|||||||
use log::debug;
|
|
||||||
use std::fs;
|
use std::fs;
|
||||||
|
use log::{debug};
|
||||||
|
|
||||||
use serde_json::{Map, Value};
|
use serde_json::{Value, Map};
|
||||||
|
|
||||||
use crate::error::SegDenseError;
|
use crate::error::SegDenseError;
|
||||||
use crate::mapper::{FeatureInfo, FeatureMapper, MapWriter};
|
use crate::mapper::{FeatureMapper, FeatureInfo, MapWriter};
|
||||||
use crate::segdense_transform_spec_home_recap_2022::{self as seg_dense, InputFeature};
|
use crate::segdense_transform_spec_home_recap_2022::{self as seg_dense, InputFeature};
|
||||||
|
|
||||||
pub fn load_config(file_name: &str) -> Result<seg_dense::Root, SegDenseError> {
|
pub fn load_config(file_name: &str) -> seg_dense::Root {
|
||||||
let json_str = fs::read_to_string(file_name)?;
|
let json_str = fs::read_to_string(file_name).expect(
|
||||||
// &format!("Unable to load segdense file {}", file_name));
|
&format!("Unable to load segdense file {}", file_name));
|
||||||
let seg_dense_config = parse(&json_str)?;
|
let seg_dense_config = parse(&json_str).expect(
|
||||||
// &format!("Unable to parse segdense file {}", file_name));
|
&format!("Unable to parse segdense file {}", file_name));
|
||||||
Ok(seg_dense_config)
|
return seg_dense_config;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn parse(json_str: &str) -> Result<seg_dense::Root, SegDenseError> {
|
pub fn parse(json_str: &str) -> Result<seg_dense::Root, SegDenseError> {
|
||||||
let root: seg_dense::Root = serde_json::from_str(json_str)?;
|
let root: seg_dense::Root = serde_json::from_str(json_str)?;
|
||||||
Ok(root)
|
return Ok(root);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -44,8 +44,15 @@ pub fn safe_load_config(json_str: &str) -> Result<FeatureMapper, SegDenseError>
|
|||||||
load_from_parsed_config(root)
|
load_from_parsed_config(root)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn load_from_parsed_config_ref(root: &seg_dense::Root) -> FeatureMapper {
|
||||||
|
load_from_parsed_config(root.clone()).unwrap_or_else(
|
||||||
|
|error| panic!("Error loading all_config.json - {}", error))
|
||||||
|
}
|
||||||
|
|
||||||
// Perf note : make 'root' un-owned
|
// Perf note : make 'root' un-owned
|
||||||
pub fn load_from_parsed_config(root: seg_dense::Root) -> Result<FeatureMapper, SegDenseError> {
|
pub fn load_from_parsed_config(root: seg_dense::Root) ->
|
||||||
|
Result<FeatureMapper, SegDenseError> {
|
||||||
|
|
||||||
let v = root.input_features_map;
|
let v = root.input_features_map;
|
||||||
|
|
||||||
// Do error check
|
// Do error check
|
||||||
@ -79,7 +86,7 @@ pub fn load_from_parsed_config(root: seg_dense::Root) -> Result<FeatureMapper, S
|
|||||||
Some(info) => {
|
Some(info) => {
|
||||||
debug!("{:?}", info);
|
debug!("{:?}", info);
|
||||||
fm.set(feature_id, info)
|
fm.set(feature_id, info)
|
||||||
}
|
},
|
||||||
None => (),
|
None => (),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -87,22 +94,19 @@ pub fn load_from_parsed_config(root: seg_dense::Root) -> Result<FeatureMapper, S
|
|||||||
Ok(fm)
|
Ok(fm)
|
||||||
}
|
}
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
fn add_feature_info_to_mapper(
|
fn add_feature_info_to_mapper(feature_mapper: &mut FeatureMapper, input_features: &Vec<InputFeature>) {
|
||||||
feature_mapper: &mut FeatureMapper,
|
|
||||||
input_features: &Vec<InputFeature>,
|
|
||||||
) {
|
|
||||||
for input_feature in input_features.iter() {
|
for input_feature in input_features.iter() {
|
||||||
let feature_id = input_feature.feature_id;
|
let feature_id = input_feature.feature_id;
|
||||||
let feature_info = to_feature_info(input_feature);
|
let feature_info = to_feature_info(input_feature);
|
||||||
|
|
||||||
match feature_info {
|
match feature_info {
|
||||||
Some(info) => {
|
Some(info) => {
|
||||||
debug!("{:?}", info);
|
debug!("{:?}", info);
|
||||||
feature_mapper.set(feature_id, info)
|
feature_mapper.set(feature_id, info)
|
||||||
|
},
|
||||||
|
None => (),
|
||||||
}
|
}
|
||||||
None => (),
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn to_feature_info(input_feature: &seg_dense::InputFeature) -> Option<FeatureInfo> {
|
pub fn to_feature_info(input_feature: &seg_dense::InputFeature) -> Option<FeatureInfo> {
|
||||||
@ -135,7 +139,7 @@ pub fn to_feature_info(input_feature: &seg_dense::InputFeature) -> Option<Featur
|
|||||||
2 => 0,
|
2 => 0,
|
||||||
3 => 2,
|
3 => 2,
|
||||||
_ => -1,
|
_ => -1,
|
||||||
},
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
if input_feature.index < 0 {
|
if input_feature.index < 0 {
|
||||||
@ -152,3 +156,4 @@ pub fn to_feature_info(input_feature: &seg_dense::InputFeature) -> Option<Featur
|
|||||||
index_within_tensor: input_feature.index,
|
index_within_tensor: input_feature.index,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
13
unified_user_actions/adapter/src/main/scala/com/twitter/unified_user_actions/adapter/behavioral_client_event/BUILD
Normal file
13
unified_user_actions/adapter/src/main/scala/com/twitter/unified_user_actions/adapter/behavioral_client_event/BUILD
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
scala_library(
|
||||||
|
sources = [
|
||||||
|
"*.scala",
|
||||||
|
],
|
||||||
|
tags = ["bazel-compatible"],
|
||||||
|
dependencies = [
|
||||||
|
"client-events/thrift/src/thrift/storage/twitter/behavioral_event:behavioral_event-scala",
|
||||||
|
"kafka/finagle-kafka/finatra-kafka/src/main/scala",
|
||||||
|
"unified_user_actions/adapter/src/main/scala/com/twitter/unified_user_actions/adapter:base",
|
||||||
|
"unified_user_actions/adapter/src/main/scala/com/twitter/unified_user_actions/adapter/common",
|
||||||
|
"unified_user_actions/thrift/src/main/thrift/com/twitter/unified_user_actions:unified_user_actions-scala",
|
||||||
|
],
|
||||||
|
)
|
96
unified_user_actions/adapter/src/main/scala/com/twitter/unified_user_actions/adapter/behavioral_client_event/BaseBCEAdapter.scala
Normal file
96
unified_user_actions/adapter/src/main/scala/com/twitter/unified_user_actions/adapter/behavioral_client_event/BaseBCEAdapter.scala
Normal file
@ -0,0 +1,96 @@
|
|||||||
|
package com.twitter.unified_user_actions.adapter.behavioral_client_event
|
||||||
|
|
||||||
|
import com.twitter.client_event_entities.serverside_context_key.latest.thriftscala.FlattenedServersideContextKey
|
||||||
|
import com.twitter.storage.behavioral_event.thriftscala.EventLogContext
|
||||||
|
import com.twitter.storage.behavioral_event.thriftscala.FlattenedEventLog
|
||||||
|
import com.twitter.unified_user_actions.adapter.common.AdapterUtils
|
||||||
|
import com.twitter.unified_user_actions.thriftscala.ActionType
|
||||||
|
import com.twitter.unified_user_actions.thriftscala.BreadcrumbTweet
|
||||||
|
import com.twitter.unified_user_actions.thriftscala.ClientEventNamespace
|
||||||
|
import com.twitter.unified_user_actions.thriftscala.EventMetadata
|
||||||
|
import com.twitter.unified_user_actions.thriftscala.Item
|
||||||
|
import com.twitter.unified_user_actions.thriftscala.ProductSurface
|
||||||
|
import com.twitter.unified_user_actions.thriftscala.ProductSurfaceInfo
|
||||||
|
import com.twitter.unified_user_actions.thriftscala.SourceLineage
|
||||||
|
import com.twitter.unified_user_actions.thriftscala.UnifiedUserAction
|
||||||
|
import com.twitter.unified_user_actions.thriftscala.UserIdentifier
|
||||||
|
|
||||||
|
case class ProductSurfaceRelated(
|
||||||
|
productSurface: Option[ProductSurface],
|
||||||
|
productSurfaceInfo: Option[ProductSurfaceInfo])
|
||||||
|
|
||||||
|
trait BaseBCEAdapter {
|
||||||
|
def toUUA(e: FlattenedEventLog): Seq[UnifiedUserAction]
|
||||||
|
|
||||||
|
protected def getUserIdentifier(c: EventLogContext): UserIdentifier =
|
||||||
|
UserIdentifier(
|
||||||
|
userId = c.userId,
|
||||||
|
guestIdMarketing = c.guestIdMarketing
|
||||||
|
)
|
||||||
|
|
||||||
|
protected def getEventMetadata(e: FlattenedEventLog): EventMetadata =
|
||||||
|
EventMetadata(
|
||||||
|
sourceLineage = SourceLineage.BehavioralClientEvents,
|
||||||
|
sourceTimestampMs =
|
||||||
|
e.context.driftAdjustedEventCreatedAtMs.getOrElse(e.context.eventCreatedAtMs),
|
||||||
|
receivedTimestampMs = AdapterUtils.currentTimestampMs,
|
||||||
|
// Client UI language or from Gizmoduck which is what user set in Twitter App.
|
||||||
|
// Please see more at https://sourcegraph.twitter.biz/git.twitter.biz/source/-/blob/finatra-internal/international/src/main/scala/com/twitter/finatra/international/LanguageIdentifier.scala
|
||||||
|
// The format should be ISO 639-1.
|
||||||
|
language = e.context.languageCode.map(AdapterUtils.normalizeLanguageCode),
|
||||||
|
// Country code could be IP address (geoduck) or User registration country (gizmoduck) and the former takes precedence.
|
||||||
|
// We don’t know exactly which one is applied, unfortunately,
|
||||||
|
// see https://sourcegraph.twitter.biz/git.twitter.biz/source/-/blob/finatra-internal/international/src/main/scala/com/twitter/finatra/international/CountryIdentifier.scala
|
||||||
|
// The format should be ISO_3166-1_alpha-2.
|
||||||
|
countryCode = e.context.countryCode.map(AdapterUtils.normalizeCountryCode),
|
||||||
|
clientAppId = e.context.clientApplicationId,
|
||||||
|
clientVersion = e.context.clientVersion,
|
||||||
|
clientPlatform = e.context.clientPlatform,
|
||||||
|
viewHierarchy = e.v1ViewTypeHierarchy,
|
||||||
|
clientEventNamespace = Some(
|
||||||
|
ClientEventNamespace(
|
||||||
|
page = e.page,
|
||||||
|
section = e.section,
|
||||||
|
element = e.element,
|
||||||
|
action = e.actionName,
|
||||||
|
subsection = e.subsection
|
||||||
|
)),
|
||||||
|
breadcrumbViews = e.v1BreadcrumbViewTypeHierarchy,
|
||||||
|
breadcrumbTweets = e.v1BreadcrumbTweetIds.map { breadcrumbs =>
|
||||||
|
breadcrumbs.map { breadcrumb =>
|
||||||
|
BreadcrumbTweet(
|
||||||
|
tweetId = breadcrumb.serversideContextId.toLong,
|
||||||
|
sourceComponent = breadcrumb.sourceComponent)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
protected def getBreadcrumbTweetIds(
|
||||||
|
breadcrumbTweetIds: Option[Seq[FlattenedServersideContextKey]]
|
||||||
|
): Seq[BreadcrumbTweet] =
|
||||||
|
breadcrumbTweetIds
|
||||||
|
.getOrElse(Nil).map(breadcrumb => {
|
||||||
|
BreadcrumbTweet(
|
||||||
|
tweetId = breadcrumb.serversideContextId.toLong,
|
||||||
|
sourceComponent = breadcrumb.sourceComponent)
|
||||||
|
})
|
||||||
|
|
||||||
|
protected def getBreadcrumbViews(breadcrumbView: Option[Seq[String]]): Seq[String] =
|
||||||
|
breadcrumbView.getOrElse(Nil)
|
||||||
|
|
||||||
|
protected def getUnifiedUserAction(
|
||||||
|
event: FlattenedEventLog,
|
||||||
|
actionType: ActionType,
|
||||||
|
item: Item,
|
||||||
|
productSurface: Option[ProductSurface] = None,
|
||||||
|
productSurfaceInfo: Option[ProductSurfaceInfo] = None
|
||||||
|
): UnifiedUserAction =
|
||||||
|
UnifiedUserAction(
|
||||||
|
userIdentifier = getUserIdentifier(event.context),
|
||||||
|
actionType = actionType,
|
||||||
|
item = item,
|
||||||
|
eventMetadata = getEventMetadata(event),
|
||||||
|
productSurface = productSurface,
|
||||||
|
productSurfaceInfo = productSurfaceInfo
|
||||||
|
)
|
||||||
|
}
|
@ -0,0 +1,39 @@
|
|||||||
|
package com.twitter.unified_user_actions.adapter.behavioral_client_event
|
||||||
|
|
||||||
|
import com.twitter.finagle.stats.NullStatsReceiver
|
||||||
|
import com.twitter.finagle.stats.StatsReceiver
|
||||||
|
import com.twitter.finatra.kafka.serde.UnKeyed
|
||||||
|
import com.twitter.storage.behavioral_event.thriftscala.FlattenedEventLog
|
||||||
|
import com.twitter.unified_user_actions.adapter.AbstractAdapter
|
||||||
|
import com.twitter.unified_user_actions.thriftscala._
|
||||||
|
|
||||||
|
class BehavioralClientEventAdapter
|
||||||
|
extends AbstractAdapter[FlattenedEventLog, UnKeyed, UnifiedUserAction] {
|
||||||
|
|
||||||
|
import BehavioralClientEventAdapter._
|
||||||
|
|
||||||
|
override def adaptOneToKeyedMany(
|
||||||
|
input: FlattenedEventLog,
|
||||||
|
statsReceiver: StatsReceiver = NullStatsReceiver
|
||||||
|
): Seq[(UnKeyed, UnifiedUserAction)] =
|
||||||
|
adaptEvent(input).map { e => (UnKeyed, e) }
|
||||||
|
}
|
||||||
|
|
||||||
|
object BehavioralClientEventAdapter {
|
||||||
|
def adaptEvent(e: FlattenedEventLog): Seq[UnifiedUserAction] =
|
||||||
|
// See go/bcecoverage for event namespaces, usage and coverage details
|
||||||
|
Option(e)
|
||||||
|
.map { e =>
|
||||||
|
(e.page, e.actionName) match {
|
||||||
|
case (Some("tweet_details"), Some("impress")) =>
|
||||||
|
TweetImpressionBCEAdapter.TweetDetails.toUUA(e)
|
||||||
|
case (Some("fullscreen_video"), Some("impress")) =>
|
||||||
|
TweetImpressionBCEAdapter.FullscreenVideo.toUUA(e)
|
||||||
|
case (Some("fullscreen_image"), Some("impress")) =>
|
||||||
|
TweetImpressionBCEAdapter.FullscreenImage.toUUA(e)
|
||||||
|
case (Some("profile"), Some("impress")) =>
|
||||||
|
ProfileImpressionBCEAdapter.Profile.toUUA(e)
|
||||||
|
case _ => Nil
|
||||||
|
}
|
||||||
|
}.getOrElse(Nil)
|
||||||
|
}
|
34
unified_user_actions/adapter/src/main/scala/com/twitter/unified_user_actions/adapter/behavioral_client_event/ImpressionBCEAdapter.scala
Normal file
34
unified_user_actions/adapter/src/main/scala/com/twitter/unified_user_actions/adapter/behavioral_client_event/ImpressionBCEAdapter.scala
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
package com.twitter.unified_user_actions.adapter.behavioral_client_event
|
||||||
|
|
||||||
|
import com.twitter.client.behavioral_event.action.impress.latest.thriftscala.Impress
|
||||||
|
import com.twitter.client_event_entities.serverside_context_key.latest.thriftscala.FlattenedServersideContextKey
|
||||||
|
import com.twitter.unified_user_actions.thriftscala.Item
|
||||||
|
|
||||||
|
trait ImpressionBCEAdapter extends BaseBCEAdapter {
|
||||||
|
type ImpressedItem <: Item
|
||||||
|
|
||||||
|
def getImpressedItem(
|
||||||
|
context: FlattenedServersideContextKey,
|
||||||
|
impression: Impress
|
||||||
|
): ImpressedItem
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The start time of an impression in milliseconds since epoch. In BCE, the impression
|
||||||
|
* tracking clock will start immediately after the page is visible with no initial delay.
|
||||||
|
*/
|
||||||
|
def getImpressedStartTimestamp(impression: Impress): Long =
|
||||||
|
impression.visibilityPctDwellStartMs
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The end time of an impression in milliseconds since epoch. In BCE, the impression
|
||||||
|
* tracking clock will end before the user exit the page.
|
||||||
|
*/
|
||||||
|
def getImpressedEndTimestamp(impression: Impress): Long =
|
||||||
|
impression.visibilityPctDwellEndMs
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The UI component that hosted the impressed item.
|
||||||
|
*/
|
||||||
|
def getImpressedUISourceComponent(context: FlattenedServersideContextKey): String =
|
||||||
|
context.sourceComponent
|
||||||
|
}
|
@ -0,0 +1,52 @@
|
|||||||
|
package com.twitter.unified_user_actions.adapter.behavioral_client_event
|
||||||
|
|
||||||
|
import com.twitter.client.behavioral_event.action.impress.latest.thriftscala.Impress
|
||||||
|
import com.twitter.client_event_entities.serverside_context_key.latest.thriftscala.FlattenedServersideContextKey
|
||||||
|
import com.twitter.storage.behavioral_event.thriftscala.FlattenedEventLog
|
||||||
|
import com.twitter.unified_user_actions.thriftscala.ActionType
|
||||||
|
import com.twitter.unified_user_actions.thriftscala.ClientProfileV2Impression
|
||||||
|
import com.twitter.unified_user_actions.thriftscala.Item
|
||||||
|
import com.twitter.unified_user_actions.thriftscala.ProductSurface
|
||||||
|
import com.twitter.unified_user_actions.thriftscala.ProfileActionInfo
|
||||||
|
import com.twitter.unified_user_actions.thriftscala.ProfileInfo
|
||||||
|
import com.twitter.unified_user_actions.thriftscala.UnifiedUserAction
|
||||||
|
|
||||||
|
object ProfileImpressionBCEAdapter {
|
||||||
|
val Profile = new ProfileImpressionBCEAdapter()
|
||||||
|
}
|
||||||
|
|
||||||
|
class ProfileImpressionBCEAdapter extends ImpressionBCEAdapter {
|
||||||
|
override type ImpressedItem = Item.ProfileInfo
|
||||||
|
|
||||||
|
override def toUUA(e: FlattenedEventLog): Seq[UnifiedUserAction] =
|
||||||
|
(e.v2Impress, e.v1UserIds) match {
|
||||||
|
case (Some(v2Impress), Some(v1UserIds)) =>
|
||||||
|
v1UserIds.map { user =>
|
||||||
|
getUnifiedUserAction(
|
||||||
|
event = e,
|
||||||
|
actionType = ActionType.ClientProfileV2Impression,
|
||||||
|
item = getImpressedItem(user, v2Impress),
|
||||||
|
productSurface = Some(ProductSurface.ProfilePage)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
case _ => Nil
|
||||||
|
}
|
||||||
|
|
||||||
|
override def getImpressedItem(
|
||||||
|
context: FlattenedServersideContextKey,
|
||||||
|
impression: Impress
|
||||||
|
): ImpressedItem =
|
||||||
|
Item.ProfileInfo(
|
||||||
|
ProfileInfo(
|
||||||
|
actionProfileId = context.serversideContextId.toLong,
|
||||||
|
profileActionInfo = Some(
|
||||||
|
ProfileActionInfo.ClientProfileV2Impression(
|
||||||
|
ClientProfileV2Impression(
|
||||||
|
impressStartTimestampMs = getImpressedStartTimestamp(impression),
|
||||||
|
impressEndTimestampMs = getImpressedEndTimestamp(impression),
|
||||||
|
sourceComponent = getImpressedUISourceComponent(context)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
))
|
||||||
|
}
|
@ -0,0 +1,84 @@
|
|||||||
|
package com.twitter.unified_user_actions.adapter.behavioral_client_event
|
||||||
|
|
||||||
|
import com.twitter.client.behavioral_event.action.impress.latest.thriftscala.Impress
|
||||||
|
import com.twitter.client_event_entities.serverside_context_key.latest.thriftscala.FlattenedServersideContextKey
|
||||||
|
import com.twitter.storage.behavioral_event.thriftscala.FlattenedEventLog
|
||||||
|
import com.twitter.unified_user_actions.thriftscala.ActionType
|
||||||
|
import com.twitter.unified_user_actions.thriftscala.ClientTweetV2Impression
|
||||||
|
import com.twitter.unified_user_actions.thriftscala.Item
|
||||||
|
import com.twitter.unified_user_actions.thriftscala.ProductSurface
|
||||||
|
import com.twitter.unified_user_actions.thriftscala.TweetActionInfo
|
||||||
|
import com.twitter.unified_user_actions.thriftscala.TweetInfo
|
||||||
|
import com.twitter.unified_user_actions.thriftscala.UnifiedUserAction
|
||||||
|
|
||||||
|
object TweetImpressionBCEAdapter {
|
||||||
|
val TweetDetails = new TweetImpressionBCEAdapter(ActionType.ClientTweetV2Impression)
|
||||||
|
val FullscreenVideo = new TweetImpressionBCEAdapter(
|
||||||
|
ActionType.ClientTweetVideoFullscreenV2Impression)
|
||||||
|
val FullscreenImage = new TweetImpressionBCEAdapter(
|
||||||
|
ActionType.ClientTweetImageFullscreenV2Impression)
|
||||||
|
}
|
||||||
|
|
||||||
|
class TweetImpressionBCEAdapter(actionType: ActionType) extends ImpressionBCEAdapter {
|
||||||
|
override type ImpressedItem = Item.TweetInfo
|
||||||
|
|
||||||
|
override def toUUA(e: FlattenedEventLog): Seq[UnifiedUserAction] =
|
||||||
|
(actionType, e.v2Impress, e.v1TweetIds, e.v1BreadcrumbTweetIds) match {
|
||||||
|
case (ActionType.ClientTweetV2Impression, Some(v2Impress), Some(v1TweetIds), _) =>
|
||||||
|
toUUAEvents(e, v2Impress, v1TweetIds)
|
||||||
|
case (
|
||||||
|
ActionType.ClientTweetVideoFullscreenV2Impression,
|
||||||
|
Some(v2Impress),
|
||||||
|
_,
|
||||||
|
Some(v1BreadcrumbTweetIds)) =>
|
||||||
|
toUUAEvents(e, v2Impress, v1BreadcrumbTweetIds)
|
||||||
|
case (
|
||||||
|
ActionType.ClientTweetImageFullscreenV2Impression,
|
||||||
|
Some(v2Impress),
|
||||||
|
_,
|
||||||
|
Some(v1BreadcrumbTweetIds)) =>
|
||||||
|
toUUAEvents(e, v2Impress, v1BreadcrumbTweetIds)
|
||||||
|
case _ => Nil
|
||||||
|
}
|
||||||
|
|
||||||
|
private def toUUAEvents(
|
||||||
|
e: FlattenedEventLog,
|
||||||
|
v2Impress: Impress,
|
||||||
|
v1TweetIds: Seq[FlattenedServersideContextKey]
|
||||||
|
): Seq[UnifiedUserAction] =
|
||||||
|
v1TweetIds.map { tweet =>
|
||||||
|
getUnifiedUserAction(
|
||||||
|
event = e,
|
||||||
|
actionType = actionType,
|
||||||
|
item = getImpressedItem(tweet, v2Impress),
|
||||||
|
productSurface = getProductSurfaceRelated.productSurface,
|
||||||
|
productSurfaceInfo = getProductSurfaceRelated.productSurfaceInfo
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
override def getImpressedItem(
|
||||||
|
context: FlattenedServersideContextKey,
|
||||||
|
impression: Impress
|
||||||
|
): ImpressedItem =
|
||||||
|
Item.TweetInfo(
|
||||||
|
TweetInfo(
|
||||||
|
actionTweetId = context.serversideContextId.toLong,
|
||||||
|
tweetActionInfo = Some(
|
||||||
|
TweetActionInfo.ClientTweetV2Impression(
|
||||||
|
ClientTweetV2Impression(
|
||||||
|
impressStartTimestampMs = getImpressedStartTimestamp(impression),
|
||||||
|
impressEndTimestampMs = getImpressedEndTimestamp(impression),
|
||||||
|
sourceComponent = getImpressedUISourceComponent(context)
|
||||||
|
)
|
||||||
|
))
|
||||||
|
))
|
||||||
|
|
||||||
|
private def getProductSurfaceRelated: ProductSurfaceRelated =
|
||||||
|
actionType match {
|
||||||
|
case ActionType.ClientTweetV2Impression =>
|
||||||
|
ProductSurfaceRelated(
|
||||||
|
productSurface = Some(ProductSurface.TweetDetailsPage),
|
||||||
|
productSurfaceInfo = None)
|
||||||
|
case _ => ProductSurfaceRelated(productSurface = None, productSurfaceInfo = None)
|
||||||
|
}
|
||||||
|
}
|
139
unified_user_actions/adapter/src/test/scala/com/twitter/unified_user_actions/adapter/BehavioralClientEventAdapterSpec.scala
Normal file
139
unified_user_actions/adapter/src/test/scala/com/twitter/unified_user_actions/adapter/BehavioralClientEventAdapterSpec.scala
Normal file
@ -0,0 +1,139 @@
|
|||||||
|
package com.twitter.unified_user_actions.adapter
|
||||||
|
|
||||||
|
import com.twitter.inject.Test
|
||||||
|
import com.twitter.storage.behavioral_event.thriftscala.FlattenedEventLog
|
||||||
|
import com.twitter.unified_user_actions.adapter.TestFixtures.BCEFixture
|
||||||
|
import com.twitter.unified_user_actions.adapter.behavioral_client_event.BehavioralClientEventAdapter
|
||||||
|
import com.twitter.unified_user_actions.thriftscala._
|
||||||
|
import com.twitter.util.Time
|
||||||
|
import org.scalatest.prop.TableDrivenPropertyChecks
|
||||||
|
|
||||||
|
class BehavioralClientEventAdapterSpec extends Test with TableDrivenPropertyChecks {
|
||||||
|
|
||||||
|
test("basic event conversion should be correct") {
|
||||||
|
new BCEFixture {
|
||||||
|
Time.withTimeAt(frozenTime) { _ =>
|
||||||
|
val tests = Table(
|
||||||
|
("event", "expected", "description"),
|
||||||
|
(
|
||||||
|
makeBCEEvent(),
|
||||||
|
makeUUAImpressEvent(productSurface = Some(ProductSurface.TweetDetailsPage)),
|
||||||
|
"tweet_details conversion"),
|
||||||
|
(makeBCEProfileImpressEvent(), makeUUAProfileImpressEvent(), "profile conversion"),
|
||||||
|
(
|
||||||
|
makeBCEVideoFullscreenImpressEvent(),
|
||||||
|
makeUUAVideoFullscreenImpressEvent(),
|
||||||
|
"fullscreen_video conversion"),
|
||||||
|
(
|
||||||
|
makeBCEImageFullscreenImpressEvent(),
|
||||||
|
makeUUAImageFullscreenImpressEvent(),
|
||||||
|
"fullscreen_image conversion"),
|
||||||
|
)
|
||||||
|
forEvery(tests) { (input: FlattenedEventLog, expected: UnifiedUserAction, desc: String) =>
|
||||||
|
assert(Seq(expected) === BehavioralClientEventAdapter.adaptEvent(input), desc)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
test(
|
||||||
|
"tweet_details is NOT missing productSurface[Info] when empty breadcrumb components and breadcrumbs tweets id") {
|
||||||
|
new BCEFixture {
|
||||||
|
Time.withTimeAt(frozenTime) { _ =>
|
||||||
|
val input = makeBCEEvent(v1BreadcrumbViewTypeHierarchy = None, v1BreadcrumbTweetIds = None)
|
||||||
|
val expected =
|
||||||
|
makeUUAImpressEvent(
|
||||||
|
productSurface = Some(ProductSurface.TweetDetailsPage),
|
||||||
|
breadcrumbViews = None,
|
||||||
|
breadcrumbTweets = None)
|
||||||
|
val actual = BehavioralClientEventAdapter.adaptEvent(input)
|
||||||
|
|
||||||
|
assert(Seq(expected) === actual)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
test("tweet_details is not missing productSurface[Info] when only breadcrumb tweets is empty") {
|
||||||
|
new BCEFixture {
|
||||||
|
Time.withTimeAt(frozenTime) { _ =>
|
||||||
|
val input = makeBCEEvent(v1BreadcrumbTweetIds = None)
|
||||||
|
val expected = makeUUAImpressEvent(
|
||||||
|
productSurface = Some(ProductSurface.TweetDetailsPage),
|
||||||
|
breadcrumbViews = Some(viewBreadcrumbs),
|
||||||
|
breadcrumbTweets = None
|
||||||
|
)
|
||||||
|
val actual = BehavioralClientEventAdapter.adaptEvent(input)
|
||||||
|
|
||||||
|
assert(Seq(expected) === actual)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
test("unsupported events should be skipped") {
|
||||||
|
new BCEFixture {
|
||||||
|
val unsupportedPage = "unsupported_page"
|
||||||
|
val unsupportedAction = "unsupported_action"
|
||||||
|
val supportedNamespaces = Table(
|
||||||
|
("page", "actions"),
|
||||||
|
("tweet_details", Seq("impress")),
|
||||||
|
("profile", Seq("impress")),
|
||||||
|
)
|
||||||
|
|
||||||
|
forAll(supportedNamespaces) { (page: String, actions: Seq[String]) =>
|
||||||
|
actions.foreach { supportedAction =>
|
||||||
|
assert(
|
||||||
|
BehavioralClientEventAdapter
|
||||||
|
.adaptEvent(
|
||||||
|
makeBCEEvent(
|
||||||
|
currentPage = Some(unsupportedPage),
|
||||||
|
actionName = Some(supportedAction))).isEmpty)
|
||||||
|
|
||||||
|
assert(BehavioralClientEventAdapter
|
||||||
|
.adaptEvent(
|
||||||
|
makeBCEEvent(currentPage = Some(page), actionName = Some(unsupportedAction))).isEmpty)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
test("event w/ missing info should be skipped") {
|
||||||
|
new BCEFixture {
|
||||||
|
val eventsWithMissingInfo = Table(
|
||||||
|
("event", "description"),
|
||||||
|
(null.asInstanceOf[FlattenedEventLog], "null event"),
|
||||||
|
(makeBCEEvent(v2Impress = None), "impression event missing v2Impress"),
|
||||||
|
(makeBCEEvent(v1TweetIds = None), "tweet event missing v1TweetIds"),
|
||||||
|
(makeBCEProfileImpressEvent(v1UserIds = None), "profile event missing v1UserIds"),
|
||||||
|
(
|
||||||
|
makeBCEVideoFullscreenImpressEvent(v1BreadcrumbTweetIds = None),
|
||||||
|
"fullscreen_video event missing v1BreadcrumbTweetIds"),
|
||||||
|
(
|
||||||
|
makeBCEImageFullscreenImpressEvent(v1BreadcrumbTweetIds = None),
|
||||||
|
"fullscreen_image event missing v1BreadcrumbTweetIds"),
|
||||||
|
)
|
||||||
|
|
||||||
|
forEvery(eventsWithMissingInfo) { (event: FlattenedEventLog, desc: String) =>
|
||||||
|
assert(
|
||||||
|
BehavioralClientEventAdapter
|
||||||
|
.adaptEvent(event).isEmpty,
|
||||||
|
desc)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
test("use eventCreateAtMs when driftAdjustedTimetampMs is empty") {
|
||||||
|
new BCEFixture {
|
||||||
|
Time.withTimeAt(frozenTime) { _ =>
|
||||||
|
val input = makeBCEEvent(
|
||||||
|
context = makeBCEContext(driftAdjustedEventCreatedAtMs = None)
|
||||||
|
)
|
||||||
|
val expected = makeUUAImpressEvent(
|
||||||
|
createTs = eventCreatedTime,
|
||||||
|
productSurface = Some(ProductSurface.TweetDetailsPage))
|
||||||
|
val actual = BehavioralClientEventAdapter.adaptEvent(input)
|
||||||
|
|
||||||
|
assert(Seq(expected) === actual)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
25
unified_user_actions/service/src/main/scala/com/twitter/unified_user_actions/service/BehavioralClientEventService.scala
Normal file
25
unified_user_actions/service/src/main/scala/com/twitter/unified_user_actions/service/BehavioralClientEventService.scala
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
package com.twitter.unified_user_actions.service;
|
||||||
|
|
||||||
|
import com.twitter.finatra.decider.modules.DeciderModule
|
||||||
|
import com.twitter.finatra.kafka.serde.UnKeyed
|
||||||
|
import com.twitter.inject.server.TwitterServer
|
||||||
|
import com.twitter.kafka.client.processor.AtLeastOnceProcessor
|
||||||
|
import com.twitter.storage.behavioral_event.thriftscala.FlattenedEventLog
|
||||||
|
import com.twitter.unified_user_actions.service.module.KafkaProcessorBehavioralClientEventModule
|
||||||
|
|
||||||
|
object BehavioralClientEventServiceMain extends BehavioralClientEventService
|
||||||
|
|
||||||
|
class BehavioralClientEventService extends TwitterServer {
|
||||||
|
override val modules = Seq(
|
||||||
|
KafkaProcessorBehavioralClientEventModule,
|
||||||
|
DeciderModule
|
||||||
|
)
|
||||||
|
|
||||||
|
override protected def setup(): Unit = {}
|
||||||
|
|
||||||
|
override protected def start(): Unit = {
|
||||||
|
val processor = injector.instance[AtLeastOnceProcessor[UnKeyed, FlattenedEventLog]]
|
||||||
|
closeOnExit(processor)
|
||||||
|
processor.start()
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,87 @@
|
|||||||
|
package com.twitter.unified_user_actions.service.module
|
||||||
|
|
||||||
|
import com.google.inject.Provides
|
||||||
|
import com.twitter.decider.Decider
|
||||||
|
import com.twitter.finagle.stats.StatsReceiver
|
||||||
|
import com.twitter.finatra.kafka.serde.UnKeyed
|
||||||
|
import com.twitter.finatra.kafka.serde.UnKeyedSerde
|
||||||
|
import com.twitter.inject.annotations.Flag
|
||||||
|
import com.twitter.inject.TwitterModule
|
||||||
|
import com.twitter.kafka.client.processor.AtLeastOnceProcessor
|
||||||
|
import com.twitter.storage.behavioral_event.thriftscala.FlattenedEventLog
|
||||||
|
import com.twitter.unified_user_actions.adapter.behavioral_client_event.BehavioralClientEventAdapter
|
||||||
|
import com.twitter.unified_user_actions.kafka.CompressionTypeFlag
|
||||||
|
import com.twitter.unified_user_actions.kafka.serde.NullableScalaSerdes
|
||||||
|
import com.twitter.util.Duration
|
||||||
|
import com.twitter.util.StorageUnit
|
||||||
|
import com.twitter.util.logging.Logging
|
||||||
|
import javax.inject.Singleton
|
||||||
|
|
||||||
|
object KafkaProcessorBehavioralClientEventModule extends TwitterModule with Logging {
|
||||||
|
override def modules = Seq(FlagsModule)
|
||||||
|
|
||||||
|
private val adapter: BehavioralClientEventAdapter = new BehavioralClientEventAdapter
|
||||||
|
private final val processorName: String = "uuaProcessor"
|
||||||
|
|
||||||
|
@Provides
|
||||||
|
@Singleton
|
||||||
|
def providesKafkaProcessor(
|
||||||
|
decider: Decider,
|
||||||
|
@Flag(FlagsModule.cluster) cluster: String,
|
||||||
|
@Flag(FlagsModule.kafkaSourceCluster) kafkaSourceCluster: String,
|
||||||
|
@Flag(FlagsModule.kafkaDestCluster) kafkaDestCluster: String,
|
||||||
|
@Flag(FlagsModule.kafkaSourceTopic) kafkaSourceTopic: String,
|
||||||
|
@Flag(FlagsModule.kafkaSinkTopics) kafkaSinkTopics: Seq[String],
|
||||||
|
@Flag(FlagsModule.kafkaGroupId) kafkaGroupId: String,
|
||||||
|
@Flag(FlagsModule.kafkaProducerClientId) kafkaProducerClientId: String,
|
||||||
|
@Flag(FlagsModule.kafkaMaxPendingRequests) kafkaMaxPendingRequests: Int,
|
||||||
|
@Flag(FlagsModule.kafkaWorkerThreads) kafkaWorkerThreads: Int,
|
||||||
|
@Flag(FlagsModule.commitInterval) commitInterval: Duration,
|
||||||
|
@Flag(FlagsModule.maxPollRecords) maxPollRecords: Int,
|
||||||
|
@Flag(FlagsModule.maxPollInterval) maxPollInterval: Duration,
|
||||||
|
@Flag(FlagsModule.sessionTimeout) sessionTimeout: Duration,
|
||||||
|
@Flag(FlagsModule.fetchMax) fetchMax: StorageUnit,
|
||||||
|
@Flag(FlagsModule.batchSize) batchSize: StorageUnit,
|
||||||
|
@Flag(FlagsModule.linger) linger: Duration,
|
||||||
|
@Flag(FlagsModule.bufferMem) bufferMem: StorageUnit,
|
||||||
|
@Flag(FlagsModule.compressionType) compressionTypeFlag: CompressionTypeFlag,
|
||||||
|
@Flag(FlagsModule.retries) retries: Int,
|
||||||
|
@Flag(FlagsModule.retryBackoff) retryBackoff: Duration,
|
||||||
|
@Flag(FlagsModule.requestTimeout) requestTimeout: Duration,
|
||||||
|
@Flag(FlagsModule.enableTrustStore) enableTrustStore: Boolean,
|
||||||
|
@Flag(FlagsModule.trustStoreLocation) trustStoreLocation: String,
|
||||||
|
statsReceiver: StatsReceiver,
|
||||||
|
): AtLeastOnceProcessor[UnKeyed, FlattenedEventLog] = {
|
||||||
|
KafkaProcessorProvider.provideDefaultAtLeastOnceProcessor(
|
||||||
|
name = processorName,
|
||||||
|
kafkaSourceCluster = kafkaSourceCluster,
|
||||||
|
kafkaGroupId = kafkaGroupId,
|
||||||
|
kafkaSourceTopic = kafkaSourceTopic,
|
||||||
|
sourceKeyDeserializer = UnKeyedSerde.deserializer,
|
||||||
|
sourceValueDeserializer = NullableScalaSerdes
|
||||||
|
.Thrift[FlattenedEventLog](statsReceiver.counter("deserializerErrors")).deserializer,
|
||||||
|
commitInterval = commitInterval,
|
||||||
|
maxPollRecords = maxPollRecords,
|
||||||
|
maxPollInterval = maxPollInterval,
|
||||||
|
sessionTimeout = sessionTimeout,
|
||||||
|
fetchMax = fetchMax,
|
||||||
|
processorMaxPendingRequests = kafkaMaxPendingRequests,
|
||||||
|
processorWorkerThreads = kafkaWorkerThreads,
|
||||||
|
adapter = adapter,
|
||||||
|
kafkaSinkTopics = kafkaSinkTopics,
|
||||||
|
kafkaDestCluster = kafkaDestCluster,
|
||||||
|
kafkaProducerClientId = kafkaProducerClientId,
|
||||||
|
batchSize = batchSize,
|
||||||
|
linger = linger,
|
||||||
|
bufferMem = bufferMem,
|
||||||
|
compressionType = compressionTypeFlag.compressionType,
|
||||||
|
retries = retries,
|
||||||
|
retryBackoff = retryBackoff,
|
||||||
|
requestTimeout = requestTimeout,
|
||||||
|
statsReceiver = statsReceiver,
|
||||||
|
trustStoreLocationOpt = if (enableTrustStore) Some(trustStoreLocation) else None,
|
||||||
|
decider = decider,
|
||||||
|
zone = ZoneFiltering.zoneMapping(cluster),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user