Skip to content

Commit

Permalink
feat: events and tstop #42 #43 #45
Browse files Browse the repository at this point in the history
* #42 add root finder, adjust step signature for roots and tstop

* fix most compiler errorrs

* #42 fix rest of compiler complaints

* #42 only need to call root_fn at end of step

* #42 revert change to diff_tmo

* fix bugs in root finder #42

* tests pass #42

* cargo fmt #42

* fix docs #42
  • Loading branch information
martinjrobins authored May 7, 2024
1 parent 958887f commit 7e1cd1e
Show file tree
Hide file tree
Showing 22 changed files with 872 additions and 233 deletions.
10 changes: 7 additions & 3 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,12 +155,16 @@ pub use ode_solver::sundials::SundialsIda;

use matrix::{DenseMatrix, Matrix, MatrixCommon, MatrixSparsity, MatrixView, MatrixViewMut};
pub use nonlinear_solver::newton::NewtonNonlinearSolver;
use nonlinear_solver::NonLinearSolver;
use nonlinear_solver::{root::RootFinder, NonLinearSolver};
pub use ode_solver::{
bdf::Bdf, builder::OdeBuilder, equations::OdeEquations, method::OdeSolverMethod,
method::OdeSolverState, problem::OdeSolverProblem, sdirk::Sdirk, tableau::Tableau,
method::OdeSolverState, method::OdeSolverStopReason, problem::OdeSolverProblem, sdirk::Sdirk,
tableau::Tableau,
};
use op::{
closure::Closure, closure_no_jac::ClosureNoJac, linear_closure::LinearClosure,
unit::UnitCallable, LinearOp, NonLinearOp, Op,
};
use op::{closure::Closure, linear_closure::LinearClosure, LinearOp, NonLinearOp, Op};
use scalar::{IndexType, Scalar, Scale};
use solver::SolverProblem;
use vector::{Vector, VectorCommon, VectorIndex, VectorRef, VectorView, VectorViewMut};
Expand Down
1 change: 1 addition & 0 deletions src/nonlinear_solver/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ impl<C: Op> Convergence<C> {
}

pub mod newton;
pub mod root;

//tests
#[cfg(test)]
Expand Down
198 changes: 198 additions & 0 deletions src/nonlinear_solver/root.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
use std::cell::RefCell;

use crate::{
scalar::{IndexType, Scalar},
NonLinearOp, Vector,
};
use anyhow::Result;
use num_traits::{abs, One, Zero};

pub struct RootFinder<V: Vector> {
t0: RefCell<V::T>,
g0: RefCell<V>,
g1: RefCell<V>,
gmid: RefCell<V>,
}

