Skip to content

Commit

Permalink
refactor: Migrate polars-expr AggregationContext to use Column (#19736
Browse files Browse the repository at this point in the history
)
  • Loading branch information
coastalwhite authored Nov 14, 2024
1 parent 869d1b9 commit 058491f
Show file tree
Hide file tree
Showing 29 changed files with 531 additions and 357 deletions.
154 changes: 132 additions & 22 deletions crates/polars-core/src/frame/column/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use std::borrow::Cow;

use arrow::bitmap::MutableBitmap;
use arrow::trusted_len::TrustMyLength;
use num_traits::{Num, NumCast};
use polars_error::PolarsResult;
use polars_utils::index::check_bounds;
Expand All @@ -8,6 +10,7 @@ pub use scalar::ScalarColumn;

use self::gather::check_bounds_ca;
use self::partitioned::PartitionedColumn;
use self::series::SeriesColumn;
use crate::chunked_array::cast::CastOptions;
use crate::chunked_array::metadata::{MetadataFlags, MetadataTrait};
use crate::datatypes::ReshapeDimension;
Expand All @@ -20,6 +23,7 @@ mod arithmetic;
mod compare;
mod partitioned;
mod scalar;
mod series;

/// A column within a [`DataFrame`].
///
Expand All @@ -35,7 +39,7 @@ mod scalar;
#[cfg_attr(feature = "serde", serde(from = "Series"))]
#[cfg_attr(feature = "serde", serde(into = "_SerdeSeries"))]
pub enum Column {
Series(Series),
Series(SeriesColumn),
Partitioned(PartitionedColumn),
Scalar(ScalarColumn),
}
Expand All @@ -47,12 +51,13 @@ pub trait IntoColumn: Sized {

impl Column {
#[inline]
#[track_caller]
pub fn new<T, Phantom>(name: PlSmallStr, values: T) -> Self
where
Phantom: ?Sized,
Series: NamedFrom<T, Phantom>,
{
Self::Series(NamedFrom::new(name, values))
Self::Series(SeriesColumn::new(NamedFrom::new(name, values)))
}

#[inline]
Expand Down Expand Up @@ -95,7 +100,7 @@ impl Column {
PartitionedColumn::new_empty(PlSmallStr::EMPTY, DataType::Null),
)
.take_materialized_series();
*self = Column::Series(series);
*self = Column::Series(series.into());
let Column::Series(s) = self else {
unreachable!();
};
Expand All @@ -107,7 +112,7 @@ impl Column {
ScalarColumn::new_empty(PlSmallStr::EMPTY, DataType::Null),
)
.take_materialized_series();
*self = Column::Series(series);
*self = Column::Series(series.into());
let Column::Series(s) = self else {
unreachable!();
};
Expand All @@ -121,7 +126,7 @@ impl Column {
#[inline]
pub fn take_materialized_series(self) -> Series {
match self {
Column::Series(s) => s,
Column::Series(s) => s.take(),
Column::Partitioned(s) => s.take_materialized_series(),
Column::Scalar(s) => s.take_materialized_series(),
}
Expand Down Expand Up @@ -586,31 +591,102 @@ impl Column {
}
}

/// General implementation for aggregation where a non-missing scalar would map to itself.
#[inline(always)]
#[cfg(any(feature = "algorithm_group_by", feature = "bitwise"))]
fn agg_with_unit_scalar(
&self,
groups: &GroupsProxy,
series_agg: impl Fn(&Series, &GroupsProxy) -> Series,
) -> Column {
match self {
Column::Series(s) => series_agg(s, groups).into_column(),
// @partition-opt
Column::Partitioned(s) => series_agg(s.as_materialized_series(), groups).into_column(),
Column::Scalar(s) => {
if s.is_empty() {
return self.clone();
}

// We utilize the aggregation on Series to see:
// 1. the output datatype of the aggregation
// 2. whether this aggregation is even defined
let series_aggregation = series_agg(
&s.as_single_value_series(),
&GroupsProxy::Slice {
// @NOTE: this group is always valid since s is non-empty.
groups: vec![[0, 1]],
rolling: false,
},
);

// If the aggregation is not defined, just return all nulls.
if series_aggregation.has_nulls() {
return Self::new_scalar(
series_aggregation.name().clone(),
Scalar::new(series_aggregation.dtype().clone(), AnyValue::Null),
groups.len(),
);
}

let mut scalar_col = s.resize(groups.len());
// The aggregation might change the type (e.g. mean changes int -> float), so we do
// a cast here to the output type.
if series_aggregation.dtype() != s.dtype() {
scalar_col = scalar_col.cast(series_aggregation.dtype()).unwrap();
}

let Some(first_empty_idx) = groups.iter().position(|g| g.is_empty()) else {
// Fast path: no empty groups. keep the scalar intact.
return scalar_col.into_column();
};

// All empty groups produce a *missing* or `null` value.
let mut validity = MutableBitmap::with_capacity(groups.len());
validity.extend_constant(first_empty_idx, true);
// SAFETY: We trust the length of this iterator.
let iter = unsafe {
TrustMyLength::new(
groups.iter().skip(first_empty_idx).map(|g| !g.is_empty()),
groups.len() - first_empty_idx,
)
};
validity.extend_from_trusted_len_iter(iter);
let validity = validity.freeze();

let mut s = scalar_col.take_materialized_series().rechunk();
// SAFETY: We perform a compute_len afterwards.
let chunks = unsafe { s.chunks_mut() };
chunks[0].with_validity(Some(validity));
s.compute_len();

s.into_column()
},
}
}

/// # Safety
///
/// Does no bounds checks, groups must be correct.
#[cfg(feature = "algorithm_group_by")]
pub unsafe fn agg_min(&self, groups: &GroupsProxy) -> Self {
// @scalar-opt
unsafe { self.as_materialized_series().agg_min(groups) }.into()
self.agg_with_unit_scalar(groups, |s, g| unsafe { s.agg_min(g) })
}

/// # Safety
///
/// Does no bounds checks, groups must be correct.
#[cfg(feature = "algorithm_group_by")]
pub unsafe fn agg_max(&self, groups: &GroupsProxy) -> Self {
// @scalar-opt
unsafe { self.as_materialized_series().agg_max(groups) }.into()
self.agg_with_unit_scalar(groups, |s, g| unsafe { s.agg_max(g) })
}

/// # Safety
///
/// Does no bounds checks, groups must be correct.
#[cfg(feature = "algorithm_group_by")]
pub unsafe fn agg_mean(&self, groups: &GroupsProxy) -> Self {
// @scalar-opt
unsafe { self.as_materialized_series().agg_mean(groups) }.into()
self.agg_with_unit_scalar(groups, |s, g| unsafe { s.agg_mean(g) })
}

/// # Safety
Expand All @@ -627,17 +703,15 @@ impl Column {
/// Does no bounds checks, groups must be correct.
#[cfg(feature = "algorithm_group_by")]
pub unsafe fn agg_first(&self, groups: &GroupsProxy) -> Self {
// @scalar-opt
unsafe { self.as_materialized_series().agg_first(groups) }.into()
self.agg_with_unit_scalar(groups, |s, g| unsafe { s.agg_first(g) })
}

/// # Safety
///
/// Does no bounds checks, groups must be correct.
#[cfg(feature = "algorithm_group_by")]
pub unsafe fn agg_last(&self, groups: &GroupsProxy) -> Self {
// @scalar-opt
unsafe { self.as_materialized_series().agg_last(groups) }.into()
self.agg_with_unit_scalar(groups, |s, g| unsafe { s.agg_last(g) })
}

/// # Safety
Expand Down Expand Up @@ -672,8 +746,7 @@ impl Column {
/// Does no bounds checks, groups must be correct.
#[cfg(feature = "algorithm_group_by")]
pub unsafe fn agg_median(&self, groups: &GroupsProxy) -> Self {
// @scalar-opt
unsafe { self.as_materialized_series().agg_median(groups) }.into()
self.agg_with_unit_scalar(groups, |s, g| unsafe { s.agg_median(g) })
}

/// # Safety
Expand All @@ -689,7 +762,7 @@ impl Column {
///
/// Does no bounds checks, groups must be correct.
#[cfg(feature = "algorithm_group_by")]
pub(crate) unsafe fn agg_std(&self, groups: &GroupsProxy, ddof: u8) -> Self {
pub unsafe fn agg_std(&self, groups: &GroupsProxy, ddof: u8) -> Self {
// @scalar-opt
unsafe { self.as_materialized_series().agg_std(groups, ddof) }.into()
}
Expand All @@ -713,6 +786,30 @@ impl Column {
unsafe { self.as_materialized_series().agg_valid_count(groups) }.into()
}

/// # Safety
///
/// Does no bounds checks, groups must be correct.
#[cfg(feature = "bitwise")]
pub fn agg_and(&self, groups: &GroupsProxy) -> Self {
self.agg_with_unit_scalar(groups, |s, g| unsafe { s.agg_and(g) })
}
/// # Safety
///
/// Does no bounds checks, groups must be correct.
#[cfg(feature = "bitwise")]
pub fn agg_or(&self, groups: &GroupsProxy) -> Self {
self.agg_with_unit_scalar(groups, |s, g| unsafe { s.agg_or(g) })
}
/// # Safety
///
/// Does no bounds checks, groups must be correct.
#[cfg(feature = "bitwise")]
pub fn agg_xor(&self, groups: &GroupsProxy) -> Self {
// @partition-opt
// @scalar-opt
unsafe { self.as_materialized_series().agg_xor(groups) }.into()
}

pub fn full_null(name: PlSmallStr, size: usize, dtype: &DataType) -> Self {
Self::new_scalar(name, Scalar::new(dtype.clone(), AnyValue::Null), size)
}
Expand Down Expand Up @@ -877,6 +974,13 @@ impl Column {
}
}

/// Packs every element into a list.
pub fn as_list(&self) -> ListChunked {
// @scalar-opt
// @partition-opt
self.as_materialized_series().as_list()
}

pub fn is_sorted_flag(&self) -> IsSorted {
// @scalar-opt
self.as_materialized_series().is_sorted_flag()
Expand Down Expand Up @@ -1105,19 +1209,25 @@ impl Column {

pub fn try_add_owned(self, other: Self) -> PolarsResult<Self> {
match (self, other) {
(Column::Series(lhs), Column::Series(rhs)) => lhs.try_add_owned(rhs).map(Column::from),
(Column::Series(lhs), Column::Series(rhs)) => {
lhs.take().try_add_owned(rhs.take()).map(Column::from)
},
(lhs, rhs) => lhs + rhs,
}
}
pub fn try_sub_owned(self, other: Self) -> PolarsResult<Self> {
match (self, other) {
(Column::Series(lhs), Column::Series(rhs)) => lhs.try_sub_owned(rhs).map(Column::from),
(Column::Series(lhs), Column::Series(rhs)) => {
lhs.take().try_sub_owned(rhs.take()).map(Column::from)
},
(lhs, rhs) => lhs - rhs,
}
}
pub fn try_mul_owned(self, other: Self) -> PolarsResult<Self> {
match (self, other) {
(Column::Series(lhs), Column::Series(rhs)) => lhs.try_mul_owned(rhs).map(Column::from),
(Column::Series(lhs), Column::Series(rhs)) => {
lhs.take().try_mul_owned(rhs.take()).map(Column::from)
},
(lhs, rhs) => lhs * rhs,
}
}
Expand Down Expand Up @@ -1443,7 +1553,7 @@ impl From<Series> for Column {
return Self::Scalar(ScalarColumn::unit_scalar_from_series(series));
}

Self::Series(series)
Self::Series(SeriesColumn::new(series))
}
}

Expand Down
2 changes: 1 addition & 1 deletion crates/polars-core/src/frame/column/partitioned.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ impl PartitionedColumn {

fn _to_series(name: PlSmallStr, values: &Series, ends: &[IdxSize]) -> Series {
let dtype = values.dtype();
let mut column = Column::Series(Series::new_empty(name, dtype));
let mut column = Column::Series(Series::new_empty(name, dtype).into());

let mut prev_offset = 0;
for (i, &offset) in ends.iter().enumerate() {
Expand Down
5 changes: 5 additions & 0 deletions crates/polars-core/src/frame/column/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,11 @@ impl ScalarColumn {
self.scalar.update(AnyValue::Null);
self
}

pub fn map_scalar(&mut self, map_scalar: impl Fn(Scalar) -> Scalar) {
self.scalar = map_scalar(std::mem::take(&mut self.scalar));
self.materialized.take();
}
}

impl IntoColumn for ScalarColumn {
Expand Down
71 changes: 71 additions & 0 deletions crates/polars-core/src/frame/column/series.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
use std::ops::{Deref, DerefMut};

use super::Series;

/// A very thin wrapper around [`Series`] that represents a [`Column`]ized version of [`Series`].
///
/// At the moment this just conditionally tracks where it was created so that materialization
/// problems can be tracked down.
#[derive(Debug, Clone)]
pub struct SeriesColumn {
inner: Series,

#[cfg(debug_assertions)]
materialized_at: Option<std::sync::Arc<std::backtrace::Backtrace>>,
}

impl SeriesColumn {
#[track_caller]
pub fn new(series: Series) -> Self {
Self {
inner: series,

#[cfg(debug_assertions)]
materialized_at: if std::env::var("POLARS_TRACK_SERIES_MATERIALIZATION").as_deref()
== Ok("1")
{
Some(std::sync::Arc::new(
std::backtrace::Backtrace::force_capture(),
))
} else {
None
},
}
}

pub fn materialized_at(&self) -> Option<&std::backtrace::Backtrace> {
#[cfg(debug_assertions)]
{
self.materialized_at.as_ref().map(|v| v.as_ref())
}

#[cfg(not(debug_assertions))]
None
}

pub fn take(self) -> Series {
self.inner
}
}

impl From<Series> for SeriesColumn {
#[track_caller]
#[inline(always)]
fn from(value: Series) -> Self {
Self::new(value)
}
}

impl Deref for SeriesColumn {
type Target = Series;

fn deref(&self) -> &Self::Target {
&self.inner
}
}

impl DerefMut for SeriesColumn {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.inner
}
}
Loading

0 comments on commit 058491f

Please sign in to comment.