From e64f8895eaa4b336d8aba9cb5c494703f88b7a3a Mon Sep 17 00:00:00 2001 From: Agaev Guseyn <60943542+agscpp@users.noreply.github.com> Date: Tue, 22 Oct 2024 18:05:54 +0300 Subject: [PATCH] Fix functions with Volatility::Volatile and parameters (#13001) Co-authored-by: Agaev Huseyn --- .../user_defined_scalar_functions.rs | 181 ++++++++++++++++++ datafusion/expr/src/udf.rs | 31 ++- .../physical-expr/src/scalar_function.rs | 5 +- 3 files changed, 212 insertions(+), 5 deletions(-) diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index 0f1c3b8e53c4..607974ffa886 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -16,9 +16,11 @@ // under the License. use std::any::Any; +use std::collections::HashMap; use std::hash::{DefaultHasher, Hash, Hasher}; use std::sync::Arc; +use arrow::array::as_string_array; use arrow::compute::kernels::numeric::add; use arrow_array::builder::BooleanBuilder; use arrow_array::cast::AsArray; @@ -483,6 +485,185 @@ async fn test_user_defined_functions_with_alias() -> Result<()> { Ok(()) } +/// Volatile UDF that should append a different value to each row +#[derive(Debug)] +struct AddIndexToStringVolatileScalarUDF { + name: String, + signature: Signature, + return_type: DataType, +} + +impl AddIndexToStringVolatileScalarUDF { + fn new() -> Self { + Self { + name: "add_index_to_string".to_string(), + signature: Signature::exact(vec![DataType::Utf8], Volatility::Volatile), + return_type: DataType::Utf8, + } + } +} + +impl ScalarUDFImpl for AddIndexToStringVolatileScalarUDF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(self.return_type.clone()) + } + + fn invoke(&self, _args: &[ColumnarValue]) -> Result { + not_impl_err!("index_with_offset function does not accept arguments") + } + + fn invoke_batch( + &self, + args: &[ColumnarValue], + number_rows: usize, + ) -> Result { + let answer = match &args[0] { + // When called with static arguments, the result is returned as an array. + ColumnarValue::Scalar(ScalarValue::Utf8(Some(value))) => { + let mut answer = vec![]; + for index in 1..=number_rows { + // When calling a function with immutable arguments, the result is returned with ")". + // Example: SELECT add_index_to_string('const_value') FROM table; + answer.push(index.to_string() + ") " + value); + } + answer + } + // The result is returned as an array when called with dynamic arguments. + ColumnarValue::Array(array) => { + let string_array = as_string_array(array); + let mut counter = HashMap::<&str, u64>::new(); + string_array + .iter() + .map(|value| { + let value = value.expect("Unexpected null"); + let index = counter.get(value).unwrap_or(&0) + 1; + counter.insert(value, index); + + // When calling a function with mutable arguments, the result is returned with ".". + // Example: SELECT add_index_to_string(table.value) FROM table; + index.to_string() + ". " + value + }) + .collect() + } + _ => unimplemented!(), + }; + Ok(ColumnarValue::Array(Arc::new(StringArray::from(answer)))) + } +} + +#[tokio::test] +async fn volatile_scalar_udf_with_params() -> Result<()> { + { + let schema = Schema::new(vec![Field::new("a", DataType::Utf8, false)]); + + let batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![Arc::new(StringArray::from(vec![ + "test_1", "test_1", "test_1", "test_2", "test_2", "test_1", "test_2", + ]))], + )?; + let ctx = SessionContext::new(); + + ctx.register_batch("t", batch)?; + + let get_new_str_udf = AddIndexToStringVolatileScalarUDF::new(); + + ctx.register_udf(ScalarUDF::from(get_new_str_udf)); + + let result = + plan_and_collect(&ctx, "select add_index_to_string(t.a) AS str from t") // with dynamic function parameters + .await?; + let expected = [ + "+-----------+", + "| str |", + "+-----------+", + "| 1. test_1 |", + "| 2. test_1 |", + "| 3. test_1 |", + "| 1. test_2 |", + "| 2. test_2 |", + "| 4. test_1 |", + "| 3. test_2 |", + "+-----------+", + ]; + assert_batches_eq!(expected, &result); + + let result = + plan_and_collect(&ctx, "select add_index_to_string('test') AS str from t") // with fixed function parameters + .await?; + let expected = [ + "+---------+", + "| str |", + "+---------+", + "| 1) test |", + "| 2) test |", + "| 3) test |", + "| 4) test |", + "| 5) test |", + "| 6) test |", + "| 7) test |", + "+---------+", + ]; + assert_batches_eq!(expected, &result); + + let result = + plan_and_collect(&ctx, "select add_index_to_string('test_value') as str") // with fixed function parameters + .await?; + let expected = [ + "+---------------+", + "| str |", + "+---------------+", + "| 1) test_value |", + "+---------------+", + ]; + assert_batches_eq!(expected, &result); + } + { + let schema = Schema::new(vec![Field::new("a", DataType::Utf8, false)]); + + let batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![Arc::new(StringArray::from(vec![ + "test_1", "test_1", "test_1", + ]))], + )?; + let ctx = SessionContext::new(); + + ctx.register_batch("t", batch)?; + + let get_new_str_udf = AddIndexToStringVolatileScalarUDF::new(); + + ctx.register_udf(ScalarUDF::from(get_new_str_udf)); + + let result = + plan_and_collect(&ctx, "select add_index_to_string(t.a) AS str from t") + .await?; + let expected = [ + "+-----------+", // + "| str |", // + "+-----------+", // + "| 1. test_1 |", // + "| 2. test_1 |", // + "| 3. test_1 |", // + "+-----------+", + ]; + assert_batches_eq!(expected, &result); + } + Ok(()) +} + #[derive(Debug)] struct CastToI64UDF { signature: Signature, diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index be3f811dbe51..3326fa7a9d65 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -209,6 +209,17 @@ impl ScalarUDF { self.inner.is_nullable(args, schema) } + /// Invoke the function with `args` and number of rows, returning the appropriate result. + /// + /// See [`ScalarUDFImpl::invoke_batch`] for more details. + pub fn invoke_batch( + &self, + args: &[ColumnarValue], + number_rows: usize, + ) -> Result { + self.inner.invoke_batch(args, number_rows) + } + /// Invoke the function without `args` but number of rows, returning the appropriate result. /// /// See [`ScalarUDFImpl::invoke_no_args`] for more details. @@ -446,7 +457,25 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { /// to arrays, which will likely be simpler code, but be slower. /// /// [invoke_no_args]: ScalarUDFImpl::invoke_no_args - fn invoke(&self, _args: &[ColumnarValue]) -> Result; + fn invoke(&self, _args: &[ColumnarValue]) -> Result { + not_impl_err!( + "Function {} does not implement invoke but called", + self.name() + ) + } + + /// Invoke the function with `args` and the number of rows, + /// returning the appropriate result. + fn invoke_batch( + &self, + args: &[ColumnarValue], + number_rows: usize, + ) -> Result { + match args.is_empty() { + true => self.invoke_no_args(number_rows), + false => self.invoke(args), + } + } /// Invoke the function without `args`, instead the number of rows are provided, /// returning the appropriate result. diff --git a/datafusion/physical-expr/src/scalar_function.rs b/datafusion/physical-expr/src/scalar_function.rs index 130c335d1c95..9b85ef9d0229 100644 --- a/datafusion/physical-expr/src/scalar_function.rs +++ b/datafusion/physical-expr/src/scalar_function.rs @@ -140,10 +140,7 @@ impl PhysicalExpr for ScalarFunctionExpr { .collect::>>()?; // evaluate the function - let output = match self.args.is_empty() { - true => self.fun.invoke_no_args(batch.num_rows()), - false => self.fun.invoke(&inputs), - }?; + let output = self.fun.invoke_batch(&inputs, batch.num_rows())?; if let ColumnarValue::Array(array) = &output { if array.len() != batch.num_rows() {