Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Extend the simple UDAF interface with function-level variables #12067

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 35 additions & 6 deletions velox/docs/develop/aggregate-functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,23 @@ A simple aggregation function is implemented as a class as the following.

// Optional.
static bool toIntermediate(
exec::out_type<Array<Generic<T1>>>& out,
exec::optional_arg_type<Generic<T1>> in);
exec::out_type<Array<Generic<T1>>>& out,
exec::optional_arg_type<Generic<T1>> in);

// Optional. Define some function-level variables.
TypePtr inputType_;
TypePtr resultType_;

// Optional. Defined only when the aggregation function needs to use function-level variables.
// This method is called once when the aggregation function is created.
void initialize(
core::AggregationNode::Step step,
const std::vector<TypePtr>& argTypes,
const TypePtr& resultType) {
VELOX_CHECK_EQ(argTypes.size(), 1);
inputType_ = argTypes[0];
resultType_ = resultType;
}

struct AccumulatorType { ... };
};
Expand All @@ -169,6 +184,14 @@ function's argument type(s) wrapped in a Row<> even if the function only takes
one argument. This is needed for the SimpleAggregateAdapter to parse input
types for arbitrary aggregation functions properly.

The author can optionally define function-level variables in the simple
aggregation function class. These variables are initialized once when the
aggregation function is created and used at every row when adding inputs to
accumulators or extracting values from accumulators. For example, if the
aggregation function needs to get the result type or the raw input type of the
aggregation function, these types can be defined as member variables in the
aggregate class and initialized in the initialize() method.

The author can define an optional flag `default_null_behavior_` indicating
whether the aggregation function has default-null behavior. This flag is true
by default. Next, the class can have an optional method `toIntermediate()`
Expand Down Expand Up @@ -248,6 +271,10 @@ For aggregaiton functions of default-null behavior, the author defines an
// Author defines data members
...

// Optional. Define a pointer to the UDAF class if the aggregation
// function uses function-level variables.
ArrayAggAggregate* fn_;

// Optional. Default is true.
static constexpr bool is_fixed_size_ = false;

Expand All @@ -257,7 +284,9 @@ For aggregaiton functions of default-null behavior, the author defines an
// Optional. Default is false.
static constexpr bool is_aligned_ = true;

explicit AccumulatorType(HashStringAllocator* allocator);
explicit AccumulatorType(
HashStringAllocator* allocator, ArrayAggAggregate* fn)
: fn_(fn);

void addInput(HashStringAllocator* allocator, exec::arg_type<T1> value1, ...);

Expand Down Expand Up @@ -353,7 +382,7 @@ For aggregaiton functions of non-default-null behavior, the author defines an
// Optional. Default is false.
static constexpr bool is_aligned_ = true;

explicit AccumulatorType(HashStringAllocator* allocator);
explicit AccumulatorType(HashStringAllocator* allocator, ArrayAggAggregate* fn);

bool addInput(HashStringAllocator* allocator, exec::optional_arg_type<T1> value1, ...);

Expand Down Expand Up @@ -873,15 +902,15 @@ unique pointers. Below is an example.
name,
std::move(signatures),
[name](
core::AggregationNode::Step /*step*/,
core::AggregationNode::Step step,
const std::vector<TypePtr>& argTypes,
const TypePtr& resultType,
const core::QueryConfig& /*config*/)
-> std::unique_ptr<exec::Aggregate> {
VELOX_CHECK_EQ(
argTypes.size(), 1, "{} takes at most one argument", name);
return std::make_unique<SimpleAggregateAdapter<SimpleArrayAggAggregate>>(
resultType);
step, argTypes, resultType);
});
}

