Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support List type coercion for CASE-WHEN-THEN expression #12490

Merged
merged 7 commits into from
Sep 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions datafusion/expr-common/src/type_coercion/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1021,6 +1021,22 @@ fn list_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
use arrow::datatypes::DataType::*;
match (lhs_type, rhs_type) {
(List(_), List(_)) => Some(lhs_type.clone()),
(LargeList(_), List(_)) => Some(lhs_type.clone()),
(List(_), LargeList(_)) => Some(rhs_type.clone()),
(LargeList(_), LargeList(_)) => Some(lhs_type.clone()),
(List(_), FixedSizeList(_, _)) => Some(lhs_type.clone()),
(FixedSizeList(_, _), List(_)) => Some(rhs_type.clone()),
// Coerce to the left side FixedSizeList type if the list lengths are the same,
// otherwise coerce to list with the left type for dynamic length
(FixedSizeList(lf, ls), FixedSizeList(_, rs)) => {
if ls == rs {
Some(lhs_type.clone())
} else {
Some(List(Arc::clone(lf)))
}
}
Comment on lines +1029 to +1037
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because we can't cast a FixedSizeList to a different length FixedSizeList,

> select arrow_cast(arrow_cast([1,2], 'FixedSizeList(2, Int64)'), 'FixedSizeList(3, Int64)');
This feature is not implemented: Unsupported CAST from FixedSizeList(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, 2) to FixedSizeList(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, 3)

I choose to use the List for the dynamic length.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense to me - the alternative is to just reject the query -- I think this is reasonable

(LargeList(_), FixedSizeList(_, _)) => Some(lhs_type.clone()),
(FixedSizeList(_, _), LargeList(_)) => Some(rhs_type.clone()),
_ => None,
}
}
Expand Down Expand Up @@ -1906,6 +1922,63 @@ mod tests {
DataType::Timestamp(TimeUnit::Second, Some("Europe/Brussels".into()))
);

// list
let inner_field = Arc::new(Field::new("item", DataType::Int64, true));
test_coercion_binary_rule!(
DataType::List(Arc::clone(&inner_field)),
DataType::List(Arc::clone(&inner_field)),
Operator::Eq,
DataType::List(Arc::clone(&inner_field))
);
test_coercion_binary_rule!(
DataType::List(Arc::clone(&inner_field)),
DataType::LargeList(Arc::clone(&inner_field)),
Operator::Eq,
DataType::LargeList(Arc::clone(&inner_field))
);
test_coercion_binary_rule!(
DataType::LargeList(Arc::clone(&inner_field)),
DataType::List(Arc::clone(&inner_field)),
Operator::Eq,
DataType::LargeList(Arc::clone(&inner_field))
);
test_coercion_binary_rule!(
DataType::LargeList(Arc::clone(&inner_field)),
DataType::LargeList(Arc::clone(&inner_field)),
Operator::Eq,
DataType::LargeList(Arc::clone(&inner_field))
);
test_coercion_binary_rule!(
DataType::FixedSizeList(Arc::clone(&inner_field), 10),
DataType::FixedSizeList(Arc::clone(&inner_field), 10),
Operator::Eq,
DataType::FixedSizeList(Arc::clone(&inner_field), 10)
);
test_coercion_binary_rule!(
DataType::FixedSizeList(Arc::clone(&inner_field), 10),
DataType::LargeList(Arc::clone(&inner_field)),
Operator::Eq,
DataType::LargeList(Arc::clone(&inner_field))
);
test_coercion_binary_rule!(
DataType::LargeList(Arc::clone(&inner_field)),
DataType::FixedSizeList(Arc::clone(&inner_field), 10),
Operator::Eq,
DataType::LargeList(Arc::clone(&inner_field))
);
test_coercion_binary_rule!(
DataType::List(Arc::clone(&inner_field)),
DataType::FixedSizeList(Arc::clone(&inner_field), 10),
Operator::Eq,
DataType::List(Arc::clone(&inner_field))
);
test_coercion_binary_rule!(
DataType::FixedSizeList(Arc::clone(&inner_field), 10),
DataType::List(Arc::clone(&inner_field)),
Operator::Eq,
DataType::List(Arc::clone(&inner_field))
);

// TODO add other data type
Ok(())
}
Expand Down
180 changes: 180 additions & 0 deletions datafusion/optimizer/src/analyzer/type_coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1811,6 +1811,186 @@ mod test {
Ok(())
}

macro_rules! test_case_expression {
($expr:expr, $when_then:expr, $case_when_type:expr, $then_else_type:expr, $schema:expr) => {
let case = Case {
expr: $expr.map(|e| Box::new(col(e))),
when_then_expr: $when_then,
else_expr: None,
};

let expected =
cast_helper(case.clone(), &$case_when_type, &$then_else_type, &$schema);

let actual = coerce_case_expression(case, &$schema)?;
assert_eq!(expected, actual);
};
}

