Skip to content

Commit

Permalink
feat: Add rounding for Decimal type
Browse files Browse the repository at this point in the history
  • Loading branch information
coastalwhite committed Nov 13, 2024
1 parent 8cb7839 commit 8c5b6c1
Show file tree
Hide file tree
Showing 6 changed files with 192 additions and 0 deletions.
24 changes: 24 additions & 0 deletions crates/polars-core/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ pub(crate) const FMT_TABLE_INLINE_COLUMN_DATA_TYPE: &str =
pub(crate) const FMT_TABLE_ROUNDED_CORNERS: &str = "POLARS_FMT_TABLE_ROUNDED_CORNERS";
pub(crate) const FMT_TABLE_CELL_LIST_LEN: &str = "POLARS_FMT_TABLE_CELL_LIST_LEN";

#[cfg(feature = "dtype-decimal")]
pub(crate) const DECIMAL_ROUNDING_MODE: &str = "POLARS_DECIMAL_ROUNDING_MODE";

pub fn verbose() -> bool {
std::env::var("POLARS_VERBOSE").as_deref().unwrap_or("") == "1"
}
Expand All @@ -46,6 +49,27 @@ pub fn get_rg_prefetch_size() -> usize {
.unwrap_or_else(|_| std::cmp::max(get_file_prefetch_size(), 128))
}

#[cfg(feature = "dtype-decimal")]
pub fn get_decimal_rounding_mode() -> crate::datatypes::RoundingMode {
use crate::datatypes::RoundingMode as RM;

let Ok(value) = std::env::var(DECIMAL_ROUNDING_MODE) else {
return RM::default();
};

match &value[..] {
"ROUND_CEILING" => RM::Ceiling,
"ROUND_DOWN" => RM::Down,
"ROUND_FLOOR" => RM::Floor,
"ROUND_HALF_DOWN" => RM::HalfDown,
"ROUND_HALF_EVEN" => RM::HalfEven,
"ROUND_HALF_UP" => RM::HalfUp,
"ROUND_UP" => RM::Up,
"ROUND_05UP" => RM::Up05,
_ => panic!("Invalid rounding mode '{value}' given through `{DECIMAL_ROUNDING_MODE}` environment value."),
}
}

pub fn force_async() -> bool {
std::env::var("POLARS_FORCE_ASYNC")
.map(|value| value == "1")
Expand Down
2 changes: 2 additions & 0 deletions crates/polars-core/src/datatypes/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@ mod into_scalar;
#[cfg(feature = "object")]
mod static_array_collect;
mod time_unit;
mod rounding_mode;

use std::cmp::Ordering;
use std::fmt::{Display, Formatter};
use std::hash::{Hash, Hasher};
use std::ops::{Add, AddAssign, Div, Mul, Rem, Sub, SubAssign};

pub use aliases::*;
pub use rounding_mode::RoundingMode;
pub use any_value::*;
pub use arrow::array::{ArrayCollectIterExt, ArrayFromIter, ArrayFromIterDtype, StaticArray};
pub use arrow::datatypes::reshape::*;
Expand Down
20 changes: 20 additions & 0 deletions crates/polars-core/src/datatypes/rounding_mode.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#[derive(Default, Debug, Clone, Copy, PartialEq, Eq)]
pub enum RoundingMode {
/// Round towards Infinity.
#[default]
Ceiling,
/// Round towards zero.
Down,
/// Round towards -Infinity.
Floor,
/// Round to nearest with ties going towards zero.
HalfDown,
/// Round to nearest with ties going to nearest even integer.
HalfEven,
/// Round to nearest with ties going away from zero.
HalfUp,
/// Round away from zero.
Up,
/// Round away from zero if last digit after rounding towards zero would have been 0 or 5; otherwise round towards zero.
Up05,
}
4 changes: 4 additions & 0 deletions crates/polars-ops/src/series/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ mod rle;
mod rolling;
#[cfg(feature = "round_series")]
mod round;
#[cfg(all(feature = "round_series", feature = "dtype-decimal"))]
mod round_decimal;
#[cfg(feature = "search_sorted")]
mod search_sorted;
#[cfg(feature = "to_dummies")]
Expand Down Expand Up @@ -122,6 +124,8 @@ pub use rle::*;
pub use rolling::*;
#[cfg(feature = "round_series")]
pub use round::*;
#[cfg(all(feature = "round_series", feature = "dtype-decimal"))]
pub use round_decimal::*;
#[cfg(feature = "search_sorted")]
pub use search_sorted::*;
#[cfg(feature = "to_dummies")]
Expand Down
12 changes: 12 additions & 0 deletions crates/polars-ops/src/series/ops/round.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ pub trait RoundSeries: SeriesSealed {
Ok(s)
};
}
#[cfg(feature = "dtype-decimal")]
if let Some(ca) = s.try_decimal() {
return Ok(super::round_decimal::dec_round(ca, decimals).into_series());
}

polars_ensure!(s.dtype().is_numeric(), InvalidOperation: "round can only be used on numeric types" );
Ok(s.clone())
Expand Down Expand Up @@ -70,6 +74,10 @@ pub trait RoundSeries: SeriesSealed {
let s = ca.apply_values(|val| val.floor()).into_series();
return Ok(s);
}
#[cfg(feature = "dtype-decimal")]
if let Some(ca) = s.try_decimal() {
return Ok(super::round_decimal::dec_round_floor(ca, 0).into_series());
}

