Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GH-44393: [C++][Compute] Swizzle vector functions #44394

Open
wants to merge 39 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 33 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
1e141d7
WIP
zanmato1984 Sep 29, 2024
216e217
WIP
zanmato1984 Sep 30, 2024
f3c73ea
Add permute function options
zanmato1984 Oct 2, 2024
be88f0c
WIP
zanmato1984 Oct 4, 2024
b445c36
Implementation done and basic tests
zanmato1984 Oct 6, 2024
3707fa6
Implement permute
zanmato1984 Oct 7, 2024
4e9d3a6
Reorg reverse_index
zanmato1984 Oct 10, 2024
c4c5c41
Fix API and doc
zanmato1984 Oct 10, 2024
78bb335
Fix API and doc
zanmato1984 Oct 10, 2024
38bcc5d
Merge remote-tracking branch 'origin/main' into vector-placement
zanmato1984 Oct 10, 2024
d88877a
Init docs
zanmato1984 Oct 10, 2024
cc6a0ef
Merge branch 'vector-permute' into vector-placement
zanmato1984 Oct 10, 2024
2bbf44b
Refine
zanmato1984 Oct 10, 2024
b31f9f2
Update docs
zanmato1984 Oct 10, 2024
b951348
Refine doc
zanmato1984 Oct 11, 2024
520b952
Add comments for the implementation
zanmato1984 Oct 11, 2024
b450f5e
Refine docs
zanmato1984 Oct 11, 2024
4ea1465
Fix uint64 overflow check
zanmato1984 Oct 11, 2024
cbdce2f
Reverse indices tests
zanmato1984 Oct 11, 2024
d2e118a
Forbit non-array-like argument
zanmato1984 Oct 11, 2024
7128a28
Fix permute option default
zanmato1984 Oct 11, 2024
034d3b7
Refine
zanmato1984 Oct 11, 2024
9f93e5c
WIP permute tests
zanmato1984 Oct 11, 2024
c320002
Refine tests
zanmato1984 Oct 12, 2024
3e438e8
More permute tests
zanmato1984 Oct 12, 2024
0811b2b
Add if-else tests using permute
zanmato1984 Oct 13, 2024
154ad95
Update some comments
zanmato1984 Oct 13, 2024
66d977a
Fix lint
zanmato1984 Oct 14, 2024
a4c292c
Merge remote-tracking branch 'origin/main' into vector-placement
zanmato1984 Oct 14, 2024
846039d
Update comment
zanmato1984 Oct 14, 2024
2f2ae47
Fix typo
zanmato1984 Oct 14, 2024
3af49a8
Typo
zanmato1984 Oct 14, 2024
944609c
Refine
zanmato1984 Oct 17, 2024
e132f0d
Update cpp/src/arrow/compute/kernels/vector_placement_test.cc
zanmato1984 Oct 31, 2024
220598b
Rename function category to swizzle
zanmato1984 Nov 4, 2024
c03f6e0
reverse_indices -> inverse_permutation
zanmato1984 Nov 4, 2024
705c7b2
output_length -> max_index
zanmato1984 Nov 4, 2024
9e9ccb0
Permute -> Scatter
zanmato1984 Nov 4, 2024
bd334fe
Fixing some renamings
zanmato1984 Nov 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cpp/src/arrow/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -787,6 +787,7 @@ if(ARROW_COMPUTE)
compute/kernels/vector_array_sort.cc
compute/kernels/vector_cumulative_ops.cc
compute/kernels/vector_pairwise.cc
compute/kernels/vector_placement.cc
compute/kernels/vector_nested.cc
compute/kernels/vector_rank.cc
compute/kernels/vector_replace.cc
Expand Down
31 changes: 31 additions & 0 deletions cpp/src/arrow/compute/api_vector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,11 @@ static auto kPairwiseOptionsType = GetFunctionOptionsType<PairwiseOptions>(
DataMember("periods", &PairwiseOptions::periods));
static auto kListFlattenOptionsType = GetFunctionOptionsType<ListFlattenOptions>(
DataMember("recursive", &ListFlattenOptions::recursive));
static auto kReverseIndicesOptionsType = GetFunctionOptionsType<ReverseIndicesOptions>(
DataMember("output_length", &ReverseIndicesOptions::output_length),
DataMember("output_type", &ReverseIndicesOptions::output_type));
static auto kPermuteOptionsType = GetFunctionOptionsType<PermuteOptions>(
DataMember("output_length", &PermuteOptions::output_length));
zanmato1984 marked this conversation as resolved.
Show resolved Hide resolved
} // namespace
} // namespace internal

