From 4b5a1b1a4a74c1de816b26a5596eb06dab06f227 Mon Sep 17 00:00:00 2001 From: Christophe Troestler Date: Tue, 2 Jan 2024 12:07:35 +0100 Subject: [PATCH] Pass closures on the stack and remove the need for "Pin" --- src/cvode.rs | 69 +++++++++++++++++++++++++--------------------------- 1 file changed, 33 insertions(+), 36 deletions(-) diff --git a/src/cvode.rs b/src/cvode.rs index af820c8..1507e7e 100644 --- a/src/cvode.rs +++ b/src/cvode.rs @@ -15,7 +15,6 @@ use std::{ ffi::{c_int, c_void, c_long}, - pin::Pin, marker::PhantomData, ptr, }; use sundials_sys::*; @@ -57,7 +56,7 @@ where V: Vector { matrix: Option, linsolver: Option, rootsfound: Vec, // cache, with len() == number of eq - user_data: Pin>>, + user_data: UserData, } // The user-data may be updated according to the type of `G`. @@ -84,24 +83,15 @@ where Ctx: Context, self.ctx } - fn new_with_fn(ctx: Ctx, cvode_mem: CVodeMem, t0: f64, - rtol: f64, atol: f64, - matrix: Option, linsolver: Option, - f: F, g: G1, ng: usize - ) -> CVode { - // FIXME: One can move `f` and `g` and set the user data - // before calling solving functions. - let user_data = Box::pin(UserData { f, g }); - let user_data_ptr = - user_data.as_ref().get_ref() as *const _ as *mut c_void; - unsafe { CVodeSetUserData(cvode_mem.0, user_data_ptr) }; - let mut rootsfound = Vec::with_capacity(ng); - rootsfound.resize(ng, 0); - CVode { - ctx, cvode_mem, - t0, vec: PhantomData, - rtol, atol, matrix, linsolver, rootsfound, user_data, - } + /// Set the user data for CVode. Since the closures are on the + /// stack, their location changes. One must let Sundials know + /// about the new locations before launching a solver. + fn update_user_data(&mut self) { + let ptr = &self.user_data as *const _; + let ret = unsafe { CVodeSetUserData( + self.cvode_mem.0, + ptr as *mut c_void) }; + debug_assert_eq!(ret, 0); } } @@ -182,10 +172,13 @@ where Ctx: Context, msg: "could not attach linear solver" }) } - Ok(Self::new_with_fn( - ctx, cvode_mem, t0, - rtol, atol, None, Some(linsolver), - f, (), 0)) + Ok(Self { + ctx, cvode_mem, + t0, vec: PhantomData, + rtol, atol, matrix: None, linsolver: Some(linsolver), + rootsfound: vec![], + user_data: UserData { f, g: () } + }) } /// Solver using the Adams linear multistep method. Recommended @@ -266,8 +259,7 @@ where V: Vector { impl CVode where Ctx: Context, V: Vector, - F: FnMut(f64, &V, &mut V) + Unpin, - G: Unpin + F: FnMut(f64, &V, &mut V), { /// Callback for the root-finding callback for `N` functions, /// where `N` is known at compile time. @@ -293,8 +285,8 @@ where Ctx: Context, /// 0 ≤ `i` < `N` (given by `g`(t,y, [g₁,...,gₙ])) are to be /// found while the IVP is being solved. /// - pub fn root(self, g: G1) -> CVode - where G1: FnMut(f64, &V, &mut [f64; M]) { + pub fn root(self, g: R) -> CVode + where R: FnMut(f64, &V, &mut [f64; M]) { // FIXME: Do we want a second (because it will not work when V // = [f64;N] since the number of equations is usually not the // same as the dimension of the problem) function accepting @@ -303,16 +295,20 @@ where Ctx: Context, // returned vector at each invocation. let r = unsafe { CVodeRootInit(self.cvode_mem.0, M as _, - Some(Self::cvroot1::)) }; + Some(Self::cvroot1::)) }; if r == CV_MEM_FAIL { - panic!("Sundials::CVode::root: memory allocation failed."); + panic!("Sundials::cvode::CVode::root: memory allocation failed."); + } + let mut rootsfound = Vec::with_capacity(M); + rootsfound.resize(M, 0); + CVode { + ctx: self.ctx, cvode_mem: self.cvode_mem, + t0: self.t0, vec: PhantomData, + rtol: self.rtol, atol: self.atol, + matrix: self.matrix, linsolver: self.linsolver, + rootsfound, + user_data: UserData { f: self.user_data.f, g }, } - let u = *Pin::into_inner(self.user_data); - Self::new_with_fn( - self.ctx, - self.cvode_mem, self.t0, - self.rtol, self.atol, - self.matrix, self.linsolver, u.f, g, M) } } @@ -365,6 +361,7 @@ where Ctx: Context, fn integrate( &mut self, t: f64, y: &mut V, itask: c_int ) -> (f64, CVStatus) { + self.update_user_data(); // Safety: `yout` does not escape this function and so will // not outlive `self.ctx`. //let n = y.len();