Skip to content

Commit

Permalink
Pass closures on the stack and remove the need for "Pin"
Browse files Browse the repository at this point in the history
  • Loading branch information
Chris00 committed Jan 2, 2024
1 parent cc9a2ae commit 4b5a1b1
Showing 1 changed file with 33 additions and 36 deletions.
69 changes: 33 additions & 36 deletions src/cvode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
use std::{
ffi::{c_int, c_void, c_long},
pin::Pin,
marker::PhantomData, ptr,
};
use sundials_sys::*;
Expand Down Expand Up @@ -57,7 +56,7 @@ where V: Vector {
matrix: Option<Matrix>,
linsolver: Option<LinSolver>,
rootsfound: Vec<c_int>, // cache, with len() == number of eq
user_data: Pin<Box<UserData<F, G>>>,
user_data: UserData<F, G>,
}

// The user-data may be updated according to the type of `G`.
Expand All @@ -84,24 +83,15 @@ where Ctx: Context,
self.ctx
}

fn new_with_fn<G1>(ctx: Ctx, cvode_mem: CVodeMem, t0: f64,
rtol: f64, atol: f64,
matrix: Option<Matrix>, linsolver: Option<LinSolver>,
f: F, g: G1, ng: usize
) -> CVode<Ctx, V, F, G1> {
// FIXME: One can move `f` and `g` and set the user data
// before calling solving functions.
let user_data = Box::pin(UserData { f, g });
let user_data_ptr =
user_data.as_ref().get_ref() as *const _ as *mut c_void;
unsafe { CVodeSetUserData(cvode_mem.0, user_data_ptr) };
let mut rootsfound = Vec::with_capacity(ng);
rootsfound.resize(ng, 0);
CVode {
ctx, cvode_mem,
t0, vec: PhantomData,
rtol, atol, matrix, linsolver, rootsfound, user_data,
}
/// Set the user data for CVode. Since the closures are on the
/// stack, their location changes. One must let Sundials know
/// about the new locations before launching a solver.
fn update_user_data(&mut self) {
let ptr = &self.user_data as *const _;
let ret = unsafe { CVodeSetUserData(
self.cvode_mem.0,
ptr as *mut c_void) };
debug_assert_eq!(ret, 0);
}
}

Expand Down Expand Up @@ -182,10 +172,13 @@ where Ctx: Context,
msg: "could not attach linear solver"
})
}
Ok(Self::new_with_fn(
ctx, cvode_mem, t0,
rtol, atol, None, Some(linsolver),
f, (), 0))
Ok(Self {
ctx, cvode_mem,
t0, vec: PhantomData,
rtol, atol, matrix: None, linsolver: Some(linsolver),
rootsfound: vec![],
user_data: UserData { f, g: () }
})
}

/// Solver using the Adams linear multistep method. Recommended
Expand Down Expand Up @@ -266,8 +259,7 @@ where V: Vector {
impl<Ctx, V, F, G> CVode<Ctx, V, F, G>
where Ctx: Context,
V: Vector,
F: FnMut(f64, &V, &mut V) + Unpin,
G: Unpin
F: FnMut(f64, &V, &mut V),
{
/// Callback for the root-finding callback for `N` functions,
/// where `N` is known at compile time.
Expand All @@ -293,8 +285,8 @@ where Ctx: Context,
/// 0 ≤ `i` < `N` (given by `g`(t,y, [g₁,...,gₙ])) are to be
/// found while the IVP is being solved.
///
pub fn root<const M: usize, G1>(self, g: G1) -> CVode<Ctx, V, F, G1>
where G1: FnMut(f64, &V, &mut [f64; M]) {
pub fn root<const M: usize, R>(self, g: R) -> CVode<Ctx, V, F, R>
where R: FnMut(f64, &V, &mut [f64; M]) {
// FIXME: Do we want a second (because it will not work when V
// = [f64;N] since the number of equations is usually not the
// same as the dimension of the problem) function accepting
Expand All @@ -303,16 +295,20 @@ where Ctx: Context,
// returned vector at each invocation.
let r = unsafe {
CVodeRootInit(self.cvode_mem.0, M as _,
Some(Self::cvroot1::<M, G1>)) };
Some(Self::cvroot1::<M, R>)) };
if r == CV_MEM_FAIL {
panic!("Sundials::CVode::root: memory allocation failed.");
panic!("Sundials::cvode::CVode::root: memory allocation failed.");
}
let mut rootsfound = Vec::with_capacity(M);
rootsfound.resize(M, 0);
CVode {
ctx: self.ctx, cvode_mem: self.cvode_mem,
t0: self.t0, vec: PhantomData,
rtol: self.rtol, atol: self.atol,
matrix: self.matrix, linsolver: self.linsolver,
rootsfound,
user_data: UserData { f: self.user_data.f, g },
}
let u = *Pin::into_inner(self.user_data);
Self::new_with_fn(
self.ctx,
self.cvode_mem, self.t0,
self.rtol, self.atol,
self.matrix, self.linsolver, u.f, g, M)
}
}

Expand Down Expand Up @@ -365,6 +361,7 @@ where Ctx: Context,
fn integrate(
&mut self, t: f64, y: &mut V, itask: c_int
) -> (f64, CVStatus) {
self.update_user_data();
// Safety: `yout` does not escape this function and so will
// not outlive `self.ctx`.
//let n = y.len();
Expand Down

0 comments on commit 4b5a1b1

Please sign in to comment.