Skip to content

Commit

Permalink
feat: Add Spark array_append function
Browse files Browse the repository at this point in the history
  • Loading branch information
leoluan2009 committed Jan 22, 2025
1 parent 523bf82 commit f404b3f
Show file tree
Hide file tree
Showing 5 changed files with 121 additions and 0 deletions.
6 changes: 6 additions & 0 deletions velox/docs/functions/spark/array.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@ Array Functions

SELECT array(1, 2, 3); -- [1,2,3]

.. spark:function:: array_append(array(E), value) -> array(E)
Add the element at the end of the array passed as first argument. ::

SELECT array_append(array(1, 2, 3), 2); -- [1, 2, 3, 2]

.. spark:function:: array_contains(array(E), value) -> boolean
Returns true if the array contains the value. ::
Expand Down
40 changes: 40 additions & 0 deletions velox/functions/sparksql/ArrayAppend.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include "velox/functions/Macros.h"

namespace facebook::velox::functions::sparksql {

template <typename TExec>
struct ArrayAppendFunction {
VELOX_DEFINE_FUNCTION_TYPES(TExec);

FOLLY_ALWAYS_INLINE bool callNullable(
out_type<Array<Generic<T1>>>& out,
const arg_type<Array<Generic<T1>>>* array,
const arg_type<Generic<T1>>* element) {
if (array == nullptr) {
return false;
}
out.reserve(array->size() + 1);
out.add_items(*array);
out.push_back(*element);
return true;
}
};

} // namespace facebook::velox::functions::sparksql
6 changes: 6 additions & 0 deletions velox/functions/sparksql/registration/RegisterArray.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "velox/functions/lib/Repeat.h"
#include "velox/functions/lib/Slice.h"
#include "velox/functions/prestosql/ArrayFunctions.h"
#include "velox/functions/sparksql/ArrayAppend.h"
#include "velox/functions/sparksql/ArrayFlattenFunction.h"
#include "velox/functions/sparksql/ArrayInsert.h"
#include "velox/functions/sparksql/ArrayMinMaxFunction.h"
Expand Down Expand Up @@ -141,6 +142,11 @@ void registerArrayFunctions(const std::string& prefix) {
makeArrayShuffleWithCustomSeed,
getMetadataForArrayShuffle());
registerIntegerSliceFunction(prefix);
registerFunction<
ArrayAppendFunction,
Array<Generic<T1>>,
Array<Generic<T1>>,
Generic<T1>>({prefix + "array_append"});
}

} // namespace sparksql
Expand Down
68 changes: 68 additions & 0 deletions velox/functions/sparksql/tests/ArrayAppendTest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "velox/functions/sparksql/tests/SparkFunctionBaseTest.h"

using namespace facebook::velox::test;

namespace facebook::velox::functions::sparksql::test {
namespace {
class ArrayAppendTest : public SparkFunctionBaseTest {
protected:
void testExpression(
const std::string& expression,
const std::vector<VectorPtr>& input,
const VectorPtr& expected) {
auto result = evaluate(expression, makeRowVector(input));
assertEqualVectors(expected, result);
}
};

TEST_F(ArrayAppendTest, intArrays) {
const auto arrayVector = makeArrayVector<int64_t>(
{{1, 2, 3, 4}, {3, 4, 5}, {7, 8, 9}, {10, 20, 30}});
const auto elementVector = makeFlatVector<int64_t>({11, 22, 33, 44});
VectorPtr expected;

expected = makeArrayVector<int64_t>({
{1, 2, 3, 4, 11},
{3, 4, 5, 22},
{7, 8, 9, 33},
{10, 20, 30, 44},
});
testExpression(
"array_append(c0, c1)", {arrayVector, elementVector}, expected);
}

TEST_F(ArrayAppendTest, nullArrays) {
const auto arrayVector = makeNullableArrayVector<int64_t>(
{{1, 2, 3, std::nullopt}, {3, 4, 5}, {7, 8, 9}, {10, 20, std::nullopt}});
const auto elementVector =
makeNullableFlatVector<int64_t>({11, std::nullopt, 33, std::nullopt});
VectorPtr expected;

expected = makeNullableArrayVector<int64_t>({
{1, 2, 3, std::nullopt, 11},
{3, 4, 5, std::nullopt},
{7, 8, 9, 33},
{10, 20, std::nullopt, std::nullopt},
});
testExpression(
"array_append(c0, c1)", {arrayVector, elementVector}, expected);
}

} // namespace
} // namespace facebook::velox::functions::sparksql::test
1 change: 1 addition & 0 deletions velox/functions/sparksql/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
add_executable(
velox_functions_spark_test
ArithmeticTest.cpp
ArrayAppendTest.cpp
ArrayFlattenTest.cpp
ArrayGetTest.cpp
ArrayInsertTest.cpp
Expand Down

0 comments on commit f404b3f

Please sign in to comment.