From f61210b230c3d784a6295b250d3d0c0e6e44ab1c Mon Sep 17 00:00:00 2001 From: Nick Johnson Date: Wed, 2 Oct 2024 12:29:11 -0700 Subject: [PATCH 1/2] Add required allocation helper function --- protocol/src/lib.rs | 58 ++++++++++++++++++++++++++++++++++++------ proxy/src/bin/proxy.rs | 13 +++++++--- 2 files changed, 60 insertions(+), 11 deletions(-) diff --git a/protocol/src/lib.rs b/protocol/src/lib.rs index 1ea5d32..32199a4 100644 --- a/protocol/src/lib.rs +++ b/protocol/src/lib.rs @@ -397,9 +397,9 @@ impl PacketWriter { packet_type: PacketType, ) -> Result<(), Error> { // Validate buffer capacity. - if packet.len() < plaintext.len() + NUM_PACKET_OVERHEAD_BYTES { + if packet.len() < PacketWriter::required_packet_allocation(plaintext) { return Err(Error::BufferTooSmall { - required_bytes: plaintext.len() + NUM_PACKET_OVERHEAD_BYTES, + required_bytes: PacketWriter::required_packet_allocation(plaintext), }); } @@ -433,9 +433,17 @@ impl PacketWriter { /// Encrypt plaintext bytes and serialize into a packet to be sent over the wire /// and handle necessary memory allocation. /// + /// # Arguments + /// /// * `plaintext` - Plaintext content to be encrypted. /// * `aad` - Optional associated authenticated data. /// * `packet_type` - Is this a genuine packet or a decoy. + /// + /// # Returns + /// + /// A `Result` containing: + /// * `Ok(Vec)`: Ciphertext packet. + /// * `Err(Error)`: An error that occurred encrypting plaintext. #[cfg(feature = "alloc")] pub fn encrypt_packet( &mut self, @@ -443,10 +451,23 @@ impl PacketWriter { aad: Option<&[u8]>, packet_type: PacketType, ) -> Result, Error> { - let mut packet = vec![0u8; plaintext.len() + NUM_PACKET_OVERHEAD_BYTES]; + let mut packet = vec![0u8; PacketWriter::required_packet_allocation(plaintext)]; self.encrypt_packet_no_alloc(plaintext, aad, &mut packet, packet_type)?; Ok(packet) } + + /// Require bytes to encrpt given plaintext contents as a packet. + /// + /// # Arguments + /// + /// * `plaintext` - Plaintext contents. + /// + /// # Returns + /// + /// Number of bytes necessary to be allocated for packet. + pub fn required_packet_allocation(plaintext: &[u8]) -> usize { + plaintext.len() + NUM_PACKET_OVERHEAD_BYTES + } } /// Encrypt and decrypt packets with a peer. @@ -1011,19 +1032,28 @@ where { /// New protocol session which completes the initial handshake and returns a handler. /// + /// # Arguments + /// + /// * `network` - Network which both parties are operating on. + /// * `role` - Role in handshake, initiator or responder. + /// * `garbage` - Optional garbage bytes to send in handshake. + /// * `decoys` - Optional decoy packet contents bytes to send in handshake. + /// * `reader` - Asynchronous buffer to read packets sent by peer. + /// * `writer` - Asynchronous buffer to write packets to peer. + /// /// # Returns /// /// A `Result` containing: /// * `Ok(AsyncProtocol)`: An initialized protocol handler. /// * `Err(ProtocolError)`: An error that occurred during the handshake. - pub async fn new( + pub async fn new<'a>( network: Network, role: Role, - garbage: Option<&[u8]>, + garbage: Option<&'a [u8]>, + decoys: Option<&'a [&'a [u8]]>, mut reader: R, mut writer: W, ) -> Result { - // Initialize buffer. let garbage_len = match garbage { Some(slice) => slice.len(), None => 0, @@ -1039,14 +1069,26 @@ where let mut remote_ellswift_buffer = [0u8; 64]; reader.read_exact(&mut remote_ellswift_buffer).await?; + let num_version_packet_bytes = PacketWriter::required_packet_allocation(&VERSION_CONTENT); + let num_decoy_packets_bytes: usize = match decoys { + Some(decoys) => decoys + .iter() + .map(|decoy| PacketWriter::required_packet_allocation(decoy)) + .sum(), + None => 0, + }; + // Complete materials and send terminator to remote. // Not exposing decoy packets yet. let mut terminator_and_version_buffer = - vec![0u8; NUM_GARBAGE_TERMINTOR_BYTES + NUM_PACKET_OVERHEAD_BYTES]; + vec![ + 0u8; + NUM_GARBAGE_TERMINTOR_BYTES + num_version_packet_bytes + num_decoy_packets_bytes + ]; handshake.complete_materials( remote_ellswift_buffer, &mut terminator_and_version_buffer, - None, + decoys, )?; writer.write_all(&terminator_and_version_buffer).await?; writer.flush().await?; diff --git a/proxy/src/bin/proxy.rs b/proxy/src/bin/proxy.rs index 8f8eb67..61612be 100644 --- a/proxy/src/bin/proxy.rs +++ b/proxy/src/bin/proxy.rs @@ -34,9 +34,16 @@ async fn proxy_conn(client: TcpStream, network: Network) -> Result<(), bip324_pr let remote_reader = remote_reader.compat(); let remote_writer = remote_writer.compat_write(); - let protocol = AsyncProtocol::new(network, Role::Initiator, None, remote_reader, remote_writer) - .await - .expect("protocol establishment"); + let protocol = AsyncProtocol::new( + network, + Role::Initiator, + None, + None, + remote_reader, + remote_writer, + ) + .await + .expect("protocol establishment"); let (client_reader, client_writer) = client.into_split(); let mut v1_client_reader = V1ProtocolReader::new(client_reader); From 75e49b57835d038d6ff164aa10aa5422b608b96a Mon Sep 17 00:00:00 2001 From: Nick Johnson Date: Thu, 3 Oct 2024 10:00:47 -0700 Subject: [PATCH 2/2] Add back garbage tests --- protocol/src/lib.rs | 91 ++++++++++++++++++++++++++++++++++++++------- 1 file changed, 77 insertions(+), 14 deletions(-) diff --git a/protocol/src/lib.rs b/protocol/src/lib.rs index 32199a4..dd26664 100644 --- a/protocol/src/lib.rs +++ b/protocol/src/lib.rs @@ -67,7 +67,7 @@ pub const NUM_INITIAL_HANDSHAKE_BUFFER_BYTES: usize = 4096; const VERSION_CONTENT: [u8; 0] = []; // Number of bytes for the authentication tag of a packet. const NUM_TAG_BYTES: usize = 16; -// Maximum number of garbage bytes to read before the terminator. +// Maximum number of garbage bytes before the terminator. const MAX_NUM_GARBAGE_BYTES: usize = 4095; // Number of bytes for the garbage terminator. const NUM_GARBAGE_TERMINTOR_BYTES: usize = 16; @@ -85,8 +85,11 @@ pub enum Error { /// total required bytes for the failed packet so the /// caller can re-allocate and re-attempt. BufferTooSmall { required_bytes: usize }, - /// The maximum amount of garbage bytes was exceeded in the handshake. - MaxGarbageLength, + /// Tried to send more garbage bytes before terminator than allowed by spec. + TooMuchGarbage, + /// The remote sent the maximum amount of garbage bytes without + /// a garbage terminator in the handshake. + NoGarbageTerminator, /// A handshake step was not completed in the proper order. HandshakeOutOfOrder, /// The remote peer is communicating on the V1 protocol. @@ -111,13 +114,17 @@ impl fmt::Display for Error { "Buffer memory allocation too small, need at least {} bytes.", required_bytes ), - Error::MaxGarbageLength => { - write!(f, "More than 4095 bytes of garbage in the handshake.") + Error::NoGarbageTerminator => { + write!(f, "More than 4095 bytes of garbage recieved in the handshake before a terminator was sent.") } Error::HandshakeOutOfOrder => write!(f, "Handshake flow out of sequence."), Error::SecretGeneration(e) => write!(f, "Cannot generate secrets: {:?}.", e), Error::Decryption(e) => write!(f, "Decrytion error: {:?}.", e), Error::V1Protocol => write!(f, "The remote peer is communicating on the V1 protocol."), + Error::TooMuchGarbage => write!( + f, + "Tried to send more than 4095 bytes of garbage in handshake." + ), } } } @@ -627,6 +634,13 @@ impl<'a> Handshake<'a> { rng: &mut impl Rng, curve: &Secp256k1, ) -> Result { + if garbage + .as_ref() + .map_or(false, |g| g.len() > MAX_NUM_GARBAGE_BYTES) + { + return Err(Error::TooMuchGarbage); + }; + let mut secret_key_buffer = [0u8; 32]; rng.fill(&mut secret_key_buffer[..]); let sk = SecretKey::from_slice(&secret_key_buffer)?; @@ -920,7 +934,7 @@ impl<'a> Handshake<'a> { { Ok((&buffer[..index], &buffer[(index + garbage_term.len())..])) } else if buffer.len() >= (MAX_NUM_GARBAGE_BYTES + NUM_GARBAGE_TERMINTOR_BYTES) { - Err(Error::MaxGarbageLength) + Err(Error::NoGarbageTerminator) } else { // Terminator not found, the buffer needs more information. Err(Error::CiphertextTooSmall) @@ -1481,7 +1495,50 @@ mod tests { #[test] #[cfg(feature = "std")] - fn test_handshake_v1_protocol() { + fn test_handshake_garbage_length_check() { + let mut rng = rand::thread_rng(); + let curve = Secp256k1::new(); + let mut handshake_buffer = [0u8; NUM_ELLIGATOR_SWIFT_BYTES + MAX_NUM_GARBAGE_BYTES]; + + // Test with valid garbage length. + let valid_garbage = vec![0u8; MAX_NUM_GARBAGE_BYTES]; + let result = Handshake::new_with_rng( + Network::Bitcoin, + Role::Initiator, + Some(&valid_garbage), + &mut handshake_buffer, + &mut rng, + &curve, + ); + assert!(result.is_ok()); + + // Test with garbage length exceeding MAX_NUM_GARBAGE_BYTES. + let invalid_garbage = vec![0u8; MAX_NUM_GARBAGE_BYTES + 1]; + let result = Handshake::new_with_rng( + Network::Bitcoin, + Role::Initiator, + Some(&invalid_garbage), + &mut handshake_buffer, + &mut rng, + &curve, + ); + assert!(matches!(result, Err(Error::TooMuchGarbage))); + + // Test with no garbage. + let result = Handshake::new_with_rng( + Network::Bitcoin, + Role::Initiator, + None, + &mut handshake_buffer, + &mut rng, + &curve, + ); + assert!(result.is_ok()); + } + + #[test] + #[cfg(feature = "std")] + fn test_handshake_no_garbage_terminator() { let mut handshake_buffer = [0u8; NUM_ELLIGATOR_SWIFT_BYTES]; let mut rng = rand::thread_rng(); let curve = Secp256k1::signing_only(); @@ -1496,17 +1553,23 @@ mod tests { ) .expect("Handshake creation should succeed"); - // Emulate remote sending network magic for start of V1 protocol. - let mut v1_protocol = [0u8; NUM_ELLIGATOR_SWIFT_BYTES]; - v1_protocol[..4].copy_from_slice(&Network::Bitcoin.magic().to_bytes()[..]); - let mut response_buffer = [0u8; 0]; - let result = handshake.complete_materials(v1_protocol, &mut response_buffer, None); - assert!(matches!(result, Err(Error::V1Protocol))); + // Skipping material creation and just placing a mock terminator. + handshake.remote_garbage_terminator = Some([0xFF; NUM_GARBAGE_TERMINTOR_BYTES]); + + // Test with a buffer that is too long. + let test_buffer = vec![0; MAX_NUM_GARBAGE_BYTES + NUM_GARBAGE_TERMINTOR_BYTES]; + let result = handshake.split_garbage(&test_buffer); + assert!(matches!(result, Err(Error::NoGarbageTerminator))); + + // Test with a buffer that's just short of the required length. + let short_buffer = vec![0; MAX_NUM_GARBAGE_BYTES + NUM_GARBAGE_TERMINTOR_BYTES - 1]; + let result = handshake.split_garbage(&short_buffer); + assert!(matches!(result, Err(Error::CiphertextTooSmall))); } #[test] #[cfg(feature = "std")] - fn test_full_handshake_with_garbage_and_decoys() { + fn test_handshake_with_garbage_and_decoys() { // Define the garbage and decoys that the initiator is sending to the responder. let initiator_garbage = vec![1u8, 2u8, 3u8]; let initiator_decoys: Vec<&[u8]> = vec![&[6u8, 7u8], &[8u8, 0u8]];