mirror of
https://github.com/twitter/the-algorithm.git
synced 2025-01-03 08:01:53 +01:00
Merge f9b7db14c6
into 72eda9a24f
This commit is contained in:
commit
5832089c01
@ -39,14 +39,18 @@ impl PredictRequest {
|
||||
) -> Vec<TensorInput> {
|
||||
let mut model_inputs = Vec::<TensorInput>::new();
|
||||
for input_name in inputs.as_slice() {
|
||||
let input_tensor = self
|
||||
self.make_inputs(&mut model_inputs);
|
||||
}
|
||||
model_inputs
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn make_inputs(model_inputs: &mut Vec<TensorInput>) {
|
||||
let input_tensor = self
|
||||
.inputs
|
||||
.get_mut(input_name)
|
||||
.unwrap_or_else(|| panic!("can't find {:?}", input_name));
|
||||
let dims = match &input_tensor.tensor_shape {
|
||||
None => None,
|
||||
Some(data) => Some(data.dim.iter().map(|d| d.size).collect_vec()),
|
||||
};
|
||||
let dims = input_tensor.tensor_shape.as_ref().map(|data| data.dim.iter().map(|d| d.size).collect_vec());
|
||||
match input_tensor.dtype() {
|
||||
DataType::DtFloat => model_inputs.push(TensorInput::new(
|
||||
TensorInputEnum::Float(std::mem::take(&mut input_tensor.float_val)),
|
||||
@ -80,9 +84,8 @@ impl PredictRequest {
|
||||
)),
|
||||
_ => panic!("unsupport input tensor type {:?}", input_tensor.dtype()),
|
||||
}
|
||||
}
|
||||
model_inputs
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn take_model_spec(&mut self) -> (String, Option<i64>) {
|
||||
let model_spec = self.model_spec.as_mut().unwrap();
|
||||
|
@ -245,11 +245,11 @@ pub fn bootstrap<T: Model>(model_factory: ModelFactory<T>) -> Result<()> {
|
||||
info!("package: {}, version: {}, args: {:?}", NAME, VERSION, *ARGS);
|
||||
//we follow SemVer. So here we assume MAJOR.MINOR.PATCH
|
||||
let parts = VERSION
|
||||
.split(".")
|
||||
.split('.')
|
||||
.map(|v| v.parse::<i64>())
|
||||
.collect::<std::result::Result<Vec<_>, _>>()?;
|
||||
if let [major, minor, patch] = &parts[..] {
|
||||
NAVI_VERSION.set(major * 1000_000 + minor * 1000 + patch);
|
||||
NAVI_VERSION.set(major * 1_000_000 + minor * 1000 + patch);
|
||||
} else {
|
||||
warn!(
|
||||
"version {} doesn't follow SemVer conversion of MAJOR.MINOR.PATCH",
|
||||
|
@ -212,7 +212,7 @@ lazy_static! {
|
||||
inputs.push(OnceCell::new());
|
||||
} else {
|
||||
inputs.push(OnceCell::with_value(
|
||||
o.split(",")
|
||||
o.split(',')
|
||||
.map(|s| s.to_owned())
|
||||
.collect::<ArrayVec<String, MAX_NUM_INPUTS>>(),
|
||||
));
|
||||
@ -225,7 +225,7 @@ lazy_static! {
|
||||
let mut outputs = ArrayVec::<ArrayVec<String, MAX_NUM_OUTPUTS>, MAX_NUM_MODELS>::new();
|
||||
for o in ARGS.output.iter() {
|
||||
outputs.push(
|
||||
o.split(",")
|
||||
o.split(',')
|
||||
.map(|s| s.to_owned())
|
||||
.collect::<ArrayVec<String, MAX_NUM_OUTPUTS>>(),
|
||||
);
|
||||
@ -233,4 +233,4 @@ lazy_static! {
|
||||
info!("all outputs:{:?}", outputs);
|
||||
outputs
|
||||
};
|
||||
}
|
||||
}
|
@ -84,9 +84,7 @@ mod utils {
|
||||
get_config_or_else(model_config, key, || default.to_string())
|
||||
}
|
||||
pub fn get_meta_dir() -> &'static str {
|
||||
ARGS.meta_json_dir
|
||||
.as_ref()
|
||||
.map(|s| s.as_str())
|
||||
ARGS.meta_json_dir.as_deref()
|
||||
.unwrap_or_else(|| {
|
||||
let model_dir = &ARGS.model_dir[0];
|
||||
let meta_dir = &model_dir[0..model_dir.rfind(&MODEL_SPECS[0]).unwrap()];
|
||||
|
105
navi/segdense/Cargo.lock
generated
Normal file
105
navi/segdense/Cargo.lock
generated
Normal file
@ -0,0 +1,105 @@
|
||||
# This file is automatically @generated by Cargo.
|
||||
# It is not intended for manual editing.
|
||||
version = 3
|
||||
|
||||
[[package]]
|
||||
name = "cfg-if"
|
||||
version = "1.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
|
||||
|
||||
[[package]]
|
||||
name = "itoa"
|
||||
version = "1.0.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "453ad9f582a441959e5f0d088b02ce04cfe8d51a8eaf077f12ac6d3e94164ca6"
|
||||
|
||||
[[package]]
|
||||
name = "log"
|
||||
version = "0.4.17"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "abb12e687cfb44aa40f41fc3978ef76448f9b6038cad6aef4259d3c095a2382e"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "proc-macro2"
|
||||
version = "1.0.56"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2b63bdb0cd06f1f4dedf69b254734f9b45af66e4a031e42a7480257d9898b435"
|
||||
dependencies = [
|
||||
"unicode-ident",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "quote"
|
||||
version = "1.0.26"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4424af4bf778aae2051a77b60283332f386554255d722233d09fbfc7e30da2fc"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ryu"
|
||||
version = "1.0.13"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f91339c0467de62360649f8d3e185ca8de4224ff281f66000de5eb2a77a79041"
|
||||
|
||||
[[package]]
|
||||
name = "segdense"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"log",
|
||||
"serde",
|
||||
"serde_json",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde"
|
||||
version = "1.0.159"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3c04e8343c3daeec41f58990b9d77068df31209f2af111e059e9fe9646693065"
|
||||
dependencies = [
|
||||
"serde_derive",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_derive"
|
||||
version = "1.0.159"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4c614d17805b093df4b147b51339e7e44bf05ef59fba1e45d83500bcfb4d8585"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_json"
|
||||
version = "1.0.95"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d721eca97ac802aa7777b701877c8004d950fc142651367300d21c1cc0194744"
|
||||
dependencies = [
|
||||
"itoa",
|
||||
"ryu",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "syn"
|
||||
version = "2.0.13"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4c9da457c5285ac1f936ebd076af6dac17a61cfe7826f2076b4d015cf47bc8ec"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"unicode-ident",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "unicode-ident"
|
||||
version = "1.0.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e5464a87b239f13a63a501f2701565754bae92d243d4bb7eb12f6d57d2269bf4"
|
@ -15,26 +15,15 @@ pub enum SegDenseError {
|
||||
}
|
||||
|
||||
impl Display for SegDenseError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
SegDenseError::IoError(io_error) => write!(f, "{}", io_error),
|
||||
SegDenseError::Json(serde_json) => write!(f, "{}", serde_json),
|
||||
SegDenseError::JsonMissingRoot => {
|
||||
write!(f, "{}", "SegDense JSON: Root Node note found!")
|
||||
}
|
||||
SegDenseError::JsonMissingObject => {
|
||||
write!(f, "{}", "SegDense JSON: Object note found!")
|
||||
}
|
||||
SegDenseError::JsonMissingArray => {
|
||||
write!(f, "{}", "SegDense JSON: Array Node note found!")
|
||||
}
|
||||
SegDenseError::JsonArraySize => {
|
||||
write!(f, "{}", "SegDense JSON: Array size not as expected!")
|
||||
}
|
||||
SegDenseError::JsonMissingInputFeature => {
|
||||
write!(f, "{}", "SegDense JSON: Missing input feature!")
|
||||
}
|
||||
}
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
SegDenseError::IoError(io_error) => write!(f, "{}", io_error),
|
||||
SegDenseError::Json(serde_json) => write!(f, "{}", serde_json),
|
||||
SegDenseError::JsonMissingRoot => write!(f, "SegDense JSON: Root Node note found!"),
|
||||
SegDenseError::JsonMissingObject => write!(f, "SegDense JSON: Object note found!"),
|
||||
SegDenseError::JsonMissingArray => write!(f, "SegDense JSON: Array Node note found!"),
|
||||
SegDenseError::JsonArraySize => write!(f, "SegDense JSON: Array size not as expected!"),
|
||||
SegDenseError::JsonMissingInputFeature => write!(f, "SegDense JSON: Missing input feature!"),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -7,12 +7,10 @@ use crate::error::SegDenseError;
|
||||
use crate::mapper::{FeatureInfo, FeatureMapper, MapWriter};
|
||||
use crate::segdense_transform_spec_home_recap_2022::{self as seg_dense, InputFeature};
|
||||
|
||||
pub fn load_config(file_name: &str) -> Result<seg_dense::Root, SegDenseError> {
|
||||
let json_str = fs::read_to_string(file_name)?;
|
||||
// &format!("Unable to load segdense file {}", file_name));
|
||||
let seg_dense_config = parse(&json_str)?;
|
||||
// &format!("Unable to parse segdense file {}", file_name));
|
||||
Ok(seg_dense_config)
|
||||
pub fn load_config(file_name: &str) -> seg_dense::Root {
|
||||
let json_str = fs::read_to_string(file_name).unwrap_or_else(|_| panic!("Unable to load segdense file {}", file_name));
|
||||
|
||||
parse(&json_str).unwrap_or_else(|_| panic!("Unable to parse segdense file {}", file_name))
|
||||
}
|
||||
|
||||
pub fn parse(json_str: &str) -> Result<seg_dense::Root, SegDenseError> {
|
||||
|
Loading…
Reference in New Issue
Block a user