Expand Down Expand Up @@ -230,6 +235,17 @@ ListFlattenOptions::ListFlattenOptions(bool recursive)
: FunctionOptions(internal::kListFlattenOptionsType), recursive(recursive) {}
constexpr char ListFlattenOptions::kTypeName[];

ReverseIndicesOptions::ReverseIndicesOptions(int64_t output_length,
std::shared_ptr<DataType> output_type)
: FunctionOptions(internal::kReverseIndicesOptionsType),
output_length(output_length),
output_type(std::move(output_type)) {}
constexpr char ReverseIndicesOptions::kTypeName[];

PermuteOptions::PermuteOptions(int64_t output_length)
: FunctionOptions(internal::kPermuteOptionsType), output_length(output_length) {}
constexpr char PermuteOptions::kTypeName[];

namespace internal {
void RegisterVectorOptions(FunctionRegistry* registry) {
DCHECK_OK(registry->AddFunctionOptionsType(kFilterOptionsType));
Expand All @@ -244,6 +260,8 @@ void RegisterVectorOptions(FunctionRegistry* registry) {
DCHECK_OK(registry->AddFunctionOptionsType(kRankOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kPairwiseOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kListFlattenOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kReverseIndicesOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kPermuteOptionsType));
}
} // namespace internal

Expand Down Expand Up @@ -429,5 +447,18 @@ Result<Datum> CumulativeMean(const Datum& values, const CumulativeOptions& optio
return CallFunction("cumulative_mean", {Datum(values)}, &options, ctx);
}

// ----------------------------------------------------------------------
// Placement functions

Result<Datum> ReverseIndices(const Datum& indices, const ReverseIndicesOptions& options,
ExecContext* ctx) {
return CallFunction("reverse_indices", {indices}, &options, ctx);
}

Result<Datum> Permute(const Datum& values, const Datum& indices,
const PermuteOptions& options, ExecContext* ctx) {
return CallFunction("permute", {values, indices}, &options, ctx);
}

} // namespace compute
} // namespace arrow
82 changes: 82 additions & 0 deletions cpp/src/arrow/compute/api_vector.h
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,37 @@ class ARROW_EXPORT ListFlattenOptions : public FunctionOptions {
bool recursive = false;
};

/// \brief Options for reverse_indices function
class ARROW_EXPORT ReverseIndicesOptions : public FunctionOptions {
public:
explicit ReverseIndicesOptions(int64_t output_length = -1,
std::shared_ptr<DataType> output_type = NULLPTR);
static constexpr char const kTypeName[] = "ReverseIndicesOptions";
static ReverseIndicesOptions Defaults() { return ReverseIndicesOptions(); }

/// \brief The length of the output reverse indices. If negative, the output will be of
/// the same length as the input indices. Any indices that are greater or equal to this
/// length will be ignored.
int64_t output_length = -1;
/// \brief The type of the output reverse indices. If null, the output will be of the
/// same type as the input indices, otherwise must be integer types. An invalid error
/// will be reported if this type is not able to store the length of the input indices.
std::shared_ptr<DataType> output_type = NULLPTR;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should nullable being considered here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It'll be decided during the actual computation. If there are "holes" in the output, validity buffer will be allocated and filled on demand.

};

/// \brief Options for permute function
class ARROW_EXPORT PermuteOptions : public FunctionOptions {
public:
explicit PermuteOptions(int64_t output_length = -1);
static constexpr char const kTypeName[] = "PermuteOptions";
static PermuteOptions Defaults() { return PermuteOptions(); }

/// \brief The length of the output permutation. If negative, the output will be of the
/// same length as the input values (and indices). Any values with indices that are
/// greater or equal to this length will be ignored.
int64_t output_length = -1;
};

/// @}

/// \brief Filter with a boolean selection filter
Expand Down Expand Up @@ -705,5 +736,56 @@ Result<std::shared_ptr<Array>> PairwiseDiff(const Array& array,
bool check_overflow = false,
ExecContext* ctx = NULLPTR);

/// \brief Return the reverse indices of the given indices.
zanmato1984 marked this conversation as resolved.
Show resolved Hide resolved
///
/// For indices[i] = x, reverse_indices[x] = i. And reverse_indices[x] = null if x does
/// not appear in the input indices. For indices[i] = x where x < 0 or x >= output_length,
/// it is ignored. If multiple indices point to the same value, the last one is used.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this explanation is confusing, but we can work on this later.

