Skip to content

Commit

Permalink
Merge pull request #55 from meilisearch/various-improvements
Browse files Browse the repository at this point in the history
Various Improvements
  • Loading branch information
Kerollmops authored Nov 25, 2024
2 parents c0fd7f7 + 9c666a1 commit 323a77e
Show file tree
Hide file tree
Showing 12 changed files with 259 additions and 148 deletions.
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[package]
name = "grenad"
description = "Tools to sort, merge, write, and read immutable key-value pairs."
version = "0.4.7"
version = "0.5.0"
authors = ["Kerollmops <[email protected]>"]
repository = "https://github.com/meilisearch/grenad"
documentation = "https://docs.rs/grenad"
Expand All @@ -11,6 +11,7 @@ license = "MIT"
[dependencies]
bytemuck = { version = "1.16.1", features = ["derive"] }
byteorder = "1.5.0"
either = { version = "1.13.0", default-features = false }
flate2 = { version = "1.0", optional = true }
lz4_flex = { version = "0.11.3", optional = true }
rayon = { version = "1.10.0", optional = true }
Expand Down
2 changes: 1 addition & 1 deletion benches/index-levels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ fn index_levels(bytes: &[u8]) {

for x in (0..NUMBER_OF_ENTRIES).step_by(1_567) {
let num = x.to_be_bytes();
cursor.move_on_key_greater_than_or_equal_to(&num).unwrap().unwrap();
cursor.move_on_key_greater_than_or_equal_to(num).unwrap().unwrap();
}
}

