Skip to content

Commit

Permalink
reorganize validate and interpolate into InterpMethods impl for…
Browse files Browse the repository at this point in the history
… `Interpolator` enum
  • Loading branch information
kylecarow committed Jan 22, 2025
1 parent 8721e6d commit cd30fe0
Show file tree
Hide file tree
Showing 2 changed files with 160 additions and 155 deletions.
285 changes: 130 additions & 155 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,14 @@ pub mod n;
pub mod one;
pub mod three;
pub mod two;
pub mod traits;

pub use error::*;
pub use n::*;
pub use one::*;
pub use three::*;
pub use two::*;
pub use traits::*;

#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -320,132 +322,6 @@ impl Interpolator {
Ok(Self::InterpND(interp))
}

pub fn validate(&self) -> Result<(), ValidationError> {
match self {
Self::Interp0D(_) => Ok(()),
Self::Interp1D(interp) => interp.validate(),
Self::Interp2D(interp) => interp.validate(),
Self::Interp3D(interp) => interp.validate(),
Self::InterpND(interp) => interp.validate(),
}
}

/// Interpolate at supplied point, after checking point validity.
/// Length of supplied point must match interpolator dimensionality.
pub fn interpolate(&self, point: &[f64]) -> Result<f64, InterpolationError> {
self.validate_point(point)?;
match self {
Self::Interp0D(value) => Ok(*value),
Self::Interp1D(interp) => {
match interp.extrapolate {
Extrapolate::Clamp => {
let clamped_point =
&[point[0].clamp(interp.x[0], *interp.x.last().unwrap())];
return interp.interpolate(clamped_point);
}
Extrapolate::Error => {
if !(interp.x[0] <= point[0] && &point[0] <= interp.x.last().unwrap()) {
return Err(InterpolationError::ExtrapolationError(format!(
"point = {point:?}, grid = {:?}",
interp.x
)));
}
}
_ => {}
};
interp.interpolate(point)
}
Self::Interp2D(interp) => {
match interp.extrapolate {
Extrapolate::Clamp => {
let clamped_point = &[
point[0].clamp(interp.x[0], *interp.x.last().unwrap()),
point[1].clamp(interp.y[0], *interp.y.last().unwrap()),
];
return interp.interpolate(clamped_point);
}
Extrapolate::Error => {
if !(interp.x[0] <= point[0] && &point[0] <= interp.x.last().unwrap()) {
return Err(InterpolationError::ExtrapolationError(format!(
"point = {point:?}, x grid = {:?}",
interp.x
)));
}
if !(interp.y[0] <= point[1] && &point[1] <= interp.y.last().unwrap()) {
return Err(InterpolationError::ExtrapolationError(format!(
"point = {point:?}, y grid = {:?}",
interp.y
)));
}
}
_ => {}
};
interp.interpolate(point)
}
Self::Interp3D(interp) => {
match interp.extrapolate {
Extrapolate::Clamp => {
let clamped_point = &[
point[0].clamp(interp.x[0], *interp.x.last().unwrap()),
point[1].clamp(interp.y[0], *interp.y.last().unwrap()),
point[2].clamp(interp.z[0], *interp.z.last().unwrap()),
];
return interp.interpolate(clamped_point);
}
Extrapolate::Error => {
if !(interp.x[0] <= point[0] && &point[0] <= interp.x.last().unwrap()) {
return Err(InterpolationError::ExtrapolationError(format!(
"point = {point:?}, x grid = {:?}",
interp.x
)));
}
if !(interp.y[0] <= point[1] && &point[1] <= interp.y.last().unwrap()) {
return Err(InterpolationError::ExtrapolationError(format!(
"point = {point:?}, y grid = {:?}",
interp.y
)));
}
if !(interp.z[0] <= point[2] && &point[2] <= interp.z.last().unwrap()) {
return Err(InterpolationError::ExtrapolationError(format!(
"point = {point:?}, z grid = {:?}",
interp.z
)));
}
}
_ => {}
};
interp.interpolate(point)
}
Self::InterpND(interp) => {
match interp.extrapolate {
Extrapolate::Clamp => {
let clamped_point: Vec<f64> = point
.iter()
.enumerate()
.map(|(dim, pt)| {
pt.clamp(interp.grid[dim][0], *interp.grid[dim].last().unwrap())
})
.collect();
return interp.interpolate(&clamped_point);
}
Extrapolate::Error => {
if !point.iter().enumerate().all(|(dim, pt_dim)| {
&interp.grid[dim][0] <= pt_dim
&& pt_dim <= interp.grid[dim].last().unwrap()
}) {
return Err(InterpolationError::ExtrapolationError(format!(
"point = {point:?}, grid: {:?}",
interp.grid,
)));
}
}
_ => {}
};
interp.interpolate(point)
}
}
}

