Skip to content

Commit

Permalink
Merge pull request #71 from nyonson/decoy-calc
Browse files Browse the repository at this point in the history
Robust packet size calculations and garbage tests
  • Loading branch information
rustaceanrob authored Oct 18, 2024
2 parents 73d9361 + 75e49b5 commit b7de828
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 25 deletions.
149 changes: 127 additions & 22 deletions protocol/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ pub const NUM_INITIAL_HANDSHAKE_BUFFER_BYTES: usize = 4096;
const VERSION_CONTENT: [u8; 0] = [];
// Number of bytes for the authentication tag of a packet.
const NUM_TAG_BYTES: usize = 16;
// Maximum number of garbage bytes to read before the terminator.
// Maximum number of garbage bytes before the terminator.
const MAX_NUM_GARBAGE_BYTES: usize = 4095;
// Number of bytes for the garbage terminator.
const NUM_GARBAGE_TERMINTOR_BYTES: usize = 16;
Expand All @@ -85,8 +85,11 @@ pub enum Error {
/// total required bytes for the failed packet so the
/// caller can re-allocate and re-attempt.
BufferTooSmall { required_bytes: usize },
/// The maximum amount of garbage bytes was exceeded in the handshake.
MaxGarbageLength,
/// Tried to send more garbage bytes before terminator than allowed by spec.
TooMuchGarbage,
/// The remote sent the maximum amount of garbage bytes without
/// a garbage terminator in the handshake.
NoGarbageTerminator,
/// A handshake step was not completed in the proper order.
HandshakeOutOfOrder,
/// The remote peer is communicating on the V1 protocol.
Expand All @@ -111,13 +114,17 @@ impl fmt::Display for Error {
"Buffer memory allocation too small, need at least {} bytes.",
required_bytes
),
Error::MaxGarbageLength => {
write!(f, "More than 4095 bytes of garbage in the handshake.")
Error::NoGarbageTerminator => {
write!(f, "More than 4095 bytes of garbage recieved in the handshake before a terminator was sent.")
}
Error::HandshakeOutOfOrder => write!(f, "Handshake flow out of sequence."),
Error::SecretGeneration(e) => write!(f, "Cannot generate secrets: {:?}.", e),
Error::Decryption(e) => write!(f, "Decrytion error: {:?}.", e),
Error::V1Protocol => write!(f, "The remote peer is communicating on the V1 protocol."),
Error::TooMuchGarbage => write!(
f,
"Tried to send more than 4095 bytes of garbage in handshake."
),
}
}
}
Expand Down Expand Up @@ -397,9 +404,9 @@ impl PacketWriter {
packet_type: PacketType,
) -> Result<(), Error> {
// Validate buffer capacity.
if packet.len() < plaintext.len() + NUM_PACKET_OVERHEAD_BYTES {
if packet.len() < PacketWriter::required_packet_allocation(plaintext) {
return Err(Error::BufferTooSmall {
required_bytes: plaintext.len() + NUM_PACKET_OVERHEAD_BYTES,
required_bytes: PacketWriter::required_packet_allocation(plaintext),
});
}

Expand Down Expand Up @@ -433,20 +440,41 @@ impl PacketWriter {
/// Encrypt plaintext bytes and serialize into a packet to be sent over the wire
/// and handle necessary memory allocation.
///
/// # Arguments
///
/// * `plaintext` - Plaintext content to be encrypted.
/// * `aad` - Optional associated authenticated data.
/// * `packet_type` - Is this a genuine packet or a decoy.
///
/// # Returns
///
/// A `Result` containing:
/// * `Ok(Vec<u8>)`: Ciphertext packet.
/// * `Err(Error)`: An error that occurred encrypting plaintext.
#[cfg(feature = "alloc")]
pub fn encrypt_packet(
&mut self,
plaintext: &[u8],
aad: Option<&[u8]>,
packet_type: PacketType,
) -> Result<Vec<u8>, Error> {
let mut packet = vec![0u8; plaintext.len() + NUM_PACKET_OVERHEAD_BYTES];
let mut packet = vec![0u8; PacketWriter::required_packet_allocation(plaintext)];
self.encrypt_packet_no_alloc(plaintext, aad, &mut packet, packet_type)?;
Ok(packet)
}

/// Require bytes to encrpt given plaintext contents as a packet.
///
/// # Arguments
///
/// * `plaintext` - Plaintext contents.
///
/// # Returns
///
/// Number of bytes necessary to be allocated for packet.
pub fn required_packet_allocation(plaintext: &[u8]) -> usize {
plaintext.len() + NUM_PACKET_OVERHEAD_BYTES
}
}

/// Encrypt and decrypt packets with a peer.
Expand Down Expand Up @@ -606,6 +634,13 @@ impl<'a> Handshake<'a> {
rng: &mut impl Rng,
curve: &Secp256k1<C>,
) -> Result<Self, Error> {
if garbage
.as_ref()
.map_or(false, |g| g.len() > MAX_NUM_GARBAGE_BYTES)
{
return Err(Error::TooMuchGarbage);
};

let mut secret_key_buffer = [0u8; 32];
rng.fill(&mut secret_key_buffer[..]);
let sk = SecretKey::from_slice(&secret_key_buffer)?;
Expand Down Expand Up @@ -899,7 +934,7 @@ impl<'a> Handshake<'a> {
{
Ok((&buffer[..index], &buffer[(index + garbage_term.len())..]))
} else if buffer.len() >= (MAX_NUM_GARBAGE_BYTES + NUM_GARBAGE_TERMINTOR_BYTES) {
Err(Error::MaxGarbageLength)
Err(Error::NoGarbageTerminator)
} else {
// Terminator not found, the buffer needs more information.
Err(Error::CiphertextTooSmall)
Expand Down Expand Up @@ -1011,19 +1046,28 @@ where
{
/// New protocol session which completes the initial handshake and returns a handler.
///
/// # Arguments
///
/// * `network` - Network which both parties are operating on.
/// * `role` - Role in handshake, initiator or responder.
/// * `garbage` - Optional garbage bytes to send in handshake.
/// * `decoys` - Optional decoy packet contents bytes to send in handshake.
/// * `reader` - Asynchronous buffer to read packets sent by peer.
/// * `writer` - Asynchronous buffer to write packets to peer.
///
/// # Returns
///
/// A `Result` containing:
/// * `Ok(AsyncProtocol)`: An initialized protocol handler.
/// * `Err(ProtocolError)`: An error that occurred during the handshake.
pub async fn new(
pub async fn new<'a>(
network: Network,
role: Role,
garbage: Option<&[u8]>,
garbage: Option<&'a [u8]>,
decoys: Option<&'a [&'a [u8]]>,
mut reader: R,
mut writer: W,
) -> Result<Self, ProtocolError> {
// Initialize buffer.
let garbage_len = match garbage {
Some(slice) => slice.len(),
None => 0,
Expand All @@ -1039,14 +1083,26 @@ where
let mut remote_ellswift_buffer = [0u8; 64];
reader.read_exact(&mut remote_ellswift_buffer).await?;

let num_version_packet_bytes = PacketWriter::required_packet_allocation(&VERSION_CONTENT);
let num_decoy_packets_bytes: usize = match decoys {
Some(decoys) => decoys
.iter()
.map(|decoy| PacketWriter::required_packet_allocation(decoy))
.sum(),
None => 0,
};

// Complete materials and send terminator to remote.
// Not exposing decoy packets yet.
let mut terminator_and_version_buffer =
vec![0u8; NUM_GARBAGE_TERMINTOR_BYTES + NUM_PACKET_OVERHEAD_BYTES];
vec![
0u8;
NUM_GARBAGE_TERMINTOR_BYTES + num_version_packet_bytes + num_decoy_packets_bytes
];
handshake.complete_materials(
remote_ellswift_buffer,
&mut terminator_and_version_buffer,
None,
decoys,
)?;
writer.write_all(&terminator_and_version_buffer).await?;
writer.flush().await?;
Expand Down Expand Up @@ -1439,7 +1495,50 @@ mod tests {

#[test]
#[cfg(feature = "std")]
fn test_handshake_v1_protocol() {
fn test_handshake_garbage_length_check() {
let mut rng = rand::thread_rng();
let curve = Secp256k1::new();
let mut handshake_buffer = [0u8; NUM_ELLIGATOR_SWIFT_BYTES + MAX_NUM_GARBAGE_BYTES];

// Test with valid garbage length.
let valid_garbage = vec![0u8; MAX_NUM_GARBAGE_BYTES];
let result = Handshake::new_with_rng(
Network::Bitcoin,
Role::Initiator,
Some(&valid_garbage),
&mut handshake_buffer,
&mut rng,
&curve,
);
assert!(result.is_ok());

// Test with garbage length exceeding MAX_NUM_GARBAGE_BYTES.
let invalid_garbage = vec![0u8; MAX_NUM_GARBAGE_BYTES + 1];
let result = Handshake::new_with_rng(
Network::Bitcoin,
Role::Initiator,
Some(&invalid_garbage),
&mut handshake_buffer,
&mut rng,
&curve,
);
assert!(matches!(result, Err(Error::TooMuchGarbage)));

// Test with no garbage.
let result = Handshake::new_with_rng(
Network::Bitcoin,
Role::Initiator,
None,
&mut handshake_buffer,
&mut rng,
&curve,
);
assert!(result.is_ok());
}

#[test]
#[cfg(feature = "std")]
fn test_handshake_no_garbage_terminator() {
let mut handshake_buffer = [0u8; NUM_ELLIGATOR_SWIFT_BYTES];
let mut rng = rand::thread_rng();
let curve = Secp256k1::signing_only();
Expand All @@ -1454,17 +1553,23 @@ mod tests {
)
.expect("Handshake creation should succeed");

// Emulate remote sending network magic for start of V1 protocol.
let mut v1_protocol = [0u8; NUM_ELLIGATOR_SWIFT_BYTES];
v1_protocol[..4].copy_from_slice(&Network::Bitcoin.magic().to_bytes()[..]);
let mut response_buffer = [0u8; 0];
let result = handshake.complete_materials(v1_protocol, &mut response_buffer, None);
assert!(matches!(result, Err(Error::V1Protocol)));
// Skipping material creation and just placing a mock terminator.
handshake.remote_garbage_terminator = Some([0xFF; NUM_GARBAGE_TERMINTOR_BYTES]);

// Test with a buffer that is too long.
let test_buffer = vec![0; MAX_NUM_GARBAGE_BYTES + NUM_GARBAGE_TERMINTOR_BYTES];
let result = handshake.split_garbage(&test_buffer);
assert!(matches!(result, Err(Error::NoGarbageTerminator)));

// Test with a buffer that's just short of the required length.
let short_buffer = vec![0; MAX_NUM_GARBAGE_BYTES + NUM_GARBAGE_TERMINTOR_BYTES - 1];
let result = handshake.split_garbage(&short_buffer);
assert!(matches!(result, Err(Error::CiphertextTooSmall)));
}

#[test]
#[cfg(feature = "std")]
fn test_full_handshake_with_garbage_and_decoys() {
fn test_handshake_with_garbage_and_decoys() {
// Define the garbage and decoys that the initiator is sending to the responder.
let initiator_garbage = vec![1u8, 2u8, 3u8];
let initiator_decoys: Vec<&[u8]> = vec![&[6u8, 7u8], &[8u8, 0u8]];
Expand Down
13 changes: 10 additions & 3 deletions proxy/src/bin/proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,16 @@ async fn proxy_conn(client: TcpStream, network: Network) -> Result<(), bip324_pr
let remote_reader = remote_reader.compat();
let remote_writer = remote_writer.compat_write();

let protocol = AsyncProtocol::new(network, Role::Initiator, None, remote_reader, remote_writer)
.await
.expect("protocol establishment");
let protocol = AsyncProtocol::new(
network,
Role::Initiator,
None,
None,
remote_reader,
remote_writer,
)
.await
.expect("protocol establishment");

let (client_reader, client_writer) = client.into_split();
let mut v1_client_reader = V1ProtocolReader::new(client_reader);
Expand Down

0 comments on commit b7de828

Please sign in to comment.