Skip to content

Commit

Permalink
IGNITE-23813 Sql. Provide correct implementation of LITERAL_AGG aggre…
Browse files Browse the repository at this point in the history
…gate function
  • Loading branch information
zstan committed Jan 27, 2025
1 parent abaa624 commit 89de7e1
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

import static org.apache.calcite.sql.type.SqlTypeName.ANY;
import static org.apache.calcite.sql.type.SqlTypeName.BIGINT;
import static org.apache.calcite.sql.type.SqlTypeName.BOOLEAN;
import static org.apache.calcite.sql.type.SqlTypeName.DECIMAL;
import static org.apache.calcite.sql.type.SqlTypeName.DOUBLE;
import static org.apache.calcite.sql.type.SqlTypeName.VARBINARY;
Expand All @@ -35,6 +34,7 @@
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.sql.fun.SqlLiteralAggFunction;
import org.apache.ignite.internal.catalog.commands.CatalogUtils;
import org.apache.ignite.internal.sql.engine.exec.exp.IgniteSqlFunctions;
import org.apache.ignite.internal.sql.engine.type.IgniteCustomType;
Expand All @@ -61,11 +61,11 @@ public Accumulators(IgniteTypeFactory typeFactory) {
/**
* Returns a supplier that creates a accumulator functions for the given aggregate call.
*/
public Supplier<Accumulator> accumulatorFactory(AggregateCall call) {
return accumulatorFunctionFactory(call);
public Supplier<Accumulator> accumulatorFactory(AggregateCall call, RelDataType inputType) {
return accumulatorFunctionFactory(call, inputType);
}

private Supplier<Accumulator> accumulatorFunctionFactory(AggregateCall call) {
private Supplier<Accumulator> accumulatorFunctionFactory(AggregateCall call, RelDataType inputType) {
// Update documentation in IgniteCustomType when you add an aggregate
// that can work for any type out of the box.
switch (call.getAggregation().getName()) {
Expand All @@ -88,7 +88,8 @@ private Supplier<Accumulator> accumulatorFunctionFactory(AggregateCall call) {
case "ANY_VALUE":
return anyValueFactory(call);
case "LITERAL_AGG":
return LiteralVal.newAccumulator(typeFactory.createSqlType(BOOLEAN));
assert call.rexList.size() == 1 : "Incorrect number of pre-operands for LiteralAgg: " + call + ", input: " + inputType;
return LiteralVal.newAccumulator(call.rexList.get(0).getType());
default:
throw new AssertionError(call.getAggregation().getName());
}
Expand Down Expand Up @@ -254,23 +255,47 @@ public RelDataType returnType(IgniteTypeFactory typeFactory) {
}

/**
* {@code LITERAL_AGG} accumulator, return {@code true} if incoming data is not empty, {@code false} otherwise. Calcite`s implementation
* RexImpTable#LiteralAggImplementor.
* {@code LITERAL_AGG} accumulator. Pseudo accumulator that accepts a single literal as an operand and returns that literal.
*
* @see SqlLiteralAggFunction
*/
public static class LiteralVal extends AnyVal {
public static class LiteralVal implements Accumulator {

private final RelDataType type;

private LiteralVal(RelDataType type) {
super(type);
this.type = type;
}

public static Supplier<Accumulator> newAccumulator(RelDataType type) {
return () -> new LiteralVal(type);
}

/** {@inheritDoc} */
@Override
public void add(AccumulatorsState state, Object[] args) {
assert args.length == 1 : args.length;
// Literal Agg is called with the same argument.
state.set(args[0]);
}

/** {@inheritDoc} */
@Override
public void end(AccumulatorsState state, AccumulatorsState result) {
Object val = state.get();
result.set(val != null);
result.set(val);
}

/** {@inheritDoc} */
@Override
public List<RelDataType> argumentTypes(IgniteTypeFactory typeFactory) {
return List.of(typeFactory.createTypeWithNullability(type, true));
}

/** {@inheritDoc} */
@Override
public RelDataType returnType(IgniteTypeFactory typeFactory) {
return type;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
import org.apache.ignite.internal.sql.engine.exec.ExecutionContext;
import org.apache.ignite.internal.sql.engine.exec.RowHandler;
import org.apache.ignite.internal.sql.engine.exec.exp.RexToLixTranslator;
import org.apache.ignite.internal.sql.engine.exec.exp.SqlScalar;
import org.apache.ignite.internal.sql.engine.type.IgniteTypeFactory;
import org.apache.ignite.internal.sql.engine.util.Commons;
import org.apache.ignite.internal.sql.engine.util.Primitives;
Expand Down Expand Up @@ -188,7 +189,7 @@ private WrapperPrototype(Accumulators accumulators, AggregateCall call) {
public AccumulatorWrapper<RowT> apply(ExecutionContext<RowT> context) {
Accumulator accumulator = accumulator();

return new AccumulatorWrapperImpl<>(context.rowHandler(), accumulator, call, inAdapter, outAdapter);
return new AccumulatorWrapperImpl<>(context, accumulator, call, inAdapter, outAdapter);
}

private Accumulator accumulator() {
Expand All @@ -197,7 +198,7 @@ private Accumulator accumulator() {
}

// init factory and adapters
accFactory = accumulators.accumulatorFactory(call);
accFactory = accumulators.accumulatorFactory(call, inputRowType);
Accumulator accumulator = accFactory.get();

inAdapter = createInAdapter(accumulator);
Expand Down Expand Up @@ -264,6 +265,8 @@ private static final class AccumulatorWrapperImpl<RowT> implements AccumulatorWr

private final boolean literalAgg;

private final Object preOperand;

private final int filterArg;

private final boolean ignoreNulls;
Expand All @@ -273,23 +276,31 @@ private static final class AccumulatorWrapperImpl<RowT> implements AccumulatorWr
private final boolean distinct;

AccumulatorWrapperImpl(
RowHandler<RowT> handler,
ExecutionContext<RowT> ctx,
Accumulator accumulator,
AggregateCall call,
Function<Object[], Object[]> inAdapter,
Function<Object, Object> outAdapter
) {
this.handler = handler;
this.handler = ctx.rowHandler();
this.accumulator = accumulator;
this.inAdapter = inAdapter;
this.outAdapter = outAdapter;
this.distinct = call.isDistinct();

// need to be refactored after https://issues.apache.org/jira/browse/IGNITE-22320
literalAgg = call.getAggregation() == LITERAL_AGG;
argList = call.getArgList();
ignoreNulls = call.ignoreNulls();
filterArg = call.hasFilter() ? call.filterArg : -1;

if (literalAgg) {
assert call.getArgList().isEmpty() : "LiteralAgg should have no operands: " + call;

SqlScalar<RowT, Object> litAggArg = ctx.expressionFactory().scalar(call.rexList.get(0));
preOperand = litAggArg.get(ctx);
} else {
preOperand = null;
}
}

@Override
Expand All @@ -308,11 +319,15 @@ public Accumulator accumulator() {
return null;
}

if (literalAgg) {
return new Object[]{preOperand};
}

int params = argList.size();

List<Integer> argList0 = argList;

if ((distinct && argList.isEmpty()) || literalAgg) {
if ((distinct && argList.isEmpty())) {
int cnt = handler.columnCount(row);
assert cnt <= 1;
argList0 = List.of(0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ public static RelDataType createSortAggRowType(ImmutableBitSet grpKeys,
builder.add(fld);
}

addAccumulatorFields(typeFactory, aggregateCalls, builder);
addAccumulatorFields(typeFactory, aggregateCalls, inputType, builder);

return builder.build();
}
Expand Down Expand Up @@ -115,19 +115,24 @@ public static RelDataType createHashAggRowType(List<ImmutableBitSet> groupSets,
builder.add(fld);
}

addAccumulatorFields(typeFactory, aggregateCalls, builder);
addAccumulatorFields(typeFactory, aggregateCalls, inputType, builder);

builder.add("_GROUP_ID", SqlTypeName.TINYINT);

return builder.build();
}

private static void addAccumulatorFields(IgniteTypeFactory typeFactory, List<AggregateCall> aggregateCalls, Builder builder) {
private static void addAccumulatorFields(
IgniteTypeFactory typeFactory,
List<AggregateCall> aggregateCalls,
RelDataType inputType,
Builder builder
) {
Accumulators accumulators = new Accumulators(typeFactory);

for (int i = 0; i < aggregateCalls.size(); i++) {
AggregateCall call = aggregateCalls.get(i);
Accumulator acc = accumulators.accumulatorFactory(call).get();
Accumulator acc = accumulators.accumulatorFactory(call, inputType).get();
RelDataType fieldType;
// For a decimal type Accumulator::returnType returns a type with default precision and scale,
// that can cause precision loss when a tuple is sent over the wire by an exchanger/outbox.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,17 +35,17 @@ public class LiteralValAccumulatorTest extends BaseIgniteAbstractTest {
public void test() {
StatefulAccumulator accumulator = newCall();

// Literal agg accepts the same value.
accumulator.add("1");
accumulator.add("2");

assertEquals(true, accumulator.end());
assertEquals("1", accumulator.end());
}

@Test
public void empty() {
StatefulAccumulator accumulator = newCall();

assertEquals(false, accumulator.end());
assertEquals(null, accumulator.end());
}

private StatefulAccumulator newCall() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ public void testHashAggRowType() {
.build();

AggregateCall call1 = newCall(typeFactory.createSqlType(SqlTypeName.BIGINT));
Accumulator acc1 = accumulators.accumulatorFactory(call1).get();
Accumulator acc1 = accumulators.accumulatorFactory(call1, inputType).get();

RelDataType expectedType = new RelDataTypeFactory.Builder(typeFactory)
.add("f1", typeFactory.createSqlType(SqlTypeName.INTEGER))
Expand Down Expand Up @@ -83,7 +83,7 @@ public void testSortAggRowType() {
.build();

AggregateCall call1 = newCall(typeFactory.createSqlType(SqlTypeName.BIGINT));
Accumulator acc1 = accumulators.accumulatorFactory(call1).get();
Accumulator acc1 = accumulators.accumulatorFactory(call1, inputType).get();

RelDataType expectedType = new RelDataTypeFactory.Builder(typeFactory)
.add("f1", typeFactory.createSqlType(SqlTypeName.INTEGER))
Expand Down

0 comments on commit 89de7e1

Please sign in to comment.