Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate CoralRelToSqlNodeConverter in CoralRelNode to trino SQL translation path. #315

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.apache.calcite.rel.logical.LogicalTableFunctionScan;
import org.apache.calcite.rel.rel2sql.RelToSqlConverter;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexCorrelVariable;
import org.apache.calcite.rex.RexFieldAccess;
Expand All @@ -35,7 +36,9 @@
import org.apache.calcite.sql.SqlLateralOperator;
import org.apache.calcite.sql.SqlLiteral;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.SqlNodeList;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.SqlUtil;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.parser.SqlParserPos;
import org.apache.calcite.sql.type.SqlTypeName;
Expand Down Expand Up @@ -75,6 +78,31 @@ private static SqlDialect returnInstance() {
return new SqlDialect(context);
}

// override is required to prevent select nodes such as SELECT CAST(NULL AS NULL)
@Override
public Result visit(Project e) {
e.getVariablesSet();
Result x = visitChild(0, e.getInput());
parseCorrelTable(e, x);
if (isStar(e.getChildExps(), e.getInput().getRowType(), e.getRowType())) {
return x;
}
final Builder builder = x.builder(e, Clause.SELECT);
final List<SqlNode> selectList = new ArrayList<>();
for (RexNode ref : e.getChildExps()) {
SqlNode sqlExpr = builder.context.toSql(null, ref);

RelDataTypeField targetField = e.getRowType().getFieldList().get(selectList.size());
if (SqlUtil.isNullLiteral(sqlExpr, false) && !targetField.getValue().getSqlTypeName().equals(SqlTypeName.NULL)) {
sqlExpr = SqlStdOperatorTable.CAST.createCall(POS, sqlExpr, dialect.getCastSpec(targetField.getType()));
}

addSelect(selectList, sqlExpr, e.getRowType());
}
builder.setSelect(new SqlNodeList(selectList, POS));
return builder.result();
}

/**
* TableScan RelNode represents a relational operator that returns the contents of a table.
* Super's implementation generates a table namespace with the catalog, schema, and table name.
Expand Down Expand Up @@ -140,11 +168,7 @@ public Result visit(Correlate e) {

final Result rightResult = visitChild(1, e.getRight());

SqlNode rightSqlNode = rightResult.asFrom();

if (e.getRight() instanceof LogicalTableFunctionScan || e.getRight() instanceof Uncollect) {
rightSqlNode = generateRightChildForSqlJoinWithLateralViews(e, rightResult);
}
SqlNode rightSqlNode = generateRightChildForSqlJoinWithLateralViews(e, rightResult);

SqlNode join = new SqlJoin(POS, leftResult.asFrom(), SqlLiteral.createBoolean(false, POS),
JoinType.COMMA.symbol(POS), rightSqlNode, JoinConditionType.NONE.symbol(POS), null);
Expand Down Expand Up @@ -333,15 +357,21 @@ public Result visit(Uncollect e) {

private SqlNode generateRightChildForSqlJoinWithLateralViews(BiRel e, Result rightResult) {
SqlNode rightSqlNode = rightResult.asFrom();
SqlNode lateralNode;

final SqlNode rightLateral = SqlStdOperatorTable.LATERAL.createCall(POS, rightSqlNode);
// Drop the AS operator from the rightSqlNode if it exists and append the LATERAL operator on the inner SqlNode.
if (rightSqlNode instanceof SqlCall && ((SqlCall) rightSqlNode).getOperator().kind == SqlKind.AS) {
lateralNode = SqlStdOperatorTable.LATERAL.createCall(POS, (SqlNode) ((SqlCall) rightSqlNode).operand(0));
} else {
lateralNode = SqlStdOperatorTable.LATERAL.createCall(POS, rightSqlNode);
}
Comment on lines +362 to +367
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add some context here (i.e. why we need to drop the AS operator from the rightSqlNode if it exists and append the LATERAL operator on the inner SqlNode), preferably with an example?


// Append the alias to unnestCall by generating SqlCall with AS operator
// Append the alias to lateralNode by generating SqlCall with AS operator
RelDataType relDataType = e.getRight().getRowType();
String alias = rightResult.aliases.entrySet().stream().filter(entry -> relDataType.equals(entry.getValue()))
.findFirst().map(Map.Entry::getKey).orElse("coralDefaultColumnAlias");

List<SqlNode> asOperands = createAsFullOperands(relDataType, rightLateral, alias);
List<SqlNode> asOperands = createAsFullOperands(relDataType, lateralNode, alias);

return SqlStdOperatorTable.AS.createCall(POS, asOperands);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,240 @@
/**
* Copyright 2022-2023 LinkedIn Corporation. All rights reserved.
* Licensed under the BSD-2 Clause license.
* See LICENSE in the project root for license information.
*/
package com.linkedin.coral.trino.rel2trino;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;

