Skip to content

Commit

Permalink
extend the simple UDAF interface with function-level variables
Browse files Browse the repository at this point in the history
  • Loading branch information
liujiayi771 committed Jan 13, 2025
1 parent 022cd87 commit 8d84155
Show file tree
Hide file tree
Showing 13 changed files with 308 additions and 72 deletions.
41 changes: 35 additions & 6 deletions velox/docs/develop/aggregate-functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,24 @@ A simple aggregation function is implemented as a class as the following.

// Optional.
static bool toIntermediate(
exec::out_type<Array<Generic<T1>>>& out,
exec::optional_arg_type<Generic<T1>> in);
exec::out_type<Array<Generic<T1>>>& out,
exec::optional_arg_type<Generic<T1>> in);

// Optional. Define some function-level variables.
TypePtr inputType;
TypePtr resultType;

// Optional. Defined only when the aggregation function needs to use function-level variables.
// This method is called once when the aggregation function is created.
static void initialize(
std::shared_ptr<FuncLevelVariableTestAggregate> fn,
core::AggregationNode::Step step,
const std::vector<TypePtr>& argTypes,
const TypePtr& resultType) {
VELOX_CHECK_EQ(argTypes.size(), 1);
fn->inputType_ = argTypes[0];
fn->resultType_ = resultType;
}

struct AccumulatorType { ... };
};
Expand All @@ -169,6 +185,14 @@ function's argument type(s) wrapped in a Row<> even if the function only takes
one argument. This is needed for the SimpleAggregateAdapter to parse input
types for arbitrary aggregation functions properly.

Some function-level variables needs to be declared in the simple aggregation
function class. These variables are initialized once when the aggregation
function is created and used at every row when adding inputs to accumulators
or extracting values from accumulators. For example, if the aggregation
function needs to get the result type or the raw input type of the aggregaiton
function, the author can hold them in the aggregate class variables, and
initialize them in the initialize() method.

The author can define an optional flag `default_null_behavior_` indicating
whether the aggregation function has default-null behavior. This flag is true
by default. Next, the class can have an optional method `toIntermediate()`
Expand Down Expand Up @@ -248,6 +272,9 @@ For aggregaiton functions of default-null behavior, the author defines an
// Author defines data members
...

// Define a pointer to the UDAF class.
std::shared_ptr<ArrayAggAggregate> fn_;

// Optional. Default is true.
static constexpr bool is_fixed_size_ = false;

Expand All @@ -257,7 +284,9 @@ For aggregaiton functions of default-null behavior, the author defines an
// Optional. Default is false.
static constexpr bool is_aligned_ = true;

explicit AccumulatorType(HashStringAllocator* allocator);
explicit AccumulatorType(
HashStringAllocator* allocator, ArrayAggAggregate fn)
: fn_(fn);

void addInput(HashStringAllocator* allocator, exec::arg_type<T1> value1, ...);

Expand Down Expand Up @@ -353,7 +382,7 @@ For aggregaiton functions of non-default-null behavior, the author defines an
// Optional. Default is false.
static constexpr bool is_aligned_ = true;

explicit AccumulatorType(HashStringAllocator* allocator);
explicit AccumulatorType(HashStringAllocator* allocator, ArrayAggAggregate* fn);

bool addInput(HashStringAllocator* allocator, exec::optional_arg_type<T1> value1, ...);

Expand Down Expand Up @@ -873,15 +902,15 @@ unique pointers. Below is an example.
name,
std::move(signatures),
[name](
core::AggregationNode::Step /*step*/,
core::AggregationNode::Step step,
const std::vector<TypePtr>& argTypes,
const TypePtr& resultType,
const core::QueryConfig& /*config*/)
-> std::unique_ptr<exec::Aggregate> {
VELOX_CHECK_EQ(
argTypes.size(), 1, "{} takes at most one argument", name);
return std::make_unique<SimpleAggregateAdapter<SimpleArrayAggAggregate>>(
resultType);
step, argTypes, resultType);
});
}

