Skip to content

Commit

Permalink
FilterMask Optimizations (#1950)
Browse files Browse the repository at this point in the history
Allow FilterMask to be cheaply created from various formats, and add
specialized implementations of `Eq` and `BitAnd`.

Fixes #1848
  • Loading branch information
gatesn authored Jan 15, 2025
1 parent 3ecb285 commit f30b8dd
Show file tree
Hide file tree
Showing 27 changed files with 613 additions and 371 deletions.
4 changes: 2 additions & 2 deletions encodings/alp/src/alp/compute/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,10 @@ impl SliceFn<ALPArray> for ALPEncoding {
}

impl FilterFn<ALPArray> for ALPEncoding {
fn filter(&self, array: &ALPArray, mask: FilterMask) -> VortexResult<ArrayData> {
fn filter(&self, array: &ALPArray, mask: &FilterMask) -> VortexResult<ArrayData> {
let patches = array
.patches()
.map(|p| p.filter(mask.clone()))
.map(|p| p.filter(mask))
.transpose()?
.flatten();

Expand Down
17 changes: 10 additions & 7 deletions encodings/alp/src/alp_rd/compute/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,16 @@ use vortex_error::VortexResult;
use crate::{ALPRDArray, ALPRDEncoding};

impl FilterFn<ALPRDArray> for ALPRDEncoding {
fn filter(&self, array: &ALPRDArray, mask: FilterMask) -> VortexResult<ArrayData> {
fn filter(&self, array: &ALPRDArray, mask: &FilterMask) -> VortexResult<ArrayData> {
let left_parts_exceptions = array
.left_parts_patches()
.map(|patches| patches.filter(mask.clone()))
.map(|patches| patches.filter(mask))
.transpose()?
.flatten();

Ok(ALPRDArray::try_new(
array.dtype().clone(),
filter(&array.left_parts(), mask.clone())?,
filter(&array.left_parts(), mask)?,
array.left_parts_dict(),
filter(&array.right_parts(), mask)?,
array.right_bit_width(),
Expand Down Expand Up @@ -46,10 +46,13 @@ mod test {
assert!(encoded.left_parts_patches().is_some());

// The first two values need no patching
let filtered = filter(encoded.as_ref(), FilterMask::from_iter([true, false, true]))
.unwrap()
.into_primitive()
.unwrap();
let filtered = filter(
encoded.as_ref(),
&FilterMask::from_iter([true, false, true]),
)
.unwrap()
.into_primitive()
.unwrap();
assert_eq!(filtered.as_slice::<T>(), &[a, outlier]);
}
}
6 changes: 3 additions & 3 deletions encodings/datetime-parts/src/compute/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@ use vortex_error::VortexResult;
use crate::{DateTimePartsArray, DateTimePartsEncoding};

impl FilterFn<DateTimePartsArray> for DateTimePartsEncoding {
fn filter(&self, array: &DateTimePartsArray, mask: FilterMask) -> VortexResult<ArrayData> {
fn filter(&self, array: &DateTimePartsArray, mask: &FilterMask) -> VortexResult<ArrayData> {
Ok(DateTimePartsArray::try_new(
array.dtype().clone(),
filter(array.days().as_ref(), mask.clone())?,
filter(array.seconds().as_ref(), mask.clone())?,
filter(array.days().as_ref(), mask)?,
filter(array.seconds().as_ref(), mask)?,
filter(array.subsecond().as_ref(), mask)?,
)?
.into_array())
Expand Down
2 changes: 1 addition & 1 deletion encodings/dict/src/compute/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ impl TakeFn<DictArray> for DictEncoding {
}

impl FilterFn<DictArray> for DictEncoding {
fn filter(&self, array: &DictArray, mask: FilterMask) -> VortexResult<ArrayData> {
fn filter(&self, array: &DictArray, mask: &FilterMask) -> VortexResult<ArrayData> {
let codes = filter(&array.codes(), mask)?;
DictArray::try_new(codes, array.values()).map(|a| a.into_array())
}
Expand Down
28 changes: 15 additions & 13 deletions encodings/fastlanes/src/bitpacking/compute/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use crate::bitpacking::compute::take::UNPACK_CHUNK_THRESHOLD;
use crate::{BitPackedArray, BitPackedEncoding};

impl FilterFn<BitPackedArray> for BitPackedEncoding {
fn filter(&self, array: &BitPackedArray, mask: FilterMask) -> VortexResult<ArrayData> {
fn filter(&self, array: &BitPackedArray, mask: &FilterMask) -> VortexResult<ArrayData> {
let primitive = match_each_unsigned_integer_ptype!(array.ptype().to_unsigned(), |$I| {
filter_primitive::<$I>(array, mask)
});
Expand All @@ -31,13 +31,13 @@ impl FilterFn<BitPackedArray> for BitPackedEncoding {
/// dictates the final `PType` of the result.
fn filter_primitive<T: NativePType + BitPacking + ArrowNativeType>(
array: &BitPackedArray,
mask: FilterMask,
mask: &FilterMask,
) -> VortexResult<PrimitiveArray> {
let validity = array.validity().filter(&mask)?;
let validity = array.validity().filter(mask)?;

let patches = array
.patches()
.map(|patches| patches.filter(mask.clone()))
.map(|patches| patches.filter(mask))
.transpose()?
.flatten();

Expand All @@ -47,15 +47,13 @@ fn filter_primitive<T: NativePType + BitPacking + ArrowNativeType>(
.and_then(|a| a.into_primitive());
}

let values: Buffer<T> = match mask.iter()? {
let values: Buffer<T> = match mask.iter() {
FilterIter::Indices(indices) => {
filter_indices(array, mask.true_count(), indices.iter().copied())
}
FilterIter::IndicesIter(iter) => filter_indices(array, mask.true_count(), iter),
FilterIter::Slices(slices) => {
filter_slices(array, mask.true_count(), slices.iter().copied())
}
FilterIter::SlicesIter(iter) => filter_slices(array, mask.true_count(), iter),
};

let mut values = PrimitiveArray::new(values, validity).reinterpret_cast(array.ptype());
Expand Down Expand Up @@ -143,9 +141,9 @@ mod test {
let unpacked = PrimitiveArray::from_iter((0..4096).map(|i| (i % 63) as u8));
let bitpacked = BitPackedArray::encode(unpacked.as_ref(), 6).unwrap();

let mask = FilterMask::from_indices(bitpacked.len(), [0, 125, 2047, 2049, 2151, 2790]);
let mask = FilterMask::from_indices(bitpacked.len(), vec![0, 125, 2047, 2049, 2151, 2790]);

let primitive_result = filter(bitpacked.as_ref(), mask)
let primitive_result = filter(bitpacked.as_ref(), &mask)
.unwrap()
.into_primitive()
.unwrap();
Expand All @@ -160,9 +158,9 @@ mod test {
let bitpacked = BitPackedArray::encode(unpacked.as_ref(), 6).unwrap();
let sliced = slice(bitpacked.as_ref(), 128, 2050).unwrap();

let mask = FilterMask::from_indices(sliced.len(), [1919, 1921]);
let mask = FilterMask::from_indices(sliced.len(), vec![1919, 1921]);

let primitive_result = filter(&sliced, mask).unwrap().into_primitive().unwrap();
let primitive_result = filter(&sliced, &mask).unwrap().into_primitive().unwrap();
let res_bytes = primitive_result.as_slice::<u8>();
assert_eq!(res_bytes, &[31, 33]);
}
Expand All @@ -171,7 +169,11 @@ mod test {
fn filter_bitpacked() {
let unpacked = PrimitiveArray::from_iter((0..4096).map(|i| (i % 63) as u8));
let bitpacked = BitPackedArray::encode(unpacked.as_ref(), 6).unwrap();
let filtered = filter(bitpacked.as_ref(), FilterMask::from_indices(4096, 0..1024)).unwrap();
let filtered = filter(
bitpacked.as_ref(),
&FilterMask::from_indices(4096, (0..1024).collect()),
)
.unwrap();
assert_eq!(
filtered.into_primitive().unwrap().as_slice::<u8>(),
(0..1024).map(|i| (i % 63) as u8).collect::<Vec<_>>()
Expand All @@ -185,7 +187,7 @@ mod test {
let bitpacked = BitPackedArray::encode(unpacked.as_ref(), 9).unwrap();
let filtered = filter(
bitpacked.as_ref(),
FilterMask::from_indices(values.len(), 0..250),
&FilterMask::from_indices(values.len(), (0..250).collect()),
)
.unwrap()
.into_primitive()
Expand Down
2 changes: 1 addition & 1 deletion encodings/fastlanes/src/for/compute/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ impl TakeFn<FoRArray> for FoREncoding {
}

impl FilterFn<FoRArray> for FoREncoding {
fn filter(&self, array: &FoRArray, mask: FilterMask) -> VortexResult<ArrayData> {
fn filter(&self, array: &FoRArray, mask: &FilterMask) -> VortexResult<ArrayData> {
FoRArray::try_new(
filter(&array.encoded(), mask)?,
array.reference_scalar(),
Expand Down
4 changes: 2 additions & 2 deletions encodings/fsst/src/compute/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,12 @@ impl ScalarAtFn<FSSTArray> for FSSTEncoding {

impl FilterFn<FSSTArray> for FSSTEncoding {
// Filtering an FSSTArray filters the codes array, leaving the symbols array untouched
fn filter(&self, array: &FSSTArray, mask: FilterMask) -> VortexResult<ArrayData> {
fn filter(&self, array: &FSSTArray, mask: &FilterMask) -> VortexResult<ArrayData> {
Ok(FSSTArray::try_new(
array.dtype().clone(),
array.symbols(),
array.symbol_lengths(),
filter(&array.codes(), mask.clone())?,
filter(&array.codes(), mask)?,
filter(&array.uncompressed_lengths(), mask)?,
)?
.into_array())
Expand Down
2 changes: 1 addition & 1 deletion encodings/fsst/tests/fsst_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ fn test_fsst_array_ops() {
// test filter
let mask = FilterMask::from_iter([false, true, false]);

let fsst_filtered = filter(&fsst_array, mask).unwrap();
let fsst_filtered = filter(&fsst_array, &mask).unwrap();
assert_eq!(fsst_filtered.encoding().id(), FSSTEncoding::ID);
assert_eq!(fsst_filtered.len(), 1);
assert_nth_scalar!(
Expand Down
12 changes: 6 additions & 6 deletions encodings/runend/src/compute/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,12 @@ impl SliceFn<RunEndArray> for RunEndEncoding {
}

impl FilterFn<RunEndArray> for RunEndEncoding {
fn filter(&self, array: &RunEndArray, mask: FilterMask) -> VortexResult<ArrayData> {
fn filter(&self, array: &RunEndArray, mask: &FilterMask) -> VortexResult<ArrayData> {
let primitive_run_ends = array.ends().into_primitive()?;
let (run_ends, values_mask) = match_each_unsigned_integer_ptype!(primitive_run_ends.ptype(), |$P| {
filter_run_ends(primitive_run_ends.as_slice::<$P>(), array.offset() as u64, array.len() as u64, mask)?
});
let values = filter(&array.values(), values_mask)?;
let values = filter(&array.values(), &values_mask)?;

RunEndArray::try_new(run_ends.into_array(), values).map(|a| a.into_array())
}
Expand All @@ -108,14 +108,14 @@ fn filter_run_ends<R: NativePType + AddAssign + From<bool> + AsPrimitive<u64>>(
run_ends: &[R],
offset: u64,
length: u64,
mask: FilterMask,
mask: &FilterMask,
) -> VortexResult<(PrimitiveArray, FilterMask)> {
let mut new_run_ends = buffer_mut![R::zero(); run_ends.len()];

let mut start = 0u64;
let mut j = 0;
let mut count = R::zero();
let filter_values = mask.to_boolean_buffer()?;
let filter_values = mask.boolean_buffer();

let new_mask: FilterMask = BooleanBuffer::collect_bool(run_ends.len(), |i| {
let mut keep = false;
Expand Down Expand Up @@ -278,7 +278,7 @@ mod test {
let arr = ree_array();
let filtered = filter(
arr.as_ref(),
FilterMask::from_iter([
&FilterMask::from_iter([
true, true, false, false, false, false, false, false, false, false, true, true,
]),
)
Expand Down Expand Up @@ -308,7 +308,7 @@ mod test {
let arr = slice(ree_array(), 2, 7).unwrap();
let filtered = filter(
&arr,
FilterMask::from_iter([true, false, false, true, true]),
&FilterMask::from_iter([true, false, false, true, true]),
)
.unwrap();
let filtered_run_end = RunEndArray::try_from(filtered).unwrap();
Expand Down
4 changes: 2 additions & 2 deletions encodings/zigzag/src/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ impl ComputeVTable for ZigZagEncoding {
}

impl FilterFn<ZigZagArray> for ZigZagEncoding {
fn filter(&self, array: &ZigZagArray, mask: FilterMask) -> VortexResult<ArrayData> {
fn filter(&self, array: &ZigZagArray, mask: &FilterMask) -> VortexResult<ArrayData> {
let encoded = filter(&array.encoded(), mask)?;
Ok(ZigZagArray::try_new(encoded)?.into_array())
}
Expand Down Expand Up @@ -145,7 +145,7 @@ mod tests {
fn filter_zigzag() {
let zigzag = ZigZagArray::encode(&buffer![-189, -160, 1].into_array()).unwrap();
let filter_mask = BooleanBuffer::from(vec![true, false, true]).into();
let actual = filter(&zigzag.into_array(), filter_mask)
let actual = filter(&zigzag.into_array(), &filter_mask)
.unwrap()
.into_primitive()
.unwrap();
Expand Down
2 changes: 1 addition & 1 deletion fuzz/fuzz_targets/array_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ fuzz_target!(|fuzz_action: FuzzArrayAction| -> Corpus {
assert_search_sorted(sorted, s, side, expected.search(), i)
}
Action::Filter(mask) => {
current_array = filter(&current_array, mask).unwrap();
current_array = filter(&current_array, &mask).unwrap();
assert_array_eq(&expected.array(), &current_array, i);
}
}
Expand Down
2 changes: 1 addition & 1 deletion pyvortex/src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ impl PyArray {
fn filter(&self, filter: &Bound<PyArray>) -> PyResult<PyArray> {
let filter = filter.borrow();
let inner =
vortex::compute::filter(&self.inner, FilterMask::try_from(filter.inner.clone())?)?;
vortex::compute::filter(&self.inner, &FilterMask::try_from(filter.inner.clone())?)?;
Ok(PyArray { inner })
}

Expand Down
14 changes: 4 additions & 10 deletions vortex-array/src/array/bool/compute/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,16 @@ use crate::compute::{FilterFn, FilterIter, FilterMask};
use crate::{ArrayData, IntoArrayData};

impl FilterFn<BoolArray> for BoolEncoding {
fn filter(&self, array: &BoolArray, mask: FilterMask) -> VortexResult<ArrayData> {
let validity = array.validity().filter(&mask)?;
fn filter(&self, array: &BoolArray, mask: &FilterMask) -> VortexResult<ArrayData> {
let validity = array.validity().filter(mask)?;

let buffer = match mask.iter()? {
let buffer = match mask.iter() {
FilterIter::Indices(indices) => filter_indices_slice(&array.boolean_buffer(), indices),
FilterIter::IndicesIter(iter) => {
filter_indices(&array.boolean_buffer(), mask.true_count(), iter)
}
FilterIter::Slices(slices) => filter_slices(
&array.boolean_buffer(),
mask.true_count(),
slices.iter().copied(),
),
FilterIter::SlicesIter(iter) => {
filter_slices(&array.boolean_buffer(), mask.true_count(), iter)
}
};

Ok(BoolArray::try_new(buffer, validity)?.into_array())
Expand Down Expand Up @@ -84,7 +78,7 @@ mod test {
let arr = BoolArray::from_iter([true, true, false]);
let mask = FilterMask::from_iter([true, false, true]);

let filtered = filter(&arr.into_array(), mask)
let filtered = filter(&arr.into_array(), &mask)
.unwrap()
.into_bool()
.unwrap();
Expand Down
Loading

0 comments on commit f30b8dd

Please sign in to comment.