diff --git a/third-party/thrift/src/thrift/lib/rust/src/binary_protocol.rs b/third-party/thrift/src/thrift/lib/rust/src/binary_protocol.rs index 2315cc6dd96f3f..40bdd893d743ab 100644 --- a/third-party/thrift/src/thrift/lib/rust/src/binary_protocol.rs +++ b/third-party/thrift/src/thrift/lib/rust/src/binary_protocol.rs @@ -37,6 +37,7 @@ use crate::protocol::ProtocolWriter; use crate::serialize::Serialize; use crate::thrift_protocol::MessageType; use crate::thrift_protocol::ProtocolID; +use crate::ttype::GetTType; use crate::ttype::TType; pub const BINARY_VERSION_MASK: u32 = 0xffff_0000; @@ -460,6 +461,31 @@ impl ProtocolReader for BinaryProtocolDeserializer { ensure_err!(self.buffer.remaining() >= received_len, ProtocolError::EOF); Ok(V::copy_from_buf(&mut self.buffer, received_len)) } + + fn min_size() -> usize { + match T::TTYPE { + TType::Void => 0, + TType::Bool => 1, + TType::Byte => 1, + TType::Double => 8, + TType::I16 => 2, + TType::I32 => 4, + TType::I64 => 8, + TType::String => 4, + TType::Struct => 1, + TType::Map => 6, + TType::Set => 5, + TType::List => 5, + TType::UTF8 => 4, + TType::UTF16 => 4, + TType::Float => 4, + TType::Stop | TType::Stream => unreachable!(), + } + } + + fn can_advance(&self, bytes: usize) -> bool { + self.buffer.can_advance(bytes) + } } /// How large an item will be when `serialize()` is called diff --git a/third-party/thrift/src/thrift/lib/rust/src/bufext.rs b/third-party/thrift/src/thrift/lib/rust/src/bufext.rs index 98559d52b2b751..d1f22e6c0e4f5f 100644 --- a/third-party/thrift/src/thrift/lib/rust/src/bufext.rs +++ b/third-party/thrift/src/thrift/lib/rust/src/bufext.rs @@ -32,6 +32,25 @@ pub trait BufExt: Buf { // Default is to just copy. self.copy_to_bytes(len) } + + /// Whether there are enough remaining bytes to advance this buffer's + /// internal cursor by `n`. + /// + /// Disjoint buffers for which `remaining()` is not O(1) should override + /// this method with a more efficient implementation. + fn can_advance(&self, n: usize) -> bool { + n <= self.remaining() + } + + /// Number of more bytes needed in order to be able to advance by `n`. + /// + /// If `n <= remaining()`, this is 0. Otherwise `n - remaining()`. + /// + /// Disjoint buffers for which `remaining()` is not O(1) should override + /// this method with a more efficient implementation. + fn shortfall(&self, n: usize) -> usize { + n.saturating_sub(self.remaining()) + } } impl BufExt for Bytes {} @@ -50,7 +69,17 @@ impl BufExt for Cursor { impl + ?Sized> BufExt for Cursor<&T> {} -impl BufExt for Chain {} +impl BufExt for Chain { + fn can_advance(&self, n: usize) -> bool { + let rest = self.first_ref().shortfall(n); + self.last_ref().can_advance(rest) + } + + fn shortfall(&self, n: usize) -> usize { + let rest = self.first_ref().shortfall(n); + self.last_ref().shortfall(rest) + } +} pub trait BufMutExt: BufMut { type Final: Send + 'static; diff --git a/third-party/thrift/src/thrift/lib/rust/src/compact_protocol.rs b/third-party/thrift/src/thrift/lib/rust/src/compact_protocol.rs index 046812271f0121..6c1c5446365d70 100644 --- a/third-party/thrift/src/thrift/lib/rust/src/compact_protocol.rs +++ b/third-party/thrift/src/thrift/lib/rust/src/compact_protocol.rs @@ -37,6 +37,7 @@ use crate::protocol::ProtocolWriter; use crate::serialize::Serialize; use crate::thrift_protocol::MessageType; use crate::thrift_protocol::ProtocolID; +use crate::ttype::GetTType; use crate::ttype::TType; use crate::varint; @@ -840,6 +841,31 @@ impl ProtocolReader for CompactProtocolDeserializer { Ok(V::copy_from_buf(&mut self.buffer, received_len)) } + + fn min_size() -> usize { + match T::TTYPE { + TType::Void => 0, + TType::Bool => 1, + TType::Byte => 1, + TType::Double => 8, + TType::I16 => 1, + TType::I32 => 1, + TType::I64 => 1, + TType::String => 1, + TType::Struct => 1, + TType::Map => 1, + TType::Set => 1, + TType::List => 1, + TType::UTF8 => 1, + TType::UTF16 => 1, + TType::Float => 4, + TType::Stop | TType::Stream => unreachable!(), + } + } + + fn can_advance(&self, bytes: usize) -> bool { + self.buffer.can_advance(bytes) + } } /// How large an item will be when `serialize()` is called diff --git a/third-party/thrift/src/thrift/lib/rust/src/deserialize.rs b/third-party/thrift/src/thrift/lib/rust/src/deserialize.rs index 08efcdef53751e..9e40f188ceb784 100644 --- a/third-party/thrift/src/thrift/lib/rust/src/deserialize.rs +++ b/third-party/thrift/src/thrift/lib/rust/src/deserialize.rs @@ -21,11 +21,14 @@ use std::collections::HashSet; use std::hash::Hash; use std::sync::Arc; +use anyhow::bail; use bytes::Bytes; use ordered_float::OrderedFloat; +use crate::errors::ProtocolError; use crate::protocol::should_break; use crate::protocol::ProtocolReader; +use crate::ttype::GetTType; use crate::Result; // Read trait. Every type that needs to be deserialized will implement this trait. @@ -223,13 +226,21 @@ where impl Deserialize

