[docx] split commit for file 2000

Signed-off-by: Ari Archer <ari.web.xyz@gmail.com>
This commit is contained in:
Ari Archer 2024-01-23 19:09:02 +02:00
parent 65c3a3fe90
commit 2488f40edf
No known key found for this signature in database
GPG Key ID: A50D5B4B599AF8A2
400 changed files with 0 additions and 14990 deletions

BIN
navi/navi/src/bin/navi.docx Normal file

Binary file not shown.

View File

@ -1,47 +0,0 @@
use anyhow::Result;
use log::info;
use navi::cli_args::{ARGS, MODEL_SPECS};
use navi::cores::validator::validatior::cli_validator;
use navi::tf_model::tf::TFModel;
use navi::{bootstrap, metrics};
use sha256::digest;
fn main() -> Result<()> {
env_logger::init();
cli_validator::validate_input_args();
//only validate in for tf as other models don't have this
assert_eq!(MODEL_SPECS.len(), ARGS.serving_sig.len());
metrics::register_custom_metrics();
//load all the custom ops - comma seperaed
if let Some(ref customops_lib) = ARGS.customops_lib {
for op_lib in customops_lib.split(",") {
load_custom_op(op_lib);
}
}
// versioning the customop so library
bootstrap::bootstrap(TFModel::new)
}
fn load_custom_op(lib_path: &str) -> () {
let res = tensorflow::Library::load(lib_path);
info!("{} load status:{:?}", lib_path, res);
let customop_version_num = get_custom_op_version(lib_path);
// Last OP version is recorded
metrics::CUSTOMOP_VERSION.set(customop_version_num);
}
//fn get_custom_op_version(customops_lib: &String) -> i64 {
fn get_custom_op_version(customops_lib: &str) -> i64 {
let customop_bytes = std::fs::read(customops_lib).unwrap(); // Vec<u8>
let customop_hash = digest(customop_bytes.as_slice());
//conver the last 4 hex digits to version number as prometheus metrics doesn't support string, the total space is 16^4 == 65536
let customop_version_num =
i64::from_str_radix(&customop_hash[customop_hash.len() - 4..], 16).unwrap();
info!(
"customop hash: {}, version_number: {}",
customop_hash, customop_version_num
);
customop_version_num
}

Binary file not shown.

View File

@ -1,24 +0,0 @@
use anyhow::Result;
use log::info;
use navi::cli_args::{ARGS, MODEL_SPECS};
use navi::onnx_model::onnx::OnnxModel;
use navi::{bootstrap, metrics};
fn main() -> Result<()> {
env_logger::init();
info!("global: {:?}", ARGS.onnx_global_thread_pool_options);
let assert_session_params = if ARGS.onnx_global_thread_pool_options.is_empty() {
// std::env::set_var("OMP_NUM_THREADS", "1");
info!("now we use per session thread pool");
MODEL_SPECS.len()
}
else {
info!("now we use global thread pool");
0
};
assert_eq!(assert_session_params, ARGS.inter_op_parallelism.len());
assert_eq!(assert_session_params, ARGS.inter_op_parallelism.len());
metrics::register_custom_metrics();
bootstrap::bootstrap(OnnxModel::new)
}

Binary file not shown.

View File

@ -1,19 +0,0 @@
use anyhow::Result;
use log::info;
use navi::cli_args::ARGS;
use navi::metrics;
use navi::torch_model::torch::TorchModel;
fn main() -> Result<()> {
env_logger::init();
//torch only has global threadpool settings versus tf has per model threadpool settings
assert_eq!(1, ARGS.inter_op_parallelism.len());
assert_eq!(1, ARGS.intra_op_parallelism.len());
//TODO for now we, we assume each model's output has only 1 tensor.
//this will be lifted once torch_model properly implements mtl outputs
tch::set_num_interop_threads(ARGS.inter_op_parallelism[0].parse()?);
tch::set_num_threads(ARGS.intra_op_parallelism[0].parse()?);
info!("torch custom ops not used for now");
metrics::register_custom_metrics();
navi::bootstrap::bootstrap(TorchModel::new)
}

Binary file not shown.

View File

@ -1,326 +0,0 @@
use anyhow::Result;
use log::{info, warn};
use x509_parser::{prelude::{parse_x509_pem}, parse_x509_certificate};
use std::collections::HashMap;
use tokio::time::Instant;
use tonic::{
Request,
Response, Status, transport::{Certificate, Identity, Server, ServerTlsConfig},
};
// protobuf related
use crate::tf_proto::tensorflow_serving::{
ClassificationRequest, ClassificationResponse, GetModelMetadataRequest,
GetModelMetadataResponse, MultiInferenceRequest, MultiInferenceResponse, PredictRequest,
PredictResponse, RegressionRequest, RegressionResponse,
};
use crate::{kf_serving::{
grpc_inference_service_server::GrpcInferenceService, ModelInferRequest, ModelInferResponse,
ModelMetadataRequest, ModelMetadataResponse, ModelReadyRequest, ModelReadyResponse,
ServerLiveRequest, ServerLiveResponse, ServerMetadataRequest, ServerMetadataResponse,
ServerReadyRequest, ServerReadyResponse,
}, ModelFactory, tf_proto::tensorflow_serving::prediction_service_server::{
PredictionService, PredictionServiceServer,
}, VERSION, NAME};
use crate::PredictResult;
use crate::cli_args::{ARGS, INPUTS, OUTPUTS};
use crate::metrics::{
NAVI_VERSION, NUM_PREDICTIONS, NUM_REQUESTS_FAILED, NUM_REQUESTS_FAILED_BY_MODEL,
NUM_REQUESTS_RECEIVED, NUM_REQUESTS_RECEIVED_BY_MODEL, RESPONSE_TIME_COLLECTOR,
CERT_EXPIRY_EPOCH
};
use crate::predict_service::{Model, PredictService};
use crate::tf_proto::tensorflow_serving::model_spec::VersionChoice::Version;
use crate::tf_proto::tensorflow_serving::ModelSpec;
#[derive(Debug)]
pub enum TensorInputEnum {
String(Vec<Vec<u8>>),
Int(Vec<i32>),
Int64(Vec<i64>),
Float(Vec<f32>),
Double(Vec<f64>),
Boolean(Vec<bool>),
}
#[derive(Debug)]
pub struct TensorInput {
pub tensor_data: TensorInputEnum,
pub name: String,
pub dims: Option<Vec<i64>>,
}
impl TensorInput {
pub fn new(tensor_data: TensorInputEnum, name: String, dims: Option<Vec<i64>>) -> TensorInput {
TensorInput {
tensor_data,
name,
dims,
}
}
}
impl TensorInputEnum {
#[inline(always)]
pub(crate) fn extend(&mut self, another: TensorInputEnum) {
match (self, another) {
(Self::String(input), Self::String(ex)) => input.extend(ex),
(Self::Int(input), Self::Int(ex)) => input.extend(ex),
(Self::Int64(input), Self::Int64(ex)) => input.extend(ex),
(Self::Float(input), Self::Float(ex)) => input.extend(ex),
(Self::Double(input), Self::Double(ex)) => input.extend(ex),
(Self::Boolean(input), Self::Boolean(ex)) => input.extend(ex),
x => panic!("input enum type not matched. input:{:?}, ex:{:?}", x.0, x.1),
}
}
#[inline(always)]
pub(crate) fn merge_batch(input_tensors: Vec<Vec<TensorInput>>) -> Vec<TensorInput> {
input_tensors
.into_iter()
.reduce(|mut acc, e| {
for (i, ext) in acc.iter_mut().zip(e) {
i.tensor_data.extend(ext.tensor_data);
}
acc
})
.unwrap() //invariant: we expect there's always rows in input_tensors
}
}
///entry point for tfServing gRPC
#[tonic::async_trait]
impl<T: Model> GrpcInferenceService for PredictService<T> {
async fn server_live(
&self,
_request: Request<ServerLiveRequest>,
) -> Result<Response<ServerLiveResponse>, Status> {
unimplemented!()
}
async fn server_ready(
&self,
_request: Request<ServerReadyRequest>,
) -> Result<Response<ServerReadyResponse>, Status> {
unimplemented!()
}
async fn model_ready(
&self,
_request: Request<ModelReadyRequest>,
) -> Result<Response<ModelReadyResponse>, Status> {
unimplemented!()
}
async fn server_metadata(
&self,
_request: Request<ServerMetadataRequest>,
) -> Result<Response<ServerMetadataResponse>, Status> {
unimplemented!()
}
async fn model_metadata(
&self,
_request: Request<ModelMetadataRequest>,
) -> Result<Response<ModelMetadataResponse>, Status> {
unimplemented!()
}
async fn model_infer(
&self,
_request: Request<ModelInferRequest>,
) -> Result<Response<ModelInferResponse>, Status> {
unimplemented!()
}
}
#[tonic::async_trait]
impl<T: Model> PredictionService for PredictService<T> {
async fn classify(
&self,
_request: Request<ClassificationRequest>,
) -> Result<Response<ClassificationResponse>, Status> {
unimplemented!()
}
async fn regress(
&self,
_request: Request<RegressionRequest>,
) -> Result<Response<RegressionResponse>, Status> {
unimplemented!()
}
async fn predict(
&self,
request: Request<PredictRequest>,
) -> Result<Response<PredictResponse>, Status> {
NUM_REQUESTS_RECEIVED.inc();
let start = Instant::now();
let mut req = request.into_inner();
let (model_spec, version) = req.take_model_spec();
NUM_REQUESTS_RECEIVED_BY_MODEL
.with_label_values(&[&model_spec])
.inc();
let idx = PredictService::<T>::get_model_index(&model_spec).ok_or_else(|| {
Status::failed_precondition(format!("model spec not found:{}", model_spec))
})?;
let input_spec = match INPUTS[idx].get() {
Some(input) => input,
_ => return Err(Status::not_found(format!("model input spec {}", idx))),
};
let input_val = req.take_input_vals(input_spec);
self.predict(idx, version, input_val, start)
.await
.map_or_else(
|e| {
NUM_REQUESTS_FAILED.inc();
NUM_REQUESTS_FAILED_BY_MODEL
.with_label_values(&[&model_spec])
.inc();
Err(Status::internal(e.to_string()))
},
|res| {
RESPONSE_TIME_COLLECTOR
.with_label_values(&[&model_spec])
.observe(start.elapsed().as_millis() as f64);
match res {
PredictResult::Ok(tensors, version) => {
let mut outputs = HashMap::new();
NUM_PREDICTIONS.with_label_values(&[&model_spec]).inc();
//FIXME: uncomment when prediction scores are normal
// PREDICTION_SCORE_SUM
// .with_label_values(&[&model_spec])
// .inc_by(tensors[0]as f64);
for (tp, output_name) in tensors
.into_iter()
.map(|tensor| tensor.create_tensor_proto())
.zip(OUTPUTS[idx].iter())
{
outputs.insert(output_name.to_owned(), tp);
}
let reply = PredictResponse {
model_spec: Some(ModelSpec {
version_choice: Some(Version(version)),
..Default::default()
}),
outputs,
};
Ok(Response::new(reply))
}
PredictResult::DropDueToOverload => Err(Status::resource_exhausted("")),
PredictResult::ModelNotFound(idx) => {
Err(Status::not_found(format!("model index {}", idx)))
},
PredictResult::ModelNotReady(idx) => {
Err(Status::unavailable(format!("model index {}", idx)))
}
PredictResult::ModelVersionNotFound(idx, version) => Err(
Status::not_found(format!("model index:{}, version {}", idx, version)),
),
}
},
)
}
async fn multi_inference(
&self,
_request: Request<MultiInferenceRequest>,
) -> Result<Response<MultiInferenceResponse>, Status> {
unimplemented!()
}
async fn get_model_metadata(
&self,
_request: Request<GetModelMetadataRequest>,
) -> Result<Response<GetModelMetadataResponse>, Status> {
unimplemented!()
}
}
// A function that takes a timestamp as input and returns a ticker stream
fn report_expiry(expiry_time: i64) {
info!("Certificate expires at epoch: {:?}", expiry_time);
CERT_EXPIRY_EPOCH.set(expiry_time as i64);
}
pub fn bootstrap<T: Model>(model_factory: ModelFactory<T>) -> Result<()> {
info!("package: {}, version: {}, args: {:?}", NAME, VERSION, *ARGS);
//we follow SemVer. So here we assume MAJOR.MINOR.PATCH
let parts = VERSION
.split(".")
.map(|v| v.parse::<i64>())
.collect::<std::result::Result<Vec<_>, _>>()?;
if let [major, minor, patch] = &parts[..] {
NAVI_VERSION.set(major * 1000_000 + minor * 1000 + patch);
} else {
warn!(
"version {} doesn't follow SemVer conversion of MAJOR.MINOR.PATCH",
VERSION
);
}
tokio::runtime::Builder::new_multi_thread()
.thread_name("async worker")
.worker_threads(ARGS.num_worker_threads)
.max_blocking_threads(ARGS.max_blocking_threads)
.enable_all()
.build()
.unwrap()
.block_on(async {
#[cfg(feature = "navi_console")]
console_subscriber::init();
let addr = format!("0.0.0.0:{}", ARGS.port).parse()?;
let ps = PredictService::init(model_factory).await;
let mut builder = if ARGS.ssl_dir.is_empty() {
Server::builder()
} else {
// Read the pem file as a string
let pem_str = std::fs::read_to_string(format!("{}/server.crt", ARGS.ssl_dir)).unwrap();
let res = parse_x509_pem(&pem_str.as_bytes());
match res {
Ok((rem, pem_2)) => {
assert!(rem.is_empty());
assert_eq!(pem_2.label, String::from("CERTIFICATE"));
let res_x509 = parse_x509_certificate(&pem_2.contents);
info!("Certificate label: {}", pem_2.label);
assert!(res_x509.is_ok());
report_expiry(res_x509.unwrap().1.validity().not_after.timestamp());
},
_ => panic!("PEM parsing failed: {:?}", res),
}
let key = tokio::fs::read(format!("{}/server.key", ARGS.ssl_dir))
.await
.expect("can't find key file");
let crt = tokio::fs::read(format!("{}/server.crt", ARGS.ssl_dir))
.await
.expect("can't find crt file");
let chain = tokio::fs::read(format!("{}/server.chain", ARGS.ssl_dir))
.await
.expect("can't find chain file");
let mut pem = Vec::new();
pem.extend(crt);
pem.extend(chain);
let identity = Identity::from_pem(pem.clone(), key);
let client_ca_cert = Certificate::from_pem(pem.clone());
let tls = ServerTlsConfig::new()
.identity(identity)
.client_ca_root(client_ca_cert);
Server::builder()
.tls_config(tls)
.expect("fail to config SSL")
};
info!(
"Prometheus server started: 0.0.0.0: {}",
ARGS.prometheus_port
);
let ps_server = builder
.add_service(PredictionServiceServer::new(ps).accept_gzip().send_gzip())
.serve(addr);
info!("Prediction server started: {}", addr);
ps_server.await.map_err(anyhow::Error::msg)
})
}

BIN
navi/navi/src/cli_args.docx Normal file

Binary file not shown.

View File

