diff --git a/velox/docs/functions/spark/array.rst b/velox/docs/functions/spark/array.rst index b16397f7db0b6..f077c99db9eea 100644 --- a/velox/docs/functions/spark/array.rst +++ b/velox/docs/functions/spark/array.rst @@ -15,6 +15,12 @@ Array Functions SELECT array(1, 2, 3); -- [1,2,3] +.. spark:function:: array_append(array(E), value) -> array(E) + + Add the element at the end of the array passed as first argument. :: + + SELECT array_append(array(1, 2, 3), 2); -- [1, 2, 3, 2] + .. spark:function:: array_contains(array(E), value) -> boolean Returns true if the array contains the value. :: diff --git a/velox/functions/sparksql/ArrayAppend.h b/velox/functions/sparksql/ArrayAppend.h new file mode 100644 index 0000000000000..34bbf90338955 --- /dev/null +++ b/velox/functions/sparksql/ArrayAppend.h @@ -0,0 +1,40 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "velox/functions/Macros.h" + +namespace facebook::velox::functions::sparksql { + +template +struct ArrayAppendFunction { + VELOX_DEFINE_FUNCTION_TYPES(TExec); + + FOLLY_ALWAYS_INLINE bool callNullable( + out_type>>& out, + const arg_type>>* array, + const arg_type>* element) { + if (array == nullptr) { + return false; + } + out.reserve(array->size() + 1); + out.add_items(*array); + out.push_back(*element); + return true; + } +}; + +} // namespace facebook::velox::functions::sparksql diff --git a/velox/functions/sparksql/registration/RegisterArray.cpp b/velox/functions/sparksql/registration/RegisterArray.cpp index c75a53ffc68d1..321d78dda98d9 100644 --- a/velox/functions/sparksql/registration/RegisterArray.cpp +++ b/velox/functions/sparksql/registration/RegisterArray.cpp @@ -18,6 +18,7 @@ #include "velox/functions/lib/Repeat.h" #include "velox/functions/lib/Slice.h" #include "velox/functions/prestosql/ArrayFunctions.h" +#include "velox/functions/sparksql/ArrayAppend.h" #include "velox/functions/sparksql/ArrayFlattenFunction.h" #include "velox/functions/sparksql/ArrayInsert.h" #include "velox/functions/sparksql/ArrayMinMaxFunction.h" @@ -141,6 +142,11 @@ void registerArrayFunctions(const std::string& prefix) { makeArrayShuffleWithCustomSeed, getMetadataForArrayShuffle()); registerIntegerSliceFunction(prefix); + registerFunction< + ArrayAppendFunction, + Array>, + Array>, + Generic>({prefix + "array_append"}); } } // namespace sparksql diff --git a/velox/functions/sparksql/tests/ArrayAppendTest.cpp b/velox/functions/sparksql/tests/ArrayAppendTest.cpp new file mode 100644 index 0000000000000..d54850b152df8 --- /dev/null +++ b/velox/functions/sparksql/tests/ArrayAppendTest.cpp @@ -0,0 +1,68 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/functions/sparksql/tests/SparkFunctionBaseTest.h" + +using namespace facebook::velox::test; + +namespace facebook::velox::functions::sparksql::test { +namespace { +class ArrayAppendTest : public SparkFunctionBaseTest { + protected: + void testExpression( + const std::string& expression, + const std::vector& input, + const VectorPtr& expected) { + auto result = evaluate(expression, makeRowVector(input)); + assertEqualVectors(expected, result); + } +}; + +TEST_F(ArrayAppendTest, intArrays) { + const auto arrayVector = makeArrayVector( + {{1, 2, 3, 4}, {3, 4, 5}, {7, 8, 9}, {10, 20, 30}}); + const auto elementVector = makeFlatVector({11, 22, 33, 44}); + VectorPtr expected; + + expected = makeArrayVector({ + {1, 2, 3, 4, 11}, + {3, 4, 5, 22}, + {7, 8, 9, 33}, + {10, 20, 30, 44}, + }); + testExpression( + "array_append(c0, c1)", {arrayVector, elementVector}, expected); +} + +TEST_F(ArrayAppendTest, nullArrays) { + const auto arrayVector = makeNullableArrayVector( + {{1, 2, 3, std::nullopt}, {3, 4, 5}, {7, 8, 9}, {10, 20, std::nullopt}}); + const auto elementVector = + makeNullableFlatVector({11, std::nullopt, 33, std::nullopt}); + VectorPtr expected; + + expected = makeNullableArrayVector({ + {1, 2, 3, std::nullopt, 11}, + {3, 4, 5, std::nullopt}, + {7, 8, 9, 33}, + {10, 20, std::nullopt, std::nullopt}, + }); + testExpression( + "array_append(c0, c1)", {arrayVector, elementVector}, expected); +} + +} // namespace +} // namespace facebook::velox::functions::sparksql::test diff --git a/velox/functions/sparksql/tests/CMakeLists.txt b/velox/functions/sparksql/tests/CMakeLists.txt index 39087bd8adb55..314eefae0a59e 100644 --- a/velox/functions/sparksql/tests/CMakeLists.txt +++ b/velox/functions/sparksql/tests/CMakeLists.txt @@ -15,6 +15,7 @@ add_executable( velox_functions_spark_test ArithmeticTest.cpp + ArrayAppendTest.cpp ArrayFlattenTest.cpp ArrayGetTest.cpp ArrayInsertTest.cpp