This commit is contained in:
Erfan Safari 2023-07-17 18:50:53 +02:00 committed by GitHub
commit 5832089c01
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 134 additions and 41 deletions

View File

@ -39,14 +39,18 @@ impl PredictRequest {
) -> Vec<TensorInput> { ) -> Vec<TensorInput> {
let mut model_inputs = Vec::<TensorInput>::new(); let mut model_inputs = Vec::<TensorInput>::new();
for input_name in inputs.as_slice() { for input_name in inputs.as_slice() {
self.make_inputs(&mut model_inputs);
}
model_inputs
}
#[inline(always)]
fn make_inputs(model_inputs: &mut Vec<TensorInput>) {
let input_tensor = self let input_tensor = self
.inputs .inputs
.get_mut(input_name) .get_mut(input_name)
.unwrap_or_else(|| panic!("can't find {:?}", input_name)); .unwrap_or_else(|| panic!("can't find {:?}", input_name));
let dims = match &input_tensor.tensor_shape { let dims = input_tensor.tensor_shape.as_ref().map(|data| data.dim.iter().map(|d| d.size).collect_vec());
None => None,
Some(data) => Some(data.dim.iter().map(|d| d.size).collect_vec()),
};
match input_tensor.dtype() { match input_tensor.dtype() {
DataType::DtFloat => model_inputs.push(TensorInput::new( DataType::DtFloat => model_inputs.push(TensorInput::new(
TensorInputEnum::Float(std::mem::take(&mut input_tensor.float_val)), TensorInputEnum::Float(std::mem::take(&mut input_tensor.float_val)),
@ -81,8 +85,7 @@ impl PredictRequest {
_ => panic!("unsupport input tensor type {:?}", input_tensor.dtype()), _ => panic!("unsupport input tensor type {:?}", input_tensor.dtype()),
} }
} }
model_inputs
}
#[inline(always)] #[inline(always)]
pub fn take_model_spec(&mut self) -> (String, Option<i64>) { pub fn take_model_spec(&mut self) -> (String, Option<i64>) {
let model_spec = self.model_spec.as_mut().unwrap(); let model_spec = self.model_spec.as_mut().unwrap();

View File

@ -245,11 +245,11 @@ pub fn bootstrap<T: Model>(model_factory: ModelFactory<T>) -> Result<()> {
info!("package: {}, version: {}, args: {:?}", NAME, VERSION, *ARGS); info!("package: {}, version: {}, args: {:?}", NAME, VERSION, *ARGS);
//we follow SemVer. So here we assume MAJOR.MINOR.PATCH //we follow SemVer. So here we assume MAJOR.MINOR.PATCH
let parts = VERSION let parts = VERSION
.split(".") .split('.')
.map(|v| v.parse::<i64>()) .map(|v| v.parse::<i64>())
.collect::<std::result::Result<Vec<_>, _>>()?; .collect::<std::result::Result<Vec<_>, _>>()?;
if let [major, minor, patch] = &parts[..] { if let [major, minor, patch] = &parts[..] {
NAVI_VERSION.set(major * 1000_000 + minor * 1000 + patch); NAVI_VERSION.set(major * 1_000_000 + minor * 1000 + patch);
} else { } else {
warn!( warn!(
"version {} doesn't follow SemVer conversion of MAJOR.MINOR.PATCH", "version {} doesn't follow SemVer conversion of MAJOR.MINOR.PATCH",

View File

@ -212,7 +212,7 @@ lazy_static! {
inputs.push(OnceCell::new()); inputs.push(OnceCell::new());
} else { } else {
inputs.push(OnceCell::with_value( inputs.push(OnceCell::with_value(
o.split(",") o.split(',')
.map(|s| s.to_owned()) .map(|s| s.to_owned())
.collect::<ArrayVec<String, MAX_NUM_INPUTS>>(), .collect::<ArrayVec<String, MAX_NUM_INPUTS>>(),
)); ));
@ -225,7 +225,7 @@ lazy_static! {
let mut outputs = ArrayVec::<ArrayVec<String, MAX_NUM_OUTPUTS>, MAX_NUM_MODELS>::new(); let mut outputs = ArrayVec::<ArrayVec<String, MAX_NUM_OUTPUTS>, MAX_NUM_MODELS>::new();
for o in ARGS.output.iter() { for o in ARGS.output.iter() {
outputs.push( outputs.push(
o.split(",") o.split(',')
.map(|s| s.to_owned()) .map(|s| s.to_owned())
.collect::<ArrayVec<String, MAX_NUM_OUTPUTS>>(), .collect::<ArrayVec<String, MAX_NUM_OUTPUTS>>(),
); );

View File

@ -84,9 +84,7 @@ mod utils {
get_config_or_else(model_config, key, || default.to_string()) get_config_or_else(model_config, key, || default.to_string())
} }
pub fn get_meta_dir() -> &'static str { pub fn get_meta_dir() -> &'static str {
ARGS.meta_json_dir ARGS.meta_json_dir.as_deref()
.as_ref()
.map(|s| s.as_str())
.unwrap_or_else(|| { .unwrap_or_else(|| {
let model_dir = &ARGS.model_dir[0]; let model_dir = &ARGS.model_dir[0];
let meta_dir = &model_dir[0..model_dir.rfind(&MODEL_SPECS[0]).unwrap()]; let meta_dir = &model_dir[0..model_dir.rfind(&MODEL_SPECS[0]).unwrap()];

105
navi/segdense/Cargo.lock generated Normal file
View File

@ -0,0 +1,105 @@
# This file is automatically @generated by Cargo.
# It is not intended for manual editing.
version = 3
[[package]]
name = "cfg-if"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]]
name = "itoa"
version = "1.0.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "453ad9f582a441959e5f0d088b02ce04cfe8d51a8eaf077f12ac6d3e94164ca6"
[[package]]
name = "log"
version = "0.4.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "abb12e687cfb44aa40f41fc3978ef76448f9b6038cad6aef4259d3c095a2382e"
dependencies = [
"cfg-if",
]
[[package]]
name = "proc-macro2"
version = "1.0.56"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2b63bdb0cd06f1f4dedf69b254734f9b45af66e4a031e42a7480257d9898b435"
dependencies = [
"unicode-ident",
]
[[package]]
name = "quote"
version = "1.0.26"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4424af4bf778aae2051a77b60283332f386554255d722233d09fbfc7e30da2fc"
dependencies = [
"proc-macro2",
]
[[package]]
name = "ryu"
version = "1.0.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f91339c0467de62360649f8d3e185ca8de4224ff281f66000de5eb2a77a79041"
[[package]]
name = "segdense"
version = "0.1.0"
dependencies = [
"log",
"serde",
"serde_json",
]
[[package]]
name = "serde"
version = "1.0.159"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3c04e8343c3daeec41f58990b9d77068df31209f2af111e059e9fe9646693065"
dependencies = [
"serde_derive",
]
[[package]]
name = "serde_derive"
version = "1.0.159"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4c614d17805b093df4b147b51339e7e44bf05ef59fba1e45d83500bcfb4d8585"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "serde_json"
version = "1.0.95"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d721eca97ac802aa7777b701877c8004d950fc142651367300d21c1cc0194744"
dependencies = [
"itoa",
"ryu",
"serde",
]
[[package]]
name = "syn"
version = "2.0.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4c9da457c5285ac1f936ebd076af6dac17a61cfe7826f2076b4d015cf47bc8ec"
dependencies = [
"proc-macro2",
"quote",
"unicode-ident",
]
[[package]]
name = "unicode-ident"
version = "1.0.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e5464a87b239f13a63a501f2701565754bae92d243d4bb7eb12f6d57d2269bf4"

View File

@ -19,22 +19,11 @@ impl Display for SegDenseError {
match self { match self {
SegDenseError::IoError(io_error) => write!(f, "{}", io_error), SegDenseError::IoError(io_error) => write!(f, "{}", io_error),
SegDenseError::Json(serde_json) => write!(f, "{}", serde_json), SegDenseError::Json(serde_json) => write!(f, "{}", serde_json),
SegDenseError::JsonMissingRoot => { SegDenseError::JsonMissingRoot => write!(f, "SegDense JSON: Root Node note found!"),
write!(f, "{}", "SegDense JSON: Root Node note found!") SegDenseError::JsonMissingObject => write!(f, "SegDense JSON: Object note found!"),
} SegDenseError::JsonMissingArray => write!(f, "SegDense JSON: Array Node note found!"),
SegDenseError::JsonMissingObject => { SegDenseError::JsonArraySize => write!(f, "SegDense JSON: Array size not as expected!"),
write!(f, "{}", "SegDense JSON: Object note found!") SegDenseError::JsonMissingInputFeature => write!(f, "SegDense JSON: Missing input feature!"),
}
SegDenseError::JsonMissingArray => {
write!(f, "{}", "SegDense JSON: Array Node note found!")
}
SegDenseError::JsonArraySize => {
write!(f, "{}", "SegDense JSON: Array size not as expected!")
}
SegDenseError::JsonMissingInputFeature => {
write!(f, "{}", "SegDense JSON: Missing input feature!")
}
}
} }
} }

View File

@ -7,12 +7,10 @@ use crate::error::SegDenseError;
use crate::mapper::{FeatureInfo, FeatureMapper, MapWriter}; use crate::mapper::{FeatureInfo, FeatureMapper, MapWriter};
use crate::segdense_transform_spec_home_recap_2022::{self as seg_dense, InputFeature}; use crate::segdense_transform_spec_home_recap_2022::{self as seg_dense, InputFeature};
pub fn load_config(file_name: &str) -> Result<seg_dense::Root, SegDenseError> { pub fn load_config(file_name: &str) -> seg_dense::Root {
let json_str = fs::read_to_string(file_name)?; let json_str = fs::read_to_string(file_name).unwrap_or_else(|_| panic!("Unable to load segdense file {}", file_name));
// &format!("Unable to load segdense file {}", file_name));
let seg_dense_config = parse(&json_str)?; parse(&json_str).unwrap_or_else(|_| panic!("Unable to parse segdense file {}", file_name))
// &format!("Unable to parse segdense file {}", file_name));
Ok(seg_dense_config)
} }
pub fn parse(json_str: &str) -> Result<seg_dense::Root, SegDenseError> { pub fn parse(json_str: &str) -> Result<seg_dense::Root, SegDenseError> {