diff --git a/arrow-ipc/benches/ipc_reader.rs b/arrow-ipc/benches/ipc_reader.rs index 7fc14664b4a7..43e391d82172 100644 --- a/arrow-ipc/benches/ipc_reader.rs +++ b/arrow-ipc/benches/ipc_reader.rs @@ -24,7 +24,7 @@ use arrow_ipc::writer::{FileWriter, IpcWriteOptions, StreamWriter}; use arrow_ipc::{root_as_footer, Block, CompressionType}; use arrow_schema::{DataType, Field, Schema}; use criterion::{criterion_group, criterion_main, Criterion}; -use std::io::Cursor; +use std::io::{Cursor, Write}; use std::sync::Arc; use tempfile::tempdir; @@ -32,17 +32,26 @@ fn criterion_benchmark(c: &mut Criterion) { let mut group = c.benchmark_group("arrow_ipc_reader"); group.bench_function("StreamReader/read_10", |b| { - let batch = create_batch(8192, true); - let mut buffer = Vec::with_capacity(2 * 1024 * 1024); - let mut writer = StreamWriter::try_new(&mut buffer, batch.schema().as_ref()).unwrap(); - for _ in 0..10 { - writer.write(&batch).unwrap(); - } - writer.finish().unwrap(); + let buffer = ipc_stream(); + b.iter(move || { + let projection = None; + let mut reader = StreamReader::try_new(buffer.as_slice(), projection).unwrap(); + for _ in 0..10 { + reader.next().unwrap().unwrap(); + } + assert!(reader.next().is_none()); + }) + }); + group.bench_function("StreamReader/no_validation/read_10", |b| { + let buffer = ipc_stream(); b.iter(move || { let projection = None; let mut reader = StreamReader::try_new(buffer.as_slice(), projection).unwrap(); + unsafe { + // safety: we created a valid IPC file + reader = reader.with_skip_validation(true); + } for _ in 0..10 { reader.next().unwrap().unwrap(); } @@ -51,22 +60,26 @@ fn criterion_benchmark(c: &mut Criterion) { }); group.bench_function("StreamReader/read_10/zstd", |b| { - let batch = create_batch(8192, true); - let mut buffer = Vec::with_capacity(2 * 1024 * 1024); - let options = IpcWriteOptions::default() - .try_with_compression(Some(CompressionType::ZSTD)) - .unwrap(); - let mut writer = - StreamWriter::try_new_with_options(&mut buffer, batch.schema().as_ref(), options) - .unwrap(); - for _ in 0..10 { - writer.write(&batch).unwrap(); - } - writer.finish().unwrap(); + let buffer = ipc_stream_zstd(); + b.iter(move || { + let projection = None; + let mut reader = StreamReader::try_new(buffer.as_slice(), projection).unwrap(); + for _ in 0..10 { + reader.next().unwrap().unwrap(); + } + assert!(reader.next().is_none()); + }) + }); + group.bench_function("StreamReader/no_validation/read_10/zstd", |b| { + let buffer = ipc_stream_zstd(); b.iter(move || { let projection = None; let mut reader = StreamReader::try_new(buffer.as_slice(), projection).unwrap(); + unsafe { + // safety: we created a valid IPC file + reader = reader.with_skip_validation(true); + } for _ in 0..10 { reader.next().unwrap().unwrap(); } @@ -74,19 +87,30 @@ fn criterion_benchmark(c: &mut Criterion) { }) }); + // --- Create IPC File --- group.bench_function("FileReader/read_10", |b| { - let batch = create_batch(8192, true); - let mut buffer = Vec::with_capacity(2 * 1024 * 1024); - let mut writer = FileWriter::try_new(&mut buffer, batch.schema().as_ref()).unwrap(); - for _ in 0..10 { - writer.write(&batch).unwrap(); - } - writer.finish().unwrap(); + let buffer = ipc_file(); + b.iter(move || { + let projection = None; + let cursor = Cursor::new(buffer.as_slice()); + let mut reader = FileReader::try_new(cursor, projection).unwrap(); + for _ in 0..10 { + reader.next().unwrap().unwrap(); + } + assert!(reader.next().is_none()); + }) + }); + group.bench_function("FileReader/no_validation/read_10", |b| { + let buffer = ipc_file(); b.iter(move || { let projection = None; let cursor = Cursor::new(buffer.as_slice()); let mut reader = FileReader::try_new(cursor, projection).unwrap(); + unsafe { + // safety: we created a valid IPC file + reader = reader.with_skip_validation(true); + } for _ in 0..10 { reader.next().unwrap().unwrap(); } @@ -94,26 +118,42 @@ fn criterion_benchmark(c: &mut Criterion) { }) }); + // write to an actual file + let dir = tempdir().unwrap(); + let path = dir.path().join("test.arrow"); + let mut file = std::fs::File::create(&path).unwrap(); + file.write_all(&ipc_file()).unwrap(); + drop(file); + group.bench_function("FileReader/read_10/mmap", |b| { - let batch = create_batch(8192, true); - // write to an actual file - let dir = tempdir().unwrap(); - let path = dir.path().join("test.arrow"); - let file = std::fs::File::create(&path).unwrap(); - let mut writer = FileWriter::try_new(file, batch.schema().as_ref()).unwrap(); - for _ in 0..10 { - writer.write(&batch).unwrap(); - } - writer.finish().unwrap(); + let path = &path; + b.iter(move || { + let ipc_file = std::fs::File::open(path).expect("failed to open file"); + let mmap = unsafe { memmap2::Mmap::map(&ipc_file).expect("failed to mmap file") }; + + // Convert the mmap region to an Arrow `Buffer` to back the arrow arrays. + let bytes = bytes::Bytes::from_owner(mmap); + let buffer = Buffer::from(bytes); + let decoder = IPCBufferDecoder::new(buffer); + assert_eq!(decoder.num_batches(), 10); + for i in 0..decoder.num_batches() { + decoder.get_batch(i); + } + }) + }); + + group.bench_function("FileReader/no_validation/read_10/mmap", |b| { + let path = &path; b.iter(move || { - let ipc_file = std::fs::File::open(&path).expect("failed to open file"); + let ipc_file = std::fs::File::open(path).expect("failed to open file"); let mmap = unsafe { memmap2::Mmap::map(&ipc_file).expect("failed to mmap file") }; // Convert the mmap region to an Arrow `Buffer` to back the arrow arrays. let bytes = bytes::Bytes::from_owner(mmap); let buffer = Buffer::from(bytes); let decoder = IPCBufferDecoder::new(buffer); + let decoder = unsafe { decoder.with_skip_validation(true) }; assert_eq!(decoder.num_batches(), 10); for i in 0..decoder.num_batches() { @@ -123,6 +163,46 @@ fn criterion_benchmark(c: &mut Criterion) { }); } +/// Return an IPC stream with 10 record batches +fn ipc_stream() -> Vec { + let batch = create_batch(8192, true); + let mut buffer = Vec::with_capacity(2 * 1024 * 1024); + let mut writer = StreamWriter::try_new(&mut buffer, batch.schema().as_ref()).unwrap(); + for _ in 0..10 { + writer.write(&batch).unwrap(); + } + writer.finish().unwrap(); + buffer +} + +/// Return an IPC stream with ZSTD compression with 10 record batches +fn ipc_stream_zstd() -> Vec { + let batch = create_batch(8192, true); + let mut buffer = Vec::with_capacity(2 * 1024 * 1024); + let options = IpcWriteOptions::default() + .try_with_compression(Some(CompressionType::ZSTD)) + .unwrap(); + let mut writer = + StreamWriter::try_new_with_options(&mut buffer, batch.schema().as_ref(), options).unwrap(); + for _ in 0..10 { + writer.write(&batch).unwrap(); + } + writer.finish().unwrap(); + buffer +} + +/// Return an IPC file with 10 record batches +fn ipc_file() -> Vec { + let batch = create_batch(8192, true); + let mut buffer = Vec::with_capacity(2 * 1024 * 1024); + let mut writer = FileWriter::try_new(&mut buffer, batch.schema().as_ref()).unwrap(); + for _ in 0..10 { + writer.write(&batch).unwrap(); + } + writer.finish().unwrap(); + buffer +} + // copied from the zero_copy_ipc example. // should we move this to an actual API? /// Wrapper around the example in the `FileDecoder` which handles the @@ -166,6 +246,11 @@ impl IPCBufferDecoder { } } + unsafe fn with_skip_validation(mut self, skip_validation: bool) -> Self { + self.decoder = self.decoder.with_skip_validation(skip_validation); + self + } + fn num_batches(&self) -> usize { self.batches.len() } diff --git a/arrow-ipc/src/reader.rs b/arrow-ipc/src/reader.rs index ca0d09e2282f..2fd6c14dd704 100644 --- a/arrow-ipc/src/reader.rs +++ b/arrow-ipc/src/reader.rs @@ -36,7 +36,7 @@ use std::sync::Arc; use arrow_array::*; use arrow_buffer::{ArrowNativeType, BooleanBuffer, Buffer, MutableBuffer, ScalarBuffer}; -use arrow_data::ArrayData; +use arrow_data::{ArrayData, ArrayDataBuilder, UnsafeFlag}; use arrow_schema::*; use crate::compression::CompressionCodec; @@ -136,24 +136,7 @@ impl RecordBatchDecoder<'_> { let child = self.create_array(struct_field, variadic_counts)?; struct_arrays.push(child); } - let null_count = struct_node.null_count() as usize; - let struct_array = if struct_arrays.is_empty() { - // `StructArray::from` can't infer the correct row count - // if we have zero fields - let len = struct_node.length() as usize; - StructArray::new_empty_fields( - len, - (null_count > 0).then(|| BooleanBuffer::new(null_buffer, 0, len).into()), - ) - } else if null_count > 0 { - // create struct array from fields, arrays and null data - let len = struct_node.length() as usize; - let nulls = BooleanBuffer::new(null_buffer, 0, len).into(); - StructArray::try_new(struct_fields.clone(), struct_arrays, Some(nulls))? - } else { - StructArray::try_new(struct_fields.clone(), struct_arrays, None)? - }; - Ok(Arc::new(struct_array)) + self.create_struct_array(struct_node, null_buffer, struct_fields, struct_arrays) } RunEndEncoded(run_ends_field, values_field) => { let run_node = self.next_node(field)?; @@ -161,15 +144,12 @@ impl RecordBatchDecoder<'_> { let values = self.create_array(values_field, variadic_counts)?; let run_array_length = run_node.length() as usize; - let array_data = ArrayData::builder(data_type.clone()) + let builder = ArrayData::builder(data_type.clone()) .len(run_array_length) .offset(0) .add_child_data(run_ends.into_data()) - .add_child_data(values.into_data()) - .align_buffers(!self.require_alignment) - .build()?; - - Ok(make_array(array_data)) + .add_child_data(values.into_data()); + self.create_array_from_builder(builder) } // Create dictionary array from RecordBatch Dictionary(_, _) => { @@ -223,7 +203,14 @@ impl RecordBatchDecoder<'_> { children.push(child); } - let array = UnionArray::try_new(fields.clone(), type_ids, value_offsets, children)?; + let array = if self.skip_validation.get() { + // safety: flag can only be set via unsafe code + unsafe { + UnionArray::new_unchecked(fields.clone(), type_ids, value_offsets, children) + } + } else { + UnionArray::try_new(fields.clone(), type_ids, value_offsets, children)? + }; Ok(Arc::new(array)) } Null => { @@ -237,14 +224,10 @@ impl RecordBatchDecoder<'_> { ))); } - let array_data = ArrayData::builder(data_type.clone()) + let builder = ArrayData::builder(data_type.clone()) .len(length as usize) - .offset(0) - .align_buffers(!self.require_alignment) - .build()?; - - // no buffer increases - Ok(Arc::new(NullArray::from(array_data))) + .offset(0); + self.create_array_from_builder(builder) } _ => { let field_node = self.next_node(field)?; @@ -286,9 +269,17 @@ impl RecordBatchDecoder<'_> { t => unreachable!("Data type {:?} either unsupported or not primitive", t), }; - let array_data = builder.align_buffers(!self.require_alignment).build()?; + self.create_array_from_builder(builder) + } - Ok(make_array(array_data)) + /// Update the ArrayDataBuilder based on settings in this decoder + fn create_array_from_builder(&self, builder: ArrayDataBuilder) -> Result { + let mut builder = builder.align_buffers(!self.require_alignment); + if self.skip_validation.get() { + // SAFETY: flag can only be set via unsafe code + unsafe { builder = builder.skip_validation(true) } + }; + Ok(make_array(builder.build()?)) } /// Reads the correct number of buffers based on list type and null_count, and creates a @@ -318,9 +309,42 @@ impl RecordBatchDecoder<'_> { _ => unreachable!("Cannot create list or map array from {:?}", data_type), }; - let array_data = builder.align_buffers(!self.require_alignment).build()?; + self.create_array_from_builder(builder) + } - Ok(make_array(array_data)) + fn create_struct_array( + &self, + struct_node: &FieldNode, + null_buffer: Buffer, + struct_fields: &Fields, + struct_arrays: Vec, + ) -> Result { + let null_count = struct_node.null_count() as usize; + let len = struct_node.length() as usize; + + if struct_arrays.is_empty() { + // `StructArray::from` can't infer the correct row count + // if we have zero fields + return Ok(Arc::new(StructArray::new_empty_fields( + len, + (null_count > 0).then(|| BooleanBuffer::new(null_buffer, 0, len).into()), + ))); + } + + let nulls = if null_count > 0 { + Some(BooleanBuffer::new(null_buffer, 0, len).into()) + } else { + None + }; + + let struct_array = if self.skip_validation.get() { + // safety: flag can only be set via unsafe code + unsafe { StructArray::new_unchecked(struct_fields.clone(), struct_arrays, nulls) } + } else { + StructArray::try_new(struct_fields.clone(), struct_arrays, nulls)? + }; + + Ok(Arc::new(struct_array)) } /// Reads the correct number of buffers based on list type and null_count, and creates a @@ -334,15 +358,12 @@ impl RecordBatchDecoder<'_> { ) -> Result { if let Dictionary(_, _) = *data_type { let null_buffer = (field_node.null_count() > 0).then_some(buffers[0].clone()); - let array_data = ArrayData::builder(data_type.clone()) + let builder = ArrayData::builder(data_type.clone()) .len(field_node.length() as usize) .add_buffer(buffers[1].clone()) .add_child_data(value_array.into_data()) - .null_bit_buffer(null_buffer) - .align_buffers(!self.require_alignment) - .build()?; - - Ok(make_array(array_data)) + .null_bit_buffer(null_buffer); + self.create_array_from_builder(builder) } else { unreachable!("Cannot create dictionary array from {:?}", data_type) } @@ -376,6 +397,10 @@ struct RecordBatchDecoder<'a> { /// Are buffers required to already be aligned? See /// [`RecordBatchDecoder::with_require_alignment`] for details require_alignment: bool, + /// Should validation be skipped when reading data? Defaults to false. + /// + /// See [`FileDecoder::with_skip_validation`] for details. + skip_validation: UnsafeFlag, } impl<'a> RecordBatchDecoder<'a> { @@ -410,6 +435,7 @@ impl<'a> RecordBatchDecoder<'a> { buffers: buffers.iter(), projection: None, require_alignment: false, + skip_validation: UnsafeFlag::new(), }) } @@ -432,6 +458,17 @@ impl<'a> RecordBatchDecoder<'a> { self } + /// Specifies if validation should be skipped when reading data (defaults to `false`) + /// + /// # Safety + /// + /// Relies on the caller only passing a flag with `true` value if they are + /// certain that the data is valid + pub fn with_skip_validation(mut self, skip_validation: UnsafeFlag) -> Self { + self.skip_validation = skip_validation; + self + } + /// Read the record batch, consuming the reader fn read_record_batch(mut self) -> Result { let mut variadic_counts: VecDeque = self @@ -601,7 +638,15 @@ pub fn read_dictionary( dictionaries_by_id: &mut HashMap, metadata: &MetadataVersion, ) -> Result<(), ArrowError> { - read_dictionary_impl(buf, batch, schema, dictionaries_by_id, metadata, false) + read_dictionary_impl( + buf, + batch, + schema, + dictionaries_by_id, + metadata, + false, + UnsafeFlag::new(), + ) } fn read_dictionary_impl( @@ -611,6 +656,7 @@ fn read_dictionary_impl( dictionaries_by_id: &mut HashMap, metadata: &MetadataVersion, require_alignment: bool, + skip_validation: UnsafeFlag, ) -> Result<(), ArrowError> { if batch.isDelta() { return Err(ArrowError::InvalidArgumentError( @@ -642,6 +688,7 @@ fn read_dictionary_impl( metadata, )? .with_require_alignment(require_alignment) + .with_skip_validation(skip_validation) .read_record_batch()?; Some(record_batch.column(0).clone()) @@ -772,6 +819,7 @@ pub struct FileDecoder { version: MetadataVersion, projection: Option>, require_alignment: bool, + skip_validation: UnsafeFlag, } impl FileDecoder { @@ -783,6 +831,7 @@ impl FileDecoder { dictionaries: Default::default(), projection: None, require_alignment: false, + skip_validation: UnsafeFlag::new(), } } @@ -792,7 +841,7 @@ impl FileDecoder { self } - /// Specifies whether or not array data in input buffers is required to be properly aligned. + /// Specifies if the array data in input buffers is required to be properly aligned. /// /// If `require_alignment` is true, this decoder will return an error if any array data in the /// input `buf` is not properly aligned. @@ -809,6 +858,21 @@ impl FileDecoder { self } + /// Specifies if validation should be skipped when reading data (defaults to `false`) + /// + /// # Safety + /// + /// This flag must only be set to `true` when you trust the input data and are sure the data you are + /// reading is a valid Arrow IPC file, otherwise undefined behavior may + /// result. + /// + /// For example, some programs may wish to trust reading IPC files written + /// by the same process that created the files. + pub unsafe fn with_skip_validation(mut self, skip_validation: bool) -> Self { + self.skip_validation.set(skip_validation); + self + } + fn read_message<'a>(&self, buf: &'a [u8]) -> Result, ArrowError> { let message = parse_message(buf)?; @@ -834,6 +898,7 @@ impl FileDecoder { &mut self.dictionaries, &message.version(), self.require_alignment, + self.skip_validation.clone(), ) } t => Err(ArrowError::ParseError(format!( @@ -867,6 +932,7 @@ impl FileDecoder { )? .with_projection(self.projection.as_deref()) .with_require_alignment(self.require_alignment) + .with_skip_validation(self.skip_validation.clone()) .read_record_batch() .map(Some) } @@ -1177,6 +1243,16 @@ impl FileReader { pub fn get_mut(&mut self) -> &mut R { &mut self.reader } + + /// Specifies if validation should be skipped when reading data (defaults to `false`) + /// + /// # Safety + /// + /// See [`FileDecoder::with_skip_validation`] + pub unsafe fn with_skip_validation(mut self, skip_validation: bool) -> Self { + self.decoder = self.decoder.with_skip_validation(skip_validation); + self + } } impl Iterator for FileReader { @@ -1250,6 +1326,11 @@ pub struct StreamReader { /// Optional projection projection: Option<(Vec, Schema)>, + + /// Should validation be skipped when reading data? Defaults to false. + /// + /// See [`FileDecoder::with_skip_validation`] for details. + skip_validation: UnsafeFlag, } impl fmt::Debug for StreamReader { @@ -1329,6 +1410,7 @@ impl StreamReader { finished: false, dictionaries_by_id, projection, + skip_validation: UnsafeFlag::new(), }) } @@ -1417,6 +1499,7 @@ impl StreamReader { )? .with_projection(self.projection.as_ref().map(|x| x.0.as_ref())) .with_require_alignment(false) + .with_skip_validation(self.skip_validation.clone()) .read_record_batch() .map(Some) } @@ -1437,6 +1520,7 @@ impl StreamReader { &mut self.dictionaries_by_id, &message.version(), false, + self.skip_validation.clone(), )?; // read the next message until we encounter a RecordBatch @@ -1462,6 +1546,16 @@ impl StreamReader { pub fn get_mut(&mut self) -> &mut R { &mut self.reader } + + /// Specifies if validation should be skipped when reading data (defaults to `false`) + /// + /// # Safety + /// + /// See [`FileDecoder::with_skip_validation`] + pub unsafe fn with_skip_validation(mut self, skip_validation: bool) -> Self { + self.skip_validation.set(skip_validation); + self + } } impl Iterator for StreamReader { @@ -1740,6 +1834,15 @@ mod tests { reader.next().unwrap() } + /// Return the first record batch read from the IPC File buffer, disabling + /// validation + fn read_ipc_skip_validation(buf: &[u8]) -> Result { + let mut reader = unsafe { + FileReader::try_new(std::io::Cursor::new(buf), None)?.with_skip_validation(true) + }; + reader.next().unwrap() + } + fn roundtrip_ipc(rb: &RecordBatch) -> RecordBatch { let buf = write_ipc(rb); read_ipc(&buf).unwrap() @@ -1748,6 +1851,19 @@ mod tests { /// Return the first record batch read from the IPC File buffer /// using the FileDecoder API fn read_ipc_with_decoder(buf: Vec) -> Result { + read_ipc_with_decoder_inner(buf, false) + } + + /// Return the first record batch read from the IPC File buffer + /// using the FileDecoder API, disabling validation + fn read_ipc_with_decoder_skip_validation(buf: Vec) -> Result { + read_ipc_with_decoder_inner(buf, true) + } + + fn read_ipc_with_decoder_inner( + buf: Vec, + skip_validation: bool, + ) -> Result { let buffer = Buffer::from_vec(buf); let trailer_start = buffer.len() - 10; let footer_len = read_footer_length(buffer[trailer_start..].try_into().unwrap())?; @@ -1756,7 +1872,10 @@ mod tests { let schema = fb_to_schema(footer.schema().unwrap()); - let mut decoder = FileDecoder::new(Arc::new(schema), footer.version()); + let mut decoder = unsafe { + FileDecoder::new(Arc::new(schema), footer.version()) + .with_skip_validation(skip_validation) + }; // Read dictionaries for block in footer.dictionaries().iter().flatten() { let block_len = block.bodyLength() as usize + block.metaDataLength() as usize; @@ -1789,6 +1908,15 @@ mod tests { reader.next().unwrap() } + /// Return the first record batch read from the IPC Stream buffer, + /// disabling validation + fn read_stream_skip_validation(buf: &[u8]) -> Result { + let mut reader = unsafe { + StreamReader::try_new(std::io::Cursor::new(buf), None)?.with_skip_validation(true) + }; + reader.next().unwrap() + } + fn roundtrip_ipc_stream(rb: &RecordBatch) -> RecordBatch { let buf = write_stream(rb); read_stream(&buf).unwrap() @@ -2456,6 +2584,57 @@ mod tests { ); } + #[test] + fn test_invalid_nested_array_ipc_read_errors() { + // one of the nested arrays has invalid data + let a_field = Field::new("a", DataType::Int32, false); + let b_field = Field::new("b", DataType::Utf8, false); + + let schema = Arc::new(Schema::new(vec![Field::new_struct( + "s", + vec![a_field.clone(), b_field.clone()], + false, + )])); + + let a_array_data = ArrayData::builder(a_field.data_type().clone()) + .len(4) + .add_buffer(Buffer::from_slice_ref([1, 2, 3, 4])) + .build() + .unwrap(); + // invalid nested child array -- length is correct, but has invalid utf8 data + let b_array_data = { + let valid: &[u8] = b" "; + let mut invalid = vec![]; + invalid.extend_from_slice(b"ValidString"); + invalid.extend_from_slice(INVALID_UTF8_FIRST_CHAR); + let binary_array = + BinaryArray::from_iter(vec![None, Some(valid), None, Some(&invalid)]); + let array = unsafe { + StringArray::new_unchecked( + binary_array.offsets().clone(), + binary_array.values().clone(), + binary_array.nulls().cloned(), + ) + }; + array.into_data() + }; + let struct_data_type = schema.field(0).data_type(); + + let invalid_struct_arr = unsafe { + make_array( + ArrayData::builder(struct_data_type.clone()) + .len(4) + .add_child_data(a_array_data) + .add_child_data(b_array_data) + .build_unchecked(), + ) + }; + expect_ipc_validation_error( + Arc::new(invalid_struct_arr), + "Invalid argument error: Invalid UTF8 sequence at string index 3 (3..18): invalid utf-8 sequence of 1 bytes from index 11", + ); + } + #[test] fn test_same_dict_id_without_preserve() { let batch = RecordBatch::try_new( @@ -2592,6 +2771,32 @@ mod tests { ); } + #[test] + fn test_validation_of_invalid_union_array() { + let array = unsafe { + let fields = UnionFields::new( + vec![1, 3], // typeids : type id 2 is not valid + vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Utf8, false), + ], + ); + let type_ids = ScalarBuffer::from(vec![1i8, 2, 3]); // 2 is invalid + let offsets = None; + let children: Vec = vec![ + Arc::new(Int32Array::from(vec![10, 20, 30])), + Arc::new(StringArray::from(vec![Some("a"), Some("b"), Some("c")])), + ]; + + UnionArray::new_unchecked(fields, type_ids, offsets, children) + }; + + expect_ipc_validation_error( + Arc::new(array), + "Invalid argument error: Type Ids values must match one of the field type ids", + ); + } + /// Invalid Utf-8 sequence in the first character /// const INVALID_UTF8_FIRST_CHAR: &[u8] = &[0xa0, 0xa1, 0x20, 0x20]; @@ -2602,18 +2807,18 @@ mod tests { // IPC Stream format let buf = write_stream(&rb); // write is ok + read_stream_skip_validation(&buf).unwrap(); let err = read_stream(&buf).unwrap_err(); assert_eq!(err.to_string(), expected_err); // IPC File format let buf = write_ipc(&rb); // write is ok + read_ipc_skip_validation(&buf).unwrap(); let err = read_ipc(&buf).unwrap_err(); assert_eq!(err.to_string(), expected_err); - // TODO verify there is no error when validation is disabled - // see https://github.com/apache/arrow-rs/issues/3287 - // IPC Format with FileDecoder + read_ipc_with_decoder_skip_validation(buf.clone()).unwrap(); let err = read_ipc_with_decoder(buf).unwrap_err(); assert_eq!(err.to_string(), expected_err); } diff --git a/arrow-ipc/src/reader/stream.rs b/arrow-ipc/src/reader/stream.rs index 174e69c1f670..5902cbe4e039 100644 --- a/arrow-ipc/src/reader/stream.rs +++ b/arrow-ipc/src/reader/stream.rs @@ -21,6 +21,7 @@ use std::sync::Arc; use arrow_array::{ArrayRef, RecordBatch}; use arrow_buffer::{Buffer, MutableBuffer}; +use arrow_data::UnsafeFlag; use arrow_schema::{ArrowError, SchemaRef}; use crate::convert::MessageBuffer; @@ -42,6 +43,12 @@ pub struct StreamDecoder { buf: MutableBuffer, /// Whether or not array data in input buffers are required to be aligned require_alignment: bool, + /// Should validation be skipped when reading data? Defaults to false. + /// + /// See [`FileDecoder::with_skip_validation`] for details. + /// + /// [`FileDecoder::with_skip_validation`]: crate::reader::FileDecoder::with_skip_validation + skip_validation: UnsafeFlag, } #[derive(Debug)] @@ -235,6 +242,7 @@ impl StreamDecoder { &mut self.dictionaries, &version, self.require_alignment, + self.skip_validation.clone(), )?; self.state = DecoderState::default(); }