From c136d6eef429c2cc000d207c31e74c9b203ada2a Mon Sep 17 00:00:00 2001 From: chrysn Date: Mon, 4 Nov 2024 15:18:42 +0000 Subject: [PATCH] python: Make safer to use Previously, a call to `.edhoc_exporter()` would happily have produced exports from the all-zero key material. --- lakers-python/src/initiator.rs | 48 +++++++++++++++++----------- lakers-python/src/lib.rs | 22 +++++++++++++ lakers-python/src/responder.rs | 57 +++++++++++++++++++++------------- 3 files changed, 87 insertions(+), 40 deletions(-) diff --git a/lakers-python/src/initiator.rs b/lakers-python/src/initiator.rs index c9e080b0..93760982 100644 --- a/lakers-python/src/initiator.rs +++ b/lakers-python/src/initiator.rs @@ -3,14 +3,19 @@ use lakers_crypto::{default_crypto, CryptoTrait}; use log::trace; use pyo3::{prelude::*, types::PyBytes}; +use super::StateMismatch; + #[pyclass(name = "EdhocInitiator")] pub struct PyEdhocInitiator { cred_i: Option, + // FIXME: This does *not* get taken out, so some data stays available for longer than it needs + // to be -- but that is apparently needed in selected_cipher_suite and + // compute_ephemeral_secret. start: InitiatorStart, - wait_m2: WaitM2, - processing_m2: ProcessingM2, - processed_m2: ProcessedM2, - completed: Completed, + wait_m2: Option, + processing_m2: Option, + processed_m2: Option, + completed: Option, } #[pymethods] @@ -31,10 +36,10 @@ impl PyEdhocInitiator { method: EDHOCMethod::StatStat.into(), suites_i, }, - wait_m2: WaitM2::default(), - processing_m2: ProcessingM2::default(), - processed_m2: ProcessedM2::default(), - completed: Completed::default(), + wait_m2: None, + processing_m2: None, + processed_m2: None, + completed: None, } } @@ -54,7 +59,7 @@ impl PyEdhocInitiator { match i_prepare_message_1(&self.start, &mut default_crypto(), c_i, &ead_1) { Ok((state, message_1)) => { - self.wait_m2 = state; + self.wait_m2 = Some(state); Ok(PyBytes::new_bound(py, message_1.as_slice())) } Err(error) => Err(error.into()), @@ -68,9 +73,13 @@ impl PyEdhocInitiator { ) -> PyResult<(Bound<'a, PyBytes>, Bound<'a, PyBytes>, Option)> { let message_2 = EdhocMessageBuffer::new_from_slice(message_2.as_slice())?; - match i_parse_message_2(&self.wait_m2, &mut default_crypto(), &message_2) { + match i_parse_message_2( + &self.wait_m2.take().ok_or(StateMismatch)?, + &mut default_crypto(), + &message_2, + ) { Ok((state, c_r, id_cred_r, ead_2)) => { - self.processing_m2 = state; + self.processing_m2 = Some(state); Ok(( PyBytes::new_bound(py, c_r.as_slice()), PyBytes::new_bound(py, id_cred_r.bytes.as_slice()), @@ -91,7 +100,7 @@ impl PyEdhocInitiator { let valid_cred_r = valid_cred_r.to_credential()?; match i_verify_message_2( - &self.processing_m2, + &self.processing_m2.take().ok_or(StateMismatch)?, &mut default_crypto(), valid_cred_r, i.as_slice() @@ -99,7 +108,7 @@ impl PyEdhocInitiator { .expect("Wrong length of initiator private key"), ) { Ok(state) => { - self.processed_m2 = state; + self.processed_m2 = Some(state); self.cred_i = Some(cred_i); Ok(()) } @@ -115,14 +124,14 @@ impl PyEdhocInitiator { ead_3: Option, ) -> PyResult<(Bound<'a, PyBytes>, Bound<'a, PyBytes>)> { match i_prepare_message_3( - &mut self.processed_m2, + &mut self.processed_m2.take().ok_or(StateMismatch)?, &mut default_crypto(), self.cred_i.unwrap(), cred_transfer, &ead_3, ) { Ok((state, message_3, prk_out)) => { - self.completed = state; + self.completed = Some(state); Ok(( PyBytes::new_bound(py, message_3.as_slice()), PyBytes::new_bound(py, prk_out.as_slice()), @@ -143,7 +152,7 @@ impl PyEdhocInitiator { context_buf[..context.len()].copy_from_slice(context.as_slice()); let res = edhoc_exporter( - &self.completed, + self.completed.as_ref().ok_or(StateMismatch)?, &mut default_crypto(), label, &context_buf, @@ -162,7 +171,7 @@ impl PyEdhocInitiator { context_buf[..context.len()].copy_from_slice(context.as_slice()); let res = edhoc_key_update( - &mut self.completed, + self.completed.as_mut().ok_or(StateMismatch)?, &mut default_crypto(), &context_buf, context.len(), @@ -171,7 +180,10 @@ impl PyEdhocInitiator { } pub fn get_h_message_1<'a>(&self, py: Python<'a>) -> PyResult> { - Ok(PyBytes::new_bound(py, &self.wait_m2.h_message_1[..])) + Ok(PyBytes::new_bound( + py, + &self.wait_m2.as_ref().ok_or(StateMismatch)?.h_message_1[..], + )) } pub fn compute_ephemeral_secret<'a>( diff --git a/lakers-python/src/lib.rs b/lakers-python/src/lib.rs index 89309ef2..d89a22c1 100644 --- a/lakers-python/src/lib.rs +++ b/lakers-python/src/lib.rs @@ -12,6 +12,28 @@ mod ead_authz; mod initiator; mod responder; +/// Error raised when operations on a Python object did not happen in the sequence in which they +/// were intended. +/// +/// This currently has no more detailed response because for every situation this can occur in, +/// there are different possible explainations that we can't get across easily in a single message. +/// For example, if `responder.processing_m1` is absent, that can be either because no message 1 +/// was processed into it yet, or because message 2 was already generated. +#[derive(Debug)] +pub(crate) struct StateMismatch; + +impl std::error::Error for StateMismatch {} +impl std::fmt::Display for StateMismatch { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Type state mismatch") + } +} +impl From for PyErr { + fn from(err: StateMismatch) -> PyErr { + pyo3::exceptions::PyRuntimeError::new_err(err.to_string()) + } +} + // NOTE: throughout this implementation, we use Vec for incoming byte lists and PyBytes for outgoing byte lists. // This is because the incoming lists of bytes are automatically converted to `Vec` by pyo3, // but the outgoing ones must be explicitly converted to `PyBytes`. diff --git a/lakers-python/src/responder.rs b/lakers-python/src/responder.rs index 16d6aaad..f70e430e 100644 --- a/lakers-python/src/responder.rs +++ b/lakers-python/src/responder.rs @@ -3,15 +3,17 @@ use lakers_crypto::{default_crypto, CryptoTrait}; use log::trace; use pyo3::{prelude::*, types::PyBytes}; +use super::StateMismatch; + #[pyclass(name = "EdhocResponder")] pub struct PyEdhocResponder { r: Vec, cred_r: Credential, - start: ResponderStart, - processing_m1: ProcessingM1, - wait_m3: WaitM3, - processing_m3: ProcessingM3, - completed: Completed, + start: Option, + processing_m1: Option, + wait_m3: Option, + processing_m3: Option, + completed: Option, } #[pymethods] @@ -26,15 +28,15 @@ impl PyEdhocResponder { Ok(Self { r, cred_r, - start: ResponderStart { + start: Some(ResponderStart { method: EDHOCMethod::StatStat.into(), y, g_y, - }, - processing_m1: ProcessingM1::default(), - wait_m3: WaitM3::default(), - processing_m3: ProcessingM3::default(), - completed: Completed::default(), + }), + processing_m1: None, + wait_m3: None, + processing_m3: None, + completed: None, }) } @@ -44,9 +46,12 @@ impl PyEdhocResponder { message_1: Vec, ) -> PyResult<(Bound<'a, PyBytes>, Option)> { let message_1 = EdhocMessageBuffer::new_from_slice(message_1.as_slice())?; - let (state, c_i, ead_1) = - r_process_message_1(&self.start, &mut default_crypto(), &message_1)?; - self.processing_m1 = state; + let (state, c_i, ead_1) = r_process_message_1( + &self.start.take().ok_or(StateMismatch)?, + &mut default_crypto(), + &message_1, + )?; + self.processing_m1 = Some(state); let c_i = PyBytes::new_bound(py, c_i.as_slice()); Ok((c_i, ead_1)) @@ -73,7 +78,7 @@ impl PyEdhocResponder { r.copy_from_slice(self.r.as_slice()); match r_prepare_message_2( - &self.processing_m1, + self.processing_m1.as_ref().ok_or(StateMismatch)?, &mut default_crypto(), self.cred_r, &r, @@ -82,7 +87,7 @@ impl PyEdhocResponder { &ead_2, ) { Ok((state, message_2)) => { - self.wait_m3 = state; + self.wait_m3 = Some(state); Ok(PyBytes::new_bound(py, message_2.as_slice())) } Err(error) => Err(error.into()), @@ -95,9 +100,13 @@ impl PyEdhocResponder { message_3: Vec, ) -> PyResult<(Bound<'a, PyBytes>, Option)> { let message_3 = EdhocMessageBuffer::new_from_slice(message_3.as_slice())?; - match r_parse_message_3(&mut self.wait_m3, &mut default_crypto(), &message_3) { + match r_parse_message_3( + &mut self.wait_m3.take().ok_or(StateMismatch)?, + &mut default_crypto(), + &message_3, + ) { Ok((state, id_cred_i, ead_3)) => { - self.processing_m3 = state; + self.processing_m3 = Some(state); Ok((PyBytes::new_bound(py, id_cred_i.bytes.as_slice()), ead_3)) } Err(error) => Err(error.into()), @@ -110,9 +119,13 @@ impl PyEdhocResponder { valid_cred_i: super::AutoCredential, ) -> PyResult> { let valid_cred_i = valid_cred_i.to_credential()?; - match r_verify_message_3(&mut self.processing_m3, &mut default_crypto(), valid_cred_i) { + match r_verify_message_3( + &mut self.processing_m3.take().ok_or(StateMismatch)?, + &mut default_crypto(), + valid_cred_i, + ) { Ok((state, prk_out)) => { - self.completed = state; + self.completed = Some(state); Ok(PyBytes::new_bound(py, prk_out.as_slice())) } Err(error) => Err(error.into()), @@ -130,7 +143,7 @@ impl PyEdhocResponder { context_buf[..context.len()].copy_from_slice(context.as_slice()); let res = edhoc_exporter( - &self.completed, + self.completed.as_ref().ok_or(StateMismatch)?, &mut default_crypto(), label, &context_buf, @@ -149,7 +162,7 @@ impl PyEdhocResponder { context_buf[..context.len()].copy_from_slice(context.as_slice()); let res = edhoc_key_update( - &mut self.completed, + self.completed.as_mut().ok_or(StateMismatch)?, &mut default_crypto(), &context_buf, context.len(),