Skip to content

Commit

Permalink
Prevent collection overallocation
Browse files Browse the repository at this point in the history
Summary:
In both the binary and compact protocols, previously it was possible for the `Vec::with_capacity` and `HashSet::with_capacity_and_hasher` and `HashMap::with_capacity_and_hasher` in these types' `fbthrift::Deserialize` impls to allocate arbitrarily big collections when fed a small Thrift message containing a large collection size. This makes Rust Thrift servers vulnerable to Denial Of Service from maliciously crafted Thrift messages, or even just unlucky garbage messages.

This diff makes these Deserialize impls short-circuit with `ProtocolError::EOF` if the remaining input data cannot possibly contain the number of collection elements claimed by the collection size.

For example, in compact protocol which represents `f32` using 4 bytes, deserializing `Vec<f32>` with 4000 bytes remaining in the input data will not bother trying to allocate a Vec with capacity larger than 1000. In particular, it will no longer try to allocate a Vec with capacity 10¹⁵, even if the serialized collection size claims that the input contains 10¹⁵ list elements.

Reviewed By: zertosh

Differential Revision: D68986442

fbshipit-source-id: a3885833d6d49f581912ffcb1365af381671a28e
  • Loading branch information
David Tolnay authored and facebook-github-bot committed Feb 1, 2025
1 parent 5b7b835 commit e68f20c
Show file tree
Hide file tree
Showing 7 changed files with 218 additions and 10 deletions.
26 changes: 26 additions & 0 deletions third-party/thrift/src/thrift/lib/rust/src/binary_protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ use crate::protocol::ProtocolWriter;
use crate::serialize::Serialize;
use crate::thrift_protocol::MessageType;
use crate::thrift_protocol::ProtocolID;
use crate::ttype::GetTType;
use crate::ttype::TType;

pub const BINARY_VERSION_MASK: u32 = 0xffff_0000;
Expand Down Expand Up @@ -460,6 +461,31 @@ impl<B: BufExt> ProtocolReader for BinaryProtocolDeserializer<B> {
ensure_err!(self.buffer.remaining() >= received_len, ProtocolError::EOF);
Ok(V::copy_from_buf(&mut self.buffer, received_len))
}

fn min_size<T: GetTType>() -> usize {
match T::TTYPE {
TType::Void => 0,
TType::Bool => 1,
TType::Byte => 1,
TType::Double => 8,
TType::I16 => 2,
TType::I32 => 4,
TType::I64 => 8,
TType::String => 4,
TType::Struct => 1,
TType::Map => 6,
TType::Set => 5,
TType::List => 5,
TType::UTF8 => 4,
TType::UTF16 => 4,
TType::Float => 4,
TType::Stop | TType::Stream => unreachable!(),
}
}

fn can_advance(&self, bytes: usize) -> bool {
self.buffer.can_advance(bytes)
}
}

/// How large an item will be when `serialize()` is called
Expand Down
31 changes: 30 additions & 1 deletion third-party/thrift/src/thrift/lib/rust/src/bufext.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,25 @@ pub trait BufExt: Buf {
// Default is to just copy.
self.copy_to_bytes(len)
}

/// Whether there are enough remaining bytes to advance this buffer's
/// internal cursor by `n`.
///
/// Disjoint buffers for which `remaining()` is not O(1) should override
/// this method with a more efficient implementation.
fn can_advance(&self, n: usize) -> bool {
n <= self.remaining()
}

/// Number of more bytes needed in order to be able to advance by `n`.
///
/// If `n <= remaining()`, this is 0. Otherwise `n - remaining()`.
///
/// Disjoint buffers for which `remaining()` is not O(1) should override
/// this method with a more efficient implementation.
fn shortfall(&self, n: usize) -> usize {
n.saturating_sub(self.remaining())
}
}