@ -1,236 +0,0 @@
use crate::{MAX_NUM_INPUTS, MAX_NUM_MODELS, MAX_NUM_OUTPUTS};
use arrayvec::ArrayVec;
use clap::Parser;
use log::info;
use once_cell::sync::OnceCell;
use std::error::Error;
use time::OffsetDateTime;
use time::format_description::well_known::Rfc3339;
#[derive(Parser, Debug, Clone)]
///Navi is configured through CLI arguments(for now) defined below.
//TODO: use clap_serde to make it config file driven
pub struct Args {
#[clap(short, long, help = "gRPC port Navi runs ons")]
pub port: i32,
#[clap(long, default_value_t = 9000, help = "prometheus metrics port")]
pub prometheus_port: u16,
#[clap(
short,
long,
default_value_t = 1,
help = "number of worker threads for tokio async runtime"
)]
pub num_worker_threads: usize,
#[clap(
long,
default_value_t = 14,
help = "number of blocking threads in tokio blocking thread pool"
)]
pub max_blocking_threads: usize,
#[clap(long, default_value = "16", help = "maximum batch size for a batch")]
pub max_batch_size: Vec<String>,
#[clap(
short,
long,
default_value = "2",
help = "max wait time for accumulating a batch"
)]
pub batch_time_out_millis: Vec<String>,
#[clap(
long,
default_value_t = 90,
help = "threshold to start dropping batches under stress"
)]
pub batch_drop_millis: u64,
#[clap(
long,
default_value_t = 300,
help = "polling interval for new version of a model and META.json config"
)]
pub model_check_interval_secs: u64,
#[clap(
short,
long,
default_value = "models/pvideo/",
help = "root directory for models"
)]
pub model_dir: Vec<String>,
#[clap(
long,
help = "directory containing META.json config. separate from model_dir to facilitate remote config management"
)]
pub meta_json_dir: Option<String>,
#[clap(short, long, default_value = "", help = "directory for ssl certs")]
pub ssl_dir: String,
#[clap(
long,
help = "call out to external process to check model updates. custom logic can be written to pull from hdfs, gcs etc"
)]
pub modelsync_cli: Option<String>,
#[clap(
long,
default_value_t = 1,
help = "specify how many versions Navi retains in memory. good for cases of rolling model upgrade"
)]
pub versions_per_model: usize,
#[clap(
short,
long,
help = "most runtimes support loading ops custom writen. currently only implemented for TF"
)]
pub customops_lib: Option<String>,
#[clap(
long,
default_value = "8",
help = "number of threads to paralleling computations inside an op"
)]
pub intra_op_parallelism: Vec<String>,
#[clap(
long,
help = "number of threads to parallelize computations of the graph"
)]
pub inter_op_parallelism: Vec<String>,
#[clap(
long,
help = "signature of a serving. only TF"
)]
pub serving_sig: Vec<String>,
#[clap(long, default_value = "examples", help = "name of each input tensor")]
pub input: Vec<String>,
#[clap(long, default_value = "output_0", help = "name of each output tensor")]
pub output: Vec<String>,
#[clap(
long,
default_value_t = 500,
help = "max warmup records to use. warmup only implemented for TF"
)]
pub max_warmup_records: usize,
#[clap(long, value_parser = Args::parse_key_val::<String, String>, value_delimiter=',')]
pub onnx_global_thread_pool_options: Vec<(String, String)>,
#[clap(
long,
default_value = "true",
help = "when to use graph parallelization. only for ONNX"
)]
pub onnx_use_parallel_mode: String,
// #[clap(long, default_value = "false")]
// pub onnx_use_onednn: String,
#[clap(
long,
default_value = "true",
help = "trace internal memory allocation and generate bulk memory allocations. only for ONNX. turn if off if batch size dynamic"
)]
pub onnx_use_memory_pattern: String,
#[clap(long, value_parser = Args::parse_key_val::<String, String>, value_delimiter=',')]
pub onnx_ep_options: Vec<(String, String)>,
#[clap(long, help = "choice of gpu EPs for ONNX: cuda or tensorrt")]
pub onnx_gpu_ep: Option<String>,
#[clap(
long,
default_value = "home",
help = "converter for various input formats"
)]
pub onnx_use_converter: Option<String>,
#[clap(
long,
help = "whether to enable runtime profiling. only implemented for ONNX for now"
)]
pub profiling: Option<String>,
#[clap(
long,
default_value = "",
help = "metrics reporting for discrete features. only for Home converter for now"
)]
pub onnx_report_discrete_feature_ids: Vec<String>,
#[clap(
long,
default_value = "",
help = "metrics reporting for continuous features. only for Home converter for now"
)]
pub onnx_report_continuous_feature_ids: Vec<String>,
}
impl Args {
pub fn get_model_specs(model_dir: Vec<String>) -> Vec<String> {
let model_specs = model_dir
.iter()
//let it panic if some model_dir are wrong
.map(|dir| {
dir.trim_end_matches('/')
.rsplit_once('/')
.unwrap()
.1
.to_owned()
})
.collect();
info!("all model_specs: {:?}", model_specs);
model_specs
}
pub fn version_str_to_epoch(dt_str: &str) -> Result<i64, anyhow::Error> {
dt_str
.parse::<i64>()
.or_else(|_| {
let ts = OffsetDateTime::parse(dt_str, &Rfc3339)
.map(|d| (d.unix_timestamp_nanos()/1_000_000) as i64);
if ts.is_ok() {
info!("original version {} -> {}", dt_str, ts.unwrap());
}
ts
})
.map_err(anyhow::Error::msg)
}
/// Parse a single key-value pair
fn parse_key_val<T, U>(s: &str) -> Result<(T, U), Box<dyn Error + Send + Sync + 'static>>
where
T: std::str::FromStr,
T::Err: Error + Send + Sync + 'static,
U: std::str::FromStr,
U::Err: Error + Send + Sync + 'static,
{
let pos = s
.find('=')
.ok_or_else(|| format!("invalid KEY=value: no `=` found in `{}`", s))?;
Ok((s[..pos].parse()?, s[pos + 1..].parse()?))
}
}
lazy_static! {
pub static ref ARGS: Args = Args::parse();
pub static ref MODEL_SPECS: ArrayVec<String, MAX_NUM_MODELS> = {
let mut specs = ArrayVec::<String, MAX_NUM_MODELS>::new();
Args::get_model_specs(ARGS.model_dir.clone())
.into_iter()
.for_each(|m| specs.push(m));
specs
};
pub static ref INPUTS: ArrayVec<OnceCell<ArrayVec<String, MAX_NUM_INPUTS>>, MAX_NUM_MODELS> = {
let mut inputs =
ArrayVec::<OnceCell<ArrayVec<String, MAX_NUM_INPUTS>>, MAX_NUM_MODELS>::new();
for (idx, o) in ARGS.input.iter().enumerate() {
if o.trim().is_empty() {
info!("input spec is empty for model {}, auto detect later", idx);
inputs.push(OnceCell::new());
} else {
inputs.push(OnceCell::with_value(
o.split(",")
.map(|s| s.to_owned())
.collect::<ArrayVec<String, MAX_NUM_INPUTS>>(),
));
}
}
info!("all inputs:{:?}", inputs);
inputs
};
pub static ref OUTPUTS: ArrayVec<ArrayVec<String, MAX_NUM_OUTPUTS>, MAX_NUM_MODELS> = {
let mut outputs = ArrayVec::<ArrayVec<String, MAX_NUM_OUTPUTS>, MAX_NUM_MODELS>::new();
for o in ARGS.output.iter() {
outputs.push(
o.split(",")
.map(|s| s.to_owned())
.collect::<ArrayVec<String, MAX_NUM_OUTPUTS>>(),
);
}
info!("all outputs:{:?}", outputs);
outputs
};
}

Binary file not shown.

View File

@ -1,22 +0,0 @@
pub mod validatior {
pub mod cli_validator {
use crate::cli_args::{ARGS, MODEL_SPECS};
pub fn validate_input_args() {
assert_eq!(MODEL_SPECS.len(), ARGS.inter_op_parallelism.len());
assert_eq!(MODEL_SPECS.len(), ARGS.intra_op_parallelism.len());
//TODO for now we, we assume each model's output has only 1 tensor.
//this will be lifted once tf_model properly implements mtl outputs
//assert_eq!(OUTPUTS.len(), OUTPUTS.iter().fold(0usize, |a, b| a+b.len()));
}
pub fn validate_ps_model_args() {
assert!(ARGS.versions_per_model <= 2);
assert!(ARGS.versions_per_model >= 1);
assert_eq!(MODEL_SPECS.len(), ARGS.input.len());
assert_eq!(MODEL_SPECS.len(), ARGS.model_dir.len());
assert_eq!(MODEL_SPECS.len(), ARGS.max_batch_size.len());
assert_eq!(MODEL_SPECS.len(), ARGS.batch_time_out_millis.len());
}
}
}

BIN
navi/navi/src/lib.docx Normal file

Binary file not shown.

View File

@ -1,215 +0,0 @@
#[macro_use]
extern crate lazy_static;
extern crate core;
use serde_json::Value;
use tokio::sync::oneshot::Sender;
use tokio::time::Instant;
use std::ops::Deref;
use itertools::Itertools;
use crate::bootstrap::TensorInput;
use crate::predict_service::Model;
use crate::tf_proto::{DataType, TensorProto};
pub mod batch;
pub mod bootstrap;
pub mod cli_args;
pub mod metrics;
pub mod onnx_model;
pub mod predict_service;
pub mod tf_model;
pub mod torch_model;
pub mod cores {
pub mod validator;
}
pub mod tf_proto {
tonic::include_proto!("tensorflow");
pub mod tensorflow_serving {
tonic::include_proto!("tensorflow.serving");
}
}
pub mod kf_serving {
tonic::include_proto!("inference");
}
#[cfg(test)]
mod tests {
use crate::cli_args::Args;
#[test]
fn test_version_string_to_epoch() {
assert_eq!(
Args::version_str_to_epoch("2022-12-20T10:18:53.000Z").unwrap_or(-1),
1671531533000
);
assert_eq!(Args::version_str_to_epoch("1203444").unwrap_or(-1), 1203444);
}
}
mod utils {
use crate::cli_args::{ARGS, MODEL_SPECS};
use anyhow::Result;
use log::info;
use serde_json::Value;
pub fn read_config(meta_file: &String) -> Result<Value> {
let json = std::fs::read_to_string(meta_file)?;
let v: Value = serde_json::from_str(&json)?;
Ok(v)
}
pub fn get_config_or_else<F>(model_config: &Value, key: &str, default: F) -> String
where
F: FnOnce() -> String,
{
match model_config[key] {
Value::String(ref v) => {
info!("from model_config: {}={}", key, v);
v.to_string()
}
Value::Number(ref num) => {
info!(
"from model_config: {}={} (turn number into a string)",
key, num
);
num.to_string()
}
_ => {
let d = default();
info!("from default: {}={}", key, d);
d
}
}
}
pub fn get_config_or(model_config: &Value, key: &str, default: &str) -> String {
get_config_or_else(model_config, key, || default.to_string())
}
pub fn get_meta_dir() -> &'static str {
ARGS.meta_json_dir
.as_ref()
.map(|s| s.as_str())
.unwrap_or_else(|| {
let model_dir = &ARGS.model_dir[0];
let meta_dir = &model_dir[0..model_dir.rfind(&MODEL_SPECS[0]).unwrap()];
info!(
"no meta_json_dir specified, hence derive from first model dir:{}->{}",
model_dir, meta_dir
);
meta_dir
})
}
}
pub type SerializedInput = Vec<u8>;
pub const VERSION: &str = env!("CARGO_PKG_VERSION");
pub const NAME: &str = env!("CARGO_PKG_NAME");
pub type ModelFactory<T> = fn(usize, String, &Value) -> anyhow::Result<T>;
pub const MAX_NUM_MODELS: usize = 16;
pub const MAX_NUM_OUTPUTS: usize = 30;
pub const MAX_NUM_INPUTS: usize = 120;
pub const META_INFO: &str = "META.json";
//use a heap allocated generic type here so that both
//Tensorflow & Pytorch implementation can return their Tensor wrapped in a Box
//without an extra memcopy to Vec
pub type TensorReturn<T> = Box<dyn Deref<Target = [T]>>;
//returned tensor may be int64 i.e., a list of relevant ad ids
pub enum TensorReturnEnum {
FloatTensorReturn(TensorReturn<f32>),
StringTensorReturn(TensorReturn<String>),
Int64TensorReturn(TensorReturn<i64>),
Int32TensorReturn(TensorReturn<i32>),
}
impl TensorReturnEnum {
#[inline(always)]
pub fn slice(&self, start: usize, end: usize) -> TensorScores {
match self {
TensorReturnEnum::FloatTensorReturn(f32_return) => {
TensorScores::Float32TensorScores(f32_return[start..end].to_vec())
}
TensorReturnEnum::Int64TensorReturn(i64_return) => {
TensorScores::Int64TensorScores(i64_return[start..end].to_vec())
}
TensorReturnEnum::Int32TensorReturn(i32_return) => {
TensorScores::Int32TensorScores(i32_return[start..end].to_vec())
}
TensorReturnEnum::StringTensorReturn(str_return) => {
TensorScores::StringTensorScores(str_return[start..end].to_vec())
}
}
}
}
#[derive(Debug)]
pub enum PredictResult {
Ok(Vec<TensorScores>, i64),
DropDueToOverload,
ModelNotFound(usize),
ModelNotReady(usize),
ModelVersionNotFound(usize, i64),
}
#[derive(Debug)]
pub enum TensorScores {
Float32TensorScores(Vec<f32>),
Int64TensorScores(Vec<i64>),
Int32TensorScores(Vec<i32>),
StringTensorScores(Vec<String>),
}
impl TensorScores {
pub fn create_tensor_proto(self) -> TensorProto {
match self {
TensorScores::Float32TensorScores(f32_tensor) => TensorProto {
dtype: DataType::DtFloat as i32,
float_val: f32_tensor,
..Default::default()
},
TensorScores::Int64TensorScores(i64_tensor) => TensorProto {
dtype: DataType::DtInt64 as i32,
int64_val: i64_tensor,
..Default::default()
},
TensorScores::Int32TensorScores(i32_tensor) => TensorProto {
dtype: DataType::DtInt32 as i32,
int_val: i32_tensor,
..Default::default()
},
TensorScores::StringTensorScores(str_tensor) => TensorProto {
dtype: DataType::DtString as i32,
string_val: str_tensor.into_iter().map(|s| s.into_bytes()).collect_vec(),
..Default::default()
},
}
}
pub fn len(&self) -> usize {
match &self {
TensorScores::Float32TensorScores(t) => t.len(),
TensorScores::Int64TensorScores(t) => t.len(),
TensorScores::Int32TensorScores(t) => t.len(),
TensorScores::StringTensorScores(t) => t.len(),
}
}
}
#[derive(Debug)]
pub enum PredictMessage<T: Model> {
Predict(
usize,
Option<i64>,
Vec<TensorInput>,
Sender<PredictResult>,
Instant,
),
UpsertModel(T),
/*
#[allow(dead_code)]
DeleteModel(usize),
*/
}
#[derive(Debug)]
pub struct Callback(Sender<PredictResult>, usize);
pub const MAX_VERSIONS_PER_MODEL: usize = 2;

BIN
navi/navi/src/metrics.docx Normal file

Binary file not shown.

View File

