Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Utf8View to crypto functions #13407

Merged
merged 1 commit into from
Nov 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 44 additions & 13 deletions datafusion/functions/src/crypto/basic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,18 @@

//! "crypto" DataFusion functions

use arrow::array::StringArray;
use arrow::array::{Array, ArrayRef, BinaryArray, OffsetSizeTrait};
use arrow::array::{AsArray, GenericStringArray, StringArray, StringViewArray};
use arrow::datatypes::DataType;
use blake2::{Blake2b512, Blake2s256, Digest};
use blake3::Hasher as Blake3;
use datafusion_common::cast::as_binary_array;

use arrow::compute::StringArrayType;
use datafusion_common::plan_err;
use datafusion_common::{
cast::{as_generic_binary_array, as_generic_string_array},
exec_err, internal_err, DataFusionError, Result, ScalarValue,
cast::as_generic_binary_array, exec_err, internal_err, DataFusionError, Result,
ScalarValue,
};
use datafusion_expr::ColumnarValue;
use md5::Md5;
Expand Down Expand Up @@ -121,9 +122,9 @@ pub fn digest(args: &[ColumnarValue]) -> Result<ColumnarValue> {
}
let digest_algorithm = match &args[1] {
ColumnarValue::Scalar(scalar) => match scalar {
ScalarValue::Utf8(Some(method)) | ScalarValue::LargeUtf8(Some(method)) => {
method.parse::<DigestAlgorithm>()
}
ScalarValue::Utf8View(Some(method))
| ScalarValue::Utf8(Some(method))
| ScalarValue::LargeUtf8(Some(method)) => method.parse::<DigestAlgorithm>(),
other => exec_err!("Unsupported data type {other:?} for function digest"),
},
ColumnarValue::Array(_) => {
Expand All @@ -132,6 +133,7 @@ pub fn digest(args: &[ColumnarValue]) -> Result<ColumnarValue> {
}?;
digest_process(&args[0], digest_algorithm)
}

impl FromStr for DigestAlgorithm {
type Err = DataFusionError;
fn from_str(name: &str) -> Result<DigestAlgorithm> {
Expand Down Expand Up @@ -166,12 +168,14 @@ impl FromStr for DigestAlgorithm {
})
}
}

impl fmt::Display for DigestAlgorithm {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", format!("{self:?}").to_lowercase())
}
}
// /// computes md5 hash digest of the given input