import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFieldImpl;
import org.apache.calcite.rel.type.RelRecordType;
import org.apache.calcite.sql.JoinConditionType;
import org.apache.calcite.sql.JoinType;
import org.apache.calcite.sql.SqlBasicCall;
import org.apache.calcite.sql.SqlBasicTypeNameSpec;
import org.apache.calcite.sql.SqlCall;
import org.apache.calcite.sql.SqlCharStringLiteral;
import org.apache.calcite.sql.SqlDataTypeSpec;
import org.apache.calcite.sql.SqlIdentifier;
import org.apache.calcite.sql.SqlJoin;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlLiteral;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.SqlNodeList;
import org.apache.calcite.sql.SqlSelect;
import org.apache.calcite.sql.SqlTypeNameSpec;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.parser.SqlParserPos;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.sql.util.SqlShuttle;
import org.apache.calcite.sql.validate.SqlValidatorUtil;

import com.linkedin.coral.com.google.common.collect.ImmutableList;
import com.linkedin.coral.common.functions.CoralSqlUnnestOperator;
import com.linkedin.coral.trino.rel2trino.functions.TrinoArrayTransformFunction;

import static org.apache.calcite.rel.rel2sql.SqlImplementor.*;
import static org.apache.calcite.sql.parser.SqlParserPos.*;


/**
* CoralSqlNodeToTrinoSqlNodeConverter rewrites the Coral SqlNode AST. It replaces Coral IR SqlCalls
* with Trino compatible SqlCalls to subsequently obtain a Trino compatible SqlNode AST representation.
* This will enable generating a SQL which can be accurately interpreted by the Trino engine.
*
* This is achieved by visiting the Coral SqlNode AST in a pre-order traversal manner and
* transforming each SqlNode (SqlCall), wherever required.
* The transformation may involve change in operator, reordering the operands
* or even re-constructing the SqlCall.
*/
public class CoralSqlNodeToTrinoSqlNodeConverter extends SqlShuttle {

public CoralSqlNodeToTrinoSqlNodeConverter() {
}

@Override
public SqlNode visit(final SqlCall call) {
SqlCall transformedSqlCall = getTransformedSqlCall(call);
return super.visit(transformedSqlCall);
}

public static SqlCall getTransformedSqlCall(SqlCall sqlCall) {
switch (sqlCall.getOperator().kind) {
case SELECT:
return getTransformedSqlSelectSqlCall(sqlCall);
case JOIN:
return getTransformedJoinSqlCall(sqlCall);
case AS:
return getTransformedAsSqlCall(sqlCall);
case UNNEST:
return getTransformedUnnestSqlCall(sqlCall);
case EQUALS:
case GREATER_THAN:
case GREATER_THAN_OR_EQUAL:
case LESS_THAN:
case LESS_THAN_OR_EQUAL:
case NOT_EQUALS:
return castOperandsToVarchar(sqlCall);
default:
return sqlCall;
}
}

// Append TryCast operator to both operands to cast each operand's data type to VARCHAR
private static SqlCall castOperandsToVarchar(SqlCall sqlCall) {
List<SqlNode> updatedOperands = new ArrayList<>();

final SqlTypeNameSpec varcharTypeNameSpec = new SqlBasicTypeNameSpec(SqlTypeName.VARCHAR, ZERO);
SqlDataTypeSpec varcharSqlDataTypeSpec = new SqlDataTypeSpec(varcharTypeNameSpec, ZERO);

for (SqlNode operand : sqlCall.getOperandList()) {
SqlNode newOperand = TrinoTryCastFunction.INSTANCE.createCall(POS,
new ArrayList<>(Arrays.asList(operand, varcharSqlDataTypeSpec)));
updatedOperands.add(newOperand);
}

return sqlCall.getOperator().createCall(POS, updatedOperands);
}

// Update unnest operand for trino engine to expand the unnest operand to a single column
private static SqlCall getTransformedUnnestSqlCall(SqlCall sqlCall) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like the SUPPORT_LEGACY_UNNEST_ARRAY_OF_STRUCT is not used, which may cause regression for LI internal use? See #158 for more details.

if (!(sqlCall.getOperator() instanceof CoralSqlUnnestOperator)) {
return sqlCall;
}

CoralSqlUnnestOperator operator = (CoralSqlUnnestOperator) sqlCall.getOperator();
SqlNode unnestOperand = sqlCall.operand(0);

// Transform UNNEST(fieldName) to UNNEST(TRANSFORM(fieldName, x -> ROW(x)))
if (operator.getRelDataType() != null) {
String fieldName = "empty";

if (unnestOperand instanceof SqlIdentifier) {
SqlIdentifier operand = (SqlIdentifier) unnestOperand;
fieldName = operand.toSqlString(TrinoSqlDialect.INSTANCE).getSql();
} else if (unnestOperand instanceof SqlCall
&& ((SqlCall) unnestOperand).getOperator().getName().equalsIgnoreCase("if")) {
// for trino outer unnest, unnest has an inner SqlCall with "if" operator
fieldName = unnestOperand.toSqlString(TrinoSqlDialect.INSTANCE).getSql();
}
SqlCharStringLiteral transformArgsLiteral =
SqlLiteral.createCharString(String.format("%s, x -> ROW(x)", fieldName), POS);

// Generate expected recordType required for transformatioin
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: transformation

RelDataType recordType = operator.getRelDataType();
RelRecordType transformDataType =
new RelRecordType(ImmutableList.of(new RelDataTypeFieldImpl("wrapper_field", 0, recordType)));

unnestOperand = new TrinoArrayTransformFunction(transformDataType).createCall(POS, transformArgsLiteral);
}

return operator.createCall(POS, new ArrayList<>(Collections.singletonList(unnestOperand)).toArray(new SqlNode[0]));
}