/// Ensure supplied point is valid for the given interpolator
fn validate_point(&self, point: &[f64]) -> Result<(), InterpolationError> {
let n = self.ndim();
Expand Down Expand Up @@ -712,6 +588,134 @@ impl Interpolator {
}
}

impl InterpMethods for Interpolator {
fn validate(&self) -> Result<(), ValidationError> {
match self {
Self::Interp0D(_) => Ok(()),
Self::Interp1D(interp) => interp.validate(),
Self::Interp2D(interp) => interp.validate(),
Self::Interp3D(interp) => interp.validate(),
Self::InterpND(interp) => interp.validate(),
}
}

/// Interpolate at supplied point, after checking point validity.
/// Length of supplied point must match interpolator dimensionality.
fn interpolate(&self, point: &[f64]) -> Result<f64, InterpolationError> {
self.validate_point(point)?;
match self {
Self::Interp0D(value) => Ok(*value),
Self::Interp1D(interp) => {
match interp.extrapolate {
Extrapolate::Clamp => {
let clamped_point =
&[point[0].clamp(interp.x[0], *interp.x.last().unwrap())];
return interp.interpolate(clamped_point);
}
Extrapolate::Error => {
if !(interp.x[0] <= point[0] && &point[0] <= interp.x.last().unwrap()) {
return Err(InterpolationError::ExtrapolationError(format!(
"point = {point:?}, grid = {:?}",
interp.x
)));
}
}
_ => {}
};
interp.interpolate(point)
}
Self::Interp2D(interp) => {
match interp.extrapolate {
Extrapolate::Clamp => {
let clamped_point = &[
point[0].clamp(interp.x[0], *interp.x.last().unwrap()),
point[1].clamp(interp.y[0], *interp.y.last().unwrap()),
];
return interp.interpolate(clamped_point);
}
Extrapolate::Error => {
if !(interp.x[0] <= point[0] && &point[0] <= interp.x.last().unwrap()) {
return Err(InterpolationError::ExtrapolationError(format!(
"point = {point:?}, x grid = {:?}",
interp.x
)));
}
if !(interp.y[0] <= point[1] && &point[1] <= interp.y.last().unwrap()) {
return Err(InterpolationError::ExtrapolationError(format!(
"point = {point:?}, y grid = {:?}",
interp.y
)));
}
}
_ => {}
};
interp.interpolate(point)
}
Self::Interp3D(interp) => {
match interp.extrapolate {
Extrapolate::Clamp => {
let clamped_point = &[
point[0].clamp(interp.x[0], *interp.x.last().unwrap()),
point[1].clamp(interp.y[0], *interp.y.last().unwrap()),
point[2].clamp(interp.z[0], *interp.z.last().unwrap()),
];
return interp.interpolate(clamped_point);
}
Extrapolate::Error => {
if !(interp.x[0] <= point[0] && &point[0] <= interp.x.last().unwrap()) {
return Err(InterpolationError::ExtrapolationError(format!(
"point = {point:?}, x grid = {:?}",
interp.x
)));
}
if !(interp.y[0] <= point[1] && &point[1] <= interp.y.last().unwrap()) {
return Err(InterpolationError::ExtrapolationError(format!(
"point = {point:?}, y grid = {:?}",
interp.y
)));
}
if !(interp.z[0] <= point[2] && &point[2] <= interp.z.last().unwrap()) {
return Err(InterpolationError::ExtrapolationError(format!(
"point = {point:?}, z grid = {:?}",
interp.z
)));
}
}
_ => {}
};
interp.interpolate(point)
}
Self::InterpND(interp) => {
match interp.extrapolate {
Extrapolate::Clamp => {
let clamped_point: Vec<f64> = point
.iter()
.enumerate()
.map(|(dim, pt)| {
pt.clamp(interp.grid[dim][0], *interp.grid[dim].last().unwrap())
})
.collect();
return interp.interpolate(&clamped_point);
}
Extrapolate::Error => {
if !point.iter().enumerate().all(|(dim, pt_dim)| {
&interp.grid[dim][0] <= pt_dim
&& pt_dim <= interp.grid[dim].last().unwrap()
}) {
return Err(InterpolationError::ExtrapolationError(format!(
"point = {point:?}, grid: {:?}",
interp.grid,
)));
}
}
_ => {}
};
interp.interpolate(point)
}
}
}
}

