mirror of
https://github.com/twitter/the-algorithm.git
synced 2024-11-14 15:45:13 +01:00
[docx] split commit for file 2000
Signed-off-by: Ari Archer <ari.web.xyz@gmail.com>
This commit is contained in:
parent
65c3a3fe90
commit
2488f40edf
BIN
navi/navi/src/bin/navi.docx
Normal file
BIN
navi/navi/src/bin/navi.docx
Normal file
Binary file not shown.
@ -1,47 +0,0 @@
|
|||||||
use anyhow::Result;
|
|
||||||
use log::info;
|
|
||||||
use navi::cli_args::{ARGS, MODEL_SPECS};
|
|
||||||
use navi::cores::validator::validatior::cli_validator;
|
|
||||||
use navi::tf_model::tf::TFModel;
|
|
||||||
use navi::{bootstrap, metrics};
|
|
||||||
use sha256::digest;
|
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
|
||||||
env_logger::init();
|
|
||||||
cli_validator::validate_input_args();
|
|
||||||
//only validate in for tf as other models don't have this
|
|
||||||
assert_eq!(MODEL_SPECS.len(), ARGS.serving_sig.len());
|
|
||||||
metrics::register_custom_metrics();
|
|
||||||
|
|
||||||
//load all the custom ops - comma seperaed
|
|
||||||
if let Some(ref customops_lib) = ARGS.customops_lib {
|
|
||||||
for op_lib in customops_lib.split(",") {
|
|
||||||
load_custom_op(op_lib);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// versioning the customop so library
|
|
||||||
bootstrap::bootstrap(TFModel::new)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn load_custom_op(lib_path: &str) -> () {
|
|
||||||
let res = tensorflow::Library::load(lib_path);
|
|
||||||
info!("{} load status:{:?}", lib_path, res);
|
|
||||||
let customop_version_num = get_custom_op_version(lib_path);
|
|
||||||
// Last OP version is recorded
|
|
||||||
metrics::CUSTOMOP_VERSION.set(customop_version_num);
|
|
||||||
}
|
|
||||||
|
|
||||||
//fn get_custom_op_version(customops_lib: &String) -> i64 {
|
|
||||||
fn get_custom_op_version(customops_lib: &str) -> i64 {
|
|
||||||
let customop_bytes = std::fs::read(customops_lib).unwrap(); // Vec<u8>
|
|
||||||
let customop_hash = digest(customop_bytes.as_slice());
|
|
||||||
//conver the last 4 hex digits to version number as prometheus metrics doesn't support string, the total space is 16^4 == 65536
|
|
||||||
let customop_version_num =
|
|
||||||
i64::from_str_radix(&customop_hash[customop_hash.len() - 4..], 16).unwrap();
|
|
||||||
info!(
|
|
||||||
"customop hash: {}, version_number: {}",
|
|
||||||
customop_hash, customop_version_num
|
|
||||||
);
|
|
||||||
customop_version_num
|
|
||||||
}
|
|
BIN
navi/navi/src/bin/navi_onnx.docx
Normal file
BIN
navi/navi/src/bin/navi_onnx.docx
Normal file
Binary file not shown.
@ -1,24 +0,0 @@
|
|||||||
use anyhow::Result;
|
|
||||||
use log::info;
|
|
||||||
use navi::cli_args::{ARGS, MODEL_SPECS};
|
|
||||||
use navi::onnx_model::onnx::OnnxModel;
|
|
||||||
use navi::{bootstrap, metrics};
|
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
|
||||||
env_logger::init();
|
|
||||||
info!("global: {:?}", ARGS.onnx_global_thread_pool_options);
|
|
||||||
let assert_session_params = if ARGS.onnx_global_thread_pool_options.is_empty() {
|
|
||||||
// std::env::set_var("OMP_NUM_THREADS", "1");
|
|
||||||
info!("now we use per session thread pool");
|
|
||||||
MODEL_SPECS.len()
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
info!("now we use global thread pool");
|
|
||||||
0
|
|
||||||
};
|
|
||||||
assert_eq!(assert_session_params, ARGS.inter_op_parallelism.len());
|
|
||||||
assert_eq!(assert_session_params, ARGS.inter_op_parallelism.len());
|
|
||||||
|
|
||||||
metrics::register_custom_metrics();
|
|
||||||
bootstrap::bootstrap(OnnxModel::new)
|
|
||||||
}
|
|
BIN
navi/navi/src/bin/navi_torch.docx
Normal file
BIN
navi/navi/src/bin/navi_torch.docx
Normal file
Binary file not shown.
@ -1,19 +0,0 @@
|
|||||||
use anyhow::Result;
|
|
||||||
use log::info;
|
|
||||||
use navi::cli_args::ARGS;
|
|
||||||
use navi::metrics;
|
|
||||||
use navi::torch_model::torch::TorchModel;
|
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
|
||||||
env_logger::init();
|
|
||||||
//torch only has global threadpool settings versus tf has per model threadpool settings
|
|
||||||
assert_eq!(1, ARGS.inter_op_parallelism.len());
|
|
||||||
assert_eq!(1, ARGS.intra_op_parallelism.len());
|
|
||||||
//TODO for now we, we assume each model's output has only 1 tensor.
|
|
||||||
//this will be lifted once torch_model properly implements mtl outputs
|
|
||||||
tch::set_num_interop_threads(ARGS.inter_op_parallelism[0].parse()?);
|
|
||||||
tch::set_num_threads(ARGS.intra_op_parallelism[0].parse()?);
|
|
||||||
info!("torch custom ops not used for now");
|
|
||||||
metrics::register_custom_metrics();
|
|
||||||
navi::bootstrap::bootstrap(TorchModel::new)
|
|
||||||
}
|
|
BIN
navi/navi/src/bootstrap.docx
Normal file
BIN
navi/navi/src/bootstrap.docx
Normal file
Binary file not shown.
@ -1,326 +0,0 @@
|
|||||||
use anyhow::Result;
|
|
||||||
use log::{info, warn};
|
|
||||||
use x509_parser::{prelude::{parse_x509_pem}, parse_x509_certificate};
|
|
||||||
use std::collections::HashMap;
|
|
||||||
use tokio::time::Instant;
|
|
||||||
use tonic::{
|
|
||||||
Request,
|
|
||||||
Response, Status, transport::{Certificate, Identity, Server, ServerTlsConfig},
|
|
||||||
};
|
|
||||||
|
|
||||||
// protobuf related
|
|
||||||
use crate::tf_proto::tensorflow_serving::{
|
|
||||||
ClassificationRequest, ClassificationResponse, GetModelMetadataRequest,
|
|
||||||
GetModelMetadataResponse, MultiInferenceRequest, MultiInferenceResponse, PredictRequest,
|
|
||||||
PredictResponse, RegressionRequest, RegressionResponse,
|
|
||||||
};
|
|
||||||
use crate::{kf_serving::{
|
|
||||||
grpc_inference_service_server::GrpcInferenceService, ModelInferRequest, ModelInferResponse,
|
|
||||||
ModelMetadataRequest, ModelMetadataResponse, ModelReadyRequest, ModelReadyResponse,
|
|
||||||
ServerLiveRequest, ServerLiveResponse, ServerMetadataRequest, ServerMetadataResponse,
|
|
||||||
ServerReadyRequest, ServerReadyResponse,
|
|
||||||
}, ModelFactory, tf_proto::tensorflow_serving::prediction_service_server::{
|
|
||||||
PredictionService, PredictionServiceServer,
|
|
||||||
}, VERSION, NAME};
|
|
||||||
|
|
||||||
use crate::PredictResult;
|
|
||||||
use crate::cli_args::{ARGS, INPUTS, OUTPUTS};
|
|
||||||
use crate::metrics::{
|
|
||||||
NAVI_VERSION, NUM_PREDICTIONS, NUM_REQUESTS_FAILED, NUM_REQUESTS_FAILED_BY_MODEL,
|
|
||||||
NUM_REQUESTS_RECEIVED, NUM_REQUESTS_RECEIVED_BY_MODEL, RESPONSE_TIME_COLLECTOR,
|
|
||||||
CERT_EXPIRY_EPOCH
|
|
||||||
};
|
|
||||||
use crate::predict_service::{Model, PredictService};
|
|
||||||
use crate::tf_proto::tensorflow_serving::model_spec::VersionChoice::Version;
|
|
||||||
use crate::tf_proto::tensorflow_serving::ModelSpec;
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub enum TensorInputEnum {
|
|
||||||
String(Vec<Vec<u8>>),
|
|
||||||
Int(Vec<i32>),
|
|
||||||
Int64(Vec<i64>),
|
|
||||||
Float(Vec<f32>),
|
|
||||||
Double(Vec<f64>),
|
|
||||||
Boolean(Vec<bool>),
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct TensorInput {
|
|
||||||
pub tensor_data: TensorInputEnum,
|
|
||||||
pub name: String,
|
|
||||||
pub dims: Option<Vec<i64>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl TensorInput {
|
|
||||||
pub fn new(tensor_data: TensorInputEnum, name: String, dims: Option<Vec<i64>>) -> TensorInput {
|
|
||||||
TensorInput {
|
|
||||||
tensor_data,
|
|
||||||
name,
|
|
||||||
dims,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl TensorInputEnum {
|
|
||||||
#[inline(always)]
|
|
||||||
pub(crate) fn extend(&mut self, another: TensorInputEnum) {
|
|
||||||
match (self, another) {
|
|
||||||
(Self::String(input), Self::String(ex)) => input.extend(ex),
|
|
||||||
(Self::Int(input), Self::Int(ex)) => input.extend(ex),
|
|
||||||
(Self::Int64(input), Self::Int64(ex)) => input.extend(ex),
|
|
||||||
(Self::Float(input), Self::Float(ex)) => input.extend(ex),
|
|
||||||
(Self::Double(input), Self::Double(ex)) => input.extend(ex),
|
|
||||||
(Self::Boolean(input), Self::Boolean(ex)) => input.extend(ex),
|
|
||||||
x => panic!("input enum type not matched. input:{:?}, ex:{:?}", x.0, x.1),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
#[inline(always)]
|
|
||||||
pub(crate) fn merge_batch(input_tensors: Vec<Vec<TensorInput>>) -> Vec<TensorInput> {
|
|
||||||
input_tensors
|
|
||||||
.into_iter()
|
|
||||||
.reduce(|mut acc, e| {
|
|
||||||
for (i, ext) in acc.iter_mut().zip(e) {
|
|
||||||
i.tensor_data.extend(ext.tensor_data);
|
|
||||||
}
|
|
||||||
acc
|
|
||||||
})
|
|
||||||
.unwrap() //invariant: we expect there's always rows in input_tensors
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
///entry point for tfServing gRPC
|
|
||||||
#[tonic::async_trait]
|
|
||||||
impl<T: Model> GrpcInferenceService for PredictService<T> {
|
|
||||||
async fn server_live(
|
|
||||||
&self,
|
|
||||||
_request: Request<ServerLiveRequest>,
|
|
||||||
) -> Result<Response<ServerLiveResponse>, Status> {
|
|
||||||
unimplemented!()
|
|
||||||
}
|
|
||||||
async fn server_ready(
|
|
||||||
&self,
|
|
||||||
_request: Request<ServerReadyRequest>,
|
|
||||||
) -> Result<Response<ServerReadyResponse>, Status> {
|
|
||||||
unimplemented!()
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn model_ready(
|
|
||||||
&self,
|
|
||||||
_request: Request<ModelReadyRequest>,
|
|
||||||
) -> Result<Response<ModelReadyResponse>, Status> {
|
|
||||||
unimplemented!()
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn server_metadata(
|
|
||||||
&self,
|
|
||||||
_request: Request<ServerMetadataRequest>,
|
|
||||||
) -> Result<Response<ServerMetadataResponse>, Status> {
|
|
||||||
unimplemented!()
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn model_metadata(
|
|
||||||
&self,
|
|
||||||
_request: Request<ModelMetadataRequest>,
|
|
||||||
) -> Result<Response<ModelMetadataResponse>, Status> {
|
|
||||||
unimplemented!()
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn model_infer(
|
|
||||||
&self,
|
|
||||||
_request: Request<ModelInferRequest>,
|
|
||||||
) -> Result<Response<ModelInferResponse>, Status> {
|
|
||||||
unimplemented!()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tonic::async_trait]
|
|
||||||
impl<T: Model> PredictionService for PredictService<T> {
|
|
||||||
async fn classify(
|
|
||||||
&self,
|
|
||||||
_request: Request<ClassificationRequest>,
|
|
||||||
) -> Result<Response<ClassificationResponse>, Status> {
|
|
||||||
unimplemented!()
|
|
||||||
}
|
|
||||||
async fn regress(
|
|
||||||
&self,
|
|
||||||
_request: Request<RegressionRequest>,
|
|
||||||
) -> Result<Response<RegressionResponse>, Status> {
|
|
||||||
unimplemented!()
|
|
||||||
}
|
|
||||||
async fn predict(
|
|
||||||
&self,
|
|
||||||
request: Request<PredictRequest>,
|
|
||||||
) -> Result<Response<PredictResponse>, Status> {
|
|
||||||
NUM_REQUESTS_RECEIVED.inc();
|
|
||||||
let start = Instant::now();
|
|
||||||
let mut req = request.into_inner();
|
|
||||||
let (model_spec, version) = req.take_model_spec();
|
|
||||||
NUM_REQUESTS_RECEIVED_BY_MODEL
|
|
||||||
.with_label_values(&[&model_spec])
|
|
||||||
.inc();
|
|
||||||
let idx = PredictService::<T>::get_model_index(&model_spec).ok_or_else(|| {
|
|
||||||
Status::failed_precondition(format!("model spec not found:{}", model_spec))
|
|
||||||
})?;
|
|
||||||
let input_spec = match INPUTS[idx].get() {
|
|
||||||
Some(input) => input,
|
|
||||||
_ => return Err(Status::not_found(format!("model input spec {}", idx))),
|
|
||||||
};
|
|
||||||
let input_val = req.take_input_vals(input_spec);
|
|
||||||
self.predict(idx, version, input_val, start)
|
|
||||||
.await
|
|
||||||
.map_or_else(
|
|
||||||
|e| {
|
|
||||||
NUM_REQUESTS_FAILED.inc();
|
|
||||||
NUM_REQUESTS_FAILED_BY_MODEL
|
|
||||||
.with_label_values(&[&model_spec])
|
|
||||||
.inc();
|
|
||||||
Err(Status::internal(e.to_string()))
|
|
||||||
},
|
|
||||||
|res| {
|
|
||||||
RESPONSE_TIME_COLLECTOR
|
|
||||||
.with_label_values(&[&model_spec])
|
|
||||||
.observe(start.elapsed().as_millis() as f64);
|
|
||||||
|
|
||||||
match res {
|
|
||||||
PredictResult::Ok(tensors, version) => {
|
|
||||||
let mut outputs = HashMap::new();
|
|
||||||
NUM_PREDICTIONS.with_label_values(&[&model_spec]).inc();
|
|
||||||
//FIXME: uncomment when prediction scores are normal
|
|
||||||
// PREDICTION_SCORE_SUM
|
|
||||||
// .with_label_values(&[&model_spec])
|
|
||||||
// .inc_by(tensors[0]as f64);
|
|
||||||
for (tp, output_name) in tensors
|
|
||||||
.into_iter()
|
|
||||||
.map(|tensor| tensor.create_tensor_proto())
|
|
||||||
.zip(OUTPUTS[idx].iter())
|
|
||||||
{
|
|
||||||
outputs.insert(output_name.to_owned(), tp);
|
|
||||||
}
|
|
||||||
let reply = PredictResponse {
|
|
||||||
model_spec: Some(ModelSpec {
|
|
||||||
version_choice: Some(Version(version)),
|
|
||||||
..Default::default()
|
|
||||||
}),
|
|
||||||
outputs,
|
|
||||||
};
|
|
||||||
Ok(Response::new(reply))
|
|
||||||
}
|
|
||||||
PredictResult::DropDueToOverload => Err(Status::resource_exhausted("")),
|
|
||||||
PredictResult::ModelNotFound(idx) => {
|
|
||||||
Err(Status::not_found(format!("model index {}", idx)))
|
|
||||||
},
|
|
||||||
PredictResult::ModelNotReady(idx) => {
|
|
||||||
Err(Status::unavailable(format!("model index {}", idx)))
|
|
||||||
}
|
|
||||||
PredictResult::ModelVersionNotFound(idx, version) => Err(
|
|
||||||
Status::not_found(format!("model index:{}, version {}", idx, version)),
|
|
||||||
),
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn multi_inference(
|
|
||||||
&self,
|
|
||||||
_request: Request<MultiInferenceRequest>,
|
|
||||||
) -> Result<Response<MultiInferenceResponse>, Status> {
|
|
||||||
unimplemented!()
|
|
||||||
}
|
|
||||||
async fn get_model_metadata(
|
|
||||||
&self,
|
|
||||||
_request: Request<GetModelMetadataRequest>,
|
|
||||||
) -> Result<Response<GetModelMetadataResponse>, Status> {
|
|
||||||
unimplemented!()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// A function that takes a timestamp as input and returns a ticker stream
|
|
||||||
fn report_expiry(expiry_time: i64) {
|
|
||||||
info!("Certificate expires at epoch: {:?}", expiry_time);
|
|
||||||
CERT_EXPIRY_EPOCH.set(expiry_time as i64);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn bootstrap<T: Model>(model_factory: ModelFactory<T>) -> Result<()> {
|
|
||||||
info!("package: {}, version: {}, args: {:?}", NAME, VERSION, *ARGS);
|
|
||||||
//we follow SemVer. So here we assume MAJOR.MINOR.PATCH
|
|
||||||
let parts = VERSION
|
|
||||||
.split(".")
|
|
||||||
.map(|v| v.parse::<i64>())
|
|
||||||
.collect::<std::result::Result<Vec<_>, _>>()?;
|
|
||||||
if let [major, minor, patch] = &parts[..] {
|
|
||||||
NAVI_VERSION.set(major * 1000_000 + minor * 1000 + patch);
|
|
||||||
} else {
|
|
||||||
warn!(
|
|
||||||
"version {} doesn't follow SemVer conversion of MAJOR.MINOR.PATCH",
|
|
||||||
VERSION
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
tokio::runtime::Builder::new_multi_thread()
|
|
||||||
.thread_name("async worker")
|
|
||||||
.worker_threads(ARGS.num_worker_threads)
|
|
||||||
.max_blocking_threads(ARGS.max_blocking_threads)
|
|
||||||
.enable_all()
|
|
||||||
.build()
|
|
||||||
.unwrap()
|
|
||||||
.block_on(async {
|
|
||||||
#[cfg(feature = "navi_console")]
|
|
||||||
console_subscriber::init();
|
|
||||||
let addr = format!("0.0.0.0:{}", ARGS.port).parse()?;
|
|
||||||
|
|
||||||
let ps = PredictService::init(model_factory).await;
|
|
||||||
|
|
||||||
let mut builder = if ARGS.ssl_dir.is_empty() {
|
|
||||||
Server::builder()
|
|
||||||
} else {
|
|
||||||
// Read the pem file as a string
|
|
||||||
let pem_str = std::fs::read_to_string(format!("{}/server.crt", ARGS.ssl_dir)).unwrap();
|
|
||||||
let res = parse_x509_pem(&pem_str.as_bytes());
|
|
||||||
match res {
|
|
||||||
Ok((rem, pem_2)) => {
|
|
||||||
assert!(rem.is_empty());
|
|
||||||
assert_eq!(pem_2.label, String::from("CERTIFICATE"));
|
|
||||||
let res_x509 = parse_x509_certificate(&pem_2.contents);
|
|
||||||
info!("Certificate label: {}", pem_2.label);
|
|
||||||
assert!(res_x509.is_ok());
|
|
||||||
report_expiry(res_x509.unwrap().1.validity().not_after.timestamp());
|
|
||||||
},
|
|
||||||
_ => panic!("PEM parsing failed: {:?}", res),
|
|
||||||
}
|
|
||||||
|
|
||||||
let key = tokio::fs::read(format!("{}/server.key", ARGS.ssl_dir))
|
|
||||||
.await
|
|
||||||
.expect("can't find key file");
|
|
||||||
let crt = tokio::fs::read(format!("{}/server.crt", ARGS.ssl_dir))
|
|
||||||
.await
|
|
||||||
.expect("can't find crt file");
|
|
||||||
let chain = tokio::fs::read(format!("{}/server.chain", ARGS.ssl_dir))
|
|
||||||
.await
|
|
||||||
.expect("can't find chain file");
|
|
||||||
let mut pem = Vec::new();
|
|
||||||
pem.extend(crt);
|
|
||||||
pem.extend(chain);
|
|
||||||
let identity = Identity::from_pem(pem.clone(), key);
|
|
||||||
let client_ca_cert = Certificate::from_pem(pem.clone());
|
|
||||||
let tls = ServerTlsConfig::new()
|
|
||||||
.identity(identity)
|
|
||||||
.client_ca_root(client_ca_cert);
|
|
||||||
Server::builder()
|
|
||||||
.tls_config(tls)
|
|
||||||
.expect("fail to config SSL")
|
|
||||||
};
|
|
||||||
|
|
||||||
info!(
|
|
||||||
"Prometheus server started: 0.0.0.0: {}",
|
|
||||||
ARGS.prometheus_port
|
|
||||||
);
|
|
||||||
|
|
||||||
let ps_server = builder
|
|
||||||
.add_service(PredictionServiceServer::new(ps).accept_gzip().send_gzip())
|
|
||||||
.serve(addr);
|
|
||||||
info!("Prediction server started: {}", addr);
|
|
||||||
ps_server.await.map_err(anyhow::Error::msg)
|
|
||||||
})
|
|
||||||
}
|
|
BIN
navi/navi/src/cli_args.docx
Normal file
BIN
navi/navi/src/cli_args.docx
Normal file
Binary file not shown.
@ -1,236 +0,0 @@
|
|||||||
use crate::{MAX_NUM_INPUTS, MAX_NUM_MODELS, MAX_NUM_OUTPUTS};
|
|
||||||
use arrayvec::ArrayVec;
|
|
||||||
use clap::Parser;
|
|
||||||
use log::info;
|
|
||||||
use once_cell::sync::OnceCell;
|
|
||||||
use std::error::Error;
|
|
||||||
use time::OffsetDateTime;
|
|
||||||
use time::format_description::well_known::Rfc3339;
|
|
||||||
#[derive(Parser, Debug, Clone)]
|
|
||||||
///Navi is configured through CLI arguments(for now) defined below.
|
|
||||||
//TODO: use clap_serde to make it config file driven
|
|
||||||
pub struct Args {
|
|
||||||
#[clap(short, long, help = "gRPC port Navi runs ons")]
|
|
||||||
pub port: i32,
|
|
||||||
#[clap(long, default_value_t = 9000, help = "prometheus metrics port")]
|
|
||||||
pub prometheus_port: u16,
|
|
||||||
#[clap(
|
|
||||||
short,
|
|
||||||
long,
|
|
||||||
default_value_t = 1,
|
|
||||||
help = "number of worker threads for tokio async runtime"
|
|
||||||
)]
|
|
||||||
pub num_worker_threads: usize,
|
|
||||||
#[clap(
|
|
||||||
long,
|
|
||||||
default_value_t = 14,
|
|
||||||
help = "number of blocking threads in tokio blocking thread pool"
|
|
||||||
)]
|
|
||||||
pub max_blocking_threads: usize,
|
|
||||||
#[clap(long, default_value = "16", help = "maximum batch size for a batch")]
|
|
||||||
pub max_batch_size: Vec<String>,
|
|
||||||
#[clap(
|
|
||||||
short,
|
|
||||||
long,
|
|
||||||
default_value = "2",
|
|
||||||
help = "max wait time for accumulating a batch"
|
|
||||||
)]
|
|
||||||
pub batch_time_out_millis: Vec<String>,
|
|
||||||
#[clap(
|
|
||||||
long,
|
|
||||||
default_value_t = 90,
|
|
||||||
help = "threshold to start dropping batches under stress"
|
|
||||||
)]
|
|
||||||
pub batch_drop_millis: u64,
|
|
||||||
#[clap(
|
|
||||||
long,
|
|
||||||
default_value_t = 300,
|
|
||||||
help = "polling interval for new version of a model and META.json config"
|
|
||||||
)]
|
|
||||||
pub model_check_interval_secs: u64,
|
|
||||||
#[clap(
|
|
||||||
short,
|
|
||||||
long,
|
|
||||||
default_value = "models/pvideo/",
|
|
||||||
help = "root directory for models"
|
|
||||||
)]
|
|
||||||
pub model_dir: Vec<String>,
|
|
||||||
#[clap(
|
|
||||||
long,
|
|
||||||
help = "directory containing META.json config. separate from model_dir to facilitate remote config management"
|
|
||||||
)]
|
|
||||||
pub meta_json_dir: Option<String>,
|
|
||||||
#[clap(short, long, default_value = "", help = "directory for ssl certs")]
|
|
||||||
pub ssl_dir: String,
|
|
||||||
#[clap(
|
|
||||||
long,
|
|
||||||
help = "call out to external process to check model updates. custom logic can be written to pull from hdfs, gcs etc"
|
|
||||||
)]
|
|
||||||
pub modelsync_cli: Option<String>,
|
|
||||||
#[clap(
|
|
||||||
long,
|
|
||||||
default_value_t = 1,
|
|
||||||
help = "specify how many versions Navi retains in memory. good for cases of rolling model upgrade"
|
|
||||||
)]
|
|
||||||
pub versions_per_model: usize,
|
|
||||||
#[clap(
|
|
||||||
short,
|
|
||||||
long,
|
|
||||||
help = "most runtimes support loading ops custom writen. currently only implemented for TF"
|
|
||||||
)]
|
|
||||||
pub customops_lib: Option<String>,
|
|
||||||
#[clap(
|
|
||||||
long,
|
|
||||||
default_value = "8",
|
|
||||||
help = "number of threads to paralleling computations inside an op"
|
|
||||||
)]
|
|
||||||
pub intra_op_parallelism: Vec<String>,
|
|
||||||
#[clap(
|
|
||||||
long,
|
|
||||||
help = "number of threads to parallelize computations of the graph"
|
|
||||||
)]
|
|
||||||
pub inter_op_parallelism: Vec<String>,
|
|
||||||
#[clap(
|
|
||||||
long,
|
|
||||||
help = "signature of a serving. only TF"
|
|
||||||
)]
|
|
||||||
pub serving_sig: Vec<String>,
|
|
||||||
#[clap(long, default_value = "examples", help = "name of each input tensor")]
|
|
||||||
pub input: Vec<String>,
|
|
||||||
#[clap(long, default_value = "output_0", help = "name of each output tensor")]
|
|
||||||
pub output: Vec<String>,
|
|
||||||
#[clap(
|
|
||||||
long,
|
|
||||||
default_value_t = 500,
|
|
||||||
help = "max warmup records to use. warmup only implemented for TF"
|
|
||||||
)]
|
|
||||||
pub max_warmup_records: usize,
|
|
||||||
#[clap(long, value_parser = Args::parse_key_val::<String, String>, value_delimiter=',')]
|
|
||||||
pub onnx_global_thread_pool_options: Vec<(String, String)>,
|
|
||||||
#[clap(
|
|
||||||
long,
|
|
||||||
default_value = "true",
|
|
||||||
help = "when to use graph parallelization. only for ONNX"
|
|
||||||
)]
|
|
||||||
pub onnx_use_parallel_mode: String,
|
|
||||||
// #[clap(long, default_value = "false")]
|
|
||||||
// pub onnx_use_onednn: String,
|
|
||||||
#[clap(
|
|
||||||
long,
|
|
||||||
default_value = "true",
|
|
||||||
help = "trace internal memory allocation and generate bulk memory allocations. only for ONNX. turn if off if batch size dynamic"
|
|
||||||
)]
|
|
||||||
pub onnx_use_memory_pattern: String,
|
|
||||||
#[clap(long, value_parser = Args::parse_key_val::<String, String>, value_delimiter=',')]
|
|
||||||
pub onnx_ep_options: Vec<(String, String)>,
|
|
||||||
#[clap(long, help = "choice of gpu EPs for ONNX: cuda or tensorrt")]
|
|
||||||
pub onnx_gpu_ep: Option<String>,
|
|
||||||
#[clap(
|
|
||||||
long,
|
|
||||||
default_value = "home",
|
|
||||||
help = "converter for various input formats"
|
|
||||||
)]
|
|
||||||
pub onnx_use_converter: Option<String>,
|
|
||||||
#[clap(
|
|
||||||
long,
|
|
||||||
help = "whether to enable runtime profiling. only implemented for ONNX for now"
|
|
||||||
)]
|
|
||||||
pub profiling: Option<String>,
|
|
||||||
#[clap(
|
|
||||||
long,
|
|
||||||
default_value = "",
|
|
||||||
help = "metrics reporting for discrete features. only for Home converter for now"
|
|
||||||
)]
|
|
||||||
pub onnx_report_discrete_feature_ids: Vec<String>,
|
|
||||||
#[clap(
|
|
||||||
long,
|
|
||||||
default_value = "",
|
|
||||||
help = "metrics reporting for continuous features. only for Home converter for now"
|
|
||||||
)]
|
|
||||||
pub onnx_report_continuous_feature_ids: Vec<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Args {
|
|
||||||
pub fn get_model_specs(model_dir: Vec<String>) -> Vec<String> {
|
|
||||||
let model_specs = model_dir
|
|
||||||
.iter()
|
|
||||||
//let it panic if some model_dir are wrong
|
|
||||||
.map(|dir| {
|
|
||||||
dir.trim_end_matches('/')
|
|
||||||
.rsplit_once('/')
|
|
||||||
.unwrap()
|
|
||||||
.1
|
|
||||||
.to_owned()
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
info!("all model_specs: {:?}", model_specs);
|
|
||||||
model_specs
|
|
||||||
}
|
|
||||||
pub fn version_str_to_epoch(dt_str: &str) -> Result<i64, anyhow::Error> {
|
|
||||||
dt_str
|
|
||||||
.parse::<i64>()
|
|
||||||
.or_else(|_| {
|
|
||||||
let ts = OffsetDateTime::parse(dt_str, &Rfc3339)
|
|
||||||
.map(|d| (d.unix_timestamp_nanos()/1_000_000) as i64);
|
|
||||||
if ts.is_ok() {
|
|
||||||
info!("original version {} -> {}", dt_str, ts.unwrap());
|
|
||||||
}
|
|
||||||
ts
|
|
||||||
})
|
|
||||||
.map_err(anyhow::Error::msg)
|
|
||||||
}
|
|
||||||
/// Parse a single key-value pair
|
|
||||||
fn parse_key_val<T, U>(s: &str) -> Result<(T, U), Box<dyn Error + Send + Sync + 'static>>
|
|
||||||
where
|
|
||||||
T: std::str::FromStr,
|
|
||||||
T::Err: Error + Send + Sync + 'static,
|
|
||||||
U: std::str::FromStr,
|
|
||||||
U::Err: Error + Send + Sync + 'static,
|
|
||||||
{
|
|
||||||
let pos = s
|
|
||||||
.find('=')
|
|
||||||
.ok_or_else(|| format!("invalid KEY=value: no `=` found in `{}`", s))?;
|
|
||||||
Ok((s[..pos].parse()?, s[pos + 1..].parse()?))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
lazy_static! {
|
|
||||||
pub static ref ARGS: Args = Args::parse();
|
|
||||||
pub static ref MODEL_SPECS: ArrayVec<String, MAX_NUM_MODELS> = {
|
|
||||||
let mut specs = ArrayVec::<String, MAX_NUM_MODELS>::new();
|
|
||||||
Args::get_model_specs(ARGS.model_dir.clone())
|
|
||||||
.into_iter()
|
|
||||||
.for_each(|m| specs.push(m));
|
|
||||||
specs
|
|
||||||
};
|
|
||||||
pub static ref INPUTS: ArrayVec<OnceCell<ArrayVec<String, MAX_NUM_INPUTS>>, MAX_NUM_MODELS> = {
|
|
||||||
let mut inputs =
|
|
||||||
ArrayVec::<OnceCell<ArrayVec<String, MAX_NUM_INPUTS>>, MAX_NUM_MODELS>::new();
|
|
||||||
for (idx, o) in ARGS.input.iter().enumerate() {
|
|
||||||
if o.trim().is_empty() {
|
|
||||||
info!("input spec is empty for model {}, auto detect later", idx);
|
|
||||||
inputs.push(OnceCell::new());
|
|
||||||
} else {
|
|
||||||
inputs.push(OnceCell::with_value(
|
|
||||||
o.split(",")
|
|
||||||
.map(|s| s.to_owned())
|
|
||||||
.collect::<ArrayVec<String, MAX_NUM_INPUTS>>(),
|
|
||||||
));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
info!("all inputs:{:?}", inputs);
|
|
||||||
inputs
|
|
||||||
};
|
|
||||||
pub static ref OUTPUTS: ArrayVec<ArrayVec<String, MAX_NUM_OUTPUTS>, MAX_NUM_MODELS> = {
|
|
||||||
let mut outputs = ArrayVec::<ArrayVec<String, MAX_NUM_OUTPUTS>, MAX_NUM_MODELS>::new();
|
|
||||||
for o in ARGS.output.iter() {
|
|
||||||
outputs.push(
|
|
||||||
o.split(",")
|
|
||||||
.map(|s| s.to_owned())
|
|
||||||
.collect::<ArrayVec<String, MAX_NUM_OUTPUTS>>(),
|
|
||||||
);
|
|
||||||
}
|
|
||||||
info!("all outputs:{:?}", outputs);
|
|
||||||
outputs
|
|
||||||
};
|
|
||||||
}
|
|
BIN
navi/navi/src/cores/validator.docx
Normal file
BIN
navi/navi/src/cores/validator.docx
Normal file
Binary file not shown.
@ -1,22 +0,0 @@
|
|||||||
pub mod validatior {
|
|
||||||
pub mod cli_validator {
|
|
||||||
use crate::cli_args::{ARGS, MODEL_SPECS};
|
|
||||||
|
|
||||||
pub fn validate_input_args() {
|
|
||||||
assert_eq!(MODEL_SPECS.len(), ARGS.inter_op_parallelism.len());
|
|
||||||
assert_eq!(MODEL_SPECS.len(), ARGS.intra_op_parallelism.len());
|
|
||||||
//TODO for now we, we assume each model's output has only 1 tensor.
|
|
||||||
//this will be lifted once tf_model properly implements mtl outputs
|
|
||||||
//assert_eq!(OUTPUTS.len(), OUTPUTS.iter().fold(0usize, |a, b| a+b.len()));
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn validate_ps_model_args() {
|
|
||||||
assert!(ARGS.versions_per_model <= 2);
|
|
||||||
assert!(ARGS.versions_per_model >= 1);
|
|
||||||
assert_eq!(MODEL_SPECS.len(), ARGS.input.len());
|
|
||||||
assert_eq!(MODEL_SPECS.len(), ARGS.model_dir.len());
|
|
||||||
assert_eq!(MODEL_SPECS.len(), ARGS.max_batch_size.len());
|
|
||||||
assert_eq!(MODEL_SPECS.len(), ARGS.batch_time_out_millis.len());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
BIN
navi/navi/src/lib.docx
Normal file
BIN
navi/navi/src/lib.docx
Normal file
Binary file not shown.
@ -1,215 +0,0 @@
|
|||||||
#[macro_use]
|
|
||||||
extern crate lazy_static;
|
|
||||||
extern crate core;
|
|
||||||
|
|
||||||
use serde_json::Value;
|
|
||||||
use tokio::sync::oneshot::Sender;
|
|
||||||
use tokio::time::Instant;
|
|
||||||
use std::ops::Deref;
|
|
||||||
use itertools::Itertools;
|
|
||||||
use crate::bootstrap::TensorInput;
|
|
||||||
use crate::predict_service::Model;
|
|
||||||
use crate::tf_proto::{DataType, TensorProto};
|
|
||||||
|
|
||||||
pub mod batch;
|
|
||||||
pub mod bootstrap;
|
|
||||||
pub mod cli_args;
|
|
||||||
pub mod metrics;
|
|
||||||
pub mod onnx_model;
|
|
||||||
pub mod predict_service;
|
|
||||||
pub mod tf_model;
|
|
||||||
pub mod torch_model;
|
|
||||||
pub mod cores {
|
|
||||||
pub mod validator;
|
|
||||||
}
|
|
||||||
|
|
||||||
pub mod tf_proto {
|
|
||||||
tonic::include_proto!("tensorflow");
|
|
||||||
pub mod tensorflow_serving {
|
|
||||||
tonic::include_proto!("tensorflow.serving");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub mod kf_serving {
|
|
||||||
tonic::include_proto!("inference");
|
|
||||||
}
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests {
|
|
||||||
use crate::cli_args::Args;
|
|
||||||
#[test]
|
|
||||||
fn test_version_string_to_epoch() {
|
|
||||||
assert_eq!(
|
|
||||||
Args::version_str_to_epoch("2022-12-20T10:18:53.000Z").unwrap_or(-1),
|
|
||||||
1671531533000
|
|
||||||
);
|
|
||||||
assert_eq!(Args::version_str_to_epoch("1203444").unwrap_or(-1), 1203444);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
mod utils {
|
|
||||||
use crate::cli_args::{ARGS, MODEL_SPECS};
|
|
||||||
use anyhow::Result;
|
|
||||||
use log::info;
|
|
||||||
use serde_json::Value;
|
|
||||||
|
|
||||||
pub fn read_config(meta_file: &String) -> Result<Value> {
|
|
||||||
let json = std::fs::read_to_string(meta_file)?;
|
|
||||||
let v: Value = serde_json::from_str(&json)?;
|
|
||||||
Ok(v)
|
|
||||||
}
|
|
||||||
pub fn get_config_or_else<F>(model_config: &Value, key: &str, default: F) -> String
|
|
||||||
where
|
|
||||||
F: FnOnce() -> String,
|
|
||||||
{
|
|
||||||
match model_config[key] {
|
|
||||||
Value::String(ref v) => {
|
|
||||||
info!("from model_config: {}={}", key, v);
|
|
||||||
v.to_string()
|
|
||||||
}
|
|
||||||
Value::Number(ref num) => {
|
|
||||||
info!(
|
|
||||||
"from model_config: {}={} (turn number into a string)",
|
|
||||||
key, num
|
|
||||||
);
|
|
||||||
num.to_string()
|
|
||||||
}
|
|
||||||
_ => {
|
|
||||||
let d = default();
|
|
||||||
info!("from default: {}={}", key, d);
|
|
||||||
d
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
pub fn get_config_or(model_config: &Value, key: &str, default: &str) -> String {
|
|
||||||
get_config_or_else(model_config, key, || default.to_string())
|
|
||||||
}
|
|
||||||
pub fn get_meta_dir() -> &'static str {
|
|
||||||
ARGS.meta_json_dir
|
|
||||||
.as_ref()
|
|
||||||
.map(|s| s.as_str())
|
|
||||||
.unwrap_or_else(|| {
|
|
||||||
let model_dir = &ARGS.model_dir[0];
|
|
||||||
let meta_dir = &model_dir[0..model_dir.rfind(&MODEL_SPECS[0]).unwrap()];
|
|
||||||
info!(
|
|
||||||
"no meta_json_dir specified, hence derive from first model dir:{}->{}",
|
|
||||||
model_dir, meta_dir
|
|
||||||
);
|
|
||||||
meta_dir
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub type SerializedInput = Vec<u8>;
|
|
||||||
pub const VERSION: &str = env!("CARGO_PKG_VERSION");
|
|
||||||
pub const NAME: &str = env!("CARGO_PKG_NAME");
|
|
||||||
pub type ModelFactory<T> = fn(usize, String, &Value) -> anyhow::Result<T>;
|
|
||||||
pub const MAX_NUM_MODELS: usize = 16;
|
|
||||||
pub const MAX_NUM_OUTPUTS: usize = 30;
|
|
||||||
pub const MAX_NUM_INPUTS: usize = 120;
|
|
||||||
pub const META_INFO: &str = "META.json";
|
|
||||||
|
|
||||||
//use a heap allocated generic type here so that both
|
|
||||||
//Tensorflow & Pytorch implementation can return their Tensor wrapped in a Box
|
|
||||||
//without an extra memcopy to Vec
|
|
||||||
pub type TensorReturn<T> = Box<dyn Deref<Target = [T]>>;
|
|
||||||
|
|
||||||
//returned tensor may be int64 i.e., a list of relevant ad ids
|
|
||||||
pub enum TensorReturnEnum {
|
|
||||||
FloatTensorReturn(TensorReturn<f32>),
|
|
||||||
StringTensorReturn(TensorReturn<String>),
|
|
||||||
Int64TensorReturn(TensorReturn<i64>),
|
|
||||||
Int32TensorReturn(TensorReturn<i32>),
|
|
||||||
}
|
|
||||||
|
|
||||||
impl TensorReturnEnum {
|
|
||||||
#[inline(always)]
|
|
||||||
pub fn slice(&self, start: usize, end: usize) -> TensorScores {
|
|
||||||
match self {
|
|
||||||
TensorReturnEnum::FloatTensorReturn(f32_return) => {
|
|
||||||
TensorScores::Float32TensorScores(f32_return[start..end].to_vec())
|
|
||||||
}
|
|
||||||
TensorReturnEnum::Int64TensorReturn(i64_return) => {
|
|
||||||
TensorScores::Int64TensorScores(i64_return[start..end].to_vec())
|
|
||||||
}
|
|
||||||
TensorReturnEnum::Int32TensorReturn(i32_return) => {
|
|
||||||
TensorScores::Int32TensorScores(i32_return[start..end].to_vec())
|
|
||||||
}
|
|
||||||
TensorReturnEnum::StringTensorReturn(str_return) => {
|
|
||||||
TensorScores::StringTensorScores(str_return[start..end].to_vec())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub enum PredictResult {
|
|
||||||
Ok(Vec<TensorScores>, i64),
|
|
||||||
DropDueToOverload,
|
|
||||||
ModelNotFound(usize),
|
|
||||||
ModelNotReady(usize),
|
|
||||||
ModelVersionNotFound(usize, i64),
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub enum TensorScores {
|
|
||||||
Float32TensorScores(Vec<f32>),
|
|
||||||
Int64TensorScores(Vec<i64>),
|
|
||||||
Int32TensorScores(Vec<i32>),
|
|
||||||
StringTensorScores(Vec<String>),
|
|
||||||
}
|
|
||||||
|
|
||||||
impl TensorScores {
|
|
||||||
pub fn create_tensor_proto(self) -> TensorProto {
|
|
||||||
match self {
|
|
||||||
TensorScores::Float32TensorScores(f32_tensor) => TensorProto {
|
|
||||||
dtype: DataType::DtFloat as i32,
|
|
||||||
float_val: f32_tensor,
|
|
||||||
..Default::default()
|
|
||||||
},
|
|
||||||
TensorScores::Int64TensorScores(i64_tensor) => TensorProto {
|
|
||||||
dtype: DataType::DtInt64 as i32,
|
|
||||||
int64_val: i64_tensor,
|
|
||||||
..Default::default()
|
|
||||||
},
|
|
||||||
TensorScores::Int32TensorScores(i32_tensor) => TensorProto {
|
|
||||||
dtype: DataType::DtInt32 as i32,
|
|
||||||
int_val: i32_tensor,
|
|
||||||
..Default::default()
|
|
||||||
},
|
|
||||||
TensorScores::StringTensorScores(str_tensor) => TensorProto {
|
|
||||||
dtype: DataType::DtString as i32,
|
|
||||||
string_val: str_tensor.into_iter().map(|s| s.into_bytes()).collect_vec(),
|
|
||||||
..Default::default()
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
pub fn len(&self) -> usize {
|
|
||||||
match &self {
|
|
||||||
TensorScores::Float32TensorScores(t) => t.len(),
|
|
||||||
TensorScores::Int64TensorScores(t) => t.len(),
|
|
||||||
TensorScores::Int32TensorScores(t) => t.len(),
|
|
||||||
TensorScores::StringTensorScores(t) => t.len(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub enum PredictMessage<T: Model> {
|
|
||||||
Predict(
|
|
||||||
usize,
|
|
||||||
Option<i64>,
|
|
||||||
Vec<TensorInput>,
|
|
||||||
Sender<PredictResult>,
|
|
||||||
Instant,
|
|
||||||
),
|
|
||||||
UpsertModel(T),
|
|
||||||
/*
|
|
||||||
#[allow(dead_code)]
|
|
||||||
DeleteModel(usize),
|
|
||||||
*/
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct Callback(Sender<PredictResult>, usize);
|
|
||||||
|
|
||||||
pub const MAX_VERSIONS_PER_MODEL: usize = 2;
|
|
BIN
navi/navi/src/metrics.docx
Normal file
BIN
navi/navi/src/metrics.docx
Normal file
Binary file not shown.
@ -1,297 +0,0 @@
|
|||||||
use log::error;
|
|
||||||
use prometheus::{
|
|
||||||
CounterVec, HistogramOpts, HistogramVec, IntCounter, IntCounterVec, IntGauge, IntGaugeVec,
|
|
||||||
Opts, Registry,
|
|
||||||
};
|
|
||||||
use warp::{Rejection, Reply};
|
|
||||||
use crate::{NAME, VERSION};
|
|
||||||
|
|
||||||
lazy_static! {
|
|
||||||
pub static ref REGISTRY: Registry = Registry::new();
|
|
||||||
pub static ref NUM_REQUESTS_RECEIVED: IntCounter =
|
|
||||||
IntCounter::new(":navi:num_requests", "Number of Requests Received")
|
|
||||||
.expect("metric can be created");
|
|
||||||
pub static ref NUM_REQUESTS_FAILED: IntCounter = IntCounter::new(
|
|
||||||
":navi:num_requests_failed",
|
|
||||||
"Number of Request Inference Failed"
|
|
||||||
)
|
|
||||||
.expect("metric can be created");
|
|
||||||
pub static ref NUM_REQUESTS_DROPPED: IntCounter = IntCounter::new(
|
|
||||||
":navi:num_requests_dropped",
|
|
||||||
"Number of Oneshot Receivers Dropped"
|
|
||||||
)
|
|
||||||
.expect("metric can be created");
|
|
||||||
pub static ref NUM_BATCHES_DROPPED: IntCounter = IntCounter::new(
|
|
||||||
":navi:num_batches_dropped",
|
|
||||||
"Number of Batches Proactively Dropped"
|
|
||||||
)
|
|
||||||
.expect("metric can be created");
|
|
||||||
pub static ref NUM_BATCH_PREDICTION: IntCounter =
|
|
||||||
IntCounter::new(":navi:num_batch_prediction", "Number of batch prediction")
|
|
||||||
.expect("metric can be created");
|
|
||||||
pub static ref BATCH_SIZE: IntGauge =
|
|
||||||
IntGauge::new(":navi:batch_size", "Size of current batch").expect("metric can be created");
|
|
||||||
pub static ref NAVI_VERSION: IntGauge =
|
|
||||||
IntGauge::new(":navi:navi_version", "navi's current version")
|
|
||||||
.expect("metric can be created");
|
|
||||||
pub static ref RESPONSE_TIME_COLLECTOR: HistogramVec = HistogramVec::new(
|
|
||||||
HistogramOpts::new(":navi:response_time", "Response Time in ms").buckets(Vec::from(&[
|
|
||||||
0.0, 10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0, 110.0, 120.0, 130.0,
|
|
||||||
140.0, 150.0, 160.0, 170.0, 180.0, 190.0, 200.0, 250.0, 300.0, 500.0, 1000.0
|
|
||||||
]
|
|
||||||
as &'static [f64])),
|
|
||||||
&["model_name"]
|
|
||||||
)
|
|
||||||
.expect("metric can be created");
|
|
||||||
pub static ref NUM_PREDICTIONS: IntCounterVec = IntCounterVec::new(
|
|
||||||
Opts::new(
|
|
||||||
":navi:num_predictions",
|
|
||||||
"Number of predictions made by model"
|
|
||||||
),
|
|
||||||
&["model_name"]
|
|
||||||
)
|
|
||||||
.expect("metric can be created");
|
|
||||||
pub static ref PREDICTION_SCORE_SUM: CounterVec = CounterVec::new(
|
|
||||||
Opts::new(
|
|
||||||
":navi:prediction_score_sum",
|
|
||||||
"Sum of prediction score made by model"
|
|
||||||
),
|
|
||||||
&["model_name"]
|
|
||||||
)
|
|
||||||
.expect("metric can be created");
|
|
||||||
pub static ref NEW_MODEL_SNAPSHOT: IntCounterVec = IntCounterVec::new(
|
|
||||||
Opts::new(
|
|
||||||
":navi:new_model_snapshot",
|
|
||||||
"Load a new version of model snapshot"
|
|
||||||
),
|
|
||||||
&["model_name"]
|
|
||||||
)
|
|
||||||
.expect("metric can be created");
|
|
||||||
pub static ref MODEL_SNAPSHOT_VERSION: IntGaugeVec = IntGaugeVec::new(
|
|
||||||
Opts::new(
|
|
||||||
":navi:model_snapshot_version",
|
|
||||||
"Record model snapshot version"
|
|
||||||
),
|
|
||||||
&["model_name"]
|
|
||||||
)
|
|
||||||
.expect("metric can be created");
|
|
||||||
pub static ref NUM_REQUESTS_RECEIVED_BY_MODEL: IntCounterVec = IntCounterVec::new(
|
|
||||||
Opts::new(
|
|
||||||
":navi:num_requests_by_model",
|
|
||||||
"Number of Requests Received by model"
|
|
||||||
),
|
|
||||||
&["model_name"]
|
|
||||||
)
|
|
||||||
.expect("metric can be created");
|
|
||||||
pub static ref NUM_REQUESTS_FAILED_BY_MODEL: IntCounterVec = IntCounterVec::new(
|
|
||||||
Opts::new(
|
|
||||||
":navi:num_requests_failed_by_model",
|
|
||||||
"Number of Request Inference Failed by model"
|
|
||||||
),
|
|
||||||
&["model_name"]
|
|
||||||
)
|
|
||||||
.expect("metric can be created");
|
|
||||||
pub static ref NUM_REQUESTS_DROPPED_BY_MODEL: IntCounterVec = IntCounterVec::new(
|
|
||||||
Opts::new(
|
|
||||||
":navi:num_requests_dropped_by_model",
|
|
||||||
"Number of Oneshot Receivers Dropped by model"
|
|
||||||
),
|
|
||||||
&["model_name"]
|
|
||||||
)
|
|
||||||
.expect("metric can be created");
|
|
||||||
pub static ref NUM_BATCHES_DROPPED_BY_MODEL: IntCounterVec = IntCounterVec::new(
|
|
||||||
Opts::new(
|
|
||||||
":navi:num_batches_dropped_by_model",
|
|
||||||
"Number of Batches Proactively Dropped by model"
|
|
||||||
),
|
|
||||||
&["model_name"]
|
|
||||||
)
|
|
||||||
.expect("metric can be created");
|
|
||||||
pub static ref INFERENCE_FAILED_REQUESTS_BY_MODEL: IntCounterVec = IntCounterVec::new(
|
|
||||||
Opts::new(
|
|
||||||
":navi:inference_failed_requests_by_model",
|
|
||||||
"Number of failed inference requests by model"
|
|
||||||
),
|
|
||||||
&["model_name"]
|
|
||||||
)
|
|
||||||
.expect("metric can be created");
|
|
||||||
pub static ref NUM_PREDICTION_BY_MODEL: IntCounterVec = IntCounterVec::new(
|
|
||||||
Opts::new(
|
|
||||||
":navi:num_prediction_by_model",
|
|
||||||
"Number of prediction by model"
|
|
||||||
),
|
|
||||||
&["model_name"]
|
|
||||||
)
|
|
||||||
.expect("metric can be created");
|
|
||||||
pub static ref NUM_BATCH_PREDICTION_BY_MODEL: IntCounterVec = IntCounterVec::new(
|
|
||||||
Opts::new(
|
|
||||||
":navi:num_batch_prediction_by_model",
|
|
||||||
"Number of batch prediction by model"
|
|
||||||
),
|
|
||||||
&["model_name"]
|
|
||||||
)
|
|
||||||
.expect("metric can be created");
|
|
||||||
pub static ref BATCH_SIZE_BY_MODEL: IntGaugeVec = IntGaugeVec::new(
|
|
||||||
Opts::new(
|
|
||||||
":navi:batch_size_by_model",
|
|
||||||
"Size of current batch by model"
|
|
||||||
),
|
|
||||||
&["model_name"]
|
|
||||||
)
|
|
||||||
.expect("metric can be created");
|
|
||||||
pub static ref CUSTOMOP_VERSION: IntGauge =
|
|
||||||
IntGauge::new(":navi:customop_version", "The hashed Custom OP Version")
|
|
||||||
.expect("metric can be created");
|
|
||||||
pub static ref MPSC_CHANNEL_SIZE: IntGauge =
|
|
||||||
IntGauge::new(":navi:mpsc_channel_size", "The mpsc channel's request size")
|
|
||||||
.expect("metric can be created");
|
|
||||||
pub static ref BLOCKING_REQUEST_NUM: IntGauge = IntGauge::new(
|
|
||||||
":navi:blocking_request_num",
|
|
||||||
"The (batch) request waiting or being executed"
|
|
||||||
)
|
|
||||||
.expect("metric can be created");
|
|
||||||
pub static ref MODEL_INFERENCE_TIME_COLLECTOR: HistogramVec = HistogramVec::new(
|
|
||||||
HistogramOpts::new(":navi:model_inference_time", "Model inference time in ms").buckets(
|
|
||||||
Vec::from(&[
|
|
||||||
0.0, 5.0, 10.0, 15.0, 20.0, 25.0, 30.0, 35.0, 40.0, 45.0, 50.0, 55.0, 60.0, 65.0,
|
|
||||||
70.0, 75.0, 80.0, 85.0, 90.0, 100.0, 110.0, 120.0, 130.0, 140.0, 150.0, 160.0,
|
|
||||||
170.0, 180.0, 190.0, 200.0, 250.0, 300.0, 500.0, 1000.0
|
|
||||||
] as &'static [f64])
|
|
||||||
),
|
|
||||||
&["model_name"]
|
|
||||||
)
|
|
||||||
.expect("metric can be created");
|
|
||||||
pub static ref CONVERTER_TIME_COLLECTOR: HistogramVec = HistogramVec::new(
|
|
||||||
HistogramOpts::new(":navi:converter_time", "converter time in microseconds").buckets(
|
|
||||||
Vec::from(&[
|
|
||||||
0.0, 500.0, 1000.0, 1500.0, 2000.0, 2500.0, 3000.0, 3500.0, 4000.0, 4500.0, 5000.0,
|
|
||||||
5500.0, 6000.0, 6500.0, 7000.0, 20000.0
|
|
||||||
] as &'static [f64])
|
|
||||||
),
|
|
||||||
&["model_name"]
|
|
||||||
)
|
|
||||||
.expect("metric can be created");
|
|
||||||
pub static ref CERT_EXPIRY_EPOCH: IntGauge =
|
|
||||||
IntGauge::new(":navi:cert_expiry_epoch", "Timestamp when the current cert expires")
|
|
||||||
.expect("metric can be created");
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn register_custom_metrics() {
|
|
||||||
REGISTRY
|
|
||||||
.register(Box::new(NUM_REQUESTS_RECEIVED.clone()))
|
|
||||||
.expect("collector can be registered");
|
|
||||||
REGISTRY
|
|
||||||
.register(Box::new(NUM_REQUESTS_FAILED.clone()))
|
|
||||||
.expect("collector can be registered");
|
|
||||||
REGISTRY
|
|
||||||
.register(Box::new(NUM_REQUESTS_DROPPED.clone()))
|
|
||||||
.expect("collector can be registered");
|
|
||||||
REGISTRY
|
|
||||||
.register(Box::new(RESPONSE_TIME_COLLECTOR.clone()))
|
|
||||||
.expect("collector can be registered");
|
|
||||||
REGISTRY
|
|
||||||
.register(Box::new(NAVI_VERSION.clone()))
|
|
||||||
.expect("collector can be registered");
|
|
||||||
REGISTRY
|
|
||||||
.register(Box::new(BATCH_SIZE.clone()))
|
|
||||||
.expect("collector can be registered");
|
|
||||||
REGISTRY
|
|
||||||
.register(Box::new(NUM_BATCH_PREDICTION.clone()))
|
|
||||||
.expect("collector can be registered");
|
|
||||||
REGISTRY
|
|
||||||
.register(Box::new(NUM_BATCHES_DROPPED.clone()))
|
|
||||||
.expect("collector can be registered");
|
|
||||||
REGISTRY
|
|
||||||
.register(Box::new(NUM_PREDICTIONS.clone()))
|
|
||||||
.expect("collector can be registered");
|
|
||||||
REGISTRY
|
|
||||||
.register(Box::new(PREDICTION_SCORE_SUM.clone()))
|
|
||||||
.expect("collector can be registered");
|
|
||||||
REGISTRY
|
|
||||||
.register(Box::new(NEW_MODEL_SNAPSHOT.clone()))
|
|
||||||
.expect("collector can be registered");
|
|
||||||
REGISTRY
|
|
||||||
.register(Box::new(MODEL_SNAPSHOT_VERSION.clone()))
|
|
||||||
.expect("collector can be registered");
|
|
||||||
REGISTRY
|
|
||||||
.register(Box::new(NUM_REQUESTS_RECEIVED_BY_MODEL.clone()))
|
|
||||||
.expect("collector can be registered");
|
|
||||||
REGISTRY
|
|
||||||
.register(Box::new(NUM_REQUESTS_FAILED_BY_MODEL.clone()))
|
|
||||||
.expect("collector can be registered");
|
|
||||||
REGISTRY
|
|
||||||
.register(Box::new(NUM_REQUESTS_DROPPED_BY_MODEL.clone()))
|
|
||||||
.expect("collector can be registered");
|
|
||||||
REGISTRY
|
|
||||||
.register(Box::new(NUM_BATCHES_DROPPED_BY_MODEL.clone()))
|
|
||||||
.expect("collector can be registered");
|
|
||||||
REGISTRY
|
|
||||||
.register(Box::new(INFERENCE_FAILED_REQUESTS_BY_MODEL.clone()))
|
|
||||||
.expect("collector can be registered");
|
|
||||||
REGISTRY
|
|
||||||
.register(Box::new(NUM_PREDICTION_BY_MODEL.clone()))
|
|
||||||
.expect("collector can be registered");
|
|
||||||
REGISTRY
|
|
||||||
.register(Box::new(NUM_BATCH_PREDICTION_BY_MODEL.clone()))
|
|
||||||
.expect("collector can be registered");
|
|
||||||
REGISTRY
|
|
||||||
.register(Box::new(BATCH_SIZE_BY_MODEL.clone()))
|
|
||||||
.expect("collector can be registered");
|
|
||||||
REGISTRY
|
|
||||||
.register(Box::new(CUSTOMOP_VERSION.clone()))
|
|
||||||
.expect("collector can be registered");
|
|
||||||
REGISTRY
|
|
||||||
.register(Box::new(MPSC_CHANNEL_SIZE.clone()))
|
|
||||||
.expect("collector can be registered");
|
|
||||||
REGISTRY
|
|
||||||
.register(Box::new(BLOCKING_REQUEST_NUM.clone()))
|
|
||||||
.expect("collector can be registered");
|
|
||||||
REGISTRY
|
|
||||||
.register(Box::new(MODEL_INFERENCE_TIME_COLLECTOR.clone()))
|
|
||||||
.expect("collector can be registered");
|
|
||||||
REGISTRY
|
|
||||||
.register(Box::new(CONVERTER_TIME_COLLECTOR.clone()))
|
|
||||||
.expect("collector can be registered");
|
|
||||||
REGISTRY
|
|
||||||
.register(Box::new(CERT_EXPIRY_EPOCH.clone()))
|
|
||||||
.expect("collector can be registered");
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn register_dynamic_metrics(c: &HistogramVec) {
|
|
||||||
REGISTRY
|
|
||||||
.register(Box::new(c.clone()))
|
|
||||||
.expect("dynamic metric collector cannot be registered");
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn metrics_handler() -> Result<impl Reply, Rejection> {
|
|
||||||
use prometheus::Encoder;
|
|
||||||
let encoder = prometheus::TextEncoder::new();
|
|
||||||
|
|
||||||
let mut buffer = Vec::new();
|
|
||||||
if let Err(e) = encoder.encode(®ISTRY.gather(), &mut buffer) {
|
|
||||||
error!("could not encode custom metrics: {}", e);
|
|
||||||
};
|
|
||||||
let mut res = match String::from_utf8(buffer) {
|
|
||||||
Ok(v) => format!("#{}:{}\n{}", NAME, VERSION, v),
|
|
||||||
Err(e) => {
|
|
||||||
error!("custom metrics could not be from_utf8'd: {}", e);
|
|
||||||
String::default()
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
buffer = Vec::new();
|
|
||||||
if let Err(e) = encoder.encode(&prometheus::gather(), &mut buffer) {
|
|
||||||
error!("could not encode prometheus metrics: {}", e);
|
|
||||||
};
|
|
||||||
let res_custom = match String::from_utf8(buffer) {
|
|
||||||
Ok(v) => v,
|
|
||||||
Err(e) => {
|
|
||||||
error!("prometheus metrics could not be from_utf8'd: {}", e);
|
|
||||||
String::default()
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
res.push_str(&res_custom);
|
|
||||||
Ok(res)
|
|
||||||
}
|
|
BIN
navi/navi/src/onnx_model.docx
Normal file
BIN
navi/navi/src/onnx_model.docx
Normal file
Binary file not shown.
@ -1,275 +0,0 @@
|
|||||||
#[cfg(feature = "onnx")]
|
|
||||||
pub mod onnx {
|
|
||||||
use crate::TensorReturnEnum;
|
|
||||||
use crate::bootstrap::{TensorInput, TensorInputEnum};
|
|
||||||
use crate::cli_args::{
|
|
||||||
Args, ARGS, INPUTS, MODEL_SPECS, OUTPUTS,
|
|
||||||
};
|
|
||||||
use crate::metrics::{self, CONVERTER_TIME_COLLECTOR};
|
|
||||||
use crate::predict_service::Model;
|
|
||||||
use crate::{MAX_NUM_INPUTS, MAX_NUM_OUTPUTS, META_INFO, utils};
|
|
||||||
use anyhow::Result;
|
|
||||||
use arrayvec::ArrayVec;
|
|
||||||
use dr_transform::converter::{BatchPredictionRequestToTorchTensorConverter, Converter};
|
|
||||||
use itertools::Itertools;
|
|
||||||
use log::{debug, info};
|
|
||||||
use dr_transform::ort::environment::Environment;
|
|
||||||
use dr_transform::ort::session::Session;
|
|
||||||
use dr_transform::ort::tensor::InputTensor;
|
|
||||||
use dr_transform::ort::{ExecutionProvider, GraphOptimizationLevel, SessionBuilder};
|
|
||||||
use dr_transform::ort::LoggingLevel;
|
|
||||||
use serde_json::Value;
|
|
||||||
use std::fmt::{Debug, Display};
|
|
||||||
use std::sync::Arc;
|
|
||||||
use std::{fmt, fs};
|
|
||||||
use tokio::time::Instant;
|
|
||||||
lazy_static! {
|
|
||||||
pub static ref ENVIRONMENT: Arc<Environment> = Arc::new(
|
|
||||||
Environment::builder()
|
|
||||||
.with_name("onnx home")
|
|
||||||
.with_log_level(LoggingLevel::Error)
|
|
||||||
.with_global_thread_pool(ARGS.onnx_global_thread_pool_options.clone())
|
|
||||||
.build()
|
|
||||||
.unwrap()
|
|
||||||
);
|
|
||||||
}
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct OnnxModel {
|
|
||||||
pub session: Session,
|
|
||||||
pub model_idx: usize,
|
|
||||||
pub version: i64,
|
|
||||||
pub export_dir: String,
|
|
||||||
pub output_filters: ArrayVec<usize, MAX_NUM_OUTPUTS>,
|
|
||||||
pub input_converter: Box<dyn Converter>,
|
|
||||||
}
|
|
||||||
impl Display for OnnxModel {
|
|
||||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
|
||||||
write!(
|
|
||||||
f,
|
|
||||||
"idx: {}, onnx model_name:{}, version:{}, output_filters:{:?}, converter:{:}",
|
|
||||||
self.model_idx,
|
|
||||||
MODEL_SPECS[self.model_idx],
|
|
||||||
self.version,
|
|
||||||
self.output_filters,
|
|
||||||
self.input_converter
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
impl Drop for OnnxModel {
|
|
||||||
fn drop(&mut self) {
|
|
||||||
if ARGS.profiling != None {
|
|
||||||
self.session.end_profiling().map_or_else(
|
|
||||||
|e| {
|
|
||||||
info!("end profiling with some error:{:?}", e);
|
|
||||||
},
|
|
||||||
|f| {
|
|
||||||
info!("profiling ended with file:{}", f);
|
|
||||||
},
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
impl OnnxModel {
|
|
||||||
fn get_output_filters(session: &Session, idx: usize) -> ArrayVec<usize, MAX_NUM_OUTPUTS> {
|
|
||||||
OUTPUTS[idx]
|
|
||||||
.iter()
|
|
||||||
.map(|output| session.outputs.iter().position(|o| o.name == *output))
|
|
||||||
.flatten()
|
|
||||||
.collect::<ArrayVec<usize, MAX_NUM_OUTPUTS>>()
|
|
||||||
}
|
|
||||||
#[cfg(target_os = "linux")]
|
|
||||||
fn ep_choices() -> Vec<ExecutionProvider> {
|
|
||||||
match ARGS.onnx_gpu_ep.as_ref().map(|e| e.as_str()) {
|
|
||||||
Some("onednn") => vec![Self::ep_with_options(ExecutionProvider::onednn())],
|
|
||||||
Some("tensorrt") => vec![Self::ep_with_options(ExecutionProvider::tensorrt())],
|
|
||||||
Some("cuda") => vec![Self::ep_with_options(ExecutionProvider::cuda())],
|
|
||||||
_ => vec![Self::ep_with_options(ExecutionProvider::cpu())],
|
|
||||||
}
|
|
||||||
}
|
|
||||||
fn ep_with_options(mut ep: ExecutionProvider) -> ExecutionProvider {
|
|
||||||
for (ref k, ref v) in ARGS.onnx_ep_options.clone() {
|
|
||||||
ep = ep.with(k, v);
|
|
||||||
info!("setting option:{} -> {} and now ep is:{:?}", k, v, ep);
|
|
||||||
}
|
|
||||||
ep
|
|
||||||
}
|
|
||||||
#[cfg(target_os = "macos")]
|
|
||||||
fn ep_choices() -> Vec<ExecutionProvider> {
|
|
||||||
vec![Self::ep_with_options(ExecutionProvider::cpu())]
|
|
||||||
}
|
|
||||||
pub fn new(idx: usize, version: String, model_config: &Value) -> Result<OnnxModel> {
|
|
||||||
let export_dir = format!("{}/{}/model.onnx", ARGS.model_dir[idx], version);
|
|
||||||
let meta_info = format!("{}/{}/{}", ARGS.model_dir[idx], version, META_INFO);
|
|
||||||
let mut builder = SessionBuilder::new(&ENVIRONMENT)?
|
|
||||||
.with_optimization_level(GraphOptimizationLevel::Level3)?
|
|
||||||
.with_parallel_execution(ARGS.onnx_use_parallel_mode == "true")?;
|
|
||||||
if ARGS.onnx_global_thread_pool_options.is_empty() {
|
|
||||||
builder = builder
|
|
||||||
.with_inter_threads(
|
|
||||||
utils::get_config_or(
|
|
||||||
model_config,
|
|
||||||
"inter_op_parallelism",
|
|
||||||
&ARGS.inter_op_parallelism[idx],
|
|
||||||
)
|
|
||||||
.parse()?,
|
|
||||||
)?
|
|
||||||
.with_intra_threads(
|
|
||||||
utils::get_config_or(
|
|
||||||
model_config,
|
|
||||||
"intra_op_parallelism",
|
|
||||||
&ARGS.intra_op_parallelism[idx],
|
|
||||||
)
|
|
||||||
.parse()?,
|
|
||||||
)?;
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
builder = builder.with_disable_per_session_threads()?;
|
|
||||||
}
|
|
||||||
builder = builder
|
|
||||||
.with_memory_pattern(ARGS.onnx_use_memory_pattern == "true")?
|
|
||||||
.with_execution_providers(&OnnxModel::ep_choices())?;
|
|
||||||
match &ARGS.profiling {
|
|
||||||
Some(p) => {
|
|
||||||
debug!("Enable profiling, writing to {}", *p);
|
|
||||||
builder = builder.with_profiling(p)?
|
|
||||||
}
|
|
||||||
_ => {}
|
|
||||||
}
|
|
||||||
let session = builder.with_model_from_file(&export_dir)?;
|
|
||||||
|
|
||||||
info!(
|
|
||||||
"inputs: {:?}, outputs: {:?}",
|
|
||||||
session.inputs.iter().format(","),
|
|
||||||
session.outputs.iter().format(",")
|
|
||||||
);
|
|
||||||
|
|
||||||
fs::read_to_string(&meta_info)
|
|
||||||
.ok()
|
|
||||||
.map(|info| info!("meta info:{}", info));
|
|
||||||
let output_filters = OnnxModel::get_output_filters(&session, idx);
|
|
||||||
let mut reporting_feature_ids: Vec<(i64, &str)> = vec![];
|
|
||||||
|
|
||||||
let input_spec_cell = &INPUTS[idx];
|
|
||||||
if input_spec_cell.get().is_none() {
|
|
||||||
let input_spec = session
|
|
||||||
.inputs
|
|
||||||
.iter()
|
|
||||||
.map(|input| input.name.clone())
|
|
||||||
.collect::<ArrayVec<String, MAX_NUM_INPUTS>>();
|
|
||||||
input_spec_cell.set(input_spec.clone()).map_or_else(
|
|
||||||
|_| info!("unable to set the input_spec for model {}", idx),
|
|
||||||
|_| info!("auto detect and set the inputs: {:?}", input_spec),
|
|
||||||
);
|
|
||||||
}
|
|
||||||
ARGS.onnx_report_discrete_feature_ids
|
|
||||||
.iter()
|
|
||||||
.for_each(|ids| {
|
|
||||||
ids.split(",")
|
|
||||||
.filter(|s| !s.is_empty())
|
|
||||||
.map(|s| s.parse::<i64>().unwrap())
|
|
||||||
.for_each(|id| reporting_feature_ids.push((id, "discrete")))
|
|
||||||
});
|
|
||||||
ARGS.onnx_report_continuous_feature_ids
|
|
||||||
.iter()
|
|
||||||
.for_each(|ids| {
|
|
||||||
ids.split(",")
|
|
||||||
.filter(|s| !s.is_empty())
|
|
||||||
.map(|s| s.parse::<i64>().unwrap())
|
|
||||||
.for_each(|id| reporting_feature_ids.push((id, "continuous")))
|
|
||||||
});
|
|
||||||
|
|
||||||
let onnx_model = OnnxModel {
|
|
||||||
session,
|
|
||||||
model_idx: idx,
|
|
||||||
version: Args::version_str_to_epoch(&version)?,
|
|
||||||
export_dir,
|
|
||||||
output_filters,
|
|
||||||
input_converter: Box::new(BatchPredictionRequestToTorchTensorConverter::new(
|
|
||||||
&ARGS.model_dir[idx],
|
|
||||||
&version,
|
|
||||||
reporting_feature_ids,
|
|
||||||
Some(metrics::register_dynamic_metrics),
|
|
||||||
)?),
|
|
||||||
};
|
|
||||||
onnx_model.warmup()?;
|
|
||||||
Ok(onnx_model)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
///Currently we only assume the input as just one string tensor.
|
|
||||||
///The string tensor will be be converted to the actual raw tensors.
|
|
||||||
/// The converter we are using is very specific to home.
|
|
||||||
/// It reads a BatchDataRecord thrift and decode it to a batch of raw input tensors.
|
|
||||||
/// Navi will then do server side batching and feed it to ONNX runtime
|
|
||||||
impl Model for OnnxModel {
|
|
||||||
//TODO: implement a generic online warmup for all runtimes
|
|
||||||
fn warmup(&self) -> Result<()> {
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[inline(always)]
|
|
||||||
fn do_predict(
|
|
||||||
&self,
|
|
||||||
input_tensors: Vec<Vec<TensorInput>>,
|
|
||||||
_: u64,
|
|
||||||
) -> (Vec<TensorReturnEnum>, Vec<Vec<usize>>) {
|
|
||||||
let batched_tensors = TensorInputEnum::merge_batch(input_tensors);
|
|
||||||
let (inputs, batch_ends): (Vec<Vec<InputTensor>>, Vec<Vec<usize>>) = batched_tensors
|
|
||||||
.into_iter()
|
|
||||||
.map(|batched_tensor| {
|
|
||||||
match batched_tensor.tensor_data {
|
|
||||||
TensorInputEnum::String(t) if ARGS.onnx_use_converter.is_some() => {
|
|
||||||
let start = Instant::now();
|
|
||||||
let (inputs, batch_ends) = self.input_converter.convert(t);
|
|
||||||
// info!("batch_ends:{:?}", batch_ends);
|
|
||||||
CONVERTER_TIME_COLLECTOR
|
|
||||||
.with_label_values(&[&MODEL_SPECS[self.model_idx()]])
|
|
||||||
.observe(
|
|
||||||
start.elapsed().as_micros() as f64
|
|
||||||
/ (*batch_ends.last().unwrap() as f64),
|
|
||||||
);
|
|
||||||
(inputs, batch_ends)
|
|
||||||
}
|
|
||||||
_ => unimplemented!(),
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.unzip();
|
|
||||||
//invariant we only support one input as string. will relax later
|
|
||||||
assert_eq!(inputs.len(), 1);
|
|
||||||
let output_tensors = self
|
|
||||||
.session
|
|
||||||
.run(inputs.into_iter().flatten().collect::<Vec<_>>())
|
|
||||||
.unwrap();
|
|
||||||
self.output_filters
|
|
||||||
.iter()
|
|
||||||
.map(|&idx| {
|
|
||||||
let mut size = 1usize;
|
|
||||||
let output = output_tensors[idx].try_extract::<f32>().unwrap();
|
|
||||||
for &dim in self.session.outputs[idx].dimensions.iter().flatten() {
|
|
||||||
size *= dim as usize;
|
|
||||||
}
|
|
||||||
let tensor_ends = batch_ends[0]
|
|
||||||
.iter()
|
|
||||||
.map(|&batch| batch * size)
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
|
|
||||||
(
|
|
||||||
//only works for batch major
|
|
||||||
//TODO: to_vec() obviously wasteful, especially for large batches(GPU) . Will refactor to
|
|
||||||
//break up output and return Vec<Vec<TensorScore>> here
|
|
||||||
TensorReturnEnum::FloatTensorReturn(Box::new(output.view().as_slice().unwrap().to_vec(),
|
|
||||||
)),
|
|
||||||
tensor_ends,
|
|
||||||
)
|
|
||||||
})
|
|
||||||
.unzip()
|
|
||||||
}
|
|
||||||
#[inline(always)]
|
|
||||||
fn model_idx(&self) -> usize {
|
|
||||||
self.model_idx
|
|
||||||
}
|
|
||||||
#[inline(always)]
|
|
||||||
fn version(&self) -> i64 {
|
|
||||||
self.version
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
BIN
navi/navi/src/predict_service.docx
Normal file
BIN
navi/navi/src/predict_service.docx
Normal file
Binary file not shown.
@ -1,315 +0,0 @@
|
|||||||
use anyhow::{anyhow, Result};
|
|
||||||
use arrayvec::ArrayVec;
|
|
||||||
use itertools::Itertools;
|
|
||||||
use log::{error, info};
|
|
||||||
use std::fmt::{Debug, Display};
|
|
||||||
use std::string::String;
|
|
||||||
use std::sync::Arc;
|
|
||||||
use std::time::Duration;
|
|
||||||
use tokio::process::Command;
|
|
||||||
use tokio::sync::mpsc::error::TryRecvError;
|
|
||||||
use tokio::sync::mpsc::{Receiver, Sender};
|
|
||||||
use tokio::sync::{mpsc, oneshot};
|
|
||||||
use tokio::time::{Instant, sleep};
|
|
||||||
use warp::Filter;
|
|
||||||
|
|
||||||
use crate::batch::BatchPredictor;
|
|
||||||
use crate::bootstrap::TensorInput;
|
|
||||||
use crate::{MAX_NUM_MODELS, MAX_VERSIONS_PER_MODEL, META_INFO, metrics, ModelFactory, PredictMessage, PredictResult, TensorReturnEnum, utils};
|
|
||||||
|
|
||||||
use crate::cli_args::{ARGS, MODEL_SPECS};
|
|
||||||
use crate::cores::validator::validatior::cli_validator;
|
|
||||||
use crate::metrics::MPSC_CHANNEL_SIZE;
|
|
||||||
use serde_json::{self, Value};
|
|
||||||
|
|
||||||
pub trait Model: Send + Sync + Display + Debug + 'static {
|
|
||||||
fn warmup(&self) -> Result<()>;
|
|
||||||
//TODO: refactor this to return vec<vec<TensorScores>>, i.e.
|
|
||||||
//we have the underlying runtime impl to split the response to each client.
|
|
||||||
//It will eliminate some inefficient memory copy in onnx_model.rs as well as simplify code
|
|
||||||
fn do_predict(
|
|
||||||
&self,
|
|
||||||
input_tensors: Vec<Vec<TensorInput>>,
|
|
||||||
total_len: u64,
|
|
||||||
) -> (Vec<TensorReturnEnum>, Vec<Vec<usize>>);
|
|
||||||
fn model_idx(&self) -> usize;
|
|
||||||
fn version(&self) -> i64;
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct PredictService<T: Model> {
|
|
||||||
tx: Sender<PredictMessage<T>>,
|
|
||||||
}
|
|
||||||
impl<T: Model> PredictService<T> {
|
|
||||||
pub async fn init(model_factory: ModelFactory<T>) -> Self {
|
|
||||||
cli_validator::validate_ps_model_args();
|
|
||||||
let (tx, rx) = mpsc::channel(32_000);
|
|
||||||
tokio::spawn(PredictService::tf_queue_manager(rx));
|
|
||||||
tokio::spawn(PredictService::model_watcher_latest(
|
|
||||||
model_factory,
|
|
||||||
tx.clone(),
|
|
||||||
));
|
|
||||||
let metrics_route = warp::path!("metrics").and_then(metrics::metrics_handler);
|
|
||||||
let metric_server = warp::serve(metrics_route).run(([0, 0, 0, 0], ARGS.prometheus_port));
|
|
||||||
tokio::spawn(metric_server);
|
|
||||||
PredictService { tx }
|
|
||||||
}
|
|
||||||
#[inline(always)]
|
|
||||||
pub async fn predict(
|
|
||||||
&self,
|
|
||||||
idx: usize,
|
|
||||||
version: Option<i64>,
|
|
||||||
val: Vec<TensorInput>,
|
|
||||||
ts: Instant,
|
|
||||||
) -> Result<PredictResult> {
|
|
||||||
let (tx, rx) = oneshot::channel();
|
|
||||||
if let Err(e) = self
|
|
||||||
.tx
|
|
||||||
.clone()
|
|
||||||
.send(PredictMessage::Predict(idx, version, val, tx, ts))
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
error!("mpsc send error:{}", e);
|
|
||||||
Err(anyhow!(e))
|
|
||||||
} else {
|
|
||||||
MPSC_CHANNEL_SIZE.inc();
|
|
||||||
rx.await.map_err(anyhow::Error::msg)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn load_latest_model_from_model_dir(
|
|
||||||
model_factory: ModelFactory<T>,
|
|
||||||
model_config: &Value,
|
|
||||||
tx: Sender<PredictMessage<T>>,
|
|
||||||
idx: usize,
|
|
||||||
max_version: String,
|
|
||||||
latest_version: &mut String,
|
|
||||||
) {
|
|
||||||
match model_factory(idx, max_version.clone(), model_config) {
|
|
||||||
Ok(tf_model) => tx
|
|
||||||
.send(PredictMessage::UpsertModel(tf_model))
|
|
||||||
.await
|
|
||||||
.map_or_else(
|
|
||||||
|e| error!("send UpsertModel error: {}", e),
|
|
||||||
|_| *latest_version = max_version,
|
|
||||||
),
|
|
||||||
Err(e) => {
|
|
||||||
error!("skip loading model due to failure: {:?}", e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn scan_load_latest_model_from_model_dir(
|
|
||||||
model_factory: ModelFactory<T>,
|
|
||||||
model_config: &Value,
|
|
||||||
tx: Sender<PredictMessage<T>>,
|
|
||||||
model_idx: usize,
|
|
||||||
cur_version: &mut String,
|
|
||||||
) -> Result<()> {
|
|
||||||
let model_dir = &ARGS.model_dir[model_idx];
|
|
||||||
let next_version = utils::get_config_or_else(model_config, "version", || {
|
|
||||||
info!("no version found, hence use max version");
|
|
||||||
std::fs::read_dir(model_dir)
|
|
||||||
.map_err(|e| format!("read dir error:{}", e))
|
|
||||||
.and_then(|paths| {
|
|
||||||
paths
|
|
||||||
.into_iter()
|
|
||||||
.flat_map(|p| {
|
|
||||||
p.map_err(|e| error!("dir entry error: {}", e))
|
|
||||||
.and_then(|dir| {
|
|
||||||
dir.file_name()
|
|
||||||
.into_string()
|
|
||||||
.map_err(|e| error!("osstring error: {:?}", e))
|
|
||||||
})
|
|
||||||
.ok()
|
|
||||||
})
|
|
||||||
.filter(|f| !f.to_lowercase().contains(&META_INFO.to_lowercase()))
|
|
||||||
.max()
|
|
||||||
.ok_or_else(|| "no dir found hence no max".to_owned())
|
|
||||||
})
|
|
||||||
.unwrap_or_else(|e| {
|
|
||||||
error!(
|
|
||||||
"can't get the max version hence return cur_version, error is: {}",
|
|
||||||
e
|
|
||||||
);
|
|
||||||
cur_version.to_string()
|
|
||||||
})
|
|
||||||
});
|
|
||||||
//as long as next version doesn't match cur version maintained we reload
|
|
||||||
if next_version.ne(cur_version) {
|
|
||||||
info!("reload the version: {}->{}", cur_version, next_version);
|
|
||||||
PredictService::load_latest_model_from_model_dir(
|
|
||||||
model_factory,
|
|
||||||
model_config,
|
|
||||||
tx,
|
|
||||||
model_idx,
|
|
||||||
next_version,
|
|
||||||
cur_version,
|
|
||||||
)
|
|
||||||
.await;
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn model_watcher_latest(model_factory: ModelFactory<T>, tx: Sender<PredictMessage<T>>) {
|
|
||||||
async fn call_external_modelsync(cli: &str, cur_versions: &Vec<String>) -> Result<()> {
|
|
||||||
let mut args = cli.split_whitespace();
|
|
||||||
|
|
||||||
let mut cmd = Command::new(args.next().ok_or(anyhow!("model sync cli empty"))?);
|
|
||||||
let extr_args = MODEL_SPECS
|
|
||||||
.iter()
|
|
||||||
.zip(cur_versions)
|
|
||||||
.flat_map(|(spec, version)| vec!["--model-spec", spec, "--cur-version", version])
|
|
||||||
.collect_vec();
|
|
||||||
info!("run model sync: {} with extra args: {:?}", cli, extr_args);
|
|
||||||
let output = cmd.args(args).args(extr_args).output().await?;
|
|
||||||
info!("model sync stdout:{}", String::from_utf8(output.stdout)?);
|
|
||||||
info!("model sync stderr:{}", String::from_utf8(output.stderr)?);
|
|
||||||
if output.status.success() {
|
|
||||||
Ok(())
|
|
||||||
} else {
|
|
||||||
Err(anyhow!(
|
|
||||||
"model sync failed with status: {:?}!",
|
|
||||||
output.status
|
|
||||||
))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
let meta_dir = utils::get_meta_dir();
|
|
||||||
let meta_file = format!("{}{}", meta_dir, META_INFO);
|
|
||||||
//initialize the latest version array
|
|
||||||
let mut cur_versions = vec!["".to_owned(); MODEL_SPECS.len()];
|
|
||||||
loop {
|
|
||||||
info!("***polling for models***"); //nice deliminter
|
|
||||||
if let Some(ref cli) = ARGS.modelsync_cli {
|
|
||||||
if let Err(e) = call_external_modelsync(cli, &cur_versions).await {
|
|
||||||
error!("model sync cli running error:{}", e)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
let config = utils::read_config(&meta_file).unwrap_or_else(|e| {
|
|
||||||
info!("config file {} not found due to: {}", meta_file, e);
|
|
||||||
Value::Null
|
|
||||||
});
|
|
||||||
info!("config:{}", config);
|
|
||||||
for (idx, cur_version) in cur_versions.iter_mut().enumerate() {
|
|
||||||
let model_dir = &ARGS.model_dir[idx];
|
|
||||||
PredictService::scan_load_latest_model_from_model_dir(
|
|
||||||
model_factory,
|
|
||||||
&config[&MODEL_SPECS[idx]],
|
|
||||||
tx.clone(),
|
|
||||||
idx,
|
|
||||||
cur_version,
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
.map_or_else(
|
|
||||||
|e| error!("scanned {}, error {:?}", model_dir, e),
|
|
||||||
|_| info!("scanned {}, latest_version: {}", model_dir, cur_version),
|
|
||||||
);
|
|
||||||
}
|
|
||||||
sleep(Duration::from_secs(ARGS.model_check_interval_secs)).await;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
async fn tf_queue_manager(mut rx: Receiver<PredictMessage<T>>) {
|
|
||||||
// Start receiving messages
|
|
||||||
info!("setting up queue manager");
|
|
||||||
let max_batch_size = ARGS
|
|
||||||
.max_batch_size
|
|
||||||
.iter()
|
|
||||||
.map(|b| b.parse().unwrap())
|
|
||||||
.collect::<Vec<usize>>();
|
|
||||||
let batch_time_out_millis = ARGS
|
|
||||||
.batch_time_out_millis
|
|
||||||
.iter()
|
|
||||||
.map(|b| b.parse().unwrap())
|
|
||||||
.collect::<Vec<u64>>();
|
|
||||||
let no_msg_wait_millis = *batch_time_out_millis.iter().min().unwrap();
|
|
||||||
let mut all_model_predictors: ArrayVec::<ArrayVec<BatchPredictor<T>, MAX_VERSIONS_PER_MODEL>, MAX_NUM_MODELS> =
|
|
||||||
(0 ..MAX_NUM_MODELS).map( |_| ArrayVec::<BatchPredictor<T>, MAX_VERSIONS_PER_MODEL>::new()).collect();
|
|
||||||
loop {
|
|
||||||
let msg = rx.try_recv();
|
|
||||||
let no_more_msg = match msg {
|
|
||||||
Ok(PredictMessage::Predict(model_spec_at, version, val, resp, ts)) => {
|
|
||||||
if let Some(model_predictors) = all_model_predictors.get_mut(model_spec_at) {
|
|
||||||
if model_predictors.is_empty() {
|
|
||||||
resp.send(PredictResult::ModelNotReady(model_spec_at))
|
|
||||||
.unwrap_or_else(|e| error!("cannot send back model not ready error: {:?}", e));
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
match version {
|
|
||||||
None => model_predictors[0].push(val, resp, ts),
|
|
||||||
Some(the_version) => match model_predictors
|
|
||||||
.iter_mut()
|
|
||||||
.find(|x| x.model.version() == the_version)
|
|
||||||
{
|
|
||||||
None => resp
|
|
||||||
.send(PredictResult::ModelVersionNotFound(
|
|
||||||
model_spec_at,
|
|
||||||
the_version,
|
|
||||||
))
|
|
||||||
.unwrap_or_else(|e| {
|
|
||||||
error!("cannot send back version error: {:?}", e)
|
|
||||||
}),
|
|
||||||
Some(predictor) => predictor.push(val, resp, ts),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
resp.send(PredictResult::ModelNotFound(model_spec_at))
|
|
||||||
.unwrap_or_else(|e| error!("cannot send back model not found error: {:?}", e))
|
|
||||||
}
|
|
||||||
MPSC_CHANNEL_SIZE.dec();
|
|
||||||
false
|
|
||||||
}
|
|
||||||
Ok(PredictMessage::UpsertModel(tf_model)) => {
|
|
||||||
let idx = tf_model.model_idx();
|
|
||||||
let predictor = BatchPredictor {
|
|
||||||
model: Arc::new(tf_model),
|
|
||||||
input_tensors: Vec::with_capacity(max_batch_size[idx]),
|
|
||||||
callbacks: Vec::with_capacity(max_batch_size[idx]),
|
|
||||||
cur_batch_size: 0,
|
|
||||||
max_batch_size: max_batch_size[idx],
|
|
||||||
batch_time_out_millis: batch_time_out_millis[idx],
|
|
||||||
//initialize to be current time
|
|
||||||
queue_reset_ts: Instant::now(),
|
|
||||||
queue_earliest_rq_ts: Instant::now(),
|
|
||||||
};
|
|
||||||
assert!(idx < all_model_predictors.len());
|
|
||||||
metrics::NEW_MODEL_SNAPSHOT
|
|
||||||
.with_label_values(&[&MODEL_SPECS[idx]])
|
|
||||||
.inc();
|
|
||||||
|
|
||||||
//we can do this since the vector is small
|
|
||||||
let predictors = &mut all_model_predictors[idx];
|
|
||||||
if predictors.len() == 0 {
|
|
||||||
info!("now we serve new model: {}", predictor.model);
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
info!("now we serve updated model: {}", predictor.model);
|
|
||||||
}
|
|
||||||
if predictors.len() == ARGS.versions_per_model {
|
|
||||||
predictors.remove(predictors.len() - 1);
|
|
||||||
}
|
|
||||||
predictors.insert(0, predictor);
|
|
||||||
false
|
|
||||||
}
|
|
||||||
Err(TryRecvError::Empty) => true,
|
|
||||||
Err(TryRecvError::Disconnected) => true,
|
|
||||||
};
|
|
||||||
for predictor in all_model_predictors.iter_mut().flatten() {
|
|
||||||
//if predictor batch queue not empty and times out or no more msg in the queue, flush
|
|
||||||
if (!predictor.input_tensors.is_empty() && (predictor.duration_past(predictor.batch_time_out_millis) || no_more_msg))
|
|
||||||
//if batch queue reaches limit, flush
|
|
||||||
|| predictor.cur_batch_size >= predictor.max_batch_size
|
|
||||||
{
|
|
||||||
predictor.batch_predict();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if no_more_msg {
|
|
||||||
sleep(Duration::from_millis(no_msg_wait_millis)).await;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
#[inline(always)]
|
|
||||||
pub fn get_model_index(model_spec: &str) -> Option<usize> {
|
|
||||||
MODEL_SPECS.iter().position(|m| m == model_spec)
|
|
||||||
}
|
|
||||||
}
|
|
BIN
navi/navi/src/tf_model.docx
Normal file
BIN
navi/navi/src/tf_model.docx
Normal file
Binary file not shown.
@ -1,492 +0,0 @@
|
|||||||
#[cfg(feature = "tf")]
|
|
||||||
pub mod tf {
|
|
||||||
use arrayvec::ArrayVec;
|
|
||||||
use itertools::Itertools;
|
|
||||||
use log::{debug, error, info, warn};
|
|
||||||
use prost::Message;
|
|
||||||
use std::fmt;
|
|
||||||
use std::fmt::Display;
|
|
||||||
use std::string::String;
|
|
||||||
use tensorflow::io::{RecordReader, RecordReadError};
|
|
||||||
use tensorflow::Operation;
|
|
||||||
use tensorflow::SavedModelBundle;
|
|
||||||
use tensorflow::SessionOptions;
|
|
||||||
use tensorflow::SessionRunArgs;
|
|
||||||
use tensorflow::Tensor;
|
|
||||||
use tensorflow::{DataType, FetchToken, Graph, TensorInfo, TensorType};
|
|
||||||
|
|
||||||
use std::thread::sleep;
|
|
||||||
use std::time::Duration;
|
|
||||||
|
|
||||||
use crate::cli_args::{Args, ARGS, INPUTS, MODEL_SPECS, OUTPUTS};
|
|
||||||
use crate::tf_proto::tensorflow_serving::prediction_log::LogType;
|
|
||||||
use crate::tf_proto::tensorflow_serving::{PredictionLog, PredictLog};
|
|
||||||
use crate::tf_proto::ConfigProto;
|
|
||||||
use anyhow::{Context, Result};
|
|
||||||
use serde_json::Value;
|
|
||||||
|
|
||||||
use crate::TensorReturnEnum;
|
|
||||||
use crate::bootstrap::{TensorInput, TensorInputEnum};
|
|
||||||
use crate::metrics::{
|
|
||||||
INFERENCE_FAILED_REQUESTS_BY_MODEL, NUM_REQUESTS_FAILED, NUM_REQUESTS_FAILED_BY_MODEL,
|
|
||||||
};
|
|
||||||
use crate::predict_service::Model;
|
|
||||||
use crate::{MAX_NUM_INPUTS, utils};
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub enum TFTensorEnum {
|
|
||||||
String(Tensor<String>),
|
|
||||||
Int(Tensor<i32>),
|
|
||||||
Int64(Tensor<i64>),
|
|
||||||
Float(Tensor<f32>),
|
|
||||||
Double(Tensor<f64>),
|
|
||||||
Boolean(Tensor<bool>),
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct TFModel {
|
|
||||||
pub model_idx: usize,
|
|
||||||
pub bundle: SavedModelBundle,
|
|
||||||
pub input_names: ArrayVec<String, MAX_NUM_INPUTS>,
|
|
||||||
pub input_info: Vec<TensorInfo>,
|
|
||||||
pub input_ops: Vec<Operation>,
|
|
||||||
pub output_names: Vec<String>,
|
|
||||||
pub output_info: Vec<TensorInfo>,
|
|
||||||
pub output_ops: Vec<Operation>,
|
|
||||||
pub export_dir: String,
|
|
||||||
pub version: i64,
|
|
||||||
pub inter_op: i32,
|
|
||||||
pub intra_op: i32,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Display for TFModel {
|
|
||||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
|
||||||
write!(
|
|
||||||
f,
|
|
||||||
"idx: {}, tensorflow model_name:{}, export_dir:{}, version:{}, inter:{}, intra:{}",
|
|
||||||
self.model_idx,
|
|
||||||
MODEL_SPECS[self.model_idx],
|
|
||||||
self.export_dir,
|
|
||||||
self.version,
|
|
||||||
self.inter_op,
|
|
||||||
self.intra_op
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl TFModel {
|
|
||||||
pub fn new(idx: usize, version: String, model_config: &Value) -> Result<TFModel> {
|
|
||||||
// Create input variables for our addition
|
|
||||||
let config = ConfigProto {
|
|
||||||
intra_op_parallelism_threads: utils::get_config_or(
|
|
||||||
model_config,
|
|
||||||
"intra_op_parallelism",
|
|
||||||
&ARGS.intra_op_parallelism[idx],
|
|
||||||
)
|
|
||||||
.parse()?,
|
|
||||||
inter_op_parallelism_threads: utils::get_config_or(
|
|
||||||
model_config,
|
|
||||||
"inter_op_parallelism",
|
|
||||||
&ARGS.inter_op_parallelism[idx],
|
|
||||||
)
|
|
||||||
.parse()?,
|
|
||||||
..Default::default()
|
|
||||||
};
|
|
||||||
let mut buf = Vec::new();
|
|
||||||
buf.reserve(config.encoded_len());
|
|
||||||
config.encode(&mut buf).unwrap();
|
|
||||||
let mut opts = SessionOptions::new();
|
|
||||||
opts.set_config(&buf)?;
|
|
||||||
let export_dir = format!("{}/{}", ARGS.model_dir[idx], version);
|
|
||||||
let mut graph = Graph::new();
|
|
||||||
let bundle = SavedModelBundle::load(&opts, ["serve"], &mut graph, &export_dir)
|
|
||||||
.context("error load model")?;
|
|
||||||
let signature = bundle
|
|
||||||
.meta_graph_def()
|
|
||||||
.get_signature(&ARGS.serving_sig[idx])
|
|
||||||
.context("error finding signature")?;
|
|
||||||
let input_names = INPUTS[idx]
|
|
||||||
.get_or_init(|| {
|
|
||||||
let input_spec = signature
|
|
||||||
.inputs()
|
|
||||||
.iter()
|
|
||||||
.map(|p| p.0.clone())
|
|
||||||
.collect::<ArrayVec<String, MAX_NUM_INPUTS>>();
|
|
||||||
info!(
|
|
||||||
"input not set from cli, now we set from model metadata:{:?}",
|
|
||||||
input_spec
|
|
||||||
);
|
|
||||||
input_spec
|
|
||||||
})
|
|
||||||
.clone();
|
|
||||||
let input_info = input_names
|
|
||||||
.iter()
|
|
||||||
.map(|i| {
|
|
||||||
signature
|
|
||||||
.get_input(i)
|
|
||||||
.context("error finding input op info")
|
|
||||||
.unwrap()
|
|
||||||
.clone()
|
|
||||||
})
|
|
||||||
.collect_vec();
|
|
||||||
|
|
||||||
let input_ops = input_info
|
|
||||||
.iter()
|
|
||||||
.map(|i| {
|
|
||||||
graph
|
|
||||||
.operation_by_name_required(&i.name().name)
|
|
||||||
.context("error finding input op")
|
|
||||||
.unwrap()
|
|
||||||
})
|
|
||||||
.collect_vec();
|
|
||||||
|
|
||||||
info!("Model Input size: {}", input_info.len());
|
|
||||||
|
|
||||||
let output_names = OUTPUTS[idx].to_vec().clone();
|
|
||||||
|
|
||||||
let output_info = output_names
|
|
||||||
.iter()
|
|
||||||
.map(|o| {
|
|
||||||
signature
|
|
||||||
.get_output(o)
|
|
||||||
.context("error finding output op info")
|
|
||||||
.unwrap()
|
|
||||||
.clone()
|
|
||||||
})
|
|
||||||
.collect_vec();
|
|
||||||
|
|
||||||
let output_ops = output_info
|
|
||||||
.iter()
|
|
||||||
.map(|o| {
|
|
||||||
graph
|
|
||||||
.operation_by_name_required(&o.name().name)
|
|
||||||
.context("error finding output op")
|
|
||||||
.unwrap()
|
|
||||||
})
|
|
||||||
.collect_vec();
|
|
||||||
|
|
||||||
let tf_model = TFModel {
|
|
||||||
model_idx: idx,
|
|
||||||
bundle,
|
|
||||||
input_names,
|
|
||||||
input_info,
|
|
||||||
input_ops,
|
|
||||||
output_names,
|
|
||||||
output_info,
|
|
||||||
output_ops,
|
|
||||||
export_dir,
|
|
||||||
version: Args::version_str_to_epoch(&version)?,
|
|
||||||
inter_op: config.inter_op_parallelism_threads,
|
|
||||||
intra_op: config.intra_op_parallelism_threads,
|
|
||||||
};
|
|
||||||
tf_model.warmup()?;
|
|
||||||
Ok(tf_model)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[inline(always)]
|
|
||||||
fn get_tftensor_dimensions<T>(
|
|
||||||
t: &[T],
|
|
||||||
input_size: u64,
|
|
||||||
batch_size: u64,
|
|
||||||
input_dims: Option<Vec<i64>>,
|
|
||||||
) -> Vec<u64> {
|
|
||||||
// if input size is 1, we just specify a single dimension to outgoing tensor matching the
|
|
||||||
// size of the input tensor. This is for backwards compatiblity with existing Navi clients
|
|
||||||
// which specify input as a single string tensor (like tfexample) and use batching support.
|
|
||||||
let mut dims = vec![];
|
|
||||||
if input_size > 1 {
|
|
||||||
if batch_size == 1 && input_dims.is_some() {
|
|
||||||
// client side batching is enabled?
|
|
||||||
input_dims
|
|
||||||
.unwrap()
|
|
||||||
.iter()
|
|
||||||
.for_each(|axis| dims.push(*axis as u64));
|
|
||||||
} else {
|
|
||||||
dims.push(batch_size);
|
|
||||||
dims.push(t.len() as u64 / batch_size);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
dims.push(t.len() as u64);
|
|
||||||
}
|
|
||||||
dims
|
|
||||||
}
|
|
||||||
|
|
||||||
fn convert_to_tftensor_enum(
|
|
||||||
input: TensorInput,
|
|
||||||
input_size: u64,
|
|
||||||
batch_size: u64,
|
|
||||||
) -> TFTensorEnum {
|
|
||||||
match input.tensor_data {
|
|
||||||
TensorInputEnum::String(t) => {
|
|
||||||
let strings = t
|
|
||||||
.into_iter()
|
|
||||||
.map(|x| unsafe { String::from_utf8_unchecked(x) })
|
|
||||||
.collect_vec();
|
|
||||||
TFTensorEnum::String(
|
|
||||||
Tensor::new(&TFModel::get_tftensor_dimensions(
|
|
||||||
strings.as_slice(),
|
|
||||||
input_size,
|
|
||||||
batch_size,
|
|
||||||
input.dims,
|
|
||||||
))
|
|
||||||
.with_values(strings.as_slice())
|
|
||||||
.unwrap(),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
TensorInputEnum::Int(t) => TFTensorEnum::Int(
|
|
||||||
Tensor::new(&TFModel::get_tftensor_dimensions(
|
|
||||||
t.as_slice(),
|
|
||||||
input_size,
|
|
||||||
batch_size,
|
|
||||||
input.dims,
|
|
||||||
))
|
|
||||||
.with_values(t.as_slice())
|
|
||||||
.unwrap(),
|
|
||||||
),
|
|
||||||
TensorInputEnum::Int64(t) => TFTensorEnum::Int64(
|
|
||||||
Tensor::new(&TFModel::get_tftensor_dimensions(
|
|
||||||
t.as_slice(),
|
|
||||||
input_size,
|
|
||||||
batch_size,
|
|
||||||
input.dims,
|
|
||||||
))
|
|
||||||
.with_values(t.as_slice())
|
|
||||||
.unwrap(),
|
|
||||||
),
|
|
||||||
TensorInputEnum::Float(t) => TFTensorEnum::Float(
|
|
||||||
Tensor::new(&TFModel::get_tftensor_dimensions(
|
|
||||||
t.as_slice(),
|
|
||||||
input_size,
|
|
||||||
batch_size,
|
|
||||||
input.dims,
|
|
||||||
))
|
|
||||||
.with_values(t.as_slice())
|
|
||||||
.unwrap(),
|
|
||||||
),
|
|
||||||
TensorInputEnum::Double(t) => TFTensorEnum::Double(
|
|
||||||
Tensor::new(&TFModel::get_tftensor_dimensions(
|
|
||||||
t.as_slice(),
|
|
||||||
input_size,
|
|
||||||
batch_size,
|
|
||||||
input.dims,
|
|
||||||
))
|
|
||||||
.with_values(t.as_slice())
|
|
||||||
.unwrap(),
|
|
||||||
),
|
|
||||||
TensorInputEnum::Boolean(t) => TFTensorEnum::Boolean(
|
|
||||||
Tensor::new(&TFModel::get_tftensor_dimensions(
|
|
||||||
t.as_slice(),
|
|
||||||
input_size,
|
|
||||||
batch_size,
|
|
||||||
input.dims,
|
|
||||||
))
|
|
||||||
.with_values(t.as_slice())
|
|
||||||
.unwrap(),
|
|
||||||
),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
fn fetch_output<T: TensorType>(
|
|
||||||
args: &mut SessionRunArgs,
|
|
||||||
token_output: &FetchToken,
|
|
||||||
batch_size: u64,
|
|
||||||
output_size: u64,
|
|
||||||
) -> (Tensor<T>, u64) {
|
|
||||||
let tensor_output = args.fetch::<T>(*token_output).expect("fetch output failed");
|
|
||||||
let mut tensor_width = tensor_output.dims()[1];
|
|
||||||
if batch_size == 1 && output_size > 1 {
|
|
||||||
tensor_width = tensor_output.dims().iter().fold(1, |mut total, &val| {
|
|
||||||
total *= val;
|
|
||||||
total
|
|
||||||
});
|
|
||||||
}
|
|
||||||
(tensor_output, tensor_width)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Model for TFModel {
|
|
||||||
fn warmup(&self) -> Result<()> {
|
|
||||||
// warm up
|
|
||||||
let warmup_file = format!(
|
|
||||||
"{}/assets.extra/tf_serving_warmup_requests",
|
|
||||||
self.export_dir
|
|
||||||
);
|
|
||||||
if std::path::Path::new(&warmup_file).exists() {
|
|
||||||
use std::io::Cursor;
|
|
||||||
info!(
|
|
||||||
"found warmup assets in {}, now perform warming up",
|
|
||||||
warmup_file
|
|
||||||
);
|
|
||||||
let f = std::fs::File::open(warmup_file).context("cannot open warmup file")?;
|
|
||||||
// let mut buf = Vec::new();
|
|
||||||
let read = std::io::BufReader::new(f);
|
|
||||||
let mut reader = RecordReader::new(read);
|
|
||||||
let mut warmup_cnt = 0;
|
|
||||||
loop {
|
|
||||||
let next = reader.read_next_owned();
|
|
||||||
match next {
|
|
||||||
Ok(res) => match res {
|
|
||||||
Some(vec) => {
|
|
||||||
// info!("read one tfRecord");
|
|
||||||
match PredictionLog::decode(&mut Cursor::new(vec))
|
|
||||||
.context("can't parse PredictonLog")?
|
|
||||||
{
|
|
||||||
PredictionLog {
|
|
||||||
log_metadata: _,
|
|
||||||
log_type:
|
|
||||||
Some(LogType::PredictLog(PredictLog {
|
|
||||||
request: Some(mut req),
|
|
||||||
response: _,
|
|
||||||
})),
|
|
||||||
} => {
|
|
||||||
if warmup_cnt == ARGS.max_warmup_records {
|
|
||||||
//warm up to max_warmup_records records
|
|
||||||
warn!(
|
|
||||||
"reached max warmup {} records, exit warmup for {}",
|
|
||||||
ARGS.max_warmup_records,
|
|
||||||
MODEL_SPECS[self.model_idx]
|
|
||||||
);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
self.do_predict(
|
|
||||||
vec![req.take_input_vals(&self.input_names)],
|
|
||||||
1,
|
|
||||||
);
|
|
||||||
sleep(Duration::from_millis(100));
|
|
||||||
warmup_cnt += 1;
|
|
||||||
}
|
|
||||||
_ => error!("some wrong record in warming up file"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
None => {
|
|
||||||
info!("end of warmup file, warmed up with records: {}", warmup_cnt);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
},
|
|
||||||
Err(RecordReadError::CorruptFile)
|
|
||||||
| Err(RecordReadError::IoError { .. }) => {
|
|
||||||
error!("read tfrecord error for warmup files, skip");
|
|
||||||
}
|
|
||||||
_ => {}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[inline(always)]
|
|
||||||
fn do_predict(
|
|
||||||
&self,
|
|
||||||
input_tensors: Vec<Vec<TensorInput>>,
|
|
||||||
batch_size: u64,
|
|
||||||
) -> (Vec<TensorReturnEnum>, Vec<Vec<usize>>) {
|
|
||||||
// let mut batch_ends = input_tensors.iter().map(|t| t.len()).collect::<Vec<usize>>();
|
|
||||||
let output_size = self.output_names.len() as u64;
|
|
||||||
let input_size = self.input_names.len() as u64;
|
|
||||||
debug!(
|
|
||||||
"Request for Tensorflow with batch size: {} and input_size: {}",
|
|
||||||
batch_size, input_size
|
|
||||||
);
|
|
||||||
// build a set of input TF tensors
|
|
||||||
|
|
||||||
let batch_end = (1usize..=input_tensors.len() as usize)
|
|
||||||
.into_iter()
|
|
||||||
.collect_vec();
|
|
||||||
let mut batch_ends = vec![batch_end; output_size as usize];
|
|
||||||
|
|
||||||
let batched_tensors = TensorInputEnum::merge_batch(input_tensors)
|
|
||||||
.into_iter()
|
|
||||||
.enumerate()
|
|
||||||
.map(|(_, i)| TFModel::convert_to_tftensor_enum(i, input_size, batch_size))
|
|
||||||
.collect_vec();
|
|
||||||
|
|
||||||
let mut args = SessionRunArgs::new();
|
|
||||||
for (index, tf_tensor) in batched_tensors.iter().enumerate() {
|
|
||||||
match tf_tensor {
|
|
||||||
TFTensorEnum::String(inner) => args.add_feed(&self.input_ops[index], 0, inner),
|
|
||||||
TFTensorEnum::Int(inner) => args.add_feed(&self.input_ops[index], 0, inner),
|
|
||||||
TFTensorEnum::Int64(inner) => args.add_feed(&self.input_ops[index], 0, inner),
|
|
||||||
TFTensorEnum::Float(inner) => args.add_feed(&self.input_ops[index], 0, inner),
|
|
||||||
TFTensorEnum::Double(inner) => args.add_feed(&self.input_ops[index], 0, inner),
|
|
||||||
TFTensorEnum::Boolean(inner) => args.add_feed(&self.input_ops[index], 0, inner),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// For output ops, we receive the same op object by name. Actual tensor tokens are available at different offsets.
|
|
||||||
// Since indices are ordered, its important to specify output flag to Navi in the same order.
|
|
||||||
let token_outputs = self
|
|
||||||
.output_ops
|
|
||||||
.iter()
|
|
||||||
.enumerate()
|
|
||||||
.map(|(idx, op)| args.request_fetch(op, idx as i32))
|
|
||||||
.collect_vec();
|
|
||||||
match self.bundle.session.run(&mut args) {
|
|
||||||
Ok(_) => (),
|
|
||||||
Err(e) => {
|
|
||||||
NUM_REQUESTS_FAILED.inc_by(batch_size);
|
|
||||||
NUM_REQUESTS_FAILED_BY_MODEL
|
|
||||||
.with_label_values(&[&MODEL_SPECS[self.model_idx]])
|
|
||||||
.inc_by(batch_size);
|
|
||||||
INFERENCE_FAILED_REQUESTS_BY_MODEL
|
|
||||||
.with_label_values(&[&MODEL_SPECS[self.model_idx]])
|
|
||||||
.inc_by(batch_size);
|
|
||||||
panic!("{model}: {e:?}", model = MODEL_SPECS[self.model_idx], e = e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
let mut predict_return = vec![];
|
|
||||||
// Check the output.
|
|
||||||
for (index, token_output) in token_outputs.iter().enumerate() {
|
|
||||||
// same ops, with type info at different offsets.
|
|
||||||
let (res, width) = match self.output_ops[index].output_type(index) {
|
|
||||||
DataType::Float => {
|
|
||||||
let (tensor_output, tensor_width) =
|
|
||||||
TFModel::fetch_output(&mut args, token_output, batch_size, output_size);
|
|
||||||
(
|
|
||||||
TensorReturnEnum::FloatTensorReturn(Box::new(tensor_output)),
|
|
||||||
tensor_width,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
DataType::Int64 => {
|
|
||||||
let (tensor_output, tensor_width) =
|
|
||||||
TFModel::fetch_output(&mut args, token_output, batch_size, output_size);
|
|
||||||
(
|
|
||||||
TensorReturnEnum::Int64TensorReturn(Box::new(tensor_output)),
|
|
||||||
tensor_width,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
DataType::Int32 => {
|
|
||||||
let (tensor_output, tensor_width) =
|
|
||||||
TFModel::fetch_output(&mut args, token_output, batch_size, output_size);
|
|
||||||
(
|
|
||||||
TensorReturnEnum::Int32TensorReturn(Box::new(tensor_output)),
|
|
||||||
tensor_width,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
DataType::String => {
|
|
||||||
let (tensor_output, tensor_width) =
|
|
||||||
TFModel::fetch_output(&mut args, token_output, batch_size, output_size);
|
|
||||||
(
|
|
||||||
TensorReturnEnum::StringTensorReturn(Box::new(tensor_output)),
|
|
||||||
tensor_width,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
_ => panic!("Unsupported return type!"),
|
|
||||||
};
|
|
||||||
let width = width as usize;
|
|
||||||
for b in batch_ends[index].iter_mut() {
|
|
||||||
*b *= width;
|
|
||||||
}
|
|
||||||
predict_return.push(res)
|
|
||||||
}
|
|
||||||
//TODO: remove in the future
|
|
||||||
//TODO: support actual mtl model outputs
|
|
||||||
(predict_return, batch_ends)
|
|
||||||
}
|
|
||||||
#[inline(always)]
|
|
||||||
fn model_idx(&self) -> usize {
|
|
||||||
self.model_idx
|
|
||||||
}
|
|
||||||
#[inline(always)]
|
|
||||||
fn version(&self) -> i64 {
|
|
||||||
self.version
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
BIN
navi/navi/src/torch_model.docx
Normal file
BIN
navi/navi/src/torch_model.docx
Normal file
Binary file not shown.
@ -1,183 +0,0 @@
|
|||||||
#[cfg(feature = "torch")]
|
|
||||||
pub mod torch {
|
|
||||||
use std::fmt;
|
|
||||||
use std::fmt::Display;
|
|
||||||
use std::string::String;
|
|
||||||
|
|
||||||
use crate::TensorReturnEnum;
|
|
||||||
use crate::SerializedInput;
|
|
||||||
use crate::bootstrap::TensorInput;
|
|
||||||
use crate::cli_args::{Args, ARGS, MODEL_SPECS};
|
|
||||||
use crate::metrics;
|
|
||||||
use crate::metrics::{
|
|
||||||
INFERENCE_FAILED_REQUESTS_BY_MODEL, NUM_REQUESTS_FAILED, NUM_REQUESTS_FAILED_BY_MODEL,
|
|
||||||
};
|
|
||||||
use crate::predict_service::Model;
|
|
||||||
use anyhow::Result;
|
|
||||||
use dr_transform::converter::BatchPredictionRequestToTorchTensorConverter;
|
|
||||||
use dr_transform::converter::Converter;
|
|
||||||
use serde_json::Value;
|
|
||||||
use tch::Tensor;
|
|
||||||
use tch::{kind, CModule, IValue};
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct TorchModel {
|
|
||||||
pub model_idx: usize,
|
|
||||||
pub version: i64,
|
|
||||||
pub module: CModule,
|
|
||||||
pub export_dir: String,
|
|
||||||
// FIXME: make this Box<Option<..>> so input converter can be optional.
|
|
||||||
// Also consider adding output_converter.
|
|
||||||
pub input_converter: Box<dyn Converter>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Display for TorchModel {
|
|
||||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
|
||||||
write!(
|
|
||||||
f,
|
|
||||||
"idx: {}, torch model_name:{}, version:{}",
|
|
||||||
self.model_idx, MODEL_SPECS[self.model_idx], self.version
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl TorchModel {
|
|
||||||
pub fn new(idx: usize, version: String, _model_config: &Value) -> Result<TorchModel> {
|
|
||||||
let export_dir = format!("{}/{}/model.pt", ARGS.model_dir[idx], version);
|
|
||||||
let model = CModule::load(&export_dir).unwrap();
|
|
||||||
let torch_model = TorchModel {
|
|
||||||
model_idx: idx,
|
|
||||||
version: Args::version_str_to_epoch(&version)?,
|
|
||||||
module: model,
|
|
||||||
export_dir,
|
|
||||||
//TODO: move converter lookup in a registry.
|
|
||||||
input_converter: Box::new(BatchPredictionRequestToTorchTensorConverter::new(
|
|
||||||
&ARGS.model_dir[idx].as_str(),
|
|
||||||
version.as_str(),
|
|
||||||
vec![],
|
|
||||||
Some(&metrics::register_dynamic_metrics),
|
|
||||||
)),
|
|
||||||
};
|
|
||||||
|
|
||||||
torch_model.warmup()?;
|
|
||||||
Ok(torch_model)
|
|
||||||
}
|
|
||||||
#[inline(always)]
|
|
||||||
pub fn decode_to_inputs(bytes: SerializedInput) -> Vec<Tensor> {
|
|
||||||
//FIXME: for now we generate 4 random tensors as inputs to unblock end to end testing
|
|
||||||
//when Shajan's decoder is ready we will swap
|
|
||||||
let row = bytes.len() as i64;
|
|
||||||
let t1 = Tensor::randn(&[row, 5293], kind::FLOAT_CPU); //continuous
|
|
||||||
let t2 = Tensor::randint(10, &[row, 149], kind::INT64_CPU); //binary
|
|
||||||
let t3 = Tensor::randint(10, &[row, 320], kind::INT64_CPU); //discrete
|
|
||||||
let t4 = Tensor::randn(&[row, 200], kind::FLOAT_CPU); //user_embedding
|
|
||||||
let t5 = Tensor::randn(&[row, 200], kind::FLOAT_CPU); //user_eng_embedding
|
|
||||||
let t6 = Tensor::randn(&[row, 200], kind::FLOAT_CPU); //author_embedding
|
|
||||||
|
|
||||||
vec![t1, t2, t3, t4, t5, t6]
|
|
||||||
}
|
|
||||||
#[inline(always)]
|
|
||||||
pub fn output_to_vec(res: IValue, dst: &mut Vec<f32>) {
|
|
||||||
match res {
|
|
||||||
IValue::Tensor(tensor) => TorchModel::tensors_to_vec(&[tensor], dst),
|
|
||||||
IValue::Tuple(ivalues) => {
|
|
||||||
TorchModel::tensors_to_vec(&TorchModel::ivalues_to_tensors(ivalues), dst)
|
|
||||||
}
|
|
||||||
_ => panic!("we only support output as a single tensor or a vec of tensors"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
#[inline(always)]
|
|
||||||
pub fn tensor_flatten_size(t: &Tensor) -> usize {
|
|
||||||
t.size().into_iter().fold(1, |acc, x| acc * x) as usize
|
|
||||||
}
|
|
||||||
#[inline(always)]
|
|
||||||
pub fn tensor_to_vec<T: kind::Element>(res: &Tensor) -> Vec<T> {
|
|
||||||
let size = TorchModel::tensor_flatten_size(res);
|
|
||||||
let mut res_f32: Vec<T> = Vec::with_capacity(size);
|
|
||||||
unsafe {
|
|
||||||
res_f32.set_len(size);
|
|
||||||
}
|
|
||||||
res.copy_data(res_f32.as_mut_slice(), size);
|
|
||||||
// println!("Copied tensor:{}, {:?}", res_f32.len(), res_f32);
|
|
||||||
res_f32
|
|
||||||
}
|
|
||||||
#[inline(always)]
|
|
||||||
pub fn tensors_to_vec(tensors: &[Tensor], dst: &mut Vec<f32>) {
|
|
||||||
let mut offset = dst.len();
|
|
||||||
tensors.iter().for_each(|t| {
|
|
||||||
let size = TorchModel::tensor_flatten_size(t);
|
|
||||||
let next_size = offset + size;
|
|
||||||
unsafe {
|
|
||||||
dst.set_len(next_size);
|
|
||||||
}
|
|
||||||
t.copy_data(&mut dst[offset..], size);
|
|
||||||
offset = next_size;
|
|
||||||
});
|
|
||||||
}
|
|
||||||
pub fn ivalues_to_tensors(ivalues: Vec<IValue>) -> Vec<Tensor> {
|
|
||||||
ivalues
|
|
||||||
.into_iter()
|
|
||||||
.map(|t| {
|
|
||||||
if let IValue::Tensor(vanilla_t) = t {
|
|
||||||
vanilla_t
|
|
||||||
} else {
|
|
||||||
panic!("not a tensor")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.collect::<Vec<Tensor>>()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Model for TorchModel {
|
|
||||||
fn warmup(&self) -> Result<()> {
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
//TODO: torch runtime needs some refactor to make it a generic interface
|
|
||||||
#[inline(always)]
|
|
||||||
fn do_predict(
|
|
||||||
&self,
|
|
||||||
input_tensors: Vec<Vec<TensorInput>>,
|
|
||||||
total_len: u64,
|
|
||||||
) -> (Vec<TensorReturnEnum>, Vec<Vec<usize>>) {
|
|
||||||
let mut buf: Vec<f32> = Vec::with_capacity(10_000);
|
|
||||||
let mut batch_ends = vec![0usize; input_tensors.len()];
|
|
||||||
for (i, batch_bytes_in_request) in input_tensors.into_iter().enumerate() {
|
|
||||||
for _ in batch_bytes_in_request.into_iter() {
|
|
||||||
//FIXME: for now use some hack
|
|
||||||
let model_input = TorchModel::decode_to_inputs(vec![0u8; 30]); //self.input_converter.convert(bytes);
|
|
||||||
let input_batch_tensors = model_input
|
|
||||||
.into_iter()
|
|
||||||
.map(|t| IValue::Tensor(t))
|
|
||||||
.collect::<Vec<IValue>>();
|
|
||||||
// match self.module.forward_is(&input_batch_tensors) {
|
|
||||||
match self.module.method_is("forward_serve", &input_batch_tensors) {
|
|
||||||
Ok(res) => TorchModel::output_to_vec(res, &mut buf),
|
|
||||||
Err(e) => {
|
|
||||||
NUM_REQUESTS_FAILED.inc_by(total_len);
|
|
||||||
NUM_REQUESTS_FAILED_BY_MODEL
|
|
||||||
.with_label_values(&[&MODEL_SPECS[self.model_idx]])
|
|
||||||
.inc_by(total_len);
|
|
||||||
INFERENCE_FAILED_REQUESTS_BY_MODEL
|
|
||||||
.with_label_values(&[&MODEL_SPECS[self.model_idx]])
|
|
||||||
.inc_by(total_len);
|
|
||||||
panic!("{model}: {e:?}", model = MODEL_SPECS[self.model_idx], e = e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
batch_ends[i] = buf.len();
|
|
||||||
}
|
|
||||||
(
|
|
||||||
vec![TensorReturnEnum::FloatTensorReturn(Box::new(buf))],
|
|
||||||
vec![batch_ends],
|
|
||||||
)
|
|
||||||
}
|
|
||||||
#[inline(always)]
|
|
||||||
fn model_idx(&self) -> usize {
|
|
||||||
self.model_idx
|
|
||||||
}
|
|
||||||
#[inline(always)]
|
|
||||||
fn version(&self) -> i64 {
|
|
||||||
self.version
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
BIN
navi/segdense/Cargo.docx
Normal file
BIN
navi/segdense/Cargo.docx
Normal file
Binary file not shown.
@ -1,11 +0,0 @@
|
|||||||
[package]
|
|
||||||
name = "segdense"
|
|
||||||
version = "0.1.0"
|
|
||||||
edition = "2021"
|
|
||||||
|
|
||||||
|
|
||||||
[dependencies]
|
|
||||||
env_logger = "0.10.0"
|
|
||||||
serde = { version = "1.0.104", features = ["derive"] }
|
|
||||||
serde_json = "1.0.48"
|
|
||||||
log = "0.4.17"
|
|
BIN
navi/segdense/src/error.docx
Normal file
BIN
navi/segdense/src/error.docx
Normal file
Binary file not shown.
@ -1,53 +0,0 @@
|
|||||||
use std::fmt::Display;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Custom error
|
|
||||||
*/
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub enum SegDenseError {
|
|
||||||
IoError(std::io::Error),
|
|
||||||
Json(serde_json::Error),
|
|
||||||
JsonMissingRoot,
|
|
||||||
JsonMissingObject,
|
|
||||||
JsonMissingArray,
|
|
||||||
JsonArraySize,
|
|
||||||
JsonMissingInputFeature,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Display for SegDenseError {
|
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
||||||
match self {
|
|
||||||
SegDenseError::IoError(io_error) => write!(f, "{}", io_error),
|
|
||||||
SegDenseError::Json(serde_json) => write!(f, "{}", serde_json),
|
|
||||||
SegDenseError::JsonMissingRoot => {
|
|
||||||
write!(f, "{}", "SegDense JSON: Root Node note found!")
|
|
||||||
}
|
|
||||||
SegDenseError::JsonMissingObject => {
|
|
||||||
write!(f, "{}", "SegDense JSON: Object note found!")
|
|
||||||
}
|
|
||||||
SegDenseError::JsonMissingArray => {
|
|
||||||
write!(f, "{}", "SegDense JSON: Array Node note found!")
|
|
||||||
}
|
|
||||||
SegDenseError::JsonArraySize => {
|
|
||||||
write!(f, "{}", "SegDense JSON: Array size not as expected!")
|
|
||||||
}
|
|
||||||
SegDenseError::JsonMissingInputFeature => {
|
|
||||||
write!(f, "{}", "SegDense JSON: Missing input feature!")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl std::error::Error for SegDenseError {}
|
|
||||||
|
|
||||||
impl From<std::io::Error> for SegDenseError {
|
|
||||||
fn from(err: std::io::Error) -> Self {
|
|
||||||
SegDenseError::IoError(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<serde_json::Error> for SegDenseError {
|
|
||||||
fn from(err: serde_json::Error) -> Self {
|
|
||||||
SegDenseError::Json(err)
|
|
||||||
}
|
|
||||||
}
|
|
BIN
navi/segdense/src/lib.docx
Normal file
BIN
navi/segdense/src/lib.docx
Normal file
Binary file not shown.
@ -1,4 +0,0 @@
|
|||||||
pub mod error;
|
|
||||||
pub mod mapper;
|
|
||||||
pub mod segdense_transform_spec_home_recap_2022;
|
|
||||||
pub mod util;
|
|
BIN
navi/segdense/src/main.docx
Normal file
BIN
navi/segdense/src/main.docx
Normal file
Binary file not shown.
@ -1,22 +0,0 @@
|
|||||||
use std::env;
|
|
||||||
use std::fs;
|
|
||||||
|
|
||||||
use segdense::error::SegDenseError;
|
|
||||||
use segdense::util;
|
|
||||||
|
|
||||||
fn main() -> Result<(), SegDenseError> {
|
|
||||||
env_logger::init();
|
|
||||||
let args: Vec<String> = env::args().collect();
|
|
||||||
|
|
||||||
let schema_file_name: &str = if args.len() == 1 {
|
|
||||||
"json/compact.json"
|
|
||||||
} else {
|
|
||||||
&args[1]
|
|
||||||
};
|
|
||||||
|
|
||||||
let json_str = fs::read_to_string(schema_file_name)?;
|
|
||||||
|
|
||||||
util::safe_load_config(&json_str)?;
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
BIN
navi/segdense/src/mapper.docx
Normal file
BIN
navi/segdense/src/mapper.docx
Normal file
Binary file not shown.
@ -1,45 +0,0 @@
|
|||||||
use std::collections::HashMap;
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct FeatureInfo {
|
|
||||||
pub tensor_index: i8,
|
|
||||||
pub index_within_tensor: i64,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub static NULL_INFO: FeatureInfo = FeatureInfo {
|
|
||||||
tensor_index: -1,
|
|
||||||
index_within_tensor: -1,
|
|
||||||
};
|
|
||||||
|
|
||||||
#[derive(Debug, Default)]
|
|
||||||
pub struct FeatureMapper {
|
|
||||||
map: HashMap<i64, FeatureInfo>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl FeatureMapper {
|
|
||||||
pub fn new() -> FeatureMapper {
|
|
||||||
FeatureMapper {
|
|
||||||
map: HashMap::new(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub trait MapWriter {
|
|
||||||
fn set(&mut self, feature_id: i64, info: FeatureInfo);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub trait MapReader {
|
|
||||||
fn get(&self, feature_id: &i64) -> Option<&FeatureInfo>;
|
|
||||||
}
|
|
||||||
|
|
||||||
impl MapWriter for FeatureMapper {
|
|
||||||
fn set(&mut self, feature_id: i64, info: FeatureInfo) {
|
|
||||||
self.map.insert(feature_id, info);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl MapReader for FeatureMapper {
|
|
||||||
fn get(&self, feature_id: &i64) -> Option<&FeatureInfo> {
|
|
||||||
self.map.get(feature_id)
|
|
||||||
}
|
|
||||||
}
|
|
BIN
navi/segdense/src/segdense_transform_spec_home_recap_2022.docx
Normal file
BIN
navi/segdense/src/segdense_transform_spec_home_recap_2022.docx
Normal file
Binary file not shown.
@ -1,182 +0,0 @@
|
|||||||
use serde::{Deserialize, Serialize};
|
|
||||||
use serde_json::Value;
|
|
||||||
|
|
||||||
#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
|
|
||||||
#[serde(rename_all = "camelCase")]
|
|
||||||
pub struct Root {
|
|
||||||
#[serde(rename = "common_prefix")]
|
|
||||||
pub common_prefix: String,
|
|
||||||
#[serde(rename = "densification_transform_spec")]
|
|
||||||
pub densification_transform_spec: DensificationTransformSpec,
|
|
||||||
#[serde(rename = "identity_transform_spec")]
|
|
||||||
pub identity_transform_spec: Vec<IdentityTransformSpec>,
|
|
||||||
#[serde(rename = "complex_feature_type_transform_spec")]
|
|
||||||
pub complex_feature_type_transform_spec: Vec<ComplexFeatureTypeTransformSpec>,
|
|
||||||
#[serde(rename = "input_features_map")]
|
|
||||||
pub input_features_map: Value,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
|
|
||||||
#[serde(rename_all = "camelCase")]
|
|
||||||
pub struct DensificationTransformSpec {
|
|
||||||
pub discrete: Discrete,
|
|
||||||
pub cont: Cont,
|
|
||||||
pub binary: Binary,
|
|
||||||
pub string: Value, // Use StringType
|
|
||||||
pub blob: Blob,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
|
|
||||||
#[serde(rename_all = "camelCase")]
|
|
||||||
pub struct Discrete {
|
|
||||||
pub tag: String,
|
|
||||||
#[serde(rename = "generic_feature_type")]
|
|
||||||
pub generic_feature_type: i64,
|
|
||||||
#[serde(rename = "feature_identifier")]
|
|
||||||
pub feature_identifier: String,
|
|
||||||
#[serde(rename = "fixed_length")]
|
|
||||||
pub fixed_length: i64,
|
|
||||||
#[serde(rename = "default_value")]
|
|
||||||
pub default_value: DefaultValue,
|
|
||||||
#[serde(rename = "input_features")]
|
|
||||||
pub input_features: Vec<InputFeature>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
|
|
||||||
#[serde(rename_all = "camelCase")]
|
|
||||||
pub struct DefaultValue {
|
|
||||||
#[serde(rename = "type")]
|
|
||||||
pub type_field: String,
|
|
||||||
pub value: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
|
|
||||||
#[serde(rename_all = "camelCase")]
|
|
||||||
pub struct InputFeature {
|
|
||||||
#[serde(rename = "feature_id")]
|
|
||||||
pub feature_id: i64,
|
|
||||||
#[serde(rename = "full_feature_name")]
|
|
||||||
pub full_feature_name: String,
|
|
||||||
#[serde(rename = "feature_type")]
|
|
||||||
pub feature_type: i64,
|
|
||||||
pub index: i64,
|
|
||||||
#[serde(rename = "maybe_exclude")]
|
|
||||||
pub maybe_exclude: bool,
|
|
||||||
pub tag: String,
|
|
||||||
#[serde(rename = "added_at")]
|
|
||||||
pub added_at: i64,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
|
|
||||||
#[serde(rename_all = "camelCase")]
|
|
||||||
pub struct Cont {
|
|
||||||
pub tag: String,
|
|
||||||
#[serde(rename = "generic_feature_type")]
|
|
||||||
pub generic_feature_type: i64,
|
|
||||||
#[serde(rename = "feature_identifier")]
|
|
||||||
pub feature_identifier: String,
|
|
||||||
#[serde(rename = "fixed_length")]
|
|
||||||
pub fixed_length: i64,
|
|
||||||
#[serde(rename = "default_value")]
|
|
||||||
pub default_value: DefaultValue,
|
|
||||||
#[serde(rename = "input_features")]
|
|
||||||
pub input_features: Vec<InputFeature>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
|
|
||||||
#[serde(rename_all = "camelCase")]
|
|
||||||
pub struct Binary {
|
|
||||||
pub tag: String,
|
|
||||||
#[serde(rename = "generic_feature_type")]
|
|
||||||
pub generic_feature_type: i64,
|
|
||||||
#[serde(rename = "feature_identifier")]
|
|
||||||
pub feature_identifier: String,
|
|
||||||
#[serde(rename = "fixed_length")]
|
|
||||||
pub fixed_length: i64,
|
|
||||||
#[serde(rename = "default_value")]
|
|
||||||
pub default_value: DefaultValue,
|
|
||||||
#[serde(rename = "input_features")]
|
|
||||||
pub input_features: Vec<InputFeature>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
|
|
||||||
#[serde(rename_all = "camelCase")]
|
|
||||||
pub struct StringType {
|
|
||||||
pub tag: String,
|
|
||||||
#[serde(rename = "generic_feature_type")]
|
|
||||||
pub generic_feature_type: i64,
|
|
||||||
#[serde(rename = "feature_identifier")]
|
|
||||||
pub feature_identifier: String,
|
|
||||||
#[serde(rename = "fixed_length")]
|
|
||||||
pub fixed_length: i64,
|
|
||||||
#[serde(rename = "default_value")]
|
|
||||||
pub default_value: DefaultValue,
|
|
||||||
#[serde(rename = "input_features")]
|
|
||||||
pub input_features: Vec<InputFeature>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
|
|
||||||
#[serde(rename_all = "camelCase")]
|
|
||||||
pub struct Blob {
|
|
||||||
pub tag: String,
|
|
||||||
#[serde(rename = "generic_feature_type")]
|
|
||||||
pub generic_feature_type: i64,
|
|
||||||
#[serde(rename = "feature_identifier")]
|
|
||||||
pub feature_identifier: String,
|
|
||||||
#[serde(rename = "fixed_length")]
|
|
||||||
pub fixed_length: i64,
|
|
||||||
#[serde(rename = "default_value")]
|
|
||||||
pub default_value: DefaultValue,
|
|
||||||
#[serde(rename = "input_features")]
|
|
||||||
pub input_features: Vec<Value>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
|
|
||||||
#[serde(rename_all = "camelCase")]
|
|
||||||
pub struct IdentityTransformSpec {
|
|
||||||
#[serde(rename = "feature_id")]
|
|
||||||
pub feature_id: i64,
|
|
||||||
#[serde(rename = "full_feature_name")]
|
|
||||||
pub full_feature_name: String,
|
|
||||||
#[serde(rename = "feature_type")]
|
|
||||||
pub feature_type: i64,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
|
|
||||||
#[serde(rename_all = "camelCase")]
|
|
||||||
pub struct ComplexFeatureTypeTransformSpec {
|
|
||||||
#[serde(rename = "feature_id")]
|
|
||||||
pub feature_id: i64,
|
|
||||||
#[serde(rename = "full_feature_name")]
|
|
||||||
pub full_feature_name: String,
|
|
||||||
#[serde(rename = "feature_type")]
|
|
||||||
pub feature_type: i64,
|
|
||||||
pub index: i64,
|
|
||||||
#[serde(rename = "maybe_exclude")]
|
|
||||||
pub maybe_exclude: bool,
|
|
||||||
pub tag: String,
|
|
||||||
#[serde(rename = "tensor_data_type")]
|
|
||||||
pub tensor_data_type: Option<i64>,
|
|
||||||
#[serde(rename = "added_at")]
|
|
||||||
pub added_at: i64,
|
|
||||||
#[serde(rename = "tensor_shape")]
|
|
||||||
#[serde(default)]
|
|
||||||
pub tensor_shape: Vec<i64>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
|
|
||||||
#[serde(rename_all = "camelCase")]
|
|
||||||
pub struct InputFeatureMapRecord {
|
|
||||||
#[serde(rename = "feature_id")]
|
|
||||||
pub feature_id: i64,
|
|
||||||
#[serde(rename = "full_feature_name")]
|
|
||||||
pub full_feature_name: String,
|
|
||||||
#[serde(rename = "feature_type")]
|
|
||||||
pub feature_type: i64,
|
|
||||||
pub index: i64,
|
|
||||||
#[serde(rename = "maybe_exclude")]
|
|
||||||
pub maybe_exclude: bool,
|
|
||||||
pub tag: String,
|
|
||||||
#[serde(rename = "added_at")]
|
|
||||||
pub added_at: i64,
|
|
||||||
}
|
|
BIN
navi/segdense/src/util.docx
Normal file
BIN
navi/segdense/src/util.docx
Normal file
Binary file not shown.
@ -1,154 +0,0 @@
|
|||||||
use log::debug;
|
|
||||||
use std::fs;
|
|
||||||
|
|
||||||
use serde_json::{Map, Value};
|
|
||||||
|
|
||||||
use crate::error::SegDenseError;
|
|
||||||
use crate::mapper::{FeatureInfo, FeatureMapper, MapWriter};
|
|
||||||
use crate::segdense_transform_spec_home_recap_2022::{self as seg_dense, InputFeature};
|
|
||||||
|
|
||||||
pub fn load_config(file_name: &str) -> Result<seg_dense::Root, SegDenseError> {
|
|
||||||
let json_str = fs::read_to_string(file_name)?;
|
|
||||||
// &format!("Unable to load segdense file {}", file_name));
|
|
||||||
let seg_dense_config = parse(&json_str)?;
|
|
||||||
// &format!("Unable to parse segdense file {}", file_name));
|
|
||||||
Ok(seg_dense_config)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn parse(json_str: &str) -> Result<seg_dense::Root, SegDenseError> {
|
|
||||||
let root: seg_dense::Root = serde_json::from_str(json_str)?;
|
|
||||||
Ok(root)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Given a json string containing a seg dense schema create a feature mapper
|
|
||||||
* which is essentially:
|
|
||||||
*
|
|
||||||
* {feature-id -> (Tensor Index, Index of feature within the tensor)}
|
|
||||||
*
|
|
||||||
* Feature id : 64 bit hash of the feature name used in DataRecords.
|
|
||||||
*
|
|
||||||
* Tensor Index : A vector of tensors is passed to the model. Tensor
|
|
||||||
* index refers to the tensor this feature is part of.
|
|
||||||
*
|
|
||||||
* Index of feature in tensor : The tensors are vectors, the index of
|
|
||||||
* feature is the position to put the feature value.
|
|
||||||
*
|
|
||||||
* There are many assumptions made in this function that is very model specific.
|
|
||||||
* These assumptions are called out below and need to be schematized eventually.
|
|
||||||
*
|
|
||||||
* Call this once for each segdense schema and cache the FeatureMapper.
|
|
||||||
*/
|
|
||||||
pub fn safe_load_config(json_str: &str) -> Result<FeatureMapper, SegDenseError> {
|
|
||||||
let root = parse(json_str)?;
|
|
||||||
load_from_parsed_config(root)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Perf note : make 'root' un-owned
|
|
||||||
pub fn load_from_parsed_config(root: seg_dense::Root) -> Result<FeatureMapper, SegDenseError> {
|
|
||||||
let v = root.input_features_map;
|
|
||||||
|
|
||||||
// Do error check
|
|
||||||
let map: Map<String, Value> = match v {
|
|
||||||
Value::Object(map) => map,
|
|
||||||
_ => return Err(SegDenseError::JsonMissingObject),
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut fm: FeatureMapper = FeatureMapper::new();
|
|
||||||
|
|
||||||
let items = map.values();
|
|
||||||
|
|
||||||
// Perf : Consider a way to avoid clone here
|
|
||||||
for item in items.cloned() {
|
|
||||||
let mut vec = match item {
|
|
||||||
Value::Array(v) => v,
|
|
||||||
_ => return Err(SegDenseError::JsonMissingArray),
|
|
||||||
};
|
|
||||||
|
|
||||||
if vec.len() != 1 {
|
|
||||||
return Err(SegDenseError::JsonArraySize);
|
|
||||||
}
|
|
||||||
|
|
||||||
let val = vec.pop().unwrap();
|
|
||||||
|
|
||||||
let input_feature: seg_dense::InputFeature = serde_json::from_value(val)?;
|
|
||||||
let feature_id = input_feature.feature_id;
|
|
||||||
let feature_info = to_feature_info(&input_feature);
|
|
||||||
|
|
||||||
match feature_info {
|
|
||||||
Some(info) => {
|
|
||||||
debug!("{:?}", info);
|
|
||||||
fm.set(feature_id, info)
|
|
||||||
}
|
|
||||||
None => (),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(fm)
|
|
||||||
}
|
|
||||||
#[allow(dead_code)]
|
|
||||||
fn add_feature_info_to_mapper(
|
|
||||||
feature_mapper: &mut FeatureMapper,
|
|
||||||
input_features: &Vec<InputFeature>,
|
|
||||||
) {
|
|
||||||
for input_feature in input_features.iter() {
|
|
||||||
let feature_id = input_feature.feature_id;
|
|
||||||
let feature_info = to_feature_info(input_feature);
|
|
||||||
|
|
||||||
match feature_info {
|
|
||||||
Some(info) => {
|
|
||||||
debug!("{:?}", info);
|
|
||||||
feature_mapper.set(feature_id, info)
|
|
||||||
}
|
|
||||||
None => (),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn to_feature_info(input_feature: &seg_dense::InputFeature) -> Option<FeatureInfo> {
|
|
||||||
if input_feature.maybe_exclude {
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
|
|
||||||
// This part needs to be schema driven
|
|
||||||
//
|
|
||||||
// tensor index : Which of these tensors this feature is part of
|
|
||||||
// [Continious, Binary, Discrete, User_embedding, user_eng_embedding, author_embedding]
|
|
||||||
// Note that this order is fixed/hardcoded here, and need to be schematized
|
|
||||||
//
|
|
||||||
let tensor_idx: i8 = match input_feature.feature_id {
|
|
||||||
// user.timelines.twhin_user_follow_embeddings.twhin_user_follow_embeddings
|
|
||||||
// Feature name is mapped to a feature-id value. The hardcoded values below correspond to a specific feature name.
|
|
||||||
-2550691008059411095 => 3,
|
|
||||||
|
|
||||||
// user.timelines.twhin_user_engagement_embeddings.twhin_user_engagement_embeddings
|
|
||||||
5390650078733277231 => 4,
|
|
||||||
|
|
||||||
// original_author.timelines.twhin_author_follow_embeddings.twhin_author_follow_embeddings
|
|
||||||
3223956748566688423 => 5,
|
|
||||||
|
|
||||||
_ => match input_feature.feature_type {
|
|
||||||
// feature_type : src/thrift/com/twitter/ml/api/data.thrift
|
|
||||||
// BINARY = 1, CONTINUOUS = 2, DISCRETE = 3,
|
|
||||||
// Map to slots in [Continious, Binary, Discrete, ..]
|
|
||||||
1 => 1,
|
|
||||||
2 => 0,
|
|
||||||
3 => 2,
|
|
||||||
_ => -1,
|
|
||||||
},
|
|
||||||
};
|
|
||||||
|
|
||||||
if input_feature.index < 0 {
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle this case later
|
|
||||||
if tensor_idx == -1 {
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
|
|
||||||
Some(FeatureInfo {
|
|
||||||
tensor_index: tensor_idx,
|
|
||||||
index_within_tensor: input_feature.index,
|
|
||||||
})
|
|
||||||
}
|
|
BIN
navi/thrift_bpr_adapter/thrift/Cargo.docx
Normal file
BIN
navi/thrift_bpr_adapter/thrift/Cargo.docx
Normal file
Binary file not shown.
@ -1,8 +0,0 @@
|
|||||||
[package]
|
|
||||||
name = "bpr_thrift"
|
|
||||||
description = "Thrift parser for Batch Prediction Request"
|
|
||||||
version = "0.1.0"
|
|
||||||
edition = "2021"
|
|
||||||
|
|
||||||
[dependencies]
|
|
||||||
thrift = "0.17.0"
|
|
BIN
navi/thrift_bpr_adapter/thrift/src/data.docx
Normal file
BIN
navi/thrift_bpr_adapter/thrift/src/data.docx
Normal file
Binary file not shown.
File diff suppressed because it is too large
Load Diff
BIN
navi/thrift_bpr_adapter/thrift/src/decoder.docx
Normal file
BIN
navi/thrift_bpr_adapter/thrift/src/decoder.docx
Normal file
Binary file not shown.
@ -1,78 +0,0 @@
|
|||||||
|
|
||||||
// A feature value can be one of these
|
|
||||||
enum FeatureVal {
|
|
||||||
Empty,
|
|
||||||
U8Vector(Vec<u8>),
|
|
||||||
FloatVector(Vec<f32>),
|
|
||||||
}
|
|
||||||
|
|
||||||
// A Feture has a name and a value
|
|
||||||
// The name for now is 'id' of type string
|
|
||||||
// Eventually this needs to be flexible - example to accomodate feature-id
|
|
||||||
struct Feature {
|
|
||||||
id: String,
|
|
||||||
val: FeatureVal,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Feature {
|
|
||||||
fn new() -> Feature {
|
|
||||||
Feature {
|
|
||||||
id: String::new(),
|
|
||||||
val: FeatureVal::Empty
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// A single inference record will have multiple features
|
|
||||||
struct Record {
|
|
||||||
fields: Vec<Feature>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Record {
|
|
||||||
fn new() -> Record {
|
|
||||||
Record { fields: vec![] }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// This is the main API used by external components
|
|
||||||
// Given a serialized input, decode it into Records
|
|
||||||
fn decode(input: Vec<u8>) -> Vec<Record> {
|
|
||||||
// For helping define the interface
|
|
||||||
vec![get_random_record(), get_random_record()]
|
|
||||||
}
|
|
||||||
|
|
||||||
// Used for testing the API, will be eventually removed
|
|
||||||
fn get_random_record() -> Record {
|
|
||||||
let mut record: Record = Record::new();
|
|
||||||
|
|
||||||
let f1: Feature = Feature {
|
|
||||||
id: String::from("continuous_features"),
|
|
||||||
val: FeatureVal::FloatVector(vec![1.0f32; 2134]),
|
|
||||||
};
|
|
||||||
|
|
||||||
record.fields.push(f1);
|
|
||||||
|
|
||||||
let f2: Feature = Feature {
|
|
||||||
id: String::from("user_embedding"),
|
|
||||||
val: FeatureVal::FloatVector(vec![2.0f32; 200]),
|
|
||||||
};
|
|
||||||
|
|
||||||
record.fields.push(f2);
|
|
||||||
|
|
||||||
let f3: Feature = Feature {
|
|
||||||
id: String::from("author_embedding"),
|
|
||||||
val: FeatureVal::FloatVector(vec![3.0f32; 200]),
|
|
||||||
};
|
|
||||||
|
|
||||||
record.fields.push(f3);
|
|
||||||
|
|
||||||
let f4: Feature = Feature {
|
|
||||||
id: String::from("binary_features"),
|
|
||||||
val: FeatureVal::U8Vector(vec![4u8; 43]),
|
|
||||||
};
|
|
||||||
|
|
||||||
record.fields.push(f4);
|
|
||||||
|
|
||||||
record
|
|
||||||
}
|
|
||||||
|
|
BIN
navi/thrift_bpr_adapter/thrift/src/lib.docx
Normal file
BIN
navi/thrift_bpr_adapter/thrift/src/lib.docx
Normal file
Binary file not shown.
@ -1,4 +0,0 @@
|
|||||||
pub mod prediction_service;
|
|
||||||
pub mod data;
|
|
||||||
pub mod tensor;
|
|
||||||
|
|
BIN
navi/thrift_bpr_adapter/thrift/src/main.docx
Normal file
BIN
navi/thrift_bpr_adapter/thrift/src/main.docx
Normal file
Binary file not shown.
@ -1,81 +0,0 @@
|
|||||||
use std::collections::BTreeSet;
|
|
||||||
use std::collections::BTreeMap;
|
|
||||||
|
|
||||||
use bpr_thrift::data::DataRecord;
|
|
||||||
use bpr_thrift::prediction_service::BatchPredictionRequest;
|
|
||||||
use thrift::OrderedFloat;
|
|
||||||
|
|
||||||
use thrift::protocol::TBinaryInputProtocol;
|
|
||||||
use thrift::protocol::TSerializable;
|
|
||||||
use thrift::transport::TBufferChannel;
|
|
||||||
use thrift::Result;
|
|
||||||
|
|
||||||
fn main() {
|
|
||||||
let data_path = "/tmp/current/timelines/output-1";
|
|
||||||
let bin_data: Vec<u8> = std::fs::read(data_path).expect("Could not read file!");
|
|
||||||
|
|
||||||
println!("Length : {}", bin_data.len());
|
|
||||||
|
|
||||||
let mut bc = TBufferChannel::with_capacity(bin_data.len(), 0);
|
|
||||||
|
|
||||||
bc.set_readable_bytes(&bin_data);
|
|
||||||
|
|
||||||
let mut protocol = TBinaryInputProtocol::new(bc, true);
|
|
||||||
|
|
||||||
let result: Result<BatchPredictionRequest> =
|
|
||||||
BatchPredictionRequest::read_from_in_protocol(&mut protocol);
|
|
||||||
|
|
||||||
match result {
|
|
||||||
Ok(bpr) => logBP(bpr),
|
|
||||||
Err(err) => println!("Error {}", err),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn logBP(bpr: BatchPredictionRequest) {
|
|
||||||
println!("-------[OUTPUT]---------------");
|
|
||||||
println!("data {:?}", bpr);
|
|
||||||
println!("------------------------------");
|
|
||||||
|
|
||||||
/*
|
|
||||||
let common = bpr.common_features;
|
|
||||||
let recs = bpr.individual_features_list;
|
|
||||||
|
|
||||||
println!("--------[Len : {}]------------------", recs.len());
|
|
||||||
|
|
||||||
println!("-------[COMMON]---------------");
|
|
||||||
match common {
|
|
||||||
Some(dr) => logDR(dr),
|
|
||||||
None => println!("None"),
|
|
||||||
}
|
|
||||||
println!("------------------------------");
|
|
||||||
for rec in recs {
|
|
||||||
logDR(rec);
|
|
||||||
}
|
|
||||||
println!("------------------------------");
|
|
||||||
*/
|
|
||||||
}
|
|
||||||
|
|
||||||
fn logDR(dr: DataRecord) {
|
|
||||||
println!("--------[DR]------------------");
|
|
||||||
|
|
||||||
match dr.binary_features {
|
|
||||||
Some(bf) => logBin(bf),
|
|
||||||
_ => (),
|
|
||||||
}
|
|
||||||
|
|
||||||
match dr.continuous_features {
|
|
||||||
Some(cf) => logCF(cf),
|
|
||||||
_ => (),
|
|
||||||
}
|
|
||||||
println!("------------------------------");
|
|
||||||
}
|
|
||||||
|
|
||||||
fn logBin(bin: BTreeSet<i64>) {
|
|
||||||
println!("B: {:?}", bin)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn logCF(cf: BTreeMap<i64, OrderedFloat<f64>>) {
|
|
||||||
for (id, fs) in cf {
|
|
||||||
println!("C: {} -> [{}]", id, fs);
|
|
||||||
}
|
|
||||||
}
|
|
BIN
navi/thrift_bpr_adapter/thrift/src/prediction_service.docx
Normal file
BIN
navi/thrift_bpr_adapter/thrift/src/prediction_service.docx
Normal file
Binary file not shown.
File diff suppressed because it is too large
Load Diff
BIN
navi/thrift_bpr_adapter/thrift/src/tensor.docx
Normal file
BIN
navi/thrift_bpr_adapter/thrift/src/tensor.docx
Normal file
Binary file not shown.
File diff suppressed because it is too large
Load Diff
BIN
product-mixer/README.docx
Normal file
BIN
product-mixer/README.docx
Normal file
Binary file not shown.
@ -1,41 +0,0 @@
|
|||||||
Product Mixer
|
|
||||||
=============
|
|
||||||
|
|
||||||
## Overview
|
|
||||||
|
|
||||||
Product Mixer is a common service framework and set of libraries that make it easy to build,
|
|
||||||
iterate on, and own product surface areas. It consists of:
|
|
||||||
|
|
||||||
- **Core Libraries:** A set of libraries that enable you to build execution pipelines out of
|
|
||||||
reusable components. You define your logic in small, well-defined, reusable components and focus
|
|
||||||
on expressing the business logic you want to have. Then you can define easy to understand pipelines
|
|
||||||
that compose your components. Product Mixer handles the execution and monitoring of your pipelines
|
|
||||||
allowing you to focus on what really matters, your business logic.
|
|
||||||
|
|
||||||
- **Service Framework:** A common service skeleton for teams to host their Product Mixer products.
|
|
||||||
|
|
||||||
- **Component Library:** A shared library of components made by the Product Mixer Team, or
|
|
||||||
contributed by users. This enables you to both easily share the reusable components you make as well
|
|
||||||
as benefit from the work other teams have done by utilizing their shared components in the library.
|
|
||||||
|
|
||||||
## Architecture
|
|
||||||
|
|
||||||
The bulk of a Product Mixer can be broken down into Pipelines and Components. Components allow you
|
|
||||||
to break business logic into separate, standardized, reusable, testable, and easily composable
|
|
||||||
pieces, where each component has a well defined abstraction. Pipelines are essentially configuration
|
|
||||||
files specifying which Components should be used and when. This makes it easy to understand how your
|
|
||||||
code will execute while keeping it organized and structured in a maintainable way.
|
|
||||||
|
|
||||||
Requests first go to Product Pipelines, which are used to select which Mixer Pipeline or
|
|
||||||
Recommendation Pipeline to run for a given request. Each Mixer or Recommendation
|
|
||||||
Pipeline may run multiple Candidate Pipelines to fetch candidates to include in the response.
|
|
||||||
|
|
||||||
Mixer Pipelines combine the results of multiple heterogeneous Candidate Pipelines together
|
|
||||||
(e.g. ads, tweets, users) while Recommendation Pipelines are used to score (via Scoring Pipelines)
|
|
||||||
and rank the results of homogenous Candidate Pipelines so that the top ranked ones can be returned.
|
|
||||||
These pipelines also marshall candidates into a domain object and then into a transport object
|
|
||||||
to return to the caller.
|
|
||||||
|
|
||||||
Candidate Pipelines fetch candidates from underlying Candidate Sources and perform some basic
|
|
||||||
operations on the Candidates, such as filtering out unwanted candidates, applying decorations,
|
|
||||||
and hydrating features.
|
|
Binary file not shown.
@ -1,57 +0,0 @@
|
|||||||
package com.twitter.product_mixer.component_library.candidate_source.account_recommendations_mixer
|
|
||||||
|
|
||||||
import com.twitter.account_recommendations_mixer.{thriftscala => t}
|
|
||||||
import com.twitter.product_mixer.component_library.model.candidate.UserCandidate
|
|
||||||
import com.twitter.product_mixer.core.feature.Feature
|
|
||||||
import com.twitter.product_mixer.core.feature.featuremap.FeatureMapBuilder
|
|
||||||
import com.twitter.product_mixer.core.functional_component.candidate_source.CandidateSourceWithExtractedFeatures
|
|
||||||
import com.twitter.product_mixer.core.functional_component.candidate_source.CandidatesWithSourceFeatures
|
|
||||||
import com.twitter.product_mixer.core.model.common.identifier.CandidateSourceIdentifier
|
|
||||||
import com.twitter.stitch.Stitch
|
|
||||||
import javax.inject.Inject
|
|
||||||
import javax.inject.Singleton
|
|
||||||
|
|
||||||
object WhoToFollowModuleHeaderFeature extends Feature[UserCandidate, t.Header]
|
|
||||||
object WhoToFollowModuleFooterFeature extends Feature[UserCandidate, Option[t.Footer]]
|
|
||||||
object WhoToFollowModuleDisplayOptionsFeature
|
|
||||||
extends Feature[UserCandidate, Option[t.DisplayOptions]]
|
|
||||||
|
|
||||||
@Singleton
|
|
||||||
class AccountRecommendationsMixerCandidateSource @Inject() (
|
|
||||||
accountRecommendationsMixer: t.AccountRecommendationsMixer.MethodPerEndpoint)
|
|
||||||
extends CandidateSourceWithExtractedFeatures[
|
|
||||||
t.AccountRecommendationsMixerRequest,
|
|
||||||
t.RecommendedUser
|
|
||||||
] {
|
|
||||||
|
|
||||||
override val identifier: CandidateSourceIdentifier =
|
|
||||||
CandidateSourceIdentifier(name = "AccountRecommendationsMixer")
|
|
||||||
|
|
||||||
override def apply(
|
|
||||||
request: t.AccountRecommendationsMixerRequest
|
|
||||||
): Stitch[CandidatesWithSourceFeatures[t.RecommendedUser]] = {
|
|
||||||
Stitch
|
|
||||||
.callFuture(accountRecommendationsMixer.getWtfRecommendations(request))
|
|
||||||
.map { response: t.WhoToFollowResponse =>
|
|
||||||
responseToCandidatesWithSourceFeatures(
|
|
||||||
response.userRecommendations,
|
|
||||||
response.header,
|
|
||||||
response.footer,
|
|
||||||
response.displayOptions)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private def responseToCandidatesWithSourceFeatures(
|
|
||||||
userRecommendations: Seq[t.RecommendedUser],
|
|
||||||
header: t.Header,
|
|
||||||
footer: Option[t.Footer],
|
|
||||||
displayOptions: Option[t.DisplayOptions],
|
|
||||||
): CandidatesWithSourceFeatures[t.RecommendedUser] = {
|
|
||||||
val features = FeatureMapBuilder()
|
|
||||||
.add(WhoToFollowModuleHeaderFeature, header)
|
|
||||||
.add(WhoToFollowModuleFooterFeature, footer)
|
|
||||||
.add(WhoToFollowModuleDisplayOptionsFeature, displayOptions)
|
|
||||||
.build()
|
|
||||||
CandidatesWithSourceFeatures(userRecommendations, features)
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,22 +0,0 @@
|
|||||||
scala_library(
|
|
||||||
sources = ["*.scala"],
|
|
||||||
compiler_option_sets = ["fatal_warnings"],
|
|
||||||
platform = "java8",
|
|
||||||
strict_deps = True,
|
|
||||||
tags = ["bazel-compatible"],
|
|
||||||
dependencies = [
|
|
||||||
"account-recommendations-mixer/thrift/src/main/thrift:thrift-scala",
|
|
||||||
"finatra/inject/inject-core/src/main/scala/com/twitter/inject",
|
|
||||||
"product-mixer/component-library/src/main/scala/com/twitter/product_mixer/component_library/model/candidate",
|
|
||||||
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/functional_component/candidate_source",
|
|
||||||
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/pipeline/pipeline_failure",
|
|
||||||
"src/thrift/com/twitter/ads/adserver:adserver_common-scala",
|
|
||||||
"stitch/stitch-core",
|
|
||||||
],
|
|
||||||
exports = [
|
|
||||||
"account-recommendations-mixer/thrift/src/main/thrift:thrift-scala",
|
|
||||||
"finatra/inject/inject-core/src/main/scala/com/twitter/inject",
|
|
||||||
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/functional_component/candidate_source",
|
|
||||||
"stitch/stitch-core",
|
|
||||||
],
|
|
||||||
)
|
|
Binary file not shown.
Binary file not shown.
@ -1,29 +0,0 @@
|
|||||||
package com.twitter.product_mixer.component_library.candidate_source.ads
|
|
||||||
|
|
||||||
import com.twitter.adserver.thriftscala.AdImpression
|
|
||||||
import com.twitter.adserver.thriftscala.AdRequestParams
|
|
||||||
import com.twitter.adserver.thriftscala.AdRequestResponse
|
|
||||||
import com.twitter.product_mixer.core.functional_component.candidate_source.strato.StratoKeyFetcherSource
|
|
||||||
import com.twitter.product_mixer.core.model.common.identifier.CandidateSourceIdentifier
|
|
||||||
import com.twitter.strato.client.Fetcher
|
|
||||||
import com.twitter.strato.generated.client.ads.admixer.MakeAdRequestClientColumn
|
|
||||||
import javax.inject.Inject
|
|
||||||
import javax.inject.Singleton
|
|
||||||
|
|
||||||
@Singleton
|
|
||||||
class AdsProdStratoCandidateSource @Inject() (adsClient: MakeAdRequestClientColumn)
|
|
||||||
extends StratoKeyFetcherSource[
|
|
||||||
AdRequestParams,
|
|
||||||
AdRequestResponse,
|
|
||||||
AdImpression
|
|
||||||
] {
|
|
||||||
|
|
||||||
override val identifier: CandidateSourceIdentifier = CandidateSourceIdentifier("AdsProdStrato")
|
|
||||||
|
|
||||||
override val fetcher: Fetcher[AdRequestParams, Unit, AdRequestResponse] = adsClient.fetcher
|
|
||||||
|
|
||||||
override protected def stratoResultTransformer(
|
|
||||||
stratoResult: AdRequestResponse
|
|
||||||
): Seq[AdImpression] =
|
|
||||||
stratoResult.impressions
|
|
||||||
}
|
|
Binary file not shown.
@ -1,22 +0,0 @@
|
|||||||
package com.twitter.product_mixer.component_library.candidate_source.ads
|
|
||||||
|
|
||||||
import com.twitter.adserver.thriftscala.AdImpression
|
|
||||||
import com.twitter.adserver.thriftscala.AdRequestParams
|
|
||||||
import com.twitter.adserver.thriftscala.NewAdServer
|
|
||||||
import com.twitter.product_mixer.core.functional_component.candidate_source.CandidateSource
|
|
||||||
import com.twitter.product_mixer.core.model.common.identifier.CandidateSourceIdentifier
|
|
||||||
import com.twitter.stitch.Stitch
|
|
||||||
import javax.inject.Inject
|
|
||||||
import javax.inject.Singleton
|
|
||||||
|
|
||||||
@Singleton
|
|
||||||
class AdsProdThriftCandidateSource @Inject() (
|
|
||||||
adServerClient: NewAdServer.MethodPerEndpoint)
|
|
||||||
extends CandidateSource[AdRequestParams, AdImpression] {
|
|
||||||
|
|
||||||
override val identifier: CandidateSourceIdentifier =
|
|
||||||
CandidateSourceIdentifier("AdsProdThrift")
|
|
||||||
|
|
||||||
override def apply(request: AdRequestParams): Stitch[Seq[AdImpression]] =
|
|
||||||
Stitch.callFuture(adServerClient.makeAdRequest(request)).map(_.impressions)
|
|
||||||
}
|
|
Binary file not shown.
@ -1,28 +0,0 @@
|
|||||||
package com.twitter.product_mixer.component_library.candidate_source.ads
|
|
||||||
|
|
||||||
import com.twitter.adserver.thriftscala.AdImpression
|
|
||||||
import com.twitter.adserver.thriftscala.AdRequestParams
|
|
||||||
import com.twitter.adserver.thriftscala.AdRequestResponse
|
|
||||||
import com.twitter.product_mixer.core.functional_component.candidate_source.strato.StratoKeyFetcherSource
|
|
||||||
import com.twitter.product_mixer.core.model.common.identifier.CandidateSourceIdentifier
|
|
||||||
import com.twitter.strato.client.Fetcher
|
|
||||||
import com.twitter.strato.generated.client.ads.admixer.MakeAdRequestStagingClientColumn
|
|
||||||
import javax.inject.Inject
|
|
||||||
import javax.inject.Singleton
|
|
||||||
|
|
||||||
@Singleton
|
|
||||||
class AdsStagingCandidateSource @Inject() (adsClient: MakeAdRequestStagingClientColumn)
|
|
||||||
extends StratoKeyFetcherSource[
|
|
||||||
AdRequestParams,
|
|
||||||
AdRequestResponse,
|
|
||||||
AdImpression
|
|
||||||
] {
|
|
||||||
override val identifier: CandidateSourceIdentifier = CandidateSourceIdentifier("AdsStaging")
|
|
||||||
|
|
||||||
override val fetcher: Fetcher[AdRequestParams, Unit, AdRequestResponse] = adsClient.fetcher
|
|
||||||
|
|
||||||
override protected def stratoResultTransformer(
|
|
||||||
stratoResult: AdRequestResponse
|
|
||||||
): Seq[AdImpression] =
|
|
||||||
stratoResult.impressions
|
|
||||||
}
|
|
@ -1,18 +0,0 @@
|
|||||||
scala_library(
|
|
||||||
sources = ["*.scala"],
|
|
||||||
compiler_option_sets = ["fatal_warnings"],
|
|
||||||
strict_deps = True,
|
|
||||||
tags = ["bazel-compatible"],
|
|
||||||
dependencies = [
|
|
||||||
"3rdparty/jvm/javax/inject:javax.inject",
|
|
||||||
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/functional_component/candidate_source/strato",
|
|
||||||
"src/thrift/com/twitter/ads/adserver:adserver_common-scala",
|
|
||||||
"src/thrift/com/twitter/ads/adserver:adserver_rpc-scala",
|
|
||||||
"strato/config/columns/ads/admixer:admixer-strato-client",
|
|
||||||
],
|
|
||||||
exports = [
|
|
||||||
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/functional_component/candidate_source/strato",
|
|
||||||
"src/thrift/com/twitter/ads/adserver:adserver_common-scala",
|
|
||||||
"src/thrift/com/twitter/ads/adserver:adserver_rpc-scala",
|
|
||||||
],
|
|
||||||
)
|
|
Binary file not shown.
Binary file not shown.
@ -1,43 +0,0 @@
|
|||||||
package com.twitter.product_mixer.component_library.candidate_source.ann
|
|
||||||
|
|
||||||
import com.twitter.ann.common._
|
|
||||||
import com.twitter.product_mixer.core.functional_component.candidate_source.CandidateSource
|
|
||||||
import com.twitter.product_mixer.core.model.common.identifier.CandidateSourceIdentifier
|
|
||||||
import com.twitter.stitch.Stitch
|
|
||||||
import com.twitter.util.{Time => _, _}
|
|
||||||
import com.twitter.finagle.util.DefaultTimer
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @param annQueryableById Ann Queryable by Id client that returns nearest neighbors for a sequence of queries
|
|
||||||
* @param identifier Candidate Source Identifier
|
|
||||||
* @tparam T1 type of the query.
|
|
||||||
* @tparam T2 type of the result.
|
|
||||||
* @tparam P runtime parameters supported by the index.
|
|
||||||
* @tparam D distance function used in the index.
|
|
||||||
*/
|
|
||||||
class AnnCandidateSource[T1, T2, P <: RuntimeParams, D <: Distance[D]](
|
|
||||||
val annQueryableById: QueryableById[T1, T2, P, D],
|
|
||||||
val batchSize: Int,
|
|
||||||
val timeoutPerRequest: Duration,
|
|
||||||
override val identifier: CandidateSourceIdentifier)
|
|
||||||
extends CandidateSource[AnnIdQuery[T1, P], NeighborWithDistanceWithSeed[T1, T2, D]] {
|
|
||||||
|
|
||||||
implicit val timer = DefaultTimer
|
|
||||||
|
|
||||||
override def apply(
|
|
||||||
request: AnnIdQuery[T1, P]
|
|
||||||
): Stitch[Seq[NeighborWithDistanceWithSeed[T1, T2, D]]] = {
|
|
||||||
val ids = request.ids
|
|
||||||
val numOfNeighbors = request.numOfNeighbors
|
|
||||||
val runtimeParams = request.runtimeParams
|
|
||||||
Stitch
|
|
||||||
.collect(
|
|
||||||
ids
|
|
||||||
.grouped(batchSize).map { batchedIds =>
|
|
||||||
annQueryableById
|
|
||||||
.batchQueryWithDistanceById(batchedIds, numOfNeighbors, runtimeParams).map {
|
|
||||||
annResult => annResult.toSeq
|
|
||||||
}.within(timeoutPerRequest).handle { case _ => Seq.empty }
|
|
||||||
}.toSeq).map(_.flatten)
|
|
||||||
}
|
|
||||||
}
|
|
Binary file not shown.
@ -1,18 +0,0 @@
|
|||||||
package com.twitter.product_mixer.component_library.candidate_source.ann
|
|
||||||
|
|
||||||
import com.twitter.ann.common._
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A [[AnnIdQuery]] is a query class which defines the ann entities with runtime params and number of neighbors requested
|
|
||||||
*
|
|
||||||
* @param ids Sequence of queries
|
|
||||||
* @param numOfNeighbors Number of neighbors requested
|
|
||||||
* @param runtimeParams ANN Runtime Params
|
|
||||||
* @param batchSize Batch size to the stitch client
|
|
||||||
* @tparam T type of query.
|
|
||||||
* @tparam P runtime parameters supported by the index.
|
|
||||||
*/
|
|
||||||
case class AnnIdQuery[T, P <: RuntimeParams](
|
|
||||||
ids: Seq[T],
|
|
||||||
numOfNeighbors: Int,
|
|
||||||
runtimeParams: P)
|
|
@ -1,17 +0,0 @@
|
|||||||
scala_library(
|
|
||||||
sources = ["*.scala"],
|
|
||||||
compiler_option_sets = ["fatal_warnings"],
|
|
||||||
strict_deps = True,
|
|
||||||
tags = ["bazel-compatible"],
|
|
||||||
dependencies = [
|
|
||||||
"ann/src/main/scala/com/twitter/ann/common",
|
|
||||||
"ann/src/main/scala/com/twitter/ann/hnsw",
|
|
||||||
"ann/src/main/thrift/com/twitter/ann/common:ann-common-scala",
|
|
||||||
"product-mixer/component-library/src/main/thrift/com/twitter/product_mixer/component_library:thrift-scala",
|
|
||||||
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/functional_component/candidate_source",
|
|
||||||
"servo/manhattan/src/main/scala",
|
|
||||||
"servo/repo/src/main/scala",
|
|
||||||
"servo/util/src/main/scala",
|
|
||||||
"stitch/stitch-core",
|
|
||||||
],
|
|
||||||
)
|
|
Binary file not shown.
@ -1,14 +0,0 @@
|
|||||||
scala_library(
|
|
||||||
sources = ["*.scala"],
|
|
||||||
compiler_option_sets = ["fatal_warnings"],
|
|
||||||
strict_deps = True,
|
|
||||||
tags = ["bazel-compatible"],
|
|
||||||
dependencies = [
|
|
||||||
"product-mixer/component-library/src/main/scala/com/twitter/product_mixer/component_library/model/cursor",
|
|
||||||
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/functional_component/candidate_source/strato",
|
|
||||||
"src/thrift/com/twitter/periscope/audio_space:audio_space-scala",
|
|
||||||
"strato/config/columns/periscope:periscope-strato-client",
|
|
||||||
"strato/config/src/thrift/com/twitter/strato/graphql:graphql-scala",
|
|
||||||
"strato/src/main/scala/com/twitter/strato/client",
|
|
||||||
],
|
|
||||||
)
|
|
Binary file not shown.
Binary file not shown.
@ -1,49 +0,0 @@
|
|||||||
package com.twitter.product_mixer.component_library.candidate_source.audiospace
|
|
||||||
|
|
||||||
import com.twitter.periscope.audio_space.thriftscala.CreatedSpacesView
|
|
||||||
import com.twitter.periscope.audio_space.thriftscala.SpaceSlice
|
|
||||||
import com.twitter.product_mixer.component_library.model.cursor.NextCursorFeature
|
|
||||||
import com.twitter.product_mixer.component_library.model.cursor.PreviousCursorFeature
|
|
||||||
import com.twitter.product_mixer.core.feature.featuremap.FeatureMap
|
|
||||||
import com.twitter.product_mixer.core.feature.featuremap.FeatureMapBuilder
|
|
||||||
import com.twitter.product_mixer.core.functional_component.candidate_source.strato.StratoKeyViewFetcherWithSourceFeaturesSource
|
|
||||||
import com.twitter.product_mixer.core.model.common.identifier.CandidateSourceIdentifier
|
|
||||||
import com.twitter.strato.client.Fetcher
|
|
||||||
import com.twitter.strato.generated.client.periscope.CreatedSpacesSliceOnUserClientColumn
|
|
||||||
import javax.inject.Inject
|
|
||||||
import javax.inject.Singleton
|
|
||||||
|
|
||||||
@Singleton
|
|
||||||
class CreatedSpacesCandidateSource @Inject() (
|
|
||||||
column: CreatedSpacesSliceOnUserClientColumn)
|
|
||||||
extends StratoKeyViewFetcherWithSourceFeaturesSource[
|
|
||||||
Long,
|
|
||||||
CreatedSpacesView,
|
|
||||||
SpaceSlice,
|
|
||||||
String
|
|
||||||
] {
|
|
||||||
|
|
||||||
override val identifier: CandidateSourceIdentifier = CandidateSourceIdentifier("CreatedSpaces")
|
|
||||||
|
|
||||||
override val fetcher: Fetcher[Long, CreatedSpacesView, SpaceSlice] = column.fetcher
|
|
||||||
|
|
||||||
override def stratoResultTransformer(
|
|
||||||
stratoKey: Long,
|
|
||||||
stratoResult: SpaceSlice
|
|
||||||
): Seq[String] =
|
|
||||||
stratoResult.items
|
|
||||||
|
|
||||||
override protected def extractFeaturesFromStratoResult(
|
|
||||||
stratoKey: Long,
|
|
||||||
stratoResult: SpaceSlice
|
|
||||||
): FeatureMap = {
|
|
||||||
val featureMapBuilder = FeatureMapBuilder()
|
|
||||||
stratoResult.sliceInfo.previousCursor.foreach { cursor =>
|
|
||||||
featureMapBuilder.add(PreviousCursorFeature, cursor)
|
|
||||||
}
|
|
||||||
stratoResult.sliceInfo.nextCursor.foreach { cursor =>
|
|
||||||
featureMapBuilder.add(NextCursorFeature, cursor)
|
|
||||||
}
|
|
||||||
featureMapBuilder.build()
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,13 +0,0 @@
|
|||||||
scala_library(
|
|
||||||
sources = ["*.scala"],
|
|
||||||
compiler_option_sets = ["fatal_warnings"],
|
|
||||||
strict_deps = True,
|
|
||||||
tags = ["bazel-compatible"],
|
|
||||||
dependencies = [
|
|
||||||
"product-mixer/component-library/src/main/scala/com/twitter/product_mixer/component_library/model/cursor",
|
|
||||||
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/functional_component/candidate_source/strato",
|
|
||||||
"strato/config/columns/consumer-identity/business-profiles:business-profiles-strato-client",
|
|
||||||
"strato/config/src/thrift/com/twitter/strato/graphql:graphql-scala",
|
|
||||||
"strato/src/main/scala/com/twitter/strato/client",
|
|
||||||
],
|
|
||||||
)
|
|
Binary file not shown.
Binary file not shown.
@ -1,53 +0,0 @@
|
|||||||
package com.twitter.product_mixer.component_library.candidate_source.business_profiles
|
|
||||||
|
|
||||||
import com.twitter.product_mixer.component_library.model.cursor.NextCursorFeature
|
|
||||||
import com.twitter.product_mixer.component_library.model.cursor.PreviousCursorFeature
|
|
||||||
import com.twitter.product_mixer.core.feature.featuremap.FeatureMap
|
|
||||||
import com.twitter.product_mixer.core.feature.featuremap.FeatureMapBuilder
|
|
||||||
import com.twitter.product_mixer.core.functional_component.candidate_source.strato.StratoKeyViewFetcherWithSourceFeaturesSource
|
|
||||||
import com.twitter.product_mixer.core.model.common.identifier.CandidateSourceIdentifier
|
|
||||||
import com.twitter.strato.client.Fetcher
|
|
||||||
import com.twitter.strato.generated.client.consumer_identity.business_profiles.BusinessProfileTeamMembersOnUserClientColumn
|
|
||||||
import com.twitter.strato.generated.client.consumer_identity.business_profiles.BusinessProfileTeamMembersOnUserClientColumn.{
|
|
||||||
Value => TeamMembersSlice
|
|
||||||
}
|
|
||||||
import com.twitter.strato.generated.client.consumer_identity.business_profiles.BusinessProfileTeamMembersOnUserClientColumn.{
|
|
||||||
View => TeamMembersView
|
|
||||||
}
|
|
||||||
import javax.inject.Inject
|
|
||||||
import javax.inject.Singleton
|
|
||||||
|
|
||||||
@Singleton
|
|
||||||
class TeamMembersCandidateSource @Inject() (
|
|
||||||
column: BusinessProfileTeamMembersOnUserClientColumn)
|
|
||||||
extends StratoKeyViewFetcherWithSourceFeaturesSource[
|
|
||||||
Long,
|
|
||||||
TeamMembersView,
|
|
||||||
TeamMembersSlice,
|
|
||||||
Long
|
|
||||||
] {
|
|
||||||
override val identifier: CandidateSourceIdentifier = CandidateSourceIdentifier(
|
|
||||||
"BusinessProfileTeamMembers")
|
|
||||||
|
|
||||||
override val fetcher: Fetcher[Long, TeamMembersView, TeamMembersSlice] = column.fetcher
|
|
||||||
|
|
||||||
override def stratoResultTransformer(
|
|
||||||
stratoKey: Long,
|
|
||||||
stratoResult: TeamMembersSlice
|
|
||||||
): Seq[Long] =
|
|
||||||
stratoResult.members
|
|
||||||
|
|
||||||
override protected def extractFeaturesFromStratoResult(
|
|
||||||
stratoKey: Long,
|
|
||||||
stratoResult: TeamMembersSlice
|
|
||||||
): FeatureMap = {
|
|
||||||
val featureMapBuilder = FeatureMapBuilder()
|
|
||||||
stratoResult.previousCursor.foreach { cursor =>
|
|
||||||
featureMapBuilder.add(PreviousCursorFeature, cursor.toString)
|
|
||||||
}
|
|
||||||
stratoResult.nextCursor.foreach { cursor =>
|
|
||||||
featureMapBuilder.add(NextCursorFeature, cursor.toString)
|
|
||||||
}
|
|
||||||
featureMapBuilder.build()
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,12 +0,0 @@
|
|||||||
scala_library(
|
|
||||||
sources = ["*.scala"],
|
|
||||||
compiler_option_sets = ["fatal_warnings"],
|
|
||||||
strict_deps = True,
|
|
||||||
tags = ["bazel-compatible"],
|
|
||||||
dependencies = [
|
|
||||||
"cr-mixer/thrift/src/main/thrift:thrift-scala",
|
|
||||||
"finatra/inject/inject-core/src/main/scala/com/twitter/inject",
|
|
||||||
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/functional_component/candidate_source",
|
|
||||||
"stitch/stitch-core",
|
|
||||||
],
|
|
||||||
)
|
|
Binary file not shown.
Binary file not shown.
@ -1,25 +0,0 @@
|
|||||||
package com.twitter.product_mixer.component_library.candidate_source.cr_mixer
|
|
||||||
|
|
||||||
import com.twitter.cr_mixer.{thriftscala => t}
|
|
||||||
import com.twitter.product_mixer.core.functional_component.candidate_source.CandidateSource
|
|
||||||
import com.twitter.product_mixer.core.model.common.identifier.CandidateSourceIdentifier
|
|
||||||
import com.twitter.stitch.Stitch
|
|
||||||
import javax.inject.Inject
|
|
||||||
import javax.inject.Singleton
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns out-of-network Tweet recommendations by using user recommendations
|
|
||||||
* from FollowRecommendationService as an input seed-set to Earlybird
|
|
||||||
*/
|
|
||||||
@Singleton
|
|
||||||
class CrMixerFrsBasedTweetRecommendationsCandidateSource @Inject() (
|
|
||||||
crMixerClient: t.CrMixer.MethodPerEndpoint)
|
|
||||||
extends CandidateSource[t.FrsTweetRequest, t.FrsTweet] {
|
|
||||||
|
|
||||||
override val identifier: CandidateSourceIdentifier =
|
|
||||||
CandidateSourceIdentifier("CrMixerFrsBasedTweetRecommendations")
|
|
||||||
|
|
||||||
override def apply(request: t.FrsTweetRequest): Stitch[Seq[t.FrsTweet]] = Stitch
|
|
||||||
.callFuture(crMixerClient.getFrsBasedTweetRecommendations(request))
|
|
||||||
.map(_.tweets)
|
|
||||||
}
|
|
Binary file not shown.
@ -1,21 +0,0 @@
|
|||||||
package com.twitter.product_mixer.component_library.candidate_source.cr_mixer
|
|
||||||
|
|
||||||
import com.twitter.cr_mixer.{thriftscala => t}
|
|
||||||
import com.twitter.product_mixer.core.functional_component.candidate_source.CandidateSource
|
|
||||||
import com.twitter.product_mixer.core.model.common.identifier.CandidateSourceIdentifier
|
|
||||||
import com.twitter.stitch.Stitch
|
|
||||||
import javax.inject.Inject
|
|
||||||
import javax.inject.Singleton
|
|
||||||
|
|
||||||
@Singleton
|
|
||||||
class CrMixerTweetRecommendationsCandidateSource @Inject() (
|
|
||||||
crMixerClient: t.CrMixer.MethodPerEndpoint)
|
|
||||||
extends CandidateSource[t.CrMixerTweetRequest, t.TweetRecommendation] {
|
|
||||||
|
|
||||||
override val identifier: CandidateSourceIdentifier =
|
|
||||||
CandidateSourceIdentifier("CrMixerTweetRecommendations")
|
|
||||||
|
|
||||||
override def apply(request: t.CrMixerTweetRequest): Stitch[Seq[t.TweetRecommendation]] = Stitch
|
|
||||||
.callFuture(crMixerClient.getTweetRecommendations(request))
|
|
||||||
.map(_.tweets)
|
|
||||||
}
|
|
@ -1,12 +0,0 @@
|
|||||||
scala_library(
|
|
||||||
sources = ["*.scala"],
|
|
||||||
compiler_option_sets = ["fatal_warnings"],
|
|
||||||
strict_deps = True,
|
|
||||||
tags = ["bazel-compatible"],
|
|
||||||
dependencies = [
|
|
||||||
"finatra/inject/inject-core/src/main/scala/com/twitter/inject",
|
|
||||||
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/functional_component/candidate_source",
|
|
||||||
"src/thrift/com/twitter/search:earlybird-scala",
|
|
||||||
"stitch/stitch-core",
|
|
||||||
],
|
|
||||||
)
|
|
Binary file not shown.
Binary file not shown.
@ -1,26 +0,0 @@
|
|||||||
package com.twitter.product_mixer.component_library.candidate_source.earlybird
|
|
||||||
|
|
||||||
import com.twitter.search.earlybird.{thriftscala => t}
|
|
||||||
import com.twitter.inject.Logging
|
|
||||||
import com.twitter.product_mixer.core.functional_component.candidate_source.CandidateSource
|
|
||||||
import com.twitter.product_mixer.core.model.common.identifier.CandidateSourceIdentifier
|
|
||||||
import com.twitter.stitch.Stitch
|
|
||||||
import javax.inject.Inject
|
|
||||||
import javax.inject.Singleton
|
|
||||||
|
|
||||||
@Singleton
|
|
||||||
class EarlybirdTweetCandidateSource @Inject() (
|
|
||||||
earlybirdService: t.EarlybirdService.MethodPerEndpoint)
|
|
||||||
extends CandidateSource[t.EarlybirdRequest, t.ThriftSearchResult]
|
|
||||||
with Logging {
|
|
||||||
|
|
||||||
override val identifier: CandidateSourceIdentifier = CandidateSourceIdentifier("EarlybirdTweets")
|
|
||||||
|
|
||||||
override def apply(request: t.EarlybirdRequest): Stitch[Seq[t.ThriftSearchResult]] = {
|
|
||||||
Stitch
|
|
||||||
.callFuture(earlybirdService.search(request))
|
|
||||||
.map { response: t.EarlybirdResponse =>
|
|
||||||
response.searchResults.map(_.results).getOrElse(Seq.empty)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,12 +0,0 @@
|
|||||||
scala_library(
|
|
||||||
sources = ["*.scala"],
|
|
||||||
compiler_option_sets = ["fatal_warnings"],
|
|
||||||
strict_deps = True,
|
|
||||||
tags = ["bazel-compatible"],
|
|
||||||
dependencies = [
|
|
||||||
"3rdparty/jvm/javax/inject:javax.inject",
|
|
||||||
"explore/explore-ranker/thrift/src/main/thrift:thrift-scala",
|
|
||||||
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/functional_component/candidate_source",
|
|
||||||
"stitch/stitch-core",
|
|
||||||
],
|
|
||||||
)
|
|
Binary file not shown.
Binary file not shown.
@ -1,31 +0,0 @@
|
|||||||
package com.twitter.product_mixer.component_library.candidate_source.explore_ranker
|
|
||||||
|
|
||||||
import com.twitter.explore_ranker.{thriftscala => t}
|
|
||||||
import com.twitter.product_mixer.core.functional_component.candidate_source.CandidateSource
|
|
||||||
import com.twitter.product_mixer.core.model.common.identifier.CandidateSourceIdentifier
|
|
||||||
import com.twitter.stitch.Stitch
|
|
||||||
import javax.inject.Inject
|
|
||||||
import javax.inject.Singleton
|
|
||||||
|
|
||||||
@Singleton
|
|
||||||
class ExploreRankerCandidateSource @Inject() (
|
|
||||||
exploreRankerService: t.ExploreRanker.MethodPerEndpoint)
|
|
||||||
extends CandidateSource[t.ExploreRankerRequest, t.ImmersiveRecsResult] {
|
|
||||||
|
|
||||||
override val identifier: CandidateSourceIdentifier = CandidateSourceIdentifier("ExploreRanker")
|
|
||||||
|
|
||||||
override def apply(
|
|
||||||
request: t.ExploreRankerRequest
|
|
||||||
): Stitch[Seq[t.ImmersiveRecsResult]] = {
|
|
||||||
Stitch
|
|
||||||
.callFuture(exploreRankerService.getRankedResults(request))
|
|
||||||
.map {
|
|
||||||
case t.ExploreRankerResponse(
|
|
||||||
t.ExploreRankerProductResponse
|
|
||||||
.ImmersiveRecsResponse(t.ImmersiveRecsResponse(immersiveRecsResults))) =>
|
|
||||||
immersiveRecsResults
|
|
||||||
case response =>
|
|
||||||
throw new UnsupportedOperationException(s"Unknown response type: $response")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,17 +0,0 @@
|
|||||||
scala_library(
|
|
||||||
sources = ["*.scala"],
|
|
||||||
compiler_option_sets = ["fatal_warnings"],
|
|
||||||
platform = "java8",
|
|
||||||
strict_deps = True,
|
|
||||||
tags = ["bazel-compatible"],
|
|
||||||
dependencies = [
|
|
||||||
"finatra/inject/inject-core/src/main/scala/com/twitter/inject",
|
|
||||||
"onboarding/service/thrift/src/main/thrift:thrift-scala",
|
|
||||||
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/functional_component/candidate_source",
|
|
||||||
"stitch/stitch-core",
|
|
||||||
],
|
|
||||||
exports = [
|
|
||||||
"onboarding/service/thrift/src/main/thrift:thrift-scala",
|
|
||||||
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/functional_component/candidate_source",
|
|
||||||
],
|
|
||||||
)
|
|
Binary file not shown.
Binary file not shown.
@ -1,50 +0,0 @@
|
|||||||
package com.twitter.product_mixer.component_library.candidate_source.flexible_injection_pipeline
|
|
||||||
|
|
||||||
import com.twitter.inject.Logging
|
|
||||||
import com.twitter.onboarding.injections.{thriftscala => injectionsthrift}
|
|
||||||
import com.twitter.onboarding.task.service.{thriftscala => servicethrift}
|
|
||||||
import com.twitter.product_mixer.core.functional_component.candidate_source.CandidateSource
|
|
||||||
import com.twitter.product_mixer.core.model.common.identifier.CandidateSourceIdentifier
|
|
||||||
import com.twitter.stitch.Stitch
|
|
||||||
import javax.inject.Inject
|
|
||||||
import javax.inject.Singleton
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns a list of prompts to insert into a user's timeline (inline prompt, cover modals, etc)
|
|
||||||
* from go/flip (the prompting platform for Twitter).
|
|
||||||
*/
|
|
||||||
@Singleton
|
|
||||||
class PromptCandidateSource @Inject() (taskService: servicethrift.TaskService.MethodPerEndpoint)
|
|
||||||
extends CandidateSource[servicethrift.GetInjectionsRequest, IntermediatePrompt]
|
|
||||||
with Logging {
|
|
||||||
|
|
||||||
override val identifier: CandidateSourceIdentifier = CandidateSourceIdentifier(
|
|
||||||
"InjectionPipelinePrompts")
|
|
||||||
|
|
||||||
override def apply(
|
|
||||||
request: servicethrift.GetInjectionsRequest
|
|
||||||
): Stitch[Seq[IntermediatePrompt]] = {
|
|
||||||
Stitch
|
|
||||||
.callFuture(taskService.getInjections(request)).map {
|
|
||||||
_.injections.flatMap {
|
|
||||||
// The entire carousel is getting added to each IntermediatePrompt item with a
|
|
||||||
// corresponding index to be unpacked later on to populate its TimelineEntry counterpart.
|
|
||||||
case injection: injectionsthrift.Injection.TilesCarousel =>
|
|
||||||
injection.tilesCarousel.tiles.zipWithIndex.map {
|
|
||||||
case (tile: injectionsthrift.Tile, index: Int) =>
|
|
||||||
IntermediatePrompt(injection, Some(index), Some(tile))
|
|
||||||
}
|
|
||||||
case injection => Seq(IntermediatePrompt(injection, None, None))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Gives an intermediate step to help 'explosion' of tile carousel tiles due to TimelineModule
|
|
||||||
* not being an extension of TimelineItem
|
|
||||||
*/
|
|
||||||
case class IntermediatePrompt(
|
|
||||||
injection: injectionsthrift.Injection,
|
|
||||||
offsetInModule: Option[Int],
|
|
||||||
carouselTile: Option[injectionsthrift.Tile])
|
|
@ -1,16 +0,0 @@
|
|||||||
scala_library(
|
|
||||||
sources = ["*.scala"],
|
|
||||||
compiler_option_sets = ["fatal_warnings"],
|
|
||||||
strict_deps = True,
|
|
||||||
tags = ["bazel-compatible"],
|
|
||||||
dependencies = [
|
|
||||||
"3rdparty/jvm/javax/inject:javax.inject",
|
|
||||||
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/functional_component/candidate_source/strato",
|
|
||||||
"src/thrift/com/twitter/hermit:hermit-scala",
|
|
||||||
"strato/config/columns/onboarding:onboarding-strato-client",
|
|
||||||
],
|
|
||||||
exports = [
|
|
||||||
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/functional_component/candidate_source/strato",
|
|
||||||
"src/thrift/com/twitter/hermit:hermit-scala",
|
|
||||||
],
|
|
||||||
)
|
|
Binary file not shown.
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user