diff --git a/velox/docs/develop/aggregate-functions.rst b/velox/docs/develop/aggregate-functions.rst index 8c60890dfcf4..0ca05fa5a332 100644 --- a/velox/docs/develop/aggregate-functions.rst +++ b/velox/docs/develop/aggregate-functions.rst @@ -157,8 +157,24 @@ A simple aggregation function is implemented as a class as the following. // Optional. static bool toIntermediate( - exec::out_type>>& out, - exec::optional_arg_type> in); + exec::out_type>>& out, + exec::optional_arg_type> in); + + // Optional. Define some function-level variables. + TypePtr inputType; + TypePtr resultType; + + // Optional. Defined only when the aggregation function needs to use function-level variables. + // This method is called once when the aggregation function is created. + static void initialize( + std::shared_ptr fn, + core::AggregationNode::Step step, + const std::vector& argTypes, + const TypePtr& resultType) { + VELOX_CHECK_EQ(argTypes.size(), 1); + fn->inputType_ = argTypes[0]; + fn->resultType_ = resultType; + } struct AccumulatorType { ... }; }; @@ -169,6 +185,14 @@ function's argument type(s) wrapped in a Row<> even if the function only takes one argument. This is needed for the SimpleAggregateAdapter to parse input types for arbitrary aggregation functions properly. +Some function-level variables needs to be declared in the simple aggregation +function class. These variables are initialized once when the aggregation +function is created and used at every row when adding inputs to accumulators +or extracting values from accumulators. For example, if the aggregation +function needs to get the result type or the raw input type of the aggregaiton +function, the author can hold them in the aggregate class variables, and +initialize them in the initialize() method. + The author can define an optional flag `default_null_behavior_` indicating whether the aggregation function has default-null behavior. This flag is true by default. Next, the class can have an optional method `toIntermediate()` @@ -248,6 +272,9 @@ For aggregaiton functions of default-null behavior, the author defines an // Author defines data members ... + // Define a pointer to the UDAF class. + std::shared_ptr fn_; + // Optional. Default is true. static constexpr bool is_fixed_size_ = false; @@ -257,7 +284,9 @@ For aggregaiton functions of default-null behavior, the author defines an // Optional. Default is false. static constexpr bool is_aligned_ = true; - explicit AccumulatorType(HashStringAllocator* allocator); + explicit AccumulatorType( + HashStringAllocator* allocator, ArrayAggAggregate fn) + : fn_(fn); void addInput(HashStringAllocator* allocator, exec::arg_type value1, ...); @@ -353,7 +382,7 @@ For aggregaiton functions of non-default-null behavior, the author defines an // Optional. Default is false. static constexpr bool is_aligned_ = true; - explicit AccumulatorType(HashStringAllocator* allocator); + explicit AccumulatorType(HashStringAllocator* allocator, ArrayAggAggregate* fn); bool addInput(HashStringAllocator* allocator, exec::optional_arg_type value1, ...); @@ -873,7 +902,7 @@ unique pointers. Below is an example. name, std::move(signatures), [name]( - core::AggregationNode::Step /*step*/, + core::AggregationNode::Step step, const std::vector& argTypes, const TypePtr& resultType, const core::QueryConfig& /*config*/) @@ -881,7 +910,7 @@ unique pointers. Below is an example. VELOX_CHECK_EQ( argTypes.size(), 1, "{} takes at most one argument", name); return std::make_unique>( - resultType); + step, argTypes, resultType); }); } diff --git a/velox/exec/SimpleAggregateAdapter.h b/velox/exec/SimpleAggregateAdapter.h index d92584c6d647..1223a90e510e 100644 --- a/velox/exec/SimpleAggregateAdapter.h +++ b/velox/exec/SimpleAggregateAdapter.h @@ -41,8 +41,15 @@ using optional_arg_type = OptionalAccessor; template class SimpleAggregateAdapter : public Aggregate { public: - explicit SimpleAggregateAdapter(TypePtr resultType) - : Aggregate(std::move(resultType)) {} + explicit SimpleAggregateAdapter( + core::AggregationNode::Step step, + const std::vector& argTypes, + TypePtr resultType) + : Aggregate(std::move(resultType)), fn_{std::make_unique()} { + if constexpr (support_initialize_) { + FUNC::initialize(fn_.get(), step, argTypes, resultType_); + } + } // Assume most aggregate functions have fixed-size accumulators. Functions // that @@ -103,6 +110,7 @@ class SimpleAggregateAdapter : public Aggregate { // These functions are called on groups of both non-null and null // accumulators. These functions also return a bool indicating whether the // current group should be a NULL in the result vector. + std::unique_ptr fn_; template struct aggregate_default_null_behavior : std::true_type {}; @@ -145,6 +153,13 @@ class SimpleAggregateAdapter : public Aggregate { struct support_to_intermediate> : std::true_type {}; + template + struct support_initialize : std::false_type {}; + + template + struct support_initialize> + : std::true_type {}; + // Whether the accumulator requires aligned access. If it is defined, // SimpleAggregateAdapter::accumulatorAlignmentSize() returns // alignof(typename FUNC::AccumulatorType). @@ -172,6 +187,8 @@ class SimpleAggregateAdapter : public Aggregate { static constexpr bool support_to_intermediate_ = support_to_intermediate::value; + static constexpr bool support_initialize_ = support_initialize::value; + static constexpr bool accumulator_is_aligned_ = accumulator_is_aligned::value; @@ -350,7 +367,8 @@ class SimpleAggregateAdapter : public Aggregate { folly::Range indices) override { setAllNulls(groups, indices); for (auto i : indices) { - new (groups[i] + offset_) typename FUNC::AccumulatorType(allocator_); + new (groups[i] + offset_) + typename FUNC::AccumulatorType(allocator_, fn_.get()); } } diff --git a/velox/exec/tests/SimpleAggregateAdapterTest.cpp b/velox/exec/tests/SimpleAggregateAdapterTest.cpp index 0a71f843a3f1..7da664d5b666 100644 --- a/velox/exec/tests/SimpleAggregateAdapterTest.cpp +++ b/velox/exec/tests/SimpleAggregateAdapterTest.cpp @@ -362,7 +362,9 @@ class CountNullsAggregate { Accumulator() = delete; - explicit Accumulator(HashStringAllocator* /*allocator*/) { + explicit Accumulator( + HashStringAllocator* /*allocator*/, + CountNullsAggregate* /*fn*/) { nullsCount_ = 0; } @@ -423,7 +425,7 @@ exec::AggregateRegistrationResult registerSimpleCountNullsAggregate( name, std::move(signatures), [name]( - core::AggregationNode::Step /*step*/, + core::AggregationNode::Step step, const std::vector& argTypes, const TypePtr& resultType, const core::QueryConfig& /*config*/) @@ -431,7 +433,7 @@ exec::AggregateRegistrationResult registerSimpleCountNullsAggregate( VELOX_CHECK_LE( argTypes.size(), 1, "{} takes at most one argument", name); return std::make_unique>( - resultType); + step, argTypes, resultType); }, false /*registerCompanionFunctions*/, true /*overwrite*/); @@ -469,5 +471,135 @@ TEST_F(SimpleCountNullsAggregationTest, basic) { testAggregations({vectors}, {}, {"simple_count_nulls(c2)"}, {expected}); } +// A testing simple avg aggregate function, and it is used to check for +// expectations for function-level variables. The validation logic is in the +// Accumulator::addInput method. +class FuncLevelVariableTestAggregate { + public: + using InputType = Row; + using IntermediateType = Row; + using OutputType = double; + + TypePtr inputType_; + TypePtr resultType_; + + static void initialize( + FuncLevelVariableTestAggregate* fn, + core::AggregationNode::Step step, + const std::vector& argTypes, + const TypePtr& resultType) { + VELOX_CHECK_EQ(argTypes.size(), 1); + fn->inputType_ = argTypes[0]; + fn->resultType_ = resultType; + } + + struct Accumulator { + int64_t sum{0}; + double count{0}; + FuncLevelVariableTestAggregate* fn_; + + explicit Accumulator( + HashStringAllocator* /*allocator*/, + FuncLevelVariableTestAggregate* fn) + : fn_(fn) {} + + void addInput( + HashStringAllocator* /*allocator*/, + exec::arg_type data) { + VELOX_CHECK_NOT_NULL(fn_->inputType_); + VELOX_CHECK_NOT_NULL(fn_->resultType_); + if (fn_->inputType_->isRow()) { + VELOX_CHECK_EQ(fn_->inputType_->size(), 2); + VELOX_CHECK_EQ(fn_->inputType_->childAt(0), BIGINT()); + VELOX_CHECK_EQ(fn_->inputType_->childAt(1), DOUBLE()); + } else { + VELOX_CHECK_EQ(fn_->inputType_, BIGINT()); + } + if (fn_->resultType_->isRow()) { + VELOX_CHECK_EQ(fn_->resultType_->size(), 2); + VELOX_CHECK_EQ(fn_->resultType_->childAt(0), BIGINT()); + VELOX_CHECK_EQ(fn_->resultType_->childAt(1), DOUBLE()); + } else { + VELOX_CHECK_EQ(fn_->resultType_, DOUBLE()); + } + sum += data; + count = checkedPlus(count, 1); + } + + void combine( + HashStringAllocator* /*allocator*/, + exec::arg_type other) { + VELOX_CHECK(other.at<0>().has_value()); + VELOX_CHECK(other.at<1>().has_value()); + sum += other.at<0>().value(); + count += other.at<1>().value(); + } + + bool writeIntermediateResult(exec::out_type& out) { + out = std::make_tuple(sum, count); + return true; + } + + bool writeFinalResult(exec::out_type& out) { + out = sum / count; + return true; + } + }; + + using AccumulatorType = Accumulator; +}; + +exec::AggregateRegistrationResult registerFuncLevelVariableTestAggregate( + const std::string& name) { + std::vector> signatures{ + exec::AggregateFunctionSignatureBuilder() + .returnType("DOUBLE") + .intermediateType("ROW(BIGINT, DOUBLE)") + .argumentType("BIGINT") + .build()}; + + return exec::registerAggregateFunction( + name, + std::move(signatures), + [name]( + core::AggregationNode::Step step, + const std::vector& argTypes, + const TypePtr& resultType, + const core::QueryConfig& /*config*/) + -> std::unique_ptr { + VELOX_CHECK_LE(argTypes.size(), 1, "{} takes at most 1 argument", name); + return std::make_unique< + SimpleAggregateAdapter>( + step, argTypes, resultType); + }, + true /*registerCompanionFunctions*/, + true /*overwrite*/); +} + +class SimpleFuncLevelVariableAggregationTest : public AggregationTestBase { + protected: + void SetUp() override { + AggregationTestBase::SetUp(); + registerFuncLevelVariableTestAggregate("simple_func_level_variable_agg"); + } +}; + +TEST_F(SimpleFuncLevelVariableAggregationTest, simpleAggregateVariables) { + auto inputVectors = makeRowVector({makeFlatVector({1, 2, 3, 4})}); + std::vector finalResult = {2.5}; + auto expected = makeRowVector({makeFlatVector(finalResult)}); + testAggregations( + {inputVectors}, {}, {"simple_func_level_variable_agg(c0)"}, {expected}); + testAggregationsWithCompanion( + {inputVectors}, + [](auto& /*builder*/) {}, + {}, + {"simple_func_level_variable_agg(c0)"}, + {{BIGINT()}}, + {}, + {expected}, + {}); +} + } // namespace } // namespace facebook::velox::aggregate::test diff --git a/velox/exec/tests/SimpleArrayAggAggregate.cpp b/velox/exec/tests/SimpleArrayAggAggregate.cpp index e3998aadf2a2..40a4f4e583e0 100644 --- a/velox/exec/tests/SimpleArrayAggAggregate.cpp +++ b/velox/exec/tests/SimpleArrayAggAggregate.cpp @@ -55,7 +55,9 @@ class ArrayAggAggregate { AccumulatorType() = delete; // Constructor used in initializeNewGroups(). - explicit AccumulatorType(HashStringAllocator* /*allocator*/) + explicit AccumulatorType( + HashStringAllocator* /*allocator*/, + ArrayAggAggregate* /*fn*/) : elements_{} {} static constexpr bool is_fixed_size_ = false; @@ -127,7 +129,7 @@ exec::AggregateRegistrationResult registerSimpleArrayAggAggregate( name, std::move(signatures), [name]( - core::AggregationNode::Step /*step*/, + core::AggregationNode::Step step, const std::vector& argTypes, const TypePtr& resultType, const core::QueryConfig& /*config*/) @@ -135,7 +137,7 @@ exec::AggregateRegistrationResult registerSimpleArrayAggAggregate( VELOX_CHECK_EQ( argTypes.size(), 1, "{} takes at most one argument", name); return std::make_unique>( - resultType); + step, argTypes, resultType); }, true /*registerCompanionFunctions*/, true /*overwrite*/); diff --git a/velox/exec/tests/SimpleAverageAggregate.cpp b/velox/exec/tests/SimpleAverageAggregate.cpp index 9f887f34eadf..fea9254cb1cf 100644 --- a/velox/exec/tests/SimpleAverageAggregate.cpp +++ b/velox/exec/tests/SimpleAverageAggregate.cpp @@ -56,7 +56,9 @@ class AverageAggregate { AccumulatorType() = delete; // Constructor used in initializeNewGroups(). - explicit AccumulatorType(HashStringAllocator* /*allocator*/) { + explicit AccumulatorType( + HashStringAllocator* /*allocator*/, + AverageAggregate* /*fn*/) { sum_ = 0; count_ = 0; } @@ -130,21 +132,23 @@ exec::AggregateRegistrationResult registerSimpleAverageAggregate( case TypeKind::SMALLINT: return std::make_unique< SimpleAggregateAdapter>>( - resultType); + step, argTypes, resultType); case TypeKind::INTEGER: return std::make_unique< SimpleAggregateAdapter>>( - resultType); + step, argTypes, resultType); case TypeKind::BIGINT: return std::make_unique< SimpleAggregateAdapter>>( - resultType); + step, argTypes, resultType); case TypeKind::REAL: return std::make_unique< - SimpleAggregateAdapter>>(resultType); + SimpleAggregateAdapter>>( + step, argTypes, resultType); case TypeKind::DOUBLE: return std::make_unique< - SimpleAggregateAdapter>>(resultType); + SimpleAggregateAdapter>>( + step, argTypes, resultType); default: VELOX_FAIL( "Unknown input type for {} aggregation {}", @@ -155,11 +159,13 @@ exec::AggregateRegistrationResult registerSimpleAverageAggregate( switch (resultType->kind()) { case TypeKind::REAL: return std::make_unique< - SimpleAggregateAdapter>>(resultType); + SimpleAggregateAdapter>>( + step, argTypes, resultType); case TypeKind::DOUBLE: case TypeKind::ROW: return std::make_unique< - SimpleAggregateAdapter>>(resultType); + SimpleAggregateAdapter>>( + step, argTypes, resultType); default: VELOX_FAIL( "Unsupported result type for final aggregation: {}", diff --git a/velox/functions/prestosql/aggregates/ApproxMostFrequentAggregate.cpp b/velox/functions/prestosql/aggregates/ApproxMostFrequentAggregate.cpp index 35f9af9e3886..3594dc9b547a 100644 --- a/velox/functions/prestosql/aggregates/ApproxMostFrequentAggregate.cpp +++ b/velox/functions/prestosql/aggregates/ApproxMostFrequentAggregate.cpp @@ -375,7 +375,9 @@ class ApproxMostFrequentBooleanAggregate { int64_t numTrue{0}; int64_t numFalse{0}; - explicit AccumulatorType(HashStringAllocator* /*allocator*/) {} + explicit AccumulatorType( + HashStringAllocator* /*allocator*/, + ApproxMostFrequentBooleanAggregate* /*fn*/) {} void addInput( HashStringAllocator* /*allocator*/, @@ -447,7 +449,9 @@ template std::unique_ptr makeApproxMostFrequentAggregate( const TypePtr& resultType, const std::string& name, - const TypePtr& valueType) { + const TypePtr& valueType, + core::AggregationNode::Step step, + const std::vector& argTypes) { if constexpr ( kKind == TypeKind::TINYINT || kKind == TypeKind::SMALLINT || kKind == TypeKind::INTEGER || kKind == TypeKind::BIGINT || @@ -460,7 +464,7 @@ std::unique_ptr makeApproxMostFrequentAggregate( if (kKind == TypeKind::BOOLEAN) { return std::make_unique< exec::SimpleAggregateAdapter>( - resultType); + step, argTypes, resultType); } VELOX_USER_FAIL( @@ -494,7 +498,7 @@ void registerApproxMostFrequentAggregate( std::move(signatures), [name]( core::AggregationNode::Step step, - const std::vector&, + const std::vector& argTypes, const TypePtr& resultType, const core::QueryConfig& /*config*/) -> std::unique_ptr { @@ -506,7 +510,9 @@ void registerApproxMostFrequentAggregate( valueType->kind(), resultType, name, - valueType); + valueType, + step, + argTypes); }, withCompanionFunctions, overwrite); diff --git a/velox/functions/prestosql/aggregates/BitwiseXorAggregate.cpp b/velox/functions/prestosql/aggregates/BitwiseXorAggregate.cpp index aa3a52d5fa28..f76e8103cf11 100644 --- a/velox/functions/prestosql/aggregates/BitwiseXorAggregate.cpp +++ b/velox/functions/prestosql/aggregates/BitwiseXorAggregate.cpp @@ -42,7 +42,9 @@ class BitwiseXorAggregate { AccumulatorType() = delete; - explicit AccumulatorType(HashStringAllocator* /*allocator*/) {} + explicit AccumulatorType( + HashStringAllocator* /*allocator*/, + BitwiseXorAggregate* /*fn*/) {} void addInput(HashStringAllocator* /*allocator*/, exec::arg_type data) { xor_ ^= data; @@ -101,19 +103,19 @@ void registerBitwiseXorAggregate( case TypeKind::TINYINT: return std::make_unique< SimpleAggregateAdapter>>( - resultType); + step, argTypes, resultType); case TypeKind::SMALLINT: return std::make_unique< SimpleAggregateAdapter>>( - resultType); + step, argTypes, resultType); case TypeKind::INTEGER: return std::make_unique< SimpleAggregateAdapter>>( - resultType); + step, argTypes, resultType); case TypeKind::BIGINT: return std::make_unique< SimpleAggregateAdapter>>( - resultType); + step, argTypes, resultType); default: VELOX_USER_FAIL( "Unknown input type for {} aggregation {}", diff --git a/velox/functions/prestosql/aggregates/GeometricMeanAggregate.cpp b/velox/functions/prestosql/aggregates/GeometricMeanAggregate.cpp index 9bbe9ec8c73b..f63f8f506753 100644 --- a/velox/functions/prestosql/aggregates/GeometricMeanAggregate.cpp +++ b/velox/functions/prestosql/aggregates/GeometricMeanAggregate.cpp @@ -50,7 +50,9 @@ class GeometricMeanAggregate { AccumulatorType() = delete; - explicit AccumulatorType(HashStringAllocator* /*allocator*/) {} + explicit AccumulatorType( + HashStringAllocator* /*allocator*/, + GeometricMeanAggregate* /*fn*/) {} void addInput( HashStringAllocator* /*allocator*/, @@ -120,15 +122,16 @@ void registerGeometricMeanAggregate( switch (inputType->kind()) { case TypeKind::BIGINT: return std::make_unique>>(resultType); + GeometricMeanAggregate>>( + step, argTypes, resultType); case TypeKind::DOUBLE: return std::make_unique< SimpleAggregateAdapter>>( - resultType); + step, argTypes, resultType); case TypeKind::REAL: return std::make_unique< SimpleAggregateAdapter>>( - resultType); + step, argTypes, resultType); default: VELOX_USER_FAIL( "Unknown input type for {} aggregation {}", diff --git a/velox/functions/sparksql/aggregates/CollectListAggregate.cpp b/velox/functions/sparksql/aggregates/CollectListAggregate.cpp index e2c14cfa7969..c85d7eff3584 100644 --- a/velox/functions/sparksql/aggregates/CollectListAggregate.cpp +++ b/velox/functions/sparksql/aggregates/CollectListAggregate.cpp @@ -51,7 +51,9 @@ class CollectListAggregate { struct AccumulatorType { ValueList elements_; - explicit AccumulatorType(HashStringAllocator* /*allocator*/) + explicit AccumulatorType( + HashStringAllocator* /*allocator*/, + CollectListAggregate* /*fn*/) : elements_{} {} static constexpr bool is_fixed_size_ = false; @@ -117,7 +119,7 @@ AggregateRegistrationResult registerCollectList( name, std::move(signatures), [name]( - core::AggregationNode::Step /*step*/, + core::AggregationNode::Step step, const std::vector& argTypes, const TypePtr& resultType, const core::QueryConfig& /*config*/) @@ -125,7 +127,7 @@ AggregateRegistrationResult registerCollectList( VELOX_CHECK_EQ( argTypes.size(), 1, "{} takes at most one argument", name); return std::make_unique>( - resultType); + step, argTypes, resultType); }, withCompanionFunctions, overwrite); diff --git a/velox/functions/sparksql/aggregates/DecimalSumAggregate.h b/velox/functions/sparksql/aggregates/DecimalSumAggregate.h index 4d10e9411547..90600879daf8 100644 --- a/velox/functions/sparksql/aggregates/DecimalSumAggregate.h +++ b/velox/functions/sparksql/aggregates/DecimalSumAggregate.h @@ -72,7 +72,9 @@ class DecimalSumAggregate { AccumulatorType() = delete; - explicit AccumulatorType(HashStringAllocator* /*allocator*/) {} + explicit AccumulatorType( + HashStringAllocator* /*allocator*/, + DecimalSumAggregate* /*fn*/) {} std::optional computeFinalResult() const { if (!sum.has_value()) { diff --git a/velox/functions/sparksql/aggregates/ModeAggregate.cpp b/velox/functions/sparksql/aggregates/ModeAggregate.cpp index 1d027fec4f1c..f43b2ebe4ae1 100644 --- a/velox/functions/sparksql/aggregates/ModeAggregate.cpp +++ b/velox/functions/sparksql/aggregates/ModeAggregate.cpp @@ -51,7 +51,9 @@ class ModeAggregate { // A map of T -> count. ValueMap values; - explicit AccumulatorType(HashStringAllocator* allocator) + explicit AccumulatorType( + HashStringAllocator* allocator, + ModeAggregate* /*fn*/) : values{AlignedStlAllocator, 16>( allocator)} {} @@ -120,7 +122,9 @@ class StringModeAggregate { // Stores unique non-null non-inline strings. Strings strings; - explicit AccumulatorType(HashStringAllocator* allocator) + explicit AccumulatorType( + HashStringAllocator* allocator, + StringModeAggregate* /*fn*/) : values{AlignedStlAllocator, 16>( allocator)} {} @@ -520,46 +524,56 @@ void registerModeAggregate( switch (inputType->kind()) { case TypeKind::BOOLEAN: return std::make_unique< - SimpleAggregateAdapter>>(resultType); + SimpleAggregateAdapter>>( + step, argTypes, resultType); case TypeKind::TINYINT: return std::make_unique< - SimpleAggregateAdapter>>(resultType); + SimpleAggregateAdapter>>( + step, argTypes, resultType); case TypeKind::SMALLINT: return std::make_unique< - SimpleAggregateAdapter>>(resultType); + SimpleAggregateAdapter>>( + step, argTypes, resultType); case TypeKind::INTEGER: return std::make_unique< - SimpleAggregateAdapter>>(resultType); + SimpleAggregateAdapter>>( + step, argTypes, resultType); case TypeKind::BIGINT: return std::make_unique< - SimpleAggregateAdapter>>(resultType); + SimpleAggregateAdapter>>( + step, argTypes, resultType); case TypeKind::HUGEINT: return std::make_unique< - SimpleAggregateAdapter>>(resultType); + SimpleAggregateAdapter>>( + step, argTypes, resultType); case TypeKind::REAL: return std::make_unique, - util::floating_point::NaNAwareEquals>>>(resultType); + util::floating_point::NaNAwareEquals>>>( + step, argTypes, resultType); case TypeKind::DOUBLE: return std::make_unique, - util::floating_point::NaNAwareEquals>>>(resultType); + util::floating_point::NaNAwareEquals>>>( + step, argTypes, resultType); case TypeKind::TIMESTAMP: return std::make_unique< - SimpleAggregateAdapter>>(resultType); + SimpleAggregateAdapter>>( + step, argTypes, resultType); case TypeKind::UNKNOWN: // Regitsers Mode function with unknown type, this needs hasher for // UnknownValue, we use folly::hasher for it. return std::make_unique, - std::equal_to>>>(resultType); + std::equal_to>>>(step, argTypes, resultType); case TypeKind::VARCHAR: case TypeKind::VARBINARY: return std::make_unique< - SimpleAggregateAdapter>(resultType); + SimpleAggregateAdapter>( + step, argTypes, resultType); case TypeKind::ARRAY: case TypeKind::MAP: case TypeKind::ROW: diff --git a/velox/functions/sparksql/aggregates/RegrReplacementAggregate.cpp b/velox/functions/sparksql/aggregates/RegrReplacementAggregate.cpp index 85f69d1f7c3f..05c841e20cb1 100644 --- a/velox/functions/sparksql/aggregates/RegrReplacementAggregate.cpp +++ b/velox/functions/sparksql/aggregates/RegrReplacementAggregate.cpp @@ -41,7 +41,9 @@ class RegrReplacementAggregate { double avg{0.0}; double m2{0.0}; - explicit AccumulatorType(HashStringAllocator* /*allocator*/) {} + explicit AccumulatorType( + HashStringAllocator* /*allocator*/, + RegrReplacementAggregate* /*fn*/) {} void addInput( HashStringAllocator* /*allocator*/, @@ -101,7 +103,7 @@ exec::AggregateRegistrationResult registerRegrReplacement( name, std::move(signatures), [name]( - core::AggregationNode::Step /*step*/, + core::AggregationNode::Step step, const std::vector& argTypes, const TypePtr& resultType, const core::QueryConfig& /*config*/) @@ -109,7 +111,8 @@ exec::AggregateRegistrationResult registerRegrReplacement( VELOX_CHECK_EQ( argTypes.size(), 1, "{} takes at most one argument", name); return std::make_unique< - exec::SimpleAggregateAdapter>(resultType); + exec::SimpleAggregateAdapter>( + step, argTypes, resultType); }, withCompanionFunctions, overwrite); diff --git a/velox/functions/sparksql/aggregates/SumAggregate.cpp b/velox/functions/sparksql/aggregates/SumAggregate.cpp index 65b0c78a730d..0d88d98a7ead 100644 --- a/velox/functions/sparksql/aggregates/SumAggregate.cpp +++ b/velox/functions/sparksql/aggregates/SumAggregate.cpp @@ -45,22 +45,27 @@ void checkAccumulatorRowType(const TypePtr& type) { std::unique_ptr constructDecimalSumAgg( const TypePtr& inputType, const TypePtr& sumType, - const TypePtr& resultType) { + const TypePtr& resultType, + core::AggregationNode::Step step, + const std::vector& argTypes) { uint8_t precision = getDecimalPrecisionScale(*sumType).first; switch (precision) { // The sum precision is calculated from the input precision with the formula // min(p + 10, 38). Therefore, the sum precision must >= 11. -#define PRECISION_CASE(precision) \ - case precision: \ - if (inputType->isShortDecimal() && sumType->isShortDecimal()) { \ - return std::make_unique>>(resultType); \ - } else if (inputType->isShortDecimal() && sumType->isLongDecimal()) { \ - return std::make_unique>>(resultType); \ - } else { \ - return std::make_unique>>(resultType); \ +#define PRECISION_CASE(precision) \ + case precision: \ + if (inputType->isShortDecimal() && sumType->isShortDecimal()) { \ + return std::make_unique>>( \ + step, argTypes, resultType); \ + } else if (inputType->isShortDecimal() && sumType->isLongDecimal()) { \ + return std::make_unique>>( \ + step, argTypes, resultType); \ + } else { \ + return std::make_unique>>( \ + step, argTypes, resultType); \ } PRECISION_CASE(11) PRECISION_CASE(12) @@ -155,7 +160,11 @@ exec::AggregateRegistrationResult registerSum( case TypeKind::BIGINT: { if (inputType->isShortDecimal()) { return constructDecimalSumAgg( - inputType, getDecimalSumType(resultType), resultType); + inputType, + getDecimalSumType(resultType), + resultType, + step, + argTypes); } return std::make_unique>( BIGINT()); @@ -165,7 +174,11 @@ exec::AggregateRegistrationResult registerSum( // If inputType is long decimal, // its output type is always long decimal. return constructDecimalSumAgg( - inputType, getDecimalSumType(resultType), resultType); + inputType, + getDecimalSumType(resultType), + resultType, + step, + argTypes); } case TypeKind::REAL: if (resultType->kind() == TypeKind::REAL) { @@ -187,7 +200,11 @@ exec::AggregateRegistrationResult registerSum( // For the intermediate aggregation step, input intermediate sum // type is equal to final result sum type. return constructDecimalSumAgg( - inputType->childAt(0), inputType->childAt(0), resultType); + inputType->childAt(0), + inputType->childAt(0), + resultType, + step, + argTypes); } [[fallthrough]]; default: