Skip to content

Commit

Permalink
add tests for extrapolation
Browse files Browse the repository at this point in the history
  • Loading branch information
kylecarow committed Jun 6, 2024
1 parent 165195e commit 9bb2f3e
Show file tree
Hide file tree
Showing 5 changed files with 233 additions and 25 deletions.
44 changes: 19 additions & 25 deletions fastsim-core/src/utils/interp/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,18 +56,6 @@ impl Interpolator {
Self::Interp0D(value) => Ok(*value),
Self::Interp1D(interp) => {
match interp.extrapolate {
Extrapolate::Extrapolate => {
ensure!(
matches!(interp.strategy, Strategy::Linear),
"`Extrapolate` is only implemented for 1-D linear, use `Clamp` or `Error` extrapolation strategy instead"
);
ensure!(
interp.x.len() >= 2,
"At least 2 data points are required for extrapolation: x = {:?}, f_x = {:?}",
interp.x,
interp.f_x,
);
}
Extrapolate::Clamp => {
let clamped_point =
&[point[0].clamp(interp.x[0], *interp.x.last().unwrap())];
Expand All @@ -82,60 +70,65 @@ impl Interpolator {
interp.x,
);
}
_ => {}
};
interp.interpolate(point)
}
Self::Interp2D(interp) => {
match interp.extrapolate {
Extrapolate::Extrapolate => bail!("`Extrapolate` is not implemented for 2-D, use `Clamp` or `Error` extrapolation strategy instead"),
Extrapolate::Clamp => {
let clamped_point = &[
point[0].clamp(interp.x[0], *interp.x.last().unwrap()),
point[1].clamp(interp.y[0], *interp.y.last().unwrap()),
];
return interp.interpolate(clamped_point);
},
}
Extrapolate::Error => {
let x_dim_ok = interp.x[0] <= point[0] && &point[0] <= interp.x.last().unwrap();
let y_dim_ok = interp.y[0] <= point[1] && &point[1] <= interp.y.last().unwrap();
let x_dim_ok =
interp.x[0] <= point[0] && &point[0] <= interp.x.last().unwrap();
let y_dim_ok =
interp.y[0] <= point[1] && &point[1] <= interp.y.last().unwrap();
ensure!(
x_dim_ok && y_dim_ok,
"Attempted to interpolate at point beyond grid data: point = {point:?}, x grid = {:?}, y grid = {:?}",
interp.x,
interp.y,
);
},
}
_ => {}
};
interp.interpolate(point)
}
Self::Interp3D(interp) => {
match interp.extrapolate {
Extrapolate::Extrapolate => bail!("`Extrapolate` is not implemented for 3-D, use `Clamp` or `Error` extrapolation strategy instead"),
Extrapolate::Clamp => {
let clamped_point = &[
point[0].clamp(interp.x[0], *interp.x.last().unwrap()),
point[1].clamp(interp.y[0], *interp.y.last().unwrap()),
point[2].clamp(interp.z[0], *interp.z.last().unwrap()),
];
return interp.interpolate(clamped_point);
},
}
Extrapolate::Error => {
let x_dim_ok = interp.x[0] <= point[0] && &point[0] <= interp.x.last().unwrap();
let y_dim_ok = interp.y[0] <= point[1] && &point[1] <= interp.y.last().unwrap();
let z_dim_ok = interp.z[0] <= point[2] && &point[2] <= interp.z.last().unwrap();
let x_dim_ok =
interp.x[0] <= point[0] && &point[0] <= interp.x.last().unwrap();
let y_dim_ok =
interp.y[0] <= point[1] && &point[1] <= interp.y.last().unwrap();
let z_dim_ok =
interp.z[0] <= point[2] && &point[2] <= interp.z.last().unwrap();
ensure!(x_dim_ok && y_dim_ok && z_dim_ok,
"Attempted to interpolate at point beyond grid data: point = {point:?}, x grid = {:?}, y grid = {:?}, z grid = {:?}",
interp.x,
interp.y,
interp.z,
);
},
}
_ => {}
};
interp.interpolate(point)
}
Self::InterpND(interp) => {
match interp.extrapolate {
Extrapolate::Extrapolate => bail!("`Extrapolate` is not implemented for multilinear interpolator, use `Clamp` or `Error` extrapolation strategy instead"),
Extrapolate::Clamp => {
let clamped_point: Vec<f64> = point
.iter()
Expand All @@ -144,12 +137,13 @@ impl Interpolator {
pt.clamp(interp.grid[dim].min().unwrap(), interp.grid[dim].max().unwrap())
).collect();
return interp.interpolate(&clamped_point);
},
}
Extrapolate::Error => ensure!(
point.iter().enumerate().all(|(dim, pt_dim)| &interp.grid[dim][0] <= pt_dim && pt_dim <= interp.grid[dim].last().unwrap()),
"Attempted to interpolate at point beyond grid data: point = {point:?}, grid: {:?}",
interp.grid,
),
_ => {}
};
interp.interpolate(&point)
}
Expand Down
41 changes: 41 additions & 0 deletions fastsim-core/src/utils/interp/n.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,8 @@ impl InterpMethods for InterpND {
fn validate(&self) -> anyhow::Result<()> {
let n = self.ndim();

ensure!(!matches!(self.extrapolate, Extrapolate::Extrapolate), "`Extrapolate` is not implemented for N-D, use `Clamp` or `Error` extrapolation strategy instead");

// Check that each grid dimension has elements
for i in 0..n {
// Indexing `grid` directly is okay because `grid == vec![]` is caught at compilation
Expand Down Expand Up @@ -278,4 +280,43 @@ mod tests {
3.1999999999999997
) // 3.2
}

#[test]
fn test_extrapolate_inputs() {
// Extrapolate::Extrapolate
assert!(InterpND::new(
vec![vec![0., 1.], vec![0., 1.], vec![0., 1.]],
array![[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]],].into_dyn(),
Strategy::Linear,
Extrapolate::Extrapolate,
)
.is_err());
// Extrapolate::Error
let interp = Interpolator::InterpND(
InterpND::new(
vec![vec![0., 1.], vec![0., 1.], vec![0., 1.]],
array![[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]],].into_dyn(),
Strategy::Linear,
Extrapolate::Error,
)
.unwrap(),
);
assert!(interp.interpolate(&[-1., -1., -1.]).is_err());
assert!(interp.interpolate(&[2., 2., 2.]).is_err());
}

