Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

python: Make safer to use #321

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 30 additions & 18 deletions lakers-python/src/initiator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,19 @@ use lakers_crypto::{default_crypto, CryptoTrait};
use log::trace;
use pyo3::{prelude::*, types::PyBytes};

use super::StateMismatch;

#[pyclass(name = "EdhocInitiator")]
pub struct PyEdhocInitiator {
cred_i: Option<Credential>,
// FIXME: This does *not* get taken out, so some data stays available for longer than it needs
// to be -- but that is apparently needed in selected_cipher_suite and
// compute_ephemeral_secret.
start: InitiatorStart,
wait_m2: WaitM2,
processing_m2: ProcessingM2,
processed_m2: ProcessedM2,
completed: Completed,
wait_m2: Option<WaitM2>,
processing_m2: Option<ProcessingM2>,
processed_m2: Option<ProcessedM2>,
completed: Option<Completed>,
}

#[pymethods]
Expand All @@ -31,10 +36,10 @@ impl PyEdhocInitiator {
method: EDHOCMethod::StatStat.into(),
suites_i,
},
wait_m2: WaitM2::default(),
processing_m2: ProcessingM2::default(),
processed_m2: ProcessedM2::default(),
completed: Completed::default(),
wait_m2: None,
processing_m2: None,
processed_m2: None,
completed: None,
}
}

