Skip to content

Commit

Permalink
First proof of concept of any Rust vector ↔ N_Vector
Browse files Browse the repository at this point in the history
  • Loading branch information
Chris00 committed Jan 1, 2024
1 parent e2217cf commit 5cced85
Show file tree
Hide file tree
Showing 7 changed files with 742 additions and 442 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use sundials::{context, cvode::CVode};
fn main() -> Result<(), Box<sundials::Error>> {
let ctx = context!()?;
let mut ode = CVode::adams(ctx, 0., &[0.], |t, u, du| *du = [1.])?;
let (u1, _) = ode.solution(1.);
let (u1, _) = ode.solution(0., &[0.], 1.);
assert_eq!(u1[0], 1.);
Ok(())
}
Expand Down
10 changes: 7 additions & 3 deletions examples/basic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,12 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
let ctx = context!().unwrap();
let mut ode = CVode::adams(ctx, 0., &[0.],
|_t, _u, du| *du = [1.])?;
let mut u1 = [f64::NAN];
ode.solve(1., &mut u1);
assert_eq!(u1[0], 1.);
let mut u = [f64::NAN];
let (t, st) = ode.step(1., &mut u);
println!("t = {t:e}, u = {u:?}, status: {st:?}");
assert_eq!(u[0], t);
let st = ode.solve(1., &mut u);
println!("t = 1., u = {u:?}, status: {st:?}");
assert_eq!(u[0], 1.);
Ok(())
}
114 changes: 74 additions & 40 deletions src/cvode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@
//! use sundials::{context, cvode::CVode};
//! let ctx = context!()?;
//! let mut ode = CVode::adams(ctx, 0., &[0.], |t, u, du| *du = [1.])?;
//! let (u1, _) = ode.solution(1.);
//! let (u1, _) = ode.solution(0., &[0.], 1.);
//! assert_eq!(u1[0], 1.);
//! # Ok::<(), sundials::Error>(())
//! ```
use std::{
ffi::{c_int, c_void, c_long},
pin::Pin,
marker::PhantomData,
marker::PhantomData, ptr,
};
use sundials_sys::*;
use super::{
Expand Down Expand Up @@ -54,8 +54,8 @@ where V: Vector {
atol: f64,
// We hold `Matrix` and `LinSolver` so they are freed when `CVode`
// is dropped.
matrix: Matrix,
linsolver: LinSolver,
matrix: Option<Matrix>,
linsolver: Option<LinSolver>,
rootsfound: Vec<c_int>, // cache, with len() == number of eq
user_data: Pin<Box<UserData<F, G>>>,
}
Expand Down Expand Up @@ -85,8 +85,9 @@ where Ctx: Context,
}

