Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Multi-output column expressions in frame sort method #17947

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,16 @@ use crate::utils::_split_offsets;
pub(crate) fn args_validate<T: PolarsDataType>(
ca: &ChunkedArray<T>,
other: &[Series],
descending: &[bool],
param_value: &[bool],
param_name: &str,
) -> PolarsResult<()> {
for s in other {
assert_eq!(ca.len(), s.len());
}
polars_ensure!(other.len() == (descending.len() - 1),
polars_ensure!(other.len() == (param_value.len() - 1),
ComputeError:
"the amount of ordering booleans: {} does not match the number of series: {}",
descending.len(), other.len() + 1,
"the length of `{}` ({}) does not match the number of series ({})",
param_name, param_value.len(), other.len() + 1,
);
Ok(())
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ impl CategoricalChunked {
options: &SortMultipleOptions,
) -> PolarsResult<IdxCa> {
if self.uses_lexical_ordering() {
args_validate(self.physical(), by, &options.descending)?;
args_validate(self.physical(), by, &options.descending, "descending")?;
args_validate(self.physical(), by, &options.nulls_last, "nulls_last")?;
let mut count: IdxSize = 0;

// we use bytes to save a monomorphisized str impl
Expand Down
11 changes: 6 additions & 5 deletions crates/polars-core/src/chunked_array/ops/sort/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,8 @@ fn arg_sort_multiple_numeric<T: PolarsNumericType>(
by: &[Series],
options: &SortMultipleOptions,
) -> PolarsResult<IdxCa> {
args_validate(ca, by, &options.descending)?;
args_validate(ca, by, &options.descending, "descending")?;
args_validate(ca, by, &options.nulls_last, "nulls_last")?;
let mut count: IdxSize = 0;

let no_nulls = ca.null_count() == 0;
Expand Down Expand Up @@ -426,8 +427,8 @@ impl ChunkSort<BinaryType> for BinaryChunked {
by: &[Series],
options: &SortMultipleOptions,
) -> PolarsResult<IdxCa> {
args_validate(self, by, &options.descending)?;

args_validate(self, by, &options.descending, "descending")?;
args_validate(self, by, &options.nulls_last, "nulls_last")?;
let mut count: IdxSize = 0;

let mut vals = Vec::with_capacity(self.len());
Expand Down Expand Up @@ -568,8 +569,8 @@ impl ChunkSort<BinaryOffsetType> for BinaryOffsetChunked {
by: &[Series],
options: &SortMultipleOptions,
) -> PolarsResult<IdxCa> {
args_validate(self, by, &options.descending)?;

args_validate(self, by, &options.descending, "descending")?;
args_validate(self, by, &options.nulls_last, "nulls_last")?;
let mut count: IdxSize = 0;

let mut vals = Vec::with_capacity(self.len());
Expand Down
8 changes: 3 additions & 5 deletions crates/polars-core/src/frame/explode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -305,11 +305,9 @@ impl DataFrame {

// values will all be placed in single column, so we must find their supertype
let schema = self.schema();
let mut iter = on.iter().map(|v| {
schema
.get(v)
.ok_or_else(|| polars_err!(ColumnNotFound: "{}", v))
});
let mut iter = on
.iter()
.map(|v| schema.get(v).ok_or_else(|| polars_err!(col_not_found = v)));
let mut st = iter.next().unwrap()?.clone();
for dt in iter {
st = try_get_supertype(&st, dt?)?;
Expand Down
22 changes: 12 additions & 10 deletions crates/polars-core/src/frame/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ impl DataFrame {
/// Get the index of the column.
fn check_name_to_idx(&self, name: &str) -> PolarsResult<usize> {
self.get_column_index(name)
.ok_or_else(|| polars_err!(ColumnNotFound: "{}", name))
.ok_or_else(|| polars_err!(col_not_found = name))
}

fn check_already_present(&self, name: &str) -> PolarsResult<()> {
Expand Down Expand Up @@ -1361,7 +1361,7 @@ impl DataFrame {
/// Get column index of a [`Series`] by name.
pub fn try_get_column_index(&self, name: &str) -> PolarsResult<usize> {
self.get_column_index(name)
.ok_or_else(|| polars_err!(ColumnNotFound: "{}", name))
.ok_or_else(|| polars_err!(col_not_found = name))
}

/// Select a single column by name.
Expand Down Expand Up @@ -1560,7 +1560,7 @@ impl DataFrame {
.map(|name| {
let idx = *name_to_idx
.get(name.as_str())
.ok_or_else(|| polars_err!(ColumnNotFound: "{}", name))?;
.ok_or_else(|| polars_err!(col_not_found = name))?;
Ok(self
.select_at_idx(idx)
.unwrap()
Expand Down Expand Up @@ -1588,7 +1588,7 @@ impl DataFrame {
.map(|name| {
let idx = *name_to_idx
.get(name.as_str())
.ok_or_else(|| polars_err!(ColumnNotFound: "{}", name))?;
.ok_or_else(|| polars_err!(col_not_found = name))?;
Ok(self.select_at_idx(idx).unwrap().clone())
})
.collect::<PolarsResult<Vec<_>>>()?
Expand Down Expand Up @@ -1696,7 +1696,7 @@ impl DataFrame {
/// ```
pub fn rename(&mut self, column: &str, name: &str) -> PolarsResult<&mut Self> {
self.select_mut(column)
.ok_or_else(|| polars_err!(ColumnNotFound: "{}", column))
.ok_or_else(|| polars_err!(col_not_found = column))
.map(|s| s.rename(name))?;
let unique_names: AHashSet<&str, ahash::RandomState> =
AHashSet::from_iter(self.columns.iter().map(|s| s.name()));
Expand Down Expand Up @@ -1728,11 +1728,13 @@ impl DataFrame {
mut sort_options: SortMultipleOptions,
slice: Option<(i64, usize)>,
) -> PolarsResult<Self> {
if by_column.is_empty() {
polars_bail!(ComputeError: "No columns selected for sorting");
}
// note that the by_column argument also contains evaluated expression from
// polars-lazy that may not even be present in this dataframe.

// therefore when we try to set the first columns as sorted, we ignore the error
// as expressions are not present (they are renamed to _POLARS_SORT_COLUMN_i.
// polars-lazy that may not even be present in this dataframe. therefore
// when we try to set the first columns as sorted, we ignore the error as
// expressions are not present (they are renamed to _POLARS_SORT_COLUMN_i.
let first_descending = sort_options.descending[0];
let first_by_column = by_column[0].name().to_string();

Expand Down Expand Up @@ -2966,7 +2968,7 @@ impl DataFrame {
for col in cols {
let _ = schema
.get(&col)
.ok_or_else(|| polars_err!(ColumnNotFound: "{}", col))?;
.ok_or_else(|| polars_err!(col_not_found = col))?;
}
}
DataFrame::new(new_cols)
Expand Down
3 changes: 3 additions & 0 deletions crates/polars-error/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,9 @@ on startup."#.trim_start())
(duplicate = $name:expr) => {
polars_err!(Duplicate: "column with name '{}' has more than one occurrences", $name)
};
(col_not_found = $name:expr) => {
polars_err!(ColumnNotFound: "{:?} not found", $name)
};
(oob = $idx:expr, $len:expr) => {
polars_err!(OutOfBounds: "index {} is out of bounds for sequence of length {}", $idx, $len)
};
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-io/src/partition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ where
.iter()
.map(|x| {
let Some(i) = schema.index_of(x.as_ref()) else {
polars_bail!(ColumnNotFound: "{}", x.as_ref())
polars_bail!(col_not_found = x.as_ref())
};
Ok(i)
})
Expand Down
59 changes: 53 additions & 6 deletions crates/polars-plan/src/plans/conversion/dsl_to_ir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -389,14 +389,58 @@ pub fn to_alp_impl(
input,
by_column,
slice,
sort_options,
mut sort_options,
} => {
// note: if given an Expr::Columns, count the individual cols
let n_by_exprs = if by_column.len() == 1 {
match &by_column[0] {
Expr::Columns(cols) => cols.len(),
_ => 1,
}
} else {
by_column.len()
};
let n_desc = sort_options.descending.len();
polars_ensure!(
n_desc == n_by_exprs || n_desc == 1,
ComputeError: "the length of `descending` ({}) does not match the length of `by` ({})", n_desc, by_column.len()
);
let n_nulls_last = sort_options.nulls_last.len();
polars_ensure!(
n_nulls_last == n_by_exprs || n_nulls_last == 1,
ComputeError: "the length of `nulls_last` ({}) does not match the length of `by` ({})", n_nulls_last, by_column.len()
);

let input = to_alp_impl(owned(input), expr_arena, lp_arena, convert)
.map_err(|e| e.context(failed_input!(sort)))?;
let by_column = expand_expressions(input, by_column, lp_arena, expr_arena)
.map_err(|e| e.context(failed_here!(sort)))?;

convert.fill_scratch(&by_column, expr_arena);
let mut expanded_cols = Vec::new();
let mut nulls_last = Vec::new();
let mut descending = Vec::new();

// note: nulls_last/descending need to be matched to expanded multi-output expressions.
// when one of nulls_last/descending has not been updated from the default (single
// value true/false), 'cycle' ensures that "by_column" iter is not truncated.
for (c, (&n, &d)) in by_column.into_iter().zip(
sort_options
.nulls_last
.iter()
.cycle()
.zip(sort_options.descending.iter().cycle()),
) {
let exprs = expand_expressions(input, vec![c], lp_arena, expr_arena)
.map_err(|e| e.context(failed_here!(sort)))?;

nulls_last.extend(std::iter::repeat(n).take(exprs.len()));
descending.extend(std::iter::repeat(d).take(exprs.len()));
expanded_cols.extend(exprs);
}
sort_options.nulls_last = nulls_last;
sort_options.descending = descending;

convert.fill_scratch(&expanded_cols, expr_arena);
let by_column = expanded_cols;

let lp = IR::Sort {
input,
by_column,
Expand Down Expand Up @@ -479,7 +523,7 @@ pub fn to_alp_impl(
if turn_off_coalesce {
let options = Arc::make_mut(&mut options);
if matches!(options.args.coalesce, JoinCoalesce::CoalesceColumns) {
polars_warn!("Coalescing join requested but not all join keys are column references, turning off key coalescing");
polars_warn!("coalescing join requested but not all join keys are column references, turning off key coalescing");
}
options.args.coalesce = JoinCoalesce::KeepColumns;
}
Expand Down Expand Up @@ -604,7 +648,10 @@ pub fn to_alp_impl(
DslFunction::Drop(DropFunction { to_drop, strict }) => {
if strict {
for col_name in to_drop.iter() {
polars_ensure!(input_schema.contains(col_name), ColumnNotFound: "{col_name}");
polars_ensure!(
input_schema.contains(col_name),
col_not_found = col_name
);
}
}

Expand Down
4 changes: 1 addition & 3 deletions crates/polars-plan/src/plans/functions/dsl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,7 @@ impl DslFunction {
let Expr::Column(name) = e else {
polars_bail!(InvalidOperation: "expected column expression")
};

polars_ensure!(input_schema.contains(name), ColumnNotFound: "{name}");

polars_ensure!(input_schema.contains(name), col_not_found = name);
Ok(name.clone())
})
.collect::<PolarsResult<Arc<[Arc<str>]>>>()?;
Expand Down
4 changes: 2 additions & 2 deletions py-polars/polars/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -4814,8 +4814,8 @@ def sort(
Parameters
----------
by
Column(s) to sort by. Accepts expression input. Strings are parsed as column
names.
Column(s) to sort by. Accepts expression input, including selectors. Strings
are parsed as column names.
*more_by
Additional columns to sort by, specified as positional arguments.
descending
Expand Down
8 changes: 4 additions & 4 deletions py-polars/polars/expr/string.py
Original file line number Diff line number Diff line change
Expand Up @@ -2394,7 +2394,7 @@ def contains_any(
patterns
String patterns to search.
ascii_case_insensitive
Enable ASCII-aware case insensitive matching.
Enable ASCII-aware case-insensitive matching.
When this option is enabled, searching will be performed without respect
to case for ASCII letters (a-z and A-Z) only.

Expand Down Expand Up @@ -2448,9 +2448,9 @@ def replace_many(
String patterns to search and replace.
replace_with
Strings to replace where a pattern was a match.
This can be broadcasted. So it supports many:one and many:many.
This can be broadcast, so it supports many:one and many:many.
ascii_case_insensitive
Enable ASCII-aware case insensitive matching.
Enable ASCII-aware case-insensitive matching.
When this option is enabled, searching will be performed without respect
to case for ASCII letters (a-z and A-Z) only.

Expand Down Expand Up @@ -2532,7 +2532,7 @@ def extract_many(
patterns
String patterns to search.
ascii_case_insensitive
Enable ASCII-aware case insensitive matching.
Enable ASCII-aware case-insensitive matching.
When this option is enabled, searching will be performed without respect
to case for ASCII letters (a-z and A-Z) only.
overlapping
Expand Down
5 changes: 3 additions & 2 deletions py-polars/polars/lazyframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -1275,8 +1275,8 @@ def sort(
Parameters
----------
by
Column(s) to sort by. Accepts expression input. Strings are parsed as column
names.
Column(s) to sort by. Accepts expression input, including selectors. Strings
are parsed as column names.
*more_by
Additional columns to sort by, specified as positional arguments.
descending
Expand Down Expand Up @@ -1368,6 +1368,7 @@ def sort(
by = parse_into_list_of_expressions(by, *more_by)
descending = extend_bool(descending, len(by), "descending", "by")
nulls_last = extend_bool(nulls_last, len(by), "nulls_last", "by")

return self._from_pyldf(
self._ldf.sort_by_exprs(
by, descending, nulls_last, maintain_order, multithreaded
Expand Down
8 changes: 4 additions & 4 deletions py-polars/polars/series/string.py
Original file line number Diff line number Diff line change
Expand Up @@ -1813,7 +1813,7 @@ def contains_any(
patterns
String patterns to search.
ascii_case_insensitive
Enable ASCII-aware case insensitive matching.
Enable ASCII-aware case-insensitive matching.
When this option is enabled, searching will be performed without respect
to case for ASCII letters (a-z and A-Z) only.

Expand Down Expand Up @@ -1854,9 +1854,9 @@ def replace_many(
String patterns to search and replace.
replace_with
Strings to replace where a pattern was a match.
This can be broadcasted. So it supports many:one and many:many.
This can be broadcast, so it supports many:one and many:many.
ascii_case_insensitive
Enable ASCII-aware case insensitive matching.
Enable ASCII-aware case-insensitive matching.
When this option is enabled, searching will be performed without respect
to case for ASCII letters (a-z and A-Z) only.

Expand Down Expand Up @@ -1897,7 +1897,7 @@ def extract_many(
patterns
String patterns to search.
ascii_case_insensitive
Enable ASCII-aware case insensitive matching.
Enable ASCII-aware case-insensitive matching.
When this option is enabled, searching will be performed without respect
to case for ASCII letters (a-z and A-Z) only.
overlapping
Expand Down
Loading
Loading