Skip to content

Commit

Permalink
[Coral-Trino] Migrate 'adjustReturnTypeWithCast' transformation from …
Browse files Browse the repository at this point in the history
…RelNode to SqlNode layer (#385)
  • Loading branch information
ljfgem authored Apr 7, 2023
1 parent 825943f commit 575e317
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,9 @@
import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;

import com.google.common.collect.ImmutableMap;

import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.RelShuttle;
import org.apache.calcite.rel.RelShuttleImpl;
Expand Down Expand Up @@ -213,7 +210,7 @@ public RexNode visitCall(RexCall call) {
}
}

return adjustReturnTypeWithCast(rexBuilder, super.visitCall(call));
return super.visitCall(call);
}

private Optional<RexNode> visitConcat(RexCall call) {
Expand Down Expand Up @@ -358,30 +355,6 @@ private Optional<RexNode> visitCast(RexCall call) {

return Optional.empty();
}

/**
* This method is to cast the converted call to the same return type in Hive with certain version.
* e.g. `datediff` in Hive returns int type, but the corresponding function `date_diff` in Trino returns bigint type
* the type discrepancy would cause the issue while querying the view on Trino, so we need to add the CAST for them
*/
private RexNode adjustReturnTypeWithCast(RexBuilder rexBuilder, RexNode call) {
if (!(call instanceof RexCall)) {
return call;
}
final String lowercaseOperatorName = ((RexCall) call).getOperator().getName().toLowerCase(Locale.ROOT);
final ImmutableMap<String, RelDataType> operatorsToAdjust =
ImmutableMap.of("datediff", typeFactory.createSqlType(INTEGER), "cardinality",
typeFactory.createSqlType(INTEGER), "ceil", typeFactory.createSqlType(BIGINT), "ceiling",
typeFactory.createSqlType(BIGINT), "floor", typeFactory.createSqlType(BIGINT));
if (operatorsToAdjust.containsKey(lowercaseOperatorName)) {
return rexBuilder.makeCast(operatorsToAdjust.get(lowercaseOperatorName), call);
}
if (configs.getOrDefault(CAST_DATEADD_TO_STRING, false)
&& (lowercaseOperatorName.equals("date_add") || lowercaseOperatorName.equals("date_sub"))) {
return rexBuilder.makeCast(typeFactory.createSqlType(VARCHAR), call);
}
return call;
}
}

private static SqlOperator createSqlOperatorOfFunction(String functionName, SqlReturnTypeInference typeInference) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import com.linkedin.coral.trino.rel2trino.transformers.CurrentTimestampTransformer;
import com.linkedin.coral.trino.rel2trino.transformers.GenericCoralRegistryOperatorRenameSqlCallTransformer;
import com.linkedin.coral.trino.rel2trino.transformers.MapValueConstructorTransformer;
import com.linkedin.coral.trino.rel2trino.transformers.ReturnTypeAdjustmentTransformer;
import com.linkedin.coral.trino.rel2trino.transformers.ToDateOperatorTransformer;

import static com.linkedin.coral.trino.rel2trino.CoralTrinoConfigKeys.*;
Expand Down Expand Up @@ -116,7 +117,9 @@ protected SqlCall transform(SqlCall sqlCall) {
"com.linkedin.stdudfs.urnextractor.hive.UrnExtractorFunctionWrapper", 1, "urn_extractor"),
new CoralRegistryOperatorRenameSqlCallTransformer(
"com.linkedin.stdudfs.hive.daliudfs.UrnExtractorFunctionWrapper", 1, "urn_extractor"),
new GenericCoralRegistryOperatorRenameSqlCallTransformer());
new GenericCoralRegistryOperatorRenameSqlCallTransformer(),

new ReturnTypeAdjustmentTransformer(configs));
}

private SqlOperator hiveToCoralSqlOperator(String functionName) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/**
* Copyright 2023 LinkedIn Corporation. All rights reserved.
* Licensed under the BSD-2 Clause license.
* See LICENSE in the project root for license information.
*/
package com.linkedin.coral.trino.rel2trino.transformers;

import java.util.Locale;
import java.util.Map;

import com.google.common.collect.ImmutableMap;

import org.apache.calcite.sql.SqlBasicTypeNameSpec;
import org.apache.calcite.sql.SqlCall;
import org.apache.calcite.sql.SqlDataTypeSpec;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.parser.SqlParserPos;
import org.apache.calcite.sql.type.SqlTypeName;

import com.linkedin.coral.common.transformers.SqlCallTransformer;

import static com.linkedin.coral.trino.rel2trino.CoralTrinoConfigKeys.*;


