From 5e795c24953aa0f2820cc9f79091c4478071b13d Mon Sep 17 00:00:00 2001 From: nxs Date: Sun, 17 Mar 2024 19:35:35 +1100 Subject: [PATCH] perf: Coerce sorted flag of unit arrays during concat (#15104) --- crates/polars-core/src/chunked_array/mod.rs | 8 +- .../src/chunked_array/ops/append.rs | 64 +++++++++---- py-polars/tests/unit/operations/test_sort.py | 91 +++++++++++++++++-- 3 files changed, 139 insertions(+), 24 deletions(-) diff --git a/crates/polars-core/src/chunked_array/mod.rs b/crates/polars-core/src/chunked_array/mod.rs index 64fb025321b0..247c132f3a2e 100644 --- a/crates/polars-core/src/chunked_array/mod.rs +++ b/crates/polars-core/src/chunked_array/mod.rs @@ -229,7 +229,9 @@ impl ChunkedArray { None } // We now know there is at least 1 non-null item in the array, and self.len() > 0 - else if self.is_sorted_any() { + else if self.null_count() == 0 { + Some(0) + } else if self.is_sorted_any() { let out = if unsafe { self.downcast_get_unchecked(0).is_null_unchecked(0) } { // nulls are all at the start self.null_count() @@ -256,7 +258,9 @@ impl ChunkedArray { None } // We now know there is at least 1 non-null item in the array, and self.len() > 0 - else if self.is_sorted_any() { + else if self.null_count() == 0 { + Some(self.len() - 1) + } else if self.is_sorted_any() { let out = if unsafe { self.downcast_get_unchecked(0).is_null_unchecked(0) } { // nulls are all at the start self.len() - 1 diff --git a/crates/polars-core/src/chunked_array/ops/append.rs b/crates/polars-core/src/chunked_array/ops/append.rs index d2749c2764bb..3325519eddb6 100644 --- a/crates/polars-core/src/chunked_array/ops/append.rs +++ b/crates/polars-core/src/chunked_array/ops/append.rs @@ -55,10 +55,17 @@ where } }, (true, true) => { - // both arrays have non-null values - if !ca.is_sorted_any() - || !other.is_sorted_any() - || ca.is_sorted_flag() != other.is_sorted_flag() + // both arrays have non-null values. + // for arrays of unit length we can ignore the sorted flag, as it is + // not necessarily set. + if !(ca.is_sorted_any() || ca.len() == 1) + || !(other.is_sorted_any() || other.len() == 1) + || !( + // We will coerce for single values + ca.len() - ca.null_count() == 1 + || other.len() - other.null_count() == 1 + || ca.is_sorted_flag() == other.is_sorted_flag() + ) { IsSorted::Not } else { @@ -68,7 +75,7 @@ where let l_val = unsafe { ca.value_unchecked(l_idx) }; let r_val = unsafe { other.value_unchecked(r_idx) }; - let keep_sorted = + let null_pos_check = // check null positions // lhs does not end in nulls (1 + l_idx == ca.len()) @@ -77,18 +84,43 @@ where // if there are nulls, they are all on one end && !(ca.first_non_null().unwrap() != 0 && 1 + other.last_non_null().unwrap() != other.len()); - let keep_sorted = keep_sorted - // compare values - && if ca.is_sorted_ascending_flag() { - l_val.tot_le(&r_val) - } else { - l_val.tot_ge(&r_val) - }; - - if keep_sorted { - ca.is_sorted_flag() - } else { + if !null_pos_check { IsSorted::Not + } else { + #[allow(unused_assignments)] + let mut out = IsSorted::Not; + + #[allow(clippy::never_loop)] + loop { + match ( + ca.len() - ca.null_count() == 1, + other.len() - other.null_count() == 1, + ) { + (true, true) => { + out = [IsSorted::Descending, IsSorted::Ascending] + [l_val.tot_le(&r_val) as usize]; + break; + }, + (true, false) => out = other.is_sorted_flag(), + _ => out = ca.is_sorted_flag(), + } + + debug_assert!(!matches!(out, IsSorted::Not)); + + let check = if matches!(out, IsSorted::Ascending) { + l_val.tot_le(&r_val) + } else { + l_val.tot_ge(&r_val) + }; + + if !check { + out = IsSorted::Not + } + + break; + } + + out } } }, diff --git a/py-polars/tests/unit/operations/test_sort.py b/py-polars/tests/unit/operations/test_sort.py index b3ef5631e7b8..aef3fcd8a5d9 100644 --- a/py-polars/tests/unit/operations/test_sort.py +++ b/py-polars/tests/unit/operations/test_sort.py @@ -9,6 +9,14 @@ from polars.testing import assert_frame_equal, assert_series_equal +def is_sorted_any(s: pl.Series) -> bool: + return s.flags["SORTED_ASC"] or s.flags["SORTED_DESC"] + + +def is_not_sorted(s: pl.Series) -> bool: + return not is_sorted_any(s) + + def test_sort_dates_multiples() -> None: df = pl.DataFrame( [ @@ -799,12 +807,6 @@ def test_sorted_flag_14552() -> None: def test_sorted_flag_concat_15072() -> None: - def is_sorted_any(s: pl.Series) -> bool: - return s.flags["SORTED_ASC"] or s.flags["SORTED_DESC"] - - def is_not_sorted(s: pl.Series) -> bool: - return not is_sorted_any(s) - # Both all-null a = pl.Series("x", [None, None], dtype=pl.Int8) b = pl.Series("x", [None, None], dtype=pl.Int8) @@ -903,3 +905,80 @@ def is_not_sorted(s: pl.Series) -> bool: out = pl.concat((s, s.clear())) assert_series_equal(out, s) assert out.flags["SORTED_ASC"] + + +@pytest.mark.parametrize("unit_descending", [True, False]) +def test_sorted_flag_concat_unit(unit_descending: bool) -> None: + unit = pl.Series([1]).set_sorted(descending=unit_descending) + + a = unit + b = pl.Series([2, 3]).set_sorted() + + out = pl.concat((a, b)) + assert out.to_list() == [1, 2, 3] + assert out.flags["SORTED_ASC"] + + out = pl.concat((b, a)) + assert out.to_list() == [2, 3, 1] + assert is_not_sorted(out) + + a = unit + b = pl.Series([3, 2]).set_sorted(descending=True) + + out = pl.concat((a, b)) + assert out.to_list() == [1, 3, 2] + assert is_not_sorted(out) + + out = pl.concat((b, a)) + assert out.to_list() == [3, 2, 1] + assert out.flags["SORTED_DESC"] + + # unit with nulls first + unit = pl.Series([None, 1]).set_sorted(descending=unit_descending) + + a = unit + b = pl.Series([2, 3]).set_sorted() + + out = pl.concat((a, b)) + assert out.to_list() == [None, 1, 2, 3] + assert out.flags["SORTED_ASC"] + + out = pl.concat((b, a)) + assert out.to_list() == [2, 3, None, 1] + assert is_not_sorted(out) + + a = unit + b = pl.Series([3, 2]).set_sorted(descending=True) + + out = pl.concat((a, b)) + assert out.to_list() == [None, 1, 3, 2] + assert is_not_sorted(out) + + out = pl.concat((b, a)) + assert out.to_list() == [3, 2, None, 1] + assert is_not_sorted(out) + + # unit with nulls last + unit = pl.Series([1, None]).set_sorted(descending=unit_descending) + + a = unit + b = pl.Series([2, 3]).set_sorted() + + out = pl.concat((a, b)) + assert out.to_list() == [1, None, 2, 3] + assert is_not_sorted(out) + + out = pl.concat((b, a)) + assert out.to_list() == [2, 3, 1, None] + assert is_not_sorted(out) + + a = unit + b = pl.Series([3, 2]).set_sorted(descending=True) + + out = pl.concat((a, b)) + assert out.to_list() == [1, None, 3, 2] + assert is_not_sorted(out) + + out = pl.concat((b, a)) + assert out.to_list() == [3, 2, 1, None] + assert out.flags["SORTED_DESC"]