Skip to content

Commit

Permalink
fix output type
Browse files Browse the repository at this point in the history
  • Loading branch information
R-JunmingChen committed Aug 10, 2023
1 parent 422a02f commit 1a5da63
Showing 1 changed file with 37 additions and 31 deletions.
68 changes: 37 additions & 31 deletions cpp/src/arrow/compute/kernels/aggregate_basic_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -895,9 +895,9 @@ template <SimdLevel::type SimdLevel>
struct DictionaryMinMaxImpl : public ScalarAggregator {
using ThisType = DictionaryMinMaxImpl<SimdLevel>;

DictionaryMinMaxImpl(std::shared_ptr<DataType> out_type, ScalarAggregateOptions options)
: out_type(std::move(out_type)),
options(std::move(options)),
DictionaryMinMaxImpl(ScalarAggregateOptions options)
: options(std::move(options)),
out_child_type(nullptr),
has_nulls(false),
count(0),
min(nullptr),
Expand All @@ -910,82 +910,88 @@ struct DictionaryMinMaxImpl : public ScalarAggregator {
return Status::NotImplemented("No min/max implemented for DictionaryScalar");
}

const DictionaryArray& dict_array =
checked_cast<const DictionaryArray&>(*batch[0].array.ToArray());
DictionaryArray arr(batch[0].array.ToArrayData());
std::shared_ptr<Array> dict_values = arr.dictionary();
std::shared_ptr<Array> dict_indices = arr.indices();

std::shared_ptr<Array> dict_values = dict_array.dictionary();
std::shared_ptr<Array> dict_indices = dict_array.indices();
has_nulls = dict_indices->null_count() > 0;
count += dict_indices->length() - dict_indices->null_count();
this->out_child_type = dict_values->type();
this->has_nulls = dict_indices->null_count() > 0;
this->count += dict_indices->length() - dict_indices->null_count();

Datum dict_values_(*dict_values);
ARROW_ASSIGN_OR_RAISE(Datum result, MinMax(std::move(dict_values_)));
const StructScalar& struct_result =
checked_cast<const StructScalar&>(*result.scalar());

ARROW_ASSIGN_OR_RAISE(auto min_, struct_result.field(FieldRef("min")));
ARROW_ASSIGN_OR_RAISE(auto max_, struct_result.field(FieldRef("max")));

ARROW_RETURN_NOT_OK(CompareMinMax(std::move(min_), std::move(max_)));
return Status::OK();
}

Status MergeFrom(KernelContext*, KernelState&& src) override {
const auto& other = checked_cast<const ThisType&>(src);

ARROW_RETURN_NOT_OK(CompareMinMax(other.min, other.max));
has_nulls = has_nulls || other.has_nulls;
if (this->out_child_type == nullptr) {
this->out_child_type = other.out_child_type;
} else if (other.out_child_type != nullptr) {
ARROW_CHECK_EQ(this->out_child_type->id(), other.out_child_type->id());
}
this->has_nulls = this->has_nulls || other.has_nulls;
this->count += other.count;
return Status::OK();
}

Status Finalize(KernelContext*, Datum* out) override {
const auto& struct_type = checked_cast<const StructType&>(*out_type);
const auto& child_type = struct_type.field(0)->type();

std::vector<std::shared_ptr<Scalar>> values;
// Physical type != result type
if ((this->has_nulls && !options.skip_nulls) || (this->count < options.min_count) ||
min == nullptr || min->type->id() == Type::NA) {
this->min == nullptr || this->min->type->id() == Type::NA) {
// (null, null)
auto null_scalar = MakeNullScalar(child_type);
std::shared_ptr<Scalar> null_scalar = MakeNullScalar(this->out_child_type);
values = {null_scalar, null_scalar};
} else {
ARROW_CHECK_EQ(child_type->id(), min->type->id());
ARROW_CHECK_EQ(child_type->id(), max->type->id());
values = {std::move(min), std::move(max)};
values = {std::move(this->min), std::move(this->max)};
}
out->value = std::make_shared<StructScalar>(std::move(values), this->out_type);

out->value = std::make_shared<StructScalar>(
std::move(values), struct_({field("min", this->out_child_type),
field("max", this->out_child_type)}));
return Status::OK();
}

std::shared_ptr<DataType> out_type;
ScalarAggregateOptions options;
std::shared_ptr<DataType> out_child_type;
bool has_nulls;
int64_t count;
std::shared_ptr<Scalar> min;
std::shared_ptr<Scalar> max;

private:
Status CompareMinMax(std::shared_ptr<Scalar> min_,
std::shared_ptr<Scalar> max_) {
if (min == nullptr || min->type->id() == Type::NA) {
min = min_;
Status CompareMinMax(std::shared_ptr<Scalar> min_, std::shared_ptr<Scalar> max_) {
if (this->min == nullptr || this->min->type->id() == Type::NA) {
this->min = min_;
} else if (min_ != nullptr && min_->type->id() != Type::NA) {
ARROW_ASSIGN_OR_RAISE(auto min_compare_result,
ARROW_ASSIGN_OR_RAISE(Datum min_compare_result,
CallFunction("greater", {min, min_}));

const BooleanScalar& min_compare_result_scalar =
checked_cast<const BooleanScalar&>(*min_compare_result.scalar());
if (min_compare_result_scalar.value) {
min = min_;
this->min = min_;
}
}

if (max == nullptr || max->type->id() == Type::NA) {
max = max_;
if (this->max == nullptr || this->max->type->id() == Type::NA) {
this->max = max_;
} else if (max_ != nullptr && max_->type->id() != Type::NA) {
ARROW_ASSIGN_OR_RAISE(auto max_compare_result, CallFunction("less", {max, max_}));
ARROW_ASSIGN_OR_RAISE(Datum max_compare_result, CallFunction("less", {max, max_}));
const BooleanScalar& max_compare_result_scalar =
checked_cast<const BooleanScalar&>(*max_compare_result.scalar());
if (max_compare_result_scalar.value) {
max = max_;
this->max = max_;
}
}

Expand Down Expand Up @@ -1084,7 +1090,7 @@ struct MinMaxInitState {
}

Status Visit(const DictionaryType&) {
state.reset(new DictionaryMinMaxImpl<SimdLevel>(out_type, options));
state.reset(new DictionaryMinMaxImpl<SimdLevel>(options));
return Status::OK();
}

Expand Down

0 comments on commit 1a5da63

Please sign in to comment.