Skip to content

Commit

Permalink
api to build aggregator params
Browse files Browse the repository at this point in the history
  • Loading branch information
lgbo-ustc committed Nov 18, 2024
1 parent a7f4362 commit ae92f3e
Show file tree
Hide file tree
Showing 3 changed files with 179 additions and 129 deletions.
150 changes: 128 additions & 22 deletions cpp-ch/local-engine/Common/AggregateUtil.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@
* limitations under the License.
*/

#include "AggregateUtil.h"
#include <Core/Settings.h>
#include <Poco/Logger.h>
#include <Common/AggregateUtil.h>
#include <Common/CHUtil.h>
#include <Common/Exception.h>
#include <Common/Stopwatch.h>
#include <Common/formatReadable.h>
Expand All @@ -26,8 +29,26 @@ namespace DB
{
namespace ErrorCodes
{
extern const int LOGICAL_ERROR;
extern const int UNKNOWN_AGGREGATED_DATA_VARIANT;
extern const int LOGICAL_ERROR;
extern const int UNKNOWN_AGGREGATED_DATA_VARIANT;
}

namespace Setting
{
extern const SettingsUInt64 max_bytes_before_external_group_by;
extern const SettingsBool optimize_group_by_constant_keys;
extern const SettingsUInt64 min_free_disk_space_for_temporary_data;
extern const SettingsMaxThreads max_threads;
extern const SettingsBool empty_result_for_aggregation_by_empty_set;
extern const SettingsUInt64 group_by_two_level_threshold_bytes;
extern const SettingsOverflowModeGroupBy group_by_overflow_mode;
extern const SettingsUInt64 max_rows_to_group_by;
extern const SettingsBool enable_memory_bound_merging_of_aggregation_results;
extern const SettingsUInt64 aggregation_in_order_max_block_bytes;
extern const SettingsUInt64 group_by_two_level_threshold;
extern const SettingsFloat min_hit_rate_to_use_consecutive_keys_optimization;
extern const SettingsMaxThreads max_threads;
extern const SettingsUInt64 max_block_size;
}

template <typename Method>
Expand All @@ -39,24 +60,23 @@ static Int32 extractMethodBucketsNum(Method & /*method*/)
Int32 GlutenAggregatorUtil::getBucketsNum(AggregatedDataVariants & data_variants)
{
if (!data_variants.isTwoLevel())
{
return 0;
}


Int32 buckets_num = 0;
#define M(NAME) \
else if (data_variants.type == AggregatedDataVariants::Type::NAME) \
buckets_num = extractMethodBucketsNum(*data_variants.NAME);
else if (data_variants.type == AggregatedDataVariants::Type::NAME) buckets_num = extractMethodBucketsNum(*data_variants.NAME);

if (false) {} // NOLINT
if (false)
{
} // NOLINT
APPLY_FOR_VARIANTS_TWO_LEVEL(M)
#undef M
else
throw Exception(ErrorCodes::UNKNOWN_AGGREGATED_DATA_VARIANT, "Unknown aggregated data variant");
else throw Exception(ErrorCodes::UNKNOWN_AGGREGATED_DATA_VARIANT, "Unknown aggregated data variant");
return buckets_num;
}

std::optional<Block> GlutenAggregatorUtil::safeConvertOneBucketToBlock(Aggregator & aggregator, AggregatedDataVariants & variants, Arena * arena, bool final, Int32 bucket)
std::optional<Block> GlutenAggregatorUtil::safeConvertOneBucketToBlock(
Aggregator & aggregator, AggregatedDataVariants & variants, Arena * arena, bool final, Int32 bucket)
{
if (!variants.isTwoLevel())
return {};
Expand All @@ -65,7 +85,7 @@ std::optional<Block> GlutenAggregatorUtil::safeConvertOneBucketToBlock(Aggregato
return aggregator.convertOneBucketToBlock(variants, arena, final, bucket);
}

template<typename Method>
template <typename Method>
static void releaseOneBucket(Method & method, Int32 bucket)
{
method.data.impls[bucket].clearAndShrink();
Expand All @@ -77,29 +97,26 @@ void GlutenAggregatorUtil::safeReleaseOneBucket(AggregatedDataVariants & variant
return;
if (bucket >= getBucketsNum(variants))
return;
#define M(NAME) \
else if (variants.type == AggregatedDataVariants::Type::NAME) \
releaseOneBucket(*variants.NAME, bucket);
#define M(NAME) else if (variants.type == AggregatedDataVariants::Type::NAME) releaseOneBucket(*variants.NAME, bucket);

if (false) {} // NOLINT
if (false)
{
} // NOLINT
APPLY_FOR_VARIANTS_TWO_LEVEL(M)
#undef M
else
throw Exception(ErrorCodes::UNKNOWN_AGGREGATED_DATA_VARIANT, "Unknown aggregated data variant");

else throw Exception(ErrorCodes::UNKNOWN_AGGREGATED_DATA_VARIANT, "Unknown aggregated data variant");
}

}

namespace local_engine
{
AggregateDataBlockConverter::AggregateDataBlockConverter(DB::Aggregator & aggregator_, DB::AggregatedDataVariantsPtr data_variants_, bool final_)
AggregateDataBlockConverter::AggregateDataBlockConverter(
DB::Aggregator & aggregator_, DB::AggregatedDataVariantsPtr data_variants_, bool final_)
: aggregator(aggregator_), data_variants(std::move(data_variants_)), final(final_)
{
if (data_variants->isTwoLevel())
{
buckets_num = DB::GlutenAggregatorUtil::getBucketsNum(*data_variants);
}
else if (data_variants->size())
buckets_num = 1;
else
Expand Down Expand Up @@ -168,4 +185,93 @@ DB::Block AggregateDataBlockConverter::next()
output_blocks.pop_front();
return block;
}

DB::Aggregator::Params AggregatorParamsHelper::buildParams(
DB::ContextPtr context,
const DB::Names & grouping_keys,
const DB::AggregateDescriptions & agg_descriptions,
Mode mode,
Algorithm algorithm)
{
const auto & settings = context->getSettingsRef();
size_t max_rows_to_group_by = mode == Mode::PARTIAL_TO_FINISHED ? 0 : settings[DB::Setting::max_rows_to_group_by];
DB::OverflowMode group_by_overflow_mode = settings[DB::Setting::group_by_overflow_mode];
size_t group_by_two_level_threshold
= algorithm == Algorithm::GlutenGraceAggregate ? static_cast<size_t>(settings[DB::Setting::group_by_two_level_threshold]) : 0;
size_t group_by_two_level_threshold_bytes = algorithm == Algorithm::GlutenGraceAggregate
? 0
: (mode == Mode::PARTIAL_TO_FINISHED ? 0 : static_cast<size_t>(settings[DB::Setting::group_by_two_level_threshold_bytes]));
size_t max_bytes_before_external_group_by = algorithm == Algorithm::GlutenGraceAggregate
? 0
: (mode == Mode::PARTIAL_TO_FINISHED ? 0 : static_cast<size_t>(settings[DB::Setting::max_bytes_before_external_group_by]));
bool empty_result_for_aggregation_by_empty_set = algorithm == Algorithm::GlutenGraceAggregate
? false
: (mode == Mode::PARTIAL_TO_FINISHED ? false : static_cast<bool>(settings[DB::Setting::empty_result_for_aggregation_by_empty_set]));
DB::TemporaryDataOnDiskScopePtr tmp_data_scope = algorithm == Algorithm::GlutenGraceAggregate ? nullptr : context->getTempDataOnDisk();
size_t max_threads = settings[DB::Setting::max_threads];
size_t min_free_disk_space
= algorithm == Algorithm::GlutenGraceAggregate ? 0 : settings[DB::Setting::min_free_disk_space_for_temporary_data];
bool compile_aggregate_expressions = mode == Mode::PARTIAL_TO_FINISHED ? false : true;
size_t min_count_to_compile_aggregate_expression = mode == Mode::PARTIAL_TO_FINISHED ? 0 : 3;
size_t max_block_size = PODArrayUtil::adjustMemoryEfficientSize(settings[DB::Setting::max_block_size]);
bool enable_prefetch = mode == Mode::PARTIAL_TO_FINISHED ? false : true;
bool only_merge = mode == Mode::PARTIAL_TO_FINISHED;
bool optimize_group_by_constant_keys
= mode == Mode::PARTIAL_TO_FINISHED ? false : settings[DB::Setting::optimize_group_by_constant_keys];
double min_hit_rate_to_use_consecutive_keys_optimization = settings[DB::Setting::min_hit_rate_to_use_consecutive_keys_optimization];
DB::Aggregator::Params params(
grouping_keys,
agg_descriptions,
false,
max_rows_to_group_by,
group_by_overflow_mode,
group_by_two_level_threshold,
group_by_two_level_threshold_bytes,
max_bytes_before_external_group_by,
empty_result_for_aggregation_by_empty_set,
tmp_data_scope,
max_threads,
min_free_disk_space,
compile_aggregate_expressions,
min_count_to_compile_aggregate_expression,
max_block_size,
enable_prefetch,
only_merge,
optimize_group_by_constant_keys,
min_hit_rate_to_use_consecutive_keys_optimization,
{});
return params;
}


#define COMPARE_FIELD(field) \
if (lhs.field != rhs.field) \
{ \
LOG_ERROR(getLogger("AggregatorParamsHelper"), "Aggregator::Params field " #field " is not equal. {}/{}", lhs.field, rhs.field); \
return false; \
}
bool AggregatorParamsHelper::compare(const DB::Aggregator::Params & lhs, const DB::Aggregator::Params & rhs)
{
COMPARE_FIELD(overflow_row);
COMPARE_FIELD(max_rows_to_group_by);
COMPARE_FIELD(group_by_overflow_mode);
COMPARE_FIELD(group_by_two_level_threshold);
COMPARE_FIELD(group_by_two_level_threshold_bytes);
COMPARE_FIELD(max_bytes_before_external_group_by);
COMPARE_FIELD(empty_result_for_aggregation_by_empty_set);
COMPARE_FIELD(max_threads);
COMPARE_FIELD(min_free_disk_space);
COMPARE_FIELD(compile_aggregate_expressions);
if ((lhs.tmp_data_scope == nullptr) != (rhs.tmp_data_scope == nullptr))
{
LOG_ERROR(getLogger("AggregatorParamsHelper"), "Aggregator::Params field tmp_data_scope is not equal.");
return false;
}
COMPARE_FIELD(min_count_to_compile_aggregate_expression);
COMPARE_FIELD(enable_prefetch);
COMPARE_FIELD(only_merge);
COMPARE_FIELD(optimize_group_by_constant_keys);
COMPARE_FIELD(min_hit_rate_to_use_consecutive_keys_optimization);
return true;
}
}
31 changes: 30 additions & 1 deletion cpp-ch/local-engine/Common/AggregateUtil.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ class GlutenAggregatorUtil
{
public:
static Int32 getBucketsNum(AggregatedDataVariants & data_variants);
static std::optional<Block> safeConvertOneBucketToBlock(Aggregator & aggregator, AggregatedDataVariants & variants, Arena * arena, bool final, Int32 bucket);
static std::optional<Block>
safeConvertOneBucketToBlock(Aggregator & aggregator, AggregatedDataVariants & variants, Arena * arena, bool final, Int32 bucket);
static void safeReleaseOneBucket(AggregatedDataVariants & variants, Int32 bucket);
};
}
Expand All @@ -41,6 +42,7 @@ class AggregateDataBlockConverter
~AggregateDataBlockConverter() = default;
bool hasNext();
DB::Block next();

private:
DB::Aggregator & aggregator;
DB::AggregatedDataVariantsPtr data_variants;
Expand All @@ -50,4 +52,31 @@ class AggregateDataBlockConverter
Int32 current_bucket = 0;
DB::BlocksList output_blocks;
};

class AggregatorParamsHelper
{
public:
enum class Algorithm
{
GlutenGraceAggregate,
CHTwoStageAggregate
};
enum class Mode
{
INIT_TO_PARTIAL,
INIT_TO_COMPLETED,
PARTIAL_TO_PARTIAL,
PARTIAL_TO_FINISHED,
};

// for using grace aggregating, never enable ch spill, otherwise there will be data lost.
static DB::Aggregator::Params buildParams(
DB::ContextPtr context,
const DB::Names & grouping_keys,
const DB::AggregateDescriptions & agg_descriptions,
Mode mode,
Algorithm algorithm = Algorithm::GlutenGraceAggregate);
static bool compare(const DB::Aggregator::Params & lhs, const DB::Aggregator::Params & rhs);
};

}
Loading

0 comments on commit ae92f3e

Please sign in to comment.