Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GH-44393: [C++][Compute] Swizzle vector functions #44394

Open
wants to merge 39 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
1e141d7
WIP
zanmato1984 Sep 29, 2024
216e217
WIP
zanmato1984 Sep 30, 2024
f3c73ea
Add permute function options
zanmato1984 Oct 2, 2024
be88f0c
WIP
zanmato1984 Oct 4, 2024
b445c36
Implementation done and basic tests
zanmato1984 Oct 6, 2024
3707fa6
Implement permute
zanmato1984 Oct 7, 2024
4e9d3a6
Reorg reverse_index
zanmato1984 Oct 10, 2024
c4c5c41
Fix API and doc
zanmato1984 Oct 10, 2024
78bb335
Fix API and doc
zanmato1984 Oct 10, 2024
38bcc5d
Merge remote-tracking branch 'origin/main' into vector-placement
zanmato1984 Oct 10, 2024
d88877a
Init docs
zanmato1984 Oct 10, 2024
cc6a0ef
Merge branch 'vector-permute' into vector-placement
zanmato1984 Oct 10, 2024
2bbf44b
Refine
zanmato1984 Oct 10, 2024
b31f9f2
Update docs
zanmato1984 Oct 10, 2024
b951348
Refine doc
zanmato1984 Oct 11, 2024
520b952
Add comments for the implementation
zanmato1984 Oct 11, 2024
b450f5e
Refine docs
zanmato1984 Oct 11, 2024
4ea1465
Fix uint64 overflow check
zanmato1984 Oct 11, 2024
cbdce2f
Reverse indices tests
zanmato1984 Oct 11, 2024
d2e118a
Forbit non-array-like argument
zanmato1984 Oct 11, 2024
7128a28
Fix permute option default
zanmato1984 Oct 11, 2024
034d3b7
Refine
zanmato1984 Oct 11, 2024
9f93e5c
WIP permute tests
zanmato1984 Oct 11, 2024
c320002
Refine tests
zanmato1984 Oct 12, 2024
3e438e8
More permute tests
zanmato1984 Oct 12, 2024
0811b2b
Add if-else tests using permute
zanmato1984 Oct 13, 2024
154ad95
Update some comments
zanmato1984 Oct 13, 2024
66d977a
Fix lint
zanmato1984 Oct 14, 2024
a4c292c
Merge remote-tracking branch 'origin/main' into vector-placement
zanmato1984 Oct 14, 2024
846039d
Update comment
zanmato1984 Oct 14, 2024
2f2ae47
Fix typo
zanmato1984 Oct 14, 2024
3af49a8
Typo
zanmato1984 Oct 14, 2024
944609c
Refine
zanmato1984 Oct 17, 2024
e132f0d
Update cpp/src/arrow/compute/kernels/vector_placement_test.cc
zanmato1984 Oct 31, 2024
220598b
Rename function category to swizzle
zanmato1984 Nov 4, 2024
c03f6e0
reverse_indices -> inverse_permutation
zanmato1984 Nov 4, 2024
705c7b2
output_length -> max_index
zanmato1984 Nov 4, 2024
9e9ccb0
Permute -> Scatter
zanmato1984 Nov 4, 2024
bd334fe
Fixing some renamings
zanmato1984 Nov 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions cpp/src/arrow/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -786,13 +786,14 @@ if(ARROW_COMPUTE)
compute/kernels/scalar_validity.cc
compute/kernels/vector_array_sort.cc
compute/kernels/vector_cumulative_ops.cc
compute/kernels/vector_pairwise.cc
compute/kernels/vector_nested.cc
compute/kernels/vector_pairwise.cc
compute/kernels/vector_rank.cc
compute/kernels/vector_replace.cc
compute/kernels/vector_run_end_encode.cc
compute/kernels/vector_select_k.cc
compute/kernels/vector_sort.cc)
compute/kernels/vector_sort.cc
compute/kernels/vector_swizzle.cc)

append_runtime_avx2_src(ARROW_COMPUTE_SRCS compute/kernels/aggregate_basic_avx2.cc)
append_runtime_avx512_src(ARROW_COMPUTE_SRCS compute/kernels/aggregate_basic_avx512.cc)
Expand Down
33 changes: 33 additions & 0 deletions cpp/src/arrow/compute/api_vector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,12 @@ static auto kPairwiseOptionsType = GetFunctionOptionsType<PairwiseOptions>(
DataMember("periods", &PairwiseOptions::periods));
static auto kListFlattenOptionsType = GetFunctionOptionsType<ListFlattenOptions>(
DataMember("recursive", &ListFlattenOptions::recursive));
static auto kInversePermutationOptionsType =
GetFunctionOptionsType<InversePermutationOptions>(
DataMember("max_index", &InversePermutationOptions::max_index),
DataMember("output_type", &InversePermutationOptions::output_type));
static auto kScatterOptionsType = GetFunctionOptionsType<ScatterOptions>(
DataMember("max_index", &ScatterOptions::max_index));
} // namespace
} // namespace internal

