Skip to content

Commit

Permalink
fix diffsl tests
Browse files Browse the repository at this point in the history
  • Loading branch information
martinjrobins committed Nov 1, 2024
1 parent b37c2d1 commit 8db78bb
Show file tree
Hide file tree
Showing 8 changed files with 23 additions and 77 deletions.
33 changes: 0 additions & 33 deletions src/ode_solver/bdf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1250,15 +1250,13 @@ mod test {
let (problem, soln) = exponential_decay_problem::<M>(false);
test_ode_solver_no_sens(&mut s, &problem, soln, None, false);
insta::assert_yaml_snapshot!(s.get_statistics(), @r###"
---
number_of_linear_solver_setups: 11
number_of_steps: 47
number_of_error_test_failures: 0
number_of_nonlinear_solver_iterations: 82
number_of_nonlinear_solver_fails: 0
"###);
insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###"
---
number_of_calls: 84
number_of_jac_muls: 2
number_of_matrix_evals: 1
Expand Down Expand Up @@ -1288,15 +1286,13 @@ mod test {
let (problem, soln) = exponential_decay_problem::<M>(false);
test_ode_solver_no_sens(&mut s, &problem, soln, None, false);
insta::assert_yaml_snapshot!(s.get_statistics(), @r###"
---
number_of_linear_solver_setups: 11
number_of_steps: 47
number_of_error_test_failures: 0
number_of_nonlinear_solver_iterations: 82
number_of_nonlinear_solver_fails: 0
"###);
insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###"
---
number_of_calls: 84
number_of_jac_muls: 2
number_of_matrix_evals: 1
Expand All @@ -1310,15 +1306,13 @@ mod test {
let (problem, soln) = exponential_decay_problem_sens::<M>(false);
test_ode_solver(&mut s, &problem, soln, None, false, true);
insta::assert_yaml_snapshot!(s.get_statistics(), @r###"
---
number_of_linear_solver_setups: 11
number_of_steps: 44
number_of_error_test_failures: 0
number_of_nonlinear_solver_iterations: 217
number_of_nonlinear_solver_fails: 0
"###);
insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###"
---
number_of_calls: 87
number_of_jac_muls: 136
number_of_matrix_evals: 1
Expand All @@ -1332,14 +1326,12 @@ mod test {
let (problem, soln) = exponential_decay_problem_adjoint::<M>();
let adjoint_solver = test_ode_solver_adjoint(s, &problem, soln);
insta::assert_yaml_snapshot!(problem.eqn.rhs().statistics(), @r###"
---
number_of_calls: 84
number_of_jac_muls: 6
number_of_matrix_evals: 3
number_of_jac_adj_muls: 492
"###);
insta::assert_yaml_snapshot!(adjoint_solver.get_statistics(), @r###"
---
number_of_linear_solver_setups: 24
number_of_steps: 86
number_of_error_test_failures: 12
Expand All @@ -1354,14 +1346,12 @@ mod test {
let (problem, soln) = exponential_decay_with_algebraic_adjoint_problem::<M>();
let adjoint_solver = test_ode_solver_adjoint(s, &problem, soln);
insta::assert_yaml_snapshot!(problem.eqn.rhs().statistics(), @r###"
---
number_of_calls: 190
number_of_jac_muls: 24
number_of_matrix_evals: 8
number_of_jac_adj_muls: 278
"###);
insta::assert_yaml_snapshot!(adjoint_solver.get_statistics(), @r###"
---
number_of_linear_solver_setups: 32
number_of_steps: 74
number_of_error_test_failures: 15
Expand All @@ -1376,15 +1366,13 @@ mod test {
let (problem, soln) = exponential_decay_with_algebraic_problem::<M>(false);
test_ode_solver_no_sens(&mut s, &problem, soln, None, false);
insta::assert_yaml_snapshot!(s.get_statistics(), @r###"
---
number_of_linear_solver_setups: 20
number_of_steps: 41
number_of_error_test_failures: 4
number_of_nonlinear_solver_iterations: 79
number_of_nonlinear_solver_fails: 0
"###);
insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###"
---
number_of_calls: 83
number_of_jac_muls: 6
number_of_matrix_evals: 2
Expand All @@ -1407,15 +1395,13 @@ mod test {
let (problem, soln) = exponential_decay_with_algebraic_problem_sens::<M>();
test_ode_solver(&mut s, &problem, soln, None, false, true);
insta::assert_yaml_snapshot!(s.get_statistics(), @r###"
---
number_of_linear_solver_setups: 18
number_of_steps: 43
number_of_error_test_failures: 3
number_of_nonlinear_solver_iterations: 155
number_of_nonlinear_solver_fails: 0
"###);
insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###"
---
number_of_calls: 71
number_of_jac_muls: 100
number_of_matrix_evals: 3
Expand All @@ -1429,15 +1415,13 @@ mod test {
let (problem, soln) = robertson::<M>(false);
test_ode_solver_no_sens(&mut s, &problem, soln, None, false);
insta::assert_yaml_snapshot!(s.get_statistics(), @r###"
---
number_of_linear_solver_setups: 77
number_of_steps: 316
number_of_error_test_failures: 3
number_of_nonlinear_solver_iterations: 722
number_of_nonlinear_solver_fails: 19
"###);
insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###"
---
number_of_calls: 725
number_of_jac_muls: 60
number_of_matrix_evals: 20
Expand Down Expand Up @@ -1481,15 +1465,13 @@ mod test {
let (problem, soln) = robertson_sens::<M>();
test_ode_solver(&mut s, &problem, soln, None, false, true);
insta::assert_yaml_snapshot!(s.get_statistics(), @r###"
---
number_of_linear_solver_setups: 160
number_of_steps: 410
number_of_error_test_failures: 4
number_of_nonlinear_solver_iterations: 3107
number_of_nonlinear_solver_fails: 81
"###);
insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###"
---
number_of_calls: 996
number_of_jac_muls: 2495
number_of_matrix_evals: 71
Expand All @@ -1503,15 +1485,13 @@ mod test {
let (problem, soln) = robertson::<M>(true);
test_ode_solver_no_sens(&mut s, &problem, soln, None, false);
insta::assert_yaml_snapshot!(s.get_statistics(), @r###"
---
number_of_linear_solver_setups: 77
number_of_steps: 316
number_of_error_test_failures: 3
number_of_nonlinear_solver_iterations: 722
number_of_nonlinear_solver_fails: 19
"###);
insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###"
---
number_of_calls: 725
number_of_jac_muls: 63
number_of_matrix_evals: 20
Expand All @@ -1525,15 +1505,13 @@ mod test {
let (problem, soln) = robertson_ode::<M>(false, 3);
test_ode_solver_no_sens(&mut s, &problem, soln, None, false);
insta::assert_yaml_snapshot!(s.get_statistics(), @r###"
---
number_of_linear_solver_setups: 86
number_of_steps: 416
number_of_error_test_failures: 1
number_of_nonlinear_solver_iterations: 911
number_of_nonlinear_solver_fails: 15
"###);
insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###"
---
number_of_calls: 913
number_of_jac_muls: 162
number_of_matrix_evals: 18
Expand All @@ -1547,15 +1525,13 @@ mod test {
let (problem, soln) = robertson_ode_with_sens::<M>(false);
test_ode_solver(&mut s, &problem, soln, None, false, true);
insta::assert_yaml_snapshot!(s.get_statistics(), @r###"
---
number_of_linear_solver_setups: 112
number_of_steps: 467
number_of_error_test_failures: 2
number_of_nonlinear_solver_iterations: 3472
number_of_nonlinear_solver_fails: 49
"###);
insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###"
---
number_of_calls: 1041
number_of_jac_muls: 2672
number_of_matrix_evals: 45
Expand All @@ -1569,15 +1545,13 @@ mod test {
let (problem, soln) = dydt_y2_problem::<M>(false, 10);
test_ode_solver_no_sens(&mut s, &problem, soln, None, false);
insta::assert_yaml_snapshot!(s.get_statistics(), @r###"
---
number_of_linear_solver_setups: 27
number_of_steps: 161
number_of_error_test_failures: 0
number_of_nonlinear_solver_iterations: 355
number_of_nonlinear_solver_fails: 3
"###);
insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###"
---
number_of_calls: 357
number_of_jac_muls: 50
number_of_matrix_evals: 5
Expand All @@ -1591,15 +1565,13 @@ mod test {
let (problem, soln) = dydt_y2_problem::<M>(true, 10);
test_ode_solver_no_sens(&mut s, &problem, soln, None, false);
insta::assert_yaml_snapshot!(s.get_statistics(), @r###"
---
number_of_linear_solver_setups: 27
number_of_steps: 161
number_of_error_test_failures: 0
number_of_nonlinear_solver_iterations: 355
number_of_nonlinear_solver_fails: 3
"###);
insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###"
---
number_of_calls: 357
number_of_jac_muls: 15
number_of_matrix_evals: 5
Expand All @@ -1613,15 +1585,13 @@ mod test {
let (problem, soln) = gaussian_decay_problem::<M>(false, 10);
test_ode_solver_no_sens(&mut s, &problem, soln, None, false);
insta::assert_yaml_snapshot!(s.get_statistics(), @r###"
---
number_of_linear_solver_setups: 14
number_of_steps: 66
number_of_error_test_failures: 1
number_of_nonlinear_solver_iterations: 130
number_of_nonlinear_solver_fails: 0
"###);
insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###"
---
number_of_calls: 132
number_of_jac_muls: 20
number_of_matrix_evals: 2
Expand All @@ -1637,15 +1607,13 @@ mod test {
let (problem, soln) = head2d_problem::<SparseColMat<f64>, 10>();
test_ode_solver_no_sens(&mut s, &problem, soln, None, false);
insta::assert_yaml_snapshot!(s.get_statistics(), @r###"
---
number_of_linear_solver_setups: 21
number_of_steps: 167
number_of_error_test_failures: 0
number_of_nonlinear_solver_iterations: 330
number_of_nonlinear_solver_fails: 0
"###);
insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###"
---
number_of_calls: 333
number_of_jac_muls: 128
number_of_matrix_evals: 4
Expand Down Expand Up @@ -1674,7 +1642,6 @@ mod test {
let (problem, soln) = foodweb_problem::<SparseColMat<f64>, 10>();
test_ode_solver_no_sens(&mut s, &problem, soln, None, false);
insta::assert_yaml_snapshot!(s.get_statistics(), @r###"
---
number_of_linear_solver_setups: 45
number_of_steps: 161
number_of_error_test_failures: 2
Expand Down
6 changes: 4 additions & 2 deletions src/ode_solver/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -721,7 +721,7 @@ impl OdeBuilder {
}

/// Build an ODE problem from a set of equations
pub fn build_from_eqn<Eqn>(self, eqn: Rc<Eqn>) -> Result<OdeSolverProblem<Eqn>, DiffsolError>
pub fn build_from_eqn<Eqn>(self, mut eqn: Eqn) -> Result<OdeSolverProblem<Eqn>, DiffsolError>
where
Eqn: OdeEquations,
{
Expand All @@ -737,8 +737,10 @@ impl OdeBuilder {
nout,
nparams,
)?;
let p = Rc::new(Self::build_p(self.p));
eqn.set_params(p);
OdeSolverProblem::new(
eqn,
Rc::new(eqn),
Eqn::T::from(self.rtol),
atol,
self.sens_rtol.map(Eqn::T::from),
Expand Down
31 changes: 15 additions & 16 deletions src/ode_solver/diffsl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ pub type T = f64;
/// # Example
///
/// ```rust
/// use diffsol::{OdeBuilder, Bdf, OdeSolverState, OdeSolverMethod, DiffSlContext, diffsl::LlvmModule};
/// use diffsol::{OdeBuilder, Bdf, OdeSolverState, OdeSolverMethod, DiffSlContext, DiffSl, diffsl::LlvmModule};
///
/// // dy/dt = -ay
/// // y(0) = 1
Expand All @@ -32,10 +32,11 @@ pub type T = f64;
/// F { -a*u }
/// out { u }
/// ").unwrap();
/// let eqn = DiffSl::from_context(context);
/// let problem = OdeBuilder::new()
/// .rtol(1e-6)
/// .p([0.1])
/// .build_diffsl(&context).unwrap();
/// .build_from_eqn(eqn).unwrap();
/// let mut solver = Bdf::default();
/// let t = 0.4;
/// let state = OdeSolverState::new(&problem, &solver).unwrap();
Expand Down Expand Up @@ -136,14 +137,15 @@ impl<M: Matrix<T = T>, CG: CodegenModule> DiffSl<M, CG> {
ret.rhs_coloring = Some(coloring);
ret.rhs_sparsity = Some(sparsity);

let op = ret.mass().unwrap();
let non_zeros = find_matrix_non_zeros(&op, t0);
let sparsity =
M::Sparsity::try_from_indices(op.nout(), op.nstates(), non_zeros.clone())
.expect("invalid sparsity pattern");
let coloring = JacobianColoring::new(&sparsity, &non_zeros);
ret.mass_coloring = Some(coloring);
ret.mass_sparsity = Some(sparsity);
if let Some(op) = ret.mass() {
let non_zeros = find_matrix_non_zeros(&op, t0);
let sparsity =
M::Sparsity::try_from_indices(op.nout(), op.nstates(), non_zeros.clone())
.expect("invalid sparsity pattern");
let coloring = JacobianColoring::new(&sparsity, &non_zeros);
ret.mass_coloring = Some(coloring);
ret.mass_sparsity = Some(sparsity);
}
}
ret
}
Expand Down Expand Up @@ -395,7 +397,7 @@ impl<M: Matrix<T = T>, CG: CodegenModule> OdeEquations for DiffSl<M, CG> {
}

fn mass(&self) -> Option<DiffSlMass<'_, M, CG>> {
Some(DiffSlMass(self))
self.context.compiler.has_mass().then_some(DiffSlMass(self))
}

fn root(&self) -> Option<DiffSlRoot<'_, M, CG>> {
Expand Down Expand Up @@ -466,8 +468,8 @@ mod tests {
let k = 1.0;
let r = 1.0;
let context = DiffSlContext::<nalgebra::DMatrix<f64>, CG>::new(text).unwrap();
let mut eqn = DiffSl::from_context(context);
let p = DVector::from_vec(vec![r, k]);
let mut eqn = DiffSl::from_context(context);
eqn.set_params(Rc::new(p));

// test that the initial values look ok
Expand All @@ -489,10 +491,7 @@ mod tests {
mass_y.assert_eq_st(&mass_y_expect, 1e-10);

// solver a bit and check the state and output
let problem = OdeBuilder::new()
.p([r, k])
.build_from_eqn(Rc::new(eqn))
.unwrap();
let problem = OdeBuilder::new().p([r, k]).build_from_eqn(eqn).unwrap();
let mut solver = Bdf::default();
let t = 1.0;
let state = OdeSolverState::new(&problem, &solver).unwrap();
Expand Down
2 changes: 1 addition & 1 deletion src/ode_solver/equations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ impl<T> OdeEquationsAdjoint for T where
/// let p = Rc::new(V::from_vec(vec![]));
/// let eqn = OdeSolverEquations::new(rhs, mass, root, init, out, p);
///
/// let problem = OdeBuilder::new().build_from_eqn(Rc::new(eqn)).unwrap();
/// let problem = OdeBuilder::new().build_from_eqn(eqn).unwrap();
///
/// let mut solver = Bdf::default();
/// let t = 0.4;
Expand Down
Loading

0 comments on commit 8db78bb

Please sign in to comment.