///
/// For example, with indices = [null, 0, 3, 2, 4, 1, 1], the reverse indices is
/// [1, 6, 3] if output_length = 3,
/// [1, 6, 3, 2, 4, null, null] if output_length = 7.
zanmato1984 marked this conversation as resolved.
Show resolved Hide resolved
/// output_length can also be negative, in which case the reverse indices is of the same
/// length as the indices.
///
/// \param[in] indices array-like indices
/// \param[in] options configures the output length and the output type
/// \param[in] ctx the function execution context, optional
/// \return the resulting reverse indices
///
/// \since 19.0.0
/// \note API not yet finalized
ARROW_EXPORT
Result<Datum> ReverseIndices(
const Datum& indices,
const ReverseIndicesOptions& options = ReverseIndicesOptions::Defaults(),
ExecContext* ctx = NULLPTR);

/// \brief Permute the values into specified positions according to the indices.
///
/// For indices[i] = x, output[x] = values[i]. And output[x] = null if x does not appear
/// in the input indices. For indices[i] = x where x < 0 or x >= output_length, values[i]
/// is ignored. If multiple indices point to the same value, the last one is used.
///
/// For example, with values = [a, b, c, d, e, f, g] and indices = [null, 0,
/// 3, 2, 4, 1, 1], the permutation is
/// [b, g, d] if output_length = 3,
/// [b, g, d, c, e, null, null] if output_length = 7.
/// output_length can also be negative, in which case the permutation is of the same
/// length as the values (and indices).
///
/// \param[in] values datum to permute
/// \param[in] indices array-like indices
/// \param[in] options configures the output length of the permutation
/// \param[in] ctx the function execution context, optional
/// \return the resulting permutation
///
/// \since 19.0.0
/// \note API not yet finalized
ARROW_EXPORT
Result<Datum> Permute(const Datum& values, const Datum& indices,
const PermuteOptions& options = PermuteOptions::Defaults(),
ExecContext* ctx = NULLPTR);

} // namespace compute
} // namespace arrow
4 changes: 4 additions & 0 deletions cpp/src/arrow/compute/function_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,10 @@ TEST(FunctionOptions, Equality) {
options.emplace_back(new SelectKOptions(5, {{SortKey("key", SortOrder::Ascending)}}));
options.emplace_back(new Utf8NormalizeOptions());
options.emplace_back(new Utf8NormalizeOptions(Utf8NormalizeOptions::NFD));
options.emplace_back(
new ReverseIndicesOptions(/*output_length=*/42, /*output_type=*/int32()));
options.emplace_back(new PermuteOptions());
options.emplace_back(new PermuteOptions(/*output_length=*/42));

for (size_t i = 0; i < options.size(); i++) {
const size_t prev_i = i == 0 ? options.size() - 1 : i - 1;
Expand Down
6 changes: 6 additions & 0 deletions cpp/src/arrow/compute/kernels/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,12 @@ add_arrow_compute_test(vector_selection_test
EXTRA_LINK_LIBS
arrow_compute_kernels_testing)

add_arrow_compute_test(vector_placement_test
SOURCES
vector_placement_test.cc
EXTRA_LINK_LIBS
arrow_compute_kernels_testing)

add_arrow_benchmark(vector_hash_benchmark PREFIX "arrow-compute")
add_arrow_benchmark(vector_sort_benchmark PREFIX "arrow-compute")
add_arrow_benchmark(vector_partition_benchmark PREFIX "arrow-compute")
Expand Down
5 changes: 3 additions & 2 deletions cpp/src/arrow/compute/kernels/codegen_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -1037,8 +1037,9 @@ ArrayKernelExec GenerateFloatingPoint(detail::GetTypeId get_id) {
// Generate a kernel given a templated functor for integer types
//
// See "Numeric" above for description of the generator functor
template <template <typename...> class Generator, typename Type0, typename... Args>
ArrayKernelExec GenerateInteger(detail::GetTypeId get_id) {
template <template <typename...> class Generator, typename Type0,
typename KernelType = ArrayKernelExec, typename... Args>
KernelType GenerateInteger(detail::GetTypeId get_id) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this change is for generate exec_chunked?

Copy link
Contributor Author

@zanmato1984 zanmato1984 Nov 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Just slightly extending it like GenerateNumeric.

switch (get_id.id) {
case Type::INT8:
return Generator<Type0, Int8Type, Args...>::Exec;
Expand Down
Loading
Loading