diff --git a/Cargo.toml b/Cargo.toml index 6d3af17..049c291 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "grenad" description = "Tools to sort, merge, write, and read immutable key-value pairs." -version = "0.4.7" +version = "0.5.0" authors = ["Kerollmops "] repository = "https://github.com/meilisearch/grenad" documentation = "https://docs.rs/grenad" @@ -11,6 +11,7 @@ license = "MIT" [dependencies] bytemuck = { version = "1.16.1", features = ["derive"] } byteorder = "1.5.0" +either = { version = "1.13.0", default-features = false } flate2 = { version = "1.0", optional = true } lz4_flex = { version = "0.11.3", optional = true } rayon = { version = "1.10.0", optional = true } diff --git a/benches/index-levels.rs b/benches/index-levels.rs index fa6f815..1faddd5 100644 --- a/benches/index-levels.rs +++ b/benches/index-levels.rs @@ -16,7 +16,7 @@ fn index_levels(bytes: &[u8]) { for x in (0..NUMBER_OF_ENTRIES).step_by(1_567) { let num = x.to_be_bytes(); - cursor.move_on_key_greater_than_or_equal_to(&num).unwrap().unwrap(); + cursor.move_on_key_greater_than_or_equal_to(num).unwrap().unwrap(); } } diff --git a/src/block_writer.rs b/src/block_writer.rs index 2af3568..d537e01 100644 --- a/src/block_writer.rs +++ b/src/block_writer.rs @@ -94,8 +94,8 @@ impl BlockWriter { /// Insert a key that must be greater than the previously added one. pub fn insert(&mut self, key: &[u8], val: &[u8]) { debug_assert!(self.index_key_counter <= self.index_key_interval.get()); - assert!(key.len() <= u32::max_value() as usize); - assert!(val.len() <= u32::max_value() as usize); + assert!(key.len() <= u32::MAX as usize); + assert!(val.len() <= u32::MAX as usize); if self.index_key_counter == self.index_key_interval.get() { self.index_offsets.push(self.buffer.len() as u64); @@ -106,7 +106,7 @@ impl BlockWriter { // and save the current key to become the last key. match &mut self.last_key { Some(last_key) => { - assert!(key > last_key, "{:?} must be greater than {:?}", key, last_key); + assert!(key > last_key.as_slice(), "{:?} must be greater than {:?}", key, last_key); last_key.clear(); last_key.extend_from_slice(key); } diff --git a/src/compression.rs b/src/compression.rs index 653ee6a..3575070 100644 --- a/src/compression.rs +++ b/src/compression.rs @@ -4,10 +4,11 @@ use std::str::FromStr; use std::{fmt, io}; /// The different supported types of compression. -#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[derive(Default, Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] #[repr(u8)] pub enum CompressionType { /// Do not compress the blocks. + #[default] None = 0, /// Use the [`snap`] crate to de/compress the blocks. /// @@ -55,12 +56,6 @@ impl FromStr for CompressionType { } } -impl Default for CompressionType { - fn default() -> CompressionType { - CompressionType::None - } -} - /// An invalid compression type have been read and the block can't be de/compressed. #[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct InvalidCompressionType; @@ -107,6 +102,7 @@ fn zlib_decompress(data: R, out: &mut Vec) -> io::Result<()> { } #[cfg(not(feature = "zlib"))] +#[allow(clippy::ptr_arg)] // it doesn't understand that I need the same signature for all function fn zlib_decompress(_data: R, _out: &mut Vec) -> io::Result<()> { Err(io::Error::new(io::ErrorKind::Other, "unsupported zlib decompression")) } @@ -186,6 +182,7 @@ fn zstd_decompress(data: R, out: &mut Vec) -> io::Result<()> { } #[cfg(not(feature = "zstd"))] +#[allow(clippy::ptr_arg)] // it doesn't understand that I need the same signature for all function fn zstd_decompress(_data: R, _out: &mut Vec) -> io::Result<()> { Err(io::Error::new(io::ErrorKind::Other, "unsupported zstd decompression")) } @@ -211,6 +208,7 @@ fn lz4_decompress(data: R, out: &mut Vec) -> io::Result<()> { } #[cfg(not(feature = "lz4"))] +#[allow(clippy::ptr_arg)] // it doesn't understand that I need the same signature for all function fn lz4_decompress(_data: R, _out: &mut Vec) -> io::Result<()> { Err(io::Error::new(io::ErrorKind::Other, "unsupported lz4 decompression")) } diff --git a/src/lib.rs b/src/lib.rs index 83f7143..8897178 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -72,23 +72,25 @@ //! use std::convert::TryInto; //! use std::io::Cursor; //! -//! use grenad::{MergerBuilder, Reader, Writer}; +//! use grenad::{MergerBuilder, MergeFunction, Reader, Writer}; //! //! // This merge function: //! // - parses u32s from native-endian bytes, //! // - wrapping sums them and, //! // - outputs the result as native-endian bytes. -//! fn wrapping_sum_u32s<'a>( -//! _key: &[u8], -//! values: &[Cow<'a, [u8]>], -//! ) -> Result, TryFromSliceError> -//! { -//! let mut output: u32 = 0; -//! for bytes in values.iter().map(AsRef::as_ref) { -//! let num = bytes.try_into().map(u32::from_ne_bytes)?; -//! output = output.wrapping_add(num); +//! struct WrappingSumU32s; +//! +//! impl MergeFunction for WrappingSumU32s { +//! type Error = TryFromSliceError; +//! +//! fn merge<'a>(&self, key: &[u8], values: &[Cow<'a, [u8]>]) -> Result, Self::Error> { +//! let mut output: u32 = 0; +//! for bytes in values.iter().map(AsRef::as_ref) { +//! let num = bytes.try_into().map(u32::from_ne_bytes)?; +//! output = output.wrapping_add(num); +//! } +//! Ok(Cow::Owned(output.to_ne_bytes().to_vec())) //! } -//! Ok(Cow::Owned(output.to_ne_bytes().to_vec())) //! } //! //! # fn main() -> Result<(), Box> { @@ -115,7 +117,7 @@ //! //! // We create a merger that will sum our u32s when necessary, //! // and we add our readers to the list of readers to merge. -//! let merger_builder = MergerBuilder::new(wrapping_sum_u32s); +//! let merger_builder = MergerBuilder::new(WrappingSumU32s); //! let merger = merger_builder.add(readera).add(readerb).add(readerc).build(); //! //! // We can iterate over the entries in key-order. @@ -142,28 +144,30 @@ //! use std::borrow::Cow; //! use std::convert::TryInto; //! -//! use grenad::{CursorVec, SorterBuilder}; +//! use grenad::{CursorVec, MergeFunction, SorterBuilder}; //! //! // This merge function: //! // - parses u32s from native-endian bytes, //! // - wrapping sums them and, //! // - outputs the result as native-endian bytes. -//! fn wrapping_sum_u32s<'a>( -//! _key: &[u8], -//! values: &[Cow<'a, [u8]>], -//! ) -> Result, TryFromSliceError> -//! { -//! let mut output: u32 = 0; -//! for bytes in values.iter().map(AsRef::as_ref) { -//! let num = bytes.try_into().map(u32::from_ne_bytes)?; -//! output = output.wrapping_add(num); +//! struct WrappingSumU32s; +//! +//! impl MergeFunction for WrappingSumU32s { +//! type Error = TryFromSliceError; +//! +//! fn merge<'a>(&self, key: &[u8], values: &[Cow<'a, [u8]>]) -> Result, Self::Error> { +//! let mut output: u32 = 0; +//! for bytes in values.iter().map(AsRef::as_ref) { +//! let num = bytes.try_into().map(u32::from_ne_bytes)?; +//! output = output.wrapping_add(num); +//! } +//! Ok(Cow::Owned(output.to_ne_bytes().to_vec())) //! } -//! Ok(Cow::Owned(output.to_ne_bytes().to_vec())) //! } //! //! # fn main() -> Result<(), Box> { //! // We create a sorter that will sum our u32s when necessary. -//! let mut sorter = SorterBuilder::new(wrapping_sum_u32s).chunk_creator(CursorVec).build(); +//! let mut sorter = SorterBuilder::new(WrappingSumU32s).chunk_creator(CursorVec).build(); //! //! // We insert multiple entries with the same key but different values //! // in arbitrary order, the sorter will take care of merging them for us. @@ -187,7 +191,7 @@ #[cfg(test)] #[macro_use] extern crate quickcheck; - +use std::convert::Infallible; use std::mem; mod block; @@ -195,6 +199,7 @@ mod block_writer; mod compression; mod count_write; mod error; +mod merge_function; mod merger; mod metadata; mod reader; @@ -204,6 +209,7 @@ mod writer; pub use self::compression::CompressionType; pub use self::error::Error; +pub use self::merge_function::MergeFunction; pub use self::merger::{Merger, MergerBuilder, MergerIter}; pub use self::metadata::FileVersion; pub use self::reader::{PrefixIter, RangeIter, Reader, ReaderCursor, RevPrefixIter, RevRangeIter}; @@ -214,10 +220,12 @@ pub use self::sorter::{ }; pub use self::writer::{Writer, WriterBuilder}; +pub type Result = std::result::Result>; + /// Sometimes we need to use an unsafe trick to make the compiler happy. /// You can read more about the issue [on the Rust's Github issues]. /// /// [on the Rust's Github issues]: https://github.com/rust-lang/rust/issues/47680 unsafe fn transmute_entry_to_static(key: &[u8], val: &[u8]) -> (&'static [u8], &'static [u8]) { - (mem::transmute(key), mem::transmute(val)) + (mem::transmute::<&[u8], &'static [u8]>(key), mem::transmute::<&[u8], &'static [u8]>(val)) } diff --git a/src/merge_function.rs b/src/merge_function.rs new file mode 100644 index 0000000..186a5c8 --- /dev/null +++ b/src/merge_function.rs @@ -0,0 +1,46 @@ +use std::borrow::Cow; +use std::result::Result; + +use either::Either; + +/// A trait defining the way we merge multiple +/// values sharing the same key. +pub trait MergeFunction { + type Error; + fn merge<'a>(&self, key: &[u8], values: &[Cow<'a, [u8]>]) + -> Result, Self::Error>; +} + +impl MergeFunction for &MF +where + MF: MergeFunction, +{ + type Error = MF::Error; + + fn merge<'a>( + &self, + key: &[u8], + values: &[Cow<'a, [u8]>], + ) -> Result, Self::Error> { + (*self).merge(key, values) + } +} + +impl MergeFunction for Either +where + MFA: MergeFunction, + MFB: MergeFunction, +{ + type Error = MFA::Error; + + fn merge<'a>( + &self, + key: &[u8], + values: &[Cow<'a, [u8]>], + ) -> Result, Self::Error> { + match self { + Either::Left(mfa) => mfa.merge(key, values), + Either::Right(mfb) => mfb.merge(key, values), + } + } +} diff --git a/src/merger.rs b/src/merger.rs index 7ca7084..38310a2 100644 --- a/src/merger.rs +++ b/src/merger.rs @@ -4,7 +4,7 @@ use std::collections::BinaryHeap; use std::io; use std::iter::once; -use crate::{Error, ReaderCursor, Writer}; +use crate::{Error, MergeFunction, ReaderCursor, Writer}; /// A struct that is used to configure a [`Merger`] with the sources to merge. pub struct MergerBuilder { @@ -20,6 +20,7 @@ impl MergerBuilder { } /// Add a source to merge, this function can be chained. + #[allow(clippy::should_implement_trait)] // We return interior references pub fn add(mut self, source: ReaderCursor) -> Self { self.push(source); self @@ -95,7 +96,7 @@ impl Merger { } Ok(MergerIter { - merge: self.merge, + merge_function: self.merge, heap, current_key: Vec::new(), merged_value: Vec::new(), @@ -104,16 +105,16 @@ impl Merger { } } -impl Merger +impl Merger where R: io::Read + io::Seek, - MF: for<'a> Fn(&[u8], &[Cow<'a, [u8]>]) -> Result, U>, + MF: MergeFunction, { /// Consumes this [`Merger`] and streams the entries to the [`Writer`] given in parameter. pub fn write_into_stream_writer( self, writer: &mut Writer, - ) -> Result<(), Error> { + ) -> crate::Result<(), MF::Error> { let mut iter = self.into_stream_merger_iter().map_err(Error::convert_merge_error)?; while let Some((key, val)) = iter.next()? { writer.insert(key, val)?; @@ -124,7 +125,7 @@ where /// An iterator that yield the merged entries in key-order. pub struct MergerIter { - merge: MF, + merge_function: MF, heap: BinaryHeap>, current_key: Vec, merged_value: Vec, @@ -132,13 +133,15 @@ pub struct MergerIter { tmp_entries: Vec>, } -impl MergerIter +impl MergerIter where R: io::Read + io::Seek, - MF: for<'a> Fn(&[u8], &[Cow<'a, [u8]>]) -> Result, U>, + MF: MergeFunction, { /// Yield the entries in key-order where values have been merged when needed. - pub fn next(&mut self) -> Result, Error> { + #[allow(clippy::should_implement_trait)] // We return interior references + #[allow(clippy::type_complexity)] // Return type is not THAT complex + pub fn next(&mut self) -> crate::Result, MF::Error> { let first_entry = match self.heap.pop() { Some(entry) => entry, None => return Ok(None), @@ -167,7 +170,7 @@ where self.tmp_entries.iter().filter_map(|e| e.cursor.current().map(|(_, v)| v)); let values: Vec<_> = once(first_value).chain(other_values).map(Cow::Borrowed).collect(); - match (self.merge)(first_key, &values) { + match self.merge_function.merge(first_key, &values) { Ok(value) => { self.current_key.clear(); self.current_key.extend_from_slice(first_key); diff --git a/src/reader/prefix_iter.rs b/src/reader/prefix_iter.rs index 3da1f95..964710e 100644 --- a/src/reader/prefix_iter.rs +++ b/src/reader/prefix_iter.rs @@ -1,6 +1,6 @@ use std::io; -use crate::{Error, ReaderCursor}; +use crate::ReaderCursor; /// An iterator that is able to yield all the entries with /// a key that starts with a given prefix. @@ -18,7 +18,8 @@ impl PrefixIter { } /// Returns the next entry that starts with the given prefix. - pub fn next(&mut self) -> Result, Error> { + #[allow(clippy::should_implement_trait)] // We return interior references + pub fn next(&mut self) -> crate::Result> { let entry = if self.move_on_first_prefix { self.move_on_first_prefix = false; self.cursor.move_on_key_greater_than_or_equal_to(&self.prefix)? @@ -49,7 +50,8 @@ impl RevPrefixIter { } /// Returns the next entry that starts with the given prefix. - pub fn next(&mut self) -> Result, Error> { + #[allow(clippy::should_implement_trait)] // We return interior references + pub fn next(&mut self) -> crate::Result> { let entry = if self.move_on_last_prefix { self.move_on_last_prefix = false; move_on_last_prefix(&mut self.cursor, self.prefix.clone())? @@ -68,7 +70,7 @@ impl RevPrefixIter { fn move_on_last_prefix( cursor: &mut ReaderCursor, prefix: Vec, -) -> Result, Error> { +) -> crate::Result> { match advance_key(prefix) { Some(next_prefix) => match cursor.move_on_key_lower_than_or_equal_to(&next_prefix)? { Some((k, _)) if k == next_prefix => cursor.move_on_prev(), @@ -108,7 +110,7 @@ mod tests { let mut writer = Writer::memory(); for x in (10..24000u32).step_by(3) { let x = x.to_be_bytes(); - writer.insert(&x, &x).unwrap(); + writer.insert(x, x).unwrap(); } let bytes = writer.into_inner().unwrap(); diff --git a/src/reader/range_iter.rs b/src/reader/range_iter.rs index c4f92ad..133233a 100644 --- a/src/reader/range_iter.rs +++ b/src/reader/range_iter.rs @@ -1,7 +1,7 @@ use std::io; use std::ops::{Bound, RangeBounds}; -use crate::{Error, ReaderCursor}; +use crate::ReaderCursor; /// An iterator that is able to yield all the entries lying in a specified range. #[derive(Clone)] @@ -24,7 +24,8 @@ impl RangeIter { } /// Returns the next entry that is inside of the given range. - pub fn next(&mut self) -> Result, Error> { + #[allow(clippy::should_implement_trait)] // We return interior references + pub fn next(&mut self) -> crate::Result> { let entry = if self.move_on_start { self.move_on_start = false; match self.range.start_bound() { @@ -75,7 +76,8 @@ impl RevRangeIter { } /// Returns the next entry that is inside of the given range. - pub fn next(&mut self) -> Result, Error> { + #[allow(clippy::should_implement_trait)] // We return interior references + pub fn next(&mut self) -> crate::Result> { let entry = if self.move_on_start { self.move_on_start = false; match self.range.end_bound() { @@ -116,8 +118,8 @@ fn map_bound U>(bound: Bound, f: F) -> Bound { fn end_contains(end: Bound<&Vec>, key: &[u8]) -> bool { match end { Bound::Unbounded => true, - Bound::Included(end) => key <= end, - Bound::Excluded(end) => key < end, + Bound::Included(end) => key <= end.as_slice(), + Bound::Excluded(end) => key < end.as_slice(), } } @@ -125,8 +127,8 @@ fn end_contains(end: Bound<&Vec>, key: &[u8]) -> bool { fn start_contains(end: Bound<&Vec>, key: &[u8]) -> bool { match end { Bound::Unbounded => true, - Bound::Included(end) => key >= end, - Bound::Excluded(end) => key > end, + Bound::Included(end) => key >= end.as_slice(), + Bound::Excluded(end) => key > end.as_slice(), } } @@ -149,7 +151,7 @@ mod tests { for x in (10..24000i32).step_by(3) { nums.insert(x); let x = x.to_be_bytes(); - writer.insert(&x, &x).unwrap(); + writer.insert(x, x).unwrap(); } let bytes = writer.into_inner().unwrap(); @@ -186,7 +188,7 @@ mod tests { for x in (10..24000i32).step_by(3) { nums.insert(x); let x = x.to_be_bytes(); - writer.insert(&x, &x).unwrap(); + writer.insert(x, x).unwrap(); } let bytes = writer.into_inner().unwrap(); diff --git a/src/reader/reader_cursor.rs b/src/reader/reader_cursor.rs index 1ca6b19..94d6326 100644 --- a/src/reader/reader_cursor.rs +++ b/src/reader/reader_cursor.rs @@ -89,7 +89,7 @@ impl ReaderCursor { } /// Moves the cursor on the first entry and returns it. - pub fn move_on_first(&mut self) -> Result, Error> { + pub fn move_on_first(&mut self) -> crate::Result> { match self.index_block_cursor.move_on_first(&mut self.reader.reader)? { Some((_, offset_bytes)) => { let offset = offset_bytes.try_into().map(u64::from_be_bytes).unwrap(); @@ -109,7 +109,7 @@ impl ReaderCursor { } /// Moves the cursor on the last entry and returns it. - pub fn move_on_last(&mut self) -> Result, Error> { + pub fn move_on_last(&mut self) -> crate::Result> { match self.index_block_cursor.move_on_last(&mut self.reader.reader)? { Some((_, offset_bytes)) => { let offset = offset_bytes.try_into().map(u64::from_be_bytes).unwrap(); @@ -129,7 +129,7 @@ impl ReaderCursor { } /// Moves the cursor on the entry following the current one and returns it. - pub fn move_on_next(&mut self) -> Result, Error> { + pub fn move_on_next(&mut self) -> crate::Result> { match self.current_cursor.as_mut().map(BlockCursor::move_on_next) { Some(Some((key, val))) => { let (key, val) = unsafe { crate::transmute_entry_to_static(key, val) }; @@ -147,7 +147,7 @@ impl ReaderCursor { } /// Moves the cursor on the entry preceding the current one and returns it. - pub fn move_on_prev(&mut self) -> Result, Error> { + pub fn move_on_prev(&mut self) -> crate::Result> { match self.current_cursor.as_mut().map(BlockCursor::move_on_prev) { Some(Some((key, val))) => { let (key, val) = unsafe { crate::transmute_entry_to_static(key, val) }; @@ -169,7 +169,7 @@ impl ReaderCursor { pub fn move_on_key_lower_than_or_equal_to>( &mut self, target_key: A, - ) -> Result, Error> { + ) -> crate::Result> { let target_key = target_key.as_ref(); match self.move_on_key_greater_than_or_equal_to(target_key)? { Some((key, val)) if key == target_key => { @@ -186,7 +186,7 @@ impl ReaderCursor { pub fn move_on_key_greater_than_or_equal_to>( &mut self, key: A, - ) -> Result, Error> { + ) -> crate::Result> { // We move on the block which has a key greater than or equal to the key we are // searching for as the key stored in the index block is the last key of the block. let key = key.as_ref(); @@ -213,7 +213,7 @@ impl ReaderCursor { pub fn move_on_key_equal_to>( &mut self, key: A, - ) -> Result, Error> { + ) -> crate::Result> { let key = key.as_ref(); self.move_on_key_greater_than_or_equal_to(key).map(|opt| opt.filter(|(k, _)| *k == key)) } @@ -255,28 +255,28 @@ impl IndexBlockCursor { fn move_on_first( &mut self, reader: R, - ) -> Result, Error> { + ) -> crate::Result> { self.iter_index_blocks(reader, |c| c.move_on_first()) } fn move_on_last( &mut self, reader: R, - ) -> Result, Error> { + ) -> crate::Result> { self.iter_index_blocks(reader, |c| c.move_on_last()) } fn move_on_next( &mut self, reader: R, - ) -> Result, Error> { + ) -> crate::Result> { self.recursive_index_block(reader, |c| c.move_on_next()) } fn move_on_prev( &mut self, reader: R, - ) -> Result, Error> { + ) -> crate::Result> { self.recursive_index_block(reader, |c| c.move_on_prev()) } @@ -284,7 +284,7 @@ impl IndexBlockCursor { &mut self, key: &[u8], reader: R, - ) -> Result, Error> { + ) -> crate::Result> { self.iter_index_blocks(reader, |c| c.move_on_key_greater_than_or_equal_to(key)) } @@ -292,7 +292,7 @@ impl IndexBlockCursor { &mut self, mut reader: R, mut mov: F, - ) -> Result, Error> + ) -> crate::Result> where R: io::Read + io::Seek, F: FnMut(&mut BlockCursor) -> Option<(&[u8], &[u8])>, @@ -334,7 +334,7 @@ impl IndexBlockCursor { &mut self, mut reader: R, mut mov: FM, - ) -> Result, Error> + ) -> crate::Result> where R: io::Read + io::Seek, FM: FnMut(&mut BlockCursor) -> Option<(&[u8], &[u8])>, @@ -344,7 +344,7 @@ impl IndexBlockCursor { compression_type: CompressionType, blocks: &'a mut [(u64, BlockCursor)], mov: &mut FN, - ) -> Result, Error> + ) -> crate::Result> where S: io::Read + io::Seek, FN: FnMut(&mut BlockCursor) -> Option<(&[u8], &[u8])>, @@ -393,11 +393,12 @@ impl IndexBlockCursor { } /// Returns the index block cursors by calling the user function to load the blocks. + #[allow(clippy::type_complexity)] // Return type is not THAT complex fn initial_index_blocks( &mut self, mut reader: R, mut mov: FM, - ) -> Result)>>, Error> + ) -> crate::Result)>>> where R: io::Read + io::Seek, FM: FnMut(&mut BlockCursor) -> Option<(&[u8], &[u8])>, @@ -441,7 +442,7 @@ mod tests { let reader = Reader::new(Cursor::new(bytes.as_slice())).unwrap(); let mut cursor = reader.into_cursor().unwrap(); - let result = cursor.move_on_key_greater_than_or_equal_to(&[0, 0, 0, 0]).unwrap(); + let result = cursor.move_on_key_greater_than_or_equal_to([0, 0, 0, 0]).unwrap(); assert_eq!(result, None); } @@ -453,7 +454,7 @@ mod tests { for x in 0..2000u32 { let x = x.to_be_bytes(); - writer.insert(&x, &x).unwrap(); + writer.insert(x, x).unwrap(); } let bytes = writer.into_inner().unwrap(); @@ -490,7 +491,7 @@ mod tests { for x in 0..2000u32 { let x = x.to_be_bytes(); - writer.insert(&x, &x).unwrap(); + writer.insert(x, x).unwrap(); } let bytes = writer.into_inner().unwrap(); @@ -517,7 +518,7 @@ mod tests { for x in (10..24000i32).step_by(3) { nums.push(x); let x = x.to_be_bytes(); - writer.insert(&x, &x).unwrap(); + writer.insert(x, x).unwrap(); } let bytes = writer.into_inner().unwrap(); @@ -531,7 +532,7 @@ mod tests { Ok(i) => { let n = nums[i]; let (k, _) = cursor - .move_on_key_greater_than_or_equal_to(&n.to_be_bytes()) + .move_on_key_greater_than_or_equal_to(n.to_be_bytes()) .unwrap() .unwrap(); let k = k.try_into().map(i32::from_be_bytes).unwrap(); @@ -539,7 +540,7 @@ mod tests { } Err(i) => { let k = cursor - .move_on_key_greater_than_or_equal_to(&n.to_be_bytes()) + .move_on_key_greater_than_or_equal_to(n.to_be_bytes()) .unwrap() .map(|(k, _)| k.try_into().map(i32::from_be_bytes).unwrap()); assert_eq!(k, nums.get(i).copied()); @@ -556,7 +557,7 @@ mod tests { for x in (10..24000i32).step_by(3) { nums.push(x); let x = x.to_be_bytes(); - writer.insert(&x, &x).unwrap(); + writer.insert(x, x).unwrap(); } let bytes = writer.into_inner().unwrap(); @@ -569,7 +570,7 @@ mod tests { Ok(i) => { let n = nums[i]; let (k, _) = cursor - .move_on_key_lower_than_or_equal_to(&n.to_be_bytes()) + .move_on_key_lower_than_or_equal_to(n.to_be_bytes()) .unwrap() .unwrap(); let k = k.try_into().map(i32::from_be_bytes).unwrap(); @@ -577,7 +578,7 @@ mod tests { } Err(i) => { let k = cursor - .move_on_key_lower_than_or_equal_to(&n.to_be_bytes()) + .move_on_key_lower_than_or_equal_to(n.to_be_bytes()) .unwrap() .map(|(k, _)| k.try_into().map(i32::from_be_bytes).unwrap()); let expected = i.checked_sub(1).and_then(|i| nums.get(i)).copied(); @@ -597,7 +598,7 @@ mod tests { for x in (10..24000i32).step_by(3) { nums.push(x); let x = x.to_be_bytes(); - writer.insert(&x, &x).unwrap(); + writer.insert(x, x).unwrap(); } let bytes = writer.into_inner().unwrap(); @@ -611,7 +612,7 @@ mod tests { Ok(i) => { let n = nums[i]; let (k, _) = cursor - .move_on_key_greater_than_or_equal_to(&n.to_be_bytes()) + .move_on_key_greater_than_or_equal_to(n.to_be_bytes()) .unwrap() .unwrap(); let k = k.try_into().map(i32::from_be_bytes).unwrap(); @@ -619,7 +620,7 @@ mod tests { } Err(i) => { let k = cursor - .move_on_key_greater_than_or_equal_to(&n.to_be_bytes()) + .move_on_key_greater_than_or_equal_to(n.to_be_bytes()) .unwrap() .map(|(k, _)| k.try_into().map(i32::from_be_bytes).unwrap()); assert_eq!(k, nums.get(i).copied()); @@ -638,7 +639,7 @@ mod tests { for x in (10..24000i32).step_by(3) { nums.push(x); let x = x.to_be_bytes(); - writer.insert(&x, &x).unwrap(); + writer.insert(x, x).unwrap(); } let bytes = writer.into_inner().unwrap(); @@ -651,7 +652,7 @@ mod tests { Ok(i) => { let n = nums[i]; let (k, _) = cursor - .move_on_key_lower_than_or_equal_to(&n.to_be_bytes()) + .move_on_key_lower_than_or_equal_to(n.to_be_bytes()) .unwrap() .unwrap(); let k = k.try_into().map(i32::from_be_bytes).unwrap(); @@ -659,7 +660,7 @@ mod tests { } Err(i) => { let k = cursor - .move_on_key_lower_than_or_equal_to(&n.to_be_bytes()) + .move_on_key_lower_than_or_equal_to(n.to_be_bytes()) .unwrap() .map(|(k, _)| k.try_into().map(i32::from_be_bytes).unwrap()); let expected = i.checked_sub(1).and_then(|i| nums.get(i)).copied(); @@ -679,7 +680,7 @@ mod tests { let mut writer = Writer::builder().index_levels(2).memory(); for &x in &nums { let x = x.to_be_bytes(); - writer.insert(&x, &x).unwrap(); + writer.insert(x, x).unwrap(); } let bytes = writer.into_inner().unwrap(); @@ -691,7 +692,7 @@ mod tests { Ok(i) => { let q = nums[i]; let (k, _) = cursor - .move_on_key_lower_than_or_equal_to(&q.to_be_bytes()) + .move_on_key_lower_than_or_equal_to(q.to_be_bytes()) .unwrap() .unwrap(); let k = k.try_into().map(u32::from_be_bytes).unwrap(); @@ -701,7 +702,7 @@ mod tests { } Err(i) => { let k = cursor - .move_on_key_lower_than_or_equal_to(&q.to_be_bytes()) + .move_on_key_lower_than_or_equal_to(q.to_be_bytes()) .unwrap() .map(|(k, _)| k.try_into().map(u32::from_be_bytes).unwrap()); let expected = i.checked_sub(1).and_then(|i| nums.get(i)).copied(); diff --git a/src/sorter.rs b/src/sorter.rs index 8f022ed..cc42b16 100644 --- a/src/sorter.rs +++ b/src/sorter.rs @@ -1,11 +1,13 @@ use std::alloc::{alloc, dealloc, Layout}; use std::borrow::Cow; use std::convert::Infallible; +use std::fmt::Debug; #[cfg(feature = "tempfile")] use std::fs::File; use std::io::{Cursor, Read, Seek, SeekFrom, Write}; use std::mem::{align_of, size_of}; use std::num::NonZeroUsize; +use std::ptr::NonNull; use std::{cmp, io, ops, slice}; use bytemuck::{cast_slice, cast_slice_mut, Pod, Zeroable}; @@ -20,7 +22,8 @@ const DEFAULT_NB_CHUNKS: usize = 25; const MIN_NB_CHUNKS: usize = 1; use crate::{ - CompressionType, Error, Merger, MergerIter, Reader, ReaderCursor, Writer, WriterBuilder, + CompressionType, Error, MergeFunction, Merger, MergerIter, Reader, ReaderCursor, Writer, + WriterBuilder, }; /// The kind of sort algorithm used by the sorter to sort its internal vector. @@ -194,7 +197,7 @@ impl SorterBuilder { chunk_creator: self.chunk_creator, sort_algorithm: self.sort_algorithm, sort_in_parallel: self.sort_in_parallel, - merge: self.merge, + merge_function: self.merge, } } } @@ -238,8 +241,8 @@ impl Entries { /// Inserts a new entry into the buffer, if there is not /// enough space for it to be stored, we double the buffer size. pub fn insert(&mut self, key: &[u8], data: &[u8]) { - assert!(key.len() <= u32::max_value() as usize); - assert!(data.len() <= u32::max_value() as usize); + assert!(key.len() <= u32::MAX as usize); + assert!(data.len() <= u32::MAX as usize); if self.fits(key, data) { // We store the key and data bytes one after the other at the back of the buffer. @@ -374,7 +377,10 @@ struct EntryBound { } /// Represents an `EntryBound` aligned buffer. -struct EntryBoundAlignedBuffer(&'static mut [u8]); +struct EntryBoundAlignedBuffer { + data: NonNull, + len: usize, +} impl EntryBoundAlignedBuffer { /// Allocates a new buffer of the given size, it is correctly aligned to store `EntryBound`s. @@ -383,13 +389,14 @@ impl EntryBoundAlignedBuffer { let size = (size + entry_bound_size - 1) / entry_bound_size * entry_bound_size; let layout = Layout::from_size_align(size, align_of::()).unwrap(); let ptr = unsafe { alloc(layout) }; - assert!( - !ptr.is_null(), - "the allocator is unable to allocate that much memory ({} bytes requested)", - size - ); - let slice = unsafe { slice::from_raw_parts_mut(ptr, size) }; - EntryBoundAlignedBuffer(slice) + let Some(ptr) = NonNull::new(ptr) else { + panic!( + "the allocator is unable to allocate that much memory ({} bytes requested)", + size + ); + }; + + EntryBoundAlignedBuffer { data: ptr, len: size } } } @@ -397,20 +404,21 @@ impl ops::Deref for EntryBoundAlignedBuffer { type Target = [u8]; fn deref(&self) -> &Self::Target { - self.0 + unsafe { slice::from_raw_parts(self.data.as_ptr(), self.len) } } } impl ops::DerefMut for EntryBoundAlignedBuffer { fn deref_mut(&mut self) -> &mut Self::Target { - self.0 + unsafe { slice::from_raw_parts_mut(self.data.as_ptr(), self.len) } } } impl Drop for EntryBoundAlignedBuffer { fn drop(&mut self) { - let layout = Layout::from_size_align(self.0.len(), align_of::()).unwrap(); - unsafe { dealloc(self.0.as_mut_ptr(), layout) } + let layout = Layout::from_size_align(self.len, align_of::()).unwrap(); + + unsafe { dealloc(self.data.as_ptr(), layout) } } } @@ -434,7 +442,7 @@ pub struct Sorter { chunk_creator: CC, sort_algorithm: SortAlgorithm, sort_in_parallel: bool, - merge: MF, + merge_function: MF, } impl Sorter { @@ -460,14 +468,14 @@ impl Sorter { } } -impl Sorter +impl Sorter where - MF: for<'a> Fn(&[u8], &[Cow<'a, [u8]>]) -> Result, U>, + MF: MergeFunction, CC: ChunkCreator, { /// Insert an entry into the [`Sorter`] making sure that conflicts /// are resolved by the provided merge function. - pub fn insert(&mut self, key: K, val: V) -> Result<(), Error> + pub fn insert(&mut self, key: K, val: V) -> crate::Result<(), MF::Error> where K: AsRef<[u8]>, V: AsRef<[u8]>, @@ -498,7 +506,7 @@ where /// /// Writes the in-memory entries to disk, using the specify settings /// to compress the block and entries. It clears the in-memory entries. - fn write_chunk(&mut self) -> Result> { + fn write_chunk(&mut self) -> crate::Result { let count_write_chunk = self .chunk_creator .create() @@ -536,7 +544,8 @@ where None => current = Some((key, vec![Cow::Borrowed(value)])), Some((current_key, vals)) => { if current_key != &key { - let merged_val = (self.merge)(current_key, vals).map_err(Error::Merge)?; + let merged_val = + self.merge_function.merge(current_key, vals).map_err(Error::Merge)?; writer.insert(¤t_key, &merged_val)?; vals.clear(); *current_key = key; @@ -547,7 +556,7 @@ where } if let Some((key, vals)) = current.take() { - let merged_val = (self.merge)(key, &vals).map_err(Error::Merge)?; + let merged_val = self.merge_function.merge(key, &vals).map_err(Error::Merge)?; writer.insert(key, &merged_val)?; } @@ -569,7 +578,7 @@ where /// /// Merges all of the chunks into a final chunk that replaces them. /// It uses the user provided merge function to resolve merge conflicts. - fn merge_chunks(&mut self) -> Result> { + fn merge_chunks(&mut self) -> crate::Result { let count_write_chunk = self .chunk_creator .create() @@ -595,7 +604,7 @@ where } let mut writer = writer_builder.build(count_write_chunk); - let sources: Result, Error> = self + let sources: crate::Result, MF::Error> = self .chunks .drain(..) .map(|mut chunk| { @@ -605,7 +614,7 @@ where .collect(); // Create a merger to merge all those chunks. - let mut builder = Merger::builder(&self.merge); + let mut builder = Merger::builder(&self.merge_function); builder.extend(sources?); let merger = builder.build(); @@ -628,7 +637,7 @@ where pub fn write_into_stream_writer( self, writer: &mut Writer, - ) -> Result<(), Error> { + ) -> crate::Result<(), MF::Error> { let mut iter = self.into_stream_merger_iter()?; while let Some((key, val)) = iter.next()? { writer.insert(key, val)?; @@ -637,7 +646,7 @@ where } /// Consumes this [`Sorter`] and outputs a stream of the merged entries in key-order. - pub fn into_stream_merger_iter(self) -> Result, Error> { + pub fn into_stream_merger_iter(self) -> crate::Result, MF::Error> { let (sources, merge) = self.extract_reader_cursors_and_merger()?; let mut builder = Merger::builder(merge); builder.extend(sources); @@ -645,18 +654,19 @@ where } /// Consumes this [`Sorter`] and outputs the list of reader cursors. - pub fn into_reader_cursors(self) -> Result>, Error> { + pub fn into_reader_cursors(self) -> crate::Result>, MF::Error> { self.extract_reader_cursors_and_merger().map(|(readers, _)| readers) } /// A helper function to extract the readers and the merge function. + #[allow(clippy::type_complexity)] // Return type is not THAT complex fn extract_reader_cursors_and_merger( mut self, - ) -> Result<(Vec>, MF), Error> { + ) -> crate::Result<(Vec>, MF), MF::Error> { // Flush the pending unordered entries. self.chunks_total_size = self.write_chunk()?; - let Sorter { chunks, merge, .. } = self; + let Sorter { chunks, merge_function: merge, .. } = self; let result: Result, _> = chunks .into_iter() .map(|mut chunk| { @@ -669,6 +679,28 @@ where } } +impl Debug for Sorter { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Sorter") + .field("chunks_count", &self.chunks.len()) + .field("remaining_entries", &self.entries.remaining()) + .field("chunks_total_size", &self.chunks_total_size) + .field("allow_realloc", &self.allow_realloc) + .field("dump_threshold", &self.dump_threshold) + .field("max_nb_chunks", &self.max_nb_chunks) + .field("chunk_compression_type", &self.chunk_compression_type) + .field("chunk_compression_level", &self.chunk_compression_level) + .field("index_key_interval", &self.index_key_interval) + .field("block_size", &self.block_size) + .field("index_levels", &self.index_levels) + .field("chunk_creator", &"[chunck creator]") + .field("sort_algorithm", &self.sort_algorithm) + .field("sort_in_parallel", &self.sort_in_parallel) + .field("merge", &"[merge function]") + .finish() + } +} + /// A trait that represent a `ChunkCreator`. pub trait ChunkCreator { /// The generated chunk by this `ChunkCreator`. @@ -733,14 +765,25 @@ mod tests { use super::*; - fn merge<'a>(_key: &[u8], vals: &[Cow<'a, [u8]>]) -> Result, Infallible> { - Ok(vals.iter().map(AsRef::as_ref).flatten().cloned().collect()) + #[derive(Copy, Clone)] + struct ConcatMerger; + + impl MergeFunction for ConcatMerger { + type Error = Infallible; + + fn merge<'a>( + &self, + _key: &[u8], + values: &[Cow<'a, [u8]>], + ) -> std::result::Result, Self::Error> { + Ok(values.iter().flat_map(AsRef::as_ref).cloned().collect()) + } } #[test] #[cfg_attr(miri, ignore)] fn simple_cursorvec() { - let mut sorter = SorterBuilder::new(merge) + let mut sorter = SorterBuilder::new(ConcatMerger) .chunk_compression_type(CompressionType::Snappy) .chunk_creator(CursorVec) .build(); @@ -769,7 +812,7 @@ mod tests { #[test] #[cfg_attr(miri, ignore)] fn hard_cursorvec() { - let mut sorter = SorterBuilder::new(merge) + let mut sorter = SorterBuilder::new(ConcatMerger) .dump_threshold(1024) // 1KiB .allow_realloc(false) .chunk_compression_type(CompressionType::Snappy) @@ -803,20 +846,27 @@ mod tests { use rand::prelude::{SeedableRng, SliceRandom}; use rand::rngs::StdRng; - // This merge function concat bytes in the order they are received. - fn concat_bytes<'a>( - _key: &[u8], - values: &[Cow<'a, [u8]>], - ) -> Result, Infallible> { - let mut output = Vec::new(); - for value in values { - output.extend_from_slice(&value); + /// This merge function concat bytes in the order they are received. + struct ConcatBytesMerger; + + impl MergeFunction for ConcatBytesMerger { + type Error = Infallible; + + fn merge<'a>( + &self, + _key: &[u8], + values: &[Cow<'a, [u8]>], + ) -> std::result::Result, Self::Error> { + let mut output = Vec::new(); + for value in values { + output.extend_from_slice(value); + } + Ok(Cow::from(output)) } - Ok(Cow::from(output)) } // We create a sorter that will sum our u32s when necessary. - let mut sorter = SorterBuilder::new(concat_bytes).chunk_creator(CursorVec).build(); + let mut sorter = SorterBuilder::new(ConcatBytesMerger).chunk_creator(CursorVec).build(); // We insert all the possible values of an u8 in ascending order // but we split them along different keys. diff --git a/src/writer.rs b/src/writer.rs index fe0a0bd..ba45e39 100644 --- a/src/writer.rs +++ b/src/writer.rs @@ -146,9 +146,9 @@ impl Writer<()> { } } -impl Writer { +impl AsRef for Writer { /// Gets a reference to the underlying writer. - pub fn as_ref(&self) -> &W { + fn as_ref(&self) -> &W { self.writer.as_ref() } } @@ -330,7 +330,7 @@ mod tests { for x in 0..2000u32 { let x = x.to_be_bytes(); - writer.insert(&x, &x).unwrap(); + writer.insert(x, x).unwrap(); } let bytes = writer.into_inner().unwrap(); @@ -346,7 +346,7 @@ mod tests { for x in 0..2000u32 { let x = x.to_be_bytes(); - writer.insert(&x, &x).unwrap(); + writer.insert(x, x).unwrap(); } let bytes = writer.into_inner().unwrap(); @@ -363,7 +363,7 @@ mod tests { for x in 0..2000u32 { let x = x.to_be_bytes(); - writer.insert(&x, &x).unwrap(); + writer.insert(x, x).unwrap(); } let bytes = writer.into_inner().unwrap(); @@ -378,11 +378,11 @@ mod tests { .compression_type(grenad_0_4::CompressionType::Snappy) .memory(); - let total: u32 = 156_000; + let total: u32 = 1_500; for x in 0..total { let x = x.to_be_bytes(); - writer.insert(&x, &x).unwrap(); + writer.insert(x, x).unwrap(); } let bytes = writer.into_inner().unwrap();