Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
heliannuuthus committed Aug 12, 2024
1 parent 45e39b8 commit cde4567
Show file tree
Hide file tree
Showing 9 changed files with 221 additions and 165 deletions.
2 changes: 1 addition & 1 deletion sm2/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ The SM2 cryptosystem is composed of three distinct algorithms:

- [x] **SM2DSA**: digital signature algorithm defined in [GBT.32918.2-2016], [ISO.IEC.14888-3] (SM2-2)
- [ ] **SM2KEP**: key exchange protocol defined in [GBT.32918.3-2016] (SM2-3)
- [ ] **SM2PKE**: public key encryption algorithm defined in [GBT.32918.4-2016] (SM2-4)
- [x] **SM2PKE**: public key encryption algorithm defined in [GBT.32918.4-2016] (SM2-4)

## Minimum Supported Rust Version

Expand Down
1 change: 0 additions & 1 deletion sm2/sm2.bin

This file was deleted.

5 changes: 0 additions & 5 deletions sm2/sm2.key

This file was deleted.

4 changes: 0 additions & 4 deletions sm2/sm2.pub

This file was deleted.

2 changes: 1 addition & 1 deletion sm2/src/arithmetic/field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@ use core::{
iter::{Product, Sum},
ops::{AddAssign, MulAssign, Neg, SubAssign},
};
use elliptic_curve::ops::Invert;
use elliptic_curve::{
bigint::Limb,
ff::PrimeField,
ops::Invert,
subtle::{Choice, ConstantTimeEq, CtOption},
};

Expand Down
105 changes: 83 additions & 22 deletions sm2/src/pke.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,56 @@
//! SM2 Encryption Algorithm (SM2) as defined in [draft-shen-sm2-ecdsa § 5].
//!
//! ## Usage
//!
//! NOTE: requires the `sm3` crate for digest functions and the `primeorder` crate for prime field operations.
//!
//! The `DecryptingKey` struct is used for decrypting messages that were encrypted using the SM2 encryption algorithm.
//! It is initialized with a `SecretKey` or a non-zero scalar value and can decrypt ciphertexts using the specified decryption mode.
#![cfg_attr(feature = "std", doc = "```")]
#![cfg_attr(not(feature = "std"), doc = "```ignore")]
//! # fn example() -> Result<(), Box<dyn std::error::Error>> {
//! use rand_core::OsRng; // requires 'getrandom` feature
//! use sm2::{
//! pke::{EncryptingKey, Mode},
//! {SecretKey, PublicKey}
//!
//! };
//!
//! // Encrypting
//! let secret_key = SecretKey::random(&mut OsRng); // serialize with `::to_bytes()`
//! let encrypting_key = EncryptingKey::new_with_mode(secret_key, Mode::C1C2C3);
//! let plaintext = b"plaintext";
//! let ciphertext = encrypting_key.encrypt(plaintext)?;
//!
//! use sm2::pke::DecryptingKey;
//! // Decrypting
//! let decrypting_key = DecryptingKey::new_with_mode(secret_key, Mode::C1C2C3);
//! assert_eq!(decrypting_key.decrypt(&ciphertext)?, plaintext);
//!
//! // Encrypting asn.1
//! let ciphertext = encrypting_key.encrypt_asna1(plaintext)?;
//!
//! // Decrypting asn.1
//! assert_eq!(decrypting_key.decrypt_asna1(&ciphertext)?, plaintext);
//! # }
//! ```
//!
//!
//!
use core::cmp::min;

use crate::AffinePoint;

#[cfg(feature = "alloc")]
use alloc::vec;

use elliptic_curve::pkcs8::der::asn1::UintRef;
use elliptic_curve::pkcs8::der::Decode;
use elliptic_curve::pkcs8::der::DecodeValue;
use elliptic_curve::pkcs8::der::Encode;
use elliptic_curve::pkcs8::der::Length;
use elliptic_curve::pkcs8::der::Reader;
use elliptic_curve::pkcs8::der::Sequence;
use elliptic_curve::pkcs8::der::Writer;
use elliptic_curve::{
bigint::{Encoding, Uint, U256},
pkcs8::der::{
asn1::UintRef, Decode, DecodeValue, Encode, Length, Reader, Sequence, Tag, Writer,
},
};

