From 575e3171b883f4d5949f9d98b66e3cfd0eb67e85 Mon Sep 17 00:00:00 2001 From: Jiefan Li Date: Thu, 6 Apr 2023 20:00:43 -0400 Subject: [PATCH] [Coral-Trino] Migrate 'adjustReturnTypeWithCast' transformation from RelNode to SqlNode layer (#385) --- .../rel2trino/Calcite2TrinoUDFConverter.java | 29 +------- .../CoralToTrinoSqlCallConverter.java | 5 +- .../ReturnTypeAdjustmentTransformer.java | 68 +++++++++++++++++++ .../rel2trino/HiveToTrinoConverterTest.java | 4 +- 4 files changed, 75 insertions(+), 31 deletions(-) create mode 100644 coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/transformers/ReturnTypeAdjustmentTransformer.java diff --git a/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/Calcite2TrinoUDFConverter.java b/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/Calcite2TrinoUDFConverter.java index bd441426e..7b34c90d6 100644 --- a/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/Calcite2TrinoUDFConverter.java +++ b/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/Calcite2TrinoUDFConverter.java @@ -8,12 +8,9 @@ import java.math.BigDecimal; import java.util.ArrayList; import java.util.List; -import java.util.Locale; import java.util.Map; import java.util.Optional; -import com.google.common.collect.ImmutableMap; - import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.RelShuttle; import org.apache.calcite.rel.RelShuttleImpl; @@ -213,7 +210,7 @@ public RexNode visitCall(RexCall call) { } } - return adjustReturnTypeWithCast(rexBuilder, super.visitCall(call)); + return super.visitCall(call); } private Optional visitConcat(RexCall call) { @@ -358,30 +355,6 @@ private Optional visitCast(RexCall call) { return Optional.empty(); } - - /** - * This method is to cast the converted call to the same return type in Hive with certain version. - * e.g. `datediff` in Hive returns int type, but the corresponding function `date_diff` in Trino returns bigint type - * the type discrepancy would cause the issue while querying the view on Trino, so we need to add the CAST for them - */ - private RexNode adjustReturnTypeWithCast(RexBuilder rexBuilder, RexNode call) { - if (!(call instanceof RexCall)) { - return call; - } - final String lowercaseOperatorName = ((RexCall) call).getOperator().getName().toLowerCase(Locale.ROOT); - final ImmutableMap operatorsToAdjust = - ImmutableMap.of("datediff", typeFactory.createSqlType(INTEGER), "cardinality", - typeFactory.createSqlType(INTEGER), "ceil", typeFactory.createSqlType(BIGINT), "ceiling", - typeFactory.createSqlType(BIGINT), "floor", typeFactory.createSqlType(BIGINT)); - if (operatorsToAdjust.containsKey(lowercaseOperatorName)) { - return rexBuilder.makeCast(operatorsToAdjust.get(lowercaseOperatorName), call); - } - if (configs.getOrDefault(CAST_DATEADD_TO_STRING, false) - && (lowercaseOperatorName.equals("date_add") || lowercaseOperatorName.equals("date_sub"))) { - return rexBuilder.makeCast(typeFactory.createSqlType(VARCHAR), call); - } - return call; - } } private static SqlOperator createSqlOperatorOfFunction(String functionName, SqlReturnTypeInference typeInference) { diff --git a/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/CoralToTrinoSqlCallConverter.java b/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/CoralToTrinoSqlCallConverter.java index c3b5c5324..4a67d6ac2 100644 --- a/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/CoralToTrinoSqlCallConverter.java +++ b/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/CoralToTrinoSqlCallConverter.java @@ -28,6 +28,7 @@ import com.linkedin.coral.trino.rel2trino.transformers.CurrentTimestampTransformer; import com.linkedin.coral.trino.rel2trino.transformers.GenericCoralRegistryOperatorRenameSqlCallTransformer; import com.linkedin.coral.trino.rel2trino.transformers.MapValueConstructorTransformer; +import com.linkedin.coral.trino.rel2trino.transformers.ReturnTypeAdjustmentTransformer; import com.linkedin.coral.trino.rel2trino.transformers.ToDateOperatorTransformer; import static com.linkedin.coral.trino.rel2trino.CoralTrinoConfigKeys.*; @@ -116,7 +117,9 @@ protected SqlCall transform(SqlCall sqlCall) { "com.linkedin.stdudfs.urnextractor.hive.UrnExtractorFunctionWrapper", 1, "urn_extractor"), new CoralRegistryOperatorRenameSqlCallTransformer( "com.linkedin.stdudfs.hive.daliudfs.UrnExtractorFunctionWrapper", 1, "urn_extractor"), - new GenericCoralRegistryOperatorRenameSqlCallTransformer()); + new GenericCoralRegistryOperatorRenameSqlCallTransformer(), + + new ReturnTypeAdjustmentTransformer(configs)); } private SqlOperator hiveToCoralSqlOperator(String functionName) { diff --git a/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/transformers/ReturnTypeAdjustmentTransformer.java b/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/transformers/ReturnTypeAdjustmentTransformer.java new file mode 100644 index 000000000..4ddf342c4 --- /dev/null +++ b/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/transformers/ReturnTypeAdjustmentTransformer.java @@ -0,0 +1,68 @@ +/** + * Copyright 2023 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.coral.trino.rel2trino.transformers; + +import java.util.Locale; +import java.util.Map; + +import com.google.common.collect.ImmutableMap; + +import org.apache.calcite.sql.SqlBasicTypeNameSpec; +import org.apache.calcite.sql.SqlCall; +import org.apache.calcite.sql.SqlDataTypeSpec; +import org.apache.calcite.sql.SqlNode; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.apache.calcite.sql.parser.SqlParserPos; +import org.apache.calcite.sql.type.SqlTypeName; + +import com.linkedin.coral.common.transformers.SqlCallTransformer; + +import static com.linkedin.coral.trino.rel2trino.CoralTrinoConfigKeys.*; + + +/** + * This transformer casts the result of some Trino functions to the same return type as in Hive. + * + * Example: + * "DATEDIFF" function in Hive returns int type, but the corresponding function "DATE_DIFF" in Trino + * returns bigint type. To ensure compatibility, a "CAST" is added to convert the result to the int type. + */ +public class ReturnTypeAdjustmentTransformer extends SqlCallTransformer { + + private static final Map OPERATORS_TO_ADJUST = ImmutableMap. builder() + .put("date_diff", SqlTypeName.INTEGER).put("cardinality", SqlTypeName.INTEGER).put("ceil", SqlTypeName.BIGINT) + .put("ceiling", SqlTypeName.BIGINT).put("floor", SqlTypeName.BIGINT).put("date_add", SqlTypeName.VARCHAR).build(); + private final Map configs; + + public ReturnTypeAdjustmentTransformer(Map configs) { + this.configs = configs; + } + + @Override + protected boolean condition(SqlCall sqlCall) { + String lowercaseOperatorName = sqlCall.getOperator().getName().toLowerCase(Locale.ROOT); + if ("date_add".equals(lowercaseOperatorName) && !configs.getOrDefault(CAST_DATEADD_TO_STRING, false)) { + return false; + } + return OPERATORS_TO_ADJUST.containsKey(lowercaseOperatorName); + } + + @Override + protected SqlCall transform(SqlCall sqlCall) { + String lowercaseOperatorName = sqlCall.getOperator().getName().toLowerCase(Locale.ROOT); + SqlTypeName targetType = OPERATORS_TO_ADJUST.get(lowercaseOperatorName); + if (targetType != null) { + return createCast(sqlCall, targetType); + } + return sqlCall; + } + + private SqlCall createCast(SqlNode node, SqlTypeName typeName) { + SqlDataTypeSpec targetTypeNode = + new SqlDataTypeSpec(new SqlBasicTypeNameSpec(typeName, SqlParserPos.ZERO), SqlParserPos.ZERO); + return SqlStdOperatorTable.CAST.createCall(SqlParserPos.ZERO, node, targetTypeNode); + } +} diff --git a/coral-trino/src/test/java/com/linkedin/coral/trino/rel2trino/HiveToTrinoConverterTest.java b/coral-trino/src/test/java/com/linkedin/coral/trino/rel2trino/HiveToTrinoConverterTest.java index 8dfc725d2..af1bfb22e 100644 --- a/coral-trino/src/test/java/com/linkedin/coral/trino/rel2trino/HiveToTrinoConverterTest.java +++ b/coral-trino/src/test/java/com/linkedin/coral/trino/rel2trino/HiveToTrinoConverterTest.java @@ -124,7 +124,7 @@ public Object[][] viewTestCasesProvider() { { "test", "view_with_outer_explode_struct_array", "SELECT \"$cor0\".\"a\" AS \"a\", \"t0\".\"c\" AS \"c\"\n" + "FROM \"test\".\"table_with_struct_array\" AS \"$cor0\"\n" - + "CROSS JOIN UNNEST(TRANSFORM(\"if\"(\"$cor0\".\"b\" IS NOT NULL AND CAST(CARDINALITY(\"$cor0\".\"b\") AS INTEGER) > 0, \"$cor0\".\"b\", ARRAY[NULL]), x -> ROW(x))) AS \"t0\" (\"c\")" }, + + "CROSS JOIN UNNEST(TRANSFORM(\"if\"(\"$cor0\".\"b\" IS NOT NULL AND CARDINALITY(\"$cor0\".\"b\") > 0, \"$cor0\".\"b\", ARRAY[NULL]), x -> ROW(x))) AS \"t0\" (\"c\")" }, { "test", "view_with_explode_map", "SELECT \"$cor0\".\"a\" AS \"a\", \"t0\".\"c\" AS \"c\", \"t0\".\"d\" AS \"d\"\n" + "FROM \"test\".\"table_with_map\" AS \"$cor0\"\n" @@ -546,7 +546,7 @@ public void testTypeCastForDataAddFunction() { RelNode relNode = hiveToRelConverter.convertSql( "SELECT date_add('2021-08-20', 1), date_add('2021-08-20 00:00:00', 1), date_sub('2021-08-20', 1), date_sub('2021-08-20 00:00:00', 1)"); String targetSql = - "SELECT CAST(\"date_add\"('day', 1, \"date\"(CAST('2021-08-20' AS TIMESTAMP))) AS VARCHAR(65535)), CAST(\"date_add\"('day', 1, \"date\"(CAST('2021-08-20 00:00:00' AS TIMESTAMP))) AS VARCHAR(65535)), CAST(\"date_add\"('day', 1 * -1, \"date\"(CAST('2021-08-20' AS TIMESTAMP))) AS VARCHAR(65535)), CAST(\"date_add\"('day', 1 * -1, \"date\"(CAST('2021-08-20 00:00:00' AS TIMESTAMP))) AS VARCHAR(65535))\n" + "SELECT CAST(\"date_add\"('day', 1, \"date\"(CAST('2021-08-20' AS TIMESTAMP))) AS VARCHAR), CAST(\"date_add\"('day', 1, \"date\"(CAST('2021-08-20 00:00:00' AS TIMESTAMP))) AS VARCHAR), CAST(\"date_add\"('day', 1 * -1, \"date\"(CAST('2021-08-20' AS TIMESTAMP))) AS VARCHAR), CAST(\"date_add\"('day', 1 * -1, \"date\"(CAST('2021-08-20 00:00:00' AS TIMESTAMP))) AS VARCHAR)\n" + "FROM (VALUES (0)) AS \"t\" (\"ZERO\")"; String expandedSql = relToTrinoConverter.convert(relNode); assertEquals(expandedSql, targetSql);