Skip to content

Commit

Permalink
made a start converting &V to View and &mut to ViewMut #47
Browse files Browse the repository at this point in the history
  • Loading branch information
martinjrobins committed May 7, 2024
1 parent 7e1cd1e commit 2d18d96
Show file tree
Hide file tree
Showing 21 changed files with 78 additions and 75 deletions.
14 changes: 7 additions & 7 deletions src/jacobian/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ pub fn find_non_zeros_linear<F: LinearOp + ?Sized>(op: &F, t: F::T) -> Vec<(usiz
let mut triplets = Vec::with_capacity(op.nstates());
for j in 0..op.nstates() {
v[j] = F::T::NAN;
op.call_inplace(&v, t, &mut col);
op.call_inplace(v.view(), t, col.view_mut());
for i in 0..op.nout() {
if col[i].is_nan() {
triplets.push((i, j));
Expand Down Expand Up @@ -144,7 +144,7 @@ impl<M: Matrix> JacobianColoring<M> {
let dst_indices = &self.dst_indices_per_color[c];
let src_indices = &self.src_indices_per_color[c];
v.assign_at_indices(input, F::T::one());
op.call_inplace(&v, t, &mut col);
op.call_inplace(v.view(), t, col.view_mut());
y.set_data_with_indices(dst_indices, src_indices, &col);
v.assign_at_indices(input, F::T::zero());
}
Expand All @@ -165,7 +165,7 @@ mod tests {
jacobian::{coloring::nonzeros2graph, greedy_coloring::color_graph_greedy},
op::closure::Closure,
};
use nalgebra::{DMatrix, DVector};
use nalgebra::{DMatrix, DVector, DVectorView, DVectorViewMut};
use std::ops::MulAssign;

fn helper_triplets2op_nonlinear(
Expand All @@ -175,17 +175,17 @@ mod tests {
) -> impl NonLinearOp<M = DMatrix<f64>, V = DVector<f64>, T = f64> + '_ {
let nstates = ncols;
let nout = nrows;
let f = move |x: &DVector<f64>, y: &mut DVector<f64>| {
let f = move |x: DVectorView<'_, f64>, mut y: DVectorViewMut<'_, f64>| {
for (i, j, v) in triplets {
y[*i] += x[*j] * v;
}
};
let mut ret = Closure::new(
move |x: &DVector<f64>, _p: &DVector<f64>, _t, y: &mut DVector<f64>| {
move |x: DVectorView<'_, f64>, _p: &DVector<f64>, _t, mut y: DVectorViewMut<'_, f64>| {
y.fill(0.0);
f(x, y);
},
move |_x: &DVector<f64>, _p: &DVector<f64>, _t, v, y: &mut DVector<f64>| {
move |_x: DVectorView<'_, f64>, _p: &DVector<f64>, _t, v, mut y: DVectorViewMut<'_, f64>| {
y.fill(0.0);
f(v, y);
},
Expand Down Expand Up @@ -302,7 +302,7 @@ mod tests {
coloring.matrix_inplace(&op, t0, &mut jac);
let mut gemv1 = V::zeros(n);
let v = V::from_element(3, 1.0);
op.gemv_inplace(&v, t0, 0.0, &mut gemv1);
op.gemv_inplace(Vector::view(&v), t0, 0.0, Vector::view_mut(&mut gemv1));
let mut gemv2 = V::zeros(n);
jac.gemv(1.0, &v, 0.0, &mut gemv2);
gemv1.assert_eq_st(&gemv2, 1e-10);
Expand Down
4 changes: 2 additions & 2 deletions src/matrix/dense_faer_serial.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::ops::{Mul, MulAssign};

use super::default_solver::DefaultSolver;
use super::{Dense, DenseMatrix, Matrix, MatrixCommon, MatrixSparsity, MatrixView, MatrixViewMut};
use crate::op::NonLinearOp;
use crate::op::{NonLinearOp, VView, VViewMut};
use crate::scalar::{IndexType, Scalar, Scale};
use crate::vector::Vector;
use crate::FaerLU;
Expand Down Expand Up @@ -155,7 +155,7 @@ impl<T: Scalar> Matrix for Mat<T> {
}
Ok(m)
}
fn gemv(&self, alpha: Self::T, x: &Self::V, beta: Self::T, y: &mut Self::V) {
fn gemv(&self, alpha: Self::T, x: VView<'_, Self>, beta: Self::T, mut y: VViewMut<'_, Self>) {
*y = faer::scale(alpha) * self * x + faer::scale(beta) * &*y;
}
fn zeros(nrows: IndexType, ncols: IndexType) -> Self {
Expand Down
6 changes: 3 additions & 3 deletions src/matrix/dense_nalgebra_serial.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::ops::{AddAssign, Mul, MulAssign};
use anyhow::Result;
use nalgebra::{DMatrix, DMatrixView, DMatrixViewMut, DVector, DVectorView, DVectorViewMut};

use crate::op::NonLinearOp;
use crate::op::{NonLinearOp, VView, VViewMut};
use crate::{scalar::Scale, IndexType, Scalar};

use crate::{DenseMatrix, Matrix, MatrixCommon, MatrixView, MatrixViewMut, NalgebraLU};
Expand Down Expand Up @@ -127,8 +127,8 @@ impl<T: Scalar> Matrix for DMatrix<T> {
self.diagonal()
}

fn gemv(&self, alpha: Self::T, x: &Self::V, beta: Self::T, y: &mut Self::V) {
y.gemv(alpha, self, x, beta);
fn gemv(&self, alpha: Self::T, x: VView<'_, Self>, beta: Self::T, mut y: VViewMut<'_, Self>) {
y.gemv(alpha, self, &x, beta);
}
fn copy_from(&mut self, other: &Self) {
self.copy_from(other);
Expand Down
3 changes: 2 additions & 1 deletion src/matrix/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::fmt::Debug;
use std::ops::{Add, AddAssign, Index, IndexMut, Mul, MulAssign, Sub, SubAssign};

use crate::op::{VView, VViewMut};
use crate::scalar::Scale;
use crate::{IndexType, Scalar, Vector};
use anyhow::Result;
Expand Down Expand Up @@ -210,7 +211,7 @@ pub trait Matrix:
fn diagonal(&self) -> Self::V;

/// Perform a matrix-vector multiplication `y = self * x + beta * y`.
fn gemv(&self, alpha: Self::T, x: &Self::V, beta: Self::T, y: &mut Self::V);
fn gemv(&self, alpha: Self::T, x: VView<'_, Self>, beta: Self::T, y: VViewMut<'_, Self>);

/// Copy the contents of `other` into `self`
fn copy_from(&mut self, other: &Self);
Expand Down
4 changes: 2 additions & 2 deletions src/matrix/sparse_serial.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use anyhow::Result;
use nalgebra::DVector;
use nalgebra_sparse::{pattern::SparsityPattern, CooMatrix, CscMatrix};

use crate::{scalar::Scale, IndexType, Scalar};
use crate::{op::{VView, VViewMut}, scalar::Scale, IndexType, Scalar};

use super::{Matrix, MatrixCommon, MatrixSparsity};

Expand Down Expand Up @@ -180,7 +180,7 @@ impl<T: Scalar> Matrix for CscMatrix<T> {
fn copy_from(&mut self, other: &Self) {
self.clone_from(other);
}
fn gemv(&self, alpha: Self::T, x: &Self::V, beta: Self::T, y: &mut Self::V) {
fn gemv(&self, alpha: Self::T, x: VView<'_, Self>, beta: Self::T, mut y: VViewMut<'_, Self>) {
let mut tmp = self * x;
tmp *= alpha;
y.axpy(alpha, &tmp, beta);
Expand Down
4 changes: 2 additions & 2 deletions src/matrix/sundials.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use sundials_sys::{

use crate::{
ode_solver::sundials::sundials_check,
op::NonLinearOp,
op::{NonLinearOp, VView, VViewMut},
scalar::scale,
vector::sundials::{get_suncontext, SundialsVector},
IndexType, Scale, SundialsLinearSolver, Vector,
Expand Down Expand Up @@ -302,7 +302,7 @@ impl Matrix for SundialsMatrix {
}

/// Perform a matrix-vector multiplication `y = alpha * self * x + beta * y`.
fn gemv(&self, alpha: Self::T, x: &Self::V, beta: Self::T, y: &mut Self::V) {
fn gemv(&self, alpha: Self::T, x: VView<'_, Self>, beta: Self::T, mut y: VViewMut<'_, Self>) {
let a = self.sundials_matrix();
let tmp = SundialsVector::new_serial(self.nrows());
sundials_check(unsafe { SUNMatMatvecSetup(a) }).unwrap();
Expand Down
2 changes: 1 addition & 1 deletion src/nonlinear_solver/newton.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ impl<C: NonLinearOp, Ls: LinearSolver<C>> NonLinearSolver<C> for NewtonNonlinear
self.niter = 0;
loop {
self.niter += 1;
problem.f.call_inplace(xn, t, &mut tmp);
problem.f.call_inplace(xn.view(), t, tmp.view_mut());
//tmp = f_at_n

self.linear_solver.solve_in_place(&mut tmp)?;
Expand Down
10 changes: 5 additions & 5 deletions src/nonlinear_solver/root.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ impl<V: Vector> RootFinder<V> {
/// Set the lower boundary of the root search.
/// This function should be called first after [Self::new]
pub fn init(&self, root_fn: &impl NonLinearOp<V = V, T = V::T>, y: &V, t: V::T) {
root_fn.call_inplace(y, t, &mut self.g0.borrow_mut());
root_fn.call_inplace(y.view(), t, self.g0.borrow_mut().view_mut());
self.t0.replace(t);
}

Expand All @@ -48,7 +48,7 @@ impl<V: Vector> RootFinder<V> {
let g1 = &mut *self.g1.borrow_mut();
let g0 = &mut *self.g0.borrow_mut();
let gmid = &mut *self.gmid.borrow_mut();
root_fn.call_inplace(y, t, g1);
root_fn.call_inplace(y.view(), t, g1.view_mut());

let sign_change_fn = |mut acc: (bool, V::T, i32), g0: V::T, g1: V::T, i: IndexType| {
if g1 == V::T::zero() {
Expand Down Expand Up @@ -115,7 +115,7 @@ impl<V: Vector> RootFinder<V> {
}

let ymid = interpolate(t_mid).unwrap();
root_fn.call_inplace(&ymid, t_mid, gmid);
root_fn.call_inplace(ymid.view(), t_mid, gmid.view_mut());

let (rootfnd, _gfracmax, imax_i32) =
(*g0).binary_fold(gmid, (false, V::T::zero(), -1), sign_change_fn);
Expand All @@ -128,7 +128,7 @@ impl<V: Vector> RootFinder<V> {
std::mem::swap(g1, gmid);
} else if rootfnd {
// we are returning so make sure g0 is set for next iteration
root_fn.call_inplace(y, t, g0);
root_fn.call_inplace(y.view(), t, g0.view_mut());

// No sign change in (tlo,tmid), but g = 0 at tmid; return root tmid.
return Some(t_mid);
Expand All @@ -151,7 +151,7 @@ impl<V: Vector> RootFinder<V> {
i += 1;
}
// we are returning so make sure g0 is set for next iteration
root_fn.call_inplace(y, t, g0);
root_fn.call_inplace(y.view(), t, g0.view_mut());
Some(t1)
}
}
Expand Down
4 changes: 2 additions & 2 deletions src/ode_solver/bdf/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -358,11 +358,11 @@ where
scale_factor *= scale(problem.rtol);
scale_factor += problem.atol.as_ref();

let f0 = problem.eqn.rhs().call(&state.y, state.t);
let f0 = problem.eqn.rhs().call(state.y.view(), state.t);
let hf0 = &f0 * scale(state.h);
let y1 = &state.y + &hf0;
let t1 = state.t + state.h;
let f1 = problem.eqn.rhs().call(&y1, t1);
let f1 = problem.eqn.rhs().call(y1.view(), t1);

// store f1 in diff[1] for use in step size control
self.diff.column_mut(1).copy_from(&hf0);
Expand Down
14 changes: 6 additions & 8 deletions src/ode_solver/diffsl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@ use anyhow::Result;
use diffsl::execution::Compiler;

use crate::{
jacobian::{find_non_zeros_linear, find_non_zeros_nonlinear, JacobianColoring},
op::{LinearOp, NonLinearOp, Op},
OdeEquations,
jacobian::{find_non_zeros_linear, find_non_zeros_nonlinear, JacobianColoring}, op::{LinearOp, NonLinearOp, Op}, vector::Vector, OdeEquations
};

pub type T = f64;
Expand Down Expand Up @@ -170,7 +168,7 @@ impl Op for DiffSlRoot<'_> {
}

impl NonLinearOp for DiffSlRoot<'_> {
fn call_inplace(&self, x: &Self::V, t: Self::T, y: &mut Self::V) {
fn call_inplace(&self, x: <Self::V as Vector>::View<'_>, t: Self::T, y: <Self::V as Vector>::ViewMut<'_>) {
self.context.compiler.calc_stop(
t,
x.as_slice(),
Expand All @@ -185,7 +183,7 @@ impl NonLinearOp for DiffSlRoot<'_> {
}

impl NonLinearOp for DiffSlRhs<'_> {
fn call_inplace(&self, x: &Self::V, t: Self::T, y: &mut Self::V) {
fn call_inplace(&self, x: <Self::V as Vector>::View<'_>, t: Self::T, y: <Self::V as Vector>::ViewMut<'_>) {
self.context.compiler.rhs(
t,
x.as_slice(),
Expand Down Expand Up @@ -217,7 +215,7 @@ impl NonLinearOp for DiffSlRhs<'_> {
}

impl LinearOp for DiffSlMass<'_> {
fn gemv_inplace(&self, x: &Self::V, t: Self::T, beta: Self::T, y: &mut Self::V) {
fn gemv_inplace(&self, x: <Self::V as Vector>::View<'_>, t: Self::T, beta: Self::T, y: <Self::V as Vector>::ViewMut<'_>) {
let mut tmp = self.context.tmp.borrow_mut();
self.context.compiler.mass(
t,
Expand Down Expand Up @@ -339,7 +337,7 @@ mod tests {
let init = eqn.init(0.0);
let init_expect = DVector::from_vec(vec![y0, 0.0]);
init.assert_eq_st(&init_expect, 1e-10);
let rhs = eqn.rhs().call(&init, 0.0);
let rhs = eqn.rhs().call(Vector::view(&init), 0.0);
let rhs_expect = DVector::from_vec(vec![r * y0 * (1.0 - y0 / k), 2.0 * y0]);
rhs.assert_eq_st(&rhs_expect, 1e-10);
let v = DVector::from_vec(vec![1.0, 1.0]);
Expand All @@ -348,7 +346,7 @@ mod tests {
rhs_jac.assert_eq_st(&rhs_jac_expect, 1e-10);
let mut mass_y = DVector::from_vec(vec![0.0, 0.0]);
let v = DVector::from_vec(vec![1.0, 1.0]);
eqn.mass().call_inplace(&v, 0.0, &mut mass_y);
eqn.mass().call_inplace(Vector::view(&v), 0.0, Vector::view_mut(&mut mass_y));
let mass_y_expect = DVector::from_vec(vec![1.0, 0.0]);
mass_y.assert_eq_st(&mass_y_expect, 1e-10);

Expand Down
4 changes: 2 additions & 2 deletions src/ode_solver/equations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ mod tests {
fn ode_equation_test() {
let (problem, _soln) = exponential_decay_problem::<Mcpu>(false);
let y = DVector::from_vec(vec![1.0, 1.0]);
let rhs_y = problem.eqn.rhs().call(&y, 0.0);
let rhs_y = problem.eqn.rhs().call(Vector::view(&y), 0.0);
let expect_rhs_y = DVector::from_vec(vec![-0.1, -0.1]);
rhs_y.assert_eq_st(&expect_rhs_y, 1e-10);
let jac_rhs_y = problem.eqn.rhs().jac_mul(&y, 0.0, &y);
Expand All @@ -203,7 +203,7 @@ mod tests {
fn ode_with_mass_test() {
let (problem, _soln) = exponential_decay_with_algebraic_problem::<Mcpu>(false);
let y = DVector::from_vec(vec![1.0, 1.0, 1.0]);
let rhs_y = problem.eqn.rhs().call(&y, 0.0);
let rhs_y = problem.eqn.rhs().call(Vector::view(&y), 0.0);
let expect_rhs_y = DVector::from_vec(vec![-0.1, -0.1, 0.0]);
rhs_y.assert_eq_st(&expect_rhs_y, 1e-10);
let jac_rhs_y = problem.eqn.rhs().jac_mul(&y, 0.0, &y);
Expand Down
2 changes: 1 addition & 1 deletion src/ode_solver/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,7 @@ mod tests {
}

impl<M: Matrix> NonLinearOp for TestEqnRhs<M> {
fn call_inplace(&self, _x: &Self::V, _t: Self::T, y: &mut Self::V) {
fn call_inplace(&self, _x: <Self::V as Vector>::View<'_>, _t: Self::T, y: <Self::V as Vector>::ViewMut<'_>) {
y[0] = M::T::zero();
}

Expand Down
4 changes: 2 additions & 2 deletions src/ode_solver/sdirk.rs
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ where
// compute first step based on alg in Hairer, Norsett, Wanner
// Solving Ordinary Differential Equations I, Nonstiff Problems
// Section II.4.2
let f0 = problem.eqn.rhs().call(&state.y, state.t);
let f0 = problem.eqn.rhs().call(state.y.view(), state.t);
let hf0 = &f0 * scale(state.h);

let mut tmp = f0.clone();
Expand All @@ -237,7 +237,7 @@ where

let y1 = &state.y + hf0;
let t1 = state.t + h0;
let f1 = problem.eqn.rhs().call(&y1, t1);
let f1 = problem.eqn.rhs().call(y1.view(), t1);

let mut df = f1 - &f0;
df *= scale(Eqn::T::one() / h0);
Expand Down
4 changes: 2 additions & 2 deletions src/ode_solver/sundials.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,9 @@ where
let mut rr = SundialsVector::new_not_owned(rr);
// F(t, y, y') = M y' - f(t, y)
// rr = f(t, y)
data.eqn.rhs().call_inplace(&y, t, &mut rr);
data.eqn.rhs().call_inplace(y.view(), t, rr.view_mut());
// rr = M y' - rr
data.eqn.mass().gemv_inplace(&yp, t, -1.0, &mut rr);
data.eqn.mass().gemv_inplace(yp.view(), t, -1.0, rr.view_mut());
0
}

Expand Down
10 changes: 5 additions & 5 deletions src/op/bdf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,25 +120,25 @@ where
for<'b> &'b Eqn::M: MatrixRef<Eqn::M>,
{
// F(y) = M (y - y0 + psi) - c * f(y) = 0
fn call_inplace(&self, x: &Eqn::V, t: Eqn::T, y: &mut Eqn::V) {
fn call_inplace(&self, x: <Eqn::V as Vector>::View<'_>, t: Eqn::T, y: <Eqn::V as Vector>::ViewMut<'_>) {
let psi_neg_y0_ref = self.psi_neg_y0.borrow();
let psi_neg_y0 = psi_neg_y0_ref.deref();

self.eqn.rhs().call_inplace(x, t, y);

let mut tmp = self.tmp.borrow_mut();
tmp.copy_from(x);
tmp.copy_from_view(&x);
tmp.add_assign(psi_neg_y0);
let c = *self.c.borrow().deref();
// y = M tmp - c * y
self.eqn.mass().gemv_inplace(&tmp, t, -c, y);
self.eqn.mass().gemv_inplace(tmp.view(), t, -c, y);
}
// (M - c * f'(y)) v
fn jac_mul_inplace(&self, x: &Eqn::V, t: Eqn::T, v: &Eqn::V, y: &mut Eqn::V) {
self.eqn.rhs().jac_mul_inplace(x, t, v, y);
let c = *self.c.borrow().deref();
// y = Mv - c y
self.eqn.mass().gemv_inplace(v, t, -c, y);
self.eqn.mass().gemv_inplace(v.view(), t, -c, y.view_mut());
}

fn jacobian_inplace(&self, x: &Self::V, t: Self::T, y: &mut Self::M) {
Expand Down Expand Up @@ -200,7 +200,7 @@ mod tests {
// |-0.1|
// i.e. F(y) = |1 0| |2.1| - 0.1 * |-0.1| = |2.11|
// |0 1| |2.2| |-0.1| |2.21|
bdf_callable.call_inplace(&y, t, &mut y_out);
bdf_callable.call_inplace(Vector::view(&y), t, Vector::view_mut(&mut y_out));
let y_out_expect = Vcpu::from_vec(vec![2.11, 2.21]);
y_out.assert_eq_st(&y_out_expect, 1e-10);

Expand Down
Loading

0 comments on commit 2d18d96

Please sign in to comment.