diff --git a/src/lib.rs b/src/lib.rs index 4f8591c..a6cd885 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -23,12 +23,14 @@ pub mod n; pub mod one; pub mod three; pub mod two; +pub mod traits; pub use error::*; pub use n::*; pub use one::*; pub use three::*; pub use two::*; +pub use traits::*; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; @@ -320,132 +322,6 @@ impl Interpolator { Ok(Self::InterpND(interp)) } - pub fn validate(&self) -> Result<(), ValidationError> { - match self { - Self::Interp0D(_) => Ok(()), - Self::Interp1D(interp) => interp.validate(), - Self::Interp2D(interp) => interp.validate(), - Self::Interp3D(interp) => interp.validate(), - Self::InterpND(interp) => interp.validate(), - } - } - - /// Interpolate at supplied point, after checking point validity. - /// Length of supplied point must match interpolator dimensionality. - pub fn interpolate(&self, point: &[f64]) -> Result { - self.validate_point(point)?; - match self { - Self::Interp0D(value) => Ok(*value), - Self::Interp1D(interp) => { - match interp.extrapolate { - Extrapolate::Clamp => { - let clamped_point = - &[point[0].clamp(interp.x[0], *interp.x.last().unwrap())]; - return interp.interpolate(clamped_point); - } - Extrapolate::Error => { - if !(interp.x[0] <= point[0] && &point[0] <= interp.x.last().unwrap()) { - return Err(InterpolationError::ExtrapolationError(format!( - "point = {point:?}, grid = {:?}", - interp.x - ))); - } - } - _ => {} - }; - interp.interpolate(point) - } - Self::Interp2D(interp) => { - match interp.extrapolate { - Extrapolate::Clamp => { - let clamped_point = &[ - point[0].clamp(interp.x[0], *interp.x.last().unwrap()), - point[1].clamp(interp.y[0], *interp.y.last().unwrap()), - ]; - return interp.interpolate(clamped_point); - } - Extrapolate::Error => { - if !(interp.x[0] <= point[0] && &point[0] <= interp.x.last().unwrap()) { - return Err(InterpolationError::ExtrapolationError(format!( - "point = {point:?}, x grid = {:?}", - interp.x - ))); - } - if !(interp.y[0] <= point[1] && &point[1] <= interp.y.last().unwrap()) { - return Err(InterpolationError::ExtrapolationError(format!( - "point = {point:?}, y grid = {:?}", - interp.y - ))); - } - } - _ => {} - }; - interp.interpolate(point) - } - Self::Interp3D(interp) => { - match interp.extrapolate { - Extrapolate::Clamp => { - let clamped_point = &[ - point[0].clamp(interp.x[0], *interp.x.last().unwrap()), - point[1].clamp(interp.y[0], *interp.y.last().unwrap()), - point[2].clamp(interp.z[0], *interp.z.last().unwrap()), - ]; - return interp.interpolate(clamped_point); - } - Extrapolate::Error => { - if !(interp.x[0] <= point[0] && &point[0] <= interp.x.last().unwrap()) { - return Err(InterpolationError::ExtrapolationError(format!( - "point = {point:?}, x grid = {:?}", - interp.x - ))); - } - if !(interp.y[0] <= point[1] && &point[1] <= interp.y.last().unwrap()) { - return Err(InterpolationError::ExtrapolationError(format!( - "point = {point:?}, y grid = {:?}", - interp.y - ))); - } - if !(interp.z[0] <= point[2] && &point[2] <= interp.z.last().unwrap()) { - return Err(InterpolationError::ExtrapolationError(format!( - "point = {point:?}, z grid = {:?}", - interp.z - ))); - } - } - _ => {} - }; - interp.interpolate(point) - } - Self::InterpND(interp) => { - match interp.extrapolate { - Extrapolate::Clamp => { - let clamped_point: Vec = point - .iter() - .enumerate() - .map(|(dim, pt)| { - pt.clamp(interp.grid[dim][0], *interp.grid[dim].last().unwrap()) - }) - .collect(); - return interp.interpolate(&clamped_point); - } - Extrapolate::Error => { - if !point.iter().enumerate().all(|(dim, pt_dim)| { - &interp.grid[dim][0] <= pt_dim - && pt_dim <= interp.grid[dim].last().unwrap() - }) { - return Err(InterpolationError::ExtrapolationError(format!( - "point = {point:?}, grid: {:?}", - interp.grid, - ))); - } - } - _ => {} - }; - interp.interpolate(point) - } - } - } - /// Ensure supplied point is valid for the given interpolator fn validate_point(&self, point: &[f64]) -> Result<(), InterpolationError> { let n = self.ndim(); @@ -712,6 +588,134 @@ impl Interpolator { } } +impl InterpMethods for Interpolator { + fn validate(&self) -> Result<(), ValidationError> { + match self { + Self::Interp0D(_) => Ok(()), + Self::Interp1D(interp) => interp.validate(), + Self::Interp2D(interp) => interp.validate(), + Self::Interp3D(interp) => interp.validate(), + Self::InterpND(interp) => interp.validate(), + } + } + + /// Interpolate at supplied point, after checking point validity. + /// Length of supplied point must match interpolator dimensionality. + fn interpolate(&self, point: &[f64]) -> Result { + self.validate_point(point)?; + match self { + Self::Interp0D(value) => Ok(*value), + Self::Interp1D(interp) => { + match interp.extrapolate { + Extrapolate::Clamp => { + let clamped_point = + &[point[0].clamp(interp.x[0], *interp.x.last().unwrap())]; + return interp.interpolate(clamped_point); + } + Extrapolate::Error => { + if !(interp.x[0] <= point[0] && &point[0] <= interp.x.last().unwrap()) { + return Err(InterpolationError::ExtrapolationError(format!( + "point = {point:?}, grid = {:?}", + interp.x + ))); + } + } + _ => {} + }; + interp.interpolate(point) + } + Self::Interp2D(interp) => { + match interp.extrapolate { + Extrapolate::Clamp => { + let clamped_point = &[ + point[0].clamp(interp.x[0], *interp.x.last().unwrap()), + point[1].clamp(interp.y[0], *interp.y.last().unwrap()), + ]; + return interp.interpolate(clamped_point); + } + Extrapolate::Error => { + if !(interp.x[0] <= point[0] && &point[0] <= interp.x.last().unwrap()) { + return Err(InterpolationError::ExtrapolationError(format!( + "point = {point:?}, x grid = {:?}", + interp.x + ))); + } + if !(interp.y[0] <= point[1] && &point[1] <= interp.y.last().unwrap()) { + return Err(InterpolationError::ExtrapolationError(format!( + "point = {point:?}, y grid = {:?}", + interp.y + ))); + } + } + _ => {} + }; + interp.interpolate(point) + } + Self::Interp3D(interp) => { + match interp.extrapolate { + Extrapolate::Clamp => { + let clamped_point = &[ + point[0].clamp(interp.x[0], *interp.x.last().unwrap()), + point[1].clamp(interp.y[0], *interp.y.last().unwrap()), + point[2].clamp(interp.z[0], *interp.z.last().unwrap()), + ]; + return interp.interpolate(clamped_point); + } + Extrapolate::Error => { + if !(interp.x[0] <= point[0] && &point[0] <= interp.x.last().unwrap()) { + return Err(InterpolationError::ExtrapolationError(format!( + "point = {point:?}, x grid = {:?}", + interp.x + ))); + } + if !(interp.y[0] <= point[1] && &point[1] <= interp.y.last().unwrap()) { + return Err(InterpolationError::ExtrapolationError(format!( + "point = {point:?}, y grid = {:?}", + interp.y + ))); + } + if !(interp.z[0] <= point[2] && &point[2] <= interp.z.last().unwrap()) { + return Err(InterpolationError::ExtrapolationError(format!( + "point = {point:?}, z grid = {:?}", + interp.z + ))); + } + } + _ => {} + }; + interp.interpolate(point) + } + Self::InterpND(interp) => { + match interp.extrapolate { + Extrapolate::Clamp => { + let clamped_point: Vec = point + .iter() + .enumerate() + .map(|(dim, pt)| { + pt.clamp(interp.grid[dim][0], *interp.grid[dim].last().unwrap()) + }) + .collect(); + return interp.interpolate(&clamped_point); + } + Extrapolate::Error => { + if !point.iter().enumerate().all(|(dim, pt_dim)| { + &interp.grid[dim][0] <= pt_dim + && pt_dim <= interp.grid[dim].last().unwrap() + }) { + return Err(InterpolationError::ExtrapolationError(format!( + "point = {point:?}, grid: {:?}", + interp.grid, + ))); + } + } + _ => {} + }; + interp.interpolate(point) + } + } + } +} + /// Interpolation strategy #[derive(Clone, Debug, PartialEq, Default)] #[cfg_attr(feature = "serde", derive(Deserialize, Serialize))] @@ -745,35 +749,6 @@ pub enum Extrapolate { Error, } -/// Methods applicable to all interpolator helper structs -pub trait InterpMethods { - /// Validate data stored in [Self]. By design, [Self] can be instantiatated - /// only via the `new` method, which calls `validate`. - fn validate(&self) -> Result<(), ValidationError>; - /// Interpolate at given point - fn interpolate(&self, point: &[f64]) -> Result; -} - -/// Linear interpolation: -pub trait Linear { - fn linear(&self, point: &[f64]) -> Result; -} - -/// Left-nearest (previous value) interpolation: -pub trait LeftNearest { - fn left_nearest(&self, point: &[f64]) -> Result; -} - -/// Right-nearest (next value) interpolation: -pub trait RightNearest { - fn right_nearest(&self, point: &[f64]) -> Result; -} - -/// Nearest value (left or right, whichever nearest) interpolation: -pub trait Nearest { - fn nearest(&self, point: &[f64]) -> Result; -} - #[cfg(test)] mod tests { use super::*; diff --git a/src/traits.rs b/src/traits.rs new file mode 100644 index 0000000..b11eb13 --- /dev/null +++ b/src/traits.rs @@ -0,0 +1,30 @@ +use super::*; + +/// Methods applicable to all interpolators +pub trait InterpMethods { + /// Validate data stored in [Self]. By design, [Self] can be instantiatated + /// only via the `new` method, which calls `validate`. + fn validate(&self) -> Result<(), ValidationError>; + /// Interpolate at given point + fn interpolate(&self, point: &[f64]) -> Result; +} + +/// Linear interpolation: +pub trait Linear { + fn linear(&self, point: &[f64]) -> Result; +} + +/// Left-nearest (previous value) interpolation: +pub trait LeftNearest { + fn left_nearest(&self, point: &[f64]) -> Result; +} + +/// Right-nearest (next value) interpolation: +pub trait RightNearest { + fn right_nearest(&self, point: &[f64]) -> Result; +} + +/// Nearest value (left or right, whichever nearest) interpolation: +pub trait Nearest { + fn nearest(&self, point: &[f64]) -> Result; +}