Expand Down Expand Up @@ -230,6 +236,17 @@ ListFlattenOptions::ListFlattenOptions(bool recursive)
: FunctionOptions(internal::kListFlattenOptionsType), recursive(recursive) {}
constexpr char ListFlattenOptions::kTypeName[];

InversePermutationOptions::InversePermutationOptions(
int64_t max_index, std::shared_ptr<DataType> output_type)
: FunctionOptions(internal::kInversePermutationOptionsType),
max_index(max_index),
output_type(std::move(output_type)) {}
constexpr char InversePermutationOptions::kTypeName[];

ScatterOptions::ScatterOptions(int64_t max_index)
: FunctionOptions(internal::kScatterOptionsType), max_index(max_index) {}
constexpr char ScatterOptions::kTypeName[];

namespace internal {
void RegisterVectorOptions(FunctionRegistry* registry) {
DCHECK_OK(registry->AddFunctionOptionsType(kFilterOptionsType));
Expand All @@ -244,6 +261,8 @@ void RegisterVectorOptions(FunctionRegistry* registry) {
DCHECK_OK(registry->AddFunctionOptionsType(kRankOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kPairwiseOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kListFlattenOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kInversePermutationOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kScatterOptionsType));
}
} // namespace internal

Expand Down Expand Up @@ -429,5 +448,19 @@ Result<Datum> CumulativeMean(const Datum& values, const CumulativeOptions& optio
return CallFunction("cumulative_mean", {Datum(values)}, &options, ctx);
}

// ----------------------------------------------------------------------
// Swizzle functions

Result<Datum> InversePermutation(const Datum& indices,
const InversePermutationOptions& options,
ExecContext* ctx) {
return CallFunction("inverse_permutation", {indices}, &options, ctx);
}

Result<Datum> Scatter(const Datum& values, const Datum& indices,
const ScatterOptions& options, ExecContext* ctx) {
return CallFunction("scatter", {values, indices}, &options, ctx);
}

} // namespace compute
} // namespace arrow
79 changes: 79 additions & 0 deletions cpp/src/arrow/compute/api_vector.h
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,38 @@ class ARROW_EXPORT ListFlattenOptions : public FunctionOptions {
bool recursive = false;
};

/// \brief Options for inverse_permutation function
class ARROW_EXPORT InversePermutationOptions : public FunctionOptions {
public:
explicit InversePermutationOptions(int64_t max_index = -1,
std::shared_ptr<DataType> output_type = NULLPTR);
static constexpr char const kTypeName[] = "InversePermutationOptions";
static InversePermutationOptions Defaults() { return InversePermutationOptions(); }

/// \brief The max value in the input indices to process. Any indices that are greater
/// than this value will be ignored. If negative, this value will be set to the length
/// of the input indices minus 1.
int64_t max_index = -1;
/// \brief The type of the output inverse permutation. If null, the output will be of
/// the same type as the input indices, otherwise must be integer types. An invalid
/// error will be reported if this type is not able to store the length of the input
/// indices.
std::shared_ptr<DataType> output_type = NULLPTR;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should nullable being considered here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It'll be decided during the actual computation. If there are "holes" in the output, validity buffer will be allocated and filled on demand.

};

/// \brief Options for scatter function
class ARROW_EXPORT ScatterOptions : public FunctionOptions {
public:
explicit ScatterOptions(int64_t max_index = -1);
static constexpr char const kTypeName[] = "ScatterOptions";
static ScatterOptions Defaults() { return ScatterOptions(); }

/// \brief The max value in the input indices to process. Any values with indices that
/// are greater than this value will be ignored. If negative, this value will be set to
/// the length of the input minus 1.
int64_t max_index = -1;
};

/// @}

/// \brief Filter with a boolean selection filter
Expand Down Expand Up @@ -705,5 +737,52 @@ Result<std::shared_ptr<Array>> PairwiseDiff(const Array& array,
bool check_overflow = false,
ExecContext* ctx = NULLPTR);

/// \brief Return the inverse permutation of the given indices.
///
/// For indices[i] = x, inverse_permutation[x] = i. And inverse_permutation[x] = null if x
/// does not appear in the input indices. For indices[i] = x where x < 0 or x > max_index,
/// it is ignored. If multiple indices point to the same value, the last one is used.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this explanation is confusing, but we can work on this later.