@ -1,297 +0,0 @@
use log::error;
use prometheus::{
CounterVec, HistogramOpts, HistogramVec, IntCounter, IntCounterVec, IntGauge, IntGaugeVec,
Opts, Registry,
};
use warp::{Rejection, Reply};
use crate::{NAME, VERSION};
lazy_static! {
pub static ref REGISTRY: Registry = Registry::new();
pub static ref NUM_REQUESTS_RECEIVED: IntCounter =
IntCounter::new(":navi:num_requests", "Number of Requests Received")
.expect("metric can be created");
pub static ref NUM_REQUESTS_FAILED: IntCounter = IntCounter::new(
":navi:num_requests_failed",
"Number of Request Inference Failed"
)
.expect("metric can be created");
pub static ref NUM_REQUESTS_DROPPED: IntCounter = IntCounter::new(
":navi:num_requests_dropped",
"Number of Oneshot Receivers Dropped"
)
.expect("metric can be created");
pub static ref NUM_BATCHES_DROPPED: IntCounter = IntCounter::new(
":navi:num_batches_dropped",
"Number of Batches Proactively Dropped"
)
.expect("metric can be created");
pub static ref NUM_BATCH_PREDICTION: IntCounter =
IntCounter::new(":navi:num_batch_prediction", "Number of batch prediction")
.expect("metric can be created");
pub static ref BATCH_SIZE: IntGauge =
IntGauge::new(":navi:batch_size", "Size of current batch").expect("metric can be created");
pub static ref NAVI_VERSION: IntGauge =
IntGauge::new(":navi:navi_version", "navi's current version")
.expect("metric can be created");
pub static ref RESPONSE_TIME_COLLECTOR: HistogramVec = HistogramVec::new(
HistogramOpts::new(":navi:response_time", "Response Time in ms").buckets(Vec::from(&[
0.0, 10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0, 110.0, 120.0, 130.0,
140.0, 150.0, 160.0, 170.0, 180.0, 190.0, 200.0, 250.0, 300.0, 500.0, 1000.0
]
as &'static [f64])),
&["model_name"]
)
.expect("metric can be created");
pub static ref NUM_PREDICTIONS: IntCounterVec = IntCounterVec::new(
Opts::new(
":navi:num_predictions",
"Number of predictions made by model"
),
&["model_name"]
)
.expect("metric can be created");
pub static ref PREDICTION_SCORE_SUM: CounterVec = CounterVec::new(
Opts::new(
":navi:prediction_score_sum",
"Sum of prediction score made by model"
),
&["model_name"]
)
.expect("metric can be created");
pub static ref NEW_MODEL_SNAPSHOT: IntCounterVec = IntCounterVec::new(
Opts::new(
":navi:new_model_snapshot",
"Load a new version of model snapshot"
),
&["model_name"]
)
.expect("metric can be created");
pub static ref MODEL_SNAPSHOT_VERSION: IntGaugeVec = IntGaugeVec::new(
Opts::new(
":navi:model_snapshot_version",
"Record model snapshot version"
),
&["model_name"]
)
.expect("metric can be created");
pub static ref NUM_REQUESTS_RECEIVED_BY_MODEL: IntCounterVec = IntCounterVec::new(
Opts::new(
":navi:num_requests_by_model",
"Number of Requests Received by model"
),
&["model_name"]
)
.expect("metric can be created");
pub static ref NUM_REQUESTS_FAILED_BY_MODEL: IntCounterVec = IntCounterVec::new(
Opts::new(
":navi:num_requests_failed_by_model",
"Number of Request Inference Failed by model"
),
&["model_name"]
)
.expect("metric can be created");
pub static ref NUM_REQUESTS_DROPPED_BY_MODEL: IntCounterVec = IntCounterVec::new(
Opts::new(
":navi:num_requests_dropped_by_model",
"Number of Oneshot Receivers Dropped by model"
),
&["model_name"]
)
.expect("metric can be created");
pub static ref NUM_BATCHES_DROPPED_BY_MODEL: IntCounterVec = IntCounterVec::new(
Opts::new(
":navi:num_batches_dropped_by_model",
"Number of Batches Proactively Dropped by model"
),
&["model_name"]
)
.expect("metric can be created");
pub static ref INFERENCE_FAILED_REQUESTS_BY_MODEL: IntCounterVec = IntCounterVec::new(
Opts::new(
":navi:inference_failed_requests_by_model",
"Number of failed inference requests by model"
),
&["model_name"]
)
.expect("metric can be created");
pub static ref NUM_PREDICTION_BY_MODEL: IntCounterVec = IntCounterVec::new(
Opts::new(
":navi:num_prediction_by_model",
"Number of prediction by model"
),
&["model_name"]
)
.expect("metric can be created");
pub static ref NUM_BATCH_PREDICTION_BY_MODEL: IntCounterVec = IntCounterVec::new(
Opts::new(
":navi:num_batch_prediction_by_model",
"Number of batch prediction by model"
),
&["model_name"]
)
.expect("metric can be created");
pub static ref BATCH_SIZE_BY_MODEL: IntGaugeVec = IntGaugeVec::new(
Opts::new(
":navi:batch_size_by_model",
"Size of current batch by model"
),
&["model_name"]
)
.expect("metric can be created");
pub static ref CUSTOMOP_VERSION: IntGauge =
IntGauge::new(":navi:customop_version", "The hashed Custom OP Version")
.expect("metric can be created");
pub static ref MPSC_CHANNEL_SIZE: IntGauge =
IntGauge::new(":navi:mpsc_channel_size", "The mpsc channel's request size")
.expect("metric can be created");
pub static ref BLOCKING_REQUEST_NUM: IntGauge = IntGauge::new(
":navi:blocking_request_num",
"The (batch) request waiting or being executed"
)
.expect("metric can be created");
pub static ref MODEL_INFERENCE_TIME_COLLECTOR: HistogramVec = HistogramVec::new(
HistogramOpts::new(":navi:model_inference_time", "Model inference time in ms").buckets(
Vec::from(&[
0.0, 5.0, 10.0, 15.0, 20.0, 25.0, 30.0, 35.0, 40.0, 45.0, 50.0, 55.0, 60.0, 65.0,
70.0, 75.0, 80.0, 85.0, 90.0, 100.0, 110.0, 120.0, 130.0, 140.0, 150.0, 160.0,
170.0, 180.0, 190.0, 200.0, 250.0, 300.0, 500.0, 1000.0
] as &'static [f64])
),
&["model_name"]
)
.expect("metric can be created");
pub static ref CONVERTER_TIME_COLLECTOR: HistogramVec = HistogramVec::new(
HistogramOpts::new(":navi:converter_time", "converter time in microseconds").buckets(
Vec::from(&[
0.0, 500.0, 1000.0, 1500.0, 2000.0, 2500.0, 3000.0, 3500.0, 4000.0, 4500.0, 5000.0,
5500.0, 6000.0, 6500.0, 7000.0, 20000.0
] as &'static [f64])
),
&["model_name"]
)
.expect("metric can be created");
pub static ref CERT_EXPIRY_EPOCH: IntGauge =
IntGauge::new(":navi:cert_expiry_epoch", "Timestamp when the current cert expires")
.expect("metric can be created");
}
pub fn register_custom_metrics() {
REGISTRY
.register(Box::new(NUM_REQUESTS_RECEIVED.clone()))
.expect("collector can be registered");
REGISTRY
.register(Box::new(NUM_REQUESTS_FAILED.clone()))
.expect("collector can be registered");
REGISTRY
.register(Box::new(NUM_REQUESTS_DROPPED.clone()))
.expect("collector can be registered");
REGISTRY
.register(Box::new(RESPONSE_TIME_COLLECTOR.clone()))
.expect("collector can be registered");
REGISTRY
.register(Box::new(NAVI_VERSION.clone()))
.expect("collector can be registered");
REGISTRY
.register(Box::new(BATCH_SIZE.clone()))
.expect("collector can be registered");
REGISTRY
.register(Box::new(NUM_BATCH_PREDICTION.clone()))
.expect("collector can be registered");
REGISTRY
.register(Box::new(NUM_BATCHES_DROPPED.clone()))
.expect("collector can be registered");
REGISTRY
.register(Box::new(NUM_PREDICTIONS.clone()))
.expect("collector can be registered");
REGISTRY
.register(Box::new(PREDICTION_SCORE_SUM.clone()))
.expect("collector can be registered");
REGISTRY
.register(Box::new(NEW_MODEL_SNAPSHOT.clone()))
.expect("collector can be registered");
REGISTRY
.register(Box::new(MODEL_SNAPSHOT_VERSION.clone()))
.expect("collector can be registered");
REGISTRY
.register(Box::new(NUM_REQUESTS_RECEIVED_BY_MODEL.clone()))
.expect("collector can be registered");
REGISTRY
.register(Box::new(NUM_REQUESTS_FAILED_BY_MODEL.clone()))
.expect("collector can be registered");
REGISTRY
.register(Box::new(NUM_REQUESTS_DROPPED_BY_MODEL.clone()))
.expect("collector can be registered");
REGISTRY
.register(Box::new(NUM_BATCHES_DROPPED_BY_MODEL.clone()))
.expect("collector can be registered");
REGISTRY
.register(Box::new(INFERENCE_FAILED_REQUESTS_BY_MODEL.clone()))
.expect("collector can be registered");
REGISTRY
.register(Box::new(NUM_PREDICTION_BY_MODEL.clone()))
.expect("collector can be registered");
REGISTRY
.register(Box::new(NUM_BATCH_PREDICTION_BY_MODEL.clone()))
.expect("collector can be registered");
REGISTRY
.register(Box::new(BATCH_SIZE_BY_MODEL.clone()))
.expect("collector can be registered");
REGISTRY
.register(Box::new(CUSTOMOP_VERSION.clone()))
.expect("collector can be registered");
REGISTRY
.register(Box::new(MPSC_CHANNEL_SIZE.clone()))
.expect("collector can be registered");
REGISTRY
.register(Box::new(BLOCKING_REQUEST_NUM.clone()))
.expect("collector can be registered");
REGISTRY
.register(Box::new(MODEL_INFERENCE_TIME_COLLECTOR.clone()))
.expect("collector can be registered");
REGISTRY
.register(Box::new(CONVERTER_TIME_COLLECTOR.clone()))
.expect("collector can be registered");
REGISTRY
.register(Box::new(CERT_EXPIRY_EPOCH.clone()))
.expect("collector can be registered");
}
pub fn register_dynamic_metrics(c: &HistogramVec) {
REGISTRY
.register(Box::new(c.clone()))
.expect("dynamic metric collector cannot be registered");
}
pub async fn metrics_handler() -> Result<impl Reply, Rejection> {
use prometheus::Encoder;
let encoder = prometheus::TextEncoder::new();
let mut buffer = Vec::new();
if let Err(e) = encoder.encode(&REGISTRY.gather(), &mut buffer) {
error!("could not encode custom metrics: {}", e);
};
let mut res = match String::from_utf8(buffer) {
Ok(v) => format!("#{}:{}\n{}", NAME, VERSION, v),
Err(e) => {
error!("custom metrics could not be from_utf8'd: {}", e);
String::default()
}
};
buffer = Vec::new();
if let Err(e) = encoder.encode(&prometheus::gather(), &mut buffer) {
error!("could not encode prometheus metrics: {}", e);
};
let res_custom = match String::from_utf8(buffer) {
Ok(v) => v,
Err(e) => {
error!("prometheus metrics could not be from_utf8'd: {}", e);
String::default()
}
};
res.push_str(&res_custom);
Ok(res)
}

Binary file not shown.

View File

@ -1,275 +0,0 @@
#[cfg(feature = "onnx")]
pub mod onnx {
use crate::TensorReturnEnum;
use crate::bootstrap::{TensorInput, TensorInputEnum};
use crate::cli_args::{
Args, ARGS, INPUTS, MODEL_SPECS, OUTPUTS,
};
use crate::metrics::{self, CONVERTER_TIME_COLLECTOR};
use crate::predict_service::Model;
use crate::{MAX_NUM_INPUTS, MAX_NUM_OUTPUTS, META_INFO, utils};
use anyhow::Result;
use arrayvec::ArrayVec;
use dr_transform::converter::{BatchPredictionRequestToTorchTensorConverter, Converter};
use itertools::Itertools;
use log::{debug, info};
use dr_transform::ort::environment::Environment;
use dr_transform::ort::session::Session;
use dr_transform::ort::tensor::InputTensor;
use dr_transform::ort::{ExecutionProvider, GraphOptimizationLevel, SessionBuilder};
use dr_transform::ort::LoggingLevel;
use serde_json::Value;
use std::fmt::{Debug, Display};
use std::sync::Arc;
use std::{fmt, fs};
use tokio::time::Instant;
lazy_static! {
pub static ref ENVIRONMENT: Arc<Environment> = Arc::new(
Environment::builder()
.with_name("onnx home")
.with_log_level(LoggingLevel::Error)
.with_global_thread_pool(ARGS.onnx_global_thread_pool_options.clone())
.build()
.unwrap()
);
}
#[derive(Debug)]
pub struct OnnxModel {
pub session: Session,
pub model_idx: usize,
pub version: i64,
pub export_dir: String,
pub output_filters: ArrayVec<usize, MAX_NUM_OUTPUTS>,
pub input_converter: Box<dyn Converter>,
}
impl Display for OnnxModel {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"idx: {}, onnx model_name:{}, version:{}, output_filters:{:?}, converter:{:}",
self.model_idx,
MODEL_SPECS[self.model_idx],
self.version,
self.output_filters,
self.input_converter
)
}
}
impl Drop for OnnxModel {
fn drop(&mut self) {
if ARGS.profiling != None {
self.session.end_profiling().map_or_else(
|e| {
info!("end profiling with some error:{:?}", e);
},
|f| {
info!("profiling ended with file:{}", f);
},
);
}
}
}
impl OnnxModel {
fn get_output_filters(session: &Session, idx: usize) -> ArrayVec<usize, MAX_NUM_OUTPUTS> {
OUTPUTS[idx]
.iter()
.map(|output| session.outputs.iter().position(|o| o.name == *output))
.flatten()
.collect::<ArrayVec<usize, MAX_NUM_OUTPUTS>>()
}
#[cfg(target_os = "linux")]
fn ep_choices() -> Vec<ExecutionProvider> {
match ARGS.onnx_gpu_ep.as_ref().map(|e| e.as_str()) {
Some("onednn") => vec![Self::ep_with_options(ExecutionProvider::onednn())],
Some("tensorrt") => vec![Self::ep_with_options(ExecutionProvider::tensorrt())],
Some("cuda") => vec![Self::ep_with_options(ExecutionProvider::cuda())],
_ => vec![Self::ep_with_options(ExecutionProvider::cpu())],
}
}
fn ep_with_options(mut ep: ExecutionProvider) -> ExecutionProvider {
for (ref k, ref v) in ARGS.onnx_ep_options.clone() {
ep = ep.with(k, v);
info!("setting option:{} -> {} and now ep is:{:?}", k, v, ep);
}
ep
}
#[cfg(target_os = "macos")]
fn ep_choices() -> Vec<ExecutionProvider> {
vec![Self::ep_with_options(ExecutionProvider::cpu())]
}
pub fn new(idx: usize, version: String, model_config: &Value) -> Result<OnnxModel> {
let export_dir = format!("{}/{}/model.onnx", ARGS.model_dir[idx], version);
let meta_info = format!("{}/{}/{}", ARGS.model_dir[idx], version, META_INFO);
let mut builder = SessionBuilder::new(&ENVIRONMENT)?
.with_optimization_level(GraphOptimizationLevel::Level3)?
.with_parallel_execution(ARGS.onnx_use_parallel_mode == "true")?;
if ARGS.onnx_global_thread_pool_options.is_empty() {
builder = builder
.with_inter_threads(
utils::get_config_or(
model_config,
"inter_op_parallelism",
&ARGS.inter_op_parallelism[idx],
)
.parse()?,
)?
.with_intra_threads(
utils::get_config_or(
model_config,
"intra_op_parallelism",
&ARGS.intra_op_parallelism[idx],
)
.parse()?,
)?;
}
else {
builder = builder.with_disable_per_session_threads()?;
}
builder = builder
.with_memory_pattern(ARGS.onnx_use_memory_pattern == "true")?
.with_execution_providers(&OnnxModel::ep_choices())?;
match &ARGS.profiling {
Some(p) => {
debug!("Enable profiling, writing to {}", *p);
builder = builder.with_profiling(p)?
}
_ => {}
}
let session = builder.with_model_from_file(&export_dir)?;
info!(
"inputs: {:?}, outputs: {:?}",
session.inputs.iter().format(","),
session.outputs.iter().format(",")
);
fs::read_to_string(&meta_info)
.ok()
.map(|info| info!("meta info:{}", info));
let output_filters = OnnxModel::get_output_filters(&session, idx);
let mut reporting_feature_ids: Vec<(i64, &str)> = vec![];
let input_spec_cell = &INPUTS[idx];
if input_spec_cell.get().is_none() {
let input_spec = session
.inputs
.iter()
.map(|input| input.name.clone())
.collect::<ArrayVec<String, MAX_NUM_INPUTS>>();
input_spec_cell.set(input_spec.clone()).map_or_else(
|_| info!("unable to set the input_spec for model {}", idx),
|_| info!("auto detect and set the inputs: {:?}", input_spec),
);
}
ARGS.onnx_report_discrete_feature_ids
.iter()
.for_each(|ids| {
ids.split(",")
.filter(|s| !s.is_empty())
.map(|s| s.parse::<i64>().unwrap())
.for_each(|id| reporting_feature_ids.push((id, "discrete")))
});
ARGS.onnx_report_continuous_feature_ids
.iter()
.for_each(|ids| {
ids.split(",")
.filter(|s| !s.is_empty())
.map(|s| s.parse::<i64>().unwrap())
.for_each(|id| reporting_feature_ids.push((id, "continuous")))
});
let onnx_model = OnnxModel {
session,
model_idx: idx,
version: Args::version_str_to_epoch(&version)?,
export_dir,
output_filters,
input_converter: Box::new(BatchPredictionRequestToTorchTensorConverter::new(
&ARGS.model_dir[idx],
&version,
reporting_feature_ids,
Some(metrics::register_dynamic_metrics),
)?),
};
onnx_model.warmup()?;
Ok(onnx_model)
}
}
///Currently we only assume the input as just one string tensor.
///The string tensor will be be converted to the actual raw tensors.
/// The converter we are using is very specific to home.
/// It reads a BatchDataRecord thrift and decode it to a batch of raw input tensors.
/// Navi will then do server side batching and feed it to ONNX runtime
impl Model for OnnxModel {
//TODO: implement a generic online warmup for all runtimes
fn warmup(&self) -> Result<()> {
Ok(())
}
#[inline(always)]
fn do_predict(
&self,
input_tensors: Vec<Vec<TensorInput>>,
_: u64,
) -> (Vec<TensorReturnEnum>, Vec<Vec<usize>>) {
let batched_tensors = TensorInputEnum::merge_batch(input_tensors);
let (inputs, batch_ends): (Vec<Vec<InputTensor>>, Vec<Vec<usize>>) = batched_tensors
.into_iter()
.map(|batched_tensor| {
match batched_tensor.tensor_data {
TensorInputEnum::String(t) if ARGS.onnx_use_converter.is_some() => {
let start = Instant::now();
let (inputs, batch_ends) = self.input_converter.convert(t);
// info!("batch_ends:{:?}", batch_ends);
CONVERTER_TIME_COLLECTOR
.with_label_values(&[&MODEL_SPECS[self.model_idx()]])
.observe(
start.elapsed().as_micros() as f64
/ (*batch_ends.last().unwrap() as f64),
);
(inputs, batch_ends)
}
_ => unimplemented!(),
}
})
.unzip();
//invariant we only support one input as string. will relax later
assert_eq!(inputs.len(), 1);
let output_tensors = self
.session
.run(inputs.into_iter().flatten().collect::<Vec<_>>())
.unwrap();
self.output_filters
.iter()
.map(|&idx| {
let mut size = 1usize;
let output = output_tensors[idx].try_extract::<f32>().unwrap();
for &dim in self.session.outputs[idx].dimensions.iter().flatten() {
size *= dim as usize;
}
let tensor_ends = batch_ends[0]
.iter()
.map(|&batch| batch * size)
.collect::<Vec<_>>();
(
//only works for batch major
//TODO: to_vec() obviously wasteful, especially for large batches(GPU) . Will refactor to
//break up output and return Vec<Vec<TensorScore>> here
TensorReturnEnum::FloatTensorReturn(Box::new(output.view().as_slice().unwrap().to_vec(),
)),
tensor_ends,
)
})
.unzip()
}
#[inline(always)]
fn model_idx(&self) -> usize {
self.model_idx
}
#[inline(always)]
fn version(&self) -> i64 {
self.version
}
}
}

Binary file not shown.

View File

@ -1,315 +0,0 @@
use anyhow::{anyhow, Result};
use arrayvec::ArrayVec;
use itertools::Itertools;
use log::{error, info};
use std::fmt::{Debug, Display};
use std::string::String;
use std::sync::Arc;
use std::time::Duration;
use tokio::process::Command;
use tokio::sync::mpsc::error::TryRecvError;
use tokio::sync::mpsc::{Receiver, Sender};
use tokio::sync::{mpsc, oneshot};
use tokio::time::{Instant, sleep};
use warp::Filter;
use crate::batch::BatchPredictor;
use crate::bootstrap::TensorInput;
use crate::{MAX_NUM_MODELS, MAX_VERSIONS_PER_MODEL, META_INFO, metrics, ModelFactory, PredictMessage, PredictResult, TensorReturnEnum, utils};
use crate::cli_args::{ARGS, MODEL_SPECS};
use crate::cores::validator::validatior::cli_validator;
use crate::metrics::MPSC_CHANNEL_SIZE;
use serde_json::{self, Value};
pub trait Model: Send + Sync + Display + Debug + 'static {
fn warmup(&self) -> Result<()>;
//TODO: refactor this to return vec<vec<TensorScores>>, i.e.
//we have the underlying runtime impl to split the response to each client.
//It will eliminate some inefficient memory copy in onnx_model.rs as well as simplify code
fn do_predict(
&self,
input_tensors: Vec<Vec<TensorInput>>,
total_len: u64,
) -> (Vec<TensorReturnEnum>, Vec<Vec<usize>>);
fn model_idx(&self) -> usize;
fn version(&self) -> i64;
}
#[derive(Debug)]
pub struct PredictService<T: Model> {
tx: Sender<PredictMessage<T>>,
}
impl<T: Model> PredictService<T> {
pub async fn init(model_factory: ModelFactory<T>) -> Self {
cli_validator::validate_ps_model_args();
let (tx, rx) = mpsc::channel(32_000);
tokio::spawn(PredictService::tf_queue_manager(rx));
tokio::spawn(PredictService::model_watcher_latest(
model_factory,
tx.clone(),
));
let metrics_route = warp::path!("metrics").and_then(metrics::metrics_handler);
let metric_server = warp::serve(metrics_route).run(([0, 0, 0, 0], ARGS.prometheus_port));
tokio::spawn(metric_server);
PredictService { tx }
}
#[inline(always)]
pub async fn predict(
&self,
idx: usize,
version: Option<i64>,
val: Vec<TensorInput>,
ts: Instant,
) -> Result<PredictResult> {
let (tx, rx) = oneshot::channel();
if let Err(e) = self
.tx
.clone()
.send(PredictMessage::Predict(idx, version, val, tx, ts))
.await
{
error!("mpsc send error:{}", e);
Err(anyhow!(e))
} else {
MPSC_CHANNEL_SIZE.inc();
rx.await.map_err(anyhow::Error::msg)
}
}
async fn load_latest_model_from_model_dir(
model_factory: ModelFactory<T>,
model_config: &Value,
tx: Sender<PredictMessage<T>>,
idx: usize,
max_version: String,
latest_version: &mut String,
) {
match model_factory(idx, max_version.clone(), model_config) {
Ok(tf_model) => tx
.send(PredictMessage::UpsertModel(tf_model))
.await
.map_or_else(
|e| error!("send UpsertModel error: {}", e),
|_| *latest_version = max_version,
),
Err(e) => {
error!("skip loading model due to failure: {:?}", e);
}
}
}
async fn scan_load_latest_model_from_model_dir(
model_factory: ModelFactory<T>,
model_config: &Value,
tx: Sender<PredictMessage<T>>,
model_idx: usize,
cur_version: &mut String,
) -> Result<()> {
let model_dir = &ARGS.model_dir[model_idx];
let next_version = utils::get_config_or_else(model_config, "version", || {
info!("no version found, hence use max version");
std::fs::read_dir(model_dir)
.map_err(|e| format!("read dir error:{}", e))
.and_then(|paths| {
paths
.into_iter()
.flat_map(|p| {
p.map_err(|e| error!("dir entry error: {}", e))
.and_then(|dir| {
dir.file_name()
.into_string()
.map_err(|e| error!("osstring error: {:?}", e))
})
.ok()
})
.filter(|f| !f.to_lowercase().contains(&META_INFO.to_lowercase()))
.max()
.ok_or_else(|| "no dir found hence no max".to_owned())
})
.unwrap_or_else(|e| {
error!(
"can't get the max version hence return cur_version, error is: {}",
e
);
cur_version.to_string()
})
});
//as long as next version doesn't match cur version maintained we reload
if next_version.ne(cur_version) {
info!("reload the version: {}->{}", cur_version, next_version);
PredictService::load_latest_model_from_model_dir(
model_factory,
model_config,
tx,
model_idx,
next_version,
cur_version,
)
.await;
}
Ok(())
}
async fn model_watcher_latest(model_factory: ModelFactory<T>, tx: Sender<PredictMessage<T>>) {
async fn call_external_modelsync(cli: &str, cur_versions: &Vec<String>) -> Result<()> {
let mut args = cli.split_whitespace();
let mut cmd = Command::new(args.next().ok_or(anyhow!("model sync cli empty"))?);
let extr_args = MODEL_SPECS
.iter()
.zip(cur_versions)
.flat_map(|(spec, version)| vec!["--model-spec", spec, "--cur-version", version])
.collect_vec();
info!("run model sync: {} with extra args: {:?}", cli, extr_args);
let output = cmd.args(args).args(extr_args).output().await?;
info!("model sync stdout:{}", String::from_utf8(output.stdout)?);
info!("model sync stderr:{}", String::from_utf8(output.stderr)?);
if output.status.success() {
Ok(())
} else {
Err(anyhow!(
"model sync failed with status: {:?}!",
output.status
))
}
}
let meta_dir = utils::get_meta_dir();
let meta_file = format!("{}{}", meta_dir, META_INFO);
//initialize the latest version array
let mut cur_versions = vec!["".to_owned(); MODEL_SPECS.len()];
loop {
info!("***polling for models***"); //nice deliminter
if let Some(ref cli) = ARGS.modelsync_cli {
if let Err(e) = call_external_modelsync(cli, &cur_versions).await {
error!("model sync cli running error:{}", e)
}
}
let config = utils::read_config(&meta_file).unwrap_or_else(|e| {
info!("config file {} not found due to: {}", meta_file, e);
Value::Null
});
info!("config:{}", config);
for (idx, cur_version) in cur_versions.iter_mut().enumerate() {
let model_dir = &ARGS.model_dir[idx];
PredictService::scan_load_latest_model_from_model_dir(
model_factory,
&config[&MODEL_SPECS[idx]],
tx.clone(),
idx,
cur_version,
)
.await
.map_or_else(
|e| error!("scanned {}, error {:?}", model_dir, e),
|_| info!("scanned {}, latest_version: {}", model_dir, cur_version),
);
}
sleep(Duration::from_secs(ARGS.model_check_interval_secs)).await;
}
}
async fn tf_queue_manager(mut rx: Receiver<PredictMessage<T>>) {
// Start receiving messages
info!("setting up queue manager");
let max_batch_size = ARGS
.max_batch_size
.iter()
.map(|b| b.parse().unwrap())
.collect::<Vec<usize>>();
let batch_time_out_millis = ARGS
.batch_time_out_millis
.iter()
.map(|b| b.parse().unwrap())
.collect::<Vec<u64>>();
let no_msg_wait_millis = *batch_time_out_millis.iter().min().unwrap();
let mut all_model_predictors: ArrayVec::<ArrayVec<BatchPredictor<T>, MAX_VERSIONS_PER_MODEL>, MAX_NUM_MODELS> =
(0 ..MAX_NUM_MODELS).map( |_| ArrayVec::<BatchPredictor<T>, MAX_VERSIONS_PER_MODEL>::new()).collect();
loop {
let msg = rx.try_recv();
let no_more_msg = match msg {
Ok(PredictMessage::Predict(model_spec_at, version, val, resp, ts)) => {
if let Some(model_predictors) = all_model_predictors.get_mut(model_spec_at) {
if model_predictors.is_empty() {
resp.send(PredictResult::ModelNotReady(model_spec_at))
.unwrap_or_else(|e| error!("cannot send back model not ready error: {:?}", e));
}
else {
match version {
None => model_predictors[0].push(val, resp, ts),
Some(the_version) => match model_predictors
.iter_mut()
.find(|x| x.model.version() == the_version)
{
None => resp
.send(PredictResult::ModelVersionNotFound(
model_spec_at,
the_version,
))
.unwrap_or_else(|e| {
error!("cannot send back version error: {:?}", e)
}),
Some(predictor) => predictor.push(val, resp, ts),
},
}
}
} else {
resp.send(PredictResult::ModelNotFound(model_spec_at))
.unwrap_or_else(|e| error!("cannot send back model not found error: {:?}", e))
}
MPSC_CHANNEL_SIZE.dec();
false
}
Ok(PredictMessage::UpsertModel(tf_model)) => {
let idx = tf_model.model_idx();
let predictor = BatchPredictor {
model: Arc::new(tf_model),
input_tensors: Vec::with_capacity(max_batch_size[idx]),
callbacks: Vec::with_capacity(max_batch_size[idx]),
cur_batch_size: 0,
max_batch_size: max_batch_size[idx],
batch_time_out_millis: batch_time_out_millis[idx],
//initialize to be current time
queue_reset_ts: Instant::now(),
queue_earliest_rq_ts: Instant::now(),
};
assert!(idx < all_model_predictors.len());
metrics::NEW_MODEL_SNAPSHOT
.with_label_values(&[&MODEL_SPECS[idx]])
.inc();
//we can do this since the vector is small
let predictors = &mut all_model_predictors[idx];
if predictors.len() == 0 {
info!("now we serve new model: {}", predictor.model);
}
else {
info!("now we serve updated model: {}", predictor.model);
}
if predictors.len() == ARGS.versions_per_model {
predictors.remove(predictors.len() - 1);
}
predictors.insert(0, predictor);
false
}
Err(TryRecvError::Empty) => true,
Err(TryRecvError::Disconnected) => true,
};
for predictor in all_model_predictors.iter_mut().flatten() {
//if predictor batch queue not empty and times out or no more msg in the queue, flush
if (!predictor.input_tensors.is_empty() && (predictor.duration_past(predictor.batch_time_out_millis) || no_more_msg))
//if batch queue reaches limit, flush
|| predictor.cur_batch_size >= predictor.max_batch_size
{
predictor.batch_predict();
}
}
if no_more_msg {
sleep(Duration::from_millis(no_msg_wait_millis)).await;
}
}
}
#[inline(always)]
pub fn get_model_index(model_spec: &str) -> Option<usize> {
MODEL_SPECS.iter().position(|m| m == model_spec)
}
}

BIN
navi/navi/src/tf_model.docx Normal file

Binary file not shown.

View File

@ -1,492 +0,0 @@
#[cfg(feature = "tf")]
pub mod tf {
use arrayvec::ArrayVec;
use itertools::Itertools;
use log::{debug, error, info, warn};
use prost::Message;
use std::fmt;
use std::fmt::Display;
use std::string::String;
use tensorflow::io::{RecordReader, RecordReadError};
use tensorflow::Operation;
use tensorflow::SavedModelBundle;
use tensorflow::SessionOptions;
use tensorflow::SessionRunArgs;
use tensorflow::Tensor;
use tensorflow::{DataType, FetchToken, Graph, TensorInfo, TensorType};
use std::thread::sleep;
use std::time::Duration;
use crate::cli_args::{Args, ARGS, INPUTS, MODEL_SPECS, OUTPUTS};
use crate::tf_proto::tensorflow_serving::prediction_log::LogType;
use crate::tf_proto::tensorflow_serving::{PredictionLog, PredictLog};
use crate::tf_proto::ConfigProto;
use anyhow::{Context, Result};
use serde_json::Value;
use crate::TensorReturnEnum;
use crate::bootstrap::{TensorInput, TensorInputEnum};
use crate::metrics::{
INFERENCE_FAILED_REQUESTS_BY_MODEL, NUM_REQUESTS_FAILED, NUM_REQUESTS_FAILED_BY_MODEL,
};
use crate::predict_service::Model;
use crate::{MAX_NUM_INPUTS, utils};
#[derive(Debug)]
pub enum TFTensorEnum {
String(Tensor<String>),
Int(Tensor<i32>),
Int64(Tensor<i64>),
Float(Tensor<f32>),
Double(Tensor<f64>),
Boolean(Tensor<bool>),
}
#[derive(Debug)]
pub struct TFModel {
pub model_idx: usize,
pub bundle: SavedModelBundle,
pub input_names: ArrayVec<String, MAX_NUM_INPUTS>,
pub input_info: Vec<TensorInfo>,
pub input_ops: Vec<Operation>,
pub output_names: Vec<String>,
pub output_info: Vec<TensorInfo>,
pub output_ops: Vec<Operation>,
pub export_dir: String,
pub version: i64,
pub inter_op: i32,
pub intra_op: i32,
}
impl Display for TFModel {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"idx: {}, tensorflow model_name:{}, export_dir:{}, version:{}, inter:{}, intra:{}",
self.model_idx,
MODEL_SPECS[self.model_idx],
self.export_dir,
self.version,
self.inter_op,
self.intra_op
)
}
}
impl TFModel {
pub fn new(idx: usize, version: String, model_config: &Value) -> Result<TFModel> {
// Create input variables for our addition
let config = ConfigProto {
intra_op_parallelism_threads: utils::get_config_or(
model_config,
"intra_op_parallelism",
&ARGS.intra_op_parallelism[idx],
)
.parse()?,
inter_op_parallelism_threads: utils::get_config_or(
model_config,
"inter_op_parallelism",
&ARGS.inter_op_parallelism[idx],
)
.parse()?,
..Default::default()
};
let mut buf = Vec::new();
buf.reserve(config.encoded_len());
config.encode(&mut buf).unwrap();
let mut opts = SessionOptions::new();
opts.set_config(&buf)?;
let export_dir = format!("{}/{}", ARGS.model_dir[idx], version);
let mut graph = Graph::new();
let bundle = SavedModelBundle::load(&opts, ["serve"], &mut graph, &export_dir)
.context("error load model")?;
let signature = bundle
.meta_graph_def()
.get_signature(&ARGS.serving_sig[idx])
.context("error finding signature")?;
let input_names = INPUTS[idx]
.get_or_init(|| {
let input_spec = signature
.inputs()
.iter()
.map(|p| p.0.clone())
.collect::<ArrayVec<String, MAX_NUM_INPUTS>>();
info!(
"input not set from cli, now we set from model metadata:{:?}",
input_spec
);
input_spec
})
.clone();
let input_info = input_names
.iter()
.map(|i| {
signature
.get_input(i)
.context("error finding input op info")
.unwrap()
.clone()
})
.collect_vec();
let input_ops = input_info
.iter()
.map(|i| {
graph
.operation_by_name_required(&i.name().name)
.context("error finding input op")
.unwrap()
})
.collect_vec();
info!("Model Input size: {}", input_info.len());
let output_names = OUTPUTS[idx].to_vec().clone();
let output_info = output_names
.iter()
.map(|o| {
signature
.get_output(o)
.context("error finding output op info")
.unwrap()
.clone()
})
.collect_vec();
let output_ops = output_info
.iter()
.map(|o| {
graph
.operation_by_name_required(&o.name().name)
.context("error finding output op")
.unwrap()
})
.collect_vec();
let tf_model = TFModel {
model_idx: idx,
bundle,
input_names,
input_info,
input_ops,
output_names,
output_info,
output_ops,
export_dir,
version: Args::version_str_to_epoch(&version)?,
inter_op: config.inter_op_parallelism_threads,
intra_op: config.intra_op_parallelism_threads,
};
tf_model.warmup()?;
Ok(tf_model)
}
#[inline(always)]
fn get_tftensor_dimensions<T>(
t: &[T],
input_size: u64,
batch_size: u64,
input_dims: Option<Vec<i64>>,
) -> Vec<u64> {
// if input size is 1, we just specify a single dimension to outgoing tensor matching the
// size of the input tensor. This is for backwards compatiblity with existing Navi clients
// which specify input as a single string tensor (like tfexample) and use batching support.
let mut dims = vec![];
if input_size > 1 {
if batch_size == 1 && input_dims.is_some() {
// client side batching is enabled?
input_dims
.unwrap()
.iter()
.for_each(|axis| dims.push(*axis as u64));
} else {
dims.push(batch_size);
dims.push(t.len() as u64 / batch_size);
}
} else {
dims.push(t.len() as u64);
}
dims
}
fn convert_to_tftensor_enum(
input: TensorInput,
input_size: u64,
batch_size: u64,
) -> TFTensorEnum {
match input.tensor_data {
TensorInputEnum::String(t) => {
let strings = t
.into_iter()
.map(|x| unsafe { String::from_utf8_unchecked(x) })
.collect_vec();
TFTensorEnum::String(
Tensor::new(&TFModel::get_tftensor_dimensions(
strings.as_slice(),
input_size,
batch_size,
input.dims,
))
.with_values(strings.as_slice())
.unwrap(),
)
}
TensorInputEnum::Int(t) => TFTensorEnum::Int(
Tensor::new(&TFModel::get_tftensor_dimensions(
t.as_slice(),
input_size,
batch_size,
input.dims,
))
.with_values(t.as_slice())
.unwrap(),
),
TensorInputEnum::Int64(t) => TFTensorEnum::Int64(
Tensor::new(&TFModel::get_tftensor_dimensions(
t.as_slice(),
input_size,
batch_size,
input.dims,
))
.with_values(t.as_slice())
.unwrap(),
),
TensorInputEnum::Float(t) => TFTensorEnum::Float(
Tensor::new(&TFModel::get_tftensor_dimensions(
t.as_slice(),
input_size,
batch_size,
input.dims,
))
.with_values(t.as_slice())
.unwrap(),
),
TensorInputEnum::Double(t) => TFTensorEnum::Double(
Tensor::new(&TFModel::get_tftensor_dimensions(
t.as_slice(),
input_size,
batch_size,
input.dims,
))
.with_values(t.as_slice())
.unwrap(),
),
TensorInputEnum::Boolean(t) => TFTensorEnum::Boolean(
Tensor::new(&TFModel::get_tftensor_dimensions(
t.as_slice(),
input_size,
batch_size,
input.dims,
))
.with_values(t.as_slice())
.unwrap(),
),
}
}
fn fetch_output<T: TensorType>(
args: &mut SessionRunArgs,
token_output: &FetchToken,
batch_size: u64,
output_size: u64,
) -> (Tensor<T>, u64) {
let tensor_output = args.fetch::<T>(*token_output).expect("fetch output failed");
let mut tensor_width = tensor_output.dims()[1];
if batch_size == 1 && output_size > 1 {
tensor_width = tensor_output.dims().iter().fold(1, |mut total, &val| {
total *= val;
total
});
}
(tensor_output, tensor_width)
}
}
impl Model for TFModel {
fn warmup(&self) -> Result<()> {
// warm up
let warmup_file = format!(
"{}/assets.extra/tf_serving_warmup_requests",
self.export_dir
);
if std::path::Path::new(&warmup_file).exists() {
use std::io::Cursor;
info!(
"found warmup assets in {}, now perform warming up",
warmup_file
);
let f = std::fs::File::open(warmup_file).context("cannot open warmup file")?;
// let mut buf = Vec::new();
let read = std::io::BufReader::new(f);
let mut reader = RecordReader::new(read);
let mut warmup_cnt = 0;
loop {
let next = reader.read_next_owned();
match next {
Ok(res) => match res {
Some(vec) => {
// info!("read one tfRecord");
match PredictionLog::decode(&mut Cursor::new(vec))
.context("can't parse PredictonLog")?
{
PredictionLog {
log_metadata: _,
log_type:
Some(LogType::PredictLog(PredictLog {
request: Some(mut req),
response: _,
})),
} => {
if warmup_cnt == ARGS.max_warmup_records {
//warm up to max_warmup_records records
warn!(
"reached max warmup {} records, exit warmup for {}",
ARGS.max_warmup_records,
MODEL_SPECS[self.model_idx]
);
break;
}
self.do_predict(
vec![req.take_input_vals(&self.input_names)],
1,
);
sleep(Duration::from_millis(100));
warmup_cnt += 1;
}
_ => error!("some wrong record in warming up file"),
}
}
None => {
info!("end of warmup file, warmed up with records: {}", warmup_cnt);
break;
}
},
Err(RecordReadError::CorruptFile)
| Err(RecordReadError::IoError { .. }) => {
error!("read tfrecord error for warmup files, skip");
}
_ => {}
}
}
}
Ok(())
}
#[inline(always)]
fn do_predict(
&self,
input_tensors: Vec<Vec<TensorInput>>,
batch_size: u64,
) -> (Vec<TensorReturnEnum>, Vec<Vec<usize>>) {
// let mut batch_ends = input_tensors.iter().map(|t| t.len()).collect::<Vec<usize>>();
let output_size = self.output_names.len() as u64;
let input_size = self.input_names.len() as u64;
debug!(
"Request for Tensorflow with batch size: {} and input_size: {}",
batch_size, input_size
);
// build a set of input TF tensors
let batch_end = (1usize..=input_tensors.len() as usize)
.into_iter()
.collect_vec();
let mut batch_ends = vec![batch_end; output_size as usize];
let batched_tensors = TensorInputEnum::merge_batch(input_tensors)
.into_iter()
.enumerate()
.map(|(_, i)| TFModel::convert_to_tftensor_enum(i, input_size, batch_size))
.collect_vec();
let mut args = SessionRunArgs::new();
for (index, tf_tensor) in batched_tensors.iter().enumerate() {
match tf_tensor {
TFTensorEnum::String(inner) => args.add_feed(&self.input_ops[index], 0, inner),
TFTensorEnum::Int(inner) => args.add_feed(&self.input_ops[index], 0, inner),
TFTensorEnum::Int64(inner) => args.add_feed(&self.input_ops[index], 0, inner),
TFTensorEnum::Float(inner) => args.add_feed(&self.input_ops[index], 0, inner),
TFTensorEnum::Double(inner) => args.add_feed(&self.input_ops[index], 0, inner),
TFTensorEnum::Boolean(inner) => args.add_feed(&self.input_ops[index], 0, inner),
}
}
// For output ops, we receive the same op object by name. Actual tensor tokens are available at different offsets.
// Since indices are ordered, its important to specify output flag to Navi in the same order.
let token_outputs = self
.output_ops
.iter()
.enumerate()
.map(|(idx, op)| args.request_fetch(op, idx as i32))
.collect_vec();
match self.bundle.session.run(&mut args) {
Ok(_) => (),
Err(e) => {
NUM_REQUESTS_FAILED.inc_by(batch_size);
NUM_REQUESTS_FAILED_BY_MODEL
.with_label_values(&[&MODEL_SPECS[self.model_idx]])
.inc_by(batch_size);
INFERENCE_FAILED_REQUESTS_BY_MODEL
.with_label_values(&[&MODEL_SPECS[self.model_idx]])
.inc_by(batch_size);
panic!("{model}: {e:?}", model = MODEL_SPECS[self.model_idx], e = e);
}
}
let mut predict_return = vec![];
// Check the output.
for (index, token_output) in token_outputs.iter().enumerate() {
// same ops, with type info at different offsets.
let (res, width) = match self.output_ops[index].output_type(index) {
DataType::Float => {
let (tensor_output, tensor_width) =
TFModel::fetch_output(&mut args, token_output, batch_size, output_size);
(
TensorReturnEnum::FloatTensorReturn(Box::new(tensor_output)),
tensor_width,
)
}
DataType::Int64 => {
let (tensor_output, tensor_width) =
TFModel::fetch_output(&mut args, token_output, batch_size, output_size);
(
TensorReturnEnum::Int64TensorReturn(Box::new(tensor_output)),
tensor_width,
)
}
DataType::Int32 => {
let (tensor_output, tensor_width) =
TFModel::fetch_output(&mut args, token_output, batch_size, output_size);
(
TensorReturnEnum::Int32TensorReturn(Box::new(tensor_output)),
tensor_width,
)
}
DataType::String => {
let (tensor_output, tensor_width) =
TFModel::fetch_output(&mut args, token_output, batch_size, output_size);
(
TensorReturnEnum::StringTensorReturn(Box::new(tensor_output)),
tensor_width,
)
}
_ => panic!("Unsupported return type!"),
};
let width = width as usize;
for b in batch_ends[index].iter_mut() {
*b *= width;
}
predict_return.push(res)
}
//TODO: remove in the future
//TODO: support actual mtl model outputs
(predict_return, batch_ends)
}
#[inline(always)]
fn model_idx(&self) -> usize {
self.model_idx
}
#[inline(always)]
fn version(&self) -> i64 {
self.version
}
}
}

