diff --git a/datafusion/expr-common/src/signature.rs b/datafusion/expr-common/src/signature.rs index 0fffd84b7047..2398f8154311 100644 --- a/datafusion/expr-common/src/signature.rs +++ b/datafusion/expr-common/src/signature.rs @@ -131,7 +131,7 @@ pub enum TypeSignature { Numeric(usize), /// Fixed number of arguments of all the same string types. /// The precedence of type from high to low is Utf8View, LargeUtf8 and Utf8. - /// Null is considerd as `Utf8` by default + /// Null is considered as `Utf8` by default /// Dictionary with string value type is also handled. String(usize), /// Zero argument diff --git a/datafusion/functions-nested/src/string.rs b/datafusion/functions-nested/src/string.rs index ce555c36274e..851aeac7f6cf 100644 --- a/datafusion/functions-nested/src/string.rs +++ b/datafusion/functions-nested/src/string.rs @@ -32,43 +32,25 @@ use std::any::{type_name, Any}; use crate::utils::{downcast_arg, make_scalar_function}; use arrow::compute::cast; +use arrow_array::builder::{ArrayBuilder, StringViewBuilder}; +use arrow_array::cast::AsArray; +use arrow_array::{GenericStringArray, StringViewArray}; use arrow_schema::DataType::{ - Dictionary, FixedSizeList, LargeList, LargeUtf8, List, Null, Utf8, -}; -use datafusion_common::cast::{ - as_generic_string_array, as_large_list_array, as_list_array, as_string_array, + Dictionary, FixedSizeList, LargeList, LargeUtf8, List, Null, Utf8, Utf8View, }; +use datafusion_common::cast::{as_large_list_array, as_list_array, as_string_array}; use datafusion_common::exec_err; use datafusion_expr::scalar_doc_sections::DOC_SECTION_ARRAY; use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, }; +use datafusion_functions::strings::StringArrayType; use std::sync::{Arc, OnceLock}; -macro_rules! to_string { - ($ARG:expr, $ARRAY:expr, $DELIMITER:expr, $NULL_STRING:expr, $WITH_NULL_STRING:expr, $ARRAY_TYPE:ident) => {{ - let arr = downcast_arg!($ARRAY, $ARRAY_TYPE); - for x in arr { - match x { - Some(x) => { - $ARG.push_str(&x.to_string()); - $ARG.push_str($DELIMITER); - } - None => { - if $WITH_NULL_STRING { - $ARG.push_str($NULL_STRING); - $ARG.push_str($DELIMITER); - } - } - } - } - Ok($ARG) - }}; -} - macro_rules! call_array_function { ($DATATYPE:expr, false) => { match $DATATYPE { + DataType::Utf8View => array_function!(StringViewArray), DataType::Utf8 => array_function!(StringArray), DataType::LargeUtf8 => array_function!(LargeStringArray), DataType::Boolean => array_function!(BooleanArray), @@ -89,6 +71,7 @@ macro_rules! call_array_function { match $DATATYPE { DataType::List(_) => array_function!(ListArray), DataType::Utf8 => array_function!(StringArray), + DataType::Utf8View => array_function!(StringViewArray), DataType::LargeUtf8 => array_function!(LargeStringArray), DataType::Boolean => array_function!(BooleanArray), DataType::Float32 => array_function!(Float32Array), @@ -106,6 +89,27 @@ macro_rules! call_array_function { }}; } +macro_rules! to_string { + ($ARG:expr, $ARRAY:expr, $DELIMITER:expr, $NULL_STRING:expr, $WITH_NULL_STRING:expr, $ARRAY_TYPE:ident) => {{ + let arr = downcast_arg!($ARRAY, $ARRAY_TYPE); + for x in arr { + match x { + Some(x) => { + $ARG.push_str(&x.to_string()); + $ARG.push_str($DELIMITER); + } + None => { + if $WITH_NULL_STRING { + $ARG.push_str($NULL_STRING); + $ARG.push_str($DELIMITER); + } + } + } + } + Ok($ARG) + }}; +} + // Create static instances of ScalarUDFs for each function make_udf_expr_and_func!( ArrayToString, @@ -222,10 +226,7 @@ impl StringToArray { pub fn new() -> Self { Self { signature: Signature::one_of( - vec![ - TypeSignature::Uniform(2, vec![Utf8, LargeUtf8]), - TypeSignature::Uniform(3, vec![Utf8, LargeUtf8]), - ], + vec![TypeSignature::String(2), TypeSignature::String(3)], Volatility::Immutable, ), aliases: vec![String::from("string_to_list")], @@ -248,12 +249,12 @@ impl ScalarUDFImpl for StringToArray { fn return_type(&self, arg_types: &[DataType]) -> Result { Ok(match arg_types[0] { - Utf8 | LargeUtf8 => { + Utf8 | Utf8View | LargeUtf8 => { List(Arc::new(Field::new("item", arg_types[0].clone(), true))) } _ => { return plan_err!( - "The string_to_array function can only accept Utf8 or LargeUtf8." + "The string_to_array function can only accept Utf8, Utf8View or LargeUtf8." ); } }) @@ -261,10 +262,10 @@ impl ScalarUDFImpl for StringToArray { fn invoke(&self, args: &[ColumnarValue]) -> Result { match args[0].data_type() { - Utf8 => make_scalar_function(string_to_array_inner::)(args), + Utf8View | Utf8 => make_scalar_function(string_to_array_inner::)(args), LargeUtf8 => make_scalar_function(string_to_array_inner::)(args), other => { - exec_err!("unsupported type for string_to_array function as {other}") + exec_err!("unsupported type for string_to_array function as {other:?}") } } } @@ -499,16 +500,214 @@ pub fn string_to_array_inner(args: &[ArrayRef]) -> Result 3 { return exec_err!("string_to_array expects two or three arguments"); } - let string_array = as_generic_string_array::(&args[0])?; - let delimiter_array = as_generic_string_array::(&args[1])?; - - let mut list_builder = ListBuilder::new(StringBuilder::with_capacity( - string_array.len(), - string_array.get_buffer_memory_size(), - )); + match (args[0].data_type(), args[1].data_type()) { + (Utf8View, Utf8View) => { + let string_array = args[0].as_string_view(); + let delimiter_array = args[1].as_string_view(); + let builder = StringViewBuilder::with_capacity(string_array.len()); + + if args.len() == 3 { + match args[2].data_type() { + Utf8View => { + let null_type_array = Some(args[2].as_string_view()); + string_to_array_impl::< + &StringViewArray, + &StringViewArray, + &StringViewArray, + StringViewBuilder, + >( + string_array, delimiter_array, null_type_array, builder + ) + } + Utf8 | LargeUtf8 => { + let null_type_array = Some(args[2].as_string::()); + string_to_array_impl::< + &StringViewArray, + &StringViewArray, + &GenericStringArray, + StringViewBuilder, + >( + string_array, delimiter_array, null_type_array, builder + ) + } + other => { + exec_err!( + "unsupported type for string_to_array function as {other:?}" + ) + } + } + } else { + string_to_array_impl::< + &StringViewArray, + &StringViewArray, + &GenericStringArray, + StringViewBuilder, + >(string_array, delimiter_array, None, builder) + } + } + (Utf8View, Utf8 | LargeUtf8) => { + let string_array = args[0].as_string_view(); + let delimiter_array = args[1].as_string::(); + let builder = StringViewBuilder::with_capacity(string_array.len()); + if args.len() == 3 { + match args[2].data_type() { + Utf8View => { + let null_type_array = Some(args[2].as_string_view()); + string_to_array_impl::< + &StringViewArray, + &GenericStringArray, + &StringViewArray, + StringViewBuilder, + >( + string_array, delimiter_array, null_type_array, builder + ) + } + Utf8 | LargeUtf8 => { + let null_type_array = Some(args[2].as_string::()); + string_to_array_impl::< + &StringViewArray, + &GenericStringArray, + &GenericStringArray, + StringViewBuilder, + >( + string_array, delimiter_array, null_type_array, builder + ) + } + other => { + exec_err!( + "unsupported type for string_to_array function as {other:?}" + ) + } + } + } else { + string_to_array_impl::< + &StringViewArray, + &GenericStringArray, + &GenericStringArray, + StringViewBuilder, + >(string_array, delimiter_array, None, builder) + } + } + (Utf8 | LargeUtf8, Utf8 | LargeUtf8) => { + let string_array = args[0].as_string::(); + let delimiter_array = args[1].as_string::(); + let builder = StringBuilder::with_capacity( + string_array.len(), + string_array.get_buffer_memory_size(), + ); + if args.len() == 3 { + match args[2].data_type() { + Utf8View => { + let null_type_array = Some(args[2].as_string_view()); + string_to_array_impl::< + &GenericStringArray, + &GenericStringArray, + &StringViewArray, + StringBuilder, + >( + string_array, delimiter_array, null_type_array, builder + ) + } + Utf8 | LargeUtf8 => { + let null_type_array = Some(args[2].as_string::()); + string_to_array_impl::< + &GenericStringArray, + &GenericStringArray, + &GenericStringArray, + StringBuilder, + >( + string_array, delimiter_array, null_type_array, builder + ) + } + other => { + exec_err!( + "unsupported type for string_to_array function as {other:?}" + ) + } + } + } else { + string_to_array_impl::< + &GenericStringArray, + &GenericStringArray, + &GenericStringArray, + StringBuilder, + >(string_array, delimiter_array, None, builder) + } + } + (Utf8 | LargeUtf8, Utf8View) => { + let string_array = args[0].as_string::(); + let delimiter_array = args[1].as_string_view(); + let builder = StringBuilder::with_capacity( + string_array.len(), + string_array.get_buffer_memory_size(), + ); + if args.len() == 3 { + match args[2].data_type() { + Utf8View => { + let null_type_array = Some(args[2].as_string_view()); + string_to_array_impl::< + &GenericStringArray, + &StringViewArray, + &StringViewArray, + StringBuilder, + >( + string_array, delimiter_array, null_type_array, builder + ) + } + Utf8 | LargeUtf8 => { + let null_type_array = Some(args[2].as_string::()); + string_to_array_impl::< + &GenericStringArray, + &StringViewArray, + &GenericStringArray, + StringBuilder, + >( + string_array, delimiter_array, null_type_array, builder + ) + } + other => { + exec_err!( + "unsupported type for string_to_array function as {other:?}" + ) + } + } + } else { + string_to_array_impl::< + &GenericStringArray, + &StringViewArray, + &GenericStringArray, + StringBuilder, + >(string_array, delimiter_array, None, builder) + } + } + other => { + exec_err!("unsupported type for string_to_array function as {other:?}") + } + } +} - match args.len() { - 2 => { +fn string_to_array_impl< + 'a, + StringArrType, + DelimiterArrType, + NullValueArrType, + StringBuilderType, +>( + string_array: StringArrType, + delimiter_array: DelimiterArrType, + null_value_array: Option, + string_builder: StringBuilderType, +) -> Result +where + StringArrType: StringArrayType<'a>, + DelimiterArrType: StringArrayType<'a>, + NullValueArrType: StringArrayType<'a>, + StringBuilderType: StringArrayBuilderType, +{ + let mut list_builder = ListBuilder::new(string_builder); + + match null_value_array { + None => { string_array.iter().zip(delimiter_array.iter()).for_each( |(string, delimiter)| { match (string, delimiter) { @@ -524,63 +723,80 @@ pub fn string_to_array_inner(args: &[ArrayRef]) -> Result { string.chars().map(|c| c.to_string()).for_each(|c| { - list_builder.values().append_value(c); + list_builder.values().append_value(c.as_str()); }); list_builder.append(true); } _ => list_builder.append(false), // null value } }, - ); + ) } - - 3 => { - let null_value_array = as_generic_string_array::(&args[2])?; - string_array - .iter() - .zip(delimiter_array.iter()) - .zip(null_value_array.iter()) - .for_each(|((string, delimiter), null_value)| { - match (string, delimiter) { - (Some(string), Some("")) => { - if Some(string) == null_value { + Some(null_value_array) => string_array + .iter() + .zip(delimiter_array.iter()) + .zip(null_value_array.iter()) + .for_each(|((string, delimiter), null_value)| { + match (string, delimiter) { + (Some(string), Some("")) => { + if Some(string) == null_value { + list_builder.values().append_null(); + } else { + list_builder.values().append_value(string); + } + list_builder.append(true); + } + (Some(string), Some(delimiter)) => { + string.split(delimiter).for_each(|s| { + if Some(s) == null_value { list_builder.values().append_null(); } else { - list_builder.values().append_value(string); + list_builder.values().append_value(s); } - list_builder.append(true); - } - (Some(string), Some(delimiter)) => { - string.split(delimiter).for_each(|s| { - if Some(s) == null_value { - list_builder.values().append_null(); - } else { - list_builder.values().append_value(s); - } - }); - list_builder.append(true); - } - (Some(string), None) => { - string.chars().map(|c| c.to_string()).for_each(|c| { - if Some(c.as_str()) == null_value { - list_builder.values().append_null(); - } else { - list_builder.values().append_value(c); - } - }); - list_builder.append(true); - } - _ => list_builder.append(false), // null value + }); + list_builder.append(true); } - }); - } - _ => { - return exec_err!( - "Expect string_to_array function to take two or three parameters" - ) - } - } + (Some(string), None) => { + string.chars().map(|c| c.to_string()).for_each(|c| { + if Some(c.as_str()) == null_value { + list_builder.values().append_null(); + } else { + list_builder.values().append_value(c.as_str()); + } + }); + list_builder.append(true); + } + _ => list_builder.append(false), // null value + } + }), + }; let list_array = list_builder.finish(); Ok(Arc::new(list_array) as ArrayRef) } + +trait StringArrayBuilderType: ArrayBuilder { + fn append_value(&mut self, val: &str); + + fn append_null(&mut self); +} + +impl StringArrayBuilderType for StringBuilder { + fn append_value(&mut self, val: &str) { + StringBuilder::append_value(self, val); + } + + fn append_null(&mut self) { + StringBuilder::append_null(self); + } +} + +impl StringArrayBuilderType for StringViewBuilder { + fn append_value(&mut self, val: &str) { + StringViewBuilder::append_value(self, val) + } + + fn append_null(&mut self) { + StringViewBuilder::append_null(self) + } +} diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index 1e60699a1f65..ecc5649173f6 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -3973,6 +3973,18 @@ ORDER BY column1; 3 [bar] bar NULL [baz] baz +# verify make_array does force to Utf8View +query T +SELECT arrow_typeof(make_array(arrow_cast('a', 'Utf8View'), 'b', 'c', 'd')); +---- +List(Field { name: "item", data_type: Utf8View, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) + +# expect a,b,c,d. make_array forces all types to be of a common type (see above) +query T +SELECT array_to_string(make_array(arrow_cast('a', 'Utf8View'), 'b', 'c', 'd'), ','); +---- +a,b,c,d + statement ok drop table table1; @@ -6916,6 +6928,34 @@ select string_to_array(e, ',') from values; [adipiscing] NULL +# string view tests for string_to_array + +# string_to_array scalar function +query ? +SELECT string_to_array(arrow_cast('abcxxxdef', 'Utf8View'), 'xxx') +---- +[abc, def] + +query ? +SELECT string_to_array(arrow_cast('abc', 'Utf8View'), NULL) +---- +[a, b, c] + +query ? +select string_to_array(arrow_cast(e, 'Utf8View'), ',') from values; +---- +[Lorem] +[ipsum] +[dolor] +[sit] +[amet] +[, ] +[consectetur] +[adipiscing] +NULL + +# test string_to_array aliases + query ? select string_to_list(e, 'm') from values; ---- diff --git a/datafusion/sqllogictest/test_files/string/string_view.slt b/datafusion/sqllogictest/test_files/string/string_view.slt index 5a08f3f5447a..98ba8181397c 100644 --- a/datafusion/sqllogictest/test_files/string/string_view.slt +++ b/datafusion/sqllogictest/test_files/string/string_view.slt @@ -1023,6 +1023,27 @@ logical_plan 01)Projection: digest(test.column1_utf8view, Utf8View("md5")) AS c 02)--TableScan: test projection=[column1_utf8view] +## Ensure no unexpected casts for string_to_array +query TT +EXPLAIN SELECT + string_to_array(column1_utf8view, ',') AS c +FROM test; +---- +logical_plan +01)Projection: string_to_array(test.column1_utf8view, Utf8View(",")) AS c +02)--TableScan: test projection=[column1_utf8view] + +## Ensure no unexpected casts for array_to_string +query TT +EXPLAIN SELECT + array_to_string(string_to_array(column1_utf8view, NULL), ',') AS c +FROM test; +---- +logical_plan +01)Projection: array_to_string(string_to_array(test.column1_utf8view, Utf8View(NULL)), Utf8(",")) AS c +02)--TableScan: test projection=[column1_utf8view] + + ## Ensure no casts for binary operators # `~` operator (regex match) query TT