#[test]
fn tes_case_when_list() -> Result<()> {
let inner_field = Arc::new(Field::new("item", DataType::Int64, true));
let schema = Arc::new(DFSchema::from_unqualified_fields(
vec![
Field::new(
"large_list",
DataType::LargeList(Arc::clone(&inner_field)),
true,
),
Field::new(
"fixed_list",
DataType::FixedSizeList(Arc::clone(&inner_field), 3),
true,
),
Field::new("list", DataType::List(inner_field), true),
]
.into(),
std::collections::HashMap::new(),
)?);

test_case_expression!(
Some("list"),
vec![(Box::new(col("large_list")), Box::new(lit("1")))],
DataType::LargeList(Arc::new(Field::new("item", DataType::Int64, true))),
DataType::Utf8,
schema
);

test_case_expression!(
Some("large_list"),
vec![(Box::new(col("list")), Box::new(lit("1")))],
DataType::LargeList(Arc::new(Field::new("item", DataType::Int64, true))),
DataType::Utf8,
schema
);

test_case_expression!(
Some("list"),
vec![(Box::new(col("fixed_list")), Box::new(lit("1")))],
DataType::List(Arc::new(Field::new("item", DataType::Int64, true))),
DataType::Utf8,
schema
);

test_case_expression!(
Some("fixed_list"),
vec![(Box::new(col("list")), Box::new(lit("1")))],
DataType::List(Arc::new(Field::new("item", DataType::Int64, true))),
DataType::Utf8,
schema
);

test_case_expression!(
Some("fixed_list"),
vec![(Box::new(col("large_list")), Box::new(lit("1")))],
DataType::LargeList(Arc::new(Field::new("item", DataType::Int64, true))),
DataType::Utf8,
schema
);

test_case_expression!(
Some("large_list"),
vec![(Box::new(col("fixed_list")), Box::new(lit("1")))],
DataType::LargeList(Arc::new(Field::new("item", DataType::Int64, true))),
DataType::Utf8,
schema
);
Ok(())
}

#[test]
fn test_then_else_list() -> Result<()> {
let inner_field = Arc::new(Field::new("item", DataType::Int64, true));
let schema = Arc::new(DFSchema::from_unqualified_fields(
vec![
Field::new("boolean", DataType::Boolean, true),
Field::new(
"large_list",
DataType::LargeList(Arc::clone(&inner_field)),
true,
),
Field::new(
"fixed_list",
DataType::FixedSizeList(Arc::clone(&inner_field), 3),
true,
),
Field::new("list", DataType::List(inner_field), true),
]
.into(),
std::collections::HashMap::new(),
)?);

// large list and list
test_case_expression!(
None::<String>,
vec![
(Box::new(col("boolean")), Box::new(col("large_list"))),
(Box::new(col("boolean")), Box::new(col("list")))
],
DataType::Boolean,
DataType::LargeList(Arc::new(Field::new("item", DataType::Int64, true))),
schema
);

test_case_expression!(
None::<String>,
vec![
(Box::new(col("boolean")), Box::new(col("list"))),
(Box::new(col("boolean")), Box::new(col("large_list")))
],
DataType::Boolean,
DataType::LargeList(Arc::new(Field::new("item", DataType::Int64, true))),
schema
);

// fixed list and list
test_case_expression!(
None::<String>,
vec![
(Box::new(col("boolean")), Box::new(col("fixed_list"))),
(Box::new(col("boolean")), Box::new(col("list")))
],
DataType::Boolean,
DataType::List(Arc::new(Field::new("item", DataType::Int64, true))),
schema
);

test_case_expression!(
None::<String>,
vec![
(Box::new(col("boolean")), Box::new(col("list"))),
(Box::new(col("boolean")), Box::new(col("fixed_list")))
],
DataType::Boolean,
DataType::List(Arc::new(Field::new("item", DataType::Int64, true))),
schema
);

// fixed list and large list
test_case_expression!(
None::<String>,
vec![
(Box::new(col("boolean")), Box::new(col("fixed_list"))),
(Box::new(col("boolean")), Box::new(col("large_list")))
],
DataType::Boolean,
DataType::LargeList(Arc::new(Field::new("item", DataType::Int64, true))),
schema
);

test_case_expression!(
None::<String>,
vec![
(Box::new(col("boolean")), Box::new(col("large_list"))),
(Box::new(col("boolean")), Box::new(col("fixed_list")))
],
DataType::Boolean,
DataType::LargeList(Arc::new(Field::new("item", DataType::Int64, true))),
schema
);
Ok(())
}

#[test]
fn interval_plus_timestamp() -> Result<()> {
// SELECT INTERVAL '1' YEAR + '2000-01-01T00:00:00'::timestamp;
Expand Down
10 changes: 8 additions & 2 deletions datafusion/physical-expr/src/expressions/case.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ use crate::physical_expr::down_cast_any_ref;
use crate::PhysicalExpr;

use arrow::array::*;
use arrow::compute::kernels::cmp::eq;
use arrow::compute::kernels::zip::zip;
use arrow::compute::{and, and_not, is_null, not, nullif, or, prep_null_mask_filter};
use arrow::datatypes::{DataType, Schema};
Expand All @@ -33,6 +32,7 @@ use datafusion_common::{exec_err, internal_err, DataFusionError, Result, ScalarV
use datafusion_expr::ColumnarValue;

use super::{Column, Literal};
use datafusion_physical_expr_common::datum::compare_with_eq;
use itertools::Itertools;

type WhenThen = (Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>);
Expand Down Expand Up @@ -204,7 +204,13 @@ impl CaseExpr {
.evaluate_selection(batch, &remainder)?;
let when_value = when_value.into_array(batch.num_rows())?;
// build boolean array representing which rows match the "when" value
let when_match = eq(&when_value, &base_value)?;
let when_match = compare_with_eq(
&when_value,
&base_value,
// The types of case and when expressions will be coerced to match.
// We only need to check if the base_value is nested.
base_value.data_type().is_nested(),
)?;
// Treat nulls as false
let when_match = match when_match.null_count() {
0 => Cow::Borrowed(&when_match),
Expand Down
Loading