polars_ensure!(s.dtype().is_numeric(), InvalidOperation: "floor can only be used on numeric types" );
Ok(s.clone())
Expand All @@ -87,6 +95,10 @@ pub trait RoundSeries: SeriesSealed {
let s = ca.apply_values(|val| val.ceil()).into_series();
return Ok(s);
}
#[cfg(feature = "dtype-decimal")]
if let Some(ca) = s.try_decimal() {
return Ok(super::round_decimal::dec_round_ceiling(ca, 0).into_series());
}

polars_ensure!(s.dtype().is_numeric(), InvalidOperation: "ceil can only be used on numeric types" );
Ok(s.clone())
Expand Down
130 changes: 130 additions & 0 deletions crates/polars-ops/src/series/ops/round_decimal.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
use polars_core::config;
use polars_core::prelude::*;

pub fn dec_round(ca: &DecimalChunked, decimals: u32) -> DecimalChunked {
dec_round_with_rm(ca, decimals, config::get_decimal_rounding_mode())
}

pub fn dec_round_with_rm(ca: &DecimalChunked, decimals: u32, rm: RoundingMode) -> DecimalChunked {
match rm {
RoundingMode::Ceiling => dec_round_ceiling(ca, decimals),
RoundingMode::Down => dec_round_down(ca, decimals),
RoundingMode::Floor => dec_round_floor(ca, decimals),
RoundingMode::HalfDown => dec_round_half_down(ca, decimals),
RoundingMode::HalfEven => dec_round_half_even(ca, decimals),
RoundingMode::HalfUp => dec_round_half_up(ca, decimals),
RoundingMode::Up => dec_round_up(ca, decimals),
RoundingMode::Up05 => dec_round_up05(ca, decimals),
}
}

fn dec_round_generic(
ca: &DecimalChunked,
decimals: u32,
f: impl Fn(i128, i128, i128) -> i128,
) -> DecimalChunked {
let precision = ca.precision();
let scale = ca.scale() as u32;
if scale <= decimals {
return ca.clone();
}

let decimal_delta = scale - decimals;
let multiplier = 10i128.pow(decimal_delta);
let threshold = multiplier / 2;

ca.apply_values(|v| f(v, multiplier, threshold))
.into_decimal_unchecked(precision, scale as usize)
}

pub fn dec_round_ceiling(ca: &DecimalChunked, decimals: u32) -> DecimalChunked {
dec_round_generic(ca, decimals, |v, multiplier, _| {
// @TODO: Optimize
let rem = v % multiplier;
if v < 0 {
v + rem.abs()
} else {
if rem == 0 {
v
} else {
v + (multiplier - rem)
}
}
})
}

pub fn dec_round_down(ca: &DecimalChunked, decimals: u32) -> DecimalChunked {
dec_round_generic(ca, decimals, |v, multiplier, _| v - (v % multiplier))
}

pub fn dec_round_floor(ca: &DecimalChunked, decimals: u32) -> DecimalChunked {
dec_round_generic(ca, decimals, |v, multiplier, _| {
// @TODO: Optimize
let rem = v % multiplier;
if v < 0 {
if rem == 0 {
v
} else {
v - (multiplier - rem.abs())
}
} else {
v - rem
}
})
}

pub fn dec_round_half_down(ca: &DecimalChunked, decimals: u32) -> DecimalChunked {
dec_round_generic(ca, decimals, |v, multiplier, threshold| {
let rem = v % multiplier;
let round_offset = if rem.abs() > threshold { multiplier } else { 0 };
let round_offset = if v < 0 { -round_offset } else { round_offset };
v - rem + round_offset
})
}

pub fn dec_round_half_even(ca: &DecimalChunked, decimals: u32) -> DecimalChunked {
dec_round_generic(ca, decimals, |v, multiplier, threshold| {
let rem = v % multiplier;
let is_v_floor_even = ((v - rem) / multiplier) % 2 == 0;
let threshold = threshold + i128::from(is_v_floor_even);
let round_offset = if rem.abs() >= threshold {
multiplier
} else {
0
};
let round_offset = if v < 0 { -round_offset } else { round_offset };
v - rem + round_offset
})
}

pub fn dec_round_half_up(ca: &DecimalChunked, decimals: u32) -> DecimalChunked {
dec_round_generic(ca, decimals, |v, multiplier, threshold| {
let rem = v % multiplier;
let round_offset = if rem.abs() >= threshold {
multiplier
} else {
0
};
let round_offset = if v < 0 { -round_offset } else { round_offset };
v - rem + round_offset
})
}

pub fn dec_round_up(ca: &DecimalChunked, decimals: u32) -> DecimalChunked {
dec_round_generic(ca, decimals, |v, multiplier, _| v + (multiplier - (v % multiplier)))
}

pub fn dec_round_up05(_ca: &DecimalChunked, _decimals: u32) -> DecimalChunked {
// assert_eq!(v.len(), target.len());
//
// if scale <= decimals {
// target.copy_from_slice(v);
// return;
// }
//
// let decimal_delta = scale - decimals;
// let multiplier = 10i128.pow(decimal_delta);
// let threshold = multiplier / 2;

todo!()
}

0 comments on commit 8c5b6c1

Please sign in to comment.