Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Various Improvements #55

Merged
merged 10 commits into from
Nov 25, 2024
Merged
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[package]
name = "grenad"
description = "Tools to sort, merge, write, and read immutable key-value pairs."
version = "0.4.7"
version = "0.5.0"
authors = ["Kerollmops <[email protected]>"]
repository = "https://github.com/meilisearch/grenad"
documentation = "https://docs.rs/grenad"
Expand All @@ -11,6 +11,7 @@ license = "MIT"
[dependencies]
bytemuck = { version = "1.16.1", features = ["derive"] }
byteorder = "1.5.0"
either = { version = "1.13.0", default-features = false }
flate2 = { version = "1.0", optional = true }
lz4_flex = { version = "0.11.3", optional = true }
rayon = { version = "1.10.0", optional = true }
Expand Down
2 changes: 1 addition & 1 deletion benches/index-levels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ fn index_levels(bytes: &[u8]) {

for x in (0..NUMBER_OF_ENTRIES).step_by(1_567) {
let num = x.to_be_bytes();
cursor.move_on_key_greater_than_or_equal_to(&num).unwrap().unwrap();
cursor.move_on_key_greater_than_or_equal_to(num).unwrap().unwrap();
}
}

Expand Down
6 changes: 3 additions & 3 deletions src/block_writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ impl BlockWriter {
/// Insert a key that must be greater than the previously added one.
pub fn insert(&mut self, key: &[u8], val: &[u8]) {
debug_assert!(self.index_key_counter <= self.index_key_interval.get());
assert!(key.len() <= u32::max_value() as usize);
assert!(val.len() <= u32::max_value() as usize);
assert!(key.len() <= u32::MAX as usize);
assert!(val.len() <= u32::MAX as usize);

if self.index_key_counter == self.index_key_interval.get() {
self.index_offsets.push(self.buffer.len() as u64);
Expand All @@ -106,7 +106,7 @@ impl BlockWriter {
// and save the current key to become the last key.
match &mut self.last_key {
Some(last_key) => {
assert!(key > last_key, "{:?} must be greater than {:?}", key, last_key);
assert!(key > last_key.as_slice(), "{:?} must be greater than {:?}", key, last_key);
last_key.clear();
last_key.extend_from_slice(key);
}
Expand Down
12 changes: 5 additions & 7 deletions src/compression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@ use std::str::FromStr;
use std::{fmt, io};

/// The different supported types of compression.
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[derive(Default, Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[repr(u8)]
pub enum CompressionType {
/// Do not compress the blocks.
#[default]
None = 0,
/// Use the [`snap`] crate to de/compress the blocks.
///
Expand Down Expand Up @@ -55,12 +56,6 @@ impl FromStr for CompressionType {
}
}

impl Default for CompressionType {
fn default() -> CompressionType {
CompressionType::None
}
}

/// An invalid compression type have been read and the block can't be de/compressed.
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct InvalidCompressionType;
Expand Down Expand Up @@ -107,6 +102,7 @@ fn zlib_decompress<R: io::Read>(data: R, out: &mut Vec<u8>) -> io::Result<()> {
}

#[cfg(not(feature = "zlib"))]
#[allow(clippy::ptr_arg)] // it doesn't understand that I need the same signature for all function
fn zlib_decompress<R: io::Read>(_data: R, _out: &mut Vec<u8>) -> io::Result<()> {
Err(io::Error::new(io::ErrorKind::Other, "unsupported zlib decompression"))
}
Expand Down Expand Up @@ -186,6 +182,7 @@ fn zstd_decompress<R: io::Read>(data: R, out: &mut Vec<u8>) -> io::Result<()> {
}

#[cfg(not(feature = "zstd"))]
#[allow(clippy::ptr_arg)] // it doesn't understand that I need the same signature for all function
fn zstd_decompress<R: io::Read>(_data: R, _out: &mut Vec<u8>) -> io::Result<()> {
Err(io::Error::new(io::ErrorKind::Other, "unsupported zstd decompression"))
}
Expand All @@ -211,6 +208,7 @@ fn lz4_decompress<R: io::Read>(data: R, out: &mut Vec<u8>) -> io::Result<()> {
}

#[cfg(not(feature = "lz4"))]
#[allow(clippy::ptr_arg)] // it doesn't understand that I need the same signature for all function
fn lz4_decompress<R: io::Read>(_data: R, _out: &mut Vec<u8>) -> io::Result<()> {
Err(io::Error::new(io::ErrorKind::Other, "unsupported lz4 decompression"))
}
Expand Down
60 changes: 34 additions & 26 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,23 +72,25 @@
//! use std::convert::TryInto;
//! use std::io::Cursor;
//!
//! use grenad::{MergerBuilder, Reader, Writer};
//! use grenad::{MergerBuilder, MergeFunction, Reader, Writer};
//!
//! // This merge function:
//! // - parses u32s from native-endian bytes,
//! // - wrapping sums them and,
//! // - outputs the result as native-endian bytes.
//! fn wrapping_sum_u32s<'a>(
//! _key: &[u8],
//! values: &[Cow<'a, [u8]>],
//! ) -> Result<Cow<'a, [u8]>, TryFromSliceError>
//! {
//! let mut output: u32 = 0;
//! for bytes in values.iter().map(AsRef::as_ref) {
//! let num = bytes.try_into().map(u32::from_ne_bytes)?;
//! output = output.wrapping_add(num);
//! struct WrappingSumU32s;
//!
//! impl MergeFunction for WrappingSumU32s {
//! type Error = TryFromSliceError;
//!
//! fn merge<'a>(&self, key: &[u8], values: &[Cow<'a, [u8]>]) -> Result<Cow<'a, [u8]>, Self::Error> {
//! let mut output: u32 = 0;
//! for bytes in values.iter().map(AsRef::as_ref) {
//! let num = bytes.try_into().map(u32::from_ne_bytes)?;
//! output = output.wrapping_add(num);
//! }
//! Ok(Cow::Owned(output.to_ne_bytes().to_vec()))
//! }
//! Ok(Cow::Owned(output.to_ne_bytes().to_vec()))
//! }
//!
//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
Expand All @@ -115,7 +117,7 @@
//!
//! // We create a merger that will sum our u32s when necessary,
//! // and we add our readers to the list of readers to merge.
//! let merger_builder = MergerBuilder::new(wrapping_sum_u32s);
//! let merger_builder = MergerBuilder::new(WrappingSumU32s);
//! let merger = merger_builder.add(readera).add(readerb).add(readerc).build();
//!
//! // We can iterate over the entries in key-order.
Expand All @@ -142,28 +144,30 @@
//! use std::borrow::Cow;
//! use std::convert::TryInto;
//!
//! use grenad::{CursorVec, SorterBuilder};
//! use grenad::{CursorVec, MergeFunction, SorterBuilder};
//!
//! // This merge function:
//! // - parses u32s from native-endian bytes,
//! // - wrapping sums them and,
//! // - outputs the result as native-endian bytes.
//! fn wrapping_sum_u32s<'a>(
//! _key: &[u8],
//! values: &[Cow<'a, [u8]>],
//! ) -> Result<Cow<'a, [u8]>, TryFromSliceError>
//! {
//! let mut output: u32 = 0;
//! for bytes in values.iter().map(AsRef::as_ref) {
//! let num = bytes.try_into().map(u32::from_ne_bytes)?;
//! output = output.wrapping_add(num);
//! struct WrappingSumU32s;
//!
//! impl MergeFunction for WrappingSumU32s {
//! type Error = TryFromSliceError;
//!
//! fn merge<'a>(&self, key: &[u8], values: &[Cow<'a, [u8]>]) -> Result<Cow<'a, [u8]>, Self::Error> {
//! let mut output: u32 = 0;
//! for bytes in values.iter().map(AsRef::as_ref) {
//! let num = bytes.try_into().map(u32::from_ne_bytes)?;
//! output = output.wrapping_add(num);
//! }
//! Ok(Cow::Owned(output.to_ne_bytes().to_vec()))
//! }
//! Ok(Cow::Owned(output.to_ne_bytes().to_vec()))
//! }
//!
//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
//! // We create a sorter that will sum our u32s when necessary.
//! let mut sorter = SorterBuilder::new(wrapping_sum_u32s).chunk_creator(CursorVec).build();
//! let mut sorter = SorterBuilder::new(WrappingSumU32s).chunk_creator(CursorVec).build();
//!
//! // We insert multiple entries with the same key but different values
//! // in arbitrary order, the sorter will take care of merging them for us.
Expand All @@ -187,14 +191,15 @@
#[cfg(test)]
#[macro_use]
extern crate quickcheck;

use std::convert::Infallible;
use std::mem;

mod block;
mod block_writer;
mod compression;
mod count_write;
mod error;
mod merge_function;
mod merger;
mod metadata;
mod reader;
Expand All @@ -204,6 +209,7 @@ mod writer;

pub use self::compression::CompressionType;
pub use self::error::Error;
pub use self::merge_function::MergeFunction;
pub use self::merger::{Merger, MergerBuilder, MergerIter};
pub use self::metadata::FileVersion;
pub use self::reader::{PrefixIter, RangeIter, Reader, ReaderCursor, RevPrefixIter, RevRangeIter};
Expand All @@ -214,10 +220,12 @@ pub use self::sorter::{
};
pub use self::writer::{Writer, WriterBuilder};

pub type Result<T, U = Infallible> = std::result::Result<T, Error<U>>;

/// Sometimes we need to use an unsafe trick to make the compiler happy.
/// You can read more about the issue [on the Rust's Github issues].
///
/// [on the Rust's Github issues]: https://github.com/rust-lang/rust/issues/47680
unsafe fn transmute_entry_to_static(key: &[u8], val: &[u8]) -> (&'static [u8], &'static [u8]) {
(mem::transmute(key), mem::transmute(val))
(mem::transmute::<&[u8], &'static [u8]>(key), mem::transmute::<&[u8], &'static [u8]>(val))
}
46 changes: 46 additions & 0 deletions src/merge_function.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
use std::borrow::Cow;
use std::result::Result;

use either::Either;

/// A trait defining the way we merge multiple
/// values sharing the same key.
pub trait MergeFunction {
type Error;
fn merge<'a>(&self, key: &[u8], values: &[Cow<'a, [u8]>])
-> Result<Cow<'a, [u8]>, Self::Error>;
}

impl<MF> MergeFunction for &MF
where
MF: MergeFunction,
{
type Error = MF::Error;

fn merge<'a>(
&self,
key: &[u8],
values: &[Cow<'a, [u8]>],
) -> Result<Cow<'a, [u8]>, Self::Error> {
(*self).merge(key, values)
}
}

impl<MFA, MFB> MergeFunction for Either<MFA, MFB>
where
MFA: MergeFunction,
MFB: MergeFunction<Error = MFA::Error>,
{
type Error = MFA::Error;

fn merge<'a>(
&self,
key: &[u8],
values: &[Cow<'a, [u8]>],
) -> Result<Cow<'a, [u8]>, Self::Error> {
match self {
Either::Left(mfa) => mfa.merge(key, values),
Either::Right(mfb) => mfb.merge(key, values),
}
}
}
23 changes: 13 additions & 10 deletions src/merger.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::collections::BinaryHeap;
use std::io;
use std::iter::once;

use crate::{Error, ReaderCursor, Writer};
use crate::{Error, MergeFunction, ReaderCursor, Writer};

/// A struct that is used to configure a [`Merger`] with the sources to merge.
pub struct MergerBuilder<R, MF> {
Expand All @@ -20,6 +20,7 @@ impl<R, MF> MergerBuilder<R, MF> {
}

/// Add a source to merge, this function can be chained.
#[allow(clippy::should_implement_trait)] // We return interior references
pub fn add(mut self, source: ReaderCursor<R>) -> Self {
self.push(source);
self
Expand Down Expand Up @@ -95,7 +96,7 @@ impl<R: io::Read + io::Seek, MF> Merger<R, MF> {
}

Ok(MergerIter {
merge: self.merge,
merge_function: self.merge,
heap,
current_key: Vec::new(),
merged_value: Vec::new(),
Expand All @@ -104,16 +105,16 @@ impl<R: io::Read + io::Seek, MF> Merger<R, MF> {
}
}

impl<R, MF, U> Merger<R, MF>
impl<R, MF> Merger<R, MF>
where
R: io::Read + io::Seek,
MF: for<'a> Fn(&[u8], &[Cow<'a, [u8]>]) -> Result<Cow<'a, [u8]>, U>,
MF: MergeFunction,
{
/// Consumes this [`Merger`] and streams the entries to the [`Writer`] given in parameter.
pub fn write_into_stream_writer<W: io::Write>(
self,
writer: &mut Writer<W>,
) -> Result<(), Error<U>> {
) -> crate::Result<(), MF::Error> {
let mut iter = self.into_stream_merger_iter().map_err(Error::convert_merge_error)?;
while let Some((key, val)) = iter.next()? {
writer.insert(key, val)?;
Expand All @@ -124,21 +125,23 @@ where

/// An iterator that yield the merged entries in key-order.
pub struct MergerIter<R, MF> {
merge: MF,
merge_function: MF,
heap: BinaryHeap<Entry<R>>,
current_key: Vec<u8>,
merged_value: Vec<u8>,
/// We keep this buffer to avoid allocating a vec every time.
tmp_entries: Vec<Entry<R>>,
}

impl<R, MF, U> MergerIter<R, MF>
impl<R, MF> MergerIter<R, MF>
where
R: io::Read + io::Seek,
MF: for<'a> Fn(&[u8], &[Cow<'a, [u8]>]) -> Result<Cow<'a, [u8]>, U>,
MF: MergeFunction,
{
/// Yield the entries in key-order where values have been merged when needed.
pub fn next(&mut self) -> Result<Option<(&[u8], &[u8])>, Error<U>> {
#[allow(clippy::should_implement_trait)] // We return interior references
#[allow(clippy::type_complexity)] // Return type is not THAT complex
pub fn next(&mut self) -> crate::Result<Option<(&[u8], &[u8])>, MF::Error> {
let first_entry = match self.heap.pop() {
Some(entry) => entry,
None => return Ok(None),
Expand Down Expand Up @@ -167,7 +170,7 @@ where
self.tmp_entries.iter().filter_map(|e| e.cursor.current().map(|(_, v)| v));
let values: Vec<_> = once(first_value).chain(other_values).map(Cow::Borrowed).collect();

match (self.merge)(first_key, &values) {
match self.merge_function.merge(first_key, &values) {
Ok(value) => {
self.current_key.clear();
self.current_key.extend_from_slice(first_key);
Expand Down
12 changes: 7 additions & 5 deletions src/reader/prefix_iter.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::io;

use crate::{Error, ReaderCursor};
use crate::ReaderCursor;

/// An iterator that is able to yield all the entries with
/// a key that starts with a given prefix.
Expand All @@ -18,7 +18,8 @@ impl<R: io::Read + io::Seek> PrefixIter<R> {
}

/// Returns the next entry that starts with the given prefix.
pub fn next(&mut self) -> Result<Option<(&[u8], &[u8])>, Error> {
#[allow(clippy::should_implement_trait)] // We return interior references
pub fn next(&mut self) -> crate::Result<Option<(&[u8], &[u8])>> {
let entry = if self.move_on_first_prefix {
self.move_on_first_prefix = false;
self.cursor.move_on_key_greater_than_or_equal_to(&self.prefix)?
Expand Down Expand Up @@ -49,7 +50,8 @@ impl<R: io::Read + io::Seek> RevPrefixIter<R> {
}

/// Returns the next entry that starts with the given prefix.
pub fn next(&mut self) -> Result<Option<(&[u8], &[u8])>, Error> {
#[allow(clippy::should_implement_trait)] // We return interior references
pub fn next(&mut self) -> crate::Result<Option<(&[u8], &[u8])>> {
let entry = if self.move_on_last_prefix {
self.move_on_last_prefix = false;
move_on_last_prefix(&mut self.cursor, self.prefix.clone())?
Expand All @@ -68,7 +70,7 @@ impl<R: io::Read + io::Seek> RevPrefixIter<R> {
fn move_on_last_prefix<R: io::Read + io::Seek>(
cursor: &mut ReaderCursor<R>,
prefix: Vec<u8>,
) -> Result<Option<(&[u8], &[u8])>, Error> {
) -> crate::Result<Option<(&[u8], &[u8])>> {
match advance_key(prefix) {
Some(next_prefix) => match cursor.move_on_key_lower_than_or_equal_to(&next_prefix)? {
Some((k, _)) if k == next_prefix => cursor.move_on_prev(),
Expand Down Expand Up @@ -108,7 +110,7 @@ mod tests {
let mut writer = Writer::memory();
for x in (10..24000u32).step_by(3) {
let x = x.to_be_bytes();
writer.insert(&x, &x).unwrap();
writer.insert(x, x).unwrap();
}

let bytes = writer.into_inner().unwrap();
Expand Down
Loading
Loading