Skip to content

Commit

Permalink
montgomery: Encapsulate Rust uses of bn_mul_mont.
Browse files Browse the repository at this point in the history
Have all calls from Rust go through `mul_mont`, which ensures
CPU feature detection has been done.
  • Loading branch information
briansmith committed Dec 2, 2023
1 parent 40e147d commit 1f55dc1
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 62 deletions.
63 changes: 3 additions & 60 deletions src/arithmetic/bigint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ pub(crate) use self::{
use crate::{
arithmetic::montgomery::*,
bits::BitLength,
c, cpu, error,
c, error,
limb::{self, Limb, LimbMask, LIMB_BITS},
};
use alloc::vec;
Expand Down Expand Up @@ -491,7 +491,7 @@ pub fn elem_exp_consttime<M>(
exponent: &PrivateExponent,
m: &Modulus<M>,
) -> Result<Elem<M, Unencoded>, error::Unspecified> {
use crate::limb::LIMB_BYTES;
use crate::{cpu, limb::LIMB_BYTES};

// Pretty much all the math here requires CPU feature detection to have
// been done. `cpu_features` isn't threaded through all the internal
Expand Down Expand Up @@ -701,67 +701,10 @@ pub fn elem_verify_equal_consttime<M, E>(
}
}

/// r *= a
fn limbs_mont_mul(r: &mut [Limb], a: &[Limb], m: &[Limb], n0: &N0, _cpu_features: cpu::Features) {
debug_assert_eq!(r.len(), m.len());
debug_assert_eq!(a.len(), m.len());
unsafe {
bn_mul_mont(
r.as_mut_ptr(),
r.as_ptr(),
a.as_ptr(),
m.as_ptr(),
n0,
r.len(),
)
}
}

/// r = a * b
#[cfg(not(target_arch = "x86_64"))]
fn limbs_mont_product(
r: &mut [Limb],
a: &[Limb],
b: &[Limb],
m: &[Limb],
n0: &N0,
_cpu_features: cpu::Features,
) {
debug_assert_eq!(r.len(), m.len());
debug_assert_eq!(a.len(), m.len());
debug_assert_eq!(b.len(), m.len());

unsafe {
bn_mul_mont(
r.as_mut_ptr(),
a.as_ptr(),
b.as_ptr(),
m.as_ptr(),
n0,
r.len(),
)
}
}

/// r = r**2
fn limbs_mont_square(r: &mut [Limb], m: &[Limb], n0: &N0, _cpu_features: cpu::Features) {
debug_assert_eq!(r.len(), m.len());
unsafe {
bn_mul_mont(
r.as_mut_ptr(),
r.as_ptr(),
r.as_ptr(),
m.as_ptr(),
n0,
r.len(),
)
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::test;
use crate::{cpu, test};

// Type-level representation of an arbitrary modulus.
struct M {}
Expand Down
83 changes: 81 additions & 2 deletions src/arithmetic/montgomery.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.

pub use super::n0::N0;
use crate::cpu;

// Indicates that the element is not encoded; there is no *R* factor
// that needs to be canceled out.
Expand Down Expand Up @@ -111,6 +112,19 @@ impl ProductEncoding for (RRR, RInverse) {
#[allow(unused_imports)]
use crate::{bssl, c, limb::Limb};

#[inline(always)]
unsafe fn mul_mont(
r: *mut Limb,
a: *const Limb,
b: *const Limb,
n: *const Limb,
n0: &N0,
num_limbs: c::size_t,
_: cpu::Features,
) {
bn_mul_mont(r, a, b, n, n0, num_limbs)
}

#[cfg(not(any(
target_arch = "aarch64",
target_arch = "arm",
Expand All @@ -120,7 +134,7 @@ use crate::{bssl, c, limb::Limb};
// TODO: Stop calling this from C and un-export it.
#[allow(deprecated)]
prefixed_export! {
pub(super) unsafe fn bn_mul_mont(
unsafe fn bn_mul_mont(
r: *mut Limb,
a: *const Limb,
b: *const Limb,
Expand Down Expand Up @@ -226,7 +240,7 @@ prefixed_extern! {
))]
prefixed_extern! {
// `r` and/or 'a' and/or 'b' may alias.
pub(super) fn bn_mul_mont(
fn bn_mul_mont(
r: *mut Limb,
a: *const Limb,
b: *const Limb,
Expand All @@ -236,6 +250,71 @@ prefixed_extern! {
);
}

/// r *= a
pub(super) fn limbs_mont_mul(
r: &mut [Limb],
a: &[Limb],
m: &[Limb],
n0: &N0,
cpu_features: cpu::Features,
) {
debug_assert_eq!(r.len(), m.len());
debug_assert_eq!(a.len(), m.len());
unsafe {
mul_mont(
r.as_mut_ptr(),
r.as_ptr(),
a.as_ptr(),
m.as_ptr(),
n0,
r.len(),
cpu_features,
)
}
}

/// r = a * b
#[cfg(not(target_arch = "x86_64"))]
pub(super) fn limbs_mont_product(
r: &mut [Limb],
a: &[Limb],
b: &[Limb],
m: &[Limb],
n0: &N0,
cpu_features: cpu::Features,
) {
debug_assert_eq!(r.len(), m.len());
debug_assert_eq!(a.len(), m.len());
debug_assert_eq!(b.len(), m.len());

unsafe {
mul_mont(
r.as_mut_ptr(),
a.as_ptr(),
b.as_ptr(),
m.as_ptr(),
n0,
r.len(),
cpu_features,
)
}
}

/// r = r**2
pub(super) fn limbs_mont_square(r: &mut [Limb], m: &[Limb], n0: &N0, cpu_features: cpu::Features) {
debug_assert_eq!(r.len(), m.len());
unsafe {
mul_mont(
r.as_mut_ptr(),
r.as_ptr(),
r.as_ptr(),
m.as_ptr(),
n0,
r.len(),
cpu_features,
)
}
}
#[cfg(test)]
mod tests {
use super::*;
Expand Down

0 comments on commit 1f55dc1

Please sign in to comment.