Skip to content

Commit

Permalink
Refactor AEAD code to make it more reusable (#9397)
Browse files Browse the repository at this point in the history
  • Loading branch information
alex authored Aug 10, 2023
1 parent 27e8b3d commit 1336f17
Showing 1 changed file with 76 additions and 65 deletions.
141 changes: 76 additions & 65 deletions src/rust/src/backend/aead.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,76 @@ use crate::buf::CffiBuf;
use crate::error::{CryptographyError, CryptographyResult};
use crate::exceptions;

fn check_length(data: &[u8]) -> CryptographyResult<()> {
if data.len() > (i32::MAX as usize) {
// This is OverflowError to match what cffi would raise
return Err(CryptographyError::from(
pyo3::exceptions::PyOverflowError::new_err(
"Data or associated data too long. Max 2**31 - 1 bytes",
),
));
}

Ok(())
}

fn encrypt_value<'p>(
py: pyo3::Python<'p>,
mut ctx: openssl::cipher_ctx::CipherCtx,
plaintext: &[u8],
tag_len: usize,
tag_first: bool,
) -> CryptographyResult<&'p pyo3::types::PyBytes> {
Ok(pyo3::types::PyBytes::new_with(
py,
plaintext.len() + tag_len,
|b| {
let ciphertext;
let tag;
// TODO: remove once we have a second AEAD implemented here.
assert!(tag_first);
(tag, ciphertext) = b.split_at_mut(tag_len);

let n = ctx
.cipher_update(plaintext, Some(ciphertext))
.map_err(CryptographyError::from)?;
assert_eq!(n, ciphertext.len());

let mut final_block = [0];
let n = ctx
.cipher_final(&mut final_block)
.map_err(CryptographyError::from)?;
assert_eq!(n, 0);

ctx.tag(tag).map_err(CryptographyError::from)?;

Ok(())
},
)?)
}

fn decrypt_value<'p>(
py: pyo3::Python<'p>,
mut ctx: openssl::cipher_ctx::CipherCtx,
ciphertext: &[u8],
) -> CryptographyResult<&'p pyo3::types::PyBytes> {
Ok(pyo3::types::PyBytes::new_with(py, ciphertext.len(), |b| {
// AES SIV can error here if the data is invalid on decrypt
let n = ctx
.cipher_update(ciphertext, Some(b))
.map_err(|_| exceptions::InvalidTag::new_err(()))?;
assert_eq!(n, b.len());

let mut final_block = [0];
let n = ctx
.cipher_final(&mut final_block)
.map_err(|_| exceptions::InvalidTag::new_err(()))?;
assert_eq!(n, 0);

Ok(())
})?)
}

#[pyo3::prelude::pyclass(
frozen,
module = "cryptography.hazmat.bindings._rust.openssl.aead",
Expand Down Expand Up @@ -85,59 +155,21 @@ impl AesSiv {
return Err(CryptographyError::from(
pyo3::exceptions::PyValueError::new_err("data must not be zero length"),
));
} else if data_bytes.len() > (i32::MAX as usize) {
// This is OverflowError to match what cffi would raise
return Err(CryptographyError::from(
pyo3::exceptions::PyOverflowError::new_err(
"Data or associated data too long. Max 2**31 - 1 bytes",
),
));
}
};
check_length(data_bytes)?;

let mut ctx = openssl::cipher_ctx::CipherCtx::new()?;
ctx.encrypt_init(Some(&self.cipher), Some(key_buf.as_bytes()), None)?;

if let Some(ads) = associated_data {
for ad in ads.iter() {
let ad = ad.extract::<CffiBuf<'_>>()?;
if ad.as_bytes().len() > (i32::MAX as usize) {
// This is OverflowError to match what cffi would raise
return Err(CryptographyError::from(
pyo3::exceptions::PyOverflowError::new_err(
"Data or associated data too long. Max 2**31 - 1 bytes",
),
));
}

check_length(ad.as_bytes())?;
ctx.cipher_update(ad.as_bytes(), None)?;
}
}

Ok(pyo3::types::PyBytes::new_with(
py,
data_bytes.len() + 16,
|b| {
// RFC 5297 defines the output as IV || C, where the tag we
// generate is the "IV" and C is the ciphertext. This is the
// opposite of our other AEADs, which are Ciphertext || Tag.
let (tag, ciphertext) = b.split_at_mut(16);

let n = ctx
.cipher_update(data_bytes, Some(ciphertext))
.map_err(CryptographyError::from)?;
assert_eq!(n, ciphertext.len());

let mut final_block = [0];
let n = ctx
.cipher_final(&mut final_block)
.map_err(CryptographyError::from)?;
assert_eq!(n, 0);

ctx.tag(tag).map_err(CryptographyError::from)?;

Ok(())
},
)?)
encrypt_value(py, ctx, data_bytes, 16, true)
}

fn decrypt<'p>(
Expand Down Expand Up @@ -170,34 +202,13 @@ impl AesSiv {
if let Some(ads) = associated_data {
for ad in ads.iter() {
let ad = ad.extract::<CffiBuf<'_>>()?;
if ad.as_bytes().len() > (i32::MAX as usize) {
// This is OverflowError to match what cffi would raise
return Err(CryptographyError::from(
pyo3::exceptions::PyOverflowError::new_err(
"Data or associated data too long. Max 2**31 - 1 bytes",
),
));
}
check_length(ad.as_bytes())?;

ctx.cipher_update(ad.as_bytes(), None)?;
}
}

Ok(pyo3::types::PyBytes::new_with(py, ciphertext.len(), |b| {
// AES SIV can error here if the data is invalid on decrypt
let n = ctx
.cipher_update(ciphertext, Some(b))
.map_err(|_| exceptions::InvalidTag::new_err(()))?;
assert_eq!(n, b.len());

let mut final_block = [0];
let n = ctx
.cipher_final(&mut final_block)
.map_err(|_| exceptions::InvalidTag::new_err(()))?;
assert_eq!(n, 0);

Ok(())
})?)
decrypt_value(py, ctx, ciphertext)
}
}

Expand Down

0 comments on commit 1336f17

Please sign in to comment.