Skip to content

Commit

Permalink
feat: Implement is_in operation on decimal type (#17832)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Jul 24, 2024
1 parent 987d9b2 commit 8a36878
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 6 deletions.
2 changes: 1 addition & 1 deletion crates/polars-core/src/chunked_array/logical/decimal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ impl DecimalChunked {
}
}

pub(crate) fn to_scale(&self, scale: usize) -> PolarsResult<Cow<'_, Self>> {
pub fn to_scale(&self, scale: usize) -> PolarsResult<Cow<'_, Self>> {
if self.scale() == scale {
return Ok(Cow::Borrowed(self));
}
Expand Down
3 changes: 3 additions & 0 deletions crates/polars-core/src/chunked_array/logical/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,9 @@ impl<K: PolarsDataType, T: PolarsDataType> Logical<K, T>
where
Self: LogicalType,
{
pub fn physical(&self) -> &ChunkedArray<T> {
&self.0
}
pub fn field(&self) -> Field {
let name = self.0.ref_field().name();
Field::new(name, LogicalType::dtype(self).clone())
Expand Down
20 changes: 15 additions & 5 deletions crates/polars-ops/src/series/ops/is_in.rs
Original file line number Diff line number Diff line change
Expand Up @@ -716,18 +716,28 @@ pub fn is_in(s: &Series, other: &Series) -> PolarsResult<BooleanChunked> {
let ca = s.bool().unwrap();
is_in_boolean(ca, other)
},
DataType::Null => {
let series_bool = s.cast(&DataType::Boolean)?;
let ca = series_bool.bool().unwrap();
Ok(ca.clone())
},
#[cfg(feature = "dtype-decimal")]
DataType::Decimal(_, _) => {
let s = s.decimal()?;
let other = other.decimal()?;
let scale = s.scale().max(other.scale());
let s = s.to_scale(scale)?;
let other = other.to_scale(scale)?.into_owned().into_series();

is_in_numeric(s.physical(), &other)
},
dt if dt.to_physical().is_numeric() => {
let s = s.to_physical_repr();
with_match_physical_numeric_polars_type!(s.dtype(), |$T| {
let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref();
is_in_numeric(ca, other)
})
},
DataType::Null => {
let series_bool = s.cast(&DataType::Boolean)?;
let ca = series_bool.bool().unwrap();
Ok(ca.clone())
},
dt => polars_bail!(opq = is_in, dt),
}
}
6 changes: 6 additions & 0 deletions crates/polars-plan/src/plans/conversion/type_coercion/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,12 @@ impl OptimizationRule for TypeCoercionRule {
return Ok(None)
},
#[cfg(feature = "dtype-decimal")]
(DataType::Decimal(_, _), dt) if dt.is_numeric() => AExpr::Cast {
expr: other_e.node(),
data_type: type_left,
options: CastOptions::NonStrict,
},
#[cfg(feature = "dtype-decimal")]
(DataType::Decimal(_, _), _) | (_, DataType::Decimal(_, _)) => {
polars_bail!(InvalidOperation: "`is_in` cannot check for {:?} values in {:?} data", &type_other, &type_left)
},
Expand Down
13 changes: 13 additions & 0 deletions py-polars/tests/unit/operations/test_is_in.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from datetime import date
from decimal import Decimal as D
from typing import TYPE_CHECKING

import pytest
Expand Down Expand Up @@ -404,3 +405,15 @@ def test_is_in_struct_enum_17618() -> None:
)
)
).shape == (0, 1)


def test_is_in_decimal() -> None:
assert pl.DataFrame({"a": [D("0.0"), D("0.2"), D("0.1")]}).select(
pl.col("a").is_in([0.0, 0.1])
)["a"].to_list() == [True, False, True]
assert pl.DataFrame({"a": [D("0.0"), D("0.2"), D("0.1")]}).select(
pl.col("a").is_in([D("0.0"), D("0.1")])
)["a"].to_list() == [True, False, True]
assert pl.DataFrame({"a": [D("0.0"), D("0.2"), D("0.1")]}).select(
pl.col("a").is_in([1, 0, 2])
)["a"].to_list() == [True, False, False]

0 comments on commit 8a36878

Please sign in to comment.