Skip to content

Commit

Permalink
Skeleton for Context to support MPI
Browse files Browse the repository at this point in the history
  • Loading branch information
Chris00 committed Dec 24, 2023
1 parent 86bffde commit 76ca952
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 41 deletions.
6 changes: 2 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,9 @@ license = "BSD-3-Clause"
keywords = ["ODE", "math", "numerics", "simulation", "science"]
categories = ["mathematics", "science"]

[build-dependencies]
sundials-sys = "0.2.5"

[dependencies]
sundials-sys = "0.2.5"
sundials-sys = "0.3.0"
mpi = { version = "0.7.0", optional = true, default-features = false }

[dev-dependencies]
eyre = "0.6.8"
11 changes: 0 additions & 11 deletions build.rs

This file was deleted.

65 changes: 39 additions & 26 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,11 @@ impl std::error::Error for Error {}
//
// SUNContext

/// Context is an object associated with the thread of execution.
pub struct Context(
#[cfg(has_context)]
SUNContext,
#[cfg(not(has_context))]
PhantomData<()>,
);

#[cfg(has_context)]
impl Drop for Context {
fn drop(&mut self) {
// FIXME: Make sure the remark about MPI is followed (when
Expand All @@ -82,19 +79,28 @@ impl Drop for Context {
}

impl Context {
#[cfg(has_context)]
fn new() -> Result<Self, Error> {
unsafe fn with_communicator(
comm: *mut std::os::raw::c_void
) -> Result<Self, Error> {
let mut ctx: SUNContext = ptr::null_mut();
if unsafe { SUNContext_Create(ptr::null_mut(),
&mut ctx as *mut _) } < 0 {
if unsafe { SUNContext_Create(comm, &mut ctx as *mut _) } < 0 {
return Err(Error::Fail { name: "Context::new",
msg: "Failed to create a context" })
}
Ok(Context(ctx))
}
#[cfg(not(has_context))]
fn new() -> Result<Self, Error> { Ok(Context(PhantomData)) }

fn new() -> Result<Self, Error> {
unsafe { Self::with_communicator(ptr::null_mut()) }
}

#[cfg(feature = "mpi")]
fn with_mpi(conn: impl mpi::topology::Communicator) -> Self {
// https://crates.io/crates/mpi-sys https://crates.io/crates/mpi
todo!()
let comm = ptr::null_mut();
unsafe { Self::with_communicator(comm) }
}
}

////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -141,14 +147,14 @@ unsafe impl<const N: usize> NVector for [f64; N] {
fn to_nvector(v: &Self, ctx: &Context) -> N_Vector {
unsafe { N_VMake_Serial(N.try_into().unwrap(),
v.as_ptr() as *mut _,
#[cfg(has_context)] ctx.0) }
ctx.0) }
}

#[inline]
fn to_nvector_mut(v: &mut Self, ctx: &Context) -> N_Vector {
unsafe { N_VMake_Serial(N.try_into().unwrap(),
v.as_mut_ptr(),
#[cfg(has_context)] ctx.0) }
ctx.0) }
}

#[inline]
Expand Down Expand Up @@ -223,9 +229,10 @@ struct UserData<F, G> {
impl<'a, V, F, G> CVode<'a, V, F, G> {
#[inline]
fn new<G1>(ctx: Context,
cvode_mem: CVodeMem, t0: f64, y0: SharedNVector,
rtol: f64, atol: f64, matrix: Matrix, linsolver: LinSolver,
f: F, g: G1, ng: usize) -> CVode<'a, V, F, G1> {
cvode_mem: CVodeMem, t0: f64, y0: SharedNVector,
rtol: f64, atol: f64, matrix: Matrix, linsolver: LinSolver,
f: F, g: G1, ng: usize
) -> CVode<'a, V, F, G1> {
let user_data = Box::pin(UserData { f, g });
let user_data_ptr =
user_data.as_ref().get_ref() as *const _ as *mut c_void;
Expand All @@ -242,10 +249,12 @@ impl<'a, V, F, G> CVode<'a, V, F, G> {

impl<'a, V, F> CVode<'a, V, F, ()>
where V: NVector,
F: FnMut(f64, &V, &mut V) {
F: FnMut(f64, &V, &mut V)
{
/// Callback for the right-hand side of the equation.
extern "C" fn cvrhs(t: f64, nvy: N_Vector, nvdy: N_Vector,
user_data: *mut c_void) -> c_int {
user_data: *mut c_void
) -> c_int {
// Protect against unwinding in C code.
match std::panic::catch_unwind(|| {
// Get f from user_data, whatever the type of g.
Expand All @@ -266,12 +275,13 @@ where V: NVector,
/// Initialize a CVode with method `llm`.
#[inline]
fn init(name: &'static str, lmm: c_int,
t0: f64, y0: &'a V, f: F) -> Result<Self, Error> {
t0: f64, y0: &'a V, f: F
) -> Result<Self, Error> {
let ctx = Context::new()?;
// FIXME: who will reclaim the N_Vector from y0? Need to wrap it to
// enable a `Drop`.
let cvode_mem = unsafe {
CVodeMem(CVodeCreate(lmm, #[cfg(has_context)] ctx.0)) };
CVodeMem(CVodeCreate(lmm, ctx.0)) };
if cvode_mem.0.is_null() {
return Err(Error::Fail{name, msg: "Allocation failed"})
}
Expand All @@ -286,18 +296,16 @@ where V: NVector,
return Err(Error::Fail{name, msg})
}
let rtol = 1e-6;
let atol = 1e-12;
let atol = 1e-12;
// Set default tolerances (otherwise the solver will complain).
unsafe { CVodeSStolerances(cvode_mem.0, rtol, atol); }
unsafe { CVodeSStolerances(cvode_mem.0, rtol, atol); }
// Set the default linear solver (FIXME: configurable)
let mat = unsafe {
SUNDenseMatrix(y0.len() as _, y0.len() as _,
#[cfg(has_context)] ctx.0) };
let mat = unsafe {
SUNDenseMatrix(y0.len() as _, y0.len() as _, ctx.0) };
if mat.is_null() {
return Err(Error::Fail{name, msg: "matrix allocation failed"})
}
let linsolver = unsafe {
SUNLinSol_Dense(nvy0, mat, #[cfg(has_context)] ctx.0) };
let linsolver = unsafe { SUNLinSol_Dense(nvy0, mat, ctx.0) };
if linsolver.is_null() {
return Err(Error::Fail{
name, msg: "linear solver allocation failed"})
Expand Down Expand Up @@ -346,6 +354,8 @@ where V: NVector {

/// Specify the maximum number of steps to be taken by the solver
/// in its attempt to reach the next output time. Default: 500.
// FIXME: make sure "mxstep steps taken before reaching tout" does
// not abort the program.
pub fn mxsteps(self, n: usize) -> Self {
let n =
if n <= c_long::MAX as usize { n as _ } else { c_long::MAX };
Expand Down Expand Up @@ -419,6 +429,9 @@ pub enum CV {
/// # Solving the IVP
impl<'a, V, F, G> CVode<'a, V, F, G>
where V: NVector {
// FIXME: for arrays, it would be more convenient to return the
// array. For types that are not on the stack (i.e. not Copy),
// taking it as an additional parameter is better.
pub fn solve(&mut self, t: f64, y: &mut V) -> CV {
Self::integrate(self, t, y, CV_NORMAL)
}
Expand Down

0 comments on commit 76ca952

Please sign in to comment.