Skip to content

Commit

Permalink
Merge pull request #292 from NREL/rjf/string-road-class
Browse files Browse the repository at this point in the history
Rjf/string road class
  • Loading branch information
robfitzgerald authored Feb 19, 2025
2 parents 42a55a0 + 40cbfe7 commit 3eac860
Show file tree
Hide file tree
Showing 5 changed files with 153 additions and 92 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
pub mod road_class_builder;
pub mod road_class_model;
pub mod road_class_parser;
pub mod road_class_service;
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::{road_class_parser::RoadClassParser, road_class_service::RoadClassFrontierService};
use super::road_class_service::RoadClassFrontierService;
use kdam::Bar;
use routee_compass_core::config::{CompassConfigurationField, ConfigJsonExtensions};
use routee_compass_core::{
Expand Down Expand Up @@ -27,9 +27,9 @@ impl FrontierModelBuilder for RoadClassBuilder {
))
})?;

let road_class_lookup: Box<[u8]> = read_utils::read_raw_file(
let road_class_lookup: Box<[String]> = read_utils::read_raw_file(
&road_class_file,
read_decoders::u8,
read_decoders::string,
Some(Bar::builder().desc("road class")),
None,
)
Expand All @@ -41,22 +41,9 @@ impl FrontierModelBuilder for RoadClassBuilder {
))
})?;

let road_class_parser = parameters
.get_config_serde_optional::<RoadClassParser>(
&"road_class_parser",
&"RoadClassFrontierModel",
)
.map_err(|e| {
FrontierModelError::BuildError(format!(
"unable to deserialize road_class_parser: {}",
e
))
})?
.unwrap_or_default();