use elliptic_curve::{
pkcs8::der::{asn1::OctetStringRef, EncodeValue},
Expand All @@ -29,10 +67,21 @@ mod encrypting;
#[cfg(feature = "arithmetic")]
pub use self::{decrypting::DecryptingKey, encrypting::EncryptingKey};

/// https://search.r-project.org/CRAN/refmans/smcryptoR/html/sm2_encrypt_asn1.html
/// Modes for the cipher encoding/decoding.
#[derive(Clone, Copy, Debug)]
pub enum Mode {
/// old mode
C1C2C3,
/// new mode
C1C3C2,
}
/// Represents a cipher structure containing encryption-related data (asn.1 format).
///
/// The `Cipher` structure includes the coordinates of the elliptic curve point (`x`, `y`),
/// the digest of the message, and the encrypted cipher text.
pub struct Cipher<'a> {
x: &'a [u8],
y: &'a [u8],
x: U256,
y: U256,
digest: &'a [u8],
cipher: &'a [u8],
}
Expand All @@ -41,14 +90,15 @@ impl<'a> Sequence<'a> for Cipher<'a> {}

impl<'a> EncodeValue for Cipher<'a> {
fn value_len(&self) -> elliptic_curve::pkcs8::der::Result<Length> {
UintRef::new(&self.x)?.encoded_len()?
+ UintRef::new(&self.y)?.encoded_len()?
UintRef::new(&self.x.to_be_bytes())?.encoded_len()?
+ UintRef::new(&self.y.to_be_bytes())?.encoded_len()?
+ OctetStringRef::new(&self.digest)?.encoded_len()?
+ OctetStringRef::new(&self.cipher)?.encoded_len()?
}

fn encode_value(&self, writer: &mut impl Writer) -> elliptic_curve::pkcs8::der::Result<()> {
UintRef::new(&self.x)?.encode(writer)?;
UintRef::new(&self.y)?.encode(writer)?;
UintRef::new(&self.x.to_be_bytes())?.encode(writer)?;
UintRef::new(&self.y.to_be_bytes())?.encode(writer)?;
OctetStringRef::new(&self.digest)?.encode(writer)?;
OctetStringRef::new(&self.cipher)?.encode(writer)?;
Ok(())
Expand All @@ -68,20 +118,16 @@ impl<'a> DecodeValue<'a> for Cipher<'a> {
let digest = OctetStringRef::decode(nr)?.into();
let cipher = OctetStringRef::decode(nr)?.into();
Ok(Cipher {
x,
y,
x: Uint::from_be_bytes(zero_byte_slice(x)?),
y: Uint::from_be_bytes(zero_byte_slice(y)?),
digest,
cipher,
})
})
}
}
#[derive(Clone, Copy, Debug)]
pub enum Mode {
C1C2C3,
C1C3C2,
}

/// Performs key derivation using a hash function and elliptic curve point.
fn kdf(hasher: &mut dyn DynDigest, kpb: AffinePoint, c2: &mut [u8]) -> Result<()> {
let klen = c2.len();
let mut ct: i32 = 0x00000001;
Expand All @@ -107,8 +153,23 @@ fn kdf(hasher: &mut dyn DynDigest, kpb: AffinePoint, c2: &mut [u8]) -> Result<()
Ok(())
}

/// XORs a portion of the buffer `c2` with a hash value.
fn xor(c2: &mut [u8], ha: &[u8], offset: usize, xor_len: usize) {
for i in 0..xor_len {
c2[offset + i] ^= ha[i];
}
}

/// Converts a byte slice to a fixed-size array, padding with leading zeroes if necessary.
pub(crate) fn zero_byte_slice<const N: usize>(
bytes: &[u8],
) -> elliptic_curve::pkcs8::der::Result<[u8; N]> {
let num_zeroes = N
.checked_sub(bytes.len())
.ok_or_else(|| Tag::Integer.length_error())?;

// Copy input into `N`-sized output buffer with leading zeroes
let mut output = [0u8; N];
output[num_zeroes..].copy_from_slice(bytes);
Ok(output)
}
106 changes: 58 additions & 48 deletions sm2/src/pke/decrypting.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,25 @@
use core::fmt::{self, Debug};