/// computes md5 hash digest of the given input
pub fn md5(args: &[ColumnarValue]) -> Result<ColumnarValue> {
if args.len() != 1 {
return exec_err!(
Expand All @@ -180,7 +184,9 @@ pub fn md5(args: &[ColumnarValue]) -> Result<ColumnarValue> {
DigestAlgorithm::Md5
);
}

let value = digest_process(&args[0], DigestAlgorithm::Md5)?;

// md5 requires special handling because of its unique utf8 return type
Ok(match value {
ColumnarValue::Array(array) => {
Expand Down Expand Up @@ -214,7 +220,8 @@ pub fn utf8_or_binary_to_binary_type(
name: &str,
) -> Result<DataType> {
Ok(match arg_type {
DataType::LargeUtf8
DataType::Utf8View
| DataType::LargeUtf8
| DataType::Utf8
| DataType::Binary
| DataType::LargeBinary => DataType::Binary,
Expand Down Expand Up @@ -296,8 +303,30 @@ impl DigestAlgorithm {
where
T: OffsetSizeTrait,
{
let input_value = as_generic_string_array::<T>(value)?;
let array: ArrayRef = match self {
let array = match value.data_type() {
DataType::Utf8 | DataType::LargeUtf8 => {
let v = value.as_string::<T>();
self.digest_utf8_array_impl::<&GenericStringArray<T>>(v)
}
DataType::Utf8View => {
let v = value.as_string_view();
self.digest_utf8_array_impl::<&StringViewArray>(v)
}
other => {
return exec_err!("unsupported type for digest_utf_array: {other:?}")
}
};
Ok(ColumnarValue::Array(array))
}

pub fn digest_utf8_array_impl<'a, StringArrType>(
self,
input_value: StringArrType,
) -> ArrayRef
where
StringArrType: StringArrayType<'a>,
{
match self {
Self::Md5 => digest_to_array!(Md5, input_value),
Self::Sha224 => digest_to_array!(Sha224, input_value),
Self::Sha256 => digest_to_array!(Sha256, input_value),
Expand All @@ -318,8 +347,7 @@ impl DigestAlgorithm {
.collect();
Arc::new(binary_array)
}
};
Ok(ColumnarValue::Array(array))
}
}
}
pub fn digest_process(
Expand All @@ -328,6 +356,7 @@ pub fn digest_process(
) -> Result<ColumnarValue> {
match value {
ColumnarValue::Array(a) => match a.data_type() {
DataType::Utf8View => digest_algorithm.digest_utf8_array::<i32>(a.as_ref()),
DataType::Utf8 => digest_algorithm.digest_utf8_array::<i32>(a.as_ref()),
DataType::LargeUtf8 => digest_algorithm.digest_utf8_array::<i64>(a.as_ref()),
DataType::Binary => digest_algorithm.digest_binary_array::<i32>(a.as_ref()),
Expand All @@ -339,7 +368,9 @@ pub fn digest_process(
),
},
ColumnarValue::Scalar(scalar) => match scalar {
ScalarValue::Utf8(a) | ScalarValue::LargeUtf8(a) => {
ScalarValue::Utf8View(a)
| ScalarValue::Utf8(a)
| ScalarValue::LargeUtf8(a) => {
Ok(digest_algorithm
.digest_scalar(a.as_ref().map(|s: &String| s.as_bytes())))
}
Expand Down
1 change: 1 addition & 0 deletions datafusion/functions/src/crypto/digest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ impl DigestFunc {
Self {
signature: Signature::one_of(
vec![
Exact(vec![Utf8View, Utf8View]),
Exact(vec![Utf8, Utf8]),
Exact(vec![LargeUtf8, Utf8]),
Exact(vec![Binary, Utf8]),
Expand Down
4 changes: 2 additions & 2 deletions datafusion/functions/src/crypto/md5.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ impl Md5Func {
Self {
signature: Signature::uniform(
1,
vec![Utf8, LargeUtf8, Binary, LargeBinary],
vec![Utf8View, Utf8, LargeUtf8, Binary, LargeBinary],
Volatility::Immutable,
),
}
Expand All @@ -65,7 +65,7 @@ impl ScalarUDFImpl for Md5Func {
use DataType::*;
Ok(match &arg_types[0] {
LargeUtf8 | LargeBinary => LargeUtf8,
Utf8 | Binary => Utf8,
Utf8View | Utf8 | Binary => Utf8,
Null => Null,
Dictionary(_, t) => match **t {
LargeUtf8 | LargeBinary => LargeUtf8,
Expand Down
2 changes: 1 addition & 1 deletion datafusion/functions/src/crypto/sha224.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ impl SHA224Func {
Self {
signature: Signature::uniform(
1,
vec![Utf8, LargeUtf8, Binary, LargeBinary],
vec![Utf8View, Utf8, LargeUtf8, Binary, LargeBinary],
Volatility::Immutable,
),
}
Expand Down
2 changes: 1 addition & 1 deletion datafusion/functions/src/crypto/sha256.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ impl SHA256Func {
Self {
signature: Signature::uniform(
1,
vec![Utf8, LargeUtf8, Binary, LargeBinary],
vec![Utf8View, Utf8, LargeUtf8, Binary, LargeBinary],
Volatility::Immutable,
),
}
Expand Down
2 changes: 1 addition & 1 deletion datafusion/functions/src/crypto/sha384.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ impl SHA384Func {
Self {
signature: Signature::uniform(
1,
vec![Utf8, LargeUtf8, Binary, LargeBinary],
vec![Utf8View, Utf8, LargeUtf8, Binary, LargeBinary],
Volatility::Immutable,
),
}
Expand Down
2 changes: 1 addition & 1 deletion datafusion/functions/src/crypto/sha512.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ impl SHA512Func {
Self {
signature: Signature::uniform(
1,
vec![Utf8, LargeUtf8, Binary, LargeBinary],
vec![Utf8View, Utf8, LargeUtf8, Binary, LargeBinary],
Volatility::Immutable,
),
}
Expand Down
5 changes: 5 additions & 0 deletions datafusion/sqllogictest/test_files/expr.slt
Original file line number Diff line number Diff line change
Expand Up @@ -2225,6 +2225,11 @@ SELECT digest('','blake3');
----
af1349b9f5f9a1a6a0404dea36dcc9499bcb25c9adc112b7cc9a93cae41f3262

# vverify utf8view
query ?
SELECT sha224(arrow_cast('tom', 'Utf8View'));
----
0bf6cb62649c42a9ae3876ab6f6d92ad36cb5414e495f8873292be4d

query T
SELECT substring('alphabet', 1)
Expand Down
60 changes: 60 additions & 0 deletions datafusion/sqllogictest/test_files/string/string_view.slt
Original file line number Diff line number Diff line change
Expand Up @@ -954,6 +954,66 @@ logical_plan
01)Projection: nullif(test.column1_utf8view, test.column1_utf8view) AS c
02)--TableScan: test projection=[column1_utf8view]

## Ensure no casts for md5
query TT
EXPLAIN SELECT
md5(column1_utf8view) as c
FROM test;
----
logical_plan
01)Projection: md5(test.column1_utf8view) AS c
02)--TableScan: test projection=[column1_utf8view]

## Ensure no casts for sha224
query TT
EXPLAIN SELECT
sha224(column1_utf8view) as c
FROM test;
----
logical_plan
01)Projection: sha224(test.column1_utf8view) AS c
02)--TableScan: test projection=[column1_utf8view]

## Ensure no casts for sha256
query TT
EXPLAIN SELECT
sha256(column1_utf8view) as c
FROM test;
----
logical_plan
01)Projection: sha256(test.column1_utf8view) AS c
02)--TableScan: test projection=[column1_utf8view]

## Ensure no casts for sha384
query TT
EXPLAIN SELECT
sha384(column1_utf8view) as c
FROM test;
----
logical_plan
01)Projection: sha384(test.column1_utf8view) AS c
02)--TableScan: test projection=[column1_utf8view]

## Ensure no casts for sha512
query TT
EXPLAIN SELECT
sha512(column1_utf8view) as c
FROM test;
----
logical_plan
01)Projection: sha512(test.column1_utf8view) AS c
02)--TableScan: test projection=[column1_utf8view]

## Ensure no casts for digest
query TT
EXPLAIN SELECT
digest(column1_utf8view, 'md5') as c
FROM test;
----
logical_plan
01)Projection: digest(test.column1_utf8view, Utf8View("md5")) AS c
02)--TableScan: test projection=[column1_utf8view]

## Ensure no casts for binary operators
# `~` operator (regex match)
query TT
Expand Down
Loading