Expand Down
24 changes: 21 additions & 3 deletions velox/exec/SimpleAggregateAdapter.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,15 @@ using optional_arg_type = OptionalAccessor<T>;
template <typename FUNC>
class SimpleAggregateAdapter : public Aggregate {
public:
explicit SimpleAggregateAdapter(TypePtr resultType)
: Aggregate(std::move(resultType)) {}
explicit SimpleAggregateAdapter(
core::AggregationNode::Step step,
const std::vector<TypePtr>& argTypes,
TypePtr resultType)
: Aggregate(std::move(resultType)), fn_{std::make_unique<FUNC>()} {
if constexpr (support_initialize_) {
FUNC::initialize(fn_.get(), step, argTypes, resultType_);
}
}

// Assume most aggregate functions have fixed-size accumulators. Functions
// that
Expand Down Expand Up @@ -103,6 +110,7 @@ class SimpleAggregateAdapter : public Aggregate {
// These functions are called on groups of both non-null and null
// accumulators. These functions also return a bool indicating whether the
// current group should be a NULL in the result vector.
std::unique_ptr<FUNC> fn_;
template <typename T, typename = void>
struct aggregate_default_null_behavior : std::true_type {};

Expand Down Expand Up @@ -145,6 +153,13 @@ class SimpleAggregateAdapter : public Aggregate {
struct support_to_intermediate<T, std::void_t<decltype(&T::toIntermediate)>>
: std::true_type {};

template <typename T, typename = void>
struct support_initialize : std::false_type {};

template <typename T>
struct support_initialize<T, std::void_t<decltype(&T::initialize)>>
: std::true_type {};

// Whether the accumulator requires aligned access. If it is defined,
// SimpleAggregateAdapter::accumulatorAlignmentSize() returns
// alignof(typename FUNC::AccumulatorType).
Expand Down Expand Up @@ -172,6 +187,8 @@ class SimpleAggregateAdapter : public Aggregate {
static constexpr bool support_to_intermediate_ =
support_to_intermediate<FUNC>::value;

static constexpr bool support_initialize_ = support_initialize<FUNC>::value;

static constexpr bool accumulator_is_aligned_ =
accumulator_is_aligned<typename FUNC::AccumulatorType>::value;

Expand Down Expand Up @@ -350,7 +367,8 @@ class SimpleAggregateAdapter : public Aggregate {
folly::Range<const vector_size_t*> indices) override {
setAllNulls(groups, indices);
for (auto i : indices) {
new (groups[i] + offset_) typename FUNC::AccumulatorType(allocator_);
new (groups[i] + offset_)
typename FUNC::AccumulatorType(allocator_, fn_.get());
}
}

Expand Down
138 changes: 135 additions & 3 deletions velox/exec/tests/SimpleAggregateAdapterTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,9 @@ class CountNullsAggregate {

Accumulator() = delete;

explicit Accumulator(HashStringAllocator* /*allocator*/) {
explicit Accumulator(
HashStringAllocator* /*allocator*/,
CountNullsAggregate* /*fn*/) {
nullsCount_ = 0;
}

Expand Down Expand Up @@ -423,15 +425,15 @@ exec::AggregateRegistrationResult registerSimpleCountNullsAggregate(
name,
std::move(signatures),
[name](
core::AggregationNode::Step /*step*/,
core::AggregationNode::Step step,
const std::vector<TypePtr>& argTypes,
const TypePtr& resultType,
const core::QueryConfig& /*config*/)
-> std::unique_ptr<exec::Aggregate> {
VELOX_CHECK_LE(
argTypes.size(), 1, "{} takes at most one argument", name);
return std::make_unique<SimpleAggregateAdapter<CountNullsAggregate>>(
resultType);
step, argTypes, resultType);
},
false /*registerCompanionFunctions*/,
true /*overwrite*/);
Expand Down Expand Up @@ -469,5 +471,135 @@ TEST_F(SimpleCountNullsAggregationTest, basic) {
testAggregations({vectors}, {}, {"simple_count_nulls(c2)"}, {expected});
}

// A testing simple avg aggregate function, and it is used to check for
// expectations for function-level variables. The validation logic is in the
// Accumulator::addInput method.
class FuncLevelVariableTestAggregate {
public:
using InputType = Row<int64_t>;
using IntermediateType = Row<int64_t, double>;
using OutputType = double;

TypePtr inputType_;
TypePtr resultType_;

static void initialize(
FuncLevelVariableTestAggregate* fn,
core::AggregationNode::Step step,
const std::vector<TypePtr>& argTypes,
const TypePtr& resultType) {
VELOX_CHECK_EQ(argTypes.size(), 1);
fn->inputType_ = argTypes[0];
fn->resultType_ = resultType;
}

struct Accumulator {
int64_t sum{0};
double count{0};
FuncLevelVariableTestAggregate* fn_;

explicit Accumulator(
HashStringAllocator* /*allocator*/,
FuncLevelVariableTestAggregate* fn)
: fn_(fn) {}

void addInput(
HashStringAllocator* /*allocator*/,
exec::arg_type<int64_t> data) {
VELOX_CHECK_NOT_NULL(fn_->inputType_);
VELOX_CHECK_NOT_NULL(fn_->resultType_);
if (fn_->inputType_->isRow()) {
VELOX_CHECK_EQ(fn_->inputType_->size(), 2);
VELOX_CHECK_EQ(fn_->inputType_->childAt(0), BIGINT());
VELOX_CHECK_EQ(fn_->inputType_->childAt(1), DOUBLE());
} else {
VELOX_CHECK_EQ(fn_->inputType_, BIGINT());
}
if (fn_->resultType_->isRow()) {
VELOX_CHECK_EQ(fn_->resultType_->size(), 2);
VELOX_CHECK_EQ(fn_->resultType_->childAt(0), BIGINT());
VELOX_CHECK_EQ(fn_->resultType_->childAt(1), DOUBLE());
} else {
VELOX_CHECK_EQ(fn_->resultType_, DOUBLE());
}
sum += data;
count = checkedPlus<int64_t>(count, 1);
}

void combine(
HashStringAllocator* /*allocator*/,
exec::arg_type<IntermediateType> other) {
VELOX_CHECK(other.at<0>().has_value());
VELOX_CHECK(other.at<1>().has_value());
sum += other.at<0>().value();
count += other.at<1>().value();
}

bool writeIntermediateResult(exec::out_type<IntermediateType>& out) {
out = std::make_tuple(sum, count);
return true;
}

bool writeFinalResult(exec::out_type<OutputType>& out) {
out = sum / count;
return true;
}
};

using AccumulatorType = Accumulator;
};

exec::AggregateRegistrationResult registerFuncLevelVariableTestAggregate(
const std::string& name) {
std::vector<std::shared_ptr<exec::AggregateFunctionSignature>> signatures{
exec::AggregateFunctionSignatureBuilder()
.returnType("DOUBLE")
.intermediateType("ROW(BIGINT, DOUBLE)")
.argumentType("BIGINT")
.build()};

return exec::registerAggregateFunction(
name,
std::move(signatures),
[name](
core::AggregationNode::Step step,
const std::vector<TypePtr>& argTypes,
const TypePtr& resultType,
const core::QueryConfig& /*config*/)
-> std::unique_ptr<exec::Aggregate> {
VELOX_CHECK_LE(argTypes.size(), 1, "{} takes at most 1 argument", name);
return std::make_unique<
SimpleAggregateAdapter<FuncLevelVariableTestAggregate>>(
step, argTypes, resultType);
},
true /*registerCompanionFunctions*/,
true /*overwrite*/);
}

class SimpleFuncLevelVariableAggregationTest : public AggregationTestBase {
protected:
void SetUp() override {
AggregationTestBase::SetUp();
registerFuncLevelVariableTestAggregate("simple_func_level_variable_agg");
}
};

TEST_F(SimpleFuncLevelVariableAggregationTest, simpleAggregateVariables) {
auto inputVectors = makeRowVector({makeFlatVector<int64_t>({1, 2, 3, 4})});
std::vector<double> finalResult = {2.5};
auto expected = makeRowVector({makeFlatVector<double>(finalResult)});
testAggregations(
{inputVectors}, {}, {"simple_func_level_variable_agg(c0)"}, {expected});
testAggregationsWithCompanion(
{inputVectors},
[](auto& /*builder*/) {},
{},
{"simple_func_level_variable_agg(c0)"},
{{BIGINT()}},
{},
{expected},
{});
}

} // namespace
} // namespace facebook::velox::aggregate::test
8 changes: 5 additions & 3 deletions velox/exec/tests/SimpleArrayAggAggregate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@ class ArrayAggAggregate {
AccumulatorType() = delete;

// Constructor used in initializeNewGroups().
explicit AccumulatorType(HashStringAllocator* /*allocator*/)
explicit AccumulatorType(
HashStringAllocator* /*allocator*/,
ArrayAggAggregate* /*fn*/)
: elements_{} {}

static constexpr bool is_fixed_size_ = false;
Expand Down Expand Up @@ -127,15 +129,15 @@ exec::AggregateRegistrationResult registerSimpleArrayAggAggregate(
name,
std::move(signatures),
[name](
core::AggregationNode::Step /*step*/,
core::AggregationNode::Step step,
const std::vector<TypePtr>& argTypes,
const TypePtr& resultType,
const core::QueryConfig& /*config*/)
-> std::unique_ptr<exec::Aggregate> {
VELOX_CHECK_EQ(
argTypes.size(), 1, "{} takes at most one argument", name);
return std::make_unique<SimpleAggregateAdapter<ArrayAggAggregate>>(
resultType);
step, argTypes, resultType);
},
true /*registerCompanionFunctions*/,
true /*overwrite*/);
Expand Down
22 changes: 14 additions & 8 deletions velox/exec/tests/SimpleAverageAggregate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,9 @@ class AverageAggregate {
AccumulatorType() = delete;

// Constructor used in initializeNewGroups().
explicit AccumulatorType(HashStringAllocator* /*allocator*/) {
explicit AccumulatorType(
HashStringAllocator* /*allocator*/,
AverageAggregate* /*fn*/) {
sum_ = 0;
count_ = 0;
}
Expand Down Expand Up @@ -130,21 +132,23 @@ exec::AggregateRegistrationResult registerSimpleAverageAggregate(
case TypeKind::SMALLINT:
return std::make_unique<
SimpleAggregateAdapter<AverageAggregate<int16_t>>>(
resultType);
step, argTypes, resultType);
case TypeKind::INTEGER:
return std::make_unique<
SimpleAggregateAdapter<AverageAggregate<int32_t>>>(
resultType);
step, argTypes, resultType);
case TypeKind::BIGINT:
return std::make_unique<
SimpleAggregateAdapter<AverageAggregate<int64_t>>>(
resultType);
step, argTypes, resultType);
case TypeKind::REAL:
return std::make_unique<
SimpleAggregateAdapter<AverageAggregate<float>>>(resultType);
SimpleAggregateAdapter<AverageAggregate<float>>>(
step, argTypes, resultType);
case TypeKind::DOUBLE:
return std::make_unique<
SimpleAggregateAdapter<AverageAggregate<double>>>(resultType);
SimpleAggregateAdapter<AverageAggregate<double>>>(
step, argTypes, resultType);
default:
VELOX_FAIL(
"Unknown input type for {} aggregation {}",
Expand All @@ -155,11 +159,13 @@ exec::AggregateRegistrationResult registerSimpleAverageAggregate(
switch (resultType->kind()) {
case TypeKind::REAL:
return std::make_unique<
SimpleAggregateAdapter<AverageAggregate<float>>>(resultType);
SimpleAggregateAdapter<AverageAggregate<float>>>(
step, argTypes, resultType);
case TypeKind::DOUBLE:
case TypeKind::ROW:
return std::make_unique<
SimpleAggregateAdapter<AverageAggregate<double>>>(resultType);
SimpleAggregateAdapter<AverageAggregate<double>>>(
step, argTypes, resultType);
default:
VELOX_FAIL(
"Unsupported result type for final aggregation: {}",
Expand Down
Loading

0 comments on commit 8d84155

Please sign in to comment.