Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Robust packet size calculations and garbage tests #71

Merged
merged 2 commits into from
Oct 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 127 additions & 22 deletions protocol/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ pub const NUM_INITIAL_HANDSHAKE_BUFFER_BYTES: usize = 4096;
const VERSION_CONTENT: [u8; 0] = [];
// Number of bytes for the authentication tag of a packet.
const NUM_TAG_BYTES: usize = 16;
// Maximum number of garbage bytes to read before the terminator.
// Maximum number of garbage bytes before the terminator.
const MAX_NUM_GARBAGE_BYTES: usize = 4095;
// Number of bytes for the garbage terminator.
const NUM_GARBAGE_TERMINTOR_BYTES: usize = 16;
Expand All @@ -85,8 +85,11 @@ pub enum Error {
/// total required bytes for the failed packet so the
/// caller can re-allocate and re-attempt.
BufferTooSmall { required_bytes: usize },
/// The maximum amount of garbage bytes was exceeded in the handshake.
MaxGarbageLength,
/// Tried to send more garbage bytes before terminator than allowed by spec.
TooMuchGarbage,
/// The remote sent the maximum amount of garbage bytes without
/// a garbage terminator in the handshake.
NoGarbageTerminator,
/// A handshake step was not completed in the proper order.
HandshakeOutOfOrder,
/// The remote peer is communicating on the V1 protocol.
Expand All @@ -111,13 +114,17 @@ impl fmt::Display for Error {
"Buffer memory allocation too small, need at least {} bytes.",
required_bytes
),
Error::MaxGarbageLength => {
write!(f, "More than 4095 bytes of garbage in the handshake.")
Error::NoGarbageTerminator => {
write!(f, "More than 4095 bytes of garbage recieved in the handshake before a terminator was sent.")
}
Error::HandshakeOutOfOrder => write!(f, "Handshake flow out of sequence."),
Error::SecretGeneration(e) => write!(f, "Cannot generate secrets: {:?}.", e),
Error::Decryption(e) => write!(f, "Decrytion error: {:?}.", e),
Error::V1Protocol => write!(f, "The remote peer is communicating on the V1 protocol."),
Error::TooMuchGarbage => write!(
f,
"Tried to send more than 4095 bytes of garbage in handshake."
),
}
}
}
Expand Down Expand Up @@ -397,9 +404,9 @@ impl PacketWriter {
packet_type: PacketType,
) -> Result<(), Error> {
// Validate buffer capacity.
if packet.len() < plaintext.len() + NUM_PACKET_OVERHEAD_BYTES {
if packet.len() < PacketWriter::required_packet_allocation(plaintext) {
return Err(Error::BufferTooSmall {
required_bytes: plaintext.len() + NUM_PACKET_OVERHEAD_BYTES,
required_bytes: PacketWriter::required_packet_allocation(plaintext),
});
}

Expand Down Expand Up @@ -433,20 +440,41 @@ impl PacketWriter {
/// Encrypt plaintext bytes and serialize into a packet to be sent over the wire
/// and handle necessary memory allocation.
///
/// # Arguments
///
/// * `plaintext` - Plaintext content to be encrypted.
/// * `aad` - Optional associated authenticated data.
/// * `packet_type` - Is this a genuine packet or a decoy.
///
/// # Returns
///
/// A `Result` containing:
/// * `Ok(Vec<u8>)`: Ciphertext packet.
/// * `Err(Error)`: An error that occurred encrypting plaintext.
#[cfg(feature = "alloc")]
pub fn encrypt_packet(
&mut self,
plaintext: &[u8],
aad: Option<&[u8]>,
packet_type: PacketType,
) -> Result<Vec<u8>, Error> {
let mut packet = vec![0u8; plaintext.len() + NUM_PACKET_OVERHEAD_BYTES];
let mut packet = vec![0u8; PacketWriter::required_packet_allocation(plaintext)];
self.encrypt_packet_no_alloc(plaintext, aad, &mut packet, packet_type)?;
Ok(packet)
}

/// Require bytes to encrpt given plaintext contents as a packet.
///
/// # Arguments
///
/// * `plaintext` - Plaintext contents.
///
/// # Returns
///
/// Number of bytes necessary to be allocated for packet.
pub fn required_packet_allocation(plaintext: &[u8]) -> usize {
plaintext.len() + NUM_PACKET_OVERHEAD_BYTES
}
}

/// Encrypt and decrypt packets with a peer.
Expand Down Expand Up @@ -606,6 +634,13 @@ impl<'a> Handshake<'a> {
rng: &mut impl Rng,
curve: &Secp256k1<C>,
) -> Result<Self, Error> {
if garbage
.as_ref()
.map_or(false, |g| g.len() > MAX_NUM_GARBAGE_BYTES)
{
return Err(Error::TooMuchGarbage);
};

let mut secret_key_buffer = [0u8; 32];
rng.fill(&mut secret_key_buffer[..]);
let sk = SecretKey::from_slice(&secret_key_buffer)?;
Expand Down Expand Up @@ -899,7 +934,7 @@ impl<'a> Handshake<'a> {
{
Ok((&buffer[..index], &buffer[(index + garbage_term.len())..]))
} else if buffer.len() >= (MAX_NUM_GARBAGE_BYTES + NUM_GARBAGE_TERMINTOR_BYTES) {
Err(Error::MaxGarbageLength)
Err(Error::NoGarbageTerminator)
} else {
// Terminator not found, the buffer needs more information.
Err(Error::CiphertextTooSmall)
Expand Down Expand Up @@ -1011,19 +1046,28 @@ where
{
/// New protocol session which completes the initial handshake and returns a handler.
///
/// # Arguments
///
/// * `network` - Network which both parties are operating on.
/// * `role` - Role in handshake, initiator or responder.
/// * `garbage` - Optional garbage bytes to send in handshake.
/// * `decoys` - Optional decoy packet contents bytes to send in handshake.
/// * `reader` - Asynchronous buffer to read packets sent by peer.
/// * `writer` - Asynchronous buffer to write packets to peer.
///
/// # Returns
///
/// A `Result` containing:
/// * `Ok(AsyncProtocol)`: An initialized protocol handler.
/// * `Err(ProtocolError)`: An error that occurred during the handshake.
pub async fn new(
pub async fn new<'a>(
network: Network,
role: Role,
garbage: Option<&[u8]>,
garbage: Option<&'a [u8]>,
decoys: Option<&'a [&'a [u8]]>,
mut reader: R,
mut writer: W,
) -> Result<Self, ProtocolError> {
// Initialize buffer.
let garbage_len = match garbage {
Some(slice) => slice.len(),
None => 0,
Expand All @@ -1039,14 +1083,26 @@ where
let mut remote_ellswift_buffer = [0u8; 64];
reader.read_exact(&mut remote_ellswift_buffer).await?;

let num_version_packet_bytes = PacketWriter::required_packet_allocation(&VERSION_CONTENT);
let num_decoy_packets_bytes: usize = match decoys {
Some(decoys) => decoys
.iter()
.map(|decoy| PacketWriter::required_packet_allocation(decoy))
.sum(),
None => 0,
};

// Complete materials and send terminator to remote.
// Not exposing decoy packets yet.
let mut terminator_and_version_buffer =
vec![0u8; NUM_GARBAGE_TERMINTOR_BYTES + NUM_PACKET_OVERHEAD_BYTES];
vec![
0u8;
NUM_GARBAGE_TERMINTOR_BYTES + num_version_packet_bytes + num_decoy_packets_bytes
];
handshake.complete_materials(
remote_ellswift_buffer,
&mut terminator_and_version_buffer,
None,
decoys,
)?;
writer.write_all(&terminator_and_version_buffer).await?;
writer.flush().await?;
Expand Down Expand Up @@ -1439,7 +1495,50 @@ mod tests {

#[test]
#[cfg(feature = "std")]
fn test_handshake_v1_protocol() {
fn test_handshake_garbage_length_check() {
let mut rng = rand::thread_rng();
let curve = Secp256k1::new();
let mut handshake_buffer = [0u8; NUM_ELLIGATOR_SWIFT_BYTES + MAX_NUM_GARBAGE_BYTES];

// Test with valid garbage length.
let valid_garbage = vec![0u8; MAX_NUM_GARBAGE_BYTES];
let result = Handshake::new_with_rng(
Network::Bitcoin,
Role::Initiator,
Some(&valid_garbage),
&mut handshake_buffer,
&mut rng,
&curve,
);
assert!(result.is_ok());

// Test with garbage length exceeding MAX_NUM_GARBAGE_BYTES.
let invalid_garbage = vec![0u8; MAX_NUM_GARBAGE_BYTES + 1];
let result = Handshake::new_with_rng(
Network::Bitcoin,
Role::Initiator,
Some(&invalid_garbage),
&mut handshake_buffer,
&mut rng,
&curve,
);
assert!(matches!(result, Err(Error::TooMuchGarbage)));

// Test with no garbage.
let result = Handshake::new_with_rng(
Network::Bitcoin,
Role::Initiator,
None,
&mut handshake_buffer,
&mut rng,
&curve,
);
assert!(result.is_ok());
}

#[test]
#[cfg(feature = "std")]
fn test_handshake_no_garbage_terminator() {
let mut handshake_buffer = [0u8; NUM_ELLIGATOR_SWIFT_BYTES];
let mut rng = rand::thread_rng();
let curve = Secp256k1::signing_only();
Expand All @@ -1454,17 +1553,23 @@ mod tests {
)
.expect("Handshake creation should succeed");

// Emulate remote sending network magic for start of V1 protocol.
let mut v1_protocol = [0u8; NUM_ELLIGATOR_SWIFT_BYTES];
v1_protocol[..4].copy_from_slice(&Network::Bitcoin.magic().to_bytes()[..]);
let mut response_buffer = [0u8; 0];
let result = handshake.complete_materials(v1_protocol, &mut response_buffer, None);
assert!(matches!(result, Err(Error::V1Protocol)));
// Skipping material creation and just placing a mock terminator.
handshake.remote_garbage_terminator = Some([0xFF; NUM_GARBAGE_TERMINTOR_BYTES]);

// Test with a buffer that is too long.
let test_buffer = vec![0; MAX_NUM_GARBAGE_BYTES + NUM_GARBAGE_TERMINTOR_BYTES];
let result = handshake.split_garbage(&test_buffer);
assert!(matches!(result, Err(Error::NoGarbageTerminator)));

// Test with a buffer that's just short of the required length.
let short_buffer = vec![0; MAX_NUM_GARBAGE_BYTES + NUM_GARBAGE_TERMINTOR_BYTES - 1];
let result = handshake.split_garbage(&short_buffer);
assert!(matches!(result, Err(Error::CiphertextTooSmall)));
}

#[test]
#[cfg(feature = "std")]
fn test_full_handshake_with_garbage_and_decoys() {
fn test_handshake_with_garbage_and_decoys() {
// Define the garbage and decoys that the initiator is sending to the responder.
let initiator_garbage = vec![1u8, 2u8, 3u8];
let initiator_decoys: Vec<&[u8]> = vec![&[6u8, 7u8], &[8u8, 0u8]];
Expand Down
13 changes: 10 additions & 3 deletions proxy/src/bin/proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,16 @@ async fn proxy_conn(client: TcpStream, network: Network) -> Result<(), bip324_pr
let remote_reader = remote_reader.compat();
let remote_writer = remote_writer.compat_write();

let protocol = AsyncProtocol::new(network, Role::Initiator, None, remote_reader, remote_writer)
.await
.expect("protocol establishment");
let protocol = AsyncProtocol::new(
network,
Role::Initiator,
None,
None,
remote_reader,
remote_writer,
)
.await
.expect("protocol establishment");

let (client_reader, client_writer) = client.into_split();
let mut v1_client_reader = V1ProtocolReader::new(client_reader);
Expand Down
Loading