#[test]
fn test_extrapolate_clamp() {
let interp = Interpolator::InterpND(
InterpND::new(
vec![vec![0., 1.], vec![0., 1.], vec![0., 1.]],
array![[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]],].into_dyn(),
Strategy::Linear,
Extrapolate::Clamp,
)
.unwrap(),
);
assert_eq!(interp.interpolate(&[-1., -1., -1.]).unwrap(), 0.);
assert_eq!(interp.interpolate(&[2., 2., 2.]).unwrap(), 7.);
}
}
73 changes: 73 additions & 0 deletions fastsim-core/src/utils/interp/one.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,15 @@ impl Interp1D {
// Extrapolate, if applicable
if matches!(self.extrapolate, Extrapolate::Extrapolate) {
if point < self.x[0] {
log::warn!("Extrapolating: point = {}, x_min = {}", point, self.x[0]);
let slope = (self.f_x[1] - self.f_x[0]) / (self.x[1] - self.x[0]);
return Ok(slope * (point - self.x[0]) + self.f_x[0]);
} else if &point > self.x.last().unwrap() {
log::warn!(
"Extrapolating: point = {}, x_max = {}",
point,
self.x.last().unwrap()
);
let slope = (self.f_x.last().unwrap() - self.f_x[self.f_x.len() - 2])
/ (self.x.last().unwrap() - self.x[self.x.len() - 2]);
return Ok(slope * (point - self.x.last().unwrap()) + self.f_x.last().unwrap());
Expand Down Expand Up @@ -83,6 +89,19 @@ impl InterpMethods for Interp1D {
fn validate(&self) -> anyhow::Result<()> {
let x_grid_len = self.x.len();

if matches!(self.extrapolate, Extrapolate::Extrapolate) {
ensure!(
matches!(self.strategy, Strategy::Linear),
"`Extrapolate` is only implemented for 1-D linear, use `Clamp` or `Error` extrapolation strategy instead"
);
ensure!(
self.x.len() >= 2,
"At least 2 data points are required for extrapolation: x = {:?}, f_x = {:?}",
self.x,
self.f_x,
);
}

// Check that each grid dimension has elements
ensure!(x_grid_len != 0, "Supplied x-coordinates cannot be empty");
// Check that grid points are monotonically increasing
Expand Down Expand Up @@ -193,4 +212,58 @@ mod tests {
assert_eq!(interp.interpolate(&[3.75]).unwrap(), 1.0);
assert_eq!(interp.interpolate(&[4.00]).unwrap(), 1.0);
}

#[test]
fn test_extrapolate_inputs() {
// Incorrect strategy
assert!(Interp1D::new(
vec![0., 1., 2., 3., 4.],
vec![0.2, 0.4, 0.6, 0.8, 1.0],
Strategy::Nearest,
Extrapolate::Extrapolate,
)
.is_err());
// Extrapolate::Error
let interp = Interpolator::Interp1D(
Interp1D::new(
vec![0., 1., 2., 3., 4.],
vec![0.2, 0.4, 0.6, 0.8, 1.0],
Strategy::Linear,
Extrapolate::Error,
)
.unwrap(),
);
assert!(interp.interpolate(&[-1.]).is_err());
assert!(interp.interpolate(&[5.]).is_err());
}

#[test]
fn test_extrapolate_clamp() {
let interp = Interpolator::Interp1D(
Interp1D::new(
vec![0., 1., 2., 3., 4.],
vec![0.2, 0.4, 0.6, 0.8, 1.0],
Strategy::Linear,
Extrapolate::Clamp,
)
.unwrap(),
);
assert_eq!(interp.interpolate(&[-1.]).unwrap(), 0.2);
assert_eq!(interp.interpolate(&[5.]).unwrap(), 1.0);
}

#[test]
fn test_extrapolate() {
let interp = Interpolator::Interp1D(
Interp1D::new(
vec![0., 1., 2., 3., 4.],
vec![0.2, 0.4, 0.6, 0.8, 1.0],
Strategy::Linear,
Extrapolate::Extrapolate,
)
.unwrap(),
);
assert_eq!(interp.interpolate(&[-1.]).unwrap(), 0.0);
assert_eq!(interp.interpolate(&[5.]).unwrap(), 1.2);
}
}
56 changes: 56 additions & 0 deletions fastsim-core/src/utils/interp/three.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ impl InterpMethods for Interp3D {
let y_grid_len = self.y.len();
let z_grid_len = self.z.len();

ensure!(!matches!(self.extrapolate, Extrapolate::Extrapolate), "`Extrapolate` is not implemented for 3-D, use `Clamp` or `Error` extrapolation strategy instead");

// Check that each grid dimension has elements
ensure!(
x_grid_len != 0 || y_grid_len != 0 || z_grid_len != 0,
Expand Down Expand Up @@ -199,4 +201,58 @@ mod tests {
3.1999999999999997
) // 3.2
}

#[test]
fn test_extrapolate_inputs() {
// Extrapolate::Extrapolate
assert!(Interp3D::new(
vec![0., 1.],
vec![0., 1.],
vec![0., 1.],
vec![
vec![vec![0., 1.], vec![2., 3.]],
vec![vec![4., 5.], vec![6., 7.]],
],
Strategy::Linear,
Extrapolate::Extrapolate,
)
.is_err());
// Extrapolate::Error
let interp = Interpolator::Interp3D(
Interp3D::new(
vec![0., 1.],
vec![0., 1.],
vec![0., 1.],
vec![
vec![vec![0., 1.], vec![2., 3.]],
vec![vec![4., 5.], vec![6., 7.]],
],
Strategy::Linear,
Extrapolate::Error,
)
.unwrap(),
);
assert!(interp.interpolate(&[-1., -1., -1.]).is_err());
assert!(interp.interpolate(&[2., 2., 2.]).is_err());
}

#[test]
fn test_extrapolate_clamp() {
let interp = Interpolator::Interp3D(
Interp3D::new(
vec![0., 1.],
vec![0., 1.],
vec![0., 1.],
vec![
vec![vec![0., 1.], vec![2., 3.]],
vec![vec![4., 5.], vec![6., 7.]],
],
Strategy::Linear,
Extrapolate::Clamp,
)
.unwrap(),
);
assert_eq!(interp.interpolate(&[-1., -1., -1.]).unwrap(), 0.);
assert_eq!(interp.interpolate(&[2., 2., 2.]).unwrap(), 7.);
}
}
44 changes: 44 additions & 0 deletions fastsim-core/src/utils/interp/two.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ impl InterpMethods for Interp2D {
let x_grid_len = self.x.len();
let y_grid_len = self.y.len();

ensure!(!matches!(self.extrapolate, Extrapolate::Extrapolate), "`Extrapolate` is not implemented for 2-D, use `Clamp` or `Error` extrapolation strategy instead");

// Check that each grid dimension has elements
ensure!(
x_grid_len != 0 && y_grid_len != 0,
Expand Down Expand Up @@ -129,4 +131,46 @@ mod tests {
let interp_res = interp.interpolate(&[0.25, 0.65]).unwrap();
assert_eq!(interp_res, 1.1500000000000001) // 1.15
}

#[test]
fn test_extrapolate_inputs() {
// Extrapolate::Extrapolate
assert!(Interp2D::new(
vec![0., 1.],
vec![0., 1.],
vec![vec![0., 1.], vec![2., 3.]],
Strategy::Linear,
Extrapolate::Extrapolate,
)
.is_err());
// Extrapolate::Error
let interp = Interpolator::Interp2D(
Interp2D::new(
vec![0., 1.],
vec![0., 1.],
vec![vec![0., 1.], vec![2., 3.]],
Strategy::Linear,
Extrapolate::Error,
)
.unwrap(),
);
assert!(interp.interpolate(&[-1.]).is_err());
assert!(interp.interpolate(&[2.]).is_err());
}

#[test]
fn test_extrapolate_clamp() {
let interp = Interpolator::Interp2D(
Interp2D::new(
vec![0., 1.],
vec![0., 1.],
vec![vec![0., 1.], vec![2., 3.]],
Strategy::Linear,
Extrapolate::Clamp,
)
.unwrap(),
);
assert_eq!(interp.interpolate(&[-1., -1.]).unwrap(), 0.);
assert_eq!(interp.interpolate(&[2., 2.]).unwrap(), 3.);
}
}

0 comments on commit 9bb2f3e

Please sign in to comment.