Expand Down
6 changes: 3 additions & 3 deletions src/block_writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ impl BlockWriter {
/// Insert a key that must be greater than the previously added one.
pub fn insert(&mut self, key: &[u8], val: &[u8]) {
debug_assert!(self.index_key_counter <= self.index_key_interval.get());
assert!(key.len() <= u32::max_value() as usize);
assert!(val.len() <= u32::max_value() as usize);
assert!(key.len() <= u32::MAX as usize);
assert!(val.len() <= u32::MAX as usize);

if self.index_key_counter == self.index_key_interval.get() {
self.index_offsets.push(self.buffer.len() as u64);
Expand All @@ -106,7 +106,7 @@ impl BlockWriter {
// and save the current key to become the last key.
match &mut self.last_key {
Some(last_key) => {
assert!(key > last_key, "{:?} must be greater than {:?}", key, last_key);
assert!(key > last_key.as_slice(), "{:?} must be greater than {:?}", key, last_key);
last_key.clear();
last_key.extend_from_slice(key);
}
Expand Down
12 changes: 5 additions & 7 deletions src/compression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@ use std::str::FromStr;
use std::{fmt, io};

/// The different supported types of compression.
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[derive(Default, Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[repr(u8)]
pub enum CompressionType {
/// Do not compress the blocks.
#[default]
None = 0,
/// Use the [`snap`] crate to de/compress the blocks.
///
Expand Down Expand Up @@ -55,12 +56,6 @@ impl FromStr for CompressionType {
}
}

impl Default for CompressionType {
fn default() -> CompressionType {
CompressionType::None
}
}

/// An invalid compression type have been read and the block can't be de/compressed.
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct InvalidCompressionType;
Expand Down Expand Up @@ -107,6 +102,7 @@ fn zlib_decompress<R: io::Read>(data: R, out: &mut Vec<u8>) -> io::Result<()> {
}

#[cfg(not(feature = "zlib"))]
#[allow(clippy::ptr_arg)] // it doesn't understand that I need the same signature for all function
fn zlib_decompress<R: io::Read>(_data: R, _out: &mut Vec<u8>) -> io::Result<()> {
Err(io::Error::new(io::ErrorKind::Other, "unsupported zlib decompression"))
}
Expand Down Expand Up @@ -186,6 +182,7 @@ fn zstd_decompress<R: io::Read>(data: R, out: &mut Vec<u8>) -> io::Result<()> {
}

#[cfg(not(feature = "zstd"))]
#[allow(clippy::ptr_arg)] // it doesn't understand that I need the same signature for all function
fn zstd_decompress<R: io::Read>(_data: R, _out: &mut Vec<u8>) -> io::Result<()> {
Err(io::Error::new(io::ErrorKind::Other, "unsupported zstd decompression"))
}
Expand All @@ -211,6 +208,7 @@ fn lz4_decompress<R: io::Read>(data: R, out: &mut Vec<u8>) -> io::Result<()> {
}

#[cfg(not(feature = "lz4"))]
#[allow(clippy::ptr_arg)] // it doesn't understand that I need the same signature for all function
fn lz4_decompress<R: io::Read>(_data: R, _out: &mut Vec<u8>) -> io::Result<()> {
Err(io::Error::new(io::ErrorKind::Other, "unsupported lz4 decompression"))
}
Expand Down
60 changes: 34 additions & 26 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,23 +72,25 @@
//! use std::convert::TryInto;
//! use std::io::Cursor;
//!
//! use grenad::{MergerBuilder, Reader, Writer};
//! use grenad::{MergerBuilder, MergeFunction, Reader, Writer};
//!
//! // This merge function:
//! // - parses u32s from native-endian bytes,
//! // - wrapping sums them and,
//! // - outputs the result as native-endian bytes.
//! fn wrapping_sum_u32s<'a>(
//! _key: &[u8],
//! values: &[Cow<'a, [u8]>],
//! ) -> Result<Cow<'a, [u8]>, TryFromSliceError>
//! {
//! let mut output: u32 = 0;
//! for bytes in values.iter().map(AsRef::as_ref) {
//! let num = bytes.try_into().map(u32::from_ne_bytes)?;
//! output = output.wrapping_add(num);
//! struct WrappingSumU32s;
//!
//! impl MergeFunction for WrappingSumU32s {
//! type Error = TryFromSliceError;
//!
//! fn merge<'a>(&self, key: &[u8], values: &[Cow<'a, [u8]>]) -> Result<Cow<'a, [u8]>, Self::Error> {
//! let mut output: u32 = 0;
//! for bytes in values.iter().map(AsRef::as_ref) {
//! let num = bytes.try_into().map(u32::from_ne_bytes)?;
//! output = output.wrapping_add(num);
//! }
//! Ok(Cow::Owned(output.to_ne_bytes().to_vec()))
//! }
//! Ok(Cow::Owned(output.to_ne_bytes().to_vec()))
//! }
//!
//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
Expand All @@ -115,7 +117,7 @@
//!
//! // We create a merger that will sum our u32s when necessary,
//! // and we add our readers to the list of readers to merge.
//! let merger_builder = MergerBuilder::new(wrapping_sum_u32s);
//! let merger_builder = MergerBuilder::new(WrappingSumU32s);
//! let merger = merger_builder.add(readera).add(readerb).add(readerc).build();
//!
//! // We can iterate over the entries in key-order.
Expand All @@ -142,28 +144,30 @@
//! use std::borrow::Cow;
//! use std::convert::TryInto;
//!
//! use grenad::{CursorVec, SorterBuilder};
//! use grenad::{CursorVec, MergeFunction, SorterBuilder};
//!
//! // This merge function:
//! // - parses u32s from native-endian bytes,
//! // - wrapping sums them and,
//! // - outputs the result as native-endian bytes.
//! fn wrapping_sum_u32s<'a>(
//! _key: &[u8],
//! values: &[Cow<'a, [u8]>],
//! ) -> Result<Cow<'a, [u8]>, TryFromSliceError>
//! {
//! let mut output: u32 = 0;
//! for bytes in values.iter().map(AsRef::as_ref) {
//! let num = bytes.try_into().map(u32::from_ne_bytes)?;
//! output = output.wrapping_add(num);
//! struct WrappingSumU32s;
//!
//! impl MergeFunction for WrappingSumU32s {
//! type Error = TryFromSliceError;
//!
//! fn merge<'a>(&self, key: &[u8], values: &[Cow<'a, [u8]>]) -> Result<Cow<'a, [u8]>, Self::Error> {
//! let mut output: u32 = 0;
//! for bytes in values.iter().map(AsRef::as_ref) {
//! let num = bytes.try_into().map(u32::from_ne_bytes)?;
//! output = output.wrapping_add(num);
//! }
//! Ok(Cow::Owned(output.to_ne_bytes().to_vec()))
//! }
//! Ok(Cow::Owned(output.to_ne_bytes().to_vec()))
//! }
//!
//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
//! // We create a sorter that will sum our u32s when necessary.
//! let mut sorter = SorterBuilder::new(wrapping_sum_u32s).chunk_creator(CursorVec).build();
//! let mut sorter = SorterBuilder::new(WrappingSumU32s).chunk_creator(CursorVec).build();
//!
//! // We insert multiple entries with the same key but different values
//! // in arbitrary order, the sorter will take care of merging them for us.
Expand All @@ -187,14 +191,15 @@
#[cfg(test)]
#[macro_use]
extern crate quickcheck;

use std::convert::Infallible;
use std::mem;

mod block;
mod block_writer;
mod compression;
mod count_write;
mod error;
mod merge_function;
mod merger;
mod metadata;
mod reader;
Expand All @@ -204,6 +209,7 @@ mod writer;

pub use self::compression::CompressionType;
pub use self::error::Error;
pub use self::merge_function::MergeFunction;
pub use self::merger::{Merger, MergerBuilder, MergerIter};
pub use self::metadata::FileVersion;
pub use self::reader::{PrefixIter, RangeIter, Reader, ReaderCursor, RevPrefixIter, RevRangeIter};
Expand All @@ -214,10 +220,12 @@ pub use self::sorter::{
};
pub use self::writer::{Writer, WriterBuilder};

pub type Result<T, U = Infallible> = std::result::Result<T, Error<U>>;

/// Sometimes we need to use an unsafe trick to make the compiler happy.
/// You can read more about the issue [on the Rust's Github issues].
///
/// [on the Rust's Github issues]: https://github.com/rust-lang/rust/issues/47680
unsafe fn transmute_entry_to_static(key: &[u8], val: &[u8]) -> (&'static [u8], &'static [u8]) {
(mem::transmute(key), mem::transmute(val))
(mem::transmute::<&[u8], &'static [u8]>(key), mem::transmute::<&[u8], &'static [u8]>(val))
}
46 changes: 46 additions & 0 deletions src/merge_function.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
use std::borrow::Cow;
use std::result::Result;

use either::Either;

/// A trait defining the way we merge multiple
/// values sharing the same key.
pub trait MergeFunction {
type Error;
fn merge<'a>(&self, key: &[u8], values: &[Cow<'a, [u8]>])
-> Result<Cow<'a, [u8]>, Self::Error>;
}

impl<MF> MergeFunction for &MF
where
MF: MergeFunction,
{
type Error = MF::Error;

fn merge<'a>(
&self,
key: &[u8],
values: &[Cow<'a, [u8]>],
) -> Result<Cow<'a, [u8]>, Self::Error> {
(*self).merge(key, values)
}
}

impl<MFA, MFB> MergeFunction for Either<MFA, MFB>
where
MFA: MergeFunction,
MFB: MergeFunction<Error = MFA::Error>,
{
type Error = MFA::Error;

fn merge<'a>(
&self,
key: &[u8],
values: &[Cow<'a, [u8]>],
) -> Result<Cow<'a, [u8]>, Self::Error> {
match self {
Either::Left(mfa) => mfa.merge(key, values),
Either::Right(mfb) => mfb.merge(key, values),
}
}
}
23 changes: 13 additions & 10 deletions src/merger.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::collections::BinaryHeap;
use std::io;
use std::iter::once;

use crate::{Error, ReaderCursor, Writer};
use crate::{Error, MergeFunction, ReaderCursor, Writer};

/// A struct that is used to configure a [`Merger`] with the sources to merge.
pub struct MergerBuilder<R, MF> {
Expand All @@ -20,6 +20,7 @@ impl<R, MF> MergerBuilder<R, MF> {
}

/// Add a source to merge, this function can be chained.
#[allow(clippy::should_implement_trait)] // We return interior references
pub fn add(mut self, source: ReaderCursor<R>) -> Self {
self.push(source);
self
Expand Down Expand Up @@ -95,7 +96,7 @@ impl<R: io::Read + io::Seek, MF> Merger<R, MF> {
}

Ok(MergerIter {
merge: self.merge,
merge_function: self.merge,
heap,
current_key: Vec::new(),
merged_value: Vec::new(),
Expand All @@ -104,16 +105,16 @@ impl<R: io::Read + io::Seek, MF> Merger<R, MF> {
}
}

impl<R, MF, U> Merger<R, MF>
impl<R, MF> Merger<R, MF>
where
R: io::Read + io::Seek,
MF: for<'a> Fn(&[u8], &[Cow<'a, [u8]>]) -> Result<Cow<'a, [u8]>, U>,
MF: MergeFunction,
{
/// Consumes this [`Merger`] and streams the entries to the [`Writer`] given in parameter.
pub fn write_into_stream_writer<W: io::Write>(
self,
writer: &mut Writer<W>,
) -> Result<(), Error<U>> {
) -> crate::Result<(), MF::Error> {
let mut iter = self.into_stream_merger_iter().map_err(Error::convert_merge_error)?;
while let Some((key, val)) = iter.next()? {
writer.insert(key, val)?;
Expand All @@ -124,21 +125,23 @@ where

/// An iterator that yield the merged entries in key-order.
pub struct MergerIter<R, MF> {
merge: MF,
merge_function: MF,
heap: BinaryHeap<Entry<R>>,
current_key: Vec<u8>,
merged_value: Vec<u8>,
/// We keep this buffer to avoid allocating a vec every time.
tmp_entries: Vec<Entry<R>>,
}

impl<R, MF, U> MergerIter<R, MF>
impl<R, MF> MergerIter<R, MF>
where
R: io::Read + io::Seek,
MF: for<'a> Fn(&[u8], &[Cow<'a, [u8]>]) -> Result<Cow<'a, [u8]>, U>,
MF: MergeFunction,
{
/// Yield the entries in key-order where values have been merged when needed.
pub fn next(&mut self) -> Result<Option<(&[u8], &[u8])>, Error<U>> {
#[allow(clippy::should_implement_trait)] // We return interior references
#[allow(clippy::type_complexity)] // Return type is not THAT complex
pub fn next(&mut self) -> crate::Result<Option<(&[u8], &[u8])>, MF::Error> {
let first_entry = match self.heap.pop() {
Some(entry) => entry,
None => return Ok(None),
Expand Down Expand Up @@ -167,7 +170,7 @@ where
self.tmp_entries.iter().filter_map(|e| e.cursor.current().map(|(_, v)| v));
let values: Vec<_> = once(first_value).chain(other_values).map(Cow::Borrowed).collect();

match (self.merge)(first_key, &values) {
match self.merge_function.merge(first_key, &values) {
Ok(value) => {
self.current_key.clear();
self.current_key.extend_from_slice(first_key);
Expand Down
12 changes: 7 additions & 5 deletions src/reader/prefix_iter.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::io;

use crate::{Error, ReaderCursor};
use crate::ReaderCursor;

/// An iterator that is able to yield all the entries with
/// a key that starts with a given prefix.
Expand All @@ -18,7 +18,8 @@ impl<R: io::Read + io::Seek> PrefixIter<R> {
}

/// Returns the next entry that starts with the given prefix.
pub fn next(&mut self) -> Result<Option<(&[u8], &[u8])>, Error> {
#[allow(clippy::should_implement_trait)] // We return interior references
pub fn next(&mut self) -> crate::Result<Option<(&[u8], &[u8])>> {
let entry = if self.move_on_first_prefix {
self.move_on_first_prefix = false;
self.cursor.move_on_key_greater_than_or_equal_to(&self.prefix)?
Expand Down Expand Up @@ -49,7 +50,8 @@ impl<R: io::Read + io::Seek> RevPrefixIter<R> {
}

/// Returns the next entry that starts with the given prefix.
pub fn next(&mut self) -> Result<Option<(&[u8], &[u8])>, Error> {
#[allow(clippy::should_implement_trait)] // We return interior references
pub fn next(&mut self) -> crate::Result<Option<(&[u8], &[u8])>> {
let entry = if self.move_on_last_prefix {
self.move_on_last_prefix = false;
move_on_last_prefix(&mut self.cursor, self.prefix.clone())?
Expand All @@ -68,7 +70,7 @@ impl<R: io::Read + io::Seek> RevPrefixIter<R> {
fn move_on_last_prefix<R: io::Read + io::Seek>(
cursor: &mut ReaderCursor<R>,
prefix: Vec<u8>,
) -> Result<Option<(&[u8], &[u8])>, Error> {
) -> crate::Result<Option<(&[u8], &[u8])>> {
match advance_key(prefix) {
Some(next_prefix) => match cursor.move_on_key_lower_than_or_equal_to(&next_prefix)? {
Some((k, _)) if k == next_prefix => cursor.move_on_prev(),
Expand Down Expand Up @@ -108,7 +110,7 @@ mod tests {
let mut writer = Writer::memory();
for x in (10..24000u32).step_by(3) {
let x = x.to_be_bytes();
writer.insert(&x, &x).unwrap();
writer.insert(x, x).unwrap();
}

let bytes = writer.into_inner().unwrap();
Expand Down
Loading

0 comments on commit 323a77e

Please sign in to comment.