Binary file not shown.

View File

@ -1,183 +0,0 @@
#[cfg(feature = "torch")]
pub mod torch {
use std::fmt;
use std::fmt::Display;
use std::string::String;
use crate::TensorReturnEnum;
use crate::SerializedInput;
use crate::bootstrap::TensorInput;
use crate::cli_args::{Args, ARGS, MODEL_SPECS};
use crate::metrics;
use crate::metrics::{
INFERENCE_FAILED_REQUESTS_BY_MODEL, NUM_REQUESTS_FAILED, NUM_REQUESTS_FAILED_BY_MODEL,
};
use crate::predict_service::Model;
use anyhow::Result;
use dr_transform::converter::BatchPredictionRequestToTorchTensorConverter;
use dr_transform::converter::Converter;
use serde_json::Value;
use tch::Tensor;
use tch::{kind, CModule, IValue};
#[derive(Debug)]
pub struct TorchModel {
pub model_idx: usize,
pub version: i64,
pub module: CModule,
pub export_dir: String,
// FIXME: make this Box<Option<..>> so input converter can be optional.
// Also consider adding output_converter.
pub input_converter: Box<dyn Converter>,
}
impl Display for TorchModel {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"idx: {}, torch model_name:{}, version:{}",
self.model_idx, MODEL_SPECS[self.model_idx], self.version
)
}
}
impl TorchModel {
pub fn new(idx: usize, version: String, _model_config: &Value) -> Result<TorchModel> {
let export_dir = format!("{}/{}/model.pt", ARGS.model_dir[idx], version);
let model = CModule::load(&export_dir).unwrap();
let torch_model = TorchModel {
model_idx: idx,
version: Args::version_str_to_epoch(&version)?,
module: model,
export_dir,
//TODO: move converter lookup in a registry.
input_converter: Box::new(BatchPredictionRequestToTorchTensorConverter::new(
&ARGS.model_dir[idx].as_str(),
version.as_str(),
vec![],
Some(&metrics::register_dynamic_metrics),
)),
};
torch_model.warmup()?;
Ok(torch_model)
}
#[inline(always)]
pub fn decode_to_inputs(bytes: SerializedInput) -> Vec<Tensor> {
//FIXME: for now we generate 4 random tensors as inputs to unblock end to end testing
//when Shajan's decoder is ready we will swap
let row = bytes.len() as i64;
let t1 = Tensor::randn(&[row, 5293], kind::FLOAT_CPU); //continuous
let t2 = Tensor::randint(10, &[row, 149], kind::INT64_CPU); //binary
let t3 = Tensor::randint(10, &[row, 320], kind::INT64_CPU); //discrete
let t4 = Tensor::randn(&[row, 200], kind::FLOAT_CPU); //user_embedding
let t5 = Tensor::randn(&[row, 200], kind::FLOAT_CPU); //user_eng_embedding
let t6 = Tensor::randn(&[row, 200], kind::FLOAT_CPU); //author_embedding
vec![t1, t2, t3, t4, t5, t6]
}
#[inline(always)]
pub fn output_to_vec(res: IValue, dst: &mut Vec<f32>) {
match res {
IValue::Tensor(tensor) => TorchModel::tensors_to_vec(&[tensor], dst),
IValue::Tuple(ivalues) => {
TorchModel::tensors_to_vec(&TorchModel::ivalues_to_tensors(ivalues), dst)
}
_ => panic!("we only support output as a single tensor or a vec of tensors"),
}
}
#[inline(always)]
pub fn tensor_flatten_size(t: &Tensor) -> usize {
t.size().into_iter().fold(1, |acc, x| acc * x) as usize
}
#[inline(always)]
pub fn tensor_to_vec<T: kind::Element>(res: &Tensor) -> Vec<T> {
let size = TorchModel::tensor_flatten_size(res);
let mut res_f32: Vec<T> = Vec::with_capacity(size);
unsafe {
res_f32.set_len(size);
}
res.copy_data(res_f32.as_mut_slice(), size);
// println!("Copied tensor:{}, {:?}", res_f32.len(), res_f32);
res_f32
}
#[inline(always)]
pub fn tensors_to_vec(tensors: &[Tensor], dst: &mut Vec<f32>) {
let mut offset = dst.len();
tensors.iter().for_each(|t| {
let size = TorchModel::tensor_flatten_size(t);
let next_size = offset + size;
unsafe {
dst.set_len(next_size);
}
t.copy_data(&mut dst[offset..], size);
offset = next_size;
});
}
pub fn ivalues_to_tensors(ivalues: Vec<IValue>) -> Vec<Tensor> {
ivalues
.into_iter()
.map(|t| {
if let IValue::Tensor(vanilla_t) = t {
vanilla_t
} else {
panic!("not a tensor")
}
})
.collect::<Vec<Tensor>>()
}
}
impl Model for TorchModel {
fn warmup(&self) -> Result<()> {
Ok(())
}
//TODO: torch runtime needs some refactor to make it a generic interface
#[inline(always)]
fn do_predict(
&self,
input_tensors: Vec<Vec<TensorInput>>,
total_len: u64,
) -> (Vec<TensorReturnEnum>, Vec<Vec<usize>>) {
let mut buf: Vec<f32> = Vec::with_capacity(10_000);
let mut batch_ends = vec![0usize; input_tensors.len()];
for (i, batch_bytes_in_request) in input_tensors.into_iter().enumerate() {
for _ in batch_bytes_in_request.into_iter() {
//FIXME: for now use some hack
let model_input = TorchModel::decode_to_inputs(vec![0u8; 30]); //self.input_converter.convert(bytes);
let input_batch_tensors = model_input
.into_iter()
.map(|t| IValue::Tensor(t))
.collect::<Vec<IValue>>();
// match self.module.forward_is(&input_batch_tensors) {
match self.module.method_is("forward_serve", &input_batch_tensors) {
Ok(res) => TorchModel::output_to_vec(res, &mut buf),
Err(e) => {
NUM_REQUESTS_FAILED.inc_by(total_len);
NUM_REQUESTS_FAILED_BY_MODEL
.with_label_values(&[&MODEL_SPECS[self.model_idx]])
.inc_by(total_len);
INFERENCE_FAILED_REQUESTS_BY_MODEL
.with_label_values(&[&MODEL_SPECS[self.model_idx]])
.inc_by(total_len);
panic!("{model}: {e:?}", model = MODEL_SPECS[self.model_idx], e = e);
}
}
}
batch_ends[i] = buf.len();
}
(
vec![TensorReturnEnum::FloatTensorReturn(Box::new(buf))],
vec![batch_ends],
)
}
#[inline(always)]
fn model_idx(&self) -> usize {
self.model_idx
}
#[inline(always)]
fn version(&self) -> i64 {
self.version
}
}
}

