From e68f20c577b842f8290854023e37f2eddc39c27e Mon Sep 17 00:00:00 2001 From: David Tolnay Date: Fri, 31 Jan 2025 17:56:48 -0800 Subject: [PATCH] Prevent collection overallocation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: In both the binary and compact protocols, previously it was possible for the `Vec::with_capacity` and `HashSet::with_capacity_and_hasher` and `HashMap::with_capacity_and_hasher` in these types' `fbthrift::Deserialize` impls to allocate arbitrarily big collections when fed a small Thrift message containing a large collection size. This makes Rust Thrift servers vulnerable to Denial Of Service from maliciously crafted Thrift messages, or even just unlucky garbage messages. This diff makes these Deserialize impls short-circuit with `ProtocolError::EOF` if the remaining input data cannot possibly contain the number of collection elements claimed by the collection size. For example, in compact protocol which represents `f32` using 4 bytes, deserializing `Vec` with 4000 bytes remaining in the input data will not bother trying to allocate a Vec with capacity larger than 1000. In particular, it will no longer try to allocate a Vec with capacity 10¹⁵, even if the serialized collection size claims that the input contains 10¹⁵ list elements. Reviewed By: zertosh Differential Revision: D68986442 fbshipit-source-id: a3885833d6d49f581912ffcb1365af381671a28e --- .../thrift/lib/rust/src/binary_protocol.rs | 26 +++++++++++ .../thrift/src/thrift/lib/rust/src/bufext.rs | 31 ++++++++++++- .../thrift/lib/rust/src/compact_protocol.rs | 26 +++++++++++ .../src/thrift/lib/rust/src/deserialize.rs | 46 +++++++++++++++---- .../src/thrift/lib/rust/src/protocol.rs | 18 ++++++++ .../src/thrift/lib/rust/src/tests/binary.rs | 40 ++++++++++++++++ .../src/thrift/lib/rust/src/tests/compact.rs | 41 +++++++++++++++++ 7 files changed, 218 insertions(+), 10 deletions(-) diff --git a/third-party/thrift/src/thrift/lib/rust/src/binary_protocol.rs b/third-party/thrift/src/thrift/lib/rust/src/binary_protocol.rs index 2315cc6dd96f3f..40bdd893d743ab 100644 --- a/third-party/thrift/src/thrift/lib/rust/src/binary_protocol.rs +++ b/third-party/thrift/src/thrift/lib/rust/src/binary_protocol.rs @@ -37,6 +37,7 @@ use crate::protocol::ProtocolWriter; use crate::serialize::Serialize; use crate::thrift_protocol::MessageType; use crate::thrift_protocol::ProtocolID; +use crate::ttype::GetTType; use crate::ttype::TType; pub const BINARY_VERSION_MASK: u32 = 0xffff_0000; @@ -460,6 +461,31 @@ impl ProtocolReader for BinaryProtocolDeserializer { ensure_err!(self.buffer.remaining() >= received_len, ProtocolError::EOF); Ok(V::copy_from_buf(&mut self.buffer, received_len)) } + + fn min_size() -> usize { + match T::TTYPE { + TType::Void => 0, + TType::Bool => 1, + TType::Byte => 1, + TType::Double => 8, + TType::I16 => 2, + TType::I32 => 4, + TType::I64 => 8, + TType::String => 4, + TType::Struct => 1, + TType::Map => 6, + TType::Set => 5, + TType::List => 5, + TType::UTF8 => 4, + TType::UTF16 => 4, + TType::Float => 4, + TType::Stop | TType::Stream => unreachable!(), + } + } + + fn can_advance(&self, bytes: usize) -> bool { + self.buffer.can_advance(bytes) + } } /// How large an item will be when `serialize()` is called diff --git a/third-party/thrift/src/thrift/lib/rust/src/bufext.rs b/third-party/thrift/src/thrift/lib/rust/src/bufext.rs index 98559d52b2b751..d1f22e6c0e4f5f 100644 --- a/third-party/thrift/src/thrift/lib/rust/src/bufext.rs +++ b/third-party/thrift/src/thrift/lib/rust/src/bufext.rs @@ -32,6 +32,25 @@ pub trait BufExt: Buf { // Default is to just copy. self.copy_to_bytes(len) } + + /// Whether there are enough remaining bytes to advance this buffer's + /// internal cursor by `n`. + /// + /// Disjoint buffers for which `remaining()` is not O(1) should override + /// this method with a more efficient implementation. + fn can_advance(&self, n: usize) -> bool { + n <= self.remaining() + } + + /// Number of more bytes needed in order to be able to advance by `n`. + /// + /// If `n <= remaining()`, this is 0. Otherwise `n - remaining()`. + /// + /// Disjoint buffers for which `remaining()` is not O(1) should override + /// this method with a more efficient implementation. + fn shortfall(&self, n: usize) -> usize { + n.saturating_sub(self.remaining()) + } } impl BufExt for Bytes {} @@ -50,7 +69,17 @@ impl BufExt for Cursor { impl + ?Sized> BufExt for Cursor<&T> {} -impl BufExt for Chain {} +impl BufExt for Chain { + fn can_advance(&self, n: usize) -> bool { + let rest = self.first_ref().shortfall(n); + self.last_ref().can_advance(rest) + } + + fn shortfall(&self, n: usize) -> usize { + let rest = self.first_ref().shortfall(n); + self.last_ref().shortfall(rest) + } +} pub trait BufMutExt: BufMut { type Final: Send + 'static; diff --git a/third-party/thrift/src/thrift/lib/rust/src/compact_protocol.rs b/third-party/thrift/src/thrift/lib/rust/src/compact_protocol.rs index 046812271f0121..6c1c5446365d70 100644 --- a/third-party/thrift/src/thrift/lib/rust/src/compact_protocol.rs +++ b/third-party/thrift/src/thrift/lib/rust/src/compact_protocol.rs @@ -37,6 +37,7 @@ use crate::protocol::ProtocolWriter; use crate::serialize::Serialize; use crate::thrift_protocol::MessageType; use crate::thrift_protocol::ProtocolID; +use crate::ttype::GetTType; use crate::ttype::TType; use crate::varint; @@ -840,6 +841,31 @@ impl ProtocolReader for CompactProtocolDeserializer { Ok(V::copy_from_buf(&mut self.buffer, received_len)) } + + fn min_size() -> usize { + match T::TTYPE { + TType::Void => 0, + TType::Bool => 1, + TType::Byte => 1, + TType::Double => 8, + TType::I16 => 1, + TType::I32 => 1, + TType::I64 => 1, + TType::String => 1, + TType::Struct => 1, + TType::Map => 1, + TType::Set => 1, + TType::List => 1, + TType::UTF8 => 1, + TType::UTF16 => 1, + TType::Float => 4, + TType::Stop | TType::Stream => unreachable!(), + } + } + + fn can_advance(&self, bytes: usize) -> bool { + self.buffer.can_advance(bytes) + } } /// How large an item will be when `serialize()` is called diff --git a/third-party/thrift/src/thrift/lib/rust/src/deserialize.rs b/third-party/thrift/src/thrift/lib/rust/src/deserialize.rs index 08efcdef53751e..9e40f188ceb784 100644 --- a/third-party/thrift/src/thrift/lib/rust/src/deserialize.rs +++ b/third-party/thrift/src/thrift/lib/rust/src/deserialize.rs @@ -21,11 +21,14 @@ use std::collections::HashSet; use std::hash::Hash; use std::sync::Arc; +use anyhow::bail; use bytes::Bytes; use ordered_float::OrderedFloat; +use crate::errors::ProtocolError; use crate::protocol::should_break; use crate::protocol::ProtocolReader; +use crate::ttype::GetTType; use crate::Result; // Read trait. Every type that needs to be deserialized will implement this trait. @@ -223,13 +226,21 @@ where impl Deserialize

