From 89de7e1f9128e0cfa9197bb570e9afa8b4034782 Mon Sep 17 00:00:00 2001 From: zstan Date: Thu, 23 Jan 2025 11:24:55 +0300 Subject: [PATCH] IGNITE-23813 Sql. Provide correct implementation of LITERAL_AGG aggregate function --- .../sql/engine/exec/exp/agg/Accumulators.java | 45 ++++++++++++++----- .../exec/exp/agg/AccumulatorsFactory.java | 27 ++++++++--- .../internal/sql/engine/util/PlanUtils.java | 13 ++++-- .../exp/agg/LiteralValAccumulatorTest.java | 6 +-- .../sql/engine/util/PlanUtilsTest.java | 4 +- 5 files changed, 70 insertions(+), 25 deletions(-) diff --git a/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/exec/exp/agg/Accumulators.java b/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/exec/exp/agg/Accumulators.java index f08f73c4b5f..b52e6031d23 100644 --- a/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/exec/exp/agg/Accumulators.java +++ b/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/exec/exp/agg/Accumulators.java @@ -19,7 +19,6 @@ import static org.apache.calcite.sql.type.SqlTypeName.ANY; import static org.apache.calcite.sql.type.SqlTypeName.BIGINT; -import static org.apache.calcite.sql.type.SqlTypeName.BOOLEAN; import static org.apache.calcite.sql.type.SqlTypeName.DECIMAL; import static org.apache.calcite.sql.type.SqlTypeName.DOUBLE; import static org.apache.calcite.sql.type.SqlTypeName.VARBINARY; @@ -35,6 +34,7 @@ import org.apache.calcite.rel.core.AggregateCall; import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.sql.fun.SqlLiteralAggFunction; import org.apache.ignite.internal.catalog.commands.CatalogUtils; import org.apache.ignite.internal.sql.engine.exec.exp.IgniteSqlFunctions; import org.apache.ignite.internal.sql.engine.type.IgniteCustomType; @@ -61,11 +61,11 @@ public Accumulators(IgniteTypeFactory typeFactory) { /** * Returns a supplier that creates a accumulator functions for the given aggregate call. */ - public Supplier accumulatorFactory(AggregateCall call) { - return accumulatorFunctionFactory(call); + public Supplier accumulatorFactory(AggregateCall call, RelDataType inputType) { + return accumulatorFunctionFactory(call, inputType); } - private Supplier accumulatorFunctionFactory(AggregateCall call) { + private Supplier accumulatorFunctionFactory(AggregateCall call, RelDataType inputType) { // Update documentation in IgniteCustomType when you add an aggregate // that can work for any type out of the box. switch (call.getAggregation().getName()) { @@ -88,7 +88,8 @@ private Supplier accumulatorFunctionFactory(AggregateCall call) { case "ANY_VALUE": return anyValueFactory(call); case "LITERAL_AGG": - return LiteralVal.newAccumulator(typeFactory.createSqlType(BOOLEAN)); + assert call.rexList.size() == 1 : "Incorrect number of pre-operands for LiteralAgg: " + call + ", input: " + inputType; + return LiteralVal.newAccumulator(call.rexList.get(0).getType()); default: throw new AssertionError(call.getAggregation().getName()); } @@ -254,23 +255,47 @@ public RelDataType returnType(IgniteTypeFactory typeFactory) { } /** - * {@code LITERAL_AGG} accumulator, return {@code true} if incoming data is not empty, {@code false} otherwise. Calcite`s implementation - * RexImpTable#LiteralAggImplementor. + * {@code LITERAL_AGG} accumulator. Pseudo accumulator that accepts a single literal as an operand and returns that literal. + * + * @see SqlLiteralAggFunction */ - public static class LiteralVal extends AnyVal { + public static class LiteralVal implements Accumulator { + + private final RelDataType type; + private LiteralVal(RelDataType type) { - super(type); + this.type = type; } public static Supplier newAccumulator(RelDataType type) { return () -> new LiteralVal(type); } + /** {@inheritDoc} */ + @Override + public void add(AccumulatorsState state, Object[] args) { + assert args.length == 1 : args.length; + // Literal Agg is called with the same argument. + state.set(args[0]); + } + /** {@inheritDoc} */ @Override public void end(AccumulatorsState state, AccumulatorsState result) { Object val = state.get(); - result.set(val != null); + result.set(val); + } + + /** {@inheritDoc} */ + @Override + public List argumentTypes(IgniteTypeFactory typeFactory) { + return List.of(typeFactory.createTypeWithNullability(type, true)); + } + + /** {@inheritDoc} */ + @Override + public RelDataType returnType(IgniteTypeFactory typeFactory) { + return type; } } diff --git a/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/exec/exp/agg/AccumulatorsFactory.java b/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/exec/exp/agg/AccumulatorsFactory.java index ef6ebed1741..5f504eac53f 100644 --- a/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/exec/exp/agg/AccumulatorsFactory.java +++ b/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/exec/exp/agg/AccumulatorsFactory.java @@ -49,6 +49,7 @@ import org.apache.ignite.internal.sql.engine.exec.ExecutionContext; import org.apache.ignite.internal.sql.engine.exec.RowHandler; import org.apache.ignite.internal.sql.engine.exec.exp.RexToLixTranslator; +import org.apache.ignite.internal.sql.engine.exec.exp.SqlScalar; import org.apache.ignite.internal.sql.engine.type.IgniteTypeFactory; import org.apache.ignite.internal.sql.engine.util.Commons; import org.apache.ignite.internal.sql.engine.util.Primitives; @@ -188,7 +189,7 @@ private WrapperPrototype(Accumulators accumulators, AggregateCall call) { public AccumulatorWrapper apply(ExecutionContext context) { Accumulator accumulator = accumulator(); - return new AccumulatorWrapperImpl<>(context.rowHandler(), accumulator, call, inAdapter, outAdapter); + return new AccumulatorWrapperImpl<>(context, accumulator, call, inAdapter, outAdapter); } private Accumulator accumulator() { @@ -197,7 +198,7 @@ private Accumulator accumulator() { } // init factory and adapters - accFactory = accumulators.accumulatorFactory(call); + accFactory = accumulators.accumulatorFactory(call, inputRowType); Accumulator accumulator = accFactory.get(); inAdapter = createInAdapter(accumulator); @@ -264,6 +265,8 @@ private static final class AccumulatorWrapperImpl implements AccumulatorWr private final boolean literalAgg; + private final Object preOperand; + private final int filterArg; private final boolean ignoreNulls; @@ -273,23 +276,31 @@ private static final class AccumulatorWrapperImpl implements AccumulatorWr private final boolean distinct; AccumulatorWrapperImpl( - RowHandler handler, + ExecutionContext ctx, Accumulator accumulator, AggregateCall call, Function inAdapter, Function outAdapter ) { - this.handler = handler; + this.handler = ctx.rowHandler(); this.accumulator = accumulator; this.inAdapter = inAdapter; this.outAdapter = outAdapter; this.distinct = call.isDistinct(); - // need to be refactored after https://issues.apache.org/jira/browse/IGNITE-22320 literalAgg = call.getAggregation() == LITERAL_AGG; argList = call.getArgList(); ignoreNulls = call.ignoreNulls(); filterArg = call.hasFilter() ? call.filterArg : -1; + + if (literalAgg) { + assert call.getArgList().isEmpty() : "LiteralAgg should have no operands: " + call; + + SqlScalar litAggArg = ctx.expressionFactory().scalar(call.rexList.get(0)); + preOperand = litAggArg.get(ctx); + } else { + preOperand = null; + } } @Override @@ -308,11 +319,15 @@ public Accumulator accumulator() { return null; } + if (literalAgg) { + return new Object[]{preOperand}; + } + int params = argList.size(); List argList0 = argList; - if ((distinct && argList.isEmpty()) || literalAgg) { + if ((distinct && argList.isEmpty())) { int cnt = handler.columnCount(row); assert cnt <= 1; argList0 = List.of(0); diff --git a/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/util/PlanUtils.java b/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/util/PlanUtils.java index 90fe27e6146..4d7af9d0c14 100644 --- a/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/util/PlanUtils.java +++ b/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/util/PlanUtils.java @@ -80,7 +80,7 @@ public static RelDataType createSortAggRowType(ImmutableBitSet grpKeys, builder.add(fld); } - addAccumulatorFields(typeFactory, aggregateCalls, builder); + addAccumulatorFields(typeFactory, aggregateCalls, inputType, builder); return builder.build(); } @@ -115,19 +115,24 @@ public static RelDataType createHashAggRowType(List groupSets, builder.add(fld); } - addAccumulatorFields(typeFactory, aggregateCalls, builder); + addAccumulatorFields(typeFactory, aggregateCalls, inputType, builder); builder.add("_GROUP_ID", SqlTypeName.TINYINT); return builder.build(); } - private static void addAccumulatorFields(IgniteTypeFactory typeFactory, List aggregateCalls, Builder builder) { + private static void addAccumulatorFields( + IgniteTypeFactory typeFactory, + List aggregateCalls, + RelDataType inputType, + Builder builder + ) { Accumulators accumulators = new Accumulators(typeFactory); for (int i = 0; i < aggregateCalls.size(); i++) { AggregateCall call = aggregateCalls.get(i); - Accumulator acc = accumulators.accumulatorFactory(call).get(); + Accumulator acc = accumulators.accumulatorFactory(call, inputType).get(); RelDataType fieldType; // For a decimal type Accumulator::returnType returns a type with default precision and scale, // that can cause precision loss when a tuple is sent over the wire by an exchanger/outbox. diff --git a/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/exec/exp/agg/LiteralValAccumulatorTest.java b/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/exec/exp/agg/LiteralValAccumulatorTest.java index b766d303b8d..077dd04add8 100644 --- a/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/exec/exp/agg/LiteralValAccumulatorTest.java +++ b/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/exec/exp/agg/LiteralValAccumulatorTest.java @@ -35,17 +35,17 @@ public class LiteralValAccumulatorTest extends BaseIgniteAbstractTest { public void test() { StatefulAccumulator accumulator = newCall(); + // Literal agg accepts the same value. accumulator.add("1"); - accumulator.add("2"); - assertEquals(true, accumulator.end()); + assertEquals("1", accumulator.end()); } @Test public void empty() { StatefulAccumulator accumulator = newCall(); - assertEquals(false, accumulator.end()); + assertEquals(null, accumulator.end()); } private StatefulAccumulator newCall() { diff --git a/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/util/PlanUtilsTest.java b/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/util/PlanUtilsTest.java index c35dbaa40b3..ac13578f9d8 100644 --- a/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/util/PlanUtilsTest.java +++ b/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/util/PlanUtilsTest.java @@ -53,7 +53,7 @@ public void testHashAggRowType() { .build(); AggregateCall call1 = newCall(typeFactory.createSqlType(SqlTypeName.BIGINT)); - Accumulator acc1 = accumulators.accumulatorFactory(call1).get(); + Accumulator acc1 = accumulators.accumulatorFactory(call1, inputType).get(); RelDataType expectedType = new RelDataTypeFactory.Builder(typeFactory) .add("f1", typeFactory.createSqlType(SqlTypeName.INTEGER)) @@ -83,7 +83,7 @@ public void testSortAggRowType() { .build(); AggregateCall call1 = newCall(typeFactory.createSqlType(SqlTypeName.BIGINT)); - Accumulator acc1 = accumulators.accumulatorFactory(call1).get(); + Accumulator acc1 = accumulators.accumulatorFactory(call1, inputType).get(); RelDataType expectedType = new RelDataTypeFactory.Builder(typeFactory) .add("f1", typeFactory.createSqlType(SqlTypeName.INTEGER))