From 8c5b6c14c0119b33a6ceb58c9d97ee68a313b2b2 Mon Sep 17 00:00:00 2001 From: coastalwhite Date: Wed, 13 Nov 2024 14:53:10 +0100 Subject: [PATCH] feat: Add rounding for Decimal type --- crates/polars-core/src/config.rs | 24 ++++ crates/polars-core/src/datatypes/mod.rs | 2 + .../src/datatypes/rounding_mode.rs | 20 +++ crates/polars-ops/src/series/ops/mod.rs | 4 + crates/polars-ops/src/series/ops/round.rs | 12 ++ .../src/series/ops/round_decimal.rs | 130 ++++++++++++++++++ 6 files changed, 192 insertions(+) create mode 100644 crates/polars-core/src/datatypes/rounding_mode.rs create mode 100644 crates/polars-ops/src/series/ops/round_decimal.rs diff --git a/crates/polars-core/src/config.rs b/crates/polars-core/src/config.rs index 919810811188..7a1477d8b57c 100644 --- a/crates/polars-core/src/config.rs +++ b/crates/polars-core/src/config.rs @@ -29,6 +29,9 @@ pub(crate) const FMT_TABLE_INLINE_COLUMN_DATA_TYPE: &str = pub(crate) const FMT_TABLE_ROUNDED_CORNERS: &str = "POLARS_FMT_TABLE_ROUNDED_CORNERS"; pub(crate) const FMT_TABLE_CELL_LIST_LEN: &str = "POLARS_FMT_TABLE_CELL_LIST_LEN"; +#[cfg(feature = "dtype-decimal")] +pub(crate) const DECIMAL_ROUNDING_MODE: &str = "POLARS_DECIMAL_ROUNDING_MODE"; + pub fn verbose() -> bool { std::env::var("POLARS_VERBOSE").as_deref().unwrap_or("") == "1" } @@ -46,6 +49,27 @@ pub fn get_rg_prefetch_size() -> usize { .unwrap_or_else(|_| std::cmp::max(get_file_prefetch_size(), 128)) } +#[cfg(feature = "dtype-decimal")] +pub fn get_decimal_rounding_mode() -> crate::datatypes::RoundingMode { + use crate::datatypes::RoundingMode as RM; + + let Ok(value) = std::env::var(DECIMAL_ROUNDING_MODE) else { + return RM::default(); + }; + + match &value[..] { + "ROUND_CEILING" => RM::Ceiling, + "ROUND_DOWN" => RM::Down, + "ROUND_FLOOR" => RM::Floor, + "ROUND_HALF_DOWN" => RM::HalfDown, + "ROUND_HALF_EVEN" => RM::HalfEven, + "ROUND_HALF_UP" => RM::HalfUp, + "ROUND_UP" => RM::Up, + "ROUND_05UP" => RM::Up05, + _ => panic!("Invalid rounding mode '{value}' given through `{DECIMAL_ROUNDING_MODE}` environment value."), + } +} + pub fn force_async() -> bool { std::env::var("POLARS_FORCE_ASYNC") .map(|value| value == "1") diff --git a/crates/polars-core/src/datatypes/mod.rs b/crates/polars-core/src/datatypes/mod.rs index 8d84d47be978..dd72a84bfdf7 100644 --- a/crates/polars-core/src/datatypes/mod.rs +++ b/crates/polars-core/src/datatypes/mod.rs @@ -16,6 +16,7 @@ mod into_scalar; #[cfg(feature = "object")] mod static_array_collect; mod time_unit; +mod rounding_mode; use std::cmp::Ordering; use std::fmt::{Display, Formatter}; @@ -23,6 +24,7 @@ use std::hash::{Hash, Hasher}; use std::ops::{Add, AddAssign, Div, Mul, Rem, Sub, SubAssign}; pub use aliases::*; +pub use rounding_mode::RoundingMode; pub use any_value::*; pub use arrow::array::{ArrayCollectIterExt, ArrayFromIter, ArrayFromIterDtype, StaticArray}; pub use arrow::datatypes::reshape::*; diff --git a/crates/polars-core/src/datatypes/rounding_mode.rs b/crates/polars-core/src/datatypes/rounding_mode.rs new file mode 100644 index 000000000000..b71864c792af --- /dev/null +++ b/crates/polars-core/src/datatypes/rounding_mode.rs @@ -0,0 +1,20 @@ +#[derive(Default, Debug, Clone, Copy, PartialEq, Eq)] +pub enum RoundingMode { + /// Round towards Infinity. + #[default] + Ceiling, + /// Round towards zero. + Down, + /// Round towards -Infinity. + Floor, + /// Round to nearest with ties going towards zero. + HalfDown, + /// Round to nearest with ties going to nearest even integer. + HalfEven, + /// Round to nearest with ties going away from zero. + HalfUp, + /// Round away from zero. + Up, + /// Round away from zero if last digit after rounding towards zero would have been 0 or 5; otherwise round towards zero. + Up05, +} diff --git a/crates/polars-ops/src/series/ops/mod.rs b/crates/polars-ops/src/series/ops/mod.rs index b684815238f7..5e60cd359abc 100644 --- a/crates/polars-ops/src/series/ops/mod.rs +++ b/crates/polars-ops/src/series/ops/mod.rs @@ -53,6 +53,8 @@ mod rle; mod rolling; #[cfg(feature = "round_series")] mod round; +#[cfg(all(feature = "round_series", feature = "dtype-decimal"))] +mod round_decimal; #[cfg(feature = "search_sorted")] mod search_sorted; #[cfg(feature = "to_dummies")] @@ -122,6 +124,8 @@ pub use rle::*; pub use rolling::*; #[cfg(feature = "round_series")] pub use round::*; +#[cfg(all(feature = "round_series", feature = "dtype-decimal"))] +pub use round_decimal::*; #[cfg(feature = "search_sorted")] pub use search_sorted::*; #[cfg(feature = "to_dummies")] diff --git a/crates/polars-ops/src/series/ops/round.rs b/crates/polars-ops/src/series/ops/round.rs index 7ed6b2e40eed..b4794d79ef8f 100644 --- a/crates/polars-ops/src/series/ops/round.rs +++ b/crates/polars-ops/src/series/ops/round.rs @@ -35,6 +35,10 @@ pub trait RoundSeries: SeriesSealed { Ok(s) }; } + #[cfg(feature = "dtype-decimal")] + if let Some(ca) = s.try_decimal() { + return Ok(super::round_decimal::dec_round(ca, decimals).into_series()); + } polars_ensure!(s.dtype().is_numeric(), InvalidOperation: "round can only be used on numeric types" ); Ok(s.clone()) @@ -70,6 +74,10 @@ pub trait RoundSeries: SeriesSealed { let s = ca.apply_values(|val| val.floor()).into_series(); return Ok(s); } + #[cfg(feature = "dtype-decimal")] + if let Some(ca) = s.try_decimal() { + return Ok(super::round_decimal::dec_round_floor(ca, 0).into_series()); + } polars_ensure!(s.dtype().is_numeric(), InvalidOperation: "floor can only be used on numeric types" ); Ok(s.clone()) @@ -87,6 +95,10 @@ pub trait RoundSeries: SeriesSealed { let s = ca.apply_values(|val| val.ceil()).into_series(); return Ok(s); } + #[cfg(feature = "dtype-decimal")] + if let Some(ca) = s.try_decimal() { + return Ok(super::round_decimal::dec_round_ceiling(ca, 0).into_series()); + } polars_ensure!(s.dtype().is_numeric(), InvalidOperation: "ceil can only be used on numeric types" ); Ok(s.clone()) diff --git a/crates/polars-ops/src/series/ops/round_decimal.rs b/crates/polars-ops/src/series/ops/round_decimal.rs new file mode 100644 index 000000000000..7baed97cc1a8 --- /dev/null +++ b/crates/polars-ops/src/series/ops/round_decimal.rs @@ -0,0 +1,130 @@ +use polars_core::config; +use polars_core::prelude::*; + +pub fn dec_round(ca: &DecimalChunked, decimals: u32) -> DecimalChunked { + dec_round_with_rm(ca, decimals, config::get_decimal_rounding_mode()) +} + +pub fn dec_round_with_rm(ca: &DecimalChunked, decimals: u32, rm: RoundingMode) -> DecimalChunked { + match rm { + RoundingMode::Ceiling => dec_round_ceiling(ca, decimals), + RoundingMode::Down => dec_round_down(ca, decimals), + RoundingMode::Floor => dec_round_floor(ca, decimals), + RoundingMode::HalfDown => dec_round_half_down(ca, decimals), + RoundingMode::HalfEven => dec_round_half_even(ca, decimals), + RoundingMode::HalfUp => dec_round_half_up(ca, decimals), + RoundingMode::Up => dec_round_up(ca, decimals), + RoundingMode::Up05 => dec_round_up05(ca, decimals), + } +} + +fn dec_round_generic( + ca: &DecimalChunked, + decimals: u32, + f: impl Fn(i128, i128, i128) -> i128, +) -> DecimalChunked { + let precision = ca.precision(); + let scale = ca.scale() as u32; + if scale <= decimals { + return ca.clone(); + } + + let decimal_delta = scale - decimals; + let multiplier = 10i128.pow(decimal_delta); + let threshold = multiplier / 2; + + ca.apply_values(|v| f(v, multiplier, threshold)) + .into_decimal_unchecked(precision, scale as usize) +} + +pub fn dec_round_ceiling(ca: &DecimalChunked, decimals: u32) -> DecimalChunked { + dec_round_generic(ca, decimals, |v, multiplier, _| { + // @TODO: Optimize + let rem = v % multiplier; + if v < 0 { + v + rem.abs() + } else { + if rem == 0 { + v + } else { + v + (multiplier - rem) + } + } + }) +} + +pub fn dec_round_down(ca: &DecimalChunked, decimals: u32) -> DecimalChunked { + dec_round_generic(ca, decimals, |v, multiplier, _| v - (v % multiplier)) +} + +pub fn dec_round_floor(ca: &DecimalChunked, decimals: u32) -> DecimalChunked { + dec_round_generic(ca, decimals, |v, multiplier, _| { + // @TODO: Optimize + let rem = v % multiplier; + if v < 0 { + if rem == 0 { + v + } else { + v - (multiplier - rem.abs()) + } + } else { + v - rem + } + }) +} + +pub fn dec_round_half_down(ca: &DecimalChunked, decimals: u32) -> DecimalChunked { + dec_round_generic(ca, decimals, |v, multiplier, threshold| { + let rem = v % multiplier; + let round_offset = if rem.abs() > threshold { multiplier } else { 0 }; + let round_offset = if v < 0 { -round_offset } else { round_offset }; + v - rem + round_offset + }) +} + +pub fn dec_round_half_even(ca: &DecimalChunked, decimals: u32) -> DecimalChunked { + dec_round_generic(ca, decimals, |v, multiplier, threshold| { + let rem = v % multiplier; + let is_v_floor_even = ((v - rem) / multiplier) % 2 == 0; + let threshold = threshold + i128::from(is_v_floor_even); + let round_offset = if rem.abs() >= threshold { + multiplier + } else { + 0 + }; + let round_offset = if v < 0 { -round_offset } else { round_offset }; + v - rem + round_offset + }) +} + +pub fn dec_round_half_up(ca: &DecimalChunked, decimals: u32) -> DecimalChunked { + dec_round_generic(ca, decimals, |v, multiplier, threshold| { + let rem = v % multiplier; + let round_offset = if rem.abs() >= threshold { + multiplier + } else { + 0 + }; + let round_offset = if v < 0 { -round_offset } else { round_offset }; + v - rem + round_offset + }) +} + +pub fn dec_round_up(ca: &DecimalChunked, decimals: u32) -> DecimalChunked { + dec_round_generic(ca, decimals, |v, multiplier, _| v + (multiplier - (v % multiplier))) +} + +pub fn dec_round_up05(_ca: &DecimalChunked, _decimals: u32) -> DecimalChunked { + // assert_eq!(v.len(), target.len()); + // + // if scale <= decimals { + // target.copy_from_slice(v); + // return; + // } + // + // let decimal_delta = scale - decimals; + // let multiplier = 10i128.pow(decimal_delta); + // let threshold = multiplier / 2; + + todo!() +}