diff --git a/cpp/src/arrow/compute/kernels/aggregate_basic_internal.h b/cpp/src/arrow/compute/kernels/aggregate_basic_internal.h index cffb2629df292..292712d95f922 100644 --- a/cpp/src/arrow/compute/kernels/aggregate_basic_internal.h +++ b/cpp/src/arrow/compute/kernels/aggregate_basic_internal.h @@ -895,9 +895,9 @@ template struct DictionaryMinMaxImpl : public ScalarAggregator { using ThisType = DictionaryMinMaxImpl; - DictionaryMinMaxImpl(std::shared_ptr out_type, ScalarAggregateOptions options) - : out_type(std::move(out_type)), - options(std::move(options)), + DictionaryMinMaxImpl(ScalarAggregateOptions options) + : options(std::move(options)), + out_child_type(nullptr), has_nulls(false), count(0), min(nullptr), @@ -910,82 +910,88 @@ struct DictionaryMinMaxImpl : public ScalarAggregator { return Status::NotImplemented("No min/max implemented for DictionaryScalar"); } - const DictionaryArray& dict_array = - checked_cast(*batch[0].array.ToArray()); + DictionaryArray arr(batch[0].array.ToArrayData()); + std::shared_ptr dict_values = arr.dictionary(); + std::shared_ptr dict_indices = arr.indices(); - std::shared_ptr dict_values = dict_array.dictionary(); - std::shared_ptr dict_indices = dict_array.indices(); - has_nulls = dict_indices->null_count() > 0; - count += dict_indices->length() - dict_indices->null_count(); + this->out_child_type = dict_values->type(); + this->has_nulls = dict_indices->null_count() > 0; + this->count += dict_indices->length() - dict_indices->null_count(); Datum dict_values_(*dict_values); ARROW_ASSIGN_OR_RAISE(Datum result, MinMax(std::move(dict_values_))); const StructScalar& struct_result = checked_cast(*result.scalar()); + ARROW_ASSIGN_OR_RAISE(auto min_, struct_result.field(FieldRef("min"))); ARROW_ASSIGN_OR_RAISE(auto max_, struct_result.field(FieldRef("max"))); + ARROW_RETURN_NOT_OK(CompareMinMax(std::move(min_), std::move(max_))); return Status::OK(); } Status MergeFrom(KernelContext*, KernelState&& src) override { const auto& other = checked_cast(src); + ARROW_RETURN_NOT_OK(CompareMinMax(other.min, other.max)); - has_nulls = has_nulls || other.has_nulls; + if (this->out_child_type == nullptr) { + this->out_child_type = other.out_child_type; + } else if (other.out_child_type != nullptr) { + ARROW_CHECK_EQ(this->out_child_type->id(), other.out_child_type->id()); + } + this->has_nulls = this->has_nulls || other.has_nulls; this->count += other.count; return Status::OK(); } Status Finalize(KernelContext*, Datum* out) override { - const auto& struct_type = checked_cast(*out_type); - const auto& child_type = struct_type.field(0)->type(); - std::vector> values; // Physical type != result type if ((this->has_nulls && !options.skip_nulls) || (this->count < options.min_count) || - min == nullptr || min->type->id() == Type::NA) { + this->min == nullptr || this->min->type->id() == Type::NA) { // (null, null) - auto null_scalar = MakeNullScalar(child_type); + std::shared_ptr null_scalar = MakeNullScalar(this->out_child_type); values = {null_scalar, null_scalar}; } else { - ARROW_CHECK_EQ(child_type->id(), min->type->id()); - ARROW_CHECK_EQ(child_type->id(), max->type->id()); - values = {std::move(min), std::move(max)}; + values = {std::move(this->min), std::move(this->max)}; } - out->value = std::make_shared(std::move(values), this->out_type); + + out->value = std::make_shared( + std::move(values), struct_({field("min", this->out_child_type), + field("max", this->out_child_type)})); return Status::OK(); } - std::shared_ptr out_type; ScalarAggregateOptions options; + std::shared_ptr out_child_type; bool has_nulls; int64_t count; std::shared_ptr min; std::shared_ptr max; private: - Status CompareMinMax(std::shared_ptr min_, - std::shared_ptr max_) { - if (min == nullptr || min->type->id() == Type::NA) { - min = min_; + Status CompareMinMax(std::shared_ptr min_, std::shared_ptr max_) { + if (this->min == nullptr || this->min->type->id() == Type::NA) { + this->min = min_; } else if (min_ != nullptr && min_->type->id() != Type::NA) { - ARROW_ASSIGN_OR_RAISE(auto min_compare_result, + ARROW_ASSIGN_OR_RAISE(Datum min_compare_result, CallFunction("greater", {min, min_})); + const BooleanScalar& min_compare_result_scalar = checked_cast(*min_compare_result.scalar()); if (min_compare_result_scalar.value) { - min = min_; + this->min = min_; } } - if (max == nullptr || max->type->id() == Type::NA) { - max = max_; + if (this->max == nullptr || this->max->type->id() == Type::NA) { + this->max = max_; } else if (max_ != nullptr && max_->type->id() != Type::NA) { - ARROW_ASSIGN_OR_RAISE(auto max_compare_result, CallFunction("less", {max, max_})); + ARROW_ASSIGN_OR_RAISE(Datum max_compare_result, CallFunction("less", {max, max_})); const BooleanScalar& max_compare_result_scalar = checked_cast(*max_compare_result.scalar()); if (max_compare_result_scalar.value) { - max = max_; + this->max = max_; } } @@ -1084,7 +1090,7 @@ struct MinMaxInitState { } Status Visit(const DictionaryType&) { - state.reset(new DictionaryMinMaxImpl(out_type, options)); + state.reset(new DictionaryMinMaxImpl(options)); return Status::OK(); }