BIN
navi/segdense/Cargo.docx Normal file

Binary file not shown.

View File

@ -1,11 +0,0 @@
[package]
name = "segdense"
version = "0.1.0"
edition = "2021"
[dependencies]
env_logger = "0.10.0"
serde = { version = "1.0.104", features = ["derive"] }
serde_json = "1.0.48"
log = "0.4.17"

Binary file not shown.

View File

@ -1,53 +0,0 @@
use std::fmt::Display;
/**
* Custom error
*/
#[derive(Debug)]
pub enum SegDenseError {
IoError(std::io::Error),
Json(serde_json::Error),
JsonMissingRoot,
JsonMissingObject,
JsonMissingArray,
JsonArraySize,
JsonMissingInputFeature,
}
impl Display for SegDenseError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
SegDenseError::IoError(io_error) => write!(f, "{}", io_error),
SegDenseError::Json(serde_json) => write!(f, "{}", serde_json),
SegDenseError::JsonMissingRoot => {
write!(f, "{}", "SegDense JSON: Root Node note found!")
}
SegDenseError::JsonMissingObject => {
write!(f, "{}", "SegDense JSON: Object note found!")
}
SegDenseError::JsonMissingArray => {
write!(f, "{}", "SegDense JSON: Array Node note found!")
}
SegDenseError::JsonArraySize => {
write!(f, "{}", "SegDense JSON: Array size not as expected!")
}
SegDenseError::JsonMissingInputFeature => {
write!(f, "{}", "SegDense JSON: Missing input feature!")
}
}
}
}
impl std::error::Error for SegDenseError {}
impl From<std::io::Error> for SegDenseError {
fn from(err: std::io::Error) -> Self {
SegDenseError::IoError(err)
}
}
impl From<serde_json::Error> for SegDenseError {
fn from(err: serde_json::Error) -> Self {
SegDenseError::Json(err)
}
}

