From 16c3e19f71668687c3ecdb08daeeecfd26188965 Mon Sep 17 00:00:00 2001 From: Brian Smith Date: Sat, 2 Dec 2023 14:15:42 -0800 Subject: [PATCH] montgomery: Encapsulate Rust uses of bn_mul_mont. Have all calls from Rust go through `mul_mont`, which ensures CPU feature detection has been done. --- src/arithmetic/bigint.rs | 63 ++------------------------- src/arithmetic/montgomery.rs | 83 +++++++++++++++++++++++++++++++++++- 2 files changed, 84 insertions(+), 62 deletions(-) diff --git a/src/arithmetic/bigint.rs b/src/arithmetic/bigint.rs index 4f127b55dd..92eb21fab0 100644 --- a/src/arithmetic/bigint.rs +++ b/src/arithmetic/bigint.rs @@ -44,7 +44,7 @@ pub(crate) use self::{ use crate::{ arithmetic::montgomery::*, bits::BitLength, - c, cpu, error, + c, error, limb::{self, Limb, LimbMask, LIMB_BITS}, }; use alloc::vec; @@ -491,7 +491,7 @@ pub fn elem_exp_consttime( exponent: &PrivateExponent, m: &Modulus, ) -> Result, error::Unspecified> { - use crate::limb::LIMB_BYTES; + use crate::{cpu, limb::LIMB_BYTES}; // Pretty much all the math here requires CPU feature detection to have // been done. `cpu_features` isn't threaded through all the internal @@ -701,67 +701,10 @@ pub fn elem_verify_equal_consttime( } } -/// r *= a -fn limbs_mont_mul(r: &mut [Limb], a: &[Limb], m: &[Limb], n0: &N0, _cpu_features: cpu::Features) { - debug_assert_eq!(r.len(), m.len()); - debug_assert_eq!(a.len(), m.len()); - unsafe { - bn_mul_mont( - r.as_mut_ptr(), - r.as_ptr(), - a.as_ptr(), - m.as_ptr(), - n0, - r.len(), - ) - } -} - -/// r = a * b -#[cfg(not(target_arch = "x86_64"))] -fn limbs_mont_product( - r: &mut [Limb], - a: &[Limb], - b: &[Limb], - m: &[Limb], - n0: &N0, - _cpu_features: cpu::Features, -) { - debug_assert_eq!(r.len(), m.len()); - debug_assert_eq!(a.len(), m.len()); - debug_assert_eq!(b.len(), m.len()); - - unsafe { - bn_mul_mont( - r.as_mut_ptr(), - a.as_ptr(), - b.as_ptr(), - m.as_ptr(), - n0, - r.len(), - ) - } -} - -/// r = r**2 -fn limbs_mont_square(r: &mut [Limb], m: &[Limb], n0: &N0, _cpu_features: cpu::Features) { - debug_assert_eq!(r.len(), m.len()); - unsafe { - bn_mul_mont( - r.as_mut_ptr(), - r.as_ptr(), - r.as_ptr(), - m.as_ptr(), - n0, - r.len(), - ) - } -} - #[cfg(test)] mod tests { use super::*; - use crate::test; + use crate::{cpu, test}; // Type-level representation of an arbitrary modulus. struct M {} diff --git a/src/arithmetic/montgomery.rs b/src/arithmetic/montgomery.rs index a155ec36bc..b3bed1b14c 100644 --- a/src/arithmetic/montgomery.rs +++ b/src/arithmetic/montgomery.rs @@ -13,6 +13,7 @@ // CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. pub use super::n0::N0; +use crate::cpu; // Indicates that the element is not encoded; there is no *R* factor // that needs to be canceled out. @@ -111,6 +112,19 @@ impl ProductEncoding for (RRR, RInverse) { #[allow(unused_imports)] use crate::{bssl, c, limb::Limb}; +#[inline(always)] +unsafe fn mul_mont( + r: *mut Limb, + a: *const Limb, + b: *const Limb, + n: *const Limb, + n0: &N0, + num_limbs: c::size_t, + _: cpu::Features, +) { + bn_mul_mont(r, a, b, n, n0, num_limbs) +} + #[cfg(not(any( target_arch = "aarch64", target_arch = "arm", @@ -120,7 +134,7 @@ use crate::{bssl, c, limb::Limb}; // TODO: Stop calling this from C and un-export it. #[allow(deprecated)] prefixed_export! { - pub(super) unsafe fn bn_mul_mont( + unsafe fn bn_mul_mont( r: *mut Limb, a: *const Limb, b: *const Limb, @@ -226,7 +240,7 @@ prefixed_extern! { ))] prefixed_extern! { // `r` and/or 'a' and/or 'b' may alias. - pub(super) fn bn_mul_mont( + fn bn_mul_mont( r: *mut Limb, a: *const Limb, b: *const Limb, @@ -236,6 +250,71 @@ prefixed_extern! { ); } +/// r *= a +pub(super) fn limbs_mont_mul( + r: &mut [Limb], + a: &[Limb], + m: &[Limb], + n0: &N0, + cpu_features: cpu::Features, +) { + debug_assert_eq!(r.len(), m.len()); + debug_assert_eq!(a.len(), m.len()); + unsafe { + mul_mont( + r.as_mut_ptr(), + r.as_ptr(), + a.as_ptr(), + m.as_ptr(), + n0, + r.len(), + cpu_features, + ) + } +} + +/// r = a * b +#[cfg(not(target_arch = "x86_64"))] +pub(super) fn limbs_mont_product( + r: &mut [Limb], + a: &[Limb], + b: &[Limb], + m: &[Limb], + n0: &N0, + cpu_features: cpu::Features, +) { + debug_assert_eq!(r.len(), m.len()); + debug_assert_eq!(a.len(), m.len()); + debug_assert_eq!(b.len(), m.len()); + + unsafe { + mul_mont( + r.as_mut_ptr(), + a.as_ptr(), + b.as_ptr(), + m.as_ptr(), + n0, + r.len(), + cpu_features, + ) + } +} + +/// r = r**2 +pub(super) fn limbs_mont_square(r: &mut [Limb], m: &[Limb], n0: &N0, cpu_features: cpu::Features) { + debug_assert_eq!(r.len(), m.len()); + unsafe { + mul_mont( + r.as_mut_ptr(), + r.as_ptr(), + r.as_ptr(), + m.as_ptr(), + n0, + r.len(), + cpu_features, + ) + } +} #[cfg(test)] mod tests { use super::*;