impl BufExt for Bytes {}
Expand All @@ -50,7 +69,17 @@ impl BufExt for Cursor<Bytes> {

impl<T: AsRef<[u8]> + ?Sized> BufExt for Cursor<&T> {}

impl<T: BufExt, U: BufExt> BufExt for Chain<T, U> {}
impl<T: BufExt, U: BufExt> BufExt for Chain<T, U> {
fn can_advance(&self, n: usize) -> bool {
let rest = self.first_ref().shortfall(n);
self.last_ref().can_advance(rest)
}

fn shortfall(&self, n: usize) -> usize {
let rest = self.first_ref().shortfall(n);
self.last_ref().shortfall(rest)
}
}

pub trait BufMutExt: BufMut {
type Final: Send + 'static;
Expand Down
26 changes: 26 additions & 0 deletions third-party/thrift/src/thrift/lib/rust/src/compact_protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ use crate::protocol::ProtocolWriter;
use crate::serialize::Serialize;
use crate::thrift_protocol::MessageType;
use crate::thrift_protocol::ProtocolID;
use crate::ttype::GetTType;
use crate::ttype::TType;
use crate::varint;

Expand Down Expand Up @@ -840,6 +841,31 @@ impl<B: BufExt> ProtocolReader for CompactProtocolDeserializer<B> {

Ok(V::copy_from_buf(&mut self.buffer, received_len))
}

fn min_size<T: GetTType>() -> usize {
match T::TTYPE {
TType::Void => 0,
TType::Bool => 1,
TType::Byte => 1,
TType::Double => 8,
TType::I16 => 1,
TType::I32 => 1,
TType::I64 => 1,
TType::String => 1,
TType::Struct => 1,
TType::Map => 1,
TType::Set => 1,
TType::List => 1,
TType::UTF8 => 1,
TType::UTF16 => 1,
TType::Float => 4,
TType::Stop | TType::Stream => unreachable!(),
}
}

fn can_advance(&self, bytes: usize) -> bool {
self.buffer.can_advance(bytes)
}
}

/// How large an item will be when `serialize()` is called
Expand Down
46 changes: 37 additions & 9 deletions third-party/thrift/src/thrift/lib/rust/src/deserialize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,14 @@ use std::collections::HashSet;
use std::hash::Hash;
use std::sync::Arc;

use anyhow::bail;
use bytes::Bytes;
use ordered_float::OrderedFloat;

use crate::errors::ProtocolError;
use crate::protocol::should_break;
use crate::protocol::ProtocolReader;
use crate::ttype::GetTType;
use crate::Result;

// Read trait. Every type that needs to be deserialized will implement this trait.
Expand Down Expand Up @@ -223,13 +226,21 @@ where
impl<P, T, S> Deserialize<P> for HashSet<T, S>
where
P: ProtocolReader,
T: Deserialize<P> + Hash + Eq,
T: Deserialize<P> + GetTType + Hash + Eq,
S: std::hash::BuildHasher + Default,
{
fn read(p: &mut P) -> Result<Self> {
let (_elem_ty, len) = p.read_set_begin()?;
let mut hset =
HashSet::with_capacity_and_hasher(len.unwrap_or_default(), Default::default());
let mut hset = {
let cap = len.unwrap_or(0);
if match cap.checked_mul(P::min_size::<T>()) {
None => true,
Some(total) => !p.can_advance(total),
} {
bail!(ProtocolError::EOF);
}
HashSet::with_capacity_and_hasher(cap, S::default())
};

if let Some(0) = len {
return Ok(hset);
Expand Down Expand Up @@ -294,14 +305,22 @@ where
impl<P, K, V, S> Deserialize<P> for HashMap<K, V, S>
where
P: ProtocolReader,
K: Deserialize<P> + Hash + Eq,
V: Deserialize<P>,
K: Deserialize<P> + GetTType + Hash + Eq,
V: Deserialize<P> + GetTType,
S: std::hash::BuildHasher + Default,
{
fn read(p: &mut P) -> Result<Self> {
let (_key_ty, _val_ty, len) = p.read_map_begin()?;
let mut hmap =
HashMap::with_capacity_and_hasher(len.unwrap_or_default(), Default::default());
let mut hmap = {
let cap = len.unwrap_or(0);
if match cap.checked_mul(P::min_size::<K>() + P::min_size::<V>()) {
None => true,
Some(total) => !p.can_advance(total),
} {
bail!(ProtocolError::EOF);
}
HashMap::with_capacity_and_hasher(cap, S::default())
};

if let Some(0) = len {
return Ok(hmap);
Expand Down Expand Up @@ -332,12 +351,21 @@ where
impl<P, T> Deserialize<P> for Vec<T>
where
P: ProtocolReader,
T: Deserialize<P> + crate::ttype::GetTType, // GetTType just to exclude Vec<u8>
T: Deserialize<P> + GetTType,
{
/// Vec<T> is Thrift List type
fn read(p: &mut P) -> Result<Self> {
let (_elem_ty, len) = p.read_list_begin()?;
let mut list = Vec::with_capacity(len.unwrap_or_default());
let mut list = {
let cap = len.unwrap_or(0);
if match cap.checked_mul(P::min_size::<T>()) {
None => true,
Some(total) => !p.can_advance(total),
} {
bail!(ProtocolError::EOF);
}
Vec::with_capacity(cap)
};

if let Some(0) = len {
return Ok(list);
Expand Down
18 changes: 18 additions & 0 deletions third-party/thrift/src/thrift/lib/rust/src/protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ use crate::framing::FramingEncoded;
use crate::framing::FramingEncodedFinal;
use crate::thrift_protocol::MessageType;
use crate::thrift_protocol::ProtocolID;
use crate::ttype::GetTType;
use crate::ttype::TType;
use crate::Result;

Expand Down Expand Up @@ -263,6 +264,23 @@ pub trait ProtocolReader {
fn read_string(&mut self) -> Result<String>;
fn read_binary<V: CopyFromBuf>(&mut self) -> Result<V>;

/// The smallest number of bytes in which a collection element of type `T`
/// could be represented.
//
// TODO: once feature(adt_const_params) is stable, make this generic
// over `const TTYPE: TType` instead of T to reduce monomorphization.
// https://github.com/rust-lang/rust/issues/95174
fn min_size<T: GetTType>() -> usize {
0
}

/// Whether there is enough input data available to bother trying to read
/// collection elements that would occupy a minimum of `n` bytes.
fn can_advance(&self, bytes: usize) -> bool {
let _ = bytes;
true
}

/// Skip over the next data element from the provided input Protocol object
fn skip(&mut self, field_type: TType) -> Result<()> {
skip_inner(self, field_type, DEFAULT_RECURSION_DEPTH)
Expand Down
40 changes: 40 additions & 0 deletions third-party/thrift/src/thrift/lib/rust/src/tests/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,14 @@
* limitations under the License.
*/

use std::collections::HashMap;
use std::collections::HashSet;
use std::io::Cursor;

use bytes::Buf;
use bytes::BufMut;
use bytes::Bytes;
use bytes::BytesMut;

use super::BOOL_VALUES;
use super::BYTE_VALUES;
Expand All @@ -26,6 +30,8 @@ use super::FLOAT_VALUES;
use super::INT16_VALUES;
use super::INT32_VALUES;
use super::INT64_VALUES;
use crate::deserialize::Deserialize;
use crate::errors::ProtocolError;
use crate::thrift_protocol::MessageType;
use crate::ttype::TType;
use crate::BinaryProtocol;
Expand Down Expand Up @@ -508,3 +514,37 @@ fn read_binary_from_chained_buffer() {
.expect("read \" world\" from the buffer");
assert_eq!(result.as_slice(), b" world");
}

#[test]
fn test_overallocation() {
let mut malicious = BytesMut::new();
malicious.put_u8(TType::I16 as u8);
malicious.put_i32(1_000_000_000);
malicious.put_bytes(0, 10);
let malicious = malicious.freeze();
let mut deserializer = <BinaryProtocol>::deserializer(Cursor::new(malicious.clone()));
let err = <Vec<i16> as Deserialize<_>>::read(&mut deserializer).unwrap_err();
assert_eq!(
err.downcast_ref::<ProtocolError>(),
Some(&ProtocolError::EOF),
);

let mut deserializer = <BinaryProtocol>::deserializer(Cursor::new(malicious));
let err = <HashSet<i16> as Deserialize<_>>::read(&mut deserializer).unwrap_err();
assert_eq!(
err.downcast_ref::<ProtocolError>(),
Some(&ProtocolError::EOF),
);

let mut malicious = BytesMut::new();
malicious.put_u8(TType::String as u8);
malicious.put_u8(TType::I16 as u8);
malicious.put_i32(1_000_000_000);
malicious.put_bytes(0, 10);
let mut deserializer = <BinaryProtocol>::deserializer(Cursor::new(malicious.freeze()));
let err = <HashMap<String, i16> as Deserialize<_>>::read(&mut deserializer).unwrap_err();
assert_eq!(
err.downcast_ref::<ProtocolError>(),
Some(&ProtocolError::EOF),
);
}
41 changes: 41 additions & 0 deletions third-party/thrift/src/thrift/lib/rust/src/tests/compact.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,14 @@
* limitations under the License.
*/

use std::collections::HashMap;
use std::collections::HashSet;
use std::io::Cursor;

use bytes::Buf;
use bytes::BufMut;
use bytes::Bytes;
use bytes::BytesMut;

use super::BOOL_VALUES;
use super::BYTE_VALUES;
Expand All @@ -26,6 +30,10 @@ use super::FLOAT_VALUES;
use super::INT16_VALUES;
use super::INT32_VALUES;
use super::INT64_VALUES;
use crate::bufext::BufMutExt as _;
use crate::compact_protocol::CType;
use crate::deserialize::Deserialize;
use crate::errors::ProtocolError;
use crate::thrift_protocol::MessageType;
use crate::ttype::TType;
use crate::CompactProtocol;
Expand Down Expand Up @@ -599,3 +607,36 @@ fn skip_stop_in_container() {
},
}
}

#[test]
fn test_overallocation() {
let mut malicious = BytesMut::new();
malicious.put_u8(0xf0 | CType::I16 as u8);
malicious.put_varint_i64(1_000_000_000_000_000);
malicious.put_bytes(0, 10);
let malicious = malicious.freeze();
let mut deserializer = <CompactProtocol>::deserializer(Cursor::new(malicious.clone()));
let err = <Vec<i16> as Deserialize<_>>::read(&mut deserializer).unwrap_err();
assert_eq!(
err.downcast_ref::<ProtocolError>(),
Some(&ProtocolError::EOF),
);

let mut deserializer = <CompactProtocol>::deserializer(Cursor::new(malicious));
let err = <HashSet<i16> as Deserialize<_>>::read(&mut deserializer).unwrap_err();
assert_eq!(
err.downcast_ref::<ProtocolError>(),
Some(&ProtocolError::EOF),
);

let mut malicious = BytesMut::new();
malicious.put_varint_i64(1_000_000_000_000_000);
malicious.put_u8(((CType::Binary as u8) << 4) | (CType::I16 as u8));
malicious.put_bytes(0, 10);
let mut deserializer = <CompactProtocol>::deserializer(Cursor::new(malicious.freeze()));
let err = <HashMap<String, i16> as Deserialize<_>>::read(&mut deserializer).unwrap_err();
assert_eq!(
err.downcast_ref::<ProtocolError>(),
Some(&ProtocolError::EOF),
);
}

0 comments on commit e68f20c

Please sign in to comment.