private static SqlCall getTransformedSqlSelectSqlCall(SqlCall sqlCall) {
if (((SqlSelect) sqlCall).getSelectList() != null && ((SqlSelect) sqlCall).getSelectList().size() != 0) {
final List<SqlNode> modifiedSelectList = new ArrayList<>();

for (SqlNode selectNode : ((SqlSelect) sqlCall).getSelectList().getList()) {
final String name = SqlValidatorUtil.getAlias(selectNode, -1);
final boolean nestedFieldAccess =
selectNode instanceof SqlIdentifier && ((SqlIdentifier) selectNode).names.size() > 1;

// always add "AS" when accessing nested fields.
if (nestedFieldAccess) {
selectNode = SqlStdOperatorTable.AS.createCall(POS, selectNode, new SqlIdentifier(name, POS));
}
modifiedSelectList.add(selectNode);
}
((SqlSelect) sqlCall).setSelectList(new SqlNodeList(modifiedSelectList, POS));
}
return sqlCall;
}

private static SqlCall getTransformedJoinSqlCall(SqlCall sqlCall) {
SqlJoin joinSqlCall = (SqlJoin) sqlCall;

if (joinSqlCall.getJoinType() != JoinType.COMMA) {
return sqlCall;
}

/**
* check if there's an unnest SqlCall present in the nested SqlNodes:
* false -> substitute COMMA JOIN with CROSS JOIN
* true -> check if unnest operand is an inline independent array (not referring to columns in the SQL)
* true -> return
* false -> substitute COMMA JOIN with CROSS JOIN
*/
if (isUnnestOperatorPresentInChildNode(joinSqlCall.getRight())) {
if (shouldSwapForCrossJoin(joinSqlCall.getRight())) {
return createCrossJoinSqlCall(joinSqlCall);
} else {
return sqlCall;
}
} else {
return createCrossJoinSqlCall(joinSqlCall);
}
}

private static SqlCall getTransformedAsSqlCall(SqlCall sqlCall) {
if (sqlCall.operandCount() <= 2 || !(sqlCall.operand(0) instanceof SqlBasicCall)
|| !(sqlCall.operand(0) instanceof SqlBasicCall && sqlCall.operand(0).getKind() == SqlKind.LATERAL)) {
return sqlCall;
}

List<SqlNode> oldAliasOperands = sqlCall.getOperandList();
List<SqlNode> newAliasOperands = new ArrayList<>();
SqlCall lateralSqlCall = sqlCall.operand(0);

// Drop the LATERAL operator when a lateralSqlCall's child operator is UNNEST
SqlCall newAliasFirstOperand =
lateralSqlCall.operand(0).getKind() == SqlKind.UNNEST ? lateralSqlCall.operand(0) : lateralSqlCall;

newAliasOperands.add(newAliasFirstOperand);
newAliasOperands.addAll(oldAliasOperands.subList(1, oldAliasOperands.size()));

return SqlStdOperatorTable.AS.createCall(ZERO, newAliasOperands);
}