Expand Down
25 changes: 22 additions & 3 deletions velox/exec/SimpleAggregateAdapter.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,15 @@ using optional_arg_type = OptionalAccessor<T>;
template <typename FUNC>
class SimpleAggregateAdapter : public Aggregate {
public:
explicit SimpleAggregateAdapter(TypePtr resultType)
: Aggregate(std::move(resultType)) {}
explicit SimpleAggregateAdapter(
core::AggregationNode::Step step,
const std::vector<TypePtr>& argTypes,
TypePtr resultType)
: Aggregate(std::move(resultType)), fn_{std::make_unique<FUNC>()} {
if constexpr (support_initialize_) {
fn_->initialize(step, argTypes, resultType_);
}
}

// Assume most aggregate functions have fixed-size accumulators. Functions
// that
Expand Down Expand Up @@ -145,6 +152,13 @@ class SimpleAggregateAdapter : public Aggregate {
struct support_to_intermediate<T, std::void_t<decltype(&T::toIntermediate)>>
: std::true_type {};

template <typename T, typename = void>
struct support_initialize : std::false_type {};

template <typename T>
struct support_initialize<T, std::void_t<decltype(&T::initialize)>>
: std::true_type {};

// Whether the accumulator requires aligned access. If it is defined,
// SimpleAggregateAdapter::accumulatorAlignmentSize() returns
// alignof(typename FUNC::AccumulatorType).
Expand Down Expand Up @@ -172,6 +186,8 @@ class SimpleAggregateAdapter : public Aggregate {
static constexpr bool support_to_intermediate_ =
support_to_intermediate<FUNC>::value;

static constexpr bool support_initialize_ = support_initialize<FUNC>::value;

static constexpr bool accumulator_is_aligned_ =
accumulator_is_aligned<typename FUNC::AccumulatorType>::value;

Expand Down Expand Up @@ -350,7 +366,8 @@ class SimpleAggregateAdapter : public Aggregate {
folly::Range<const vector_size_t*> indices) override {
setAllNulls(groups, indices);
for (auto i : indices) {
new (groups[i] + offset_) typename FUNC::AccumulatorType(allocator_);
new (groups[i] + offset_)
typename FUNC::AccumulatorType(allocator_, fn_.get());
}
}

Expand Down Expand Up @@ -571,6 +588,8 @@ class SimpleAggregateAdapter : public Aggregate {

std::vector<DecodedVector> inputDecoded_;
DecodedVector intermediateDecoded_;

std::unique_ptr<FUNC> fn_;
};

} // namespace facebook::velox::exec
139 changes: 136 additions & 3 deletions velox/exec/tests/SimpleAggregateAdapterTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,9 @@ class CountNullsAggregate {

Accumulator() = delete;

explicit Accumulator(HashStringAllocator* /*allocator*/) {
explicit Accumulator(
HashStringAllocator* /*allocator*/,
CountNullsAggregate* /*fn*/) {
nullsCount_ = 0;
}

Expand Down Expand Up @@ -423,15 +425,15 @@ exec::AggregateRegistrationResult registerSimpleCountNullsAggregate(
name,
std::move(signatures),
[name](
core::AggregationNode::Step /*step*/,
core::AggregationNode::Step step,
const std::vector<TypePtr>& argTypes,
const TypePtr& resultType,
const core::QueryConfig& /*config*/)
-> std::unique_ptr<exec::Aggregate> {
VELOX_CHECK_LE(
argTypes.size(), 1, "{} takes at most one argument", name);
return std::make_unique<SimpleAggregateAdapter<CountNullsAggregate>>(
resultType);
step, argTypes, resultType);
},
false /*registerCompanionFunctions*/,
true /*overwrite*/);
Expand Down Expand Up @@ -469,5 +471,136 @@ TEST_F(SimpleCountNullsAggregationTest, basic) {
testAggregations({vectors}, {}, {"simple_count_nulls(c2)"}, {expected});
}

// A testing simple avg aggregate function, and it is used to check for
// expectations for function-level variables. The validation logic is in the
// Accumulator::addInput method.
class FuncLevelVariableTestAggregate {
public:
using InputType = Row<int64_t>;
using IntermediateType = Row<int64_t, double>;
using OutputType = double;

// These two variables are used for testing, they are set during the creation
// of the aggregation function and will be checked in addInput().
TypePtr inputType_;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Add a comment above to say explicitly that these are the function level variable we're going to test, and explain that they are set during the creation of the aggregation function and checked in addInput().

TypePtr resultType_;

void initialize(
core::AggregationNode::Step step,
const std::vector<TypePtr>& argTypes,
const TypePtr& resultType) {
VELOX_CHECK_EQ(argTypes.size(), 1);
inputType_ = argTypes[0];
resultType_ = resultType;
}

struct Accumulator {
int64_t sum{0};
double count{0};
FuncLevelVariableTestAggregate* fn_;

explicit Accumulator(
HashStringAllocator* /*allocator*/,
FuncLevelVariableTestAggregate* fn)
: fn_(fn) {}

void addInput(
HashStringAllocator* /*allocator*/,
exec::arg_type<int64_t> data) {
VELOX_CHECK_NOT_NULL(fn_->inputType_);
VELOX_CHECK_NOT_NULL(fn_->resultType_);
if (fn_->inputType_->isRow()) {
VELOX_CHECK_EQ(fn_->inputType_->size(), 2);
VELOX_CHECK_EQ(fn_->inputType_->childAt(0), BIGINT());
VELOX_CHECK_EQ(fn_->inputType_->childAt(1), DOUBLE());
} else {
VELOX_CHECK_EQ(fn_->inputType_, BIGINT());
}
if (fn_->resultType_->isRow()) {
VELOX_CHECK_EQ(fn_->resultType_->size(), 2);
VELOX_CHECK_EQ(fn_->resultType_->childAt(0), BIGINT());
VELOX_CHECK_EQ(fn_->resultType_->childAt(1), DOUBLE());
} else {
VELOX_CHECK_EQ(fn_->resultType_, DOUBLE());
}
sum += data;
count = checkedPlus<int64_t>(count, 1);
}

void combine(
HashStringAllocator* /*allocator*/,
exec::arg_type<IntermediateType> other) {
VELOX_CHECK(other.at<0>().has_value());
VELOX_CHECK(other.at<1>().has_value());
sum += other.at<0>().value();
count += other.at<1>().value();
}

bool writeIntermediateResult(exec::out_type<IntermediateType>& out) {
out = std::make_tuple(sum, count);
return true;
}

bool writeFinalResult(exec::out_type<OutputType>& out) {
out = sum / count;
return true;
}
};

using AccumulatorType = Accumulator;
};

exec::AggregateRegistrationResult registerFuncLevelVariableTestAggregate(
const std::string& name) {
std::vector<std::shared_ptr<exec::AggregateFunctionSignature>> signatures{
exec::AggregateFunctionSignatureBuilder()
.returnType("DOUBLE")
.intermediateType("ROW(BIGINT, DOUBLE)")
.argumentType("BIGINT")
.build()};

return exec::registerAggregateFunction(
name,
std::move(signatures),
[name](
core::AggregationNode::Step step,
const std::vector<TypePtr>& argTypes,
const TypePtr& resultType,
const core::QueryConfig& /*config*/)
-> std::unique_ptr<exec::Aggregate> {
VELOX_CHECK_LE(argTypes.size(), 1, "{} takes at most 1 argument", name);
return std::make_unique<
SimpleAggregateAdapter<FuncLevelVariableTestAggregate>>(
step, argTypes, resultType);
},
true /*registerCompanionFunctions*/,
true /*overwrite*/);
}

class SimpleFuncLevelVariableAggregationTest : public AggregationTestBase {
protected:
void SetUp() override {
AggregationTestBase::SetUp();
registerFuncLevelVariableTestAggregate("simple_func_level_variable_agg");
}
};

TEST_F(SimpleFuncLevelVariableAggregationTest, simpleAggregateVariables) {
auto inputVectors = makeRowVector({makeFlatVector<int64_t>({1, 2, 3, 4})});
std::vector<double> finalResult = {2.5};
auto expected = makeRowVector({makeFlatVector<double>(finalResult)});
testAggregations(
{inputVectors}, {}, {"simple_func_level_variable_agg(c0)"}, {expected});
testAggregationsWithCompanion(
{inputVectors},
[](auto& /*builder*/) {},
{},
{"simple_func_level_variable_agg(c0)"},
{{BIGINT()}},
{},
{expected},
{});
}

} // namespace
} // namespace facebook::velox::aggregate::test
8 changes: 5 additions & 3 deletions velox/exec/tests/SimpleArrayAggAggregate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@ class ArrayAggAggregate {
AccumulatorType() = delete;

// Constructor used in initializeNewGroups().
explicit AccumulatorType(HashStringAllocator* /*allocator*/)
explicit AccumulatorType(
HashStringAllocator* /*allocator*/,
ArrayAggAggregate* /*fn*/)
: elements_{} {}

static constexpr bool is_fixed_size_ = false;
Expand Down Expand Up @@ -127,15 +129,15 @@ exec::AggregateRegistrationResult registerSimpleArrayAggAggregate(
name,
std::move(signatures),
[name](
core::AggregationNode::Step /*step*/,
core::AggregationNode::Step step,
const std::vector<TypePtr>& argTypes,
const TypePtr& resultType,
const core::QueryConfig& /*config*/)
-> std::unique_ptr<exec::Aggregate> {
VELOX_CHECK_EQ(
argTypes.size(), 1, "{} takes at most one argument", name);
return std::make_unique<SimpleAggregateAdapter<ArrayAggAggregate>>(
resultType);
step, argTypes, resultType);
},
true /*registerCompanionFunctions*/,
true /*overwrite*/);
Expand Down
Loading
Loading