Skip to content

Commit

Permalink
ml-dsa: support for encoding PKCS#8 private keys (#892)
Browse files Browse the repository at this point in the history
  • Loading branch information
baloo authored Jan 26, 2025
1 parent f7e7312 commit e7c698a
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 11 deletions.
40 changes: 33 additions & 7 deletions ml-dsa/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@
//!
//! ```
//! use ml_dsa::{MlDsa65, KeyGen};
//! use signature::{Signer, Verifier};
//! use signature::{Keypair, Signer, Verifier};
//!
//! let mut rng = rand::thread_rng();
//! let kp = MlDsa65::key_gen(&mut rng);
//!
//! let msg = b"Hello world";
//! let sig = kp.signing_key.sign(msg);
//! let sig = kp.signing_key().sign(msg);
//!
//! assert!(kp.verifying_key.verify(msg, &sig).is_ok());
//! assert!(kp.verifying_key().verify(msg, &sig).is_ok());
//! ```
mod algebra;
Expand Down Expand Up @@ -71,9 +71,9 @@ use {

#[cfg(all(feature = "alloc", feature = "pkcs8"))]
use pkcs8::{
der::asn1::{BitString, BitStringRef},
der::asn1::{BitString, BitStringRef, OctetStringRef},
spki::{SignatureBitStringEncoding, SubjectPublicKeyInfo},
EncodePublicKey,
EncodePrivateKey, EncodePublicKey,
};

use crate::algebra::{AlgebraExt, Elem, NttMatrix, NttVector, Truncate, Vector};
Expand Down Expand Up @@ -178,10 +178,20 @@ fn message_representative(tr: &[u8], Mp: &[&[u8]]) -> B64 {
/// An ML-DSA key pair
pub struct KeyPair<P: MlDsaParams> {
/// The signing key of the key pair
pub signing_key: SigningKey<P>,
signing_key: SigningKey<P>,

/// The verifying key of the key pair
pub verifying_key: VerifyingKey<P>,
verifying_key: VerifyingKey<P>,

/// The seed this signing key was derived from
seed: B32,
}

impl<P: MlDsaParams> KeyPair<P> {
/// The signing key of the key pair
pub fn signing_key(&self) -> &SigningKey<P> {
&self.signing_key
}
}

impl<P: MlDsaParams> AsRef<VerifyingKey<P>> for KeyPair<P> {
Expand Down Expand Up @@ -234,6 +244,21 @@ where
Signature::<P>::ALGORITHM_IDENTIFIER;
}

#[cfg(all(feature = "alloc", feature = "pkcs8"))]
impl<P> EncodePrivateKey for KeyPair<P>
where
P: MlDsaParams,
P: AssociatedAlgorithmIdentifier<Params = AnyRef<'static>>,
{
fn to_pkcs8_der(&self) -> pkcs8::Result<der::SecretDocument> {
let pkcs8_key = pkcs8::PrivateKeyInfoRef::new(
P::ALGORITHM_IDENTIFIER,
OctetStringRef::new(&self.seed)?,
);
Ok(der::SecretDocument::encode_msg(&pkcs8_key)?)
}
}

/// An ML-DSA signing key
#[derive(Clone, PartialEq)]
pub struct SigningKey<P: MlDsaParams> {
Expand Down Expand Up @@ -793,6 +818,7 @@ where
KeyPair {
signing_key,
verifying_key,
seed: xi.clone(),
}
}
}
Expand Down
5 changes: 3 additions & 2 deletions ml-dsa/tests/key-gen.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use ml_dsa::*;

use hybrid_array::Array;
use signature::Keypair;
use std::{fs::read_to_string, path::PathBuf};

#[test]
Expand Down Expand Up @@ -32,8 +33,8 @@ fn verify<P: MlDsaParams>(tc: &acvp::TestCase) {
let sk_bytes = EncodedSigningKey::<P>::try_from(tc.sk.as_slice()).unwrap();

let kp = P::key_gen_internal(&seed);
let sk = kp.signing_key;
let vk = kp.verifying_key;
let sk = kp.signing_key().clone();
let vk = kp.verifying_key().clone();

// Verify correctness via serialization
assert_eq!(sk.encode(), sk_bytes);
Expand Down
10 changes: 8 additions & 2 deletions ml-dsa/tests/pkcs8.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use ml_dsa::{KeyPair, MlDsa44, MlDsa65, MlDsa87, MlDsaParams, SigningKey, Verify
use pkcs8::{
der::{pem::LineEnding, AnyRef},
spki::AssociatedAlgorithmIdentifier,
DecodePrivateKey, DecodePublicKey, EncodePublicKey,
DecodePrivateKey, DecodePublicKey, EncodePrivateKey, EncodePublicKey,
};
use signature::Keypair;

Expand All @@ -18,7 +18,13 @@ fn private_key_serialization() {
{
let sk = SigningKey::<P>::from_pkcs8_pem(private_bytes).expect("parse private key");
let kp = KeyPair::<P>::from_pkcs8_pem(private_bytes).expect("parse private key");
assert!(sk == kp.signing_key);
assert!(sk == *kp.signing_key());
assert_eq!(
kp.to_pkcs8_pem(LineEnding::LF)
.expect("serialize private seed")
.deref(),
private_bytes
);

let pk = VerifyingKey::<P>::from_public_key_pem(public_bytes).expect("parse public key");
assert_eq!(
Expand Down

0 comments on commit e7c698a

Please sign in to comment.