Skip to content

Commit

Permalink
Fix functions with Volatility::Volatile and parameters (apache#13001)
Browse files Browse the repository at this point in the history
Co-authored-by: Agaev Huseyn <[email protected]>
  • Loading branch information
2 people authored and 0x501D committed Nov 21, 2024
1 parent d53f727 commit e64f889
Show file tree
Hide file tree
Showing 3 changed files with 212 additions and 5 deletions.
181 changes: 181 additions & 0 deletions datafusion/core/tests/user_defined/user_defined_scalar_functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
// under the License.

use std::any::Any;
use std::collections::HashMap;
use std::hash::{DefaultHasher, Hash, Hasher};
use std::sync::Arc;

use arrow::array::as_string_array;
use arrow::compute::kernels::numeric::add;
use arrow_array::builder::BooleanBuilder;
use arrow_array::cast::AsArray;
Expand Down Expand Up @@ -483,6 +485,185 @@ async fn test_user_defined_functions_with_alias() -> Result<()> {
Ok(())
}

/// Volatile UDF that should append a different value to each row
#[derive(Debug)]
struct AddIndexToStringVolatileScalarUDF {
name: String,
signature: Signature,
return_type: DataType,
}

impl AddIndexToStringVolatileScalarUDF {
fn new() -> Self {
Self {
name: "add_index_to_string".to_string(),
signature: Signature::exact(vec![DataType::Utf8], Volatility::Volatile),
return_type: DataType::Utf8,
}
}
}

impl ScalarUDFImpl for AddIndexToStringVolatileScalarUDF {
fn as_any(&self) -> &dyn Any {
self
}

fn name(&self) -> &str {
&self.name
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(self.return_type.clone())
}

fn invoke(&self, _args: &[ColumnarValue]) -> Result<ColumnarValue> {
not_impl_err!("index_with_offset function does not accept arguments")
}

fn invoke_batch(
&self,
args: &[ColumnarValue],
number_rows: usize,
) -> Result<ColumnarValue> {
let answer = match &args[0] {
// When called with static arguments, the result is returned as an array.
ColumnarValue::Scalar(ScalarValue::Utf8(Some(value))) => {
let mut answer = vec![];
for index in 1..=number_rows {
// When calling a function with immutable arguments, the result is returned with ")".
// Example: SELECT add_index_to_string('const_value') FROM table;
answer.push(index.to_string() + ") " + value);
}
answer
}
// The result is returned as an array when called with dynamic arguments.
ColumnarValue::Array(array) => {
let string_array = as_string_array(array);
let mut counter = HashMap::<&str, u64>::new();
string_array
.iter()
.map(|value| {
let value = value.expect("Unexpected null");
let index = counter.get(value).unwrap_or(&0) + 1;
counter.insert(value, index);

// When calling a function with mutable arguments, the result is returned with ".".
// Example: SELECT add_index_to_string(table.value) FROM table;
index.to_string() + ". " + value
})
.collect()
}
_ => unimplemented!(),
};
Ok(ColumnarValue::Array(Arc::new(StringArray::from(answer))))
}
}

#[tokio::test]
async fn volatile_scalar_udf_with_params() -> Result<()> {
{
let schema = Schema::new(vec![Field::new("a", DataType::Utf8, false)]);

let batch = RecordBatch::try_new(
Arc::new(schema.clone()),
vec![Arc::new(StringArray::from(vec![
"test_1", "test_1", "test_1", "test_2", "test_2", "test_1", "test_2",
]))],
)?;
let ctx = SessionContext::new();

ctx.register_batch("t", batch)?;

let get_new_str_udf = AddIndexToStringVolatileScalarUDF::new();

ctx.register_udf(ScalarUDF::from(get_new_str_udf));

let result =
plan_and_collect(&ctx, "select add_index_to_string(t.a) AS str from t") // with dynamic function parameters
.await?;
let expected = [
"+-----------+",
"| str |",
"+-----------+",
"| 1. test_1 |",
"| 2. test_1 |",
"| 3. test_1 |",
"| 1. test_2 |",
"| 2. test_2 |",
"| 4. test_1 |",
"| 3. test_2 |",
"+-----------+",
];
assert_batches_eq!(expected, &result);

let result =
plan_and_collect(&ctx, "select add_index_to_string('test') AS str from t") // with fixed function parameters
.await?;
let expected = [
"+---------+",
"| str |",
"+---------+",
"| 1) test |",
"| 2) test |",
"| 3) test |",
"| 4) test |",
"| 5) test |",
"| 6) test |",
"| 7) test |",
"+---------+",
];
assert_batches_eq!(expected, &result);

let result =
plan_and_collect(&ctx, "select add_index_to_string('test_value') as str") // with fixed function parameters
.await?;
let expected = [
"+---------------+",
"| str |",
"+---------------+",
"| 1) test_value |",
"+---------------+",
];
assert_batches_eq!(expected, &result);
}
{
let schema = Schema::new(vec![Field::new("a", DataType::Utf8, false)]);

let batch = RecordBatch::try_new(
Arc::new(schema.clone()),
vec![Arc::new(StringArray::from(vec![
"test_1", "test_1", "test_1",
]))],
)?;
let ctx = SessionContext::new();

ctx.register_batch("t", batch)?;

let get_new_str_udf = AddIndexToStringVolatileScalarUDF::new();

ctx.register_udf(ScalarUDF::from(get_new_str_udf));

let result =
plan_and_collect(&ctx, "select add_index_to_string(t.a) AS str from t")
.await?;
let expected = [
"+-----------+", //
"| str |", //
"+-----------+", //
"| 1. test_1 |", //
"| 2. test_1 |", //
"| 3. test_1 |", //
"+-----------+",
];
assert_batches_eq!(expected, &result);
}
Ok(())
}

#[derive(Debug)]
struct CastToI64UDF {
signature: Signature,
Expand Down
31 changes: 30 additions & 1 deletion datafusion/expr/src/udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,17 @@ impl ScalarUDF {
self.inner.is_nullable(args, schema)
}

/// Invoke the function with `args` and number of rows, returning the appropriate result.
///
/// See [`ScalarUDFImpl::invoke_batch`] for more details.
pub fn invoke_batch(
&self,
args: &[ColumnarValue],
number_rows: usize,
) -> Result<ColumnarValue> {
self.inner.invoke_batch(args, number_rows)
}

/// Invoke the function without `args` but number of rows, returning the appropriate result.
///
/// See [`ScalarUDFImpl::invoke_no_args`] for more details.
Expand Down Expand Up @@ -446,7 +457,25 @@ pub trait ScalarUDFImpl: Debug + Send + Sync {
/// to arrays, which will likely be simpler code, but be slower.
///
/// [invoke_no_args]: ScalarUDFImpl::invoke_no_args
fn invoke(&self, _args: &[ColumnarValue]) -> Result<ColumnarValue>;
fn invoke(&self, _args: &[ColumnarValue]) -> Result<ColumnarValue> {
not_impl_err!(
"Function {} does not implement invoke but called",
self.name()
)
}

/// Invoke the function with `args` and the number of rows,
/// returning the appropriate result.
fn invoke_batch(
&self,
args: &[ColumnarValue],
number_rows: usize,
) -> Result<ColumnarValue> {
match args.is_empty() {
true => self.invoke_no_args(number_rows),
false => self.invoke(args),
}
}

/// Invoke the function without `args`, instead the number of rows are provided,
/// returning the appropriate result.
Expand Down
5 changes: 1 addition & 4 deletions datafusion/physical-expr/src/scalar_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,10 +140,7 @@ impl PhysicalExpr for ScalarFunctionExpr {
.collect::<Result<Vec<_>>>()?;

// evaluate the function
let output = match self.args.is_empty() {
true => self.fun.invoke_no_args(batch.num_rows()),
false => self.fun.invoke(&inputs),
}?;
let output = self.fun.invoke_batch(&inputs, batch.num_rows())?;

if let ColumnarValue::Array(array) = &output {
if array.len() != batch.num_rows() {
Expand Down

0 comments on commit e64f889

Please sign in to comment.