let m: Arc<dyn FrontierModelService> = Arc::new(RoadClassFrontierService {
road_class_lookup: Arc::new(road_class_lookup),
road_class_parser,
road_class_by_edge: Arc::new(road_class_lookup),
// road_class_parser,
});
Ok(m)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use std::{

pub struct RoadClassFrontierModel {
pub service: Arc<RoadClassFrontierService>,
pub road_classes: Option<HashSet<u8>>,
pub query_road_classes: Option<HashSet<String>>,
}

impl FrontierModel for RoadClassFrontierModel {
Expand All @@ -30,11 +30,11 @@ impl FrontierModel for RoadClassFrontierModel {
}

fn valid_edge(&self, edge: &Edge) -> Result<bool, FrontierModelError> {
match &self.road_classes {
match &self.query_road_classes {
None => Ok(true),
Some(road_classes) => self
.service
.road_class_lookup
.road_class_by_edge
.get(edge.edge_id.0)
.ok_or_else(|| {
FrontierModelError::FrontierModelError(format!(
Expand All @@ -46,3 +46,111 @@ impl FrontierModel for RoadClassFrontierModel {
}
}
}

#[cfg(test)]
mod test {
use crate::app::compass::model::frontier_model::road_class::road_class_service::RoadClassFrontierService;
use routee_compass_core::model::{
frontier::{FrontierModel, FrontierModelService},
network::Edge,
state::StateModel,
};
use serde_json::{json, Value};
use std::sync::Arc;

/// builds the test model for a given RoadClassModel test
/// # Arguments
/// * `road_class_vector` - the value assumed to be read from a file, with road classes by EdgeId index value
/// * `query` - the user query which should provide the set of valid road classes for this search
fn mock(road_class_vector: Box<[String]>, query: Value) -> Arc<dyn FrontierModel> {
let service = Arc::new(RoadClassFrontierService {
road_class_by_edge: Arc::new(road_class_vector),
});
let state_model = Arc::new(StateModel::empty());
service.build(&query, state_model.clone()).unwrap()
}

#[test]
fn test_no_road_classes() {
let model = mock(Box::new([String::from("a")]), json!({}));
let result = model.valid_edge(&Edge::new(0, 0, 1, 1.0)).unwrap();
assert!(result)
}

#[test]
fn test_valid_class() {
let model = mock(
Box::new([String::from("a")]),
json!({"road_classes": ["a"]}),
);
let result = model.valid_edge(&Edge::new(0, 0, 1, 1.0)).unwrap();
assert!(result)
}

#[test]
fn test_invalid_class() {
let model = mock(
Box::new([String::from("oh no!")]),
json!({"road_classes": ["a"]}),
);
let result = model.valid_edge(&Edge::new(0, 0, 1, 1.0)).unwrap();
assert!(!result)
}

#[test]
fn test_one_of_valid_classes() {
let model = mock(
Box::new([String::from("a")]),
json!({"road_classes": ["a", "b", "c"]}),
);
let result = model.valid_edge(&Edge::new(0, 0, 1, 1.0)).unwrap();
assert!(result)
}

#[test]
fn test_none_of_valid_classes() {
let model = mock(
Box::new([String::from("oh no!")]),
json!({"road_classes": ["a", "b", "c"]}),
);
let result = model.valid_edge(&Edge::new(0, 0, 1, 1.0)).unwrap();
assert!(!result)
}

#[test]
fn test_valid_numeric_class() {
let model = mock(Box::new([String::from("1")]), json!({"road_classes": [1]}));
let result = model.valid_edge(&Edge::new(0, 0, 1, 1.0)).unwrap();
assert!(result)
}

#[test]
fn test_invalid_numeric_class() {
let model = mock(
Box::new([String::from("OH NO!")]),
json!({"road_classes": [1]}),
);
let result = model.valid_edge(&Edge::new(0, 0, 1, 1.0)).unwrap();
assert!(!result)
}

#[test]
fn test_valid_boolean_class() {
let model = mock(
Box::new([String::from("true")]),
json!({"road_classes": [true]}),
);
let result = model.valid_edge(&Edge::new(0, 0, 1, 1.0)).unwrap();
assert!(result)
}

#[test]
fn test_invalid_boolean_class() {
let model = mock(
Box::new([String::from("OH NO!")]),
json!({"road_classes": [true]}),
);
let result = model.valid_edge(&Edge::new(0, 0, 1, 1.0)).unwrap();
assert!(!result)
}
}

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
use super::{road_class_model::RoadClassFrontierModel, road_class_parser::RoadClassParser};
use super::road_class_model::RoadClassFrontierModel;
use routee_compass_core::model::{
frontier::{FrontierModel, FrontierModelError, FrontierModelService},
state::StateModel,
};
use std::sync::Arc;
use serde_json::Value;
use std::{collections::HashSet, sync::Arc};

#[derive(Clone)]
pub struct RoadClassFrontierService {
pub road_class_lookup: Arc<Box<[u8]>>,
pub road_class_parser: RoadClassParser,
pub road_class_by_edge: Arc<Box<[String]>>,
}

impl FrontierModelService for RoadClassFrontierService {
Expand All @@ -17,17 +17,43 @@ impl FrontierModelService for RoadClassFrontierService {
query: &serde_json::Value,
_state_model: Arc<StateModel>,
) -> Result<Arc<dyn FrontierModel>, FrontierModelError> {
let query_road_classes = match query.get("road_classes").map(read_road_classes_from_query) {
Some(Err(e)) => Err(e),
Some(Ok(road_classes)) => Ok(Some(road_classes)),
None => Ok(None),
}?;

let service: Arc<RoadClassFrontierService> = Arc::new(self.clone());
let road_classes = self.road_class_parser.read_query(query).map_err(|e| {
FrontierModelError::BuildError(format!(
"Unable to parse incoming query road_classes due to: {}",
e
))
})?;
let model = RoadClassFrontierModel {
service,
road_classes,
query_road_classes,
};
Ok(Arc::new(model))
}
}

/// decodes the query `road_classes` value into a set of road class identifiers
fn read_road_classes_from_query(value: &Value) -> Result<HashSet<String>, FrontierModelError> {
let arr = value.as_array().ok_or_else(|| {
FrontierModelError::BuildError(format!(
"query 'road_classes' value must be an array, found '{}'",
value
))
})?;
// if the value is a string (or number or bool), store it as a valid road class
let arr_str = arr
.iter()
.enumerate()
.map(|(idx, c)| match c {
Value::Bool(b) => Ok(b.to_string()),
Value::Number(number) => Ok(number.to_string()),
Value::String(string) => Ok(string.clone()),
_ => Err(FrontierModelError::BuildError(format!(
"query 'road_classes[{}]' value must be a string, found '{}'",
idx, c
))),
})
.collect::<Result<HashSet<_>, _>>()?;

Ok(arr_str)
}

0 comments on commit 3eac860

Please sign in to comment.