for HashSet where P: ProtocolReader, - T: Deserialize

+ Hash + Eq, + T: Deserialize

+ GetTType + Hash + Eq, S: std::hash::BuildHasher + Default, { fn read(p: &mut P) -> Result { let (_elem_ty, len) = p.read_set_begin()?; - let mut hset = - HashSet::with_capacity_and_hasher(len.unwrap_or_default(), Default::default()); + let mut hset = { + let cap = len.unwrap_or(0); + if match cap.checked_mul(P::min_size::()) { + None => true, + Some(total) => !p.can_advance(total), + } { + bail!(ProtocolError::EOF); + } + HashSet::with_capacity_and_hasher(cap, S::default()) + }; if let Some(0) = len { return Ok(hset); @@ -294,14 +305,22 @@ where impl Deserialize

for HashMap where P: ProtocolReader, - K: Deserialize

+ Hash + Eq, - V: Deserialize

, + K: Deserialize

+ GetTType + Hash + Eq, + V: Deserialize

+ GetTType, S: std::hash::BuildHasher + Default, { fn read(p: &mut P) -> Result { let (_key_ty, _val_ty, len) = p.read_map_begin()?; - let mut hmap = - HashMap::with_capacity_and_hasher(len.unwrap_or_default(), Default::default()); + let mut hmap = { + let cap = len.unwrap_or(0); + if match cap.checked_mul(P::min_size::() + P::min_size::()) { + None => true, + Some(total) => !p.can_advance(total), + } { + bail!(ProtocolError::EOF); + } + HashMap::with_capacity_and_hasher(cap, S::default()) + }; if let Some(0) = len { return Ok(hmap); @@ -332,12 +351,21 @@ where impl Deserialize

for Vec where P: ProtocolReader, - T: Deserialize

+ crate::ttype::GetTType, // GetTType just to exclude Vec + T: Deserialize

+ GetTType, { /// Vec is Thrift List type fn read(p: &mut P) -> Result { let (_elem_ty, len) = p.read_list_begin()?; - let mut list = Vec::with_capacity(len.unwrap_or_default()); + let mut list = { + let cap = len.unwrap_or(0); + if match cap.checked_mul(P::min_size::()) { + None => true, + Some(total) => !p.can_advance(total), + } { + bail!(ProtocolError::EOF); + } + Vec::with_capacity(cap) + }; if let Some(0) = len { return Ok(list); diff --git a/third-party/thrift/src/thrift/lib/rust/src/protocol.rs b/third-party/thrift/src/thrift/lib/rust/src/protocol.rs index 41570d81598999..ca1a67228bda6f 100644 --- a/third-party/thrift/src/thrift/lib/rust/src/protocol.rs +++ b/third-party/thrift/src/thrift/lib/rust/src/protocol.rs @@ -24,6 +24,7 @@ use crate::framing::FramingEncoded; use crate::framing::FramingEncodedFinal; use crate::thrift_protocol::MessageType; use crate::thrift_protocol::ProtocolID; +use crate::ttype::GetTType; use crate::ttype::TType; use crate::Result; @@ -263,6 +264,23 @@ pub trait ProtocolReader { fn read_string(&mut self) -> Result; fn read_binary(&mut self) -> Result; + /// The smallest number of bytes in which a collection element of type `T` + /// could be represented. + // + // TODO: once feature(adt_const_params) is stable, make this generic + // over `const TTYPE: TType` instead of T to reduce monomorphization. + // https://github.com/rust-lang/rust/issues/95174 + fn min_size() -> usize { + 0 + } + + /// Whether there is enough input data available to bother trying to read + /// collection elements that would occupy a minimum of `n` bytes. + fn can_advance(&self, bytes: usize) -> bool { + let _ = bytes; + true + } + /// Skip over the next data element from the provided input Protocol object fn skip(&mut self, field_type: TType) -> Result<()> { skip_inner(self, field_type, DEFAULT_RECURSION_DEPTH) diff --git a/third-party/thrift/src/thrift/lib/rust/src/tests/binary.rs b/third-party/thrift/src/thrift/lib/rust/src/tests/binary.rs index 2f548ccb094cd7..735dbde6580357 100644 --- a/third-party/thrift/src/thrift/lib/rust/src/tests/binary.rs +++ b/third-party/thrift/src/thrift/lib/rust/src/tests/binary.rs @@ -14,10 +14,14 @@ * limitations under the License. */ +use std::collections::HashMap; +use std::collections::HashSet; use std::io::Cursor; use bytes::Buf; +use bytes::BufMut; use bytes::Bytes; +use bytes::BytesMut; use super::BOOL_VALUES; use super::BYTE_VALUES; @@ -26,6 +30,8 @@ use super::FLOAT_VALUES; use super::INT16_VALUES; use super::INT32_VALUES; use super::INT64_VALUES; +use crate::deserialize::Deserialize; +use crate::errors::ProtocolError; use crate::thrift_protocol::MessageType; use crate::ttype::TType; use crate::BinaryProtocol; @@ -508,3 +514,37 @@ fn read_binary_from_chained_buffer() { .expect("read \" world\" from the buffer"); assert_eq!(result.as_slice(), b" world"); } + +#[test] +fn test_overallocation() { + let mut malicious = BytesMut::new(); + malicious.put_u8(TType::I16 as u8); + malicious.put_i32(1_000_000_000); + malicious.put_bytes(0, 10); + let malicious = malicious.freeze(); + let mut deserializer = ::deserializer(Cursor::new(malicious.clone())); + let err = as Deserialize<_>>::read(&mut deserializer).unwrap_err(); + assert_eq!( + err.downcast_ref::(), + Some(&ProtocolError::EOF), + ); + + let mut deserializer = ::deserializer(Cursor::new(malicious)); + let err = as Deserialize<_>>::read(&mut deserializer).unwrap_err(); + assert_eq!( + err.downcast_ref::(), + Some(&ProtocolError::EOF), + ); + + let mut malicious = BytesMut::new(); + malicious.put_u8(TType::String as u8); + malicious.put_u8(TType::I16 as u8); + malicious.put_i32(1_000_000_000); + malicious.put_bytes(0, 10); + let mut deserializer = ::deserializer(Cursor::new(malicious.freeze())); + let err = as Deserialize<_>>::read(&mut deserializer).unwrap_err(); + assert_eq!( + err.downcast_ref::(), + Some(&ProtocolError::EOF), + ); +} diff --git a/third-party/thrift/src/thrift/lib/rust/src/tests/compact.rs b/third-party/thrift/src/thrift/lib/rust/src/tests/compact.rs index 2196c12253915a..5e20ff6bf772b3 100644 --- a/third-party/thrift/src/thrift/lib/rust/src/tests/compact.rs +++ b/third-party/thrift/src/thrift/lib/rust/src/tests/compact.rs @@ -14,10 +14,14 @@ * limitations under the License. */ +use std::collections::HashMap; +use std::collections::HashSet; use std::io::Cursor; use bytes::Buf; +use bytes::BufMut; use bytes::Bytes; +use bytes::BytesMut; use super::BOOL_VALUES; use super::BYTE_VALUES; @@ -26,6 +30,10 @@ use super::FLOAT_VALUES; use super::INT16_VALUES; use super::INT32_VALUES; use super::INT64_VALUES; +use crate::bufext::BufMutExt as _; +use crate::compact_protocol::CType; +use crate::deserialize::Deserialize; +use crate::errors::ProtocolError; use crate::thrift_protocol::MessageType; use crate::ttype::TType; use crate::CompactProtocol; @@ -599,3 +607,36 @@ fn skip_stop_in_container() { }, } } + +#[test] +fn test_overallocation() { + let mut malicious = BytesMut::new(); + malicious.put_u8(0xf0 | CType::I16 as u8); + malicious.put_varint_i64(1_000_000_000_000_000); + malicious.put_bytes(0, 10); + let malicious = malicious.freeze(); + let mut deserializer = ::deserializer(Cursor::new(malicious.clone())); + let err = as Deserialize<_>>::read(&mut deserializer).unwrap_err(); + assert_eq!( + err.downcast_ref::(), + Some(&ProtocolError::EOF), + ); + + let mut deserializer = ::deserializer(Cursor::new(malicious)); + let err = as Deserialize<_>>::read(&mut deserializer).unwrap_err(); + assert_eq!( + err.downcast_ref::(), + Some(&ProtocolError::EOF), + ); + + let mut malicious = BytesMut::new(); + malicious.put_varint_i64(1_000_000_000_000_000); + malicious.put_u8(((CType::Binary as u8) << 4) | (CType::I16 as u8)); + malicious.put_bytes(0, 10); + let mut deserializer = ::deserializer(Cursor::new(malicious.freeze())); + let err = as Deserialize<_>>::read(&mut deserializer).unwrap_err(); + assert_eq!( + err.downcast_ref::(), + Some(&ProtocolError::EOF), + ); +}