Skip to content

Commit

Permalink
Add blob size limits. (#2705)
Browse files Browse the repository at this point in the history
* Add a blob size limit.

* Add a bytecode size limit.

* Add unit tests for limits.

* Don't enforce the limit for already published bytecode.

* Simplify LimitedWriter; add unit test.

* Add decompressed_size_at_most.

* Update and copy comment about #2710.
  • Loading branch information
afck committed Oct 26, 2024
1 parent f667fb4 commit 2109142
Show file tree
Hide file tree
Showing 9 changed files with 172 additions and 36 deletions.
72 changes: 43 additions & 29 deletions linera-base/src/data_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ use crate::{
ApplicationId, BlobId, BlobType, BytecodeId, Destination, GenericApplicationId, MessageId,
UserApplicationId,
},
limited_writer::{LimitedWriter, LimitedWriterError},
time::{Duration, SystemTime},
};

Expand Down Expand Up @@ -842,14 +843,8 @@ impl fmt::Debug for Bytecode {
#[derive(Error, Debug)]
pub enum DecompressionError {
/// Compressed bytecode is invalid, and could not be decompressed.
#[cfg(not(target_arch = "wasm32"))]
#[error("Bytecode could not be decompressed")]
InvalidCompressedBytecode(#[source] io::Error),

/// Compressed bytecode is invalid, and could not be decompressed.
#[cfg(target_arch = "wasm32")]
#[error("Bytecode could not be decompressed")]
InvalidCompressedBytecode(#[from] ruzstd::frame_decoder::FrameDecoderError),
#[error("Bytecode could not be decompressed: {0}")]
InvalidCompressedBytecode(#[from] io::Error),
}

/// A compressed WebAssembly module's bytecode.
Expand Down Expand Up @@ -878,30 +873,57 @@ impl From<Bytecode> for CompressedBytecode {
}

#[cfg(not(target_arch = "wasm32"))]
impl TryFrom<&CompressedBytecode> for Bytecode {
type Error = DecompressionError;

fn try_from(compressed_bytecode: &CompressedBytecode) -> Result<Self, Self::Error> {
let bytes = zstd::stream::decode_all(&*compressed_bytecode.compressed_bytes)
.map_err(DecompressionError::InvalidCompressedBytecode)?;
impl CompressedBytecode {
/// Returns `true` if the decompressed size does not exceed the limit.
pub fn decompressed_size_at_most(&self, limit: u64) -> Result<bool, DecompressionError> {
let mut decoder = zstd::stream::Decoder::new(&*self.compressed_bytes)?;
let limit = usize::try_from(limit).unwrap_or(usize::MAX);
let mut writer = LimitedWriter::new(io::sink(), limit);
match io::copy(&mut decoder, &mut writer) {
Ok(_) => Ok(true),
Err(error) => {
error.downcast::<LimitedWriterError>()?;
Ok(false)
}
}
}

/// Decompresses a [`CompressedBytecode`] into a [`Bytecode`].
pub fn decompress(&self) -> Result<Bytecode, DecompressionError> {
let bytes = zstd::stream::decode_all(&*self.compressed_bytes)?;
Ok(Bytecode { bytes })
}
}

#[cfg(target_arch = "wasm32")]
impl TryFrom<&CompressedBytecode> for Bytecode {
type Error = DecompressionError;
impl CompressedBytecode {
/// Returns `true` if the decompressed size does not exceed the limit.
pub fn decompressed_size_at_most(&self, limit: u64) -> Result<bool, DecompressionError> {
let compressed_bytes = &*self.compressed_bytes;
let limit = usize::try_from(limit).unwrap_or(usize::MAX);
let mut writer = LimitedWriter::new(io::sink(), limit);
let mut decoder = ruzstd::streaming_decoder::StreamingDecoder::new(compressed_bytes)
.map_err(io::Error::other)?;

// TODO(#2710): Decode multiple frames, if present
match io::copy(&mut decoder, &mut writer) {
Ok(_) => Ok(true),
Err(error) => {
error.downcast::<LimitedWriterError>()?;
Ok(false)
}
}
}

fn try_from(compressed_bytecode: &CompressedBytecode) -> Result<Self, Self::Error> {
/// Decompresses a [`CompressedBytecode`] into a [`Bytecode`].
pub fn decompress(&self) -> Result<Bytecode, DecompressionError> {
use ruzstd::{io::Read, streaming_decoder::StreamingDecoder};

let compressed_bytes = &*compressed_bytecode.compressed_bytes;
let compressed_bytes = &*self.compressed_bytes;
let mut bytes = Vec::new();
let mut decoder = StreamingDecoder::new(compressed_bytes)?;
let mut decoder = StreamingDecoder::new(compressed_bytes).map_err(io::Error::other)?;

// Decode multiple frames, if present
// (https://github.com/KillingSpark/zstd-rs/issues/57)
// TODO(#2710): Decode multiple frames, if present
while !decoder.get_ref().is_empty() {
decoder
.read_to_end(&mut bytes)
Expand All @@ -912,14 +934,6 @@ impl TryFrom<&CompressedBytecode> for Bytecode {
}
}

impl TryFrom<CompressedBytecode> for Bytecode {
type Error = DecompressionError;

fn try_from(compressed_bytecode: CompressedBytecode) -> Result<Self, Self::Error> {
Bytecode::try_from(&compressed_bytecode)
}
}

impl fmt::Debug for CompressedBytecode {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("CompressedBytecode").finish_non_exhaustive()
Expand Down
1 change: 1 addition & 0 deletions linera-base/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ pub mod crypto;
pub mod data_types;
mod graphql;
pub mod identifiers;
mod limited_writer;
pub mod ownership;
#[cfg(with_metrics)]
pub mod prometheus_util;
Expand Down
68 changes: 68 additions & 0 deletions linera-base/src/limited_writer.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// Copyright (c) Zefchain Labs, Inc.
// SPDX-License-Identifier: Apache-2.0

use std::io::{self, Write};

use thiserror::Error;

use crate::ensure;

#[derive(Error, Debug)]
#[error("Writer limit exceeded")]
pub struct LimitedWriterError;

/// Custom writer that enforces a byte limit.
pub struct LimitedWriter<W: Write> {
inner: W,
limit: usize,
written: usize,
}

impl<W: Write> LimitedWriter<W> {
pub fn new(inner: W, limit: usize) -> Self {
Self {
inner,
limit,
written: 0,
}
}
}

impl<W: Write> Write for LimitedWriter<W> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
// Calculate the number of bytes we can write without exceeding the limit.
// Fail if the buffer doesn't fit.
ensure!(
self.limit
.checked_sub(self.written)
.is_some_and(|remaining| buf.len() <= remaining),
io::Error::other(LimitedWriterError)
);
// Forward to the inner writer.
let n = self.inner.write(buf)?;
self.written += n;
Ok(n)
}

fn flush(&mut self) -> io::Result<()> {
self.inner.flush()
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_limited_writer() {
let mut out_buffer = Vec::new();
let mut writer = LimitedWriter::new(&mut out_buffer, 5);
assert_eq!(writer.write(b"foo").unwrap(), 3);
assert_eq!(writer.write(b"ba").unwrap(), 2);
assert!(writer
.write(b"r")
.unwrap_err()
.downcast::<LimitedWriterError>()
.is_ok());
}
}
23 changes: 21 additions & 2 deletions linera-core/src/chain_worker/state/temporary_changes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
use std::borrow::Cow;

use linera_base::{
data_types::{ArithmeticError, Timestamp, UserApplicationDescription},
data_types::{ArithmeticError, CompressedBytecode, Timestamp, UserApplicationDescription},
ensure,
identifiers::{GenericApplicationId, UserApplicationId},
identifiers::{BlobType, GenericApplicationId, UserApplicationId},
};
use linera_chain::{
data_types::{
Expand All @@ -28,6 +28,7 @@ use {

use super::{check_block_epoch, ChainWorkerState};
use crate::{
client::{MAXIMUM_BLOB_SIZE, MAXIMUM_BYTECODE_SIZE},
data_types::{ChainInfo, ChainInfoQuery, ChainInfoResponse},
worker::WorkerError,
};
Expand Down Expand Up @@ -213,6 +214,24 @@ where
for blob in blobs {
self.0.cache_recent_blob(Cow::Borrowed(blob)).await;
}
for blob in self.0.get_blobs(block.published_blob_ids()).await? {
match blob.id().blob_type {
BlobType::Data => {}
BlobType::ContractBytecode | BlobType::ServiceBytecode => {
ensure!(
CompressedBytecode::from(blob.content().clone())
.decompressed_size_at_most(MAXIMUM_BYTECODE_SIZE)?,
WorkerError::BytecodeTooLarge
);
}
}
ensure!(
u64::try_from(blob.content().bytes.len())
.ok()
.is_some_and(|size| size <= MAXIMUM_BLOB_SIZE),
WorkerError::BlobTooLarge
)
}

let local_time = self.0.storage.clock().current_time();
ensure!(
Expand Down
5 changes: 5 additions & 0 deletions linera-core/src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,11 @@ mod chain_state;
#[path = "../unit_tests/client_tests.rs"]
mod client_tests;

/// The maximum size of a data or bytecode blob, in bytes.
pub(crate) const MAXIMUM_BLOB_SIZE: u64 = 3 * 1024 * 1024;
/// The maximum size of decompressed bytecode, in bytes.
pub(crate) const MAXIMUM_BYTECODE_SIZE: u64 = 30 * 1024 * 1024;

#[cfg(with_metrics)]
mod metrics {
use std::sync::LazyLock;
Expand Down
8 changes: 7 additions & 1 deletion linera-core/src/unit_tests/client_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ use crate::test_utils::ServiceStorageBuilder;
use crate::{
client::{
BlanketMessagePolicy, ChainClient, ChainClientError, ClientOutcome, MessageAction,
MessagePolicy,
MessagePolicy, MAXIMUM_BLOB_SIZE,
},
local_node::LocalNodeError,
node::{
Expand Down Expand Up @@ -2560,5 +2560,11 @@ where
assert_eq!(executed_block.block.incoming_bundles.len(), 1);
assert_eq!(executed_block.required_blob_ids().len(), 1);

let large_blob_bytes = vec![0; MAXIMUM_BLOB_SIZE as usize + 1];
let result = client1
.publish_data_blob(BlobContent::new(large_blob_bytes))
.await;
assert_matches!(result, Err(ChainClientError::LocalNodeError(_)));

Ok(())
}
17 changes: 16 additions & 1 deletion linera-core/src/unit_tests/wasm_client_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,10 @@ use crate::client::client_tests::RocksDbStorageBuilder;
use crate::client::client_tests::ScyllaDbStorageBuilder;
#[cfg(feature = "storage-service")]
use crate::client::client_tests::ServiceStorageBuilder;
use crate::client::client_tests::{MemoryStorageBuilder, StorageBuilder, TestBuilder};
use crate::client::{
client_tests::{MemoryStorageBuilder, StorageBuilder, TestBuilder},
ChainClientError, MAXIMUM_BYTECODE_SIZE,
};

#[cfg_attr(feature = "wasmer", test_case(WasmRuntime::Wasmer ; "wasmer"))]
#[cfg_attr(feature = "wasmtime", test_case(WasmRuntime::Wasmtime ; "wasmtime"))]
Expand Down Expand Up @@ -195,6 +198,18 @@ where
let balance_after_init = creator.local_balance().await?;
assert!(balance_after_init < balance_after_messaging);

let large_bytecode = Bytecode::new(vec![0; MAXIMUM_BYTECODE_SIZE as usize + 1]);
let small_bytecode = Bytecode::new(vec![]);
// Publishing bytecode that exceeds the limit fails.
let result = publisher
.publish_bytecode(large_bytecode.clone(), small_bytecode.clone())
.await;
assert_matches!(result, Err(ChainClientError::LocalNodeError(_)));
let result = publisher
.publish_bytecode(small_bytecode, large_bytecode)
.await;
assert_matches!(result, Err(ChainClientError::LocalNodeError(_)));

Ok(())
}

Expand Down
10 changes: 9 additions & 1 deletion linera-core/src/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ use std::{
use linera_base::crypto::PublicKey;
use linera_base::{
crypto::{CryptoHash, KeyPair},
data_types::{ArithmeticError, Blob, BlockHeight, Round, UserApplicationDescription},
data_types::{
ArithmeticError, Blob, BlockHeight, DecompressionError, Round, UserApplicationDescription,
},
doc_scalar,
identifiers::{BlobId, ChainId, Owner, UserApplicationId},
};
Expand Down Expand Up @@ -215,6 +217,12 @@ pub enum WorkerError {
FullChainWorkerCache,
#[error("Failed to join spawned worker task")]
JoinError,
#[error("Blob exceeds size limit")]
BlobTooLarge,
#[error("Bytecode exceeds size limit")]
BytecodeTooLarge,
#[error(transparent)]
Decompression(#[from] DecompressionError),
}

impl From<linera_chain::ChainError> for WorkerError {
Expand Down
4 changes: 2 additions & 2 deletions linera-storage/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ pub trait Storage: Sized {
contract_blob
.into_inner_contract_bytecode()
.expect("Contract Bytecode Blob is of the wrong Blob type!")
.try_into()?,
.decompress()?,
wasm_runtime,
)
.await?,
Expand Down Expand Up @@ -335,7 +335,7 @@ pub trait Storage: Sized {
service_blob
.into_inner_service_bytecode()
.expect("Service Bytecode Blob is of the wrong Blob type!")
.try_into()?,
.decompress()?,
wasm_runtime,
)
.await?,
Expand Down

0 comments on commit 2109142

Please sign in to comment.