fn new_with_fn<G1>(ctx: Ctx, cvode_mem: CVodeMem, t0: f64,
rtol: f64, atol: f64, matrix: Matrix, linsolver: LinSolver,
f: F, g: G1, ng: usize
rtol: f64, atol: f64,
matrix: Option<Matrix>, linsolver: Option<LinSolver>,
f: F, g: G1, ng: usize
) -> CVode<Ctx, V, F, G1> {
// FIXME: One can move `f` and `g` and set the user data
// before calling solving functions.
Expand All @@ -107,7 +108,7 @@ where Ctx: Context,
impl<Ctx, V, F> CVode<Ctx, V, F, ()>
where Ctx: Context,
V: Vector + Sized,
F: FnMut(f64, V::View<'_>, V::ViewMut<'_>)
F: FnMut(f64, &V, &mut V)
{
/// Callback for the right-hand side of the equation.
extern "C" fn cvrhs(t: f64, nvy: N_Vector, nvdy: N_Vector,
Expand Down Expand Up @@ -142,11 +143,11 @@ where Ctx: Context,
if cvode_mem.0.is_null() {
return Err(Error::Fail{name, msg: "Allocation failed"})
}
let n = y0.len();
// let n = V::len(y0);
// SAFETY: Once `y0` has been passed to `CVodeInit`, it is
// copied to internal structures and thus can be freed.
let y0 =
match unsafe { y0.to_nvector(ctx.as_ptr()) } {
match unsafe { V::as_nvector(y0, ctx.as_ptr()) } {
Some(y0) => y0,
None => panic!("The context of y0 is not the same as the \
context of the CVode solver."),
Expand All @@ -169,18 +170,18 @@ where Ctx: Context,
// Set default tolerances (otherwise the solver will complain).
unsafe { CVodeSStolerances(
cvode_mem.0, rtol, atol); }
// Set the default linear solver (FIXME: configurable)
let mat = Matrix::new(name, &ctx, n, n)?;
let linsolver = unsafe { LinSolver::new(
name, ctx.as_ptr(), V::as_ptr(&y0) as *mut _, &mat)? };
let r = unsafe {
CVodeSetLinearSolver(cvode_mem.0, linsolver.0, mat.0) };
// Set the default linear solver to one that does not require
// the `…nvgetarraypointer` on vectors (FIXME: configurable)
let linsolver = unsafe { LinSolver::spgmr(
name, ctx.as_ptr(), V::as_ptr(&y0) as *mut _)? };
let r = unsafe { CVodeSetLinearSolver(
cvode_mem.0, linsolver.0, ptr::null_mut()) };
if r != CVLS_SUCCESS as i32 {
return Err(Error::Fail{name, msg: "could not attach linear solver"})
}
Ok(Self::new_with_fn(
ctx, cvode_mem, t0,
rtol, atol, mat, linsolver,
rtol, atol, None, Some(linsolver),
f, (), 0))
}

Expand Down Expand Up @@ -269,7 +270,7 @@ where Ctx: Context,
/// where `N` is known at compile time.
extern "C" fn cvroot1<const N: usize, G1>(
t: f64, y: N_Vector, gout: *mut f64, user_data: *mut c_void) -> c_int
where G1: FnMut(f64, V::View<'_>, &mut [f64; N]) {
where G1: FnMut(f64, &V, &mut [f64; N]) {
// Protect against unwinding in C code.
match std::panic::catch_unwind(|| {
let u = unsafe { &mut *(user_data as *mut UserData<F, G1>) };
Expand All @@ -290,7 +291,7 @@ where Ctx: Context,
/// found while the IVP is being solved.
///
pub fn root<const M: usize, G1>(self, g: G1) -> CVode<Ctx, V, F, G1>
where G1: FnMut(f64, V::View<'_>, &mut [f64; M]) {
where G1: FnMut(f64, &V, &mut [f64; M]) {
// FIXME: Do we want a second (because it will not work when V
// = [f64;N] since the number of equations is usually not the
// same as the dimension of the problem) function accepting
Expand Down Expand Up @@ -349,21 +350,23 @@ where Ctx: Context,
/// - A root of one of the root functions was found both at a point `t`
/// and also very near `t`.
pub fn solve(&mut self, t: f64, y: &mut V) -> CVStatus {
Self::integrate(self, t, y, CV_NORMAL)
Self::integrate(self, t, y, CV_NORMAL).1
}

/// Same as [`CVode::solve`] but only perform one time step in the
/// direction of `t`.
pub fn step(&mut self, t: f64, y: &mut V) -> CVStatus {
pub fn step(&mut self, t: f64, y: &mut V) -> (f64, CVStatus) {
Self::integrate(self, t, y, CV_ONE_STEP)
}

fn integrate(&mut self, t: f64, y: &mut V, itask: c_int) -> CVStatus {
fn integrate(
&mut self, t: f64, y: &mut V, itask: c_int
) -> (f64, CVStatus) {
// Safety: `yout` does not escape this function and so will
// not outlive `self.ctx`.
//let n = y.len();
let yout =
match unsafe { V::to_nvector_mut(y, self.ctx.as_ptr()) }{
match unsafe { V::as_mut_nvector(y, self.ctx.as_ptr()) }{
Some(yout) => yout,
None => panic!("The context of the output vector y is not \
the same as the context of CVode."),
Expand All @@ -374,7 +377,7 @@ where Ctx: Context,
t,
V::as_mut_ptr(&yout),
&mut tret, itask) };
match r {
let status = match r {
CV_SUCCESS => CVStatus::Ok,
CV_TSTOP_RETURN => CVStatus::Tstop(tret),
CV_ROOT_RETURN => {
Expand All @@ -388,28 +391,59 @@ where Ctx: Context,
}
CV_MEM_NULL | CV_NO_MALLOC => unreachable!(),
CV_ILL_INPUT => CVStatus::IllInput,
CV_TOO_CLOSE => CVStatus::TooClose,
CV_TOO_MUCH_WORK => CVStatus::TooMuchWork,
CV_TOO_MUCH_ACC => CVStatus::TooMuchAcc,
CV_ERR_FAILURE => CVStatus::ErrFailure,
CV_CONV_FAILURE => CVStatus::ConvFailure,
CV_LINIT_FAIL => panic!("CV_LINIT_FAIL"),
CV_LSETUP_FAIL => panic!("CV_LSETUP_FAIL"),
CV_LSOLVE_FAIL => panic!("CV_LSOLVE_FAIL"),
CV_LINIT_FAIL => panic!("The linear solver interface’s \
initialization function failed."),
CV_LSETUP_FAIL => panic!("The linear solver interface’s setup \
function failed in an unrecoverable manner."),
CV_LSOLVE_FAIL => panic!("The linear solver interface’s solve \
function failed in an unrecoverable manner."),
CV_CONSTR_FAIL => panic!("The inequality constraints were \
violated and the solver was unable to recover."),
CV_RHSFUNC_FAIL => panic!("The right-hand side function failed \
in an unrecoverable manner."),
CV_REPTD_RHSFUNC_ERR => panic!("Convergence test failures \
occurred too many times due to repeated recoverable errors \
in the right-hand side function."),
CV_UNREC_RHSFUNC_ERR => panic!("The right-hand function had a \
recoverable error, but no recovery was possible."),
CV_RTFUNC_FAIL => panic!("The root function failed"),
CV_TOO_CLOSE => CVStatus::TooClose,
_ => panic!("sundials::CVode: unexpected return code {}", r),
}
};
(tret, status)
}

}

impl<const N: usize, Ctx, F, G> CVode<Ctx, [f64; N], F, G>
where Ctx: Context {
/// Return the solution at time `t`.
// FIXME: provide it for any type `V` — which must implement a
// creation function.
pub fn solution(&mut self, t: f64) -> ([f64; N], CVStatus) {
let mut y = [f64::NAN; N];
impl<Ctx, V, F, G> CVode<Ctx, V, F, G>
where Ctx: Context,
V: Vector
{
/// Return the solution with initial conditions (`t0`, `y0`) at
/// time `t`. This is a convenience function.
pub fn solution(&mut self, t0: f64, y0: &V, t: f64) -> (V, CVStatus) {
let mut y = y0.clone();
// Avoid CVStatus::TooClose
if t == t0 {
return (y, CVStatus::Ok)
}
let y0 = match unsafe { V::as_nvector(y0, self.ctx.as_ptr()) } {
Some(y0) => y0,
None => panic!("The context of `y0` differs from the one \
of the ODE solver."),
};
// Reinitialize to allow any time `t`, even if not monotonic
// w.r.t. previous calls.
let ret = unsafe {
CVodeReInit(self.cvode_mem.0, t0, V::as_ptr(&y0) as *mut _)
};
if ret != CV_SUCCESS {
panic!("CVodeReInit returned code {ret}. Please report.");
}
let cv = self.solve(t, &mut y);
(y, cv)
}
Expand All @@ -436,7 +470,7 @@ mod tests {
let ctx = context!().unwrap();
let mut ode = CVode::adams(ctx, 0., &[0.],
|_,_, du| *du = [1.]).unwrap();
assert_eq!(ode.solution(1.).0, [1.]);
assert_eq!(ode.solution(0., &[0.], 1.).0, [1.]);
}

#[test]
Expand Down Expand Up @@ -477,7 +511,7 @@ mod tests {
let ode = move || {
CVode::adams(ctx, 0., &init, |_,_, du| *du = [1.]).unwrap()
};
assert_eq!(ode().solution(1.).0, [1.]);
assert_eq!(ode().solution(0., &init, 1.).0, [1.]);
}

#[test]
Expand All @@ -489,13 +523,13 @@ mod tests {
*du = [1., 1.]
}).unwrap()
};
assert_eq!(ode().solution(1.).0, [2., 3.]);
assert_eq!(ode().solution(0., &init, 1.).0, [2., 3.]);
let (u, cv) = ode()
.root(|_, &u, z| *z = [u[0] - 2.])
.solution(2.);
.solution(0., &init, 2.);
assert!(matches!(cv, CVStatus::Root(_,_)));
assert_eq!(u, [2., 3.]);
assert_eq!(ode().solution(2.).0, [3., 4.]);
assert_eq!(ode().solution(0., &init, 2.).0, [3., 4.]);
}

#[test]
Expand Down
11 changes: 6 additions & 5 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
//! let ctx = context!()?;
//! let mut ode = CVode::adams(ctx, 0., &[0.], |t, u, du| *du = [1.])?;
//! let (u1, _) = ode.solution(1.);
//! let (u1, _) = ode.solution(0., &[0.], 1.);
//! assert_eq!(u1[0], 1.);
//! # Ok(()) }
//! ```
Expand Down Expand Up @@ -281,7 +281,7 @@ impl Drop for Matrix {
}

impl Matrix {
fn new(
fn dense(
name: &'static str, ctx: &impl Context, m: usize, n: usize,
) -> Result<Self, Error> {
let mat = unsafe {
Expand All @@ -308,11 +308,12 @@ impl LinSolver {
///
/// # Safety
/// The return value must not outlive `ctx`.
unsafe fn new(
name: &'static str, ctx: SUNContext, vec: N_Vector, mat: &Matrix,
unsafe fn spgmr(
name: &'static str, ctx: SUNContext,
vec: N_Vector,
) -> Result<Self, Error> {
let linsolver = unsafe {
SUNLinSol_Dense(vec, mat.0, ctx) };
SUNLinSol_SPGMR(vec, SUN_PREC_NONE as _, 30, ctx) };
if linsolver.is_null() {
Err(Error::Fail{ name, msg: "linear solver allocation failed"})
} else {
Expand Down
Loading

0 comments on commit 5cced85

Please sign in to comment.