use anyhow::Result; use log::{info, warn}; use std::collections::HashMap; use tokio::time::Instant; use tonic::{ Request, Response, Status, transport::{Certificate, Identity, Server, ServerTlsConfig}, }; // protobuf related use crate::tf_proto::tensorflow_serving::{ ClassificationRequest, ClassificationResponse, GetModelMetadataRequest, GetModelMetadataResponse, MultiInferenceRequest, MultiInferenceResponse, PredictRequest, PredictResponse, RegressionRequest, RegressionResponse, }; use crate::{kf_serving::{ grpc_inference_service_server::GrpcInferenceService, ModelInferRequest, ModelInferResponse, ModelMetadataRequest, ModelMetadataResponse, ModelReadyRequest, ModelReadyResponse, ServerLiveRequest, ServerLiveResponse, ServerMetadataRequest, ServerMetadataResponse, ServerReadyRequest, ServerReadyResponse, }, ModelFactory, tf_proto::tensorflow_serving::prediction_service_server::{ PredictionService, PredictionServiceServer, }, VERSION, NAME}; use crate::PredictResult; use crate::cli_args::{ARGS, INPUTS, OUTPUTS}; use crate::metrics::{ NAVI_VERSION, NUM_PREDICTIONS, NUM_REQUESTS_FAILED, NUM_REQUESTS_FAILED_BY_MODEL, NUM_REQUESTS_RECEIVED, NUM_REQUESTS_RECEIVED_BY_MODEL, RESPONSE_TIME_COLLECTOR, }; use crate::predict_service::{Model, PredictService}; use crate::tf_proto::tensorflow_serving::model_spec::VersionChoice::Version; use crate::tf_proto::tensorflow_serving::ModelSpec; #[derive(Debug)] pub enum TensorInputEnum { String(Vec>), Int(Vec), Int64(Vec), Float(Vec), Double(Vec), Boolean(Vec), } #[derive(Debug)] pub struct TensorInput { pub tensor_data: TensorInputEnum, pub name: String, pub dims: Option>, } impl TensorInput { pub fn new(tensor_data: TensorInputEnum, name: String, dims: Option>) -> TensorInput { TensorInput { tensor_data, name, dims, } } } impl TensorInputEnum { #[inline(always)] pub(crate) fn extend(&mut self, another: TensorInputEnum) { match (self, another) { (Self::String(input), Self::String(ex)) => input.extend(ex), (Self::Int(input), Self::Int(ex)) => input.extend(ex), (Self::Int64(input), Self::Int64(ex)) => input.extend(ex), (Self::Float(input), Self::Float(ex)) => input.extend(ex), (Self::Double(input), Self::Double(ex)) => input.extend(ex), (Self::Boolean(input), Self::Boolean(ex)) => input.extend(ex), x => panic!("input enum type not matched. input:{:?}, ex:{:?}", x.0, x.1), } } #[inline(always)] pub(crate) fn merge_batch(input_tensors: Vec>) -> Vec { input_tensors .into_iter() .reduce(|mut acc, e| { for (i, ext) in acc.iter_mut().zip(e) { i.tensor_data.extend(ext.tensor_data); } acc }) .unwrap() //invariant: we expect there's always rows in input_tensors } } ///entry point for tfServing gRPC #[tonic::async_trait] impl GrpcInferenceService for PredictService { async fn server_live( &self, _request: Request, ) -> Result, Status> { unimplemented!() } async fn server_ready( &self, _request: Request, ) -> Result, Status> { unimplemented!() } async fn model_ready( &self, _request: Request, ) -> Result, Status> { unimplemented!() } async fn server_metadata( &self, _request: Request, ) -> Result, Status> { unimplemented!() } async fn model_metadata( &self, _request: Request, ) -> Result, Status> { unimplemented!() } async fn model_infer( &self, _request: Request, ) -> Result, Status> { unimplemented!() } } #[tonic::async_trait] impl PredictionService for PredictService { async fn classify( &self, _request: Request, ) -> Result, Status> { unimplemented!() } async fn regress( &self, _request: Request, ) -> Result, Status> { unimplemented!() } async fn predict( &self, request: Request, ) -> Result, Status> { NUM_REQUESTS_RECEIVED.inc(); let start = Instant::now(); let mut req = request.into_inner(); let (model_spec, version) = req.take_model_spec(); NUM_REQUESTS_RECEIVED_BY_MODEL .with_label_values(&[&model_spec]) .inc(); let idx = PredictService::::get_model_index(&model_spec).ok_or_else(|| { Status::failed_precondition(format!("model spec not found:{}", model_spec)) })?; let input_spec = match INPUTS[idx].get() { Some(input) => input, _ => return Err(Status::not_found(format!("model input spec {}", idx))), }; let input_val = req.take_input_vals(input_spec); self.predict(idx, version, input_val, start) .await .map_or_else( |e| { NUM_REQUESTS_FAILED.inc(); NUM_REQUESTS_FAILED_BY_MODEL .with_label_values(&[&model_spec]) .inc(); Err(Status::internal(e.to_string())) }, |res| { RESPONSE_TIME_COLLECTOR .with_label_values(&[&model_spec]) .observe(start.elapsed().as_millis() as f64); match res { PredictResult::Ok(tensors, version) => { let mut outputs = HashMap::new(); NUM_PREDICTIONS.with_label_values(&[&model_spec]).inc(); //FIXME: uncomment when prediction scores are normal // PREDICTION_SCORE_SUM // .with_label_values(&[&model_spec]) // .inc_by(tensors[0]as f64); for (tp, output_name) in tensors .into_iter() .map(|tensor| tensor.create_tensor_proto()) .zip(OUTPUTS[idx].iter()) { outputs.insert(output_name.to_owned(), tp); } let reply = PredictResponse { model_spec: Some(ModelSpec { version_choice: Some(Version(version)), ..Default::default() }), outputs, }; Ok(Response::new(reply)) } PredictResult::DropDueToOverload => Err(Status::resource_exhausted("")), PredictResult::ModelNotFound(idx) => { Err(Status::not_found(format!("model index {}", idx))) }, PredictResult::ModelNotReady(idx) => { Err(Status::unavailable(format!("model index {}", idx))) } PredictResult::ModelVersionNotFound(idx, version) => Err( Status::not_found(format!("model index:{}, version {}", idx, version)), ), } }, ) } async fn multi_inference( &self, _request: Request, ) -> Result, Status> { unimplemented!() } async fn get_model_metadata( &self, _request: Request, ) -> Result, Status> { unimplemented!() } } pub fn bootstrap(model_factory: ModelFactory) -> Result<()> { info!("package: {}, version: {}, args: {:?}", NAME, VERSION, *ARGS); //we follow SemVer. So here we assume MAJOR.MINOR.PATCH let parts = VERSION .split(".") .map(|v| v.parse::()) .collect::, _>>()?; if let [major, minor, patch] = &parts[..] { NAVI_VERSION.set(major * 1000_000 + minor * 1000 + patch); } else { warn!( "version {} doesn't follow SemVer conversion of MAJOR.MINOR.PATCH", VERSION ); } tokio::runtime::Builder::new_multi_thread() .thread_name("async worker") .worker_threads(ARGS.num_worker_threads) .max_blocking_threads(ARGS.max_blocking_threads) .enable_all() .build() .unwrap() .block_on(async { #[cfg(feature = "navi_console")] console_subscriber::init(); let addr = format!("0.0.0.0:{}", ARGS.port).parse()?; let ps = PredictService::init(model_factory).await; let mut builder = if ARGS.ssl_dir.is_empty() { Server::builder() } else { let key = tokio::fs::read(format!("{}/server.key", ARGS.ssl_dir)) .await .expect("can't find key file"); let crt = tokio::fs::read(format!("{}/server.crt", ARGS.ssl_dir)) .await .expect("can't find crt file"); let chain = tokio::fs::read(format!("{}/server.chain", ARGS.ssl_dir)) .await .expect("can't find chain file"); let mut pem = Vec::new(); pem.extend(crt); pem.extend(chain); let identity = Identity::from_pem(pem.clone(), key); let client_ca_cert = Certificate::from_pem(pem.clone()); let tls = ServerTlsConfig::new() .identity(identity) .client_ca_root(client_ca_cert); Server::builder() .tls_config(tls) .expect("fail to config SSL") }; info!( "Prometheus server started: 0.0.0.0: {}", ARGS.prometheus_port ); let ps_server = builder .add_service(PredictionServiceServer::new(ps).accept_gzip().send_gzip()) .serve(addr); info!("Prediction server started: {}", addr); ps_server.await.map_err(anyhow::Error::msg) }) }