Skip to content

Commit

Permalink
arithmetic: Extend model of aliasing to cover more cases.
Browse files Browse the repository at this point in the history
AliasingSlices modeled functions that take 3 pointers that may all
alias. Rename it to AliasingSlices3. Introduce an AliasingSlices2
that is analogous but for 2 pointers. This may be used for
functions that take 2 arguments or that take 3 arguments, but
where only two may alias, e.g. `(&mut x, &a), &a)`.
  • Loading branch information
briansmith committed Jan 23, 2025
1 parent cdc10bf commit 73aeb12
Show file tree
Hide file tree
Showing 6 changed files with 109 additions and 30 deletions.
23 changes: 18 additions & 5 deletions src/arithmetic/bigint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ pub(crate) use self::{
use super::{montgomery::*, LimbSliceError, MAX_LIMBS};
use crate::{
bits::BitLength,
c, error,
c,
error::{self, LenMismatchError},
limb::{self, Limb, LIMB_BITS},
};
use alloc::vec;
Expand Down Expand Up @@ -99,8 +100,13 @@ fn from_montgomery_amm<M>(limbs: BoxedLimbs<M>, m: &Modulus<M>) -> Elem<M, Unenc
let mut one = [0; MAX_LIMBS];
one[0] = 1;
let one = &one[..m.limbs().len()];
limbs_mul_mont((&mut limbs[..], one), m.limbs(), m.n0(), m.cpu_features())
.unwrap_or_else(unwrap_impossible_limb_slice_error);
limbs_mul_mont(
((&mut limbs[..], one), one),
m.limbs(),
m.n0(),
m.cpu_features(),
)
.unwrap_or_else(unwrap_impossible_limb_slice_error);
Elem {
limbs,
encoding: PhantomData,
Expand Down Expand Up @@ -146,7 +152,7 @@ where
(AF, BF): ProductEncoding,
{
limbs_mul_mont(
(&mut b.limbs[..], &a.limbs[..]),
((&mut b.limbs[..], &a.limbs[..]), &a.limbs[..]),
m.limbs(),
m.n0(),
m.cpu_features(),
Expand Down Expand Up @@ -235,7 +241,8 @@ pub fn elem_widen<Larger, Smaller>(

// TODO: Document why this works for all Montgomery factors.
pub fn elem_add<M, E>(mut a: Elem<M, E>, b: Elem<M, E>, m: &Modulus<M>) -> Elem<M, E> {
limb::limbs_add_assign_mod(&mut a.limbs, &b.limbs, m.limbs());
limb::limbs_add_assign_mod(&mut a.limbs[..], &b.limbs[..], m.limbs())
.unwrap_or_else(unwrap_impossible_len_mismatch_error);
a
}

Expand Down Expand Up @@ -723,6 +730,12 @@ pub fn elem_verify_equal_consttime<M, E>(
}
}

#[cold]
#[inline(never)]
fn unwrap_impossible_len_mismatch_error(_: LenMismatchError) {
unreachable!()
}

#[cold]
#[inline(never)]
fn unwrap_impossible_limb_slice_error(err: LimbSliceError) {
Expand Down
6 changes: 3 additions & 3 deletions src/arithmetic/ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
// OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
// CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.

use super::{inout::AliasingSlices, n0::N0, LimbSliceError, MAX_LIMBS, MIN_LIMBS};
use super::{inout::AliasingSlices3, n0::N0, LimbSliceError, MAX_LIMBS, MIN_LIMBS};
use crate::{c, limb::Limb, polyfill::usize_from_u32};
use core::mem::size_of;

Expand Down Expand Up @@ -53,7 +53,7 @@ macro_rules! bn_mul_mont_ffi {

#[inline]
pub(super) unsafe fn bn_mul_mont_ffi<Cpu, const LEN_MIN: usize, const LEN_MOD: usize>(
mut in_out: impl AliasingSlices<Limb>,
mut in_out: impl AliasingSlices3<Limb>,
n: &[Limb],
n0: &N0,
cpu: Cpu,
Expand Down Expand Up @@ -84,7 +84,7 @@ pub(super) unsafe fn bn_mul_mont_ffi<Cpu, const LEN_MIN: usize, const LEN_MOD: u
}

in_out
.with_pointers(n.len(), |r, a, b| {
.with_3_pointers(n.len(), |r, a, b| {
let len = n.len();
let n = n.as_ptr();
let _: Cpu = cpu;
Expand Down
72 changes: 59 additions & 13 deletions src/arithmetic/inout.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,33 +14,33 @@

pub(crate) use crate::error::LenMismatchError;

pub(crate) trait AliasingSlices<T> {
fn with_pointers<R>(
pub(crate) trait AliasingSlices2<T> {
fn with_2_pointers<R>(
&mut self,
expected_len: usize,
f: impl FnOnce(*mut T, *const T, *const T) -> R,
f: impl FnOnce(*mut T, *const T) -> R,
) -> Result<R, LenMismatchError>;
}

impl<T> AliasingSlices<T> for &mut [T] {
fn with_pointers<R>(
impl<T> AliasingSlices2<T> for &mut [T] {
fn with_2_pointers<R>(
&mut self,
expected_len: usize,
f: impl FnOnce(*mut T, *const T, *const T) -> R,
f: impl FnOnce(*mut T, *const T) -> R,
) -> Result<R, LenMismatchError> {
let r = self;
if r.len() != expected_len {
return Err(LenMismatchError::new(r.len()));
}
Ok(f(r.as_mut_ptr(), r.as_ptr(), r.as_ptr()))
Ok(f(r.as_mut_ptr(), r.as_ptr()))
}
}

impl<T> AliasingSlices<T> for (&mut [T], &[T]) {
fn with_pointers<R>(
impl<T> AliasingSlices2<T> for (&mut [T], &[T]) {
fn with_2_pointers<R>(
&mut self,
expected_len: usize,
f: impl FnOnce(*mut T, *const T, *const T) -> R,
f: impl FnOnce(*mut T, *const T) -> R,
) -> Result<R, LenMismatchError> {
let (r, a) = self;
if r.len() != expected_len {
Expand All @@ -49,12 +49,41 @@ impl<T> AliasingSlices<T> for (&mut [T], &[T]) {
if a.len() != expected_len {
return Err(LenMismatchError::new(a.len()));
}
Ok(f(r.as_mut_ptr(), r.as_ptr(), a.as_ptr()))
Ok(f(r.as_mut_ptr(), a.as_ptr()))
}
}

impl<T> AliasingSlices<T> for (&mut [T], &[T], &[T]) {
fn with_pointers<R>(
pub(crate) trait AliasingSlices3<T> {
fn with_3_pointers<R>(
&mut self,
expected_len: usize,
f: impl FnOnce(*mut T, *const T, *const T) -> R,
) -> Result<R, LenMismatchError>;
}

// TODO:
// impl<A, T> AliasingSlices3<T> for A where Self: AliasingSlices2<T> {
// fn with_3_pointers<R>(
// &mut self,
// expected_len: usize,
// f: impl FnOnce(*mut T, *const T, *const T) -> R,
// ) -> Result<R, LenMismatchError> {
// <Self as AliasingSlices2<T>>::with_2_pointers(expected_len, |r, a| f(r, r, a))
// }
// }

impl<T> AliasingSlices3<T> for &mut [T] {
fn with_3_pointers<R>(
&mut self,
expected_len: usize,
f: impl FnOnce(*mut T, *const T, *const T) -> R,
) -> Result<R, LenMismatchError> {
<Self as AliasingSlices2<T>>::with_2_pointers(self, expected_len, |r, a| f(r, r, a))
}
}

impl<T> AliasingSlices3<T> for (&mut [T], &[T], &[T]) {
fn with_3_pointers<R>(
&mut self,
expected_len: usize,
f: impl FnOnce(*mut T, *const T, *const T) -> R,
Expand All @@ -72,3 +101,20 @@ impl<T> AliasingSlices<T> for (&mut [T], &[T], &[T]) {
Ok(f(r.as_mut_ptr(), a.as_ptr(), b.as_ptr()))
}
}

impl<RA, T> AliasingSlices3<T> for (RA, &[T])
where
RA: AliasingSlices2<T>,
{
fn with_3_pointers<R>(
&mut self,
expected_len: usize,
f: impl FnOnce(*mut T, *const T, *const T) -> R,
) -> Result<R, LenMismatchError> {
let (ra, b) = self;
if b.len() != expected_len {
return Err(LenMismatchError::new(b.len()));
}
ra.with_2_pointers(expected_len, |r, a| f(r, a, b.as_ptr()))
}
}
4 changes: 2 additions & 2 deletions src/arithmetic/montgomery.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
// CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.

pub use super::n0::N0;
use super::{inout::AliasingSlices, LimbSliceError, MIN_LIMBS};
use super::{inout::AliasingSlices3, LimbSliceError, MIN_LIMBS};
use crate::cpu;
use cfg_if::cfg_if;

Expand Down Expand Up @@ -116,7 +116,7 @@ use crate::{bssl, c, limb::Limb};

#[inline(always)]
pub(super) fn limbs_mul_mont(
in_out: impl AliasingSlices<Limb>,
in_out: impl AliasingSlices3<Limb>,
n: &[Limb],
n0: &N0,
cpu: cpu::Features,
Expand Down
15 changes: 13 additions & 2 deletions src/ec/suite_b/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@
// CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.

use crate::{
arithmetic::limbs_from_hex, arithmetic::montgomery::*, constant_time::LeakyWord, cpu, error,
arithmetic::limbs_from_hex,
arithmetic::montgomery::*,
constant_time::LeakyWord,
cpu,
error::{self, LenMismatchError},
limb::*,
};
use core::marker::PhantomData;
Expand Down Expand Up @@ -133,7 +137,8 @@ impl<M> Modulus<M> {
&mut a.limbs[..num_limbs],
&b.limbs[..num_limbs],
&self.limbs[..num_limbs],
);
)
.unwrap_or_else(unwrap_impossible_len_mismatch_error)
}
}

Expand Down Expand Up @@ -600,6 +605,12 @@ fn parse_big_endian_fixed_consttime<M>(
Ok(r)
}

#[cold]
#[inline(never)]
fn unwrap_impossible_len_mismatch_error(_: LenMismatchError) {
unreachable!()
}

#[cfg(test)]
mod tests {
extern crate alloc;
Expand Down
19 changes: 14 additions & 5 deletions src/limb.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
//! limbs use the native endianness.
use crate::{
c, constant_time, error,
arithmetic::inout::AliasingSlices3,
c, constant_time,
error::{self, LenMismatchError},
polyfill::{slice, usize_from_u32, ArrayFlatMap},
};

Expand Down Expand Up @@ -286,9 +288,14 @@ pub fn fold_5_bit_windows<R, I: FnOnce(Window) -> R, F: Fn(R, Window) -> R>(
}

#[inline]
pub(crate) fn limbs_add_assign_mod(a: &mut [Limb], b: &[Limb], m: &[Limb]) {
debug_assert_eq!(a.len(), m.len());
debug_assert_eq!(b.len(), m.len());
pub(crate) fn limbs_add_assign_mod<'io, InOut: 'io>(
in_out: InOut,
b: &'io [Limb],
m: &[Limb],
) -> Result<(), LenMismatchError>
where
(InOut, &'io [Limb]): AliasingSlices3<Limb>,
{
prefixed_extern! {
// `r` and `a` may alias.
fn LIMBS_add_mod(
Expand All @@ -299,7 +306,9 @@ pub(crate) fn limbs_add_assign_mod(a: &mut [Limb], b: &[Limb], m: &[Limb]) {
num_limbs: c::size_t,
);
}
unsafe { LIMBS_add_mod(a.as_mut_ptr(), a.as_ptr(), b.as_ptr(), m.as_ptr(), m.len()) }
(in_out, b).with_3_pointers(m.len(), |r, a, b| unsafe {
LIMBS_add_mod(r, a, b, m.as_ptr(), m.len())
})
}

// r *= 2 (mod m).
Expand Down

0 comments on commit 73aeb12

Please sign in to comment.