use crate::arithmetic::field::FieldElement;
use crate::{AffinePoint, EncodedPoint, FieldBytes, NonZeroScalar, PublicKey, Scalar, SecretKey};

use alloc::vec::Vec;
use elliptic_curve::ops::Reduce;
use elliptic_curve::pkcs8::der::Decode;
use elliptic_curve::sec1::FromEncodedPoint;
use elliptic_curve::subtle::{Choice, ConstantTimeEq};
use elliptic_curve::Error;
use elliptic_curve::{bigint::U256, sec1::ToEncodedPoint, Group, Result};
use crate::{
arithmetic::field::FieldElement, AffinePoint, EncodedPoint, FieldBytes, NonZeroScalar,
PublicKey, Scalar, SecretKey,
};

use alloc::{borrow::ToOwned, vec::Vec};
use elliptic_curve::{
bigint::U256,
ops::Reduce,
pkcs8::der::Decode,
sec1::{FromEncodedPoint, ToEncodedPoint},
subtle::{Choice, ConstantTimeEq},
Error, Group, Result,
};
use primeorder::PrimeField;

use sm3::digest::DynDigest;
use sm3::{Digest, Sm3};

use super::encrypting::EncryptingKey;
use super::{kdf, vec, Cipher, Mode};
use sm3::{digest::DynDigest, Digest, Sm3};

