From 31e82d6474cf47b3695bf919c44c94d146192a03 Mon Sep 17 00:00:00 2001 From: twitter-team <> Date: Wed, 5 Apr 2023 16:08:19 -0700 Subject: [PATCH] improvements from external prs -fix corner case where dr converter failed when initializing Closes twitter/the-algorithm#550 --- navi/dr_transform/src/all_config.rs | 3 +- navi/dr_transform/src/converter.rs | 262 ++++++++++-------- .../tensorflow/core/framework/full_type.proto | 10 +- .../tensorflow/core/framework/function.proto | 2 +- .../tensorflow/core/framework/node_def.proto | 2 +- .../tensorflow/core/framework/op_def.proto | 4 +- .../core/framework/step_stats.proto | 2 +- .../tensorflow/core/framework/tensor.proto | 2 +- .../tensorflow/core/protobuf/config.proto | 8 +- .../core/protobuf/coordination_service.proto | 2 +- .../tensorflow/core/protobuf/debug.proto | 2 +- .../core/protobuf/debug_event.proto | 6 +- .../distributed_runtime_payloads.proto | 2 +- .../core/protobuf/eager_service.proto | 2 +- .../tensorflow/core/protobuf/master.proto | 2 +- .../core/protobuf/saved_object_graph.proto | 2 +- .../core/protobuf/tensor_bundle.proto | 2 +- .../tensorflow/core/protobuf/worker.proto | 4 +- .../tensorflow_serving/apis/logging.proto | 2 +- .../file_system_storage_path_source.proto | 4 +- .../config/model_server_config.proto | 4 +- navi/navi/src/bootstrap.rs | 26 +- navi/navi/src/metrics.rs | 7 + navi/navi/src/onnx_model.rs | 2 +- navi/navi/src/predict_service.rs | 40 ++- navi/segdense/src/error.rs | 56 ++-- navi/segdense/src/lib.rs | 4 +- navi/segdense/src/main.rs | 23 +- navi/segdense/src/mapper.rs | 4 +- ...segdense_transform_spec_home_recap_2022.rs | 1 - navi/segdense/src/util.rs | 57 ++-- 31 files changed, 305 insertions(+), 244 deletions(-) diff --git a/navi/dr_transform/src/all_config.rs b/navi/dr_transform/src/all_config.rs index 29451bfd4..d5c52c362 100644 --- a/navi/dr_transform/src/all_config.rs +++ b/navi/dr_transform/src/all_config.rs @@ -44,5 +44,6 @@ pub struct RenamedFeatures { } pub fn parse(json_str: &str) -> Result { - serde_json::from_str(json_str) + let all_config: AllConfig = serde_json::from_str(json_str)?; + Ok(all_config) } diff --git a/navi/dr_transform/src/converter.rs b/navi/dr_transform/src/converter.rs index 578d766fd..3097aedc0 100644 --- a/navi/dr_transform/src/converter.rs +++ b/navi/dr_transform/src/converter.rs @@ -2,6 +2,9 @@ use std::collections::BTreeSet; use std::fmt::{self, Debug, Display}; use std::fs; +use crate::all_config; +use crate::all_config::AllConfig; +use anyhow::{bail, Context}; use bpr_thrift::data::DataRecord; use bpr_thrift::prediction_service::BatchPredictionRequest; use bpr_thrift::tensor::GeneralTensor; @@ -16,8 +19,6 @@ use segdense::util; use thrift::protocol::{TBinaryInputProtocol, TSerializable}; use thrift::transport::TBufferChannel; -use crate::{all_config, all_config::AllConfig}; - pub fn log_feature_match( dr: &DataRecord, seg_dense_config: &DensificationTransformSpec, @@ -28,20 +29,24 @@ pub fn log_feature_match( for (feature_id, feature_value) in dr.continuous_features.as_ref().unwrap() { debug!( - "{dr_type} - Continuous Datarecord => Feature ID: {feature_id}, Feature value: {feature_value}" + "{} - Continous Datarecord => Feature ID: {}, Feature value: {}", + dr_type, feature_id, feature_value ); for input_feature in &seg_dense_config.cont.input_features { if input_feature.feature_id == *feature_id { - debug!("Matching input feature: {input_feature:?}") + debug!("Matching input feature: {:?}", input_feature) } } } for feature_id in dr.binary_features.as_ref().unwrap() { - debug!("{dr_type} - Binary Datarecord => Feature ID: {feature_id}"); + debug!( + "{} - Binary Datarecord => Feature ID: {}", + dr_type, feature_id + ); for input_feature in &seg_dense_config.binary.input_features { if input_feature.feature_id == *feature_id { - debug!("Found input feature: {input_feature:?}") + debug!("Found input feature: {:?}", input_feature) } } } @@ -90,18 +95,19 @@ impl BatchPredictionRequestToTorchTensorConverter { model_version: &str, reporting_feature_ids: Vec<(i64, &str)>, register_metric_fn: Option, - ) -> BatchPredictionRequestToTorchTensorConverter { - let all_config_path = format!("{model_dir}/{model_version}/all_config.json"); - let seg_dense_config_path = - format!("{model_dir}/{model_version}/segdense_transform_spec_home_recap_2022.json"); - let seg_dense_config = util::load_config(&seg_dense_config_path); + ) -> anyhow::Result { + let all_config_path = format!("{}/{}/all_config.json", model_dir, model_version); + let seg_dense_config_path = format!( + "{}/{}/segdense_transform_spec_home_recap_2022.json", + model_dir, model_version + ); + let seg_dense_config = util::load_config(&seg_dense_config_path)?; let all_config = all_config::parse( &fs::read_to_string(&all_config_path) - .unwrap_or_else(|error| panic!("error loading all_config.json - {error}")), - ) - .unwrap(); + .with_context(|| "error loading all_config.json - ")?, + )?; - let feature_mapper = util::load_from_parsed_config_ref(&seg_dense_config); + let feature_mapper = util::load_from_parsed_config(seg_dense_config.clone())?; let user_embedding_feature_id = Self::get_feature_id( &all_config @@ -131,11 +137,11 @@ impl BatchPredictionRequestToTorchTensorConverter { let (discrete_feature_metrics, continuous_feature_metrics) = METRICS.get_or_init(|| { let discrete = HistogramVec::new( HistogramOpts::new(":navi:feature_id:discrete", "Discrete Feature ID values") - .buckets(Vec::from([ - 0.0f64, 10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0, 110.0, + .buckets(Vec::from(&[ + 0.0, 10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0, 110.0, 120.0, 130.0, 140.0, 150.0, 160.0, 170.0, 180.0, 190.0, 200.0, 250.0, 300.0, 500.0, 1000.0, 10000.0, 100000.0, - ])), + ] as &'static [f64])), &["feature_id"], ) .expect("metric cannot be created"); @@ -144,18 +150,18 @@ impl BatchPredictionRequestToTorchTensorConverter { ":navi:feature_id:continuous", "continuous Feature ID values", ) - .buckets(Vec::from([ - 0.0f64, 10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0, 110.0, - 120.0, 130.0, 140.0, 150.0, 160.0, 170.0, 180.0, 190.0, 200.0, 250.0, 300.0, - 500.0, 1000.0, 10000.0, 100000.0, - ])), + .buckets(Vec::from(&[ + 0.0, 10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0, 110.0, 120.0, + 130.0, 140.0, 150.0, 160.0, 170.0, 180.0, 190.0, 200.0, 250.0, 300.0, 500.0, + 1000.0, 10000.0, 100000.0, + ] as &'static [f64])), &["feature_id"], ) .expect("metric cannot be created"); - if let Some(r) = register_metric_fn { + register_metric_fn.map(|r| { r(&discrete); r(&continuous); - } + }); (discrete, continuous) }); @@ -164,13 +170,16 @@ impl BatchPredictionRequestToTorchTensorConverter { for (feature_id, feature_type) in reporting_feature_ids.iter() { match *feature_type { - "discrete" => discrete_features_to_report.insert(*feature_id), - "continuous" => continuous_features_to_report.insert(*feature_id), - _ => panic!("Invalid feature type {feature_type} for reporting metrics!"), + "discrete" => discrete_features_to_report.insert(feature_id.clone()), + "continuous" => continuous_features_to_report.insert(feature_id.clone()), + _ => bail!( + "Invalid feature type {} for reporting metrics!", + feature_type + ), }; } - BatchPredictionRequestToTorchTensorConverter { + Ok(BatchPredictionRequestToTorchTensorConverter { all_config, seg_dense_config, all_config_path, @@ -183,7 +192,7 @@ impl BatchPredictionRequestToTorchTensorConverter { continuous_features_to_report, discrete_feature_metrics, continuous_feature_metrics, - } + }) } fn get_feature_id(feature_name: &str, seg_dense_config: &Root) -> i64 { @@ -218,43 +227,45 @@ impl BatchPredictionRequestToTorchTensorConverter { let mut working_set = vec![0 as f32; total_size]; let mut bpr_start = 0; for (bpr, &bpr_end) in bprs.iter().zip(batch_size) { - if bpr.common_features.is_some() - && bpr.common_features.as_ref().unwrap().tensors.is_some() - && bpr - .common_features - .as_ref() - .unwrap() - .tensors - .as_ref() - .unwrap() - .contains_key(&feature_id) - { - let source_tensor = bpr - .common_features - .as_ref() - .unwrap() - .tensors - .as_ref() - .unwrap() - .get(&feature_id) - .unwrap(); - let tensor = match source_tensor { - GeneralTensor::FloatTensor(float_tensor) => - //Tensor::of_slice( + if bpr.common_features.is_some() { + if bpr.common_features.as_ref().unwrap().tensors.is_some() { + if bpr + .common_features + .as_ref() + .unwrap() + .tensors + .as_ref() + .unwrap() + .contains_key(&feature_id) { - float_tensor - .floats - .iter() - .map(|x| x.into_inner() as f32) - .collect::>() - } - _ => vec![0 as f32; cols], - }; + let source_tensor = bpr + .common_features + .as_ref() + .unwrap() + .tensors + .as_ref() + .unwrap() + .get(&feature_id) + .unwrap(); + let tensor = match source_tensor { + GeneralTensor::FloatTensor(float_tensor) => + //Tensor::of_slice( + { + float_tensor + .floats + .iter() + .map(|x| x.into_inner() as f32) + .collect::>() + } + _ => vec![0 as f32; cols], + }; - // since the tensor is found in common feature, add it in all batches - for row in bpr_start..bpr_end { - for col in 0..cols { - working_set[row * cols + col] = tensor[col]; + // since the tensor is found in common feature, add it in all batches + for row in bpr_start..bpr_end { + for col in 0..cols { + working_set[row * cols + col] = tensor[col]; + } + } } } } @@ -298,9 +309,9 @@ impl BatchPredictionRequestToTorchTensorConverter { // (INT64 --> INT64, DataRecord.discrete_feature) fn get_continuous(&self, bprs: &[BatchPredictionRequest], batch_ends: &[usize]) -> InputTensor { // These need to be part of model schema - let rows = batch_ends[batch_ends.len() - 1]; - let cols = 5293; - let full_size = rows * cols; + let rows: usize = batch_ends[batch_ends.len() - 1]; + let cols: usize = 5293; + let full_size: usize = rows * cols; let default_val = f32::NAN; let mut tensor = vec![default_val; full_size]; @@ -325,15 +336,18 @@ impl BatchPredictionRequestToTorchTensorConverter { .unwrap(); for feature in common_features { - if let Some(f_info) = self.feature_mapper.get(feature.0) { - let idx = f_info.index_within_tensor as usize; - if idx < cols { - // Set value in each row - for r in bpr_start..bpr_end { - let flat_index = r * cols + idx; - tensor[flat_index] = feature.1.into_inner() as f32; + match self.feature_mapper.get(feature.0) { + Some(f_info) => { + let idx = f_info.index_within_tensor as usize; + if idx < cols { + // Set value in each row + for r in bpr_start..bpr_end { + let flat_index: usize = r * cols + idx; + tensor[flat_index] = feature.1.into_inner() as f32; + } } } + None => (), } if self.continuous_features_to_report.contains(feature.0) { self.continuous_feature_metrics @@ -349,24 +363,28 @@ impl BatchPredictionRequestToTorchTensorConverter { // Process the batch of datarecords for r in bpr_start..bpr_end { - let dr: &DataRecord = &bpr.individual_features_list[r - bpr_start]; + let dr: &DataRecord = + &bpr.individual_features_list[usize::try_from(r - bpr_start).unwrap()]; if dr.continuous_features.is_some() { for feature in dr.continuous_features.as_ref().unwrap() { - if let Some(f_info) = self.feature_mapper.get(feature.0) { - let idx = f_info.index_within_tensor as usize; - let flat_index = r * cols + idx; - if flat_index < tensor.len() && idx < cols { - tensor[flat_index] = feature.1.into_inner() as f32; + match self.feature_mapper.get(&feature.0) { + Some(f_info) => { + let idx = f_info.index_within_tensor as usize; + let flat_index: usize = r * cols + idx; + if flat_index < tensor.len() && idx < cols { + tensor[flat_index] = feature.1.into_inner() as f32; + } } + None => (), } if self.continuous_features_to_report.contains(feature.0) { self.continuous_feature_metrics .with_label_values(&[feature.0.to_string().as_str()]) - .observe(feature.1.into_inner()) + .observe(feature.1.into_inner() as f64) } else if self.discrete_features_to_report.contains(feature.0) { self.discrete_feature_metrics .with_label_values(&[feature.0.to_string().as_str()]) - .observe(feature.1.into_inner()) + .observe(feature.1.into_inner() as f64) } } } @@ -383,10 +401,10 @@ impl BatchPredictionRequestToTorchTensorConverter { fn get_binary(&self, bprs: &[BatchPredictionRequest], batch_ends: &[usize]) -> InputTensor { // These need to be part of model schema - let rows = batch_ends[batch_ends.len() - 1]; - let cols = 149; - let full_size = rows * cols; - let default_val = 0; + let rows: usize = batch_ends[batch_ends.len() - 1]; + let cols: usize = 149; + let full_size: usize = rows * cols; + let default_val: i64 = 0; let mut v = vec![default_val; full_size]; @@ -410,15 +428,18 @@ impl BatchPredictionRequestToTorchTensorConverter { .unwrap(); for feature in common_features { - if let Some(f_info) = self.feature_mapper.get(feature) { - let idx = f_info.index_within_tensor as usize; - if idx < cols { - // Set value in each row - for r in bpr_start..bpr_end { - let flat_index = r * cols + idx; - v[flat_index] = 1; + match self.feature_mapper.get(feature) { + Some(f_info) => { + let idx = f_info.index_within_tensor as usize; + if idx < cols { + // Set value in each row + for r in bpr_start..bpr_end { + let flat_index: usize = r * cols + idx; + v[flat_index] = 1; + } } } + None => (), } } } @@ -428,10 +449,13 @@ impl BatchPredictionRequestToTorchTensorConverter { let dr: &DataRecord = &bpr.individual_features_list[r - bpr_start]; if dr.binary_features.is_some() { for feature in dr.binary_features.as_ref().unwrap() { - if let Some(f_info) = self.feature_mapper.get(feature) { - let idx = f_info.index_within_tensor as usize; - let flat_index = r * cols + idx; - v[flat_index] = 1; + match self.feature_mapper.get(&feature) { + Some(f_info) => { + let idx = f_info.index_within_tensor as usize; + let flat_index: usize = r * cols + idx; + v[flat_index] = 1; + } + None => (), } } } @@ -448,10 +472,10 @@ impl BatchPredictionRequestToTorchTensorConverter { #[allow(dead_code)] fn get_discrete(&self, bprs: &[BatchPredictionRequest], batch_ends: &[usize]) -> InputTensor { // These need to be part of model schema - let rows = batch_ends[batch_ends.len() - 1]; - let cols = 320; - let full_size = rows * cols; - let default_val = 0; + let rows: usize = batch_ends[batch_ends.len() - 1]; + let cols: usize = 320; + let full_size: usize = rows * cols; + let default_val: i64 = 0; let mut v = vec![default_val; full_size]; @@ -475,15 +499,18 @@ impl BatchPredictionRequestToTorchTensorConverter { .unwrap(); for feature in common_features { - if let Some(f_info) = self.feature_mapper.get(feature.0) { - let idx = f_info.index_within_tensor as usize; - if idx < cols { - // Set value in each row - for r in bpr_start..bpr_end { - let flat_index = r * cols + idx; - v[flat_index] = *feature.1; + match self.feature_mapper.get(feature.0) { + Some(f_info) => { + let idx = f_info.index_within_tensor as usize; + if idx < cols { + // Set value in each row + for r in bpr_start..bpr_end { + let flat_index: usize = r * cols + idx; + v[flat_index] = *feature.1; + } } } + None => (), } if self.discrete_features_to_report.contains(feature.0) { self.discrete_feature_metrics @@ -495,15 +522,18 @@ impl BatchPredictionRequestToTorchTensorConverter { // Process the batch of datarecords for r in bpr_start..bpr_end { - let dr: &DataRecord = &bpr.individual_features_list[r]; + let dr: &DataRecord = &bpr.individual_features_list[usize::try_from(r).unwrap()]; if dr.discrete_features.is_some() { for feature in dr.discrete_features.as_ref().unwrap() { - if let Some(f_info) = self.feature_mapper.get(feature.0) { - let idx = f_info.index_within_tensor as usize; - let flat_index = r * cols + idx; - if flat_index < v.len() && idx < cols { - v[flat_index] = *feature.1; + match self.feature_mapper.get(&feature.0) { + Some(f_info) => { + let idx = f_info.index_within_tensor as usize; + let flat_index: usize = r * cols + idx; + if flat_index < v.len() && idx < cols { + v[flat_index] = *feature.1; + } } + None => (), } if self.discrete_features_to_report.contains(feature.0) { self.discrete_feature_metrics @@ -569,7 +599,7 @@ impl Converter for BatchPredictionRequestToTorchTensorConverter { .map(|bpr| bpr.individual_features_list.len()) .scan(0usize, |acc, e| { //running total - *acc += e; + *acc = *acc + e; Some(*acc) }) .collect::>(); diff --git a/navi/navi/proto/tensorflow/core/framework/full_type.proto b/navi/navi/proto/tensorflow/core/framework/full_type.proto index e8175ed3d..ddf05ec8f 100644 --- a/navi/navi/proto/tensorflow/core/framework/full_type.proto +++ b/navi/navi/proto/tensorflow/core/framework/full_type.proto @@ -122,7 +122,7 @@ enum FullTypeId { // TFT_TENSOR[TFT_INT32, TFT_UNKNOWN] // is a Tensor of int32 element type and unknown shape. // - // TODO: Define TFT_SHAPE and add more examples. + // TODO(mdan): Define TFT_SHAPE and add more examples. TFT_TENSOR = 1000; // Array (or tensorflow::TensorList in the variant type registry). @@ -178,7 +178,7 @@ enum FullTypeId { // object (for now). // The bool element type. - // TODO + // TODO(mdan): Quantized types, legacy representations (e.g. ref) TFT_BOOL = 200; // Integer element types. TFT_UINT8 = 201; @@ -195,7 +195,7 @@ enum FullTypeId { TFT_DOUBLE = 211; TFT_BFLOAT16 = 215; // Complex element types. - // TODO: Represent as TFT_COMPLEX[TFT_DOUBLE] instead? + // TODO(mdan): Represent as TFT_COMPLEX[TFT_DOUBLE] instead? TFT_COMPLEX64 = 212; TFT_COMPLEX128 = 213; // The string element type. @@ -240,7 +240,7 @@ enum FullTypeId { // ownership is in the true sense: "the op argument representing the lock is // available". // Mutex locks are the dynamic counterpart of control dependencies. - // TODO: Properly document this thing. + // TODO(mdan): Properly document this thing. // // Parametrization: TFT_MUTEX_LOCK[]. TFT_MUTEX_LOCK = 10202; @@ -271,6 +271,6 @@ message FullTypeDef { oneof attr { string s = 3; int64 i = 4; - // TODO: list/tensor, map? Need to reconcile with TFT_RECORD, etc. + // TODO(mdan): list/tensor, map? Need to reconcile with TFT_RECORD, etc. } } diff --git a/navi/navi/proto/tensorflow/core/framework/function.proto b/navi/navi/proto/tensorflow/core/framework/function.proto index efa3c9aeb..6e59df718 100644 --- a/navi/navi/proto/tensorflow/core/framework/function.proto +++ b/navi/navi/proto/tensorflow/core/framework/function.proto @@ -23,7 +23,7 @@ message FunctionDefLibrary { // with a value. When a GraphDef has a call to a function, it must // have binding for every attr defined in the signature. // -// TODO: +// TODO(zhifengc): // * device spec, etc. message FunctionDef { // The definition of the function's name, arguments, return values, diff --git a/navi/navi/proto/tensorflow/core/framework/node_def.proto b/navi/navi/proto/tensorflow/core/framework/node_def.proto index 801759817..705e90aa3 100644 --- a/navi/navi/proto/tensorflow/core/framework/node_def.proto +++ b/navi/navi/proto/tensorflow/core/framework/node_def.proto @@ -61,7 +61,7 @@ message NodeDef { // one of the names from the corresponding OpDef's attr field). // The values must have a type matching the corresponding OpDef // attr's type field. - // TODO: Add some examples here showing best practices. + // TODO(josh11b): Add some examples here showing best practices. map attr = 5; message ExperimentalDebugInfo { diff --git a/navi/navi/proto/tensorflow/core/framework/op_def.proto b/navi/navi/proto/tensorflow/core/framework/op_def.proto index a53fdf028..b71f5ce87 100644 --- a/navi/navi/proto/tensorflow/core/framework/op_def.proto +++ b/navi/navi/proto/tensorflow/core/framework/op_def.proto @@ -96,7 +96,7 @@ message OpDef { // Human-readable description. string description = 4; - // TODO: bool is_optional? + // TODO(josh11b): bool is_optional? // --- Constraints --- // These constraints are only in effect if specified. Default is no @@ -139,7 +139,7 @@ message OpDef { // taking input from multiple devices with a tree of aggregate ops // that aggregate locally within each device (and possibly within // groups of nearby devices) before communicating. - // TODO: Implement that optimization. + // TODO(josh11b): Implement that optimization. bool is_aggregate = 16; // for things like add // Other optimizations go here, like diff --git a/navi/navi/proto/tensorflow/core/framework/step_stats.proto b/navi/navi/proto/tensorflow/core/framework/step_stats.proto index 62238234d..762487f02 100644 --- a/navi/navi/proto/tensorflow/core/framework/step_stats.proto +++ b/navi/navi/proto/tensorflow/core/framework/step_stats.proto @@ -53,7 +53,7 @@ message MemoryStats { // Time/size stats recorded for a single execution of a graph node. message NodeExecStats { - // TODO: Use some more compact form of node identity than + // TODO(tucker): Use some more compact form of node identity than // the full string name. Either all processes should agree on a // global id (cost_id?) for each node, or we should use a hash of // the name. diff --git a/navi/navi/proto/tensorflow/core/framework/tensor.proto b/navi/navi/proto/tensorflow/core/framework/tensor.proto index 2d4b593be..eb057b127 100644 --- a/navi/navi/proto/tensorflow/core/framework/tensor.proto +++ b/navi/navi/proto/tensorflow/core/framework/tensor.proto @@ -16,7 +16,7 @@ option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framewo message TensorProto { DataType dtype = 1; - // Shape of the tensor. TODO: sort out the 0-rank issues. + // Shape of the tensor. TODO(touts): sort out the 0-rank issues. TensorShapeProto tensor_shape = 2; // Only one of the representations below is set, one of "tensor_contents" and diff --git a/navi/navi/proto/tensorflow/core/protobuf/config.proto b/navi/navi/proto/tensorflow/core/protobuf/config.proto index ff78e1f22..e454309fc 100644 --- a/navi/navi/proto/tensorflow/core/protobuf/config.proto +++ b/navi/navi/proto/tensorflow/core/protobuf/config.proto @@ -532,7 +532,7 @@ message ConfigProto { // We removed the flag client_handles_error_formatting. Marking the tag // number as reserved. - // TODO: Should we just remove this tag so that it can be + // TODO(shikharagarwal): Should we just remove this tag so that it can be // used in future for other purpose? reserved 2; @@ -576,7 +576,7 @@ message ConfigProto { // - If isolate_session_state is true, session states are isolated. // - If isolate_session_state is false, session states are shared. // - // TODO: Add a single API that consistently treats + // TODO(b/129330037): Add a single API that consistently treats // isolate_session_state and ClusterSpec propagation. bool share_session_state_in_clusterspec_propagation = 8; @@ -704,7 +704,7 @@ message ConfigProto { // Options for a single Run() call. message RunOptions { - // TODO Turn this into a TraceOptions proto which allows + // TODO(pbar) Turn this into a TraceOptions proto which allows // tracing to be controlled in a more orthogonal manner? enum TraceLevel { NO_TRACE = 0; @@ -781,7 +781,7 @@ message RunMetadata { repeated GraphDef partition_graphs = 3; message FunctionGraphs { - // TODO: Include some sort of function/cache-key identifier? + // TODO(nareshmodi): Include some sort of function/cache-key identifier? repeated GraphDef partition_graphs = 1; GraphDef pre_optimization_graph = 2; diff --git a/navi/navi/proto/tensorflow/core/protobuf/coordination_service.proto b/navi/navi/proto/tensorflow/core/protobuf/coordination_service.proto index e190bb028..730fb8c10 100644 --- a/navi/navi/proto/tensorflow/core/protobuf/coordination_service.proto +++ b/navi/navi/proto/tensorflow/core/protobuf/coordination_service.proto @@ -194,7 +194,7 @@ service CoordinationService { // Report error to the task. RPC sets the receiving instance of coordination // service agent to error state permanently. - // TODO: Consider splitting this into a different RPC service. + // TODO(b/195990880): Consider splitting this into a different RPC service. rpc ReportErrorToAgent(ReportErrorToAgentRequest) returns (ReportErrorToAgentResponse); diff --git a/navi/navi/proto/tensorflow/core/protobuf/debug.proto b/navi/navi/proto/tensorflow/core/protobuf/debug.proto index 1cc76f1ed..2fabd0319 100644 --- a/navi/navi/proto/tensorflow/core/protobuf/debug.proto +++ b/navi/navi/proto/tensorflow/core/protobuf/debug.proto @@ -46,7 +46,7 @@ message DebugTensorWatch { // are to be debugged, the callers of Session::Run() must use distinct // debug_urls to make sure that the streamed or dumped events do not overlap // among the invocations. - // TODO: More visible documentation of this in g3docs. + // TODO(cais): More visible documentation of this in g3docs. repeated string debug_urls = 4; // Do not error out if debug op creation fails (e.g., due to dtype diff --git a/navi/navi/proto/tensorflow/core/protobuf/debug_event.proto b/navi/navi/proto/tensorflow/core/protobuf/debug_event.proto index b68f45d4d..5530004d7 100644 --- a/navi/navi/proto/tensorflow/core/protobuf/debug_event.proto +++ b/navi/navi/proto/tensorflow/core/protobuf/debug_event.proto @@ -12,7 +12,7 @@ option java_package = "org.tensorflow.util"; option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/protobuf/for_core_protos_go_proto"; // Available modes for extracting debugging information from a Tensor. -// TODO: Document the detailed column names and semantics in a separate +// TODO(cais): Document the detailed column names and semantics in a separate // markdown file once the implementation settles. enum TensorDebugMode { UNSPECIFIED = 0; @@ -223,7 +223,7 @@ message DebuggedDevice { // A debugger-generated ID for the device. Guaranteed to be unique within // the scope of the debugged TensorFlow program, including single-host and // multi-host settings. - // TODO: Test the uniqueness guarantee in multi-host settings. + // TODO(cais): Test the uniqueness guarantee in multi-host settings. int32 device_id = 2; } @@ -264,7 +264,7 @@ message Execution { // field with the DebuggedDevice messages. repeated int32 output_tensor_device_ids = 9; - // TODO support, add more fields + // TODO(cais): When backporting to V1 Session.run() support, add more fields // such as fetches and feeds. } diff --git a/navi/navi/proto/tensorflow/core/protobuf/distributed_runtime_payloads.proto b/navi/navi/proto/tensorflow/core/protobuf/distributed_runtime_payloads.proto index c19da9d82..ddb346afa 100644 --- a/navi/navi/proto/tensorflow/core/protobuf/distributed_runtime_payloads.proto +++ b/navi/navi/proto/tensorflow/core/protobuf/distributed_runtime_payloads.proto @@ -7,7 +7,7 @@ option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/protobu // Used to serialize and transmit tensorflow::Status payloads through // grpc::Status `error_details` since grpc::Status lacks payload API. -// TODO: Use GRPC API once supported. +// TODO(b/204231601): Use GRPC API once supported. message GrpcPayloadContainer { map payloads = 1; } diff --git a/navi/navi/proto/tensorflow/core/protobuf/eager_service.proto b/navi/navi/proto/tensorflow/core/protobuf/eager_service.proto index 9d658c7d9..204acf6b1 100644 --- a/navi/navi/proto/tensorflow/core/protobuf/eager_service.proto +++ b/navi/navi/proto/tensorflow/core/protobuf/eager_service.proto @@ -172,7 +172,7 @@ message WaitQueueDoneRequest { } message WaitQueueDoneResponse { - // TODO: Consider adding NodeExecStats here to be able to + // TODO(nareshmodi): Consider adding NodeExecStats here to be able to // propagate some stats. } diff --git a/navi/navi/proto/tensorflow/core/protobuf/master.proto b/navi/navi/proto/tensorflow/core/protobuf/master.proto index 60555cd58..e1732a932 100644 --- a/navi/navi/proto/tensorflow/core/protobuf/master.proto +++ b/navi/navi/proto/tensorflow/core/protobuf/master.proto @@ -94,7 +94,7 @@ message ExtendSessionRequest { } message ExtendSessionResponse { - // TODO: Return something about the operation? + // TODO(mrry): Return something about the operation? // The new version number for the extended graph, to be used in the next call // to ExtendSession. diff --git a/navi/navi/proto/tensorflow/core/protobuf/saved_object_graph.proto b/navi/navi/proto/tensorflow/core/protobuf/saved_object_graph.proto index 70b31f0e6..a59ad0ed2 100644 --- a/navi/navi/proto/tensorflow/core/protobuf/saved_object_graph.proto +++ b/navi/navi/proto/tensorflow/core/protobuf/saved_object_graph.proto @@ -176,7 +176,7 @@ message SavedBareConcreteFunction { // allows the ConcreteFunction to be called with nest structure inputs. This // field may not be populated. If this field is absent, the concrete function // can only be called with flat inputs. - // TODO: support calling saved ConcreteFunction with structured + // TODO(b/169361281): support calling saved ConcreteFunction with structured // inputs in C++ SavedModel API. FunctionSpec function_spec = 4; } diff --git a/navi/navi/proto/tensorflow/core/protobuf/tensor_bundle.proto b/navi/navi/proto/tensorflow/core/protobuf/tensor_bundle.proto index 4433afae2..999195cc9 100644 --- a/navi/navi/proto/tensorflow/core/protobuf/tensor_bundle.proto +++ b/navi/navi/proto/tensorflow/core/protobuf/tensor_bundle.proto @@ -17,7 +17,7 @@ option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/protobu // Special header that is associated with a bundle. // -// TODO: maybe in the future, we can add information about +// TODO(zongheng,zhifengc): maybe in the future, we can add information about // which binary produced this checkpoint, timestamp, etc. Sometime, these can be // valuable debugging information. And if needed, these can be used as defensive // information ensuring reader (binary version) of the checkpoint and the writer diff --git a/navi/navi/proto/tensorflow/core/protobuf/worker.proto b/navi/navi/proto/tensorflow/core/protobuf/worker.proto index 0df080c77..18d60b568 100644 --- a/navi/navi/proto/tensorflow/core/protobuf/worker.proto +++ b/navi/navi/proto/tensorflow/core/protobuf/worker.proto @@ -188,7 +188,7 @@ message DeregisterGraphRequest { } message DeregisterGraphResponse { - // TODO: Optionally add summary stats for the graph. + // TODO(mrry): Optionally add summary stats for the graph. } //////////////////////////////////////////////////////////////////////////////// @@ -294,7 +294,7 @@ message RunGraphResponse { // If the request asked for execution stats, the cost graph, or the partition // graphs, these are returned here. - // TODO: Package these in a RunMetadata instead. + // TODO(suharshs): Package these in a RunMetadata instead. StepStats step_stats = 2; CostGraphDef cost_graph = 3; repeated GraphDef partition_graph = 4; diff --git a/navi/navi/proto/tensorflow_serving/apis/logging.proto b/navi/navi/proto/tensorflow_serving/apis/logging.proto index 9d304f44d..6298bb4b2 100644 --- a/navi/navi/proto/tensorflow_serving/apis/logging.proto +++ b/navi/navi/proto/tensorflow_serving/apis/logging.proto @@ -13,5 +13,5 @@ message LogMetadata { SamplingConfig sampling_config = 2; // List of tags used to load the relevant MetaGraphDef from SavedModel. repeated string saved_model_tags = 3; - // TODO: Add more metadata as mentioned in the bug. + // TODO(b/33279154): Add more metadata as mentioned in the bug. } diff --git a/navi/navi/proto/tensorflow_serving/config/file_system_storage_path_source.proto b/navi/navi/proto/tensorflow_serving/config/file_system_storage_path_source.proto index 8d8541d4f..add7aa2a2 100644 --- a/navi/navi/proto/tensorflow_serving/config/file_system_storage_path_source.proto +++ b/navi/navi/proto/tensorflow_serving/config/file_system_storage_path_source.proto @@ -58,7 +58,7 @@ message FileSystemStoragePathSourceConfig { // A single servable name/base_path pair to monitor. // DEPRECATED: Use 'servables' instead. - // TODO: Stop using these fields, and ultimately remove them here. + // TODO(b/30898016): Stop using these fields, and ultimately remove them here. string servable_name = 1 [deprecated = true]; string base_path = 2 [deprecated = true]; @@ -76,7 +76,7 @@ message FileSystemStoragePathSourceConfig { // check for a version to appear later.) // DEPRECATED: Use 'servable_versions_always_present' instead, which includes // this behavior. - // TODO: Remove 2019-10-31 or later. + // TODO(b/30898016): Remove 2019-10-31 or later. bool fail_if_zero_versions_at_startup = 4 [deprecated = true]; // If true, the servable is always expected to exist on the underlying diff --git a/navi/navi/proto/tensorflow_serving/config/model_server_config.proto b/navi/navi/proto/tensorflow_serving/config/model_server_config.proto index 0f80aa1c7..cadc2b6e6 100644 --- a/navi/navi/proto/tensorflow_serving/config/model_server_config.proto +++ b/navi/navi/proto/tensorflow_serving/config/model_server_config.proto @@ -9,7 +9,7 @@ import "tensorflow_serving/config/logging_config.proto"; option cc_enable_arenas = true; // The type of model. -// TODO: DEPRECATED. +// TODO(b/31336131): DEPRECATED. enum ModelType { MODEL_TYPE_UNSPECIFIED = 0 [deprecated = true]; TENSORFLOW = 1 [deprecated = true]; @@ -31,7 +31,7 @@ message ModelConfig { string base_path = 2; // Type of model. - // TODO: DEPRECATED. Please use 'model_platform' instead. + // TODO(b/31336131): DEPRECATED. Please use 'model_platform' instead. ModelType model_type = 3 [deprecated = true]; // Type of model (e.g. "tensorflow"). diff --git a/navi/navi/src/bootstrap.rs b/navi/navi/src/bootstrap.rs index 56215292f..1f767f17e 100644 --- a/navi/navi/src/bootstrap.rs +++ b/navi/navi/src/bootstrap.rs @@ -1,5 +1,6 @@ use anyhow::Result; use log::{info, warn}; +use x509_parser::{prelude::{parse_x509_pem}, parse_x509_certificate}; use std::collections::HashMap; use tokio::time::Instant; use tonic::{ @@ -27,6 +28,7 @@ use crate::cli_args::{ARGS, INPUTS, OUTPUTS}; use crate::metrics::{ NAVI_VERSION, NUM_PREDICTIONS, NUM_REQUESTS_FAILED, NUM_REQUESTS_FAILED_BY_MODEL, NUM_REQUESTS_RECEIVED, NUM_REQUESTS_RECEIVED_BY_MODEL, RESPONSE_TIME_COLLECTOR, + CERT_EXPIRY_EPOCH }; use crate::predict_service::{Model, PredictService}; use crate::tf_proto::tensorflow_serving::model_spec::VersionChoice::Version; @@ -233,6 +235,12 @@ impl PredictionService for PredictService { } } +// A function that takes a timestamp as input and returns a ticker stream +fn report_expiry(expiry_time: i64) { + info!("Certificate expires at epoch: {:?}", expiry_time); + CERT_EXPIRY_EPOCH.set(expiry_time as i64); +} + pub fn bootstrap(model_factory: ModelFactory) -> Result<()> { info!("package: {}, version: {}, args: {:?}", NAME, VERSION, *ARGS); //we follow SemVer. So here we assume MAJOR.MINOR.PATCH @@ -249,6 +257,7 @@ pub fn bootstrap(model_factory: ModelFactory) -> Result<()> { ); } + tokio::runtime::Builder::new_multi_thread() .thread_name("async worker") .worker_threads(ARGS.num_worker_threads) @@ -266,6 +275,21 @@ pub fn bootstrap(model_factory: ModelFactory) -> Result<()> { let mut builder = if ARGS.ssl_dir.is_empty() { Server::builder() } else { + // Read the pem file as a string + let pem_str = std::fs::read_to_string(format!("{}/server.crt", ARGS.ssl_dir)).unwrap(); + let res = parse_x509_pem(&pem_str.as_bytes()); + match res { + Ok((rem, pem_2)) => { + assert!(rem.is_empty()); + assert_eq!(pem_2.label, String::from("CERTIFICATE")); + let res_x509 = parse_x509_certificate(&pem_2.contents); + info!("Certificate label: {}", pem_2.label); + assert!(res_x509.is_ok()); + report_expiry(res_x509.unwrap().1.validity().not_after.timestamp()); + }, + _ => panic!("PEM parsing failed: {:?}", res), + } + let key = tokio::fs::read(format!("{}/server.key", ARGS.ssl_dir)) .await .expect("can't find key file"); @@ -281,7 +305,7 @@ pub fn bootstrap(model_factory: ModelFactory) -> Result<()> { let identity = Identity::from_pem(pem.clone(), key); let client_ca_cert = Certificate::from_pem(pem.clone()); let tls = ServerTlsConfig::new() - .identity(identity) + .identity(identity) .client_ca_root(client_ca_cert); Server::builder() .tls_config(tls) diff --git a/navi/navi/src/metrics.rs b/navi/navi/src/metrics.rs index 7cc9e6fcf..373f84f0f 100644 --- a/navi/navi/src/metrics.rs +++ b/navi/navi/src/metrics.rs @@ -171,6 +171,9 @@ lazy_static! { &["model_name"] ) .expect("metric can be created"); + pub static ref CERT_EXPIRY_EPOCH: IntGauge = + IntGauge::new(":navi:cert_expiry_epoch", "Timestamp when the current cert expires") + .expect("metric can be created"); } pub fn register_custom_metrics() { @@ -249,6 +252,10 @@ pub fn register_custom_metrics() { REGISTRY .register(Box::new(CONVERTER_TIME_COLLECTOR.clone())) .expect("collector can be registered"); + REGISTRY + .register(Box::new(CERT_EXPIRY_EPOCH.clone())) + .expect("collector can be registered"); + } pub fn register_dynamic_metrics(c: &HistogramVec) { diff --git a/navi/navi/src/onnx_model.rs b/navi/navi/src/onnx_model.rs index a0d75c8c9..18f116570 100644 --- a/navi/navi/src/onnx_model.rs +++ b/navi/navi/src/onnx_model.rs @@ -189,7 +189,7 @@ pub mod onnx { &version, reporting_feature_ids, Some(metrics::register_dynamic_metrics), - )), + )?), }; onnx_model.warmup()?; Ok(onnx_model) diff --git a/navi/navi/src/predict_service.rs b/navi/navi/src/predict_service.rs index 8650662cf..fc355d7ea 100644 --- a/navi/navi/src/predict_service.rs +++ b/navi/navi/src/predict_service.rs @@ -24,7 +24,7 @@ use serde_json::{self, Value}; pub trait Model: Send + Sync + Display + Debug + 'static { fn warmup(&self) -> Result<()>; - //TODO: refactor this to return Vec>, i.e. + //TODO: refactor this to return vec>, i.e. //we have the underlying runtime impl to split the response to each client. //It will eliminate some inefficient memory copy in onnx_model.rs as well as simplify code fn do_predict( @@ -222,8 +222,8 @@ impl PredictService { .map(|b| b.parse().unwrap()) .collect::>(); let no_msg_wait_millis = *batch_time_out_millis.iter().min().unwrap(); - let mut all_model_predictors = - ArrayVec::, MAX_VERSIONS_PER_MODEL>, MAX_NUM_MODELS>::new(); + let mut all_model_predictors: ArrayVec::, MAX_VERSIONS_PER_MODEL>, MAX_NUM_MODELS> = + (0 ..MAX_NUM_MODELS).map( |_| ArrayVec::, MAX_VERSIONS_PER_MODEL>::new()).collect(); loop { let msg = rx.try_recv(); let no_more_msg = match msg { @@ -272,27 +272,23 @@ impl PredictService { queue_reset_ts: Instant::now(), queue_earliest_rq_ts: Instant::now(), }; - if idx < all_model_predictors.len() { - metrics::NEW_MODEL_SNAPSHOT - .with_label_values(&[&MODEL_SPECS[idx]]) - .inc(); + assert!(idx < all_model_predictors.len()); + metrics::NEW_MODEL_SNAPSHOT + .with_label_values(&[&MODEL_SPECS[idx]]) + .inc(); - info!("now we serve updated model: {}", predictor.model); - //we can do this since the vector is small - let predictors = &mut all_model_predictors[idx]; - if predictors.len() == ARGS.versions_per_model { - predictors.remove(predictors.len() - 1); - } - predictors.insert(0, predictor); - } else { - info!("now we serve new model: {:}", predictor.model); - let mut predictors = - ArrayVec::, MAX_VERSIONS_PER_MODEL>::new(); - predictors.push(predictor); - all_model_predictors.push(predictors); - //check the invariant that we always push the last model to the end - assert_eq!(all_model_predictors.len(), idx + 1) + //we can do this since the vector is small + let predictors = &mut all_model_predictors[idx]; + if predictors.len() == 0 { + info!("now we serve new model: {}", predictor.model); } + else { + info!("now we serve updated model: {}", predictor.model); + } + if predictors.len() == ARGS.versions_per_model { + predictors.remove(predictors.len() - 1); + } + predictors.insert(0, predictor); false } Err(TryRecvError::Empty) => true, diff --git a/navi/segdense/src/error.rs b/navi/segdense/src/error.rs index d997b6933..4c1d9af7d 100644 --- a/navi/segdense/src/error.rs +++ b/navi/segdense/src/error.rs @@ -5,39 +5,49 @@ use std::fmt::Display; */ #[derive(Debug)] pub enum SegDenseError { - IoError(std::io::Error), - Json(serde_json::Error), - JsonMissingRoot, - JsonMissingObject, - JsonMissingArray, - JsonArraySize, - JsonMissingInputFeature, + IoError(std::io::Error), + Json(serde_json::Error), + JsonMissingRoot, + JsonMissingObject, + JsonMissingArray, + JsonArraySize, + JsonMissingInputFeature, } impl Display for SegDenseError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - SegDenseError::IoError(io_error) => write!(f, "{}", io_error), - SegDenseError::Json(serde_json) => write!(f, "{}", serde_json), - SegDenseError::JsonMissingRoot => write!(f, "{}", "SegDense JSON: Root Node note found!"), - SegDenseError::JsonMissingObject => write!(f, "{}", "SegDense JSON: Object note found!"), - SegDenseError::JsonMissingArray => write!(f, "{}", "SegDense JSON: Array Node note found!"), - SegDenseError::JsonArraySize => write!(f, "{}", "SegDense JSON: Array size not as expected!"), - SegDenseError::JsonMissingInputFeature => write!(f, "{}", "SegDense JSON: Missing input feature!"), + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + SegDenseError::IoError(io_error) => write!(f, "{}", io_error), + SegDenseError::Json(serde_json) => write!(f, "{}", serde_json), + SegDenseError::JsonMissingRoot => { + write!(f, "{}", "SegDense JSON: Root Node note found!") + } + SegDenseError::JsonMissingObject => { + write!(f, "{}", "SegDense JSON: Object note found!") + } + SegDenseError::JsonMissingArray => { + write!(f, "{}", "SegDense JSON: Array Node note found!") + } + SegDenseError::JsonArraySize => { + write!(f, "{}", "SegDense JSON: Array size not as expected!") + } + SegDenseError::JsonMissingInputFeature => { + write!(f, "{}", "SegDense JSON: Missing input feature!") + } + } } - } } impl std::error::Error for SegDenseError {} impl From for SegDenseError { - fn from(err: std::io::Error) -> Self { - SegDenseError::IoError(err) - } + fn from(err: std::io::Error) -> Self { + SegDenseError::IoError(err) + } } impl From for SegDenseError { - fn from(err: serde_json::Error) -> Self { - SegDenseError::Json(err) - } + fn from(err: serde_json::Error) -> Self { + SegDenseError::Json(err) + } } diff --git a/navi/segdense/src/lib.rs b/navi/segdense/src/lib.rs index 476411702..f9930da64 100644 --- a/navi/segdense/src/lib.rs +++ b/navi/segdense/src/lib.rs @@ -1,4 +1,4 @@ pub mod error; -pub mod segdense_transform_spec_home_recap_2022; pub mod mapper; -pub mod util; \ No newline at end of file +pub mod segdense_transform_spec_home_recap_2022; +pub mod util; diff --git a/navi/segdense/src/main.rs b/navi/segdense/src/main.rs index 1515df101..d8f7f8bc4 100644 --- a/navi/segdense/src/main.rs +++ b/navi/segdense/src/main.rs @@ -5,19 +5,18 @@ use segdense::error::SegDenseError; use segdense::util; fn main() -> Result<(), SegDenseError> { - env_logger::init(); - let args: Vec = env::args().collect(); - - let schema_file_name: &str = if args.len() == 1 { - "json/compact.json" - } else { - &args[1] - }; + env_logger::init(); + let args: Vec = env::args().collect(); - let json_str = fs::read_to_string(schema_file_name)?; + let schema_file_name: &str = if args.len() == 1 { + "json/compact.json" + } else { + &args[1] + }; - util::safe_load_config(&json_str)?; + let json_str = fs::read_to_string(schema_file_name)?; - Ok(()) + util::safe_load_config(&json_str)?; + + Ok(()) } - diff --git a/navi/segdense/src/mapper.rs b/navi/segdense/src/mapper.rs index f640f2aeb..f5a1d6532 100644 --- a/navi/segdense/src/mapper.rs +++ b/navi/segdense/src/mapper.rs @@ -19,13 +19,13 @@ pub struct FeatureMapper { impl FeatureMapper { pub fn new() -> FeatureMapper { FeatureMapper { - map: HashMap::new() + map: HashMap::new(), } } } pub trait MapWriter { - fn set(&mut self, feature_id: i64, info: FeatureInfo); + fn set(&mut self, feature_id: i64, info: FeatureInfo); } pub trait MapReader { diff --git a/navi/segdense/src/segdense_transform_spec_home_recap_2022.rs b/navi/segdense/src/segdense_transform_spec_home_recap_2022.rs index a3b3513f8..ff6d3ae17 100644 --- a/navi/segdense/src/segdense_transform_spec_home_recap_2022.rs +++ b/navi/segdense/src/segdense_transform_spec_home_recap_2022.rs @@ -164,7 +164,6 @@ pub struct ComplexFeatureTypeTransformSpec { pub tensor_shape: Vec, } - #[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct InputFeatureMapRecord { diff --git a/navi/segdense/src/util.rs b/navi/segdense/src/util.rs index 5d020cea3..116725189 100644 --- a/navi/segdense/src/util.rs +++ b/navi/segdense/src/util.rs @@ -1,23 +1,23 @@ +use log::debug; use std::fs; -use log::{debug}; -use serde_json::{Value, Map}; +use serde_json::{Map, Value}; use crate::error::SegDenseError; -use crate::mapper::{FeatureMapper, FeatureInfo, MapWriter}; +use crate::mapper::{FeatureInfo, FeatureMapper, MapWriter}; use crate::segdense_transform_spec_home_recap_2022::{self as seg_dense, InputFeature}; -pub fn load_config(file_name: &str) -> seg_dense::Root { - let json_str = fs::read_to_string(file_name).expect( - &format!("Unable to load segdense file {}", file_name)); - let seg_dense_config = parse(&json_str).expect( - &format!("Unable to parse segdense file {}", file_name)); - return seg_dense_config; +pub fn load_config(file_name: &str) -> Result { + let json_str = fs::read_to_string(file_name)?; + // &format!("Unable to load segdense file {}", file_name)); + let seg_dense_config = parse(&json_str)?; + // &format!("Unable to parse segdense file {}", file_name)); + Ok(seg_dense_config) } pub fn parse(json_str: &str) -> Result { let root: seg_dense::Root = serde_json::from_str(json_str)?; - return Ok(root); + Ok(root) } /** @@ -44,15 +44,8 @@ pub fn safe_load_config(json_str: &str) -> Result load_from_parsed_config(root) } -pub fn load_from_parsed_config_ref(root: &seg_dense::Root) -> FeatureMapper { - load_from_parsed_config(root.clone()).unwrap_or_else( - |error| panic!("Error loading all_config.json - {}", error)) -} - // Perf note : make 'root' un-owned -pub fn load_from_parsed_config(root: seg_dense::Root) -> - Result { - +pub fn load_from_parsed_config(root: seg_dense::Root) -> Result { let v = root.input_features_map; // Do error check @@ -86,7 +79,7 @@ pub fn load_from_parsed_config(root: seg_dense::Root) -> Some(info) => { debug!("{:?}", info); fm.set(feature_id, info) - }, + } None => (), } } @@ -94,19 +87,22 @@ pub fn load_from_parsed_config(root: seg_dense::Root) -> Ok(fm) } #[allow(dead_code)] -fn add_feature_info_to_mapper(feature_mapper: &mut FeatureMapper, input_features: &Vec) { +fn add_feature_info_to_mapper( + feature_mapper: &mut FeatureMapper, + input_features: &Vec, +) { for input_feature in input_features.iter() { - let feature_id = input_feature.feature_id; - let feature_info = to_feature_info(input_feature); - - match feature_info { - Some(info) => { - debug!("{:?}", info); - feature_mapper.set(feature_id, info) - }, - None => (), + let feature_id = input_feature.feature_id; + let feature_info = to_feature_info(input_feature); + + match feature_info { + Some(info) => { + debug!("{:?}", info); + feature_mapper.set(feature_id, info) } + None => (), } + } } pub fn to_feature_info(input_feature: &seg_dense::InputFeature) -> Option { @@ -139,7 +135,7 @@ pub fn to_feature_info(input_feature: &seg_dense::InputFeature) -> Option 0, 3 => 2, _ => -1, - } + }, }; if input_feature.index < 0 { @@ -156,4 +152,3 @@ pub fn to_feature_info(input_feature: &seg_dense::InputFeature) -> Option