diff --git a/src/arithmetic/bigint.rs b/src/arithmetic/bigint.rs index 7f4bc3ab8f..257037ef01 100644 --- a/src/arithmetic/bigint.rs +++ b/src/arithmetic/bigint.rs @@ -45,7 +45,8 @@ pub(crate) use self::{ use super::{montgomery::*, LimbSliceError, MAX_LIMBS}; use crate::{ bits::BitLength, - c, error, + c, + error::{self, LenMismatchError}, limb::{self, Limb, LIMB_BITS}, }; use alloc::vec; @@ -716,11 +717,18 @@ pub fn elem_verify_equal_consttime( a: &Elem, b: &Elem, ) -> Result<(), error::Unspecified> { - if limb::limbs_equal_limbs_consttime(&a.limbs, &b.limbs).leak() { - Ok(()) - } else { - Err(error::Unspecified) + let equal = limb::limbs_equal_limbs_consttime(&a.limbs, &b.limbs) + .unwrap_or_else(unwrap_impossible_len_mismatch_error); + if !equal.leak() { + return Err(error::Unspecified); } + Ok(()) +} + +#[cold] +#[inline(never)] +fn unwrap_impossible_len_mismatch_error(LenMismatchError { .. }: LenMismatchError) -> T { + unreachable!() } #[cold] diff --git a/src/ec/suite_b/ecdsa/verification.rs b/src/ec/suite_b/ecdsa/verification.rs index b867a29303..6e5d134ccf 100644 --- a/src/ec/suite_b/ecdsa/verification.rs +++ b/src/ec/suite_b/ecdsa/verification.rs @@ -150,7 +150,7 @@ impl EcdsaVerificationAlgorithm { fn sig_r_equals_x(q: &Modulus, r: &Elem, x: &Elem, z2: &Elem) -> bool { let r_jacobian = q.elem_product(z2, r); let x = q.elem_unencoded(x); - q.elem_equals_vartime(&r_jacobian, &x) + q.elems_are_equal(&r_jacobian, &x).leak() } let mut r = self.ops.scalar_as_elem(&r); if sig_r_equals_x(q, &r, &x, &z2) { diff --git a/src/ec/suite_b/ops.rs b/src/ec/suite_b/ops.rs index 63425cf4fb..f9a7514ff0 100644 --- a/src/ec/suite_b/ops.rs +++ b/src/ec/suite_b/ops.rs @@ -143,9 +143,10 @@ impl Modulus { impl Modulus { #[inline] - pub fn elems_are_equal(&self, a: &Elem, b: &Elem) -> LimbMask { + pub fn elems_are_equal(&self, a: &Elem, b: &Elem) -> LimbMask { let num_limbs = self.num_limbs.into(); limbs_equal_limbs_consttime(&a.limbs[..num_limbs], &b.limbs[..num_limbs]) + .unwrap_or_else(unwrap_impossible_len_mismatch_error) } #[inline] @@ -434,11 +435,6 @@ impl PublicScalarOps { } impl Modulus { - pub fn elem_equals_vartime(&self, a: &Elem, b: &Elem) -> bool { - let num_limbs = self.num_limbs.into(); - limbs_equal_limbs_consttime(&a.limbs[..num_limbs], &b.limbs[..num_limbs]).leak() - } - pub fn elem_less_than_vartime(&self, a: &Elem, b: &PublicElem) -> bool { let num_limbs = self.num_limbs.into(); limbs_less_than_limbs_vartime(&a.limbs[..num_limbs], &b.limbs[..num_limbs]) @@ -604,6 +600,12 @@ fn parse_big_endian_fixed_consttime( Ok(r) } +#[cold] +#[inline(never)] +fn unwrap_impossible_len_mismatch_error(LenMismatchError { .. }: LenMismatchError) -> T { + unreachable!() +} + #[cfg(test)] mod tests { extern crate alloc; diff --git a/src/limb.rs b/src/limb.rs index f54971f20c..1c284522c5 100644 --- a/src/limb.rs +++ b/src/limb.rs @@ -40,12 +40,12 @@ pub const LIMB_BYTES: usize = (LIMB_BITS + 7) / 8; pub type LimbMask = constant_time::BoolMask; #[inline] -pub fn limbs_equal_limbs_consttime(a: &[Limb], b: &[Limb]) -> LimbMask { - prefixed_extern! { - fn LIMBS_equal(a: *const Limb, b: *const Limb, num_limbs: c::size_t) -> LimbMask; +pub fn limbs_equal_limbs_consttime(a: &[Limb], b: &[Limb]) -> Result { + if a.len() != b.len() { + return Err(LenMismatchError::new(a.len())); } - assert_eq!(a.len(), b.len()); - unsafe { LIMBS_equal(a.as_ptr(), b.as_ptr(), a.len()) } + let all = a.iter().zip(b).fold(0, |running, (a, b)| running | (a ^ b)); + Ok(limb_is_zero_constant_time(all)) } #[inline]