use super::{encrypting::EncryptingKey, kdf, vec, Cipher, Mode};
/// Represents a decryption key used for decrypting messages using elliptic curve cryptography.
#[derive(Clone)]
pub struct DecryptingKey {
secret_scalar: NonZeroScalar,
Expand All @@ -26,10 +28,12 @@ pub struct DecryptingKey {
}

impl DecryptingKey {
/// Creates a new `DecryptingKey` from a `SecretKey` with the default decryption mode (`C1C3C2`).
pub fn new(secret_key: SecretKey) -> Self {
Self::new_with_mode(secret_key.to_nonzero_scalar(), Mode::C1C3C2)
}

/// Creates a new `DecryptingKey` from a non-zero scalar and sets the decryption mode.
pub fn new_with_mode(secret_scalar: NonZeroScalar, mode: Mode) -> Self {
Self {
secret_scalar,
Expand Down Expand Up @@ -79,37 +83,41 @@ impl DecryptingKey {
&self.encryting_key
}

/// Decrypt inplace
pub fn decrypt(&self, ciphertext: &mut [u8]) -> Result<Vec<u8>> {
/// Decrypts a ciphertext in-place using the default digest algorithm (`Sm3`).
pub fn decrypt(&self, ciphertext: &[u8]) -> Result<Vec<u8>> {
self.decrypt_digest::<Sm3>(ciphertext)
}
/// Decrypt inplace
pub fn decrypt_digest<D>(&self, ciphertext: &mut [u8]) -> Result<Vec<u8>>

/// Decrypts a ciphertext in-place using the specified digest algorithm.
pub fn decrypt_digest<D>(&self, ciphertext: &[u8]) -> Result<Vec<u8>>
where
D: 'static + Digest + DynDigest + Send + Sync,
{
let mut digest = D::new();
decrypt(&self.secret_scalar, self.mode, &mut digest, ciphertext)
}

/// Decrypts a ciphertext in-place from ASN.1 format using the default digest algorithm (`Sm3`).
pub fn decrypt_asna1(&self, ciphertext: &[u8]) -> Result<Vec<u8>> {
self.decrypt_asna1_digest::<Sm3>(ciphertext)
}

/// Decrypt inplace
/// Decrypts a ciphertext in-place from ASN.1 format using the specified digest algorithm.
pub fn decrypt_asna1_digest<D>(&self, ciphertext: &[u8]) -> Result<Vec<u8>>
where
D: 'static + Digest + DynDigest + Send + Sync,
{
let cipher =
Cipher::from_der(&ciphertext).map_err(|e| elliptic_curve::pkcs8::Error::from(e))?;

let prefix: &[u8] = &[0x04];
let x: [u8; 32] = cipher.x.to_be_bytes();
let y: [u8; 32] = cipher.y.to_be_bytes();
let mut cipher = match self.mode {
Mode::C1C2C3 => [&[0x04], cipher.x, cipher.y, cipher.cipher, cipher.digest].concat(),
Mode::C1C3C2 => [&[0x04], cipher.x, cipher.y, cipher.digest, cipher.cipher].concat(),
Mode::C1C2C3 => [prefix, &x, &y, cipher.cipher, cipher.digest].concat(),
Mode::C1C3C2 => [prefix, &x, &y, cipher.digest, cipher.cipher].concat(),
};

Ok(self.decrypt_digest::<D>(&mut cipher)?)
Ok(self.decrypt_digest::<D>(&mut cipher)?.to_vec())
}
}

Expand Down Expand Up @@ -150,58 +158,60 @@ fn decrypt(
secret_scalar: &Scalar,
mode: Mode,
hasher: &mut dyn DynDigest,
cipher: &mut [u8],
cipher: &[u8],
) -> Result<Vec<u8>> {
let q = U256::from_be_hex(FieldElement::MODULUS);
let c1_len = (q.bits() + 7) / 8 * 2 + 1;

let (c1, c) = cipher.split_at_mut(c1_len as usize);
// B1: get 𝐶1 from 𝐶
let (c1, c) = cipher.split_at(c1_len as usize);
let encoded_c1 = EncodedPoint::from_bytes(c1).unwrap();

// verify that point c1 satisfies the elliptic curve
let mut c1_point = AffinePoint::from_encoded_point(&encoded_c1).unwrap();

// B2: compute point 𝑆 = [ℎ]𝐶1
let s = c1_point * Scalar::reduce(U256::from_u32(FieldElement::S));

if s.is_identity().into() {
return Err(Error);
}

// B3: compute [𝑑𝐵]𝐶1 = (𝑥2, 𝑦2)
c1_point = (c1_point * secret_scalar).to_affine();

let digest_size = hasher.output_size();

let (c2, c3) = match mode {
Mode::C1C3C2 => {
let (c3, c2) = c.split_at_mut(digest_size);
let (c3, c2) = c.split_at(digest_size);
(c2, c3)
}
Mode::C1C2C3 => c.split_at_mut(c.len() - digest_size),
Mode::C1C2C3 => c.split_at(c.len() - digest_size),
};

kdf(hasher, c1_point, c2)?;
// B4: compute 𝑡 = 𝐾𝐷𝐹(𝑥2 ∥ 𝑦2, 𝑘𝑙𝑒𝑛)
// B5: get 𝐶2 from 𝐶 and compute 𝑀′ = 𝐶2 ⊕ t
let mut c2 = c2.to_owned();
kdf(hasher, c1_point, &mut c2)?;

let mut c3_checked = vec![0u8; digest_size];
// compute 𝑢 = 𝐻𝑎𝑠ℎ(𝑥2 ∥ 𝑀′∥ 𝑦2).
let mut u = vec![0u8; digest_size];
let encode_point = c1_point.to_encoded_point(false);

hasher.update(&encode_point.x().unwrap());
hasher.update(c2);
hasher.update(&mut c2);
hasher.update(&encode_point.y().unwrap());
hasher
.finalize_into_reset(&mut c3_checked)
.map_err(|_e| Error)?;

let checked =
c3_checked
.iter()
.zip(c3)
.fold(0, |mut check, (&c3_byte, &mut c3checked_byte)| {
check |= c3_byte ^ c3checked_byte;
check
});

hasher.finalize_into_reset(&mut u).map_err(|_e| Error)?;
let checked = u
.iter()
.zip(c3)
.fold(0, |mut check, (&c3_byte, &c3checked_byte)| {
check |= c3_byte ^ c3checked_byte;
check
});

// If 𝑢 ≠ 𝐶3, output “ERROR” and exit
if checked != 0 {
return Err(Error);
}

// B7: output the plaintext 𝑀′.
Ok(c2.to_vec())
}
Loading

0 comments on commit cde4567

Please sign in to comment.