From 08d16497db81ad819d0ce5ecabc8f15654226941 Mon Sep 17 00:00:00 2001 From: chrysn Date: Sun, 5 Nov 2023 02:23:56 +0100 Subject: [PATCH] feat!: Typestating of high-layer initiator This is a rather minimal version in that the API is only altered as necessary -- setting c_r is not deferred yet. Note that this already not only reduces the size of the Done initiator, but also frees it from lifetime constraints (because at that point it doesn't need to know the setup details any more). --- examples/coap/src/bin/coapclient.rs | 12 +-- lib/src/lib.rs | 114 ++++++++++++++++++---------- 2 files changed, 78 insertions(+), 48 deletions(-) diff --git a/examples/coap/src/bin/coapclient.rs b/examples/coap/src/bin/coapclient.rs index 6b22ff8a..78c52bfd 100644 --- a/examples/coap/src/bin/coapclient.rs +++ b/examples/coap/src/bin/coapclient.rs @@ -26,7 +26,7 @@ fn main() { // Send Message 1 over CoAP and convert the response to byte let mut msg_1_buf = Vec::from([0xf5u8]); // EDHOC message_1 when transported over CoAP is prepended with CBOR true let c_i = generate_connection_identifier_cbor(); - let message_1 = initiator.prepare_message_1(c_i).unwrap(); + let (initiator, message_1) = initiator.prepare_message_1(c_i).unwrap(); msg_1_buf.extend_from_slice(&message_1.content[..message_1.len]); println!("message_1 len = {}", msg_1_buf.len()); @@ -37,15 +37,15 @@ fn main() { println!("response_vec = {:02x?}", response.message.payload); println!("message_2 len = {}", response.message.payload.len()); - let c_r = initiator.process_message_2( + let m2result = initiator.process_message_2( &response.message.payload[..] .try_into() .expect("wrong length"), ); - if c_r.is_ok() { - let mut msg_3 = Vec::from([c_r.unwrap()]); - let (message_3, prk_out) = initiator.prepare_message_3().unwrap(); + if let Ok((initiator, c_r)) = m2result { + let mut msg_3 = Vec::from([c_r]); + let (mut initiator, message_3, prk_out) = initiator.prepare_message_3().unwrap(); msg_3.extend_from_slice(&message_3.content[..message_3.len]); println!("message_3 len = {}", msg_3.len()); @@ -76,6 +76,6 @@ fn main() { println!("OSCORE secret after key update: {:02x?}", _oscore_secret); println!("OSCORE salt after key update: {:02x?}", _oscore_salt); } else { - panic!("Message 2 processing error: {:#?}", c_r); + panic!("Message 2 processing error: {:#?}", m2result); } } diff --git a/lib/src/lib.rs b/lib/src/lib.rs index ed8d2e46..38818e54 100644 --- a/lib/src/lib.rs +++ b/lib/src/lib.rs @@ -3,8 +3,7 @@ pub use { edhoc_consts::State as EdhocState, edhoc_consts::*, edhoc_crypto::default_crypto, - edhoc_crypto_trait::Crypto as CryptoTrait, EdhocInitiatorState as EdhocInitiator, - EdhocResponderState as EdhocResponder, + edhoc_crypto_trait::Crypto as CryptoTrait, EdhocResponderState as EdhocResponder, }; #[cfg(any(feature = "ead-none", feature = "ead-zeroconf"))] @@ -15,14 +14,35 @@ use edhoc::*; use edhoc_consts::*; -#[derive(Default, Copy, Clone, Debug)] -pub struct EdhocInitiatorState<'a> { +#[derive(Default, Debug)] +pub struct EdhocInitiator<'a> { + state: State, // opaque state + i: &'a [u8], // private authentication key of I + cred_i: &'a [u8], // I's full credential + cred_r: Option<&'a [u8]>, // R's full credential (if provided) +} + +#[derive(Default, Debug)] +pub struct EdhocInitiatorWaitM2<'a> { + state: State, // opaque state + i: &'a [u8], // private authentication key of I + cred_i: &'a [u8], // I's full credential + cred_r: Option<&'a [u8]>, // R's full credential (if provided) +} + +#[derive(Default, Debug)] +pub struct EdhocInitiatorBuildM3<'a> { state: State, // opaque state i: &'a [u8], // private authentication key of I cred_i: &'a [u8], // I's full credential cred_r: Option<&'a [u8]>, // R's full credential (if provided) } +#[derive(Default, Debug)] +pub struct EdhocInitiatorDone { + state: State, // opaque state +} + #[derive(Default, Copy, Clone, Debug)] pub struct EdhocResponderState<'a> { state: State, // opaque state @@ -146,16 +166,16 @@ impl<'a> EdhocResponderState<'a> { } } -impl<'a> EdhocInitiatorState<'a> { +impl<'a> EdhocInitiator<'a> { pub fn new( state: State, i: &'a [u8], cred_i: &'a [u8], cred_r: Option<&'a [u8]>, - ) -> EdhocInitiatorState<'a> { + ) -> EdhocInitiator<'a> { assert!(i.len() == P256_ELEM_LEN); - EdhocInitiatorState { + EdhocInitiator { state, i, cred_i, @@ -164,24 +184,31 @@ impl<'a> EdhocInitiatorState<'a> { } pub fn prepare_message_1( - self: &mut EdhocInitiatorState<'a>, + self: EdhocInitiator<'a>, c_i: u8, - ) -> Result { + ) -> Result<(EdhocInitiatorWaitM2<'a>, BufferMessage1), EDHOCError> { let (x, g_x) = default_crypto().p256_generate_key_pair(); match i_prepare_message_1(self.state, &mut default_crypto(), x, g_x, c_i) { - Ok((state, message_1)) => { - self.state = state; - Ok(message_1) - } + Ok((state, message_1)) => Ok(( + EdhocInitiatorWaitM2 { + state, + i: self.i, + cred_i: self.cred_i, + cred_r: self.cred_r, + }, + message_1, + )), Err(error) => Err(error), } } +} +impl<'a> EdhocInitiatorWaitM2<'a> { pub fn process_message_2( - self: &mut EdhocInitiatorState<'a>, + self, message_2: &BufferMessage2, - ) -> Result { + ) -> Result<(EdhocInitiatorBuildM3<'a>, u8), EDHOCError> { match i_process_message_2( self.state, &mut default_crypto(), @@ -191,17 +218,24 @@ impl<'a> EdhocInitiatorState<'a> { .try_into() .expect("Wrong length of initiator private key"), ) { - Ok((state, c_r, _kid)) => { - self.state = state; - Ok(c_r) - } + Ok((state, c_r, _kid)) => Ok(( + EdhocInitiatorBuildM3 { + state, + i: self.i, + cred_i: self.cred_i, + cred_r: self.cred_r, + }, + c_r, + )), Err(error) => Err(error), } } +} +impl<'a> EdhocInitiatorBuildM3<'a> { pub fn prepare_message_3( - self: &mut EdhocInitiatorState<'a>, - ) -> Result<(BufferMessage3, [u8; SHA256_DIGEST_LEN]), EDHOCError> { + self, + ) -> Result<(EdhocInitiatorDone, BufferMessage3, [u8; SHA256_DIGEST_LEN]), EDHOCError> { match i_prepare_message_3( self.state, &mut default_crypto(), @@ -209,15 +243,16 @@ impl<'a> EdhocInitiatorState<'a> { self.cred_i, ) { Ok((state, message_3, prk_out)) => { - self.state = state; - Ok((message_3, prk_out)) + Ok((EdhocInitiatorDone { state }, message_3, prk_out)) } Err(error) => Err(error), } } +} +impl EdhocInitiatorDone { pub fn edhoc_exporter( - self: &mut EdhocInitiatorState<'a>, + &mut self, label: u8, context: &[u8], length: usize, @@ -242,7 +277,7 @@ impl<'a> EdhocInitiatorState<'a> { } pub fn edhoc_key_update( - self: &mut EdhocInitiatorState<'a>, + &mut self, context: &[u8], ) -> Result<[u8; SHA256_DIGEST_LEN], EDHOCError> { let mut context_buf = [0x00u8; MAX_KDF_CONTEXT_LEN]; @@ -311,8 +346,8 @@ mod test { #[test] fn test_new_initiator() { let state: EdhocState = Default::default(); - let _initiator = EdhocInitiatorState::new(state, I, CRED_I, Some(CRED_R)); - let _initiator = EdhocInitiatorState::new(state, I, CRED_I, None); + let _initiator = EdhocInitiator::new(state, I, CRED_I, Some(CRED_R)); + let _initiator = EdhocInitiator::new(state, I, CRED_I, None); } #[test] @@ -325,7 +360,7 @@ mod test { #[test] fn test_prepare_message_1() { let state: EdhocState = Default::default(); - let mut initiator = EdhocInitiatorState::new(state, I, CRED_I, Some(CRED_R)); + let mut initiator = EdhocInitiator::new(state, I, CRED_I, Some(CRED_R)); let c_i = generate_connection_identifier_cbor(); let message_1 = initiator.prepare_message_1(c_i); @@ -359,15 +394,14 @@ mod test { #[test] fn test_handshake() { let state_initiator: EdhocState = Default::default(); - let mut initiator = EdhocInitiatorState::new(state_initiator, I, CRED_I, Some(CRED_R)); + let mut initiator = EdhocInitiator::new(state_initiator, I, CRED_I, Some(CRED_R)); let state_responder: EdhocState = Default::default(); let mut responder = EdhocResponderState::new(state_responder, R, CRED_R, Some(CRED_I)); let c_i: u8 = generate_connection_identifier_cbor(); - let result = initiator.prepare_message_1(c_i); // to update the state - assert!(result.is_ok()); + let (initiator, result) = initiator.prepare_message_1(c_i).unwrap(); // to update the state - let error = responder.process_message_1(&result.unwrap()); + let error = responder.process_message_1(&result); assert!(error.is_ok()); let c_r = generate_connection_identifier_cbor(); @@ -377,13 +411,9 @@ mod test { let message_2 = ret.unwrap(); assert!(c_r != 0xff); - let _c_r = initiator.process_message_2(&message_2); - assert!(_c_r.is_ok()); - - let ret = initiator.prepare_message_3(); - assert!(ret.is_ok()); + let (initiator, _) = initiator.process_message_2(&message_2).unwrap(); - let (message_3, i_prk_out) = ret.unwrap(); + let (mut initiator, message_3, i_prk_out) = initiator.prepare_message_3().unwrap(); let r_prk_out = responder.process_message_3(&message_3); assert!(r_prk_out.is_ok()); @@ -437,7 +467,7 @@ mod test { #[test] fn test_ead_zeroconf() { let state_initiator: EdhocState = Default::default(); - let mut initiator = EdhocInitiatorState::new(state_initiator, I, CRED_I, None); + let mut initiator = EdhocInitiator::new(state_initiator, I, CRED_I, None); let state_responder: EdhocState = Default::default(); let mut responder = EdhocResponderState::new(state_responder, R, CRED_V_TV, Some(CRED_I)); @@ -467,7 +497,7 @@ mod test { )); let c_i = generate_connection_identifier_cbor(); - let message_1 = initiator.prepare_message_1(c_i).unwrap(); + let (initiator, message_1) = initiator.prepare_message_1(c_i).unwrap(); assert_eq!( ead_initiator_state.protocol_state, EADInitiatorProtocolState::WaitEAD2 @@ -486,14 +516,14 @@ mod test { EADResponderProtocolState::Completed ); - initiator.process_message_2(&message_2).unwrap(); + let (initiator, _) = initiator.process_message_2(&message_2).unwrap(); assert_eq!( ead_initiator_state.protocol_state, EADInitiatorProtocolState::Completed ); - let (message_3, i_prk_out) = initiator.prepare_message_3().unwrap(); + let (initiator, message_3, i_prk_out) = initiator.prepare_message_3().unwrap(); let r_prk_out = responder.process_message_3(&message_3).unwrap(); assert_eq!(i_prk_out, r_prk_out);