Skip to content

Commit

Permalink
python: Make safer to use
Browse files Browse the repository at this point in the history
Previously, a call to `.edhoc_exporter()` would happily have produced
exports from the all-zero key material.
  • Loading branch information
chrysn committed Nov 4, 2024
1 parent a11fe34 commit c136d6e
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 40 deletions.
48 changes: 30 additions & 18 deletions lakers-python/src/initiator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,19 @@ use lakers_crypto::{default_crypto, CryptoTrait};
use log::trace;
use pyo3::{prelude::*, types::PyBytes};

use super::StateMismatch;

#[pyclass(name = "EdhocInitiator")]
pub struct PyEdhocInitiator {
cred_i: Option<Credential>,
// FIXME: This does *not* get taken out, so some data stays available for longer than it needs
// to be -- but that is apparently needed in selected_cipher_suite and
// compute_ephemeral_secret.
start: InitiatorStart,
wait_m2: WaitM2,
processing_m2: ProcessingM2,
processed_m2: ProcessedM2,
completed: Completed,
wait_m2: Option<WaitM2>,
processing_m2: Option<ProcessingM2>,
processed_m2: Option<ProcessedM2>,
completed: Option<Completed>,
}

#[pymethods]
Expand All @@ -31,10 +36,10 @@ impl PyEdhocInitiator {
method: EDHOCMethod::StatStat.into(),
suites_i,
},
wait_m2: WaitM2::default(),
processing_m2: ProcessingM2::default(),
processed_m2: ProcessedM2::default(),
completed: Completed::default(),
wait_m2: None,
processing_m2: None,
processed_m2: None,
completed: None,
}
}

