Skip to content

Commit

Permalink
[Coral-Trino] Migrate 'collect_list' and 'collect_set' transformation…
Browse files Browse the repository at this point in the history
…s from RelNode to SqlNode layer (#375)
  • Loading branch information
ljfgem authored Mar 28, 2023
1 parent c546ff3 commit ec523ba
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
import org.apache.calcite.sql.fun.SqlMapValueConstructor;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.parser.SqlParserPos;
import org.apache.calcite.sql.type.ReturnTypes;
import org.apache.calcite.sql.type.SqlReturnTypeInference;
import org.apache.calcite.sql.validate.SqlUserDefinedFunction;

Expand Down Expand Up @@ -213,13 +212,6 @@ public RexNode visitCall(RexCall call) {
}
}

if (operatorName.equalsIgnoreCase("collect_list") || operatorName.equalsIgnoreCase("collect_set")) {
Optional<RexNode> modifiedCall = visitCollectListOrSetFunction(call);
if (modifiedCall.isPresent()) {
return modifiedCall.get();
}
}

if (operatorName.equalsIgnoreCase("substr")) {
Optional<RexNode> modifiedCall = visitSubstring(call);
if (modifiedCall.isPresent()) {
Expand Down Expand Up @@ -262,18 +254,6 @@ private Optional<RexNode> visitConcat(RexCall call) {
return Optional.of(rexBuilder.makeCall(op, castOperands));
}

private Optional<RexNode> visitCollectListOrSetFunction(RexCall call) {
List<RexNode> convertedOperands = visitList(call.getOperands(), (boolean[]) null);
final SqlOperator arrayAgg = createSqlOperatorOfFunction("array_agg", FunctionReturnTypes.ARRAY_OF_ARG0_TYPE);
final SqlOperator arrayDistinct = createSqlOperatorOfFunction("array_distinct", ReturnTypes.ARG0_NULLABLE);
final String operatorName = call.getOperator().getName();
if (operatorName.equalsIgnoreCase("collect_list")) {
return Optional.of(rexBuilder.makeCall(arrayAgg, convertedOperands));
} else {
return Optional.of(rexBuilder.makeCall(arrayDistinct, rexBuilder.makeCall(arrayAgg, convertedOperands)));
}
}

private Optional<RexNode> visitFromUnixtime(RexCall call) {
List<RexNode> convertedOperands = visitList(call.getOperands(), (boolean[]) null);
SqlOperator formatDatetime = createSqlOperatorOfFunction("format_datetime", FunctionReturnTypes.STRING);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import com.linkedin.coral.common.transformers.SqlCallTransformers;
import com.linkedin.coral.hive.hive2rel.functions.StaticHiveFunctionRegistry;
import com.linkedin.coral.trino.rel2trino.functions.TrinoElementAtFunction;
import com.linkedin.coral.trino.rel2trino.transformers.CollectListOrSetFunctionTransformer;
import com.linkedin.coral.trino.rel2trino.transformers.CoralRegistryOperatorRenameSqlCallTransformer;
import com.linkedin.coral.trino.rel2trino.transformers.GenericCoralRegistryOperatorRenameSqlCallTransformer;
import com.linkedin.coral.trino.rel2trino.transformers.ToDateOperatorTransformer;
Expand Down Expand Up @@ -49,7 +50,7 @@ public CoralToTrinoSqlCallConverter(Map<String, Boolean> configs) {
protected SqlCall transform(SqlCall sqlCall) {
return TrinoElementAtFunction.INSTANCE.createCall(SqlParserPos.ZERO, sqlCall.getOperandList());
}
},
}, new CollectListOrSetFunctionTransformer(),
// math functions
new OperatorRenameSqlCallTransformer(SqlStdOperatorTable.RAND, 0, "RANDOM"),
new JsonTransformSqlCallTransformer(SqlStdOperatorTable.RAND, 1, "RANDOM", "[]", null, null),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/**
* Copyright 2023 LinkedIn Corporation. All rights reserved.
* Licensed under the BSD-2 Clause license.
* See LICENSE in the project root for license information.
*/
package com.linkedin.coral.trino.rel2trino.transformers;

import java.util.ArrayList;
import java.util.List;

import org.apache.calcite.sql.SqlCall;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.parser.SqlParserPos;

import com.linkedin.coral.common.functions.FunctionReturnTypes;
import com.linkedin.coral.common.transformers.SqlCallTransformer;


/**
* This class implements the transformation from the operations "collect_list" and "collect_set" to their
* respective Trino-compatible versions.
*
* For example, "collect_list(col)" is transformed into "array_agg(col)", and
* "collect_set(col)" is transformed into "array_distinct(array_agg(col))".
*/
public class CollectListOrSetFunctionTransformer extends SqlCallTransformer {

private static final String COLLECT_LIST = "collect_list";
private static final String COLLECT_SET = "collect_set";
private static final String ARRAY_AGG = "array_agg";
private static final String ARRAY_DISTINCT = "array_distinct";

@Override
protected boolean condition(SqlCall sqlCall) {
final String operatorName = sqlCall.getOperator().getName();
return operatorName.equalsIgnoreCase(COLLECT_LIST) || operatorName.equalsIgnoreCase(COLLECT_SET);
}

@Override
protected SqlCall transform(SqlCall sqlCall) {
final String operatorName = sqlCall.getOperator().getName();
final SqlOperator arrayAgg = createSqlOperator(ARRAY_AGG, FunctionReturnTypes.ARRAY_OF_ARG0_TYPE);
final SqlOperator arrayDistinct = createSqlOperator(ARRAY_DISTINCT, FunctionReturnTypes.ARRAY_OF_ARG0_TYPE);

final List<SqlNode> operands = new ArrayList<>(sqlCall.getOperandList());

if (operatorName.equalsIgnoreCase(COLLECT_LIST)) {
return arrayAgg.createCall(SqlParserPos.ZERO, operands);
} else {
return arrayDistinct.createCall(SqlParserPos.ZERO, arrayAgg.createCall(SqlParserPos.ZERO, operands));
}
}
}

0 comments on commit ec523ba

Please sign in to comment.