diff --git a/crates/polars-parquet/src/arrow/write/dictionary.rs b/crates/polars-parquet/src/arrow/write/dictionary.rs index d8a1866e6bc0..4a507557d36b 100644 --- a/crates/polars-parquet/src/arrow/write/dictionary.rs +++ b/crates/polars-parquet/src/arrow/write/dictionary.rs @@ -1,7 +1,13 @@ -use arrow::array::{Array, BinaryViewArray, DictionaryArray, DictionaryKey, Utf8ViewArray}; +use arrow::array::{ + Array, BinaryViewArray, DictionaryArray, DictionaryKey, PrimitiveArray, Utf8ViewArray, +}; use arrow::bitmap::{Bitmap, MutableBitmap}; +use arrow::buffer::Buffer; use arrow::datatypes::{ArrowDataType, IntegerType}; +use arrow::types::NativeType; +use polars_compute::min_max::MinMaxKernel; use polars_error::{polars_bail, PolarsResult}; +use polars_utils::unwrap::UnwrapUncheckedRelease; use super::binary::{ build_statistics as binary_build_statistics, encode_plain as binary_encode_plain, @@ -24,12 +30,129 @@ use crate::parquet::statistics::ParquetStatistics; use crate::parquet::CowBuffer; use crate::write::DynIter; +trait MinMaxThreshold { + const DELTA_THRESHOLD: Self; +} + +macro_rules! minmaxthreshold_impls { + ($($t:ty => $threshold:literal,)+) => { + $( + impl MinMaxThreshold for $t { + const DELTA_THRESHOLD: Self = $threshold; + } + )+ + }; +} + +minmaxthreshold_impls! { + i8 => 16, + i16 => 256, + i32 => 512, + i64 => 2048, + u8 => 16, + u16 => 256, + u32 => 512, + u64 => 2048, +} + +fn min_max_integer_encode_as_dictionary_optional<'a, E, T>( + array: &'a dyn Array, +) -> Option> +where + E: std::fmt::Debug, + T: NativeType + + MinMaxThreshold + + std::cmp::Ord + + TryInto + + std::ops::Sub + + num_traits::CheckedSub, + std::ops::RangeInclusive: Iterator, + PrimitiveArray: MinMaxKernel = T>, +{ + use ArrowDataType as DT; + let (min, max): (T, T) = as MinMaxKernel>::min_max_ignore_nan_kernel( + array.as_any().downcast_ref().unwrap(), + )?; + + debug_assert!(max >= min, "{max} >= {min}"); + if !max + .checked_sub(&min) + .is_some_and(|v| v <= T::DELTA_THRESHOLD) + { + return None; + } + + // @TODO: This currently overestimates the values, it might be interesting to use the unique + // kernel here. + let values = PrimitiveArray::new(DT::from(T::PRIMITIVE), (min..=max).collect(), None); + let values = Box::new(values); + + let keys: Buffer = array + .as_any() + .downcast_ref::>() + .unwrap() + .values() + .iter() + .map(|v| unsafe { + // @NOTE: + // Since the values might contain nulls which have a undefined value. We just + // clamp the values to between the min and max value. This way, they will still + // be valid dictionary keys. This is mostly to make the + // unwrap_unchecked_release not produce any unsafety. + (*v.clamp(&min, &max) - min) + .try_into() + .unwrap_unchecked_release() + }) + .collect(); + + let keys = PrimitiveArray::new(DT::UInt32, keys, array.validity().cloned()); + Some( + DictionaryArray::::try_new( + ArrowDataType::Dictionary( + IntegerType::UInt32, + Box::new(DT::from(T::PRIMITIVE)), + false, // @TODO: This might be able to be set to true? + ), + keys, + values, + ) + .unwrap(), + ) +} + pub(crate) fn encode_as_dictionary_optional( array: &dyn Array, nested: &[Nested], type_: PrimitiveType, options: WriteOptions, ) -> Option>>> { + use ArrowDataType as DT; + let fast_dictionary = match array.data_type() { + DT::Int8 => min_max_integer_encode_as_dictionary_optional::<_, i8>(array), + DT::Int16 => min_max_integer_encode_as_dictionary_optional::<_, i16>(array), + DT::Int32 | DT::Date32 | DT::Time32(_) => { + min_max_integer_encode_as_dictionary_optional::<_, i32>(array) + }, + DT::Int64 | DT::Date64 | DT::Time64(_) | DT::Timestamp(_, _) | DT::Duration(_) => { + min_max_integer_encode_as_dictionary_optional::<_, i64>(array) + }, + DT::UInt8 => min_max_integer_encode_as_dictionary_optional::<_, u8>(array), + DT::UInt16 => min_max_integer_encode_as_dictionary_optional::<_, u16>(array), + DT::UInt32 => min_max_integer_encode_as_dictionary_optional::<_, u32>(array), + DT::UInt64 => min_max_integer_encode_as_dictionary_optional::<_, u64>(array), + _ => None, + }; + + if let Some(fast_dictionary) = fast_dictionary { + return Some(array_to_pages( + &fast_dictionary, + type_, + nested, + options, + Encoding::RleDictionary, + )); + } + let dtype = Box::new(array.data_type().clone()); let len_before = array.len();