impl<V: Vector> RootFinder<V> {
pub fn new(n: usize) -> Self {
Self {
t0: RefCell::new(V::T::zero()),
g0: RefCell::new(V::zeros(n)),
g1: RefCell::new(V::zeros(n)),
gmid: RefCell::new(V::zeros(n)),
}
}

/// Set the lower boundary of the root search.
/// This function should be called first after [Self::new]
pub fn init(&self, root_fn: &impl NonLinearOp<V = V, T = V::T>, y: &V, t: V::T) {
root_fn.call_inplace(y, t, &mut self.g0.borrow_mut());
self.t0.replace(t);
}

/// Set the upper boundary of the root search and checks for a zero crossing.
/// If a zero crossing is found, the index of the crossing is returned
///
/// This function assumes that g0 and t0 have already beeen set via [Self::init]
/// or previous iterations of [Self::check_root]
///
/// We find the root of a function using the method proposed by Sundials [docs](https://sundials.readthedocs.io/en/latest/cvode/Mathematics_link.html#rootfinding)
pub fn check_root(
&self,
interpolate: &impl Fn(V::T) -> Result<V>,
root_fn: &impl NonLinearOp<V = V, T = V::T>,
y: &V,
t: V::T,
) -> Option<V::T> {
let g1 = &mut *self.g1.borrow_mut();
let g0 = &mut *self.g0.borrow_mut();
let gmid = &mut *self.gmid.borrow_mut();
root_fn.call_inplace(y, t, g1);

let sign_change_fn = |mut acc: (bool, V::T, i32), g0: V::T, g1: V::T, i: IndexType| {
if g1 == V::T::zero() {
acc.0 = true;
} else if g0 * g1 < V::T::zero() {
let gfrac = abs(g1 / (g1 - g0));
if gfrac > acc.1 {
acc.1 = gfrac;
acc.2 = i32::try_from(i).unwrap();
}
}
acc
};
let (rootfnd, _gfracmax, imax) =
(*g0).binary_fold(g1, (false, V::T::zero(), -1), sign_change_fn);

// if no sign change we don't need to find the root
if imax < 0 {
// setup g0 for next iteration
std::mem::swap(g0, g1);
self.t0.replace(t);
return if rootfnd {
// found a root at the upper boundary and no other sign change, return the root
Some(t)
} else {
// no root found or sign change, return None
None
};
}

// otherwise we need to do the modified secant method to find the root
let mut imax = IndexType::try_from(imax).unwrap();
let mut alpha = V::T::one();
let mut sign_change = [false, true];
let mut i = 0;
let mut t1 = t;
let mut t0 = *self.t0.borrow();
let tol = V::T::from(100.0) * V::T::EPSILON * (abs(t1) + abs(t1 - t0));
let half = V::T::from(0.5);
let double = V::T::from(2.0);
let five = V::T::from(5.0);
let pntone = V::T::from(0.1);
while abs(t1 - t0) > tol {
let mut t_mid = t1 - (t1 - t0) * g1[imax] / (g1[imax] - alpha * g0[imax]);

// adjust t_mid away from the boundaries
if abs(t_mid - t0) < half * tol {
let fracint = abs(t1 - t0) / tol;
let fracsub = if fracint > five {
pntone
} else {
half / fracint
};
t_mid = t0 + fracsub * (t1 - t0);
}
if abs(t1 - t_mid) < half * tol {
let fracint = abs(t1 - t0) / tol;
let fracsub = if fracint > five {
pntone
} else {
half / fracint
};
t_mid = t1 - fracsub * (t1 - t0);
}

let ymid = interpolate(t_mid).unwrap();
root_fn.call_inplace(&ymid, t_mid, gmid);

let (rootfnd, _gfracmax, imax_i32) =
(*g0).binary_fold(gmid, (false, V::T::zero(), -1), sign_change_fn);
let lower = imax_i32 >= 0;

if lower {
// Sign change found in (tlo,tmid); replace thi with tmid.
t1 = t_mid;
imax = IndexType::try_from(imax_i32).unwrap();
std::mem::swap(g1, gmid);
} else if rootfnd {
// we are returning so make sure g0 is set for next iteration
root_fn.call_inplace(y, t, g0);

// No sign change in (tlo,tmid), but g = 0 at tmid; return root tmid.
return Some(t_mid);
} else {
// No sign change in (tlo,tmid), and no zero at tmid. Sign change must be in (tmid,thi). Replace tlo with tmid.
t0 = t_mid;
std::mem::swap(g0, gmid);
}

sign_change[i % 2] = lower;
if i >= 2 {
alpha = if sign_change[0] != sign_change[1] {
V::T::one()
} else if sign_change[0] {
half * alpha
} else {
double * alpha
};
}
i += 1;
}
// we are returning so make sure g0 is set for next iteration
root_fn.call_inplace(y, t, g0);
Some(t1)
}
}

#[cfg(test)]
mod tests {
use std::rc::Rc;

use crate::{ClosureNoJac, RootFinder, Vector};
use anyhow::Result;

#[test]
fn test_root() {
type V = nalgebra::DVector<f64>;
type M = nalgebra::DMatrix<f64>;
let interpolate = |t: f64| -> Result<V> { Ok(Vector::from_vec(vec![t])) };
let root_fn = ClosureNoJac::<M, _>::new(
|y: &V, _p: &V, _t: f64, g: &mut V| {
g[0] = y[0] - 0.4;
},
1,
1,
Rc::new(V::zeros(0)),
);

// check no root
let root_finder = RootFinder::new(1);
root_finder.init(&root_fn, &Vector::from_vec(vec![0.0]), 0.0);
let root =
root_finder.check_root(&interpolate, &root_fn, &Vector::from_vec(vec![0.3]), 0.3);
assert_eq!(root, None);

// check root
let root_finder = RootFinder::new(1);
root_finder.init(&root_fn, &Vector::from_vec(vec![0.0]), 0.0);
let root =
root_finder.check_root(&interpolate, &root_fn, &Vector::from_vec(vec![1.3]), 1.3);
if let Some(root) = root {
assert!((root - 0.4).abs() < 1e-10);
} else {
unreachable!();
}
}
}
75 changes: 71 additions & 4 deletions src/ode_solver/bdf/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,18 @@ use std::rc::Rc;

use anyhow::{anyhow, Result};

use num_traits::{One, Pow, Zero};
use num_traits::{abs, One, Pow, Zero};
use serde::Serialize;