BIN
navi/segdense/src/lib.docx Normal file

Binary file not shown.

View File

@ -1,4 +0,0 @@
pub mod error;
pub mod mapper;
pub mod segdense_transform_spec_home_recap_2022;
pub mod util;

BIN
navi/segdense/src/main.docx Normal file

Binary file not shown.

View File

@ -1,22 +0,0 @@
use std::env;
use std::fs;
use segdense::error::SegDenseError;
use segdense::util;
fn main() -> Result<(), SegDenseError> {
env_logger::init();
let args: Vec<String> = env::args().collect();
let schema_file_name: &str = if args.len() == 1 {
"json/compact.json"
} else {
&args[1]
};
let json_str = fs::read_to_string(schema_file_name)?;
util::safe_load_config(&json_str)?;
Ok(())
}

Binary file not shown.

View File

@ -1,45 +0,0 @@
use std::collections::HashMap;
#[derive(Debug)]
pub struct FeatureInfo {
pub tensor_index: i8,
pub index_within_tensor: i64,
}
pub static NULL_INFO: FeatureInfo = FeatureInfo {
tensor_index: -1,
index_within_tensor: -1,
};
#[derive(Debug, Default)]
pub struct FeatureMapper {
map: HashMap<i64, FeatureInfo>,
}
impl FeatureMapper {
pub fn new() -> FeatureMapper {
FeatureMapper {
map: HashMap::new(),
}
}
}
pub trait MapWriter {
fn set(&mut self, feature_id: i64, info: FeatureInfo);
}
pub trait MapReader {
fn get(&self, feature_id: &i64) -> Option<&FeatureInfo>;
}
impl MapWriter for FeatureMapper {
fn set(&mut self, feature_id: i64, info: FeatureInfo) {
self.map.insert(feature_id, info);
}
}
impl MapReader for FeatureMapper {
fn get(&self, feature_id: &i64) -> Option<&FeatureInfo> {
self.map.get(feature_id)
}
}

View File

@ -1,182 +0,0 @@
use serde::{Deserialize, Serialize};
use serde_json::Value;
#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct Root {
#[serde(rename = "common_prefix")]
pub common_prefix: String,
#[serde(rename = "densification_transform_spec")]
pub densification_transform_spec: DensificationTransformSpec,
#[serde(rename = "identity_transform_spec")]
pub identity_transform_spec: Vec<IdentityTransformSpec>,
#[serde(rename = "complex_feature_type_transform_spec")]
pub complex_feature_type_transform_spec: Vec<ComplexFeatureTypeTransformSpec>,
#[serde(rename = "input_features_map")]
pub input_features_map: Value,
}
#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct DensificationTransformSpec {
pub discrete: Discrete,
pub cont: Cont,
pub binary: Binary,
pub string: Value, // Use StringType
pub blob: Blob,
}
#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct Discrete {
pub tag: String,
#[serde(rename = "generic_feature_type")]
pub generic_feature_type: i64,
#[serde(rename = "feature_identifier")]
pub feature_identifier: String,
#[serde(rename = "fixed_length")]
pub fixed_length: i64,
#[serde(rename = "default_value")]
pub default_value: DefaultValue,
#[serde(rename = "input_features")]
pub input_features: Vec<InputFeature>,
}
#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct DefaultValue {
#[serde(rename = "type")]
pub type_field: String,
pub value: String,
}
#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct InputFeature {
#[serde(rename = "feature_id")]
pub feature_id: i64,
#[serde(rename = "full_feature_name")]
pub full_feature_name: String,
#[serde(rename = "feature_type")]
pub feature_type: i64,
pub index: i64,
#[serde(rename = "maybe_exclude")]
pub maybe_exclude: bool,
pub tag: String,
#[serde(rename = "added_at")]
pub added_at: i64,
}
#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct Cont {
pub tag: String,
#[serde(rename = "generic_feature_type")]
pub generic_feature_type: i64,
#[serde(rename = "feature_identifier")]
pub feature_identifier: String,
#[serde(rename = "fixed_length")]
pub fixed_length: i64,
#[serde(rename = "default_value")]
pub default_value: DefaultValue,
#[serde(rename = "input_features")]
pub input_features: Vec<InputFeature>,
}
#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct Binary {
pub tag: String,
#[serde(rename = "generic_feature_type")]
pub generic_feature_type: i64,
#[serde(rename = "feature_identifier")]
pub feature_identifier: String,
#[serde(rename = "fixed_length")]
pub fixed_length: i64,
#[serde(rename = "default_value")]
pub default_value: DefaultValue,
#[serde(rename = "input_features")]
pub input_features: Vec<InputFeature>,
}
#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct StringType {
pub tag: String,
#[serde(rename = "generic_feature_type")]
pub generic_feature_type: i64,
#[serde(rename = "feature_identifier")]
pub feature_identifier: String,
#[serde(rename = "fixed_length")]
pub fixed_length: i64,
#[serde(rename = "default_value")]
pub default_value: DefaultValue,
#[serde(rename = "input_features")]
pub input_features: Vec<InputFeature>,
}
#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct Blob {
pub tag: String,
#[serde(rename = "generic_feature_type")]
pub generic_feature_type: i64,
#[serde(rename = "feature_identifier")]
pub feature_identifier: String,
#[serde(rename = "fixed_length")]
pub fixed_length: i64,
#[serde(rename = "default_value")]
pub default_value: DefaultValue,
#[serde(rename = "input_features")]
pub input_features: Vec<Value>,
}
#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct IdentityTransformSpec {
#[serde(rename = "feature_id")]
pub feature_id: i64,
#[serde(rename = "full_feature_name")]
pub full_feature_name: String,
#[serde(rename = "feature_type")]
pub feature_type: i64,
}
#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ComplexFeatureTypeTransformSpec {
#[serde(rename = "feature_id")]
pub feature_id: i64,
#[serde(rename = "full_feature_name")]
pub full_feature_name: String,
#[serde(rename = "feature_type")]
pub feature_type: i64,
pub index: i64,
#[serde(rename = "maybe_exclude")]
pub maybe_exclude: bool,
pub tag: String,
#[serde(rename = "tensor_data_type")]
pub tensor_data_type: Option<i64>,
#[serde(rename = "added_at")]
pub added_at: i64,
#[serde(rename = "tensor_shape")]
#[serde(default)]
pub tensor_shape: Vec<i64>,
}
#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct InputFeatureMapRecord {
#[serde(rename = "feature_id")]
pub feature_id: i64,
#[serde(rename = "full_feature_name")]
pub full_feature_name: String,
#[serde(rename = "feature_type")]
pub feature_type: i64,
pub index: i64,
#[serde(rename = "maybe_exclude")]
pub maybe_exclude: bool,
pub tag: String,
#[serde(rename = "added_at")]
pub added_at: i64,
}

BIN
navi/segdense/src/util.docx Normal file

Binary file not shown.

View File

@ -1,154 +0,0 @@
use log::debug;
use std::fs;
use serde_json::{Map, Value};
use crate::error::SegDenseError;
use crate::mapper::{FeatureInfo, FeatureMapper, MapWriter};
use crate::segdense_transform_spec_home_recap_2022::{self as seg_dense, InputFeature};
pub fn load_config(file_name: &str) -> Result<seg_dense::Root, SegDenseError> {
let json_str = fs::read_to_string(file_name)?;
// &format!("Unable to load segdense file {}", file_name));
let seg_dense_config = parse(&json_str)?;
// &format!("Unable to parse segdense file {}", file_name));
Ok(seg_dense_config)
}
pub fn parse(json_str: &str) -> Result<seg_dense::Root, SegDenseError> {
let root: seg_dense::Root = serde_json::from_str(json_str)?;
Ok(root)
}
/**
* Given a json string containing a seg dense schema create a feature mapper
* which is essentially:
*
* {feature-id -> (Tensor Index, Index of feature within the tensor)}
*
* Feature id : 64 bit hash of the feature name used in DataRecords.
*
* Tensor Index : A vector of tensors is passed to the model. Tensor
* index refers to the tensor this feature is part of.
*
* Index of feature in tensor : The tensors are vectors, the index of
* feature is the position to put the feature value.
*
* There are many assumptions made in this function that is very model specific.
* These assumptions are called out below and need to be schematized eventually.
*
* Call this once for each segdense schema and cache the FeatureMapper.
*/
pub fn safe_load_config(json_str: &str) -> Result<FeatureMapper, SegDenseError> {
let root = parse(json_str)?;
load_from_parsed_config(root)
}
// Perf note : make 'root' un-owned
pub fn load_from_parsed_config(root: seg_dense::Root) -> Result<FeatureMapper, SegDenseError> {
let v = root.input_features_map;
// Do error check
let map: Map<String, Value> = match v {
Value::Object(map) => map,
_ => return Err(SegDenseError::JsonMissingObject),
};
let mut fm: FeatureMapper = FeatureMapper::new();
let items = map.values();
// Perf : Consider a way to avoid clone here
for item in items.cloned() {
let mut vec = match item {
Value::Array(v) => v,
_ => return Err(SegDenseError::JsonMissingArray),
};
if vec.len() != 1 {
return Err(SegDenseError::JsonArraySize);
}
let val = vec.pop().unwrap();
let input_feature: seg_dense::InputFeature = serde_json::from_value(val)?;
let feature_id = input_feature.feature_id;
let feature_info = to_feature_info(&input_feature);
match feature_info {
Some(info) => {
debug!("{:?}", info);
fm.set(feature_id, info)
}
None => (),
}
}
Ok(fm)
}
#[allow(dead_code)]
fn add_feature_info_to_mapper(
feature_mapper: &mut FeatureMapper,
input_features: &Vec<InputFeature>,
) {
for input_feature in input_features.iter() {
let feature_id = input_feature.feature_id;
let feature_info = to_feature_info(input_feature);
match feature_info {
Some(info) => {
debug!("{:?}", info);
feature_mapper.set(feature_id, info)
}
None => (),
}
}
}
pub fn to_feature_info(input_feature: &seg_dense::InputFeature) -> Option<FeatureInfo> {
if input_feature.maybe_exclude {
return None;
}
// This part needs to be schema driven
//
// tensor index : Which of these tensors this feature is part of
// [Continious, Binary, Discrete, User_embedding, user_eng_embedding, author_embedding]
// Note that this order is fixed/hardcoded here, and need to be schematized
//
let tensor_idx: i8 = match input_feature.feature_id {
// user.timelines.twhin_user_follow_embeddings.twhin_user_follow_embeddings
// Feature name is mapped to a feature-id value. The hardcoded values below correspond to a specific feature name.
-2550691008059411095 => 3,
// user.timelines.twhin_user_engagement_embeddings.twhin_user_engagement_embeddings
5390650078733277231 => 4,
// original_author.timelines.twhin_author_follow_embeddings.twhin_author_follow_embeddings
3223956748566688423 => 5,
_ => match input_feature.feature_type {
// feature_type : src/thrift/com/twitter/ml/api/data.thrift
// BINARY = 1, CONTINUOUS = 2, DISCRETE = 3,
// Map to slots in [Continious, Binary, Discrete, ..]
1 => 1,
2 => 0,
3 => 2,
_ => -1,
},
};
if input_feature.index < 0 {
return None;
}
// Handle this case later
if tensor_idx == -1 {
return None;
}
Some(FeatureInfo {
tensor_index: tensor_idx,
index_within_tensor: input_feature.index,
})
}

