diff --git a/examples/coap/src/bin/coapserver-coaphandler.rs b/examples/coap/src/bin/coapserver-coaphandler.rs index 232ed551..304bb063 100644 --- a/examples/coap/src/bin/coapserver-coaphandler.rs +++ b/examples/coap/src/bin/coapserver-coaphandler.rs @@ -15,16 +15,18 @@ const R: &[u8] = &hex!("72cc4761dbd4c78f758931aa589d348d1ef874a7e303ede2f140dcf3 #[derive(Default, Debug)] struct EdhocHandler { - connections: Vec<(u8, EdhocResponder<'static>)>, + connections: Vec<(u8, EdhocResponderWaitM3<'static>)>, } impl EdhocHandler { - fn connection_by_c_r(&mut self, c_r: u8) -> Option<&mut EdhocResponder<'static>> { - self.connections - .iter_mut() - .filter(|(current_c_r, _)| current_c_r == &c_r) - .map(|(_, responder)| responder) - .next() + fn take_connection_by_c_r(&mut self, c_r: u8) -> Option> { + let index = self + .connections + .iter() + .position(|(current_c_r, _)| current_c_r == &c_r)?; + let last = self.connections.len() - 1; + self.connections.swap(index, last); + Some(self.connections.pop().unwrap().1) } fn new_c_r(&self) -> u8 { @@ -40,7 +42,12 @@ impl EdhocHandler { } enum EdhocResponse { - OkSend2 { c_r: u8 }, + // We could also store the responder in the Vec (once we're done rendering the response, we'll + // take up a slot there anyway) if we make it an enum. + OkSend2 { + c_r: u8, + responder: EdhocResponderBuildM2<'static>, + }, Message3Processed, } @@ -54,33 +61,33 @@ impl coap_handler::Handler for EdhocHandler { if starts_with_true { let state = EdhocState::default(); - let mut responder = EdhocResponder::new(state, &R, &CRED_R, Some(&CRED_I)); + let responder = EdhocResponder::new(state, &R, &CRED_R, Some(&CRED_I)); - let error = responder + let response = responder .process_message_1(&request.payload()[1..].try_into().expect("wrong length")); - if error.is_ok() { + if let Ok(responder) = response { let c_r = self.new_c_r(); - // save edhoc connection - self.connections.push((c_r, responder)); - EdhocResponse::OkSend2 { c_r } + EdhocResponse::OkSend2 { c_r, responder } } else { panic!("How to respond to non-OK?") } } else { // potentially message 3 let c_r_rcvd = request.payload()[0]; - let mut responder = self.connection_by_c_r(c_r_rcvd).expect("No such C_R found"); + let responder = self + .take_connection_by_c_r(c_r_rcvd) + .expect("No such C_R found"); println!("Found state with connection identifier {:?}", c_r_rcvd); - let prk_out = responder + let result = responder .process_message_3(&request.payload()[1..].try_into().expect("wrong length")); - if prk_out.is_err() { - println!("EDHOC processing error: {:?}", prk_out); + let Ok((mut responder, prk_out)) = result else { + println!("EDHOC processing error: {:?}", result); // FIXME remove state from edhoc_connections panic!("Handler can't just not respond"); - } + }; println!("EDHOC exchange successfully completed"); println!("PRK_out: {:02x?}", prk_out); @@ -115,9 +122,9 @@ impl coap_handler::Handler for EdhocHandler { ) { response.set_code(coap_numbers::code::CHANGED.try_into().ok().unwrap()); match req { - EdhocResponse::OkSend2 { c_r } => { - let responder = self.connection_by_c_r(c_r).unwrap(); - let message_2 = responder.prepare_message_2(c_r).unwrap(); + EdhocResponse::OkSend2 { c_r, responder } => { + let (responder, message_2) = responder.prepare_message_2(c_r).unwrap(); + self.connections.push((c_r, responder)); response.set_payload(&message_2.content[..message_2.len]); } EdhocResponse::Message3Processed => (), // "send empty ack back"? @@ -128,7 +135,7 @@ impl coap_handler::Handler for EdhocHandler { fn build_handler() -> impl coap_handler::Handler { use coap_handler_implementations::{HandlerBuilder, ReportingHandlerBuilder}; - let mut edhoc: EdhocHandler = Default::default(); + let edhoc: EdhocHandler = Default::default(); coap_handler_implementations::new_dispatcher() .at_with_attributes(&[".well-known", "edhoc"], &[], edhoc) diff --git a/examples/coap/src/bin/coapserver.rs b/examples/coap/src/bin/coapserver.rs index 9283fcfd..e860f440 100644 --- a/examples/coap/src/bin/coapserver.rs +++ b/examples/coap/src/bin/coapserver.rs @@ -32,17 +32,17 @@ fn main() { // This is an EDHOC message if request.message.payload[0] == 0xf5 { let state = EdhocState::default(); - let mut responder = EdhocResponder::new(state, &R, &CRED_R, Some(&CRED_I)); + let responder = EdhocResponder::new(state, &R, &CRED_R, Some(&CRED_I)); - let error = responder.process_message_1( + let result = responder.process_message_1( &request.message.payload[1..] .try_into() .expect("wrong length"), ); - if error.is_ok() { + if let Ok(responder) = result { let c_r = generate_connection_identifier_cbor(); - let message_2 = responder.prepare_message_2(c_r).unwrap(); + let (responder, message_2) = responder.prepare_message_2(c_r).unwrap(); response.message.payload = Vec::from(&message_2.content[..message_2.len]); // save edhoc connection edhoc_connections.push((c_r, responder)); @@ -51,24 +51,21 @@ fn main() { // potentially message 3 println!("Received message 3"); let c_r_rcvd = request.message.payload[0]; - let (index, mut responder, ec) = lookup_state(c_r_rcvd, edhoc_connections).unwrap(); - edhoc_connections = ec; + // FIXME let's better not *panic here + let responder = take_state(c_r_rcvd, &mut edhoc_connections).unwrap(); println!("Found state with connection identifier {:?}", c_r_rcvd); - let prk_out = responder.process_message_3( + let result = responder.process_message_3( &request.message.payload[1..] .try_into() .expect("wrong length"), ); - - if prk_out.is_err() { - println!("EDHOC processing error: {:?}", prk_out); - // FIXME remove state from edhoc_connections + let Ok((mut responder, prk_out)) = result else { + println!("EDHOC processing error: {:?}", response); + // We don't get another chance, it's popped and can't be used any further + // anyway legally continue; - } - - // update edhoc connection - edhoc_connections[index] = (c_r_rcvd, responder); + }; // send empty ack back response.message.payload = b"".to_vec(); @@ -106,14 +103,16 @@ fn main() { } } -fn lookup_state<'a>( - c_r_rcvd: u8, - edhoc_protocol_states: Vec<(u8, EdhocResponder<'a>)>, -) -> Result<(usize, EdhocResponder<'a>, Vec<(u8, EdhocResponder)>), EDHOCError> { +fn take_state(c_r_rcvd: u8, edhoc_protocol_states: &mut Vec<(u8, R)>) -> Result { for (i, element) in edhoc_protocol_states.iter().enumerate() { let (c_r, responder) = element; if *c_r == c_r_rcvd { - return Ok((i, *responder, edhoc_protocol_states)); + let max_index = edhoc_protocol_states.len() - 1; + edhoc_protocol_states.swap(i, max_index); + let Some((_c_r, responder)) = edhoc_protocol_states.pop() else { + unreachable!(); + }; + return Ok(responder); } } return Err(EDHOCError::WrongState); diff --git a/examples/edhoc-rs-no_std/src/main.rs b/examples/edhoc-rs-no_std/src/main.rs index 8720458d..576dc7eb 100644 --- a/examples/edhoc-rs-no_std/src/main.rs +++ b/examples/edhoc-rs-no_std/src/main.rs @@ -108,34 +108,25 @@ fn main() -> ! { let state_initiator: EdhocState = Default::default(); let mut initiator = EdhocInitiator::new(state_initiator, I, CRED_I, Some(CRED_R)); let state_responder: EdhocState = Default::default(); - let mut responder = EdhocResponder::new(state_responder, R, CRED_R, Some(CRED_I)); + let responder = EdhocResponder::new(state_responder, R, CRED_R, Some(CRED_I)); let c_i: u8 = generate_connection_identifier_cbor().into(); - let ret = initiator.prepare_message_1(c_i); // to update the state - assert!(ret.is_ok()); - let message_1 = ret.unwrap(); + let (initiator, message_1) = initiator.prepare_message_1(c_i).unwrap(); // to update the state - let ret = responder.process_message_1(&message_1); - assert!(ret.is_ok()); + let responder = responder.process_message_1(&message_1).unwrap(); let c_r: u8 = generate_connection_identifier_cbor().into(); - let ret = responder.prepare_message_2(c_r); - assert!(ret.is_ok()); - let message_2 = ret.unwrap(); + let (responder, message_2) = responder.prepare_message_2(c_r).unwrap(); assert!(c_r != 0xff); - let _c_r = initiator.process_message_2(&message_2); - assert!(_c_r.is_ok()); + let (initiator, _c_r) = initiator.process_message_2(&message_2).unwrap(); - let ret = initiator.prepare_message_3(); - assert!(ret.is_ok()); - let (message_3, i_prk_out) = ret.unwrap(); + let (mut initiator, message_3, i_prk_out) = initiator.prepare_message_3().unwrap(); - let r_prk_out = responder.process_message_3(&message_3); - assert!(r_prk_out.is_ok()); + let (mut responder, r_prk_out) = responder.process_message_3(&message_3).unwrap(); // check that prk_out is equal at initiator and responder side - assert_eq!(i_prk_out, r_prk_out.unwrap()); + assert_eq!(i_prk_out, r_prk_out); // derive OSCORE secret and salt at both sides and compare let i_oscore_secret = initiator.edhoc_exporter(0u8, &[], 16); // label is 0 diff --git a/lib/src/lib.rs b/lib/src/lib.rs index 38818e54..c397ac1e 100644 --- a/lib/src/lib.rs +++ b/lib/src/lib.rs @@ -3,7 +3,7 @@ pub use { edhoc_consts::State as EdhocState, edhoc_consts::*, edhoc_crypto::default_crypto, - edhoc_crypto_trait::Crypto as CryptoTrait, EdhocResponderState as EdhocResponder, + edhoc_crypto_trait::Crypto as CryptoTrait, }; #[cfg(any(feature = "ead-none", feature = "ead-zeroconf"))] @@ -43,24 +43,45 @@ pub struct EdhocInitiatorDone { state: State, // opaque state } -#[derive(Default, Copy, Clone, Debug)] -pub struct EdhocResponderState<'a> { +#[derive(Default, Debug)] +pub struct EdhocResponder<'a> { state: State, // opaque state r: &'a [u8], // private authentication key of R cred_r: &'a [u8], // R's full credential cred_i: Option<&'a [u8]>, // I's full credential (if provided) } -impl<'a> EdhocResponderState<'a> { +#[derive(Default, Debug)] +pub struct EdhocResponderBuildM2<'a> { + state: State, // opaque state + r: &'a [u8], // private authentication key of R + cred_r: &'a [u8], // R's full credential + cred_i: Option<&'a [u8]>, // I's full credential (if provided) +} + +#[derive(Default, Debug)] +pub struct EdhocResponderWaitM3<'a> { + state: State, // opaque state + r: &'a [u8], // private authentication key of R + cred_r: &'a [u8], // R's full credential + cred_i: Option<&'a [u8]>, // I's full credential (if provided) +} + +#[derive(Default, Debug)] +pub struct EdhocResponderDone { + state: State, // opaque state +} + +impl<'a> EdhocResponder<'a> { pub fn new( state: State, r: &'a [u8], cred_r: &'a [u8], cred_i: Option<&'a [u8]>, - ) -> EdhocResponderState<'a> { + ) -> EdhocResponder<'a> { assert!(r.len() == P256_ELEM_LEN); - EdhocResponderState { + EdhocResponder { state, r, cred_r, @@ -69,19 +90,25 @@ impl<'a> EdhocResponderState<'a> { } pub fn process_message_1( - self: &mut EdhocResponderState<'a>, + self, message_1: &BufferMessage1, - ) -> Result<(), EDHOCError> { + ) -> Result, EDHOCError> { let state = r_process_message_1(self.state, &mut default_crypto(), message_1)?; - self.state = state; - Ok(()) + Ok(EdhocResponderBuildM2 { + state, + r: self.r, + cred_r: self.cred_r, + cred_i: self.cred_i, + }) } +} +impl<'a> EdhocResponderBuildM2<'a> { pub fn prepare_message_2( - self: &mut EdhocResponderState<'a>, + self, c_r: u8, - ) -> Result { + ) -> Result<(EdhocResponderWaitM3<'a>, BufferMessage2), EDHOCError> { let (y, g_y) = default_crypto().p256_generate_key_pair(); match r_prepare_message_2( @@ -93,34 +120,40 @@ impl<'a> EdhocResponderState<'a> { g_y, c_r, ) { - Ok((state, message_2)) => { - self.state = state; - Ok(message_2) - } + Ok((state, message_2)) => Ok(( + EdhocResponderWaitM3 { + state, + r: self.r, + cred_r: self.cred_r, + cred_i: self.cred_i, + }, + message_2, + )), Err(error) => Err(error), } } +} +impl<'a> EdhocResponderWaitM3<'a> { pub fn process_message_3( - self: &mut EdhocResponderState<'a>, + self, message_3: &BufferMessage3, - ) -> Result<[u8; SHA256_DIGEST_LEN], EDHOCError> { + ) -> Result<(EdhocResponderDone, [u8; SHA256_DIGEST_LEN]), EDHOCError> { match r_process_message_3( self.state, &mut default_crypto(), message_3, self.cred_i.unwrap(), ) { - Ok((state, prk_out)) => { - self.state = state; - Ok(prk_out) - } + Ok((state, prk_out)) => Ok((EdhocResponderDone { state }, prk_out)), Err(error) => Err(error), } } +} +impl EdhocResponderDone { pub fn edhoc_exporter( - self: &mut EdhocResponderState<'a>, + &mut self, label: u8, context: &[u8], length: usize, @@ -145,7 +178,7 @@ impl<'a> EdhocResponderState<'a> { } pub fn edhoc_key_update( - self: &mut EdhocResponderState<'a>, + &mut self, context: &[u8], ) -> Result<[u8; SHA256_DIGEST_LEN], EDHOCError> { let mut context_buf = [0x00u8; MAX_KDF_CONTEXT_LEN]; @@ -353,8 +386,8 @@ mod test { #[test] fn test_new_responder() { let state: EdhocState = Default::default(); - let _responder = EdhocResponderState::new(state, R, CRED_R, Some(CRED_I)); - let _responder = EdhocResponderState::new(state, R, CRED_R, None); + let _responder = EdhocResponder::new(state, R, CRED_R, Some(CRED_I)); + let _responder = EdhocResponder::new(state, R, CRED_R, None); } #[test] @@ -372,13 +405,17 @@ mod test { let message_1_tv_first_time = EdhocMessageBuffer::from_hex(MESSAGE_1_TV_FIRST_TIME); let message_1_tv = EdhocMessageBuffer::from_hex(MESSAGE_1_TV); let state: EdhocState = Default::default(); - let mut responder = EdhocResponderState::new(state, R, CRED_R, Some(CRED_I)); + let responder = EdhocResponder::new(state, R, CRED_R, Some(CRED_I)); // process message_1 first time, when unsupported suite is selected let error = responder.process_message_1(&message_1_tv_first_time); assert!(error.is_err()); assert_eq!(error.unwrap_err(), EDHOCError::UnsupportedCipherSuite); + // We need to create a new responder -- no message is supposed to be processed twice by a + // responder or initiator + let responder = EdhocResponder::new(state, R, CRED_R, Some(CRED_I)); + // process message_1 second time let error = responder.process_message_1(&message_1_tv); assert!(error.is_ok()); @@ -396,30 +433,25 @@ mod test { let state_initiator: EdhocState = Default::default(); let mut initiator = EdhocInitiator::new(state_initiator, I, CRED_I, Some(CRED_R)); let state_responder: EdhocState = Default::default(); - let mut responder = EdhocResponderState::new(state_responder, R, CRED_R, Some(CRED_I)); + let responder = EdhocResponder::new(state_responder, R, CRED_R, Some(CRED_I)); let c_i: u8 = generate_connection_identifier_cbor(); let (initiator, result) = initiator.prepare_message_1(c_i).unwrap(); // to update the state - let error = responder.process_message_1(&result); - assert!(error.is_ok()); + let responder = responder.process_message_1(&result).unwrap(); let c_r = generate_connection_identifier_cbor(); - let ret = responder.prepare_message_2(c_r); - assert!(ret.is_ok()); - - let message_2 = ret.unwrap(); + let (responder, message_2) = responder.prepare_message_2(c_r).unwrap(); assert!(c_r != 0xff); let (initiator, _) = initiator.process_message_2(&message_2).unwrap(); let (mut initiator, message_3, i_prk_out) = initiator.prepare_message_3().unwrap(); - let r_prk_out = responder.process_message_3(&message_3); - assert!(r_prk_out.is_ok()); + let (mut responder, r_prk_out) = responder.process_message_3(&message_3).unwrap(); // check that prk_out is equal at initiator and responder side - assert_eq!(i_prk_out, r_prk_out.unwrap()); + assert_eq!(i_prk_out, r_prk_out); // derive OSCORE secret and salt at both sides and compare let i_oscore_secret = initiator.edhoc_exporter(0u8, &[], 16); // label is 0 @@ -469,7 +501,7 @@ mod test { let state_initiator: EdhocState = Default::default(); let mut initiator = EdhocInitiator::new(state_initiator, I, CRED_I, None); let state_responder: EdhocState = Default::default(); - let mut responder = EdhocResponderState::new(state_responder, R, CRED_V_TV, Some(CRED_I)); + let responder = EdhocResponder::new(state_responder, R, CRED_V_TV, Some(CRED_I)); let u: BytesP256ElemLen = U_TV.try_into().unwrap(); let id_u: EdhocMessageBuffer = ID_U_TV.try_into().unwrap(); @@ -503,14 +535,14 @@ mod test { EADInitiatorProtocolState::WaitEAD2 ); - responder.process_message_1(&message_1).unwrap(); + let responder = responder.process_message_1(&message_1).unwrap(); assert_eq!( ead_responder_state.protocol_state, EADResponderProtocolState::ProcessedEAD1 ); let c_r = generate_connection_identifier_cbor(); - let message_2 = responder.prepare_message_2(c_r).unwrap(); + let (responder, message_2) = responder.prepare_message_2(c_r).unwrap(); assert_eq!( ead_responder_state.protocol_state, EADResponderProtocolState::Completed @@ -525,7 +557,7 @@ mod test { let (initiator, message_3, i_prk_out) = initiator.prepare_message_3().unwrap(); - let r_prk_out = responder.process_message_3(&message_3).unwrap(); + let (mut responder, r_prk_out) = responder.process_message_3(&message_3).unwrap(); assert_eq!(i_prk_out, r_prk_out); assert_eq!( ead_responder_state.protocol_state,