Skip to content

Commit

Permalink
bigint: Remove cpu::Features from OwnedModulus.
Browse files Browse the repository at this point in the history
Since all the arithmetic is actually done on `Modulus` now,
`OwnedModulus` doesn't need access to the CPU features.
  • Loading branch information
briansmith committed Dec 6, 2023
1 parent 7ae06ad commit 3227d9d
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 54 deletions.
28 changes: 12 additions & 16 deletions src/arithmetic/bigint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -719,8 +719,8 @@ mod tests {
|section, test_case| {
assert_eq!(section, "");

let m = consume_modulus::<M>(test_case, "M", cpu_features);
let m = m.modulus();
let m = consume_modulus::<M>(test_case, "M");
let m = m.modulus(cpu_features);
let expected_result = consume_elem(test_case, "ModExp", &m);
let base = consume_elem(test_case, "A", &m);
let e = {
Expand Down Expand Up @@ -749,8 +749,8 @@ mod tests {
|section, test_case| {
assert_eq!(section, "");

let m = consume_modulus::<M>(test_case, "M", cpu_features);
let m = m.modulus();
let m = consume_modulus::<M>(test_case, "M");
let m = m.modulus(cpu_features);
let expected_result = consume_elem(test_case, "ModMul", &m);
let a = consume_elem(test_case, "A", &m);
let b = consume_elem(test_case, "B", &m);
Expand All @@ -774,8 +774,8 @@ mod tests {
|section, test_case| {
assert_eq!(section, "");

let m = consume_modulus::<M>(test_case, "M", cpu_features);
let m = m.modulus();
let m = consume_modulus::<M>(test_case, "M");
let m = m.modulus(cpu_features);
let expected_result = consume_elem(test_case, "ModSquare", &m);
let a = consume_elem(test_case, "A", &m);

Expand All @@ -799,8 +799,8 @@ mod tests {

struct M {}

let m_ = consume_modulus::<M>(test_case, "M", cpu_features);
let m = m_.modulus();
let m_ = consume_modulus::<M>(test_case, "M");
let m = m_.modulus(cpu_features);
let expected_result = consume_elem(test_case, "R", &m);
let a =
consume_elem_unchecked::<M>(test_case, "A", expected_result.limbs.len() * 2);
Expand All @@ -826,8 +826,8 @@ mod tests {

struct M {}
struct O {}
let m = consume_modulus::<M>(test_case, "m", cpu_features);
let m = m.modulus();
let m = consume_modulus::<M>(test_case, "m");
let m = m.modulus(cpu_features);
let a = consume_elem_unchecked::<O>(test_case, "a", m.limbs().len());
let expected_result = consume_elem::<M>(test_case, "r", &m);
let other_modulus_len_bits = m.len_bits();
Expand Down Expand Up @@ -864,13 +864,9 @@ mod tests {
}
}

fn consume_modulus<M>(
test_case: &mut test::TestCase,
name: &str,
cpu_features: cpu::Features,
) -> OwnedModulus<M> {
fn consume_modulus<M>(test_case: &mut test::TestCase, name: &str) -> OwnedModulus<M> {
let value = test_case.consume_bytes(name);
OwnedModulus::from_be_bytes(untrusted::Input::from(&value), cpu_features).unwrap()
OwnedModulus::from_be_bytes(untrusted::Input::from(&value)).unwrap()
}

fn assert_elem_eq<M, E>(a: &Elem<M, E>, b: &Elem<M, E>) {
Expand Down
13 changes: 3 additions & 10 deletions src/arithmetic/bigint/modulus.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,6 @@ pub struct OwnedModulus<M> {
n0: N0,

len_bits: BitLength,

cpu_features: cpu::Features,
}

impl<M: PublicModulus> Clone for OwnedModulus<M> {
Expand All @@ -85,16 +83,12 @@ impl<M: PublicModulus> Clone for OwnedModulus<M> {
limbs: self.limbs.clone(),
n0: self.n0,
len_bits: self.len_bits,
cpu_features: self.cpu_features,
}
}
}

impl<M> OwnedModulus<M> {
pub(crate) fn from_be_bytes(
input: untrusted::Input,
cpu_features: cpu::Features,
) -> Result<Self, error::KeyRejected> {
pub(crate) fn from_be_bytes(input: untrusted::Input) -> Result<Self, error::KeyRejected> {
let n = BoxedLimbs::positive_minimal_width_from_be_bytes(input)?;
if n.len() > MODULUS_MAX_LIMBS {
return Err(error::KeyRejected::too_large());
Expand Down Expand Up @@ -135,7 +129,6 @@ impl<M> OwnedModulus<M> {
limbs: n,
n0,
len_bits,
cpu_features,
})
}

Expand All @@ -158,13 +151,13 @@ impl<M> OwnedModulus<M> {
encoding: PhantomData,
})
}
pub fn modulus(&self) -> Modulus<M> {
pub(crate) fn modulus(&self, cpu: cpu::Features) -> Modulus<M> {
Modulus {
limbs: &self.limbs,
n0: self.n0,
len_bits: self.len_bits,
m: PhantomData,
cpu_features: self.cpu_features,
cpu_features: cpu,
}
}

Expand Down
48 changes: 30 additions & 18 deletions src/rsa/keypair.rs
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ impl KeyPair {
)?;

let n_one = public_key.inner().n().oneRR();
let n = &public_key.inner().n().value();
let n = &public_key.inner().n().value(cpu_features);

// 6.4.1.4.3 says to skip 6.4.1.2.1 Step 2.

Expand Down Expand Up @@ -338,7 +338,7 @@ impl KeyPair {
// First, validate `2**half_n_bits < d`. Since 2**half_n_bits has a bit
// length of half_n_bits + 1, this check gives us 2**half_n_bits <= d,
// and knowing d is odd makes the inequality strict.
let d = bigint::OwnedModulus::<D>::from_be_bytes(d, cpu_features)
let d = bigint::OwnedModulus::<D>::from_be_bytes(d)
.map_err(|_| error::KeyRejected::invalid_component())?;
if !(n_bits.half_rounded_up() < d.len_bits()) {
return Err(KeyRejected::inconsistent_components());
Expand All @@ -350,7 +350,7 @@ impl KeyPair {

// Step 6.b is omitted as explained above.

let pm = &p.modulus.modulus();
let pm = &p.modulus.modulus(cpu_features);

// 6.4.1.4.3 - Step 7.

Expand All @@ -371,8 +371,8 @@ impl KeyPair {

// This should never fail since `n` and `e` were validated above.

let p = PrivateCrtPrime::new(p, dP)?;
let q = PrivateCrtPrime::new(q, dQ)?;
let p = PrivateCrtPrime::new(p, dP, cpu_features)?;
let q = PrivateCrtPrime::new(q, dQ, cpu_features)?;

Ok(Self {
p,
Expand Down Expand Up @@ -416,7 +416,7 @@ impl<M> PrivatePrime<M> {
n_bits: BitLength,
cpu_features: cpu::Features,
) -> Result<Self, KeyRejected> {
let p = bigint::OwnedModulus::from_be_bytes(p, cpu_features)?;
let p = bigint::OwnedModulus::from_be_bytes(p)?;

// 5.c / 5.g:
//
Expand All @@ -438,7 +438,7 @@ impl<M> PrivatePrime<M> {

// Steps 5.e and 5.f are omitted as explained above.

let oneRR = bigint::One::newRR(&p.modulus());
let oneRR = bigint::One::newRR(&p.modulus(cpu_features));

Ok(Self { modulus: p, oneRR })
}
Expand All @@ -453,8 +453,12 @@ struct PrivateCrtPrime<M> {
impl<M> PrivateCrtPrime<M> {
/// Constructs a `PrivateCrtPrime` from the private prime `p` and `dP` where
/// dP == d % (p - 1).
fn new(p: PrivatePrime<M>, dP: untrusted::Input) -> Result<Self, KeyRejected> {
let m = &p.modulus.modulus();
fn new(
p: PrivatePrime<M>,
dP: untrusted::Input,
cpu_features: cpu::Features,
) -> Result<Self, KeyRejected> {
let m = &p.modulus.modulus(cpu_features);
// [NIST SP-800-56B rev. 1] 6.4.1.4.3 - Steps 7.a & 7.b.
let dP = bigint::PrivateExponent::from_be_bytes_padded(dP, m)
.map_err(|error::Unspecified| KeyRejected::inconsistent_components())?;
Expand Down Expand Up @@ -482,8 +486,9 @@ fn elem_exp_consttime<M>(
c: &bigint::Elem<N>,
p: &PrivateCrtPrime<M>,
other_prime_len_bits: BitLength,
cpu_features: cpu::Features,
) -> Result<bigint::Elem<M>, error::Unspecified> {
let m = &p.modulus.modulus();
let m = &p.modulus.modulus(cpu_features);
let c_mod_m = bigint::elem_reduced(c, m, other_prime_len_bits);
let c_mod_m = bigint::elem_mul(p.oneRRR.as_ref(), c_mod_m, m);
bigint::elem_exp_consttime(c_mod_m, &p.exponent, m)
Expand Down Expand Up @@ -523,6 +528,8 @@ impl KeyPair {
msg: &[u8],
signature: &mut [u8],
) -> Result<(), error::Unspecified> {
let cpu_features = cpu::features();

if signature.len() != self.public().modulus_len() {
return Err(error::Unspecified);
}
Expand All @@ -537,7 +544,7 @@ impl KeyPair {
// with Garner's algorithm.

// Steps 1 and 2.
let m = self.private_exponentiate(signature)?;
let m = self.private_exponentiate(signature, cpu_features)?;

// Step 3.
m.fill_be_bytes(signature);
Expand All @@ -552,13 +559,17 @@ impl KeyPair {
/// leaked that would endanger the private key.
///
/// Panics if `in_out` is not `self.public().modulus_len()`.
fn private_exponentiate(&self, base: &[u8]) -> Result<bigint::Elem<N>, error::Unspecified> {
fn private_exponentiate(
&self,
base: &[u8],
cpu: cpu::Features,
) -> Result<bigint::Elem<N>, error::Unspecified> {
assert_eq!(base.len(), self.public().modulus_len());

// RFC 8017 Section 5.1.2: RSADP, using the Chinese Remainder Theorem
// with Garner's algorithm.

let n = &self.public.inner().n().value();
let n = &self.public.inner().n().value(cpu);
let n_one = self.public.inner().n().oneRR();

// Step 1. The value zero is also rejected.
Expand All @@ -569,14 +580,14 @@ impl KeyPair {

// Step 2.b.i.
let q_bits = self.q.modulus.len_bits();
let m_1 = elem_exp_consttime(&c, &self.p, q_bits)?;
let m_2 = elem_exp_consttime(&c, &self.q, self.p.modulus.len_bits())?;
let m_1 = elem_exp_consttime(&c, &self.p, q_bits, cpu)?;
let m_2 = elem_exp_consttime(&c, &self.q, self.p.modulus.len_bits(), cpu)?;

// Step 2.b.ii isn't needed since there are only two primes.

// Step 2.b.iii.
let h = {
let p = &self.p.modulus.modulus();
let p = &self.p.modulus.modulus(cpu);
let m_2 = bigint::elem_reduced_once(&m_2, p, q_bits);
let m_1_minus_m_2 = bigint::elem_sub(m_1, &m_2, p);
bigint::elem_mul(&self.qInv, m_1_minus_m_2, p)
Expand Down Expand Up @@ -605,7 +616,7 @@ impl KeyPair {
// minimum value, since the relationship of `e` to `d`, `p`, and `q` is
// not verified during `KeyPair` construction.
{
let verify = self.public.inner().exponentiate_elem(&m);
let verify = self.public.inner().exponentiate_elem(&m, cpu);
bigint::elem_verify_equal_consttime(&verify, &c)?;
}

Expand All @@ -623,6 +634,7 @@ mod tests {

#[test]
fn test_rsakeypair_private_exponentiate() {
let cpu = cpu::features();
test::run(
test_file!("keypair_private_exponentiate_tests.txt"),
|section, test_case| {
Expand All @@ -645,7 +657,7 @@ mod tests {
let mut padded = vec![0; key.public.modulus_len()];
let zeroes = padded.len() - test_case.len();
padded[zeroes..].copy_from_slice(test_case);
let _: bigint::Elem<_> = key.private_exponentiate(&padded).unwrap();
let _: bigint::Elem<_> = key.private_exponentiate(&padded, cpu).unwrap();
}
Ok(())
},
Expand Down
13 changes: 9 additions & 4 deletions src/rsa/public_key.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,9 @@ impl Inner {
&self,
base: untrusted::Input,
out_buffer: &'out mut [u8; PUBLIC_KEY_PUBLIC_MODULUS_MAX_LEN],
cpu: cpu::Features,
) -> Result<&'out [u8], error::Unspecified> {
let n = &self.n.value();
let n = &self.n.value(cpu);

// The encoded value of the base must be the same length as the modulus,
// in bytes.
Expand All @@ -162,7 +163,7 @@ impl Inner {
}

// Step 2.
let m = self.exponentiate_elem(&s);
let m = self.exponentiate_elem(&s, cpu);

// Step 3.
Ok(fill_be_bytes_n(m, self.n.len_bits(), out_buffer))
Expand All @@ -171,13 +172,17 @@ impl Inner {
/// Calculates base**e (mod n).
///
/// This is constant-time with respect to `base` only.
pub(super) fn exponentiate_elem(&self, base: &bigint::Elem<N>) -> bigint::Elem<N> {
pub(super) fn exponentiate_elem(
&self,
base: &bigint::Elem<N>,
cpu: cpu::Features,
) -> bigint::Elem<N> {
// The exponent was already checked to be at least 3.
let exponent_without_low_bit = NonZeroU64::try_from(self.e.value().get() & !1).unwrap();
// The exponent was already checked to be odd.
debug_assert_ne!(exponent_without_low_bit, self.e.value());

let n = &self.n.value();
let n = &self.n.value(cpu);

let base_r = bigint::elem_mul(self.n.oneRR(), base.clone(), n);

Expand Down
8 changes: 4 additions & 4 deletions src/rsa/public_modulus.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ impl PublicModulus {
const MIN_BITS: bits::BitLength = bits::BitLength::from_usize_bits(1024);

// Step 3 / Step c for `n` (out of order).
let value = bigint::OwnedModulus::from_be_bytes(n, cpu_features)?;
let value = bigint::OwnedModulus::from_be_bytes(n)?;
let bits = value.len_bits();

// Step 1 / Step a. XXX: SP800-56Br1 and SP800-89 require the length of
Expand All @@ -52,7 +52,7 @@ impl PublicModulus {
if bits > max_bits {
return Err(error::KeyRejected::too_large());
}
let oneRR = bigint::One::newRR(&value.modulus());
let oneRR = bigint::One::newRR(&value.modulus(cpu_features));

Ok(Self { value, oneRR })
}
Expand All @@ -69,8 +69,8 @@ impl PublicModulus {
self.value.len_bits()
}

pub(super) fn value(&self) -> bigint::Modulus<N> {
self.value.modulus()
pub(super) fn value(&self, cpu_features: cpu::Features) -> bigint::Modulus<N> {
self.value.modulus(cpu_features)
}

pub(super) fn oneRR(&self) -> &bigint::Elem<N, RR> {
Expand Down
7 changes: 5 additions & 2 deletions src/rsa/verification.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ impl signature::VerificationAlgorithm for RsaParameters {
),
msg,
signature,
cpu::features(),
)
}
}
Expand Down Expand Up @@ -184,6 +185,7 @@ where
),
untrusted::Input::from(message),
untrusted::Input::from(signature),
cpu::features(),
)
}
}
Expand All @@ -193,6 +195,7 @@ pub(crate) fn verify_rsa_(
(n, e): (untrusted::Input, untrusted::Input),
msg: untrusted::Input,
signature: untrusted::Input,
cpu: cpu::Features,
) -> Result<(), error::Unspecified> {
let max_bits: bits::BitLength =
bits::BitLength::from_usize_bytes(PUBLIC_KEY_PUBLIC_MODULUS_MAX_LEN)?;
Expand All @@ -207,12 +210,12 @@ pub(crate) fn verify_rsa_(
params.min_bits,
max_bits,
PublicExponent::_3,
cpu::features(),
cpu,
)?;

// RFC 8017 Section 5.2.2: RSAVP1.
let mut decoded = [0u8; PUBLIC_KEY_PUBLIC_MODULUS_MAX_LEN];
let decoded = key.exponentiate(signature, &mut decoded)?;
let decoded = key.exponentiate(signature, &mut decoded, cpu)?;

// Verify the padded message is correct.
let m_hash = digest::digest(params.padding_alg.digest_alg(), msg.as_slice_less_safe());
Expand Down

0 comments on commit 3227d9d

Please sign in to comment.