Expand All @@ -54,7 +59,7 @@ impl PyEdhocInitiator {

match i_prepare_message_1(&self.start, &mut default_crypto(), c_i, &ead_1) {
Ok((state, message_1)) => {
self.wait_m2 = state;
self.wait_m2 = Some(state);
Ok(PyBytes::new_bound(py, message_1.as_slice()))
}
Err(error) => Err(error.into()),
Expand All @@ -68,9 +73,13 @@ impl PyEdhocInitiator {
) -> PyResult<(Bound<'a, PyBytes>, Bound<'a, PyBytes>, Option<EADItem>)> {
let message_2 = EdhocMessageBuffer::new_from_slice(message_2.as_slice())?;

match i_parse_message_2(&self.wait_m2, &mut default_crypto(), &message_2) {
match i_parse_message_2(
&self.wait_m2.take().ok_or(StateMismatch)?,
&mut default_crypto(),
&message_2,
) {
Ok((state, c_r, id_cred_r, ead_2)) => {
self.processing_m2 = state;
self.processing_m2 = Some(state);
Ok((
PyBytes::new_bound(py, c_r.as_slice()),
PyBytes::new_bound(py, id_cred_r.bytes.as_slice()),
Expand All @@ -91,15 +100,15 @@ impl PyEdhocInitiator {
let valid_cred_r = valid_cred_r.to_credential()?;

match i_verify_message_2(
&self.processing_m2,
&self.processing_m2.take().ok_or(StateMismatch)?,
&mut default_crypto(),
valid_cred_r,
i.as_slice()
.try_into()
.expect("Wrong length of initiator private key"),
) {
Ok(state) => {
self.processed_m2 = state;
self.processed_m2 = Some(state);
self.cred_i = Some(cred_i);
Ok(())
}
Expand All @@ -115,14 +124,14 @@ impl PyEdhocInitiator {
ead_3: Option<EADItem>,
) -> PyResult<(Bound<'a, PyBytes>, Bound<'a, PyBytes>)> {
match i_prepare_message_3(
&mut self.processed_m2,
&mut self.processed_m2.take().ok_or(StateMismatch)?,
&mut default_crypto(),
self.cred_i.unwrap(),
cred_transfer,
&ead_3,
) {
Ok((state, message_3, prk_out)) => {
self.completed = state;
self.completed = Some(state);
Ok((
PyBytes::new_bound(py, message_3.as_slice()),
PyBytes::new_bound(py, prk_out.as_slice()),
Expand All @@ -143,7 +152,7 @@ impl PyEdhocInitiator {
context_buf[..context.len()].copy_from_slice(context.as_slice());

let res = edhoc_exporter(
&self.completed,
self.completed.as_ref().ok_or(StateMismatch)?,
&mut default_crypto(),
label,
&context_buf,
Expand All @@ -162,7 +171,7 @@ impl PyEdhocInitiator {
context_buf[..context.len()].copy_from_slice(context.as_slice());

let res = edhoc_key_update(
&mut self.completed,
self.completed.as_mut().ok_or(StateMismatch)?,
&mut default_crypto(),
&context_buf,
context.len(),
Expand All @@ -171,7 +180,10 @@ impl PyEdhocInitiator {
}

pub fn get_h_message_1<'a>(&self, py: Python<'a>) -> PyResult<Bound<'a, PyBytes>> {
Ok(PyBytes::new_bound(py, &self.wait_m2.h_message_1[..]))
Ok(PyBytes::new_bound(
py,
&self.wait_m2.as_ref().ok_or(StateMismatch)?.h_message_1[..],
))
}

pub fn compute_ephemeral_secret<'a>(
Expand Down
22 changes: 22 additions & 0 deletions lakers-python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,28 @@ mod ead_authz;
mod initiator;
mod responder;

/// Error raised when operations on a Python object did not happen in the sequence in which they
/// were intended.
///
/// This currently has no more detailed response because for every situation this can occur in,
/// there are different possible explainations that we can't get across easily in a single message.
/// For example, if `responder.processing_m1` is absent, that can be either because no message 1
/// was processed into it yet, or because message 2 was already generated.
#[derive(Debug)]
pub(crate) struct StateMismatch;

impl std::error::Error for StateMismatch {}
impl std::fmt::Display for StateMismatch {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Type state mismatch")
}
}
impl From<StateMismatch> for PyErr {
fn from(err: StateMismatch) -> PyErr {
pyo3::exceptions::PyRuntimeError::new_err(err.to_string())
}
}

// NOTE: throughout this implementation, we use Vec<u8> for incoming byte lists and PyBytes for outgoing byte lists.
// This is because the incoming lists of bytes are automatically converted to `Vec<u8>` by pyo3,
// but the outgoing ones must be explicitly converted to `PyBytes`.
Expand Down
57 changes: 35 additions & 22 deletions lakers-python/src/responder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,17 @@ use lakers_crypto::{default_crypto, CryptoTrait};
use log::trace;
use pyo3::{prelude::*, types::PyBytes};

use super::StateMismatch;

#[pyclass(name = "EdhocResponder")]
pub struct PyEdhocResponder {
r: Vec<u8>,
cred_r: Credential,
start: ResponderStart,
processing_m1: ProcessingM1,
wait_m3: WaitM3,
processing_m3: ProcessingM3,
completed: Completed,
start: Option<ResponderStart>,
processing_m1: Option<ProcessingM1>,
wait_m3: Option<WaitM3>,
processing_m3: Option<ProcessingM3>,
completed: Option<Completed>,
}

#[pymethods]
Expand All @@ -26,15 +28,15 @@ impl PyEdhocResponder {
Ok(Self {
r,
cred_r,
start: ResponderStart {
start: Some(ResponderStart {
method: EDHOCMethod::StatStat.into(),
y,
g_y,
},
processing_m1: ProcessingM1::default(),
wait_m3: WaitM3::default(),
processing_m3: ProcessingM3::default(),
completed: Completed::default(),
}),
processing_m1: None,
wait_m3: None,
processing_m3: None,
completed: None,
})
}

Expand All @@ -44,9 +46,12 @@ impl PyEdhocResponder {
message_1: Vec<u8>,
) -> PyResult<(Bound<'a, PyBytes>, Option<EADItem>)> {
let message_1 = EdhocMessageBuffer::new_from_slice(message_1.as_slice())?;
let (state, c_i, ead_1) =
r_process_message_1(&self.start, &mut default_crypto(), &message_1)?;
self.processing_m1 = state;
let (state, c_i, ead_1) = r_process_message_1(
&self.start.take().ok_or(StateMismatch)?,
&mut default_crypto(),
&message_1,
)?;
self.processing_m1 = Some(state);
let c_i = PyBytes::new_bound(py, c_i.as_slice());

Ok((c_i, ead_1))
Expand All @@ -73,7 +78,7 @@ impl PyEdhocResponder {
r.copy_from_slice(self.r.as_slice());

match r_prepare_message_2(
&self.processing_m1,
self.processing_m1.as_ref().ok_or(StateMismatch)?,
&mut default_crypto(),
self.cred_r,
&r,
Expand All @@ -82,7 +87,7 @@ impl PyEdhocResponder {
&ead_2,
) {
Ok((state, message_2)) => {
self.wait_m3 = state;
self.wait_m3 = Some(state);
Ok(PyBytes::new_bound(py, message_2.as_slice()))
}
Err(error) => Err(error.into()),
Expand All @@ -95,9 +100,13 @@ impl PyEdhocResponder {
message_3: Vec<u8>,
) -> PyResult<(Bound<'a, PyBytes>, Option<EADItem>)> {
let message_3 = EdhocMessageBuffer::new_from_slice(message_3.as_slice())?;
match r_parse_message_3(&mut self.wait_m3, &mut default_crypto(), &message_3) {
match r_parse_message_3(
&mut self.wait_m3.take().ok_or(StateMismatch)?,
&mut default_crypto(),
&message_3,
) {
Ok((state, id_cred_i, ead_3)) => {
self.processing_m3 = state;
self.processing_m3 = Some(state);
Ok((PyBytes::new_bound(py, id_cred_i.bytes.as_slice()), ead_3))
}
Err(error) => Err(error.into()),
Expand All @@ -110,9 +119,13 @@ impl PyEdhocResponder {
valid_cred_i: super::AutoCredential,
) -> PyResult<Bound<'a, PyBytes>> {
let valid_cred_i = valid_cred_i.to_credential()?;
match r_verify_message_3(&mut self.processing_m3, &mut default_crypto(), valid_cred_i) {
match r_verify_message_3(
&mut self.processing_m3.take().ok_or(StateMismatch)?,
&mut default_crypto(),
valid_cred_i,
) {
Ok((state, prk_out)) => {
self.completed = state;
self.completed = Some(state);
Ok(PyBytes::new_bound(py, prk_out.as_slice()))
}
Err(error) => Err(error.into()),
Expand All @@ -130,7 +143,7 @@ impl PyEdhocResponder {
context_buf[..context.len()].copy_from_slice(context.as_slice());

let res = edhoc_exporter(
&self.completed,
self.completed.as_ref().ok_or(StateMismatch)?,
&mut default_crypto(),
label,
&context_buf,
Expand All @@ -149,7 +162,7 @@ impl PyEdhocResponder {
context_buf[..context.len()].copy_from_slice(context.as_slice());

let res = edhoc_key_update(
&mut self.completed,
self.completed.as_mut().ok_or(StateMismatch)?,
&mut default_crypto(),
&context_buf,
context.len(),
Expand Down