/**
* This transformer casts the result of some Trino functions to the same return type as in Hive.
*
* Example:
* "DATEDIFF" function in Hive returns int type, but the corresponding function "DATE_DIFF" in Trino
* returns bigint type. To ensure compatibility, a "CAST" is added to convert the result to the int type.
*/
public class ReturnTypeAdjustmentTransformer extends SqlCallTransformer {

private static final Map<String, SqlTypeName> OPERATORS_TO_ADJUST = ImmutableMap.<String, SqlTypeName> builder()
.put("date_diff", SqlTypeName.INTEGER).put("cardinality", SqlTypeName.INTEGER).put("ceil", SqlTypeName.BIGINT)
.put("ceiling", SqlTypeName.BIGINT).put("floor", SqlTypeName.BIGINT).put("date_add", SqlTypeName.VARCHAR).build();
private final Map<String, Boolean> configs;

public ReturnTypeAdjustmentTransformer(Map<String, Boolean> configs) {
this.configs = configs;
}

@Override
protected boolean condition(SqlCall sqlCall) {
String lowercaseOperatorName = sqlCall.getOperator().getName().toLowerCase(Locale.ROOT);
if ("date_add".equals(lowercaseOperatorName) && !configs.getOrDefault(CAST_DATEADD_TO_STRING, false)) {
return false;
}
return OPERATORS_TO_ADJUST.containsKey(lowercaseOperatorName);
}

@Override
protected SqlCall transform(SqlCall sqlCall) {
String lowercaseOperatorName = sqlCall.getOperator().getName().toLowerCase(Locale.ROOT);
SqlTypeName targetType = OPERATORS_TO_ADJUST.get(lowercaseOperatorName);
if (targetType != null) {
return createCast(sqlCall, targetType);
}
return sqlCall;
}

private SqlCall createCast(SqlNode node, SqlTypeName typeName) {
SqlDataTypeSpec targetTypeNode =
new SqlDataTypeSpec(new SqlBasicTypeNameSpec(typeName, SqlParserPos.ZERO), SqlParserPos.ZERO);
return SqlStdOperatorTable.CAST.createCall(SqlParserPos.ZERO, node, targetTypeNode);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ public Object[][] viewTestCasesProvider() {

{ "test", "view_with_outer_explode_struct_array", "SELECT \"$cor0\".\"a\" AS \"a\", \"t0\".\"c\" AS \"c\"\n"
+ "FROM \"test\".\"table_with_struct_array\" AS \"$cor0\"\n"
+ "CROSS JOIN UNNEST(TRANSFORM(\"if\"(\"$cor0\".\"b\" IS NOT NULL AND CAST(CARDINALITY(\"$cor0\".\"b\") AS INTEGER) > 0, \"$cor0\".\"b\", ARRAY[NULL]), x -> ROW(x))) AS \"t0\" (\"c\")" },
+ "CROSS JOIN UNNEST(TRANSFORM(\"if\"(\"$cor0\".\"b\" IS NOT NULL AND CARDINALITY(\"$cor0\".\"b\") > 0, \"$cor0\".\"b\", ARRAY[NULL]), x -> ROW(x))) AS \"t0\" (\"c\")" },

{ "test", "view_with_explode_map", "SELECT \"$cor0\".\"a\" AS \"a\", \"t0\".\"c\" AS \"c\", \"t0\".\"d\" AS \"d\"\n"
+ "FROM \"test\".\"table_with_map\" AS \"$cor0\"\n"
Expand Down Expand Up @@ -546,7 +546,7 @@ public void testTypeCastForDataAddFunction() {
RelNode relNode = hiveToRelConverter.convertSql(
"SELECT date_add('2021-08-20', 1), date_add('2021-08-20 00:00:00', 1), date_sub('2021-08-20', 1), date_sub('2021-08-20 00:00:00', 1)");
String targetSql =
"SELECT CAST(\"date_add\"('day', 1, \"date\"(CAST('2021-08-20' AS TIMESTAMP))) AS VARCHAR(65535)), CAST(\"date_add\"('day', 1, \"date\"(CAST('2021-08-20 00:00:00' AS TIMESTAMP))) AS VARCHAR(65535)), CAST(\"date_add\"('day', 1 * -1, \"date\"(CAST('2021-08-20' AS TIMESTAMP))) AS VARCHAR(65535)), CAST(\"date_add\"('day', 1 * -1, \"date\"(CAST('2021-08-20 00:00:00' AS TIMESTAMP))) AS VARCHAR(65535))\n"
"SELECT CAST(\"date_add\"('day', 1, \"date\"(CAST('2021-08-20' AS TIMESTAMP))) AS VARCHAR), CAST(\"date_add\"('day', 1, \"date\"(CAST('2021-08-20 00:00:00' AS TIMESTAMP))) AS VARCHAR), CAST(\"date_add\"('day', 1 * -1, \"date\"(CAST('2021-08-20' AS TIMESTAMP))) AS VARCHAR), CAST(\"date_add\"('day', 1 * -1, \"date\"(CAST('2021-08-20 00:00:00' AS TIMESTAMP))) AS VARCHAR)\n"
+ "FROM (VALUES (0)) AS \"t\" (\"ZERO\")";
String expandedSql = relToTrinoConverter.convert(relNode);
assertEquals(expandedSql, targetSql);
Expand Down

0 comments on commit 575e317

Please sign in to comment.