///
/// For example, with indices = [null, 0, 3, 2, 4, 1, 1], the inverse permutation is
/// [1, 6, 3] if max_index = 2,
/// [1, 6, 3, 2, 4, null, null] if max_index = 6.
///
/// \param[in] indices array-like indices
/// \param[in] options configures the max index and the output type
/// \param[in] ctx the function execution context, optional
/// \return the resulting inverse permutation
///
/// \since 19.0.0
/// \note API not yet finalized
ARROW_EXPORT
Result<Datum> InversePermutation(
const Datum& indices,
const InversePermutationOptions& options = InversePermutationOptions::Defaults(),
ExecContext* ctx = NULLPTR);

/// \brief Scatter the values into specified positions according to the indices.
///
/// For indices[i] = x, output[x] = values[i]. And output[x] = null if x does not appear
/// in the input indices. For indices[i] = x where x < 0 or x > max_index, values[i]
/// is ignored. If multiple indices point to the same value, the last one is used.
Comment on lines +766 to +767
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
/// in the input indices. For indices[i] = x where x < 0 or x > max_index, values[i]
/// is ignored. If multiple indices point to the same value, the last one is used.
/// in the input indices. For indices[i] = x where x < 0 or x > max_index, values[i]
/// is ignored. If multiple indices point to the same value, the last one is used.

Can you explain the point of the lenient behavior wrt. negative indices and the max_index option?
Is there a use case that this enables?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No there isn't a particular case in my mind that index < 0 or index > max_index can be very useful, so we can alternatively throw exceptions if we see such values, or even not detect them at all.

(And just in case you are curious about the option max_index itself, it is essential for cases that the input indices are "sparse", please see my other comment #44394 (comment) for details of such usage.)

Thank you.

///
/// For example, with values = [a, b, c, d, e, f, g] and indices = [null, 0,
/// 3, 2, 4, 1, 1], the output is
/// [b, g, d] if max_index = 2,
/// [b, g, d, c, e, null, null] if max_index = 6.
///
/// \param[in] values datum to scatter
/// \param[in] indices array-like indices
/// \param[in] options configures the max index of to scatter
/// \param[in] ctx the function execution context, optional
/// \return the resulting datum
///
/// \since 19.0.0
/// \note API not yet finalized
ARROW_EXPORT
Result<Datum> Scatter(const Datum& values, const Datum& indices,
const ScatterOptions& options = ScatterOptions::Defaults(),
ExecContext* ctx = NULLPTR);

} // namespace compute
} // namespace arrow
4 changes: 4 additions & 0 deletions cpp/src/arrow/compute/function_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,10 @@ TEST(FunctionOptions, Equality) {
options.emplace_back(new SelectKOptions(5, {{SortKey("key", SortOrder::Ascending)}}));
options.emplace_back(new Utf8NormalizeOptions());
options.emplace_back(new Utf8NormalizeOptions(Utf8NormalizeOptions::NFD));
options.emplace_back(
new InversePermutationOptions(/*max_index=*/42, /*output_type=*/int32()));
options.emplace_back(new ScatterOptions());
options.emplace_back(new ScatterOptions(/*max_index=*/42));

for (size_t i = 0; i < options.size(); i++) {
const size_t prev_i = i == 0 ? options.size() - 1 : i - 1;
Expand Down
6 changes: 6 additions & 0 deletions cpp/src/arrow/compute/kernels/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,12 @@ add_arrow_compute_test(vector_selection_test
EXTRA_LINK_LIBS
arrow_compute_kernels_testing)

add_arrow_compute_test(vector_swizzle_test
SOURCES
vector_swizzle_test.cc
EXTRA_LINK_LIBS
arrow_compute_kernels_testing)

add_arrow_benchmark(vector_hash_benchmark PREFIX "arrow-compute")
add_arrow_benchmark(vector_sort_benchmark PREFIX "arrow-compute")
add_arrow_benchmark(vector_partition_benchmark PREFIX "arrow-compute")
Expand Down
5 changes: 3 additions & 2 deletions cpp/src/arrow/compute/kernels/codegen_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -1037,8 +1037,9 @@ ArrayKernelExec GenerateFloatingPoint(detail::GetTypeId get_id) {
// Generate a kernel given a templated functor for integer types
//
// See "Numeric" above for description of the generator functor
template <template <typename...> class Generator, typename Type0, typename... Args>
ArrayKernelExec GenerateInteger(detail::GetTypeId get_id) {
template <template <typename...> class Generator, typename Type0,
typename KernelType = ArrayKernelExec, typename... Args>
KernelType GenerateInteger(detail::GetTypeId get_id) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this change is for generate exec_chunked?

Copy link
Contributor Author

@zanmato1984 zanmato1984 Nov 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Just slightly extending it like GenerateNumeric.

switch (get_id.id) {
case Type::INT8:
return Generator<Type0, Int8Type, Args...>::Exec;
Expand Down
Loading
Loading