From c922274cd014b28d7f98267e03be9b980cd228df Mon Sep 17 00:00:00 2001 From: Brian Smith Date: Wed, 6 Dec 2023 08:42:32 -0800 Subject: [PATCH] bigint: Remove `cpu::Features` from OwnedModulus. Since all the arithmetic is actually done on `Modulus` now, `OwnedModulus` doesn't need access to the CPU features. --- src/arithmetic/bigint.rs | 28 ++++++++----------- src/arithmetic/bigint/modulus.rs | 13 ++------- src/rsa/keypair.rs | 48 ++++++++++++++++++++------------ src/rsa/public_key.rs | 13 ++++++--- src/rsa/public_modulus.rs | 8 +++--- src/rsa/verification.rs | 7 +++-- 6 files changed, 63 insertions(+), 54 deletions(-) diff --git a/src/arithmetic/bigint.rs b/src/arithmetic/bigint.rs index e8a7ea22ee..b326c35e74 100644 --- a/src/arithmetic/bigint.rs +++ b/src/arithmetic/bigint.rs @@ -719,8 +719,8 @@ mod tests { |section, test_case| { assert_eq!(section, ""); - let m = consume_modulus::(test_case, "M", cpu_features); - let m = m.modulus(); + let m = consume_modulus::(test_case, "M"); + let m = m.modulus(cpu_features); let expected_result = consume_elem(test_case, "ModExp", &m); let base = consume_elem(test_case, "A", &m); let e = { @@ -749,8 +749,8 @@ mod tests { |section, test_case| { assert_eq!(section, ""); - let m = consume_modulus::(test_case, "M", cpu_features); - let m = m.modulus(); + let m = consume_modulus::(test_case, "M"); + let m = m.modulus(cpu_features); let expected_result = consume_elem(test_case, "ModMul", &m); let a = consume_elem(test_case, "A", &m); let b = consume_elem(test_case, "B", &m); @@ -774,8 +774,8 @@ mod tests { |section, test_case| { assert_eq!(section, ""); - let m = consume_modulus::(test_case, "M", cpu_features); - let m = m.modulus(); + let m = consume_modulus::(test_case, "M"); + let m = m.modulus(cpu_features); let expected_result = consume_elem(test_case, "ModSquare", &m); let a = consume_elem(test_case, "A", &m); @@ -799,8 +799,8 @@ mod tests { struct M {} - let m_ = consume_modulus::(test_case, "M", cpu_features); - let m = m_.modulus(); + let m_ = consume_modulus::(test_case, "M"); + let m = m_.modulus(cpu_features); let expected_result = consume_elem(test_case, "R", &m); let a = consume_elem_unchecked::(test_case, "A", expected_result.limbs.len() * 2); @@ -826,8 +826,8 @@ mod tests { struct M {} struct O {} - let m = consume_modulus::(test_case, "m", cpu_features); - let m = m.modulus(); + let m = consume_modulus::(test_case, "m"); + let m = m.modulus(cpu_features); let a = consume_elem_unchecked::(test_case, "a", m.limbs().len()); let expected_result = consume_elem::(test_case, "r", &m); let other_modulus_len_bits = m.len_bits(); @@ -864,13 +864,9 @@ mod tests { } } - fn consume_modulus( - test_case: &mut test::TestCase, - name: &str, - cpu_features: cpu::Features, - ) -> OwnedModulus { + fn consume_modulus(test_case: &mut test::TestCase, name: &str) -> OwnedModulus { let value = test_case.consume_bytes(name); - OwnedModulus::from_be_bytes(untrusted::Input::from(&value), cpu_features).unwrap() + OwnedModulus::from_be_bytes(untrusted::Input::from(&value)).unwrap() } fn assert_elem_eq(a: &Elem, b: &Elem) { diff --git a/src/arithmetic/bigint/modulus.rs b/src/arithmetic/bigint/modulus.rs index d10ff9c978..3f87053c01 100644 --- a/src/arithmetic/bigint/modulus.rs +++ b/src/arithmetic/bigint/modulus.rs @@ -75,8 +75,6 @@ pub struct OwnedModulus { n0: N0, len_bits: BitLength, - - cpu_features: cpu::Features, } impl Clone for OwnedModulus { @@ -85,16 +83,12 @@ impl Clone for OwnedModulus { limbs: self.limbs.clone(), n0: self.n0, len_bits: self.len_bits, - cpu_features: self.cpu_features, } } } impl OwnedModulus { - pub(crate) fn from_be_bytes( - input: untrusted::Input, - cpu_features: cpu::Features, - ) -> Result { + pub(crate) fn from_be_bytes(input: untrusted::Input) -> Result { let n = BoxedLimbs::positive_minimal_width_from_be_bytes(input)?; if n.len() > MODULUS_MAX_LIMBS { return Err(error::KeyRejected::too_large()); @@ -135,7 +129,6 @@ impl OwnedModulus { limbs: n, n0, len_bits, - cpu_features, }) } @@ -158,13 +151,13 @@ impl OwnedModulus { encoding: PhantomData, }) } - pub fn modulus(&self) -> Modulus { + pub(crate) fn modulus(&self, cpu_features: cpu::Features) -> Modulus { Modulus { limbs: &self.limbs, n0: self.n0, len_bits: self.len_bits, m: PhantomData, - cpu_features: self.cpu_features, + cpu_features, } } diff --git a/src/rsa/keypair.rs b/src/rsa/keypair.rs index f485182fc4..a2012789e6 100644 --- a/src/rsa/keypair.rs +++ b/src/rsa/keypair.rs @@ -285,7 +285,7 @@ impl KeyPair { )?; let n_one = public_key.inner().n().oneRR(); - let n = &public_key.inner().n().value(); + let n = &public_key.inner().n().value(cpu_features); // 6.4.1.4.3 says to skip 6.4.1.2.1 Step 2. @@ -338,7 +338,7 @@ impl KeyPair { // First, validate `2**half_n_bits < d`. Since 2**half_n_bits has a bit // length of half_n_bits + 1, this check gives us 2**half_n_bits <= d, // and knowing d is odd makes the inequality strict. - let d = bigint::OwnedModulus::::from_be_bytes(d, cpu_features) + let d = bigint::OwnedModulus::::from_be_bytes(d) .map_err(|_| error::KeyRejected::invalid_component())?; if !(n_bits.half_rounded_up() < d.len_bits()) { return Err(KeyRejected::inconsistent_components()); @@ -350,7 +350,7 @@ impl KeyPair { // Step 6.b is omitted as explained above. - let pm = &p.modulus.modulus(); + let pm = &p.modulus.modulus(cpu_features); // 6.4.1.4.3 - Step 7. @@ -371,8 +371,8 @@ impl KeyPair { // This should never fail since `n` and `e` were validated above. - let p = PrivateCrtPrime::new(p, dP)?; - let q = PrivateCrtPrime::new(q, dQ)?; + let p = PrivateCrtPrime::new(p, dP, cpu_features)?; + let q = PrivateCrtPrime::new(q, dQ, cpu_features)?; Ok(Self { p, @@ -416,7 +416,7 @@ impl PrivatePrime { n_bits: BitLength, cpu_features: cpu::Features, ) -> Result { - let p = bigint::OwnedModulus::from_be_bytes(p, cpu_features)?; + let p = bigint::OwnedModulus::from_be_bytes(p)?; // 5.c / 5.g: // @@ -438,7 +438,7 @@ impl PrivatePrime { // Steps 5.e and 5.f are omitted as explained above. - let oneRR = bigint::One::newRR(&p.modulus()); + let oneRR = bigint::One::newRR(&p.modulus(cpu_features)); Ok(Self { modulus: p, oneRR }) } @@ -453,8 +453,12 @@ struct PrivateCrtPrime { impl PrivateCrtPrime { /// Constructs a `PrivateCrtPrime` from the private prime `p` and `dP` where /// dP == d % (p - 1). - fn new(p: PrivatePrime, dP: untrusted::Input) -> Result { - let m = &p.modulus.modulus(); + fn new( + p: PrivatePrime, + dP: untrusted::Input, + cpu_features: cpu::Features, + ) -> Result { + let m = &p.modulus.modulus(cpu_features); // [NIST SP-800-56B rev. 1] 6.4.1.4.3 - Steps 7.a & 7.b. let dP = bigint::PrivateExponent::from_be_bytes_padded(dP, m) .map_err(|error::Unspecified| KeyRejected::inconsistent_components())?; @@ -482,8 +486,9 @@ fn elem_exp_consttime( c: &bigint::Elem, p: &PrivateCrtPrime, other_prime_len_bits: BitLength, + cpu_features: cpu::Features, ) -> Result, error::Unspecified> { - let m = &p.modulus.modulus(); + let m = &p.modulus.modulus(cpu_features); let c_mod_m = bigint::elem_reduced(c, m, other_prime_len_bits); let c_mod_m = bigint::elem_mul(p.oneRRR.as_ref(), c_mod_m, m); bigint::elem_exp_consttime(c_mod_m, &p.exponent, m) @@ -523,6 +528,8 @@ impl KeyPair { msg: &[u8], signature: &mut [u8], ) -> Result<(), error::Unspecified> { + let cpu_features = cpu::features(); + if signature.len() != self.public().modulus_len() { return Err(error::Unspecified); } @@ -537,7 +544,7 @@ impl KeyPair { // with Garner's algorithm. // Steps 1 and 2. - let m = self.private_exponentiate(signature)?; + let m = self.private_exponentiate(signature, cpu_features)?; // Step 3. m.fill_be_bytes(signature); @@ -552,13 +559,17 @@ impl KeyPair { /// leaked that would endanger the private key. /// /// Panics if `in_out` is not `self.public().modulus_len()`. - fn private_exponentiate(&self, base: &[u8]) -> Result, error::Unspecified> { + fn private_exponentiate( + &self, + base: &[u8], + cpu_features: cpu::Features, + ) -> Result, error::Unspecified> { assert_eq!(base.len(), self.public().modulus_len()); // RFC 8017 Section 5.1.2: RSADP, using the Chinese Remainder Theorem // with Garner's algorithm. - let n = &self.public.inner().n().value(); + let n = &self.public.inner().n().value(cpu_features); let n_one = self.public.inner().n().oneRR(); // Step 1. The value zero is also rejected. @@ -569,14 +580,14 @@ impl KeyPair { // Step 2.b.i. let q_bits = self.q.modulus.len_bits(); - let m_1 = elem_exp_consttime(&c, &self.p, q_bits)?; - let m_2 = elem_exp_consttime(&c, &self.q, self.p.modulus.len_bits())?; + let m_1 = elem_exp_consttime(&c, &self.p, q_bits, cpu_features)?; + let m_2 = elem_exp_consttime(&c, &self.q, self.p.modulus.len_bits(), cpu_features)?; // Step 2.b.ii isn't needed since there are only two primes. // Step 2.b.iii. let h = { - let p = &self.p.modulus.modulus(); + let p = &self.p.modulus.modulus(cpu_features); let m_2 = bigint::elem_reduced_once(&m_2, p, q_bits); let m_1_minus_m_2 = bigint::elem_sub(m_1, &m_2, p); bigint::elem_mul(&self.qInv, m_1_minus_m_2, p) @@ -605,7 +616,7 @@ impl KeyPair { // minimum value, since the relationship of `e` to `d`, `p`, and `q` is // not verified during `KeyPair` construction. { - let verify = self.public.inner().exponentiate_elem(&m); + let verify = self.public.inner().exponentiate_elem(&m, cpu_features); bigint::elem_verify_equal_consttime(&verify, &c)?; } @@ -623,6 +634,7 @@ mod tests { #[test] fn test_rsakeypair_private_exponentiate() { + let cpu = cpu::features(); test::run( test_file!("keypair_private_exponentiate_tests.txt"), |section, test_case| { @@ -645,7 +657,7 @@ mod tests { let mut padded = vec![0; key.public.modulus_len()]; let zeroes = padded.len() - test_case.len(); padded[zeroes..].copy_from_slice(test_case); - let _: bigint::Elem<_> = key.private_exponentiate(&padded).unwrap(); + let _: bigint::Elem<_> = key.private_exponentiate(&padded, cpu).unwrap(); } Ok(()) }, diff --git a/src/rsa/public_key.rs b/src/rsa/public_key.rs index 9fd875c870..df8b7f2628 100644 --- a/src/rsa/public_key.rs +++ b/src/rsa/public_key.rs @@ -144,8 +144,9 @@ impl Inner { &self, base: untrusted::Input, out_buffer: &'out mut [u8; PUBLIC_KEY_PUBLIC_MODULUS_MAX_LEN], + cpu_features: cpu::Features, ) -> Result<&'out [u8], error::Unspecified> { - let n = &self.n.value(); + let n = &self.n.value(cpu_features); // The encoded value of the base must be the same length as the modulus, // in bytes. @@ -162,7 +163,7 @@ impl Inner { } // Step 2. - let m = self.exponentiate_elem(&s); + let m = self.exponentiate_elem(&s, cpu_features); // Step 3. Ok(fill_be_bytes_n(m, self.n.len_bits(), out_buffer)) @@ -171,13 +172,17 @@ impl Inner { /// Calculates base**e (mod n). /// /// This is constant-time with respect to `base` only. - pub(super) fn exponentiate_elem(&self, base: &bigint::Elem) -> bigint::Elem { + pub(super) fn exponentiate_elem( + &self, + base: &bigint::Elem, + cpu_features: cpu::Features, + ) -> bigint::Elem { // The exponent was already checked to be at least 3. let exponent_without_low_bit = NonZeroU64::try_from(self.e.value().get() & !1).unwrap(); // The exponent was already checked to be odd. debug_assert_ne!(exponent_without_low_bit, self.e.value()); - let n = &self.n.value(); + let n = &self.n.value(cpu_features); let base_r = bigint::elem_mul(self.n.oneRR(), base.clone(), n); diff --git a/src/rsa/public_modulus.rs b/src/rsa/public_modulus.rs index 06365f15a7..f4bebe9e92 100644 --- a/src/rsa/public_modulus.rs +++ b/src/rsa/public_modulus.rs @@ -37,7 +37,7 @@ impl PublicModulus { const MIN_BITS: bits::BitLength = bits::BitLength::from_usize_bits(1024); // Step 3 / Step c for `n` (out of order). - let value = bigint::OwnedModulus::from_be_bytes(n, cpu_features)?; + let value = bigint::OwnedModulus::from_be_bytes(n)?; let bits = value.len_bits(); // Step 1 / Step a. XXX: SP800-56Br1 and SP800-89 require the length of @@ -52,7 +52,7 @@ impl PublicModulus { if bits > max_bits { return Err(error::KeyRejected::too_large()); } - let oneRR = bigint::One::newRR(&value.modulus()); + let oneRR = bigint::One::newRR(&value.modulus(cpu_features)); Ok(Self { value, oneRR }) } @@ -69,8 +69,8 @@ impl PublicModulus { self.value.len_bits() } - pub(super) fn value(&self) -> bigint::Modulus { - self.value.modulus() + pub(super) fn value(&self, cpu_features: cpu::Features) -> bigint::Modulus { + self.value.modulus(cpu_features) } pub(super) fn oneRR(&self) -> &bigint::Elem { diff --git a/src/rsa/verification.rs b/src/rsa/verification.rs index 80627816db..c128097e68 100644 --- a/src/rsa/verification.rs +++ b/src/rsa/verification.rs @@ -35,6 +35,7 @@ impl signature::VerificationAlgorithm for RsaParameters { ), msg, signature, + cpu::features(), ) } } @@ -184,6 +185,7 @@ where ), untrusted::Input::from(message), untrusted::Input::from(signature), + cpu::features(), ) } } @@ -193,6 +195,7 @@ pub(crate) fn verify_rsa_( (n, e): (untrusted::Input, untrusted::Input), msg: untrusted::Input, signature: untrusted::Input, + cpu_features: cpu::Features, ) -> Result<(), error::Unspecified> { let max_bits: bits::BitLength = bits::BitLength::from_usize_bytes(PUBLIC_KEY_PUBLIC_MODULUS_MAX_LEN)?; @@ -207,12 +210,12 @@ pub(crate) fn verify_rsa_( params.min_bits, max_bits, PublicExponent::_3, - cpu::features(), + cpu_features, )?; // RFC 8017 Section 5.2.2: RSAVP1. let mut decoded = [0u8; PUBLIC_KEY_PUBLIC_MODULUS_MAX_LEN]; - let decoded = key.exponentiate(signature, &mut decoded)?; + let decoded = key.exponentiate(signature, &mut decoded, cpu_features)?; // Verify the padded message is correct. let m_hash = digest::digest(params.padding_alg.digest_alg(), msg.as_slice_less_safe());