private static boolean isUnnestOperatorPresentInChildNode(SqlNode sqlNode) {
if (sqlNode instanceof SqlCall && sqlNode.getKind() == SqlKind.AS
&& ((SqlCall) sqlNode).operand(0) instanceof SqlCall
&& ((SqlCall) sqlNode).operand(0).getKind() == SqlKind.LATERAL
&& ((SqlCall) ((SqlCall) sqlNode).operand(0)).operand(0) instanceof SqlCall
&& ((SqlCall) ((SqlCall) sqlNode).operand(0)).operand(0).getKind() == SqlKind.UNNEST) {
return true;
}
return false;
}

private static boolean shouldSwapForCrossJoin(SqlNode sqlNode) {
SqlNode aliasOperand = ((SqlCall) sqlNode).operand(0); // LATERAL unnest(x)
SqlNode lateralOperand = ((SqlCall) aliasOperand).operand(0); // unnest(x)
SqlNode unnestOperand = ((SqlCall) lateralOperand).operand(0);

// Field to unnest can be:
// (1) a SqlIdentifier referring to a column, ex: table1.col1
// (2) a SqlCall with "if" operator for outer unnest
// (3) a SqlSelect SqlCall
// For the above scenarios, return true
if (unnestOperand.getKind() == SqlKind.IDENTIFIER
|| (unnestOperand instanceof SqlCall
&& ((SqlCall) unnestOperand).getOperator().getName().equalsIgnoreCase("if"))
|| (lateralOperand.getKind() == SqlKind.SELECT)) { // should go to cross join
return true;
}
// If the unnest operand is an inline defined array, return false
return false;
}

private static SqlCall createCrossJoinSqlCall(SqlCall sqlCall) {
return new SqlJoin(POS, ((SqlJoin) sqlCall).getLeft(), SqlLiteral.createBoolean(false, SqlParserPos.ZERO),
JoinType.CROSS.symbol(POS), ((SqlJoin) sqlCall).getRight(), JoinConditionType.NONE.symbol(SqlParserPos.ZERO),
null);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

import com.linkedin.coral.com.google.common.collect.ImmutableList;
import com.linkedin.coral.hive.hive2rel.rel.HiveUncollect;
import com.linkedin.coral.transformers.CoralRelToSqlNodeConverter;
import com.linkedin.coral.trino.rel2trino.functions.TrinoArrayTransformFunction;

import static com.google.common.base.Preconditions.*;
Expand Down Expand Up @@ -76,8 +77,26 @@ public RelToTrinoConverter(Map<String, Boolean> configs) {
* @return SQL string
*/
public String convert(RelNode relNode) {
RelNode rel = convertRel(relNode, configs);
return convertToSqlNode(rel).accept(new TrinoSqlRewriter()).toSqlString(TrinoSqlDialect.INSTANCE).toString();
return convertDash(relNode);

// RelNode rel = convertRel(relNode, configs);
// SqlNode oldSqlNode = convertToSqlNode(rel);
// return oldSqlNode.accept(new TrinoSqlRewriter()).toSqlString(TrinoSqlDialect.INSTANCE).toString();
}

/**
* Temporary method to enable translations via CoralSqlNodeToTrinoSqlNodeConverter
*/
public String convertDash(RelNode relNode) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we give this method a more descriptive name? Not following what convertDash means.

RelNode trinoRelNode = convertRel(relNode, configs);
SqlNode coralSqlNode = convertToCoralSqlNode(trinoRelNode);
System.out.println("New coralSqlNode for trino: " + coralSqlNode);

SqlNode trinoSqlNode = coralSqlNode.accept(new CoralSqlNodeToTrinoSqlNodeConverter());
System.out.println("New trinoSqlNode for trino: " + trinoSqlNode);

SqlNode rewrittenTrinoSqlNode = trinoSqlNode.accept(new TrinoSqlRewriter());
return rewrittenTrinoSqlNode.toSqlString(TrinoSqlDialect.INSTANCE).toString();
}

/**
Expand All @@ -89,6 +108,15 @@ public SqlNode convertToSqlNode(RelNode relNode) {
return visitChild(0, relNode).asStatement();
}

/**
* Convert input relational algebra to CoralSqlNode
* @param relNode relation algebra
* @return CoralSqlNode representation for input
*/
public SqlNode convertToCoralSqlNode(RelNode relNode) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need this explicit method given it's simple and only called in convertDash.

return new CoralRelToSqlNodeConverter().convert(relNode);
}

/**
* @see #dispatch(RelNode)
* @param window Relnode representing window clause
Expand Down
Loading