Expand All @@ -54,7 +59,7 @@ impl PyEdhocInitiator {

match i_prepare_message_1(&self.start, &mut default_crypto(), c_i, &ead_1) {
Ok((state, message_1)) => {
self.wait_m2 = state;
self.wait_m2 = Some(state);
Ok(PyBytes::new_bound(py, message_1.as_slice()))
}
Err(error) => Err(error.into()),
Expand All @@ -68,9 +73,13 @@ impl PyEdhocInitiator {
) -> PyResult<(Bound<'a, PyBytes>, Bound<'a, PyBytes>, Option<EADItem>)> {
let message_2 = EdhocMessageBuffer::new_from_slice(message_2.as_slice())?;

match i_parse_message_2(&self.wait_m2, &mut default_crypto(), &message_2) {
match i_parse_message_2(
&self.wait_m2.take().ok_or(StateMismatch)?,
&mut default_crypto(),
&message_2,
) {
Ok((state, c_r, id_cred_r, ead_2)) => {
self.processing_m2 = state;
self.processing_m2 = Some(state);
Ok((
PyBytes::new_bound(py, c_r.as_slice()),
PyBytes::new_bound(py, id_cred_r.bytes.as_slice()),
Expand All @@ -91,15 +100,15 @@ impl PyEdhocInitiator {
let valid_cred_r = valid_cred_r.to_credential()?;

match i_verify_message_2(
&self.processing_m2,
&self.processing_m2.take().ok_or(StateMismatch)?,
&mut default_crypto(),
valid_cred_r,
i.as_slice()
.try_into()
.expect("Wrong length of initiator private key"),
) {
Ok(state) => {
self.processed_m2 = state;
self.processed_m2 = Some(state);
self.cred_i = Some(cred_i);
Ok(())
}
Expand All @@ -115,14 +124,14 @@ impl PyEdhocInitiator {
ead_3: Option<EADItem>,
) -> PyResult<(Bound<'a, PyBytes>, Bound<'a, PyBytes>)> {
match i_prepare_message_3(
&mut self.processed_m2,
&mut self.processed_m2.take().ok_or(StateMismatch)?,
&mut default_crypto(),
self.cred_i.unwrap(),
cred_transfer,
&ead_3,
) {
Ok((state, message_3, prk_out)) => {
self.completed = state;
self.completed = Some(state);
Ok((
PyBytes::new_bound(py, message_3.as_slice()),
PyBytes::new_bound(py, prk_out.as_slice()),
Expand All @@ -143,7 +152,7 @@ impl PyEdhocInitiator {
context_buf[..context.len()].copy_from_slice(context.as_slice());

let res = edhoc_exporter(
&self.completed,
self.completed.as_ref().ok_or(StateMismatch)?,
&mut default_crypto(),
label,
&context_buf,
Expand All @@ -162,7 +171,7 @@ impl PyEdhocInitiator {
context_buf[..context.len()].copy_from_slice(context.as_slice());

let res = edhoc_key_update(
&mut self.completed,
self.completed.as_mut().ok_or(StateMismatch)?,
&mut default_crypto(),
&context_buf,
context.len(),
Expand All @@ -171,7 +180,10 @@ impl PyEdhocInitiator {
}

pub fn get_h_message_1<'a>(&self, py: Python<'a>) -> PyResult<Bound<'a, PyBytes>> {
Ok(PyBytes::new_bound(py, &self.wait_m2.h_message_1[..]))
Ok(PyBytes::new_bound(
py,
&self.wait_m2.as_ref().ok_or(StateMismatch)?.h_message_1[..],
))
}

pub fn compute_ephemeral_secret<'a>(
Expand Down
22 changes: 22 additions & 0 deletions lakers-python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,28 @@ mod ead_authz;
mod initiator;
mod responder;

/// Error raised when operations on a Python object did not happen in the sequence in which they
/// were intended.
///
/// This currently has no more detailed response because for every situation this can occur in,
/// there are different possible explainations that we can't get across easily in a single message.
/// For example, if `responder.processing_m1` is absent, that can be either because no message 1
/// was processed into it yet, or because message 2 was already generated.
#[derive(Debug)]
pub(crate) struct StateMismatch;

impl std::error::Error for StateMismatch {}
impl std::fmt::Display for StateMismatch {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Type state mismatch")
}
}
impl From<StateMismatch> for PyErr {
fn from(err: StateMismatch) -> PyErr {
pyo3::exceptions::PyRuntimeError::new_err(err.to_string())
}
}

// NOTE: throughout this implementation, we use Vec<u8> for incoming byte lists and PyBytes for outgoing byte lists.
// This is because the incoming lists of bytes are automatically converted to `Vec<u8>` by pyo3,
// but the outgoing ones must be explicitly converted to `PyBytes`.
Expand Down
57 changes: 35 additions & 22 deletions lakers-python/src/responder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,17 @@ use lakers_crypto::{default_crypto, CryptoTrait};
use log::trace;
use pyo3::{prelude::*, types::PyBytes};

use super::StateMismatch;

#[pyclass(name = "EdhocResponder")]
pub struct PyEdhocResponder {
r: Vec<u8>,
cred_r: Credential,
start: ResponderStart,
processing_m1: ProcessingM1,
wait_m3: WaitM3,
processing_m3: ProcessingM3,
completed: Completed,
start: Option<ResponderStart>,
processing_m1: Option<ProcessingM1>,
wait_m3: Option<WaitM3>,
processing_m3: Option<ProcessingM3>,
completed: Option<Completed>,
}

#[pymethods]
Expand All @@ -26,15 +28,15 @@ impl PyEdhocResponder {
Ok(Self {
r,
cred_r,
start: ResponderStart {
start: Some(ResponderStart {
method: EDHOCMethod::StatStat.into(),
y,
g_y,
},
processing_m1: ProcessingM1::default(),
wait_m3: WaitM3::default(),
processing_m3: ProcessingM3::default(),
completed: Completed::default(),
}),
processing_m1: None,
wait_m3: None,
processing_m3: None,
completed: None,
})
}

Expand All @@ -44,9 +46,12 @@ impl PyEdhocResponder {
message_1: Vec<u8>,
) -> PyResult<(Bound<'a, PyBytes>, Option<EADItem>)> {
let message_1 = EdhocMessageBuffer::new_from_slice(message_1.as_slice())?;
let (state, c_i, ead_1) =
r_process_message_1(&self.start, &mut default_crypto(), &message_1)?;
self.processing_m1 = state;
let (state, c_i, ead_1) = r_process_message_1(
&self.start.take().ok_or(StateMismatch)?,
&mut default_crypto(),
&message_1,
)?;
self.processing_m1 = Some(state);
let c_i = PyBytes::new_bound(py, c_i.as_slice());

Ok((c_i, ead_1))
Expand All @@ -73,7 +78,7 @@ impl PyEdhocResponder {
r.copy_from_slice(self.r.as_slice());

match r_prepare_message_2(
&self.processing_m1,
self.processing_m1.as_ref().ok_or(StateMismatch)?,
&mut default_crypto(),
self.cred_r,
&r,
Expand All @@ -82,7 +87,7 @@ impl PyEdhocResponder {
&ead_2,
) {
Ok((state, message_2)) => {
self.wait_m3 = state;
self.wait_m3 = Some(state);
Ok(PyBytes::new_bound(py, message_2.as_slice()))
}
Err(error) => Err(error.into()),
Expand All @@ -95,9 +100,13 @@ impl PyEdhocResponder {
message_3: Vec<u8>,
) -> PyResult<(Bound<'a, PyBytes>, Option<EADItem>)> {
let message_3 = EdhocMessageBuffer::new_from_slice(message_3.as_slice())?;
match r_parse_message_3(&mut self.wait_m3, &mut default_crypto(), &message_3) {
match r_parse_message_3(
&mut self.wait_m3.take().ok_or(StateMismatch)?,
&mut default_crypto(),
&message_3,
) {
Ok((state, id_cred_i, ead_3)) => {
self.processing_m3 = state;
self.processing_m3 = Some(state);
Ok((PyBytes::new_bound(py, id_cred_i.bytes.as_slice()), ead_3))
}
Err(error) => Err(error.into()),
Expand All @@ -110,9 +119,13 @@ impl PyEdhocResponder {
valid_cred_i: super::AutoCredential,
) -> PyResult<Bound<'a, PyBytes>> {
let valid_cred_i = valid_cred_i.to_credential()?;
match r_verify_message_3(&mut self.processing_m3, &mut default_crypto(), valid_cred_i) {
match r_verify_message_3(
&mut self.processing_m3.take().ok_or(StateMismatch)?,
&mut default_crypto(),
valid_cred_i,
) {
Ok((state, prk_out)) => {
self.completed = state;
self.completed = Some(state);
Ok(PyBytes::new_bound(py, prk_out.as_slice()))
}
Err(error) => Err(error.into()),
Expand All @@ -130,7 +143,7 @@ impl PyEdhocResponder {
context_buf[..context.len()].copy_from_slice(context.as_slice());

let res = edhoc_exporter(
&self.completed,
self.completed.as_ref().ok_or(StateMismatch)?,
&mut default_crypto(),
label,
&context_buf,
Expand All @@ -149,7 +162,7 @@ impl PyEdhocResponder {
context_buf[..context.len()].copy_from_slice(context.as_slice());

let res = edhoc_key_update(
&mut self.completed,
self.completed.as_mut().ok_or(StateMismatch)?,
&mut default_crypto(),
&context_buf,
context.len(),
Expand Down

0 comments on commit c136d6e

Please sign in to comment.