diff --git a/examples/coap/src/bin/coapclient.rs b/examples/coap/src/bin/coapclient.rs index 6b22ff8a..78c52bfd 100644 --- a/examples/coap/src/bin/coapclient.rs +++ b/examples/coap/src/bin/coapclient.rs @@ -26,7 +26,7 @@ fn main() { // Send Message 1 over CoAP and convert the response to byte let mut msg_1_buf = Vec::from([0xf5u8]); // EDHOC message_1 when transported over CoAP is prepended with CBOR true let c_i = generate_connection_identifier_cbor(); - let message_1 = initiator.prepare_message_1(c_i).unwrap(); + let (initiator, message_1) = initiator.prepare_message_1(c_i).unwrap(); msg_1_buf.extend_from_slice(&message_1.content[..message_1.len]); println!("message_1 len = {}", msg_1_buf.len()); @@ -37,15 +37,15 @@ fn main() { println!("response_vec = {:02x?}", response.message.payload); println!("message_2 len = {}", response.message.payload.len()); - let c_r = initiator.process_message_2( + let m2result = initiator.process_message_2( &response.message.payload[..] .try_into() .expect("wrong length"), ); - if c_r.is_ok() { - let mut msg_3 = Vec::from([c_r.unwrap()]); - let (message_3, prk_out) = initiator.prepare_message_3().unwrap(); + if let Ok((initiator, c_r)) = m2result { + let mut msg_3 = Vec::from([c_r]); + let (mut initiator, message_3, prk_out) = initiator.prepare_message_3().unwrap(); msg_3.extend_from_slice(&message_3.content[..message_3.len]); println!("message_3 len = {}", msg_3.len()); @@ -76,6 +76,6 @@ fn main() { println!("OSCORE secret after key update: {:02x?}", _oscore_secret); println!("OSCORE salt after key update: {:02x?}", _oscore_salt); } else { - panic!("Message 2 processing error: {:#?}", c_r); + panic!("Message 2 processing error: {:#?}", m2result); } } diff --git a/lib/src/lib.rs b/lib/src/lib.rs index ed8d2e46..38818e54 100644 --- a/lib/src/lib.rs +++ b/lib/src/lib.rs @@ -3,8 +3,7 @@ pub use { edhoc_consts::State as EdhocState, edhoc_consts::*, edhoc_crypto::default_crypto, - edhoc_crypto_trait::Crypto as CryptoTrait, EdhocInitiatorState as EdhocInitiator, - EdhocResponderState as EdhocResponder, + edhoc_crypto_trait::Crypto as CryptoTrait, EdhocResponderState as EdhocResponder, }; #[cfg(any(feature = "ead-none", feature = "ead-zeroconf"))] @@ -15,14 +14,35 @@ use edhoc::*; use edhoc_consts::*; -#[derive(Default, Copy, Clone, Debug)] -pub struct EdhocInitiatorState<'a> { +#[derive(Default, Debug)] +pub struct EdhocInitiator<'a> { + state: State, // opaque state + i: &'a [u8], // private authentication key of I + cred_i: &'a [u8], // I's full credential + cred_r: Option<&'a [u8]>, // R's full credential (if provided) +} + +#[derive(Default, Debug)] +pub struct EdhocInitiatorWaitM2<'a> { + state: State, // opaque state + i: &'a [u8], // private authentication key of I + cred_i: &'a [u8], // I's full credential + cred_r: Option<&'a [u8]>, // R's full credential (if provided) +} + +#[derive(Default, Debug)] +pub struct EdhocInitiatorBuildM3<'a> { state: State, // opaque state i: &'a [u8], // private authentication key of I cred_i: &'a [u8], // I's full credential cred_r: Option<&'a [u8]>, // R's full credential (if provided) } +#[derive(Default, Debug)] +pub struct EdhocInitiatorDone { + state: State, // opaque state +} + #[derive(Default, Copy, Clone, Debug)] pub struct EdhocResponderState<'a> { state: State, // opaque state @@ -146,16 +166,16 @@ impl<'a> EdhocResponderState<'a> { } } -impl<'a> EdhocInitiatorState<'a> { +impl<'a> EdhocInitiator<'a> { pub fn new( state: State, i: &'a [u8], cred_i: &'a [u8], cred_r: Option<&'a [u8]>, - ) -> EdhocInitiatorState<'a> { + ) -> EdhocInitiator<'a> { assert!(i.len() == P256_ELEM_LEN); - EdhocInitiatorState { + EdhocInitiator { state, i, cred_i, @@ -164,24 +184,31 @@ impl<'a> EdhocInitiatorState<'a> { } pub fn prepare_message_1( - self: &mut EdhocInitiatorState<'a>, + self: EdhocInitiator<'a>, c_i: u8, - ) -> Result { + ) -> Result<(EdhocInitiatorWaitM2<'a>, BufferMessage1), EDHOCError> { let (x, g_x) = default_crypto().p256_generate_key_pair(); match i_prepare_message_1(self.state, &mut default_crypto(), x, g_x, c_i) { - Ok((state, message_1)) => { - self.state = state; - Ok(message_1) - } + Ok((state, message_1)) => Ok(( + EdhocInitiatorWaitM2 { + state, + i: self.i, + cred_i: self.cred_i, + cred_r: self.cred_r, + }, + message_1, + )), Err(error) => Err(error), } } +} +impl<'a> EdhocInitiatorWaitM2<'a> { pub fn process_message_2( - self: &mut EdhocInitiatorState<'a>, + self, message_2: &BufferMessage2, - ) -> Result { + ) -> Result<(EdhocInitiatorBuildM3<'a>, u8), EDHOCError> { match i_process_message_2( self.state, &mut default_crypto(), @@ -191,17 +218,24 @@ impl<'a> EdhocInitiatorState<'a> { .try_into() .expect("Wrong length of initiator private key"), ) { - Ok((state, c_r, _kid)) => { - self.state = state; - Ok(c_r) - } + Ok((state, c_r, _kid)) => Ok(( + EdhocInitiatorBuildM3 { + state, + i: self.i, + cred_i: self.cred_i, + cred_r: self.cred_r, + }, + c_r, + )), Err(error) => Err(error), } } +} +impl<'a> EdhocInitiatorBuildM3<'a> { pub fn prepare_message_3( - self: &mut EdhocInitiatorState<'a>, - ) -> Result<(BufferMessage3, [u8; SHA256_DIGEST_LEN]), EDHOCError> { + self, + ) -> Result<(EdhocInitiatorDone, BufferMessage3, [u8; SHA256_DIGEST_LEN]), EDHOCError> { match i_prepare_message_3( self.state, &mut default_crypto(), @@ -209,15 +243,16 @@ impl<'a> EdhocInitiatorState<'a> { self.cred_i, ) { Ok((state, message_3, prk_out)) => { - self.state = state; - Ok((message_3, prk_out)) + Ok((EdhocInitiatorDone { state }, message_3, prk_out)) } Err(error) => Err(error), } } +} +impl EdhocInitiatorDone { pub fn edhoc_exporter( - self: &mut EdhocInitiatorState<'a>, + &mut self, label: u8, context: &[u8], length: usize, @@ -242,7 +277,7 @@ impl<'a> EdhocInitiatorState<'a> { } pub fn edhoc_key_update( - self: &mut EdhocInitiatorState<'a>, + &mut self, context: &[u8], ) -> Result<[u8; SHA256_DIGEST_LEN], EDHOCError> { let mut context_buf = [0x00u8; MAX_KDF_CONTEXT_LEN]; @@ -311,8 +346,8 @@ mod test { #[test] fn test_new_initiator() { let state: EdhocState = Default::default(); - let _initiator = EdhocInitiatorState::new(state, I, CRED_I, Some(CRED_R)); - let _initiator = EdhocInitiatorState::new(state, I, CRED_I, None); + let _initiator = EdhocInitiator::new(state, I, CRED_I, Some(CRED_R)); + let _initiator = EdhocInitiator::new(state, I, CRED_I, None); } #[test] @@ -325,7 +360,7 @@ mod test { #[test] fn test_prepare_message_1() { let state: EdhocState = Default::default(); - let mut initiator = EdhocInitiatorState::new(state, I, CRED_I, Some(CRED_R)); + let mut initiator = EdhocInitiator::new(state, I, CRED_I, Some(CRED_R)); let c_i = generate_connection_identifier_cbor(); let message_1 = initiator.prepare_message_1(c_i); @@ -359,15 +394,14 @@ mod test { #[test] fn test_handshake() { let state_initiator: EdhocState = Default::default(); - let mut initiator = EdhocInitiatorState::new(state_initiator, I, CRED_I, Some(CRED_R)); + let mut initiator = EdhocInitiator::new(state_initiator, I, CRED_I, Some(CRED_R)); let state_responder: EdhocState = Default::default(); let mut responder = EdhocResponderState::new(state_responder, R, CRED_R, Some(CRED_I)); let c_i: u8 = generate_connection_identifier_cbor(); - let result = initiator.prepare_message_1(c_i); // to update the state - assert!(result.is_ok()); + let (initiator, result) = initiator.prepare_message_1(c_i).unwrap(); // to update the state - let error = responder.process_message_1(&result.unwrap()); + let error = responder.process_message_1(&result); assert!(error.is_ok()); let c_r = generate_connection_identifier_cbor(); @@ -377,13 +411,9 @@ mod test { let message_2 = ret.unwrap(); assert!(c_r != 0xff); - let _c_r = initiator.process_message_2(&message_2); - assert!(_c_r.is_ok()); - - let ret = initiator.prepare_message_3(); - assert!(ret.is_ok()); + let (initiator, _) = initiator.process_message_2(&message_2).unwrap(); - let (message_3, i_prk_out) = ret.unwrap(); + let (mut initiator, message_3, i_prk_out) = initiator.prepare_message_3().unwrap(); let r_prk_out = responder.process_message_3(&message_3); assert!(r_prk_out.is_ok()); @@ -437,7 +467,7 @@ mod test { #[test] fn test_ead_zeroconf() { let state_initiator: EdhocState = Default::default(); - let mut initiator = EdhocInitiatorState::new(state_initiator, I, CRED_I, None); + let mut initiator = EdhocInitiator::new(state_initiator, I, CRED_I, None); let state_responder: EdhocState = Default::default(); let mut responder = EdhocResponderState::new(state_responder, R, CRED_V_TV, Some(CRED_I)); @@ -467,7 +497,7 @@ mod test { )); let c_i = generate_connection_identifier_cbor(); - let message_1 = initiator.prepare_message_1(c_i).unwrap(); + let (initiator, message_1) = initiator.prepare_message_1(c_i).unwrap(); assert_eq!( ead_initiator_state.protocol_state, EADInitiatorProtocolState::WaitEAD2 @@ -486,14 +516,14 @@ mod test { EADResponderProtocolState::Completed ); - initiator.process_message_2(&message_2).unwrap(); + let (initiator, _) = initiator.process_message_2(&message_2).unwrap(); assert_eq!( ead_initiator_state.protocol_state, EADInitiatorProtocolState::Completed ); - let (message_3, i_prk_out) = initiator.prepare_message_3().unwrap(); + let (initiator, message_3, i_prk_out) = initiator.prepare_message_3().unwrap(); let r_prk_out = responder.process_message_3(&message_3).unwrap(); assert_eq!(i_prk_out, r_prk_out);