/// Interpolation strategy
#[derive(Clone, Debug, PartialEq, Default)]
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
Expand Down Expand Up @@ -745,35 +749,6 @@ pub enum Extrapolate {
Error,
}

/// Methods applicable to all interpolator helper structs
pub trait InterpMethods {
/// Validate data stored in [Self]. By design, [Self] can be instantiatated
/// only via the `new` method, which calls `validate`.
fn validate(&self) -> Result<(), ValidationError>;
/// Interpolate at given point
fn interpolate(&self, point: &[f64]) -> Result<f64, InterpolationError>;
}

/// Linear interpolation: <https://en.wikipedia.org/wiki/Linear_interpolation>
pub trait Linear {
fn linear(&self, point: &[f64]) -> Result<f64, InterpolationError>;
}

/// Left-nearest (previous value) interpolation: <https://en.wikipedia.org/wiki/Nearest-neighbor_interpolation>
pub trait LeftNearest {
fn left_nearest(&self, point: &[f64]) -> Result<f64, InterpolationError>;
}

/// Right-nearest (next value) interpolation: <https://en.wikipedia.org/wiki/Nearest-neighbor_interpolation>
pub trait RightNearest {
fn right_nearest(&self, point: &[f64]) -> Result<f64, InterpolationError>;
}

/// Nearest value (left or right, whichever nearest) interpolation: <https://en.wikipedia.org/wiki/Nearest-neighbor_interpolation>
pub trait Nearest {
fn nearest(&self, point: &[f64]) -> Result<f64, InterpolationError>;
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
30 changes: 30 additions & 0 deletions src/traits.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
use super::*;

/// Methods applicable to all interpolators
pub trait InterpMethods {
/// Validate data stored in [Self]. By design, [Self] can be instantiatated
/// only via the `new` method, which calls `validate`.
fn validate(&self) -> Result<(), ValidationError>;
/// Interpolate at given point
fn interpolate(&self, point: &[f64]) -> Result<f64, InterpolationError>;
}

/// Linear interpolation: <https://en.wikipedia.org/wiki/Linear_interpolation>
pub trait Linear {
fn linear(&self, point: &[f64]) -> Result<f64, InterpolationError>;
}

/// Left-nearest (previous value) interpolation: <https://en.wikipedia.org/wiki/Nearest-neighbor_interpolation>
pub trait LeftNearest {
fn left_nearest(&self, point: &[f64]) -> Result<f64, InterpolationError>;
}

/// Right-nearest (next value) interpolation: <https://en.wikipedia.org/wiki/Nearest-neighbor_interpolation>
pub trait RightNearest {
fn right_nearest(&self, point: &[f64]) -> Result<f64, InterpolationError>;
}

/// Nearest value (left or right, whichever nearest) interpolation: <https://en.wikipedia.org/wiki/Nearest-neighbor_interpolation>
pub trait Nearest {
fn nearest(&self, point: &[f64]) -> Result<f64, InterpolationError>;
}

0 comments on commit cd30fe0

Please sign in to comment.