for HashSet where P: ProtocolReader, - T: Deserialize

+ Hash + Eq, + T: Deserialize

+ GetTType + Hash + Eq, S: std::hash::BuildHasher + Default, { fn read(p: &mut P) -> Result { let (_elem_ty, len) = p.read_set_begin()?; - let mut hset = - HashSet::with_capacity_and_hasher(len.unwrap_or_default(), Default::default()); + let mut hset = { + let cap = len.unwrap_or(0); + if match cap.checked_mul(P::min_size::()) { + None => true, + Some(total) => !p.can_advance(total), + } { + bail!(ProtocolError::EOF); + } + HashSet::with_capacity_and_hasher(cap, S::default()) + }; if let Some(0) = len { return Ok(hset); @@ -294,14 +305,22 @@ where impl Deserialize

for HashMap where P: ProtocolReader, - K: Deserialize

+ Hash + Eq, - V: Deserialize

, + K: Deserialize

+ GetTType + Hash + Eq, + V: Deserialize

+ GetTType, S: std::hash::BuildHasher + Default, { fn read(p: &mut P) -> Result { let (_key_ty, _val_ty, len) = p.read_map_begin()?; - let mut hmap = - HashMap::with_capacity_and_hasher(len.unwrap_or_default(), Default::default()); + let mut hmap = { + let cap = len.unwrap_or(0); + if match cap.checked_mul(P::min_size::() + P::min_size::()) { + None => true, + Some(total) => !p.can_advance(total), + } { + bail!(ProtocolError::EOF); + } + HashMap::with_capacity_and_hasher(cap, S::default()) + }; if let Some(0) = len { return Ok(hmap); @@ -332,12 +351,21 @@ where impl Deserialize

for Vec where P: ProtocolReader, - T: Deserialize

+ crate::ttype::GetTType, // GetTType just to exclude Vec + T: Deserialize

+ GetTType, { /// Vec is Thrift List type fn read(p: &mut P) -> Result { let (_elem_ty, len) = p.read_list_begin()?; - let mut list = Vec::with_capacity(len.unwrap_or_default()); + let mut list = { + let cap = len.unwrap_or(0); + if match cap.checked_mul(P::min_size::()) { + None => true, + Some(total) => !p.can_advance(total), + } { + bail!(ProtocolError::EOF); + } + Vec::with_capacity(cap) + }; if let Some(0) = len { return Ok(list); diff --git a/third-party/thrift/src/thrift/lib/rust/src/protocol.rs b/third-party/thrift/src/thrift/lib/rust/src/protocol.rs index 41570d81598999..ca1a67228bda6f 100644 --- a/third-party/thrift/src/thrift/lib/rust/src/protocol.rs +++ b/third-party/thrift/src/thrift/lib/rust/src/protocol.rs @@ -24,6 +24,7 @@ use crate::framing::FramingEncoded; use crate::framing::FramingEncodedFinal; use crate::thrift_protocol::MessageType; use crate::thrift_protocol::ProtocolID; +use crate::ttype::GetTType; use crate::ttype::TType; use crate::Result; @@ -263,6 +264,23 @@ pub trait ProtocolReader { fn read_string(&mut self) -> Result; fn read_binary(&mut self) -> Result; + /// The smallest number of bytes in which a collection element of type `T` + /// could be represented. + // + // TODO: once feature(adt_const_params) is stable, make this generic + // over `const TTYPE: TType` instead of T to reduce monomorphization. + // https://github.com/rust-lang/rust/issues/95174 + fn min_size() -> usize { + 0 + } + + /// Whether there is enough input data available to bother trying to read + /// collection elements that would occupy a minimum of `n` bytes. + fn can_advance(&self, bytes: usize) -> bool { + let _ = bytes; + true + } + /// Skip over the next data element from the provided input Protocol object fn skip(&mut self, field_type: TType) -> Result<()> { skip_inner(self, field_type, DEFAULT_RECURSION_DEPTH) diff --git a/third-party/thrift/src/thrift/lib/rust/src/tests/binary.rs b/third-party/thrift/src/thrift/lib/rust/src/tests/binary.rs index 2f548ccb094cd7..735dbde6580357 100644 --- a/third-party/thrift/src/thrift/lib/rust/src/tests/binary.rs +++ b/third-party/thrift/src/thrift/lib/rust/src/tests/binary.rs @@ -14,10 +14,14 @@ * limitations under the License. */ +use std::collections::HashMap; +use std::collections::HashSet; use std::io::Cursor; use bytes::Buf; +use bytes::BufMut; use bytes::Bytes; +use bytes::BytesMut; use super::BOOL_VALUES; use super::BYTE_VALUES; @@ -26,6 +30,8 @@ use super::FLOAT_VALUES; use super::INT16_VALUES; use super::INT32_VALUES; use super::INT64_VALUES; +use crate::deserialize::Deserialize; +use crate::errors::ProtocolError; use crate::thrift_protocol::MessageType; use crate::ttype::TType; use crate::BinaryProtocol; @@ -508,3 +514,37 @@ fn read_binary_from_chained_buffer() { .expect("read \" world\" from the buffer"); assert_eq!(result.as_slice(), b" world"); } + +#[test] +fn test_overallocation() { + let mut malicious = BytesMut::new(); + malicious.put_u8(TType::I16 as u8); + malicious.put_i32(1_000_000_000); + malicious.put_bytes(0, 10); + let malicious = malicious.freeze(); + let mut deserializer = ::deserializer(Cursor::new(malicious.clone())); + let err = as Deserialize<_>>::read(&mut deserializer).unwrap_err(); + assert_eq!( + err.downcast_ref::(), + Some(&ProtocolError::EOF), + ); + + let mut deserializer = ::deserializer(Cursor::new(malicious)); + let err = as Deserialize<_>>::read(&mut deserializer).unwrap_err(); + assert_eq!( + err.downcast_ref::(), + Some(&ProtocolError::EOF), + ); + + let mut malicious = BytesMut::new(); + malicious.put_u8(TType::String as u8); + malicious.put_u8(TType::I16 as u8); + malicious.put_i32(1_000_000_000); + malicious.put_bytes(0, 10); + let mut deserializer = ::deserializer(Cursor::new(malicious.freeze())); + let err = as Deserialize<_>>::read(&mut deserializer).unwrap_err(); + assert_eq!( + err.downcast_ref::(), + Some(&ProtocolError::EOF), + ); +} diff --git a/third-party/thrift/src/thrift/lib/rust/src/tests/compact.rs b/third-party/thrift/src/thrift/lib/rust/src/tests/compact.rs index 2196c12253915a..5e20ff6bf772b3 100644 --- a/third-party/thrift/src/thrift/lib/rust/src/tests/compact.rs +++ b/third-party/thrift/src/thrift/lib/rust/src/tests/compact.rs @@ -14,10 +14,14 @@ * limitations under the License. */ +use std::collections::HashMap; +use std::collections::HashSet; use std::io::Cursor; use bytes::Buf; +use bytes::BufMut; use bytes::Bytes; +use bytes::BytesMut; use super::BOOL_VALUES; use super::BYTE_VALUES; @@ -26,6 +30,10 @@ use super::FLOAT_VALUES; use super::INT16_VALUES; use super::INT32_VALUES; use super::INT64_VALUES; +use crate::bufext::BufMutExt as _; +use crate::compact_protocol::CType; +use crate::deserialize::Deserialize; +use crate::errors::ProtocolError; use crate::thrift_protocol::MessageType; use crate::ttype::TType; use crate::CompactProtocol; @@ -599,3 +607,36 @@ fn skip_stop_in_container() { }, } } + +#[test] +fn test_overallocation() { + let mut malicious = BytesMut::new(); + malicious.put_u8(0xf0 | CType::I16 as u8); + malicious.put_varint_i64(1_000_000_000_000_000); + malicious.put_bytes(0, 10); + let malicious = malicious.freeze(); + let mut deserializer = ::deserializer(Cursor::new(malicious.clone())); + let err = as Deserialize<_>>::read(&mut deserializer).unwrap_err(); + assert_eq!( + err.downcast_ref::(), + Some(&ProtocolError::EOF), + ); + + let mut deserializer = ::deserializer(Cursor::new(malicious)); + let err = as Deserialize<_>>::read(&mut deserializer).unwrap_err(); + assert_eq!( + err.downcast_ref::(), + Some(&ProtocolError::EOF), + ); + + let mut malicious = BytesMut::new(); + malicious.put_varint_i64(1_000_000_000_000_000); + malicious.put_u8(((CType::Binary as u8) << 4) | (CType::I16 as u8)); + malicious.put_bytes(0, 10); + let mut deserializer = ::deserializer(Cursor::new(malicious.freeze())); + let err = as Deserialize<_>>::read(&mut deserializer).unwrap_err(); + assert_eq!( + err.downcast_ref::(), + Some(&ProtocolError::EOF), + ); +}