Binary file not shown.

View File

@ -1,8 +0,0 @@
[package]
name = "bpr_thrift"
description = "Thrift parser for Batch Prediction Request"
version = "0.1.0"
edition = "2021"
[dependencies]
thrift = "0.17.0"

Binary file not shown.

File diff suppressed because it is too large Load Diff

Binary file not shown.

View File

@ -1,78 +0,0 @@
// A feature value can be one of these
enum FeatureVal {
Empty,
U8Vector(Vec<u8>),
FloatVector(Vec<f32>),
}
// A Feture has a name and a value
// The name for now is 'id' of type string
// Eventually this needs to be flexible - example to accomodate feature-id
struct Feature {
id: String,
val: FeatureVal,
}
impl Feature {
fn new() -> Feature {
Feature {
id: String::new(),
val: FeatureVal::Empty
}
}
}
// A single inference record will have multiple features
struct Record {
fields: Vec<Feature>,
}
impl Record {
fn new() -> Record {
Record { fields: vec![] }
}
}
// This is the main API used by external components
// Given a serialized input, decode it into Records
fn decode(input: Vec<u8>) -> Vec<Record> {
// For helping define the interface
vec![get_random_record(), get_random_record()]
}
// Used for testing the API, will be eventually removed
fn get_random_record() -> Record {
let mut record: Record = Record::new();
let f1: Feature = Feature {
id: String::from("continuous_features"),
val: FeatureVal::FloatVector(vec![1.0f32; 2134]),
};
record.fields.push(f1);
let f2: Feature = Feature {
id: String::from("user_embedding"),
val: FeatureVal::FloatVector(vec![2.0f32; 200]),
};
record.fields.push(f2);
let f3: Feature = Feature {
id: String::from("author_embedding"),
val: FeatureVal::FloatVector(vec![3.0f32; 200]),
};
record.fields.push(f3);
let f4: Feature = Feature {
id: String::from("binary_features"),
val: FeatureVal::U8Vector(vec![4u8; 43]),
};
record.fields.push(f4);
record
}

Binary file not shown.

View File

@ -1,4 +0,0 @@
pub mod prediction_service;
pub mod data;
pub mod tensor;

Binary file not shown.

View File

@ -1,81 +0,0 @@
use std::collections::BTreeSet;
use std::collections::BTreeMap;
use bpr_thrift::data::DataRecord;
use bpr_thrift::prediction_service::BatchPredictionRequest;
use thrift::OrderedFloat;
use thrift::protocol::TBinaryInputProtocol;
use thrift::protocol::TSerializable;
use thrift::transport::TBufferChannel;
use thrift::Result;
fn main() {
let data_path = "/tmp/current/timelines/output-1";
let bin_data: Vec<u8> = std::fs::read(data_path).expect("Could not read file!");
println!("Length : {}", bin_data.len());
let mut bc = TBufferChannel::with_capacity(bin_data.len(), 0);
bc.set_readable_bytes(&bin_data);
let mut protocol = TBinaryInputProtocol::new(bc, true);
let result: Result<BatchPredictionRequest> =
BatchPredictionRequest::read_from_in_protocol(&mut protocol);
match result {
Ok(bpr) => logBP(bpr),
Err(err) => println!("Error {}", err),
}
}
fn logBP(bpr: BatchPredictionRequest) {
println!("-------[OUTPUT]---------------");
println!("data {:?}", bpr);
println!("------------------------------");
/*
let common = bpr.common_features;
let recs = bpr.individual_features_list;
println!("--------[Len : {}]------------------", recs.len());
println!("-------[COMMON]---------------");
match common {
Some(dr) => logDR(dr),
None => println!("None"),
}
println!("------------------------------");
for rec in recs {
logDR(rec);
}
println!("------------------------------");
*/
}
fn logDR(dr: DataRecord) {
println!("--------[DR]------------------");
match dr.binary_features {
Some(bf) => logBin(bf),
_ => (),
}
match dr.continuous_features {
Some(cf) => logCF(cf),
_ => (),
}
println!("------------------------------");
}
fn logBin(bin: BTreeSet<i64>) {
println!("B: {:?}", bin)
}
fn logCF(cf: BTreeMap<i64, OrderedFloat<f64>>) {
for (id, fs) in cf {
println!("C: {} -> [{}]", id, fs);
}
}

File diff suppressed because it is too large Load Diff

Binary file not shown.

File diff suppressed because it is too large Load Diff

BIN
product-mixer/README.docx Normal file

Binary file not shown.

View File

@ -1,41 +0,0 @@
Product Mixer
=============
## Overview
Product Mixer is a common service framework and set of libraries that make it easy to build,
iterate on, and own product surface areas. It consists of:
- **Core Libraries:** A set of libraries that enable you to build execution pipelines out of
reusable components. You define your logic in small, well-defined, reusable components and focus
on expressing the business logic you want to have. Then you can define easy to understand pipelines
that compose your components. Product Mixer handles the execution and monitoring of your pipelines
allowing you to focus on what really matters, your business logic.
- **Service Framework:** A common service skeleton for teams to host their Product Mixer products.
- **Component Library:** A shared library of components made by the Product Mixer Team, or
contributed by users. This enables you to both easily share the reusable components you make as well
as benefit from the work other teams have done by utilizing their shared components in the library.
## Architecture
The bulk of a Product Mixer can be broken down into Pipelines and Components. Components allow you
to break business logic into separate, standardized, reusable, testable, and easily composable
pieces, where each component has a well defined abstraction. Pipelines are essentially configuration
files specifying which Components should be used and when. This makes it easy to understand how your
code will execute while keeping it organized and structured in a maintainable way.
Requests first go to Product Pipelines, which are used to select which Mixer Pipeline or
Recommendation Pipeline to run for a given request. Each Mixer or Recommendation
Pipeline may run multiple Candidate Pipelines to fetch candidates to include in the response.
Mixer Pipelines combine the results of multiple heterogeneous Candidate Pipelines together
(e.g. ads, tweets, users) while Recommendation Pipelines are used to score (via Scoring Pipelines)
and rank the results of homogenous Candidate Pipelines so that the top ranked ones can be returned.
These pipelines also marshall candidates into a domain object and then into a transport object
to return to the caller.
Candidate Pipelines fetch candidates from underlying Candidate Sources and perform some basic
operations on the Candidates, such as filtering out unwanted candidates, applying decorations,
and hydrating features.

View File

@ -1,57 +0,0 @@
package com.twitter.product_mixer.component_library.candidate_source.account_recommendations_mixer
import com.twitter.account_recommendations_mixer.{thriftscala => t}
import com.twitter.product_mixer.component_library.model.candidate.UserCandidate
import com.twitter.product_mixer.core.feature.Feature
import com.twitter.product_mixer.core.feature.featuremap.FeatureMapBuilder
import com.twitter.product_mixer.core.functional_component.candidate_source.CandidateSourceWithExtractedFeatures
import com.twitter.product_mixer.core.functional_component.candidate_source.CandidatesWithSourceFeatures
import com.twitter.product_mixer.core.model.common.identifier.CandidateSourceIdentifier
import com.twitter.stitch.Stitch
import javax.inject.Inject
import javax.inject.Singleton
object WhoToFollowModuleHeaderFeature extends Feature[UserCandidate, t.Header]
object WhoToFollowModuleFooterFeature extends Feature[UserCandidate, Option[t.Footer]]
object WhoToFollowModuleDisplayOptionsFeature
extends Feature[UserCandidate, Option[t.DisplayOptions]]
@Singleton
class AccountRecommendationsMixerCandidateSource @Inject() (
accountRecommendationsMixer: t.AccountRecommendationsMixer.MethodPerEndpoint)
extends CandidateSourceWithExtractedFeatures[
t.AccountRecommendationsMixerRequest,
t.RecommendedUser
] {
override val identifier: CandidateSourceIdentifier =
CandidateSourceIdentifier(name = "AccountRecommendationsMixer")
override def apply(
request: t.AccountRecommendationsMixerRequest
): Stitch[CandidatesWithSourceFeatures[t.RecommendedUser]] = {
Stitch
.callFuture(accountRecommendationsMixer.getWtfRecommendations(request))
.map { response: t.WhoToFollowResponse =>
responseToCandidatesWithSourceFeatures(
response.userRecommendations,
response.header,
response.footer,
response.displayOptions)
}
}
private def responseToCandidatesWithSourceFeatures(
userRecommendations: Seq[t.RecommendedUser],
header: t.Header,
footer: Option[t.Footer],
displayOptions: Option[t.DisplayOptions],
): CandidatesWithSourceFeatures[t.RecommendedUser] = {
val features = FeatureMapBuilder()
.add(WhoToFollowModuleHeaderFeature, header)
.add(WhoToFollowModuleFooterFeature, footer)
.add(WhoToFollowModuleDisplayOptionsFeature, displayOptions)
.build()
CandidatesWithSourceFeatures(userRecommendations, features)
}
}

View File

@ -1,22 +0,0 @@
scala_library(
sources = ["*.scala"],
compiler_option_sets = ["fatal_warnings"],
platform = "java8",
strict_deps = True,
tags = ["bazel-compatible"],
dependencies = [
"account-recommendations-mixer/thrift/src/main/thrift:thrift-scala",
"finatra/inject/inject-core/src/main/scala/com/twitter/inject",
"product-mixer/component-library/src/main/scala/com/twitter/product_mixer/component_library/model/candidate",
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/functional_component/candidate_source",
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/pipeline/pipeline_failure",
"src/thrift/com/twitter/ads/adserver:adserver_common-scala",
"stitch/stitch-core",
],
exports = [
"account-recommendations-mixer/thrift/src/main/thrift:thrift-scala",
"finatra/inject/inject-core/src/main/scala/com/twitter/inject",
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/functional_component/candidate_source",
"stitch/stitch-core",
],
)

View File

@ -1,29 +0,0 @@
package com.twitter.product_mixer.component_library.candidate_source.ads
import com.twitter.adserver.thriftscala.AdImpression
import com.twitter.adserver.thriftscala.AdRequestParams
import com.twitter.adserver.thriftscala.AdRequestResponse
import com.twitter.product_mixer.core.functional_component.candidate_source.strato.StratoKeyFetcherSource
import com.twitter.product_mixer.core.model.common.identifier.CandidateSourceIdentifier
import com.twitter.strato.client.Fetcher
import com.twitter.strato.generated.client.ads.admixer.MakeAdRequestClientColumn
import javax.inject.Inject
import javax.inject.Singleton
@Singleton
class AdsProdStratoCandidateSource @Inject() (adsClient: MakeAdRequestClientColumn)
extends StratoKeyFetcherSource[
AdRequestParams,
AdRequestResponse,
AdImpression
] {
override val identifier: CandidateSourceIdentifier = CandidateSourceIdentifier("AdsProdStrato")
override val fetcher: Fetcher[AdRequestParams, Unit, AdRequestResponse] = adsClient.fetcher
override protected def stratoResultTransformer(
stratoResult: AdRequestResponse
): Seq[AdImpression] =
stratoResult.impressions
}

View File

@ -1,22 +0,0 @@
package com.twitter.product_mixer.component_library.candidate_source.ads
import com.twitter.adserver.thriftscala.AdImpression
import com.twitter.adserver.thriftscala.AdRequestParams
import com.twitter.adserver.thriftscala.NewAdServer
import com.twitter.product_mixer.core.functional_component.candidate_source.CandidateSource
import com.twitter.product_mixer.core.model.common.identifier.CandidateSourceIdentifier
import com.twitter.stitch.Stitch
import javax.inject.Inject
import javax.inject.Singleton
@Singleton
class AdsProdThriftCandidateSource @Inject() (
adServerClient: NewAdServer.MethodPerEndpoint)
extends CandidateSource[AdRequestParams, AdImpression] {
override val identifier: CandidateSourceIdentifier =
CandidateSourceIdentifier("AdsProdThrift")
override def apply(request: AdRequestParams): Stitch[Seq[AdImpression]] =
Stitch.callFuture(adServerClient.makeAdRequest(request)).map(_.impressions)
}

View File

@ -1,28 +0,0 @@
package com.twitter.product_mixer.component_library.candidate_source.ads
import com.twitter.adserver.thriftscala.AdImpression
import com.twitter.adserver.thriftscala.AdRequestParams
import com.twitter.adserver.thriftscala.AdRequestResponse
import com.twitter.product_mixer.core.functional_component.candidate_source.strato.StratoKeyFetcherSource
import com.twitter.product_mixer.core.model.common.identifier.CandidateSourceIdentifier
import com.twitter.strato.client.Fetcher
import com.twitter.strato.generated.client.ads.admixer.MakeAdRequestStagingClientColumn
import javax.inject.Inject
import javax.inject.Singleton
@Singleton
class AdsStagingCandidateSource @Inject() (adsClient: MakeAdRequestStagingClientColumn)
extends StratoKeyFetcherSource[
AdRequestParams,
AdRequestResponse,
AdImpression
] {
override val identifier: CandidateSourceIdentifier = CandidateSourceIdentifier("AdsStaging")
override val fetcher: Fetcher[AdRequestParams, Unit, AdRequestResponse] = adsClient.fetcher
override protected def stratoResultTransformer(
stratoResult: AdRequestResponse
): Seq[AdImpression] =
stratoResult.impressions
}

View File

@ -1,18 +0,0 @@
scala_library(
sources = ["*.scala"],
compiler_option_sets = ["fatal_warnings"],
strict_deps = True,
tags = ["bazel-compatible"],
dependencies = [
"3rdparty/jvm/javax/inject:javax.inject",
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/functional_component/candidate_source/strato",
"src/thrift/com/twitter/ads/adserver:adserver_common-scala",
"src/thrift/com/twitter/ads/adserver:adserver_rpc-scala",
"strato/config/columns/ads/admixer:admixer-strato-client",
],
exports = [
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/functional_component/candidate_source/strato",
"src/thrift/com/twitter/ads/adserver:adserver_common-scala",
"src/thrift/com/twitter/ads/adserver:adserver_rpc-scala",
],
)

View File

@ -1,43 +0,0 @@
package com.twitter.product_mixer.component_library.candidate_source.ann
import com.twitter.ann.common._
import com.twitter.product_mixer.core.functional_component.candidate_source.CandidateSource
import com.twitter.product_mixer.core.model.common.identifier.CandidateSourceIdentifier
import com.twitter.stitch.Stitch
import com.twitter.util.{Time => _, _}
import com.twitter.finagle.util.DefaultTimer
/**
* @param annQueryableById Ann Queryable by Id client that returns nearest neighbors for a sequence of queries
* @param identifier Candidate Source Identifier
* @tparam T1 type of the query.
* @tparam T2 type of the result.
* @tparam P runtime parameters supported by the index.
* @tparam D distance function used in the index.
*/
class AnnCandidateSource[T1, T2, P <: RuntimeParams, D <: Distance[D]](
val annQueryableById: QueryableById[T1, T2, P, D],
val batchSize: Int,
val timeoutPerRequest: Duration,
override val identifier: CandidateSourceIdentifier)
extends CandidateSource[AnnIdQuery[T1, P], NeighborWithDistanceWithSeed[T1, T2, D]] {
implicit val timer = DefaultTimer
override def apply(
request: AnnIdQuery[T1, P]
): Stitch[Seq[NeighborWithDistanceWithSeed[T1, T2, D]]] = {
val ids = request.ids
val numOfNeighbors = request.numOfNeighbors
val runtimeParams = request.runtimeParams
Stitch
.collect(
ids
.grouped(batchSize).map { batchedIds =>
annQueryableById
.batchQueryWithDistanceById(batchedIds, numOfNeighbors, runtimeParams).map {
annResult => annResult.toSeq
}.within(timeoutPerRequest).handle { case _ => Seq.empty }
}.toSeq).map(_.flatten)
}
}

View File