use crate::{
matrix::{default_solver::DefaultSolver, Matrix, MatrixRef},
nonlinear_solver::root::RootFinder,
op::bdf::BdfCallable,
scalar::scale,
vector::DefaultDenseMatrix,
DenseMatrix, IndexType, MatrixViewMut, NewtonNonlinearSolver, NonLinearOp, NonLinearSolver,
OdeSolverMethod, OdeSolverProblem, OdeSolverState, Op, Scalar, SolverProblem, Vector,
VectorRef, VectorView, VectorViewMut,
OdeSolverMethod, OdeSolverProblem, OdeSolverState, OdeSolverStopReason, Op, Scalar,
SolverProblem, Vector, VectorRef, VectorView, VectorViewMut,
};

pub mod faer;
Expand Down Expand Up @@ -78,6 +79,8 @@ pub struct Bdf<
error_const: Vec<Eqn::T>,
statistics: BdfStatistics<Eqn::T>,
state: Option<OdeSolverState<Eqn::V>>,
tstop: Option<Eqn::T>,
root_finder: Option<RootFinder<Eqn::V>>,
}

impl<Eqn> Default
Expand Down Expand Up @@ -112,6 +115,8 @@ where
u: <M<Eqn::V> as Matrix>::zeros(Self::MAX_ORDER + 1, Self::MAX_ORDER + 1),
statistics: BdfStatistics::default(),
state: None,
tstop: None,
root_finder: None,
}
}
}
Expand Down Expand Up @@ -249,6 +254,26 @@ where
};
(y_predict, t_new)
}

fn handle_tstop(&mut self, tstop: Eqn::T) -> Result<Option<OdeSolverStopReason<Eqn::T>>> {
// check if the we are at tstop
let state = self.state.as_ref().unwrap();
let troundoff = Eqn::T::from(100.0) * Eqn::T::EPSILON * (abs(state.t) + abs(state.h));
if abs(state.t - tstop) <= troundoff {
self.tstop = None;
return Ok(Some(OdeSolverStopReason::TstopReached));
} else if tstop < state.t - troundoff {
self.tstop = None;
return Err(anyhow!("tstop is before current time"));
}

// check if the next step will be beyond tstop, if so adjust the step size
if state.t + state.h > tstop + troundoff {
let factor = (tstop - state.t) / state.h;
self._update_step_size(factor);
}
Ok(None)
}
}

impl<M: DenseMatrix<T = Eqn::T, V = Eqn::V>, Eqn: OdeEquations, Nls> OdeSolverMethod<Eqn>
Expand Down Expand Up @@ -369,9 +394,17 @@ where

// store state
self.state = Some(state);
if let Some(root_fn) = problem.eqn.root() {
let state = self.state.as_ref().unwrap();
self.root_finder = Some(RootFinder::new(root_fn.nout()));
self.root_finder
.as_ref()
.unwrap()
.init(root_fn.as_ref(), &state.y, state.t);
}
}

fn step(&mut self) -> Result<()> {
fn step(&mut self) -> Result<OdeSolverStopReason<Eqn::T>> {
let mut d: Eqn::V;
let mut safety: Eqn::T;
let mut error_norm: Eqn::T;
Expand All @@ -380,6 +413,7 @@ where
if self.state.is_none() {
return Err(anyhow!("State not set"));
}

let (mut y_predict, mut t_new) = self._predict_forward();

// loop until step is accepted
Expand Down Expand Up @@ -531,6 +565,39 @@ where
}
self._update_step_size(factor);
}

// check for root within accepted step
if let Some(root_fn) = self.problem().as_ref().unwrap().eqn.root() {
let ret = self.root_finder.as_ref().unwrap().check_root(
&|t| self.interpolate(t),
root_fn.as_ref(),
&self.state.as_ref().unwrap().y,
self.state.as_ref().unwrap().t,
);
if let Some(root) = ret {
return Ok(OdeSolverStopReason::RootFound(root));
}
}

if let Some(tstop) = self.tstop {
if let Some(reason) = self.handle_tstop(tstop).unwrap() {
return Ok(reason);
}
}

// just a normal step, no roots or tstop reached
Ok(OdeSolverStopReason::InternalTimestep)
}

fn set_stop_time(&mut self, tstop: <Eqn as OdeEquations>::T) -> Result<()> {
self.tstop = Some(tstop);
if let Some(OdeSolverStopReason::TstopReached) = self.handle_tstop(tstop)? {
self.tstop = None;
return Err(anyhow!(
"tstop is at or before current time t = {}",
self.state.as_ref().unwrap().t
));
}
Ok(())
}
}
Loading

0 comments on commit 7e1cd1e

Please sign in to comment.