@ -1,18 +0,0 @@
package com.twitter.product_mixer.component_library.candidate_source.ann
import com.twitter.ann.common._
/**
* A [[AnnIdQuery]] is a query class which defines the ann entities with runtime params and number of neighbors requested
*
* @param ids Sequence of queries
* @param numOfNeighbors Number of neighbors requested
* @param runtimeParams ANN Runtime Params
* @param batchSize Batch size to the stitch client
* @tparam T type of query.
* @tparam P runtime parameters supported by the index.
*/
case class AnnIdQuery[T, P <: RuntimeParams](
ids: Seq[T],
numOfNeighbors: Int,
runtimeParams: P)

View File

@ -1,17 +0,0 @@
scala_library(
sources = ["*.scala"],
compiler_option_sets = ["fatal_warnings"],
strict_deps = True,
tags = ["bazel-compatible"],
dependencies = [
"ann/src/main/scala/com/twitter/ann/common",
"ann/src/main/scala/com/twitter/ann/hnsw",
"ann/src/main/thrift/com/twitter/ann/common:ann-common-scala",
"product-mixer/component-library/src/main/thrift/com/twitter/product_mixer/component_library:thrift-scala",
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/functional_component/candidate_source",
"servo/manhattan/src/main/scala",
"servo/repo/src/main/scala",
"servo/util/src/main/scala",
"stitch/stitch-core",
],
)

View File

@ -1,14 +0,0 @@
scala_library(
sources = ["*.scala"],
compiler_option_sets = ["fatal_warnings"],
strict_deps = True,
tags = ["bazel-compatible"],
dependencies = [
"product-mixer/component-library/src/main/scala/com/twitter/product_mixer/component_library/model/cursor",
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/functional_component/candidate_source/strato",
"src/thrift/com/twitter/periscope/audio_space:audio_space-scala",
"strato/config/columns/periscope:periscope-strato-client",
"strato/config/src/thrift/com/twitter/strato/graphql:graphql-scala",
"strato/src/main/scala/com/twitter/strato/client",
],
)

View File

@ -1,49 +0,0 @@
package com.twitter.product_mixer.component_library.candidate_source.audiospace
import com.twitter.periscope.audio_space.thriftscala.CreatedSpacesView
import com.twitter.periscope.audio_space.thriftscala.SpaceSlice
import com.twitter.product_mixer.component_library.model.cursor.NextCursorFeature
import com.twitter.product_mixer.component_library.model.cursor.PreviousCursorFeature
import com.twitter.product_mixer.core.feature.featuremap.FeatureMap
import com.twitter.product_mixer.core.feature.featuremap.FeatureMapBuilder
import com.twitter.product_mixer.core.functional_component.candidate_source.strato.StratoKeyViewFetcherWithSourceFeaturesSource
import com.twitter.product_mixer.core.model.common.identifier.CandidateSourceIdentifier
import com.twitter.strato.client.Fetcher
import com.twitter.strato.generated.client.periscope.CreatedSpacesSliceOnUserClientColumn
import javax.inject.Inject
import javax.inject.Singleton
@Singleton
class CreatedSpacesCandidateSource @Inject() (
column: CreatedSpacesSliceOnUserClientColumn)
extends StratoKeyViewFetcherWithSourceFeaturesSource[
Long,
CreatedSpacesView,
SpaceSlice,
String
] {
override val identifier: CandidateSourceIdentifier = CandidateSourceIdentifier("CreatedSpaces")
override val fetcher: Fetcher[Long, CreatedSpacesView, SpaceSlice] = column.fetcher
override def stratoResultTransformer(
stratoKey: Long,
stratoResult: SpaceSlice
): Seq[String] =
stratoResult.items
override protected def extractFeaturesFromStratoResult(
stratoKey: Long,
stratoResult: SpaceSlice
): FeatureMap = {
val featureMapBuilder = FeatureMapBuilder()
stratoResult.sliceInfo.previousCursor.foreach { cursor =>
featureMapBuilder.add(PreviousCursorFeature, cursor)
}
stratoResult.sliceInfo.nextCursor.foreach { cursor =>
featureMapBuilder.add(NextCursorFeature, cursor)
}
featureMapBuilder.build()
}
}

View File

@ -1,13 +0,0 @@
scala_library(
sources = ["*.scala"],
compiler_option_sets = ["fatal_warnings"],
strict_deps = True,
tags = ["bazel-compatible"],
dependencies = [
"product-mixer/component-library/src/main/scala/com/twitter/product_mixer/component_library/model/cursor",
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/functional_component/candidate_source/strato",
"strato/config/columns/consumer-identity/business-profiles:business-profiles-strato-client",
"strato/config/src/thrift/com/twitter/strato/graphql:graphql-scala",
"strato/src/main/scala/com/twitter/strato/client",
],
)

View File

@ -1,53 +0,0 @@
package com.twitter.product_mixer.component_library.candidate_source.business_profiles
import com.twitter.product_mixer.component_library.model.cursor.NextCursorFeature
import com.twitter.product_mixer.component_library.model.cursor.PreviousCursorFeature
import com.twitter.product_mixer.core.feature.featuremap.FeatureMap
import com.twitter.product_mixer.core.feature.featuremap.FeatureMapBuilder
import com.twitter.product_mixer.core.functional_component.candidate_source.strato.StratoKeyViewFetcherWithSourceFeaturesSource
import com.twitter.product_mixer.core.model.common.identifier.CandidateSourceIdentifier
import com.twitter.strato.client.Fetcher
import com.twitter.strato.generated.client.consumer_identity.business_profiles.BusinessProfileTeamMembersOnUserClientColumn
import com.twitter.strato.generated.client.consumer_identity.business_profiles.BusinessProfileTeamMembersOnUserClientColumn.{
Value => TeamMembersSlice
}
import com.twitter.strato.generated.client.consumer_identity.business_profiles.BusinessProfileTeamMembersOnUserClientColumn.{
View => TeamMembersView
}
import javax.inject.Inject
import javax.inject.Singleton
@Singleton
class TeamMembersCandidateSource @Inject() (
column: BusinessProfileTeamMembersOnUserClientColumn)
extends StratoKeyViewFetcherWithSourceFeaturesSource[
Long,
TeamMembersView,
TeamMembersSlice,
Long
] {
override val identifier: CandidateSourceIdentifier = CandidateSourceIdentifier(
"BusinessProfileTeamMembers")
override val fetcher: Fetcher[Long, TeamMembersView, TeamMembersSlice] = column.fetcher
override def stratoResultTransformer(
stratoKey: Long,
stratoResult: TeamMembersSlice
): Seq[Long] =
stratoResult.members
override protected def extractFeaturesFromStratoResult(
stratoKey: Long,
stratoResult: TeamMembersSlice
): FeatureMap = {
val featureMapBuilder = FeatureMapBuilder()
stratoResult.previousCursor.foreach { cursor =>
featureMapBuilder.add(PreviousCursorFeature, cursor.toString)
}
stratoResult.nextCursor.foreach { cursor =>
featureMapBuilder.add(NextCursorFeature, cursor.toString)
}
featureMapBuilder.build()
}
}

View File

@ -1,12 +0,0 @@
scala_library(
sources = ["*.scala"],
compiler_option_sets = ["fatal_warnings"],
strict_deps = True,
tags = ["bazel-compatible"],
dependencies = [
"cr-mixer/thrift/src/main/thrift:thrift-scala",
"finatra/inject/inject-core/src/main/scala/com/twitter/inject",
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/functional_component/candidate_source",
"stitch/stitch-core",
],
)

View File

@ -1,25 +0,0 @@
package com.twitter.product_mixer.component_library.candidate_source.cr_mixer
import com.twitter.cr_mixer.{thriftscala => t}
import com.twitter.product_mixer.core.functional_component.candidate_source.CandidateSource
import com.twitter.product_mixer.core.model.common.identifier.CandidateSourceIdentifier
import com.twitter.stitch.Stitch
import javax.inject.Inject
import javax.inject.Singleton
/**
* Returns out-of-network Tweet recommendations by using user recommendations
* from FollowRecommendationService as an input seed-set to Earlybird
*/
@Singleton
class CrMixerFrsBasedTweetRecommendationsCandidateSource @Inject() (
crMixerClient: t.CrMixer.MethodPerEndpoint)
extends CandidateSource[t.FrsTweetRequest, t.FrsTweet] {
override val identifier: CandidateSourceIdentifier =
CandidateSourceIdentifier("CrMixerFrsBasedTweetRecommendations")
override def apply(request: t.FrsTweetRequest): Stitch[Seq[t.FrsTweet]] = Stitch
.callFuture(crMixerClient.getFrsBasedTweetRecommendations(request))
.map(_.tweets)
}

View File

@ -1,21 +0,0 @@
package com.twitter.product_mixer.component_library.candidate_source.cr_mixer
import com.twitter.cr_mixer.{thriftscala => t}
import com.twitter.product_mixer.core.functional_component.candidate_source.CandidateSource
import com.twitter.product_mixer.core.model.common.identifier.CandidateSourceIdentifier
import com.twitter.stitch.Stitch
import javax.inject.Inject
import javax.inject.Singleton
@Singleton
class CrMixerTweetRecommendationsCandidateSource @Inject() (
crMixerClient: t.CrMixer.MethodPerEndpoint)
extends CandidateSource[t.CrMixerTweetRequest, t.TweetRecommendation] {
override val identifier: CandidateSourceIdentifier =
CandidateSourceIdentifier("CrMixerTweetRecommendations")
override def apply(request: t.CrMixerTweetRequest): Stitch[Seq[t.TweetRecommendation]] = Stitch
.callFuture(crMixerClient.getTweetRecommendations(request))
.map(_.tweets)
}

View File

@ -1,12 +0,0 @@
scala_library(
sources = ["*.scala"],
compiler_option_sets = ["fatal_warnings"],
strict_deps = True,
tags = ["bazel-compatible"],
dependencies = [
"finatra/inject/inject-core/src/main/scala/com/twitter/inject",
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/functional_component/candidate_source",
"src/thrift/com/twitter/search:earlybird-scala",
"stitch/stitch-core",
],
)

View File

@ -1,26 +0,0 @@
package com.twitter.product_mixer.component_library.candidate_source.earlybird
import com.twitter.search.earlybird.{thriftscala => t}
import com.twitter.inject.Logging
import com.twitter.product_mixer.core.functional_component.candidate_source.CandidateSource
import com.twitter.product_mixer.core.model.common.identifier.CandidateSourceIdentifier
import com.twitter.stitch.Stitch
import javax.inject.Inject
import javax.inject.Singleton
@Singleton
class EarlybirdTweetCandidateSource @Inject() (
earlybirdService: t.EarlybirdService.MethodPerEndpoint)
extends CandidateSource[t.EarlybirdRequest, t.ThriftSearchResult]
with Logging {
override val identifier: CandidateSourceIdentifier = CandidateSourceIdentifier("EarlybirdTweets")
override def apply(request: t.EarlybirdRequest): Stitch[Seq[t.ThriftSearchResult]] = {
Stitch
.callFuture(earlybirdService.search(request))
.map { response: t.EarlybirdResponse =>
response.searchResults.map(_.results).getOrElse(Seq.empty)
}
}
}

View File

@ -1,12 +0,0 @@
scala_library(
sources = ["*.scala"],
compiler_option_sets = ["fatal_warnings"],
strict_deps = True,
tags = ["bazel-compatible"],
dependencies = [
"3rdparty/jvm/javax/inject:javax.inject",
"explore/explore-ranker/thrift/src/main/thrift:thrift-scala",
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/functional_component/candidate_source",
"stitch/stitch-core",
],
)

View File

@ -1,31 +0,0 @@
package com.twitter.product_mixer.component_library.candidate_source.explore_ranker
import com.twitter.explore_ranker.{thriftscala => t}
import com.twitter.product_mixer.core.functional_component.candidate_source.CandidateSource
import com.twitter.product_mixer.core.model.common.identifier.CandidateSourceIdentifier
import com.twitter.stitch.Stitch
import javax.inject.Inject
import javax.inject.Singleton
@Singleton
class ExploreRankerCandidateSource @Inject() (
exploreRankerService: t.ExploreRanker.MethodPerEndpoint)
extends CandidateSource[t.ExploreRankerRequest, t.ImmersiveRecsResult] {
override val identifier: CandidateSourceIdentifier = CandidateSourceIdentifier("ExploreRanker")
override def apply(
request: t.ExploreRankerRequest
): Stitch[Seq[t.ImmersiveRecsResult]] = {
Stitch
.callFuture(exploreRankerService.getRankedResults(request))
.map {
case t.ExploreRankerResponse(
t.ExploreRankerProductResponse
.ImmersiveRecsResponse(t.ImmersiveRecsResponse(immersiveRecsResults))) =>
immersiveRecsResults
case response =>
throw new UnsupportedOperationException(s"Unknown response type: $response")
}
}
}

View File

@ -1,17 +0,0 @@
scala_library(
sources = ["*.scala"],
compiler_option_sets = ["fatal_warnings"],
platform = "java8",
strict_deps = True,
tags = ["bazel-compatible"],
dependencies = [
"finatra/inject/inject-core/src/main/scala/com/twitter/inject",
"onboarding/service/thrift/src/main/thrift:thrift-scala",
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/functional_component/candidate_source",
"stitch/stitch-core",
],
exports = [
"onboarding/service/thrift/src/main/thrift:thrift-scala",
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/functional_component/candidate_source",
],
)

View File

@ -1,50 +0,0 @@
package com.twitter.product_mixer.component_library.candidate_source.flexible_injection_pipeline
import com.twitter.inject.Logging
import com.twitter.onboarding.injections.{thriftscala => injectionsthrift}
import com.twitter.onboarding.task.service.{thriftscala => servicethrift}
import com.twitter.product_mixer.core.functional_component.candidate_source.CandidateSource
import com.twitter.product_mixer.core.model.common.identifier.CandidateSourceIdentifier
import com.twitter.stitch.Stitch
import javax.inject.Inject
import javax.inject.Singleton
/**
* Returns a list of prompts to insert into a user's timeline (inline prompt, cover modals, etc)
* from go/flip (the prompting platform for Twitter).
*/
@Singleton
class PromptCandidateSource @Inject() (taskService: servicethrift.TaskService.MethodPerEndpoint)
extends CandidateSource[servicethrift.GetInjectionsRequest, IntermediatePrompt]
with Logging {
override val identifier: CandidateSourceIdentifier = CandidateSourceIdentifier(
"InjectionPipelinePrompts")
override def apply(
request: servicethrift.GetInjectionsRequest
): Stitch[Seq[IntermediatePrompt]] = {
Stitch
.callFuture(taskService.getInjections(request)).map {
_.injections.flatMap {
// The entire carousel is getting added to each IntermediatePrompt item with a
// corresponding index to be unpacked later on to populate its TimelineEntry counterpart.
case injection: injectionsthrift.Injection.TilesCarousel =>
injection.tilesCarousel.tiles.zipWithIndex.map {
case (tile: injectionsthrift.Tile, index: Int) =>
IntermediatePrompt(injection, Some(index), Some(tile))
}
case injection => Seq(IntermediatePrompt(injection, None, None))
}
}
}
}
/**
* Gives an intermediate step to help 'explosion' of tile carousel tiles due to TimelineModule
* not being an extension of TimelineItem
*/
case class IntermediatePrompt(
injection: injectionsthrift.Injection,
offsetInModule: Option[Int],
carouselTile: Option[injectionsthrift.Tile])

View File

@ -1,16 +0,0 @@
scala_library(
sources = ["*.scala"],
compiler_option_sets = ["fatal_warnings"],
strict_deps = True,
tags = ["bazel-compatible"],
dependencies = [
"3rdparty/jvm/javax/inject:javax.inject",
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/functional_component/candidate_source/strato",
"src/thrift/com/twitter/hermit:hermit-scala",
"strato/config/columns/onboarding:onboarding-strato-client",
],
exports = [
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/functional_component/candidate_source/strato",
"src/thrift/com/twitter/hermit:hermit-scala",
],
)

Some files were not shown because too many files have changed in this diff Show More