diff --git a/coral-hive/src/main/java/com/linkedin/coral/transformers/CoralRelToSqlNodeConverter.java b/coral-hive/src/main/java/com/linkedin/coral/transformers/CoralRelToSqlNodeConverter.java index 4f2175aeb..1297476fd 100644 --- a/coral-hive/src/main/java/com/linkedin/coral/transformers/CoralRelToSqlNodeConverter.java +++ b/coral-hive/src/main/java/com/linkedin/coral/transformers/CoralRelToSqlNodeConverter.java @@ -350,14 +350,17 @@ private SqlNode generateRightChildForSqlJoinWithLateralViews(BiRel e, Result rig } /** - * Override this method to handle the conversion for RelNode `f(x).y.z` where `f` is an operator, which - * returns a struct containing field `y`, `y` is also a struct containing field `z`. + * Override this method to handle the conversion for {@link RexFieldAccess} `f(x).y.z` where `f` is an operator, + * which returns a struct containing field `y`, `y` is also a struct containing field `z`. * - * Calcite will convert this RelNode to a SqlIdentifier directly (check + * Calcite will convert this RelNode to a {@link SqlIdentifier} directly (check * {@link org.apache.calcite.rel.rel2sql.SqlImplementor.Context#toSql(RexProgram, RexNode)}), * which is not aligned with our expectation since we want to apply transformations on `f(x)` with * {@link com.linkedin.coral.common.transformers.SqlCallTransformer}. Therefore, we override this - * method to convert `f(x)` to SqlCall, `.` to {@link com.linkedin.coral.common.functions.FunctionFieldReferenceOperator#DOT} + * method to convert `f(x)` to {@link SqlCall}, `.` to {@link com.linkedin.coral.common.functions.FunctionFieldReferenceOperator#DOT}, + * so `f(x).y.z` will be converted to `(f(x).y).z`. + * + * Check `CoralSparkTest#testConvertFieldAccessOnFunctionCall` for unit test and example. */ @Override public Context aliasContext(Map aliases, boolean qualified) { @@ -373,7 +376,8 @@ public SqlNode toSql(RexProgram program, RexNode rex) { accessNames.add(((RexFieldAccess) referencedExpr).getField().getName()); referencedExpr = ((RexFieldAccess) referencedExpr).getReferenceExpr(); } - if (referencedExpr.getKind() == SqlKind.OTHER_FUNCTION || referencedExpr.getKind() == SqlKind.CAST) { + final SqlKind sqlKind = referencedExpr.getKind(); + if (sqlKind == SqlKind.OTHER_FUNCTION || sqlKind == SqlKind.CAST || sqlKind == SqlKind.ROW) { SqlNode functionCall = toSql(program, referencedExpr); Collections.reverse(accessNames); for (String accessName : accessNames) { diff --git a/coral-spark/src/main/java/com/linkedin/coral/spark/CoralToSparkSqlCallConverter.java b/coral-spark/src/main/java/com/linkedin/coral/spark/CoralToSparkSqlCallConverter.java index 6a45a7a74..c92e8d9e8 100644 --- a/coral-spark/src/main/java/com/linkedin/coral/spark/CoralToSparkSqlCallConverter.java +++ b/coral-spark/src/main/java/com/linkedin/coral/spark/CoralToSparkSqlCallConverter.java @@ -9,147 +9,151 @@ import org.apache.calcite.sql.SqlCall; import org.apache.calcite.sql.SqlNode; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; import org.apache.calcite.sql.util.SqlShuttle; import com.linkedin.coral.common.transformers.OperatorRenameSqlCallTransformer; import com.linkedin.coral.common.transformers.SqlCallTransformers; import com.linkedin.coral.spark.containers.SparkUDFInfo; -import com.linkedin.coral.spark.transformers.FallBackToHiveUDFTransformer; -import com.linkedin.coral.spark.transformers.TransportableUDFTransformer; +import com.linkedin.coral.spark.transformers.FallBackToLinkedInHiveUDFTransformer; +import com.linkedin.coral.spark.transformers.TransportUDFTransformer; -import static com.linkedin.coral.spark.transformers.TransportableUDFTransformer.*; +import static com.linkedin.coral.spark.transformers.TransportUDFTransformer.*; /** * This class extends the class of {@link org.apache.calcite.sql.util.SqlShuttle} and initialize a {@link com.linkedin.coral.common.transformers.SqlCallTransformers} * which containing a list of {@link com.linkedin.coral.common.transformers.SqlCallTransformer} to traverse the hierarchy of a {@link org.apache.calcite.sql.SqlCall} * and converts the functions from Coral operator to Spark operator if it is required + * + * In this converter, we need to apply {@link TransportUDFTransformer} before {@link FallBackToLinkedInHiveUDFTransformer} + * because we should try to transform a UDF to an equivalent Transport UDF before falling back to LinkedIn Hive UDF. */ public class CoralToSparkSqlCallConverter extends SqlShuttle { private final SqlCallTransformers sqlCallTransformers; public CoralToSparkSqlCallConverter(Set sparkUDFInfos) { this.sqlCallTransformers = SqlCallTransformers.of( - // Transportable UDFs - new TransportableUDFTransformer("com.linkedin.dali.udf.date.hive.DateFormatToEpoch", + // Transport UDFs + new TransportUDFTransformer("com.linkedin.dali.udf.date.hive.DateFormatToEpoch", "com.linkedin.stdudfs.daliudfs.spark.DateFormatToEpoch", DALI_UDFS_IVY_URL_SPARK_2_11, DALI_UDFS_IVY_URL_SPARK_2_12, sparkUDFInfos), - new TransportableUDFTransformer("com.linkedin.dali.udf.date.hive.EpochToDateFormat", + new TransportUDFTransformer("com.linkedin.dali.udf.date.hive.EpochToDateFormat", "com.linkedin.stdudfs.daliudfs.spark.EpochToDateFormat", DALI_UDFS_IVY_URL_SPARK_2_11, DALI_UDFS_IVY_URL_SPARK_2_12, sparkUDFInfos), - new TransportableUDFTransformer("com.linkedin.dali.udf.date.hive.EpochToEpochMilliseconds", + new TransportUDFTransformer("com.linkedin.dali.udf.date.hive.EpochToEpochMilliseconds", "com.linkedin.stdudfs.daliudfs.spark.EpochToEpochMilliseconds", DALI_UDFS_IVY_URL_SPARK_2_11, DALI_UDFS_IVY_URL_SPARK_2_12, sparkUDFInfos), - new TransportableUDFTransformer("com.linkedin.dali.udf.isguestmemberid.hive.IsGuestMemberId", + new TransportUDFTransformer("com.linkedin.dali.udf.isguestmemberid.hive.IsGuestMemberId", "com.linkedin.stdudfs.daliudfs.spark.IsGuestMemberId", DALI_UDFS_IVY_URL_SPARK_2_11, DALI_UDFS_IVY_URL_SPARK_2_12, sparkUDFInfos), - new TransportableUDFTransformer("com.linkedin.dali.udf.istestmemberid.hive.IsTestMemberId", + new TransportUDFTransformer("com.linkedin.dali.udf.istestmemberid.hive.IsTestMemberId", "com.linkedin.stdudfs.daliudfs.spark.IsTestMemberId", DALI_UDFS_IVY_URL_SPARK_2_11, DALI_UDFS_IVY_URL_SPARK_2_12, sparkUDFInfos), - new TransportableUDFTransformer("com.linkedin.dali.udf.maplookup.hive.MapLookup", + new TransportUDFTransformer("com.linkedin.dali.udf.maplookup.hive.MapLookup", "com.linkedin.stdudfs.daliudfs.spark.MapLookup", DALI_UDFS_IVY_URL_SPARK_2_11, DALI_UDFS_IVY_URL_SPARK_2_12, sparkUDFInfos), - new TransportableUDFTransformer("com.linkedin.dali.udf.sanitize.hive.Sanitize", + new TransportUDFTransformer("com.linkedin.dali.udf.sanitize.hive.Sanitize", "com.linkedin.stdudfs.daliudfs.spark.Sanitize", DALI_UDFS_IVY_URL_SPARK_2_11, DALI_UDFS_IVY_URL_SPARK_2_12, sparkUDFInfos), - new TransportableUDFTransformer("com.linkedin.dali.udf.watbotcrawlerlookup.hive.WATBotCrawlerLookup", + new TransportUDFTransformer("com.linkedin.dali.udf.watbotcrawlerlookup.hive.WATBotCrawlerLookup", "com.linkedin.stdudfs.daliudfs.spark.WatBotCrawlerLookup", DALI_UDFS_IVY_URL_SPARK_2_11, DALI_UDFS_IVY_URL_SPARK_2_12, sparkUDFInfos), - new TransportableUDFTransformer("com.linkedin.stdudfs.daliudfs.hive.DateFormatToEpoch", + new TransportUDFTransformer("com.linkedin.stdudfs.daliudfs.hive.DateFormatToEpoch", "com.linkedin.stdudfs.daliudfs.spark.DateFormatToEpoch", DALI_UDFS_IVY_URL_SPARK_2_11, DALI_UDFS_IVY_URL_SPARK_2_12, sparkUDFInfos), - new TransportableUDFTransformer("com.linkedin.stdudfs.daliudfs.hive.EpochToDateFormat", + new TransportUDFTransformer("com.linkedin.stdudfs.daliudfs.hive.EpochToDateFormat", "com.linkedin.stdudfs.daliudfs.spark.EpochToDateFormat", DALI_UDFS_IVY_URL_SPARK_2_11, DALI_UDFS_IVY_URL_SPARK_2_12, sparkUDFInfos), - new TransportableUDFTransformer("com.linkedin.stdudfs.daliudfs.hive.EpochToEpochMilliseconds", + new TransportUDFTransformer("com.linkedin.stdudfs.daliudfs.hive.EpochToEpochMilliseconds", "com.linkedin.stdudfs.daliudfs.spark.EpochToEpochMilliseconds", DALI_UDFS_IVY_URL_SPARK_2_11, DALI_UDFS_IVY_URL_SPARK_2_12, sparkUDFInfos), - new TransportableUDFTransformer("com.linkedin.stdudfs.daliudfs.hive.GetProfileSections", + new TransportUDFTransformer("com.linkedin.stdudfs.daliudfs.hive.GetProfileSections", "com.linkedin.stdudfs.daliudfs.spark.GetProfileSections", DALI_UDFS_IVY_URL_SPARK_2_11, DALI_UDFS_IVY_URL_SPARK_2_12, sparkUDFInfos), - new TransportableUDFTransformer("com.linkedin.stdudfs.stringudfs.hive.InitCap", + new TransportUDFTransformer("com.linkedin.stdudfs.stringudfs.hive.InitCap", "com.linkedin.stdudfs.stringudfs.spark.InitCap", "ivy://com.linkedin.standard-udfs-common-sql-udfs:standard-udfs-string-udfs:1.0.1?classifier=spark_2.11", "ivy://com.linkedin.standard-udfs-common-sql-udfs:standard-udfs-string-udfs:1.0.1?classifier=spark_2.12", sparkUDFInfos), - new TransportableUDFTransformer("com.linkedin.stdudfs.daliudfs.hive.IsGuestMemberId", + new TransportUDFTransformer("com.linkedin.stdudfs.daliudfs.hive.IsGuestMemberId", "com.linkedin.stdudfs.daliudfs.spark.IsGuestMemberId", DALI_UDFS_IVY_URL_SPARK_2_11, DALI_UDFS_IVY_URL_SPARK_2_12, sparkUDFInfos), - new TransportableUDFTransformer("com.linkedin.stdudfs.daliudfs.hive.IsTestMemberId", + new TransportUDFTransformer("com.linkedin.stdudfs.daliudfs.hive.IsTestMemberId", "com.linkedin.stdudfs.daliudfs.spark.IsTestMemberId", DALI_UDFS_IVY_URL_SPARK_2_11, DALI_UDFS_IVY_URL_SPARK_2_12, sparkUDFInfos), - new TransportableUDFTransformer("com.linkedin.stdudfs.daliudfs.hive.MapLookup", + new TransportUDFTransformer("com.linkedin.stdudfs.daliudfs.hive.MapLookup", "com.linkedin.stdudfs.daliudfs.spark.MapLookup", DALI_UDFS_IVY_URL_SPARK_2_11, DALI_UDFS_IVY_URL_SPARK_2_12, sparkUDFInfos), - new TransportableUDFTransformer("com.linkedin.stdudfs.daliudfs.hive.PortalLookup", + new TransportUDFTransformer("com.linkedin.stdudfs.daliudfs.hive.PortalLookup", "com.linkedin.stdudfs.daliudfs.spark.PortalLookup", DALI_UDFS_IVY_URL_SPARK_2_11, DALI_UDFS_IVY_URL_SPARK_2_12, sparkUDFInfos), - new TransportableUDFTransformer("com.linkedin.stdudfs.daliudfs.hive.Sanitize", + new TransportUDFTransformer("com.linkedin.stdudfs.daliudfs.hive.Sanitize", "com.linkedin.stdudfs.daliudfs.spark.Sanitize", DALI_UDFS_IVY_URL_SPARK_2_11, DALI_UDFS_IVY_URL_SPARK_2_12, sparkUDFInfos), - new TransportableUDFTransformer("com.linkedin.stdudfs.userinterfacelookup.hive.UserInterfaceLookup", + new TransportUDFTransformer("com.linkedin.stdudfs.userinterfacelookup.hive.UserInterfaceLookup", "com.linkedin.stdudfs.userinterfacelookup.spark.UserInterfaceLookup", "ivy://com.linkedin.standard-udf-userinterfacelookup:userinterfacelookup-std-udf:0.0.27?classifier=spark_2.11", "ivy://com.linkedin.standard-udf-userinterfacelookup:userinterfacelookup-std-udf:0.0.27?classifier=spark_2.12", sparkUDFInfos), - new TransportableUDFTransformer("com.linkedin.stdudfs.daliudfs.hive.WatBotCrawlerLookup", + new TransportUDFTransformer("com.linkedin.stdudfs.daliudfs.hive.WatBotCrawlerLookup", "com.linkedin.stdudfs.daliudfs.spark.WatBotCrawlerLookup", DALI_UDFS_IVY_URL_SPARK_2_11, DALI_UDFS_IVY_URL_SPARK_2_12, sparkUDFInfos), - new TransportableUDFTransformer("com.linkedin.jemslookup.udf.hive.JemsLookup", + new TransportUDFTransformer("com.linkedin.jemslookup.udf.hive.JemsLookup", "com.linkedin.jemslookup.udf.spark.JemsLookup", "ivy://com.linkedin.jobs-udf:jems-udfs:2.1.7?classifier=spark_2.11", "ivy://com.linkedin.jobs-udf:jems-udfs:2.1.7?classifier=spark_2.12", sparkUDFInfos), - new TransportableUDFTransformer("com.linkedin.stdudfs.parsing.hive.UserAgentParser", + new TransportUDFTransformer("com.linkedin.stdudfs.parsing.hive.UserAgentParser", "com.linkedin.stdudfs.parsing.spark.UserAgentParser", "ivy://com.linkedin.standard-udfs-parsing:parsing-stdudfs:3.0.3?classifier=spark_2.11", "ivy://com.linkedin.standard-udfs-parsing:parsing-stdudfs:3.0.3?classifier=spark_2.12", sparkUDFInfos), - new TransportableUDFTransformer("com.linkedin.stdudfs.parsing.hive.Ip2Str", + new TransportUDFTransformer("com.linkedin.stdudfs.parsing.hive.Ip2Str", "com.linkedin.stdudfs.parsing.spark.Ip2Str", "ivy://com.linkedin.standard-udfs-parsing:parsing-stdudfs:3.0.3?classifier=spark_2.11", "ivy://com.linkedin.standard-udfs-parsing:parsing-stdudfs:3.0.3?classifier=spark_2.12", sparkUDFInfos), - new TransportableUDFTransformer("com.linkedin.stdudfs.lookup.hive.BrowserLookup", + new TransportUDFTransformer("com.linkedin.stdudfs.lookup.hive.BrowserLookup", "com.linkedin.stdudfs.lookup.spark.BrowserLookup", "ivy://com.linkedin.standard-udfs-parsing:parsing-stdudfs:3.0.3?classifier=spark_2.11", "ivy://com.linkedin.standard-udfs-parsing:parsing-stdudfs:3.0.3?classifier=spark_2.12", sparkUDFInfos), - new TransportableUDFTransformer("com.linkedin.jobs.udf.hive.ConvertIndustryCode", + new TransportUDFTransformer("com.linkedin.jobs.udf.hive.ConvertIndustryCode", "com.linkedin.jobs.udf.spark.ConvertIndustryCode", "ivy://com.linkedin.jobs-udf:jobs-udfs:2.1.6?classifier=spark_2.11", "ivy://com.linkedin.jobs-udf:jobs-udfs:2.1.6?classifier=spark_2.12", sparkUDFInfos), - // Transportable UDF for unit test - new TransportableUDFTransformer("com.linkedin.coral.hive.hive2rel.CoralTestUDF", + // Transport UDF for unit test + new TransportUDFTransformer("com.linkedin.coral.hive.hive2rel.CoralTestUDF", "com.linkedin.coral.spark.CoralTestUDF", "ivy://com.linkedin.coral.spark.CoralTestUDF?classifier=spark_2.11", null, sparkUDFInfos), // Built-in operator - new OperatorRenameSqlCallTransformer("CARDINALITY", 1, "size"), + new OperatorRenameSqlCallTransformer(SqlStdOperatorTable.CARDINALITY, 1, "size"), // Fall back to the original Hive UDF defined in StaticHiveFunctionRegistry after failing to apply transformers above - new FallBackToHiveUDFTransformer(sparkUDFInfos)); + new FallBackToLinkedInHiveUDFTransformer(sparkUDFInfos)); } @Override diff --git a/coral-spark/src/main/java/com/linkedin/coral/spark/SparkSqlRewriter.java b/coral-spark/src/main/java/com/linkedin/coral/spark/SparkSqlRewriter.java index 5207826a6..73dd86e07 100644 --- a/coral-spark/src/main/java/com/linkedin/coral/spark/SparkSqlRewriter.java +++ b/coral-spark/src/main/java/com/linkedin/coral/spark/SparkSqlRewriter.java @@ -46,6 +46,8 @@ public class SparkSqlRewriter extends SqlShuttle { * is translated to * SELECT named_struct(.....) * + * Check `CoralSparkTest#testAvoidCastToRow` for unit test and a more complex example. + * * Also replaces: * * CAST(NULL AS NULL) @@ -70,15 +72,14 @@ && containsSqlRowTypeSpec((SqlDataTypeSpec) call.getOperandList().get(1))) { private boolean containsSqlRowTypeSpec(SqlDataTypeSpec sqlDataTypeSpec) { if (sqlDataTypeSpec instanceof SqlRowTypeSpec) { return true; - } - if (sqlDataTypeSpec instanceof SqlArrayTypeSpec) { + } else if (sqlDataTypeSpec instanceof SqlArrayTypeSpec) { return containsSqlRowTypeSpec(((SqlArrayTypeSpec) sqlDataTypeSpec).getElementTypeSpec()); - } - if (sqlDataTypeSpec instanceof SqlMapTypeSpec) { + } else if (sqlDataTypeSpec instanceof SqlMapTypeSpec) { return containsSqlRowTypeSpec(((SqlMapTypeSpec) sqlDataTypeSpec).getKeyTypeSpec()) || containsSqlRowTypeSpec(((SqlMapTypeSpec) sqlDataTypeSpec).getValueTypeSpec()); + } else { + return false; } - return false; } /** diff --git a/coral-spark/src/main/java/com/linkedin/coral/spark/transformers/FallBackToHiveUDFTransformer.java b/coral-spark/src/main/java/com/linkedin/coral/spark/transformers/FallBackToLinkedInHiveUDFTransformer.java similarity index 69% rename from coral-spark/src/main/java/com/linkedin/coral/spark/transformers/FallBackToHiveUDFTransformer.java rename to coral-spark/src/main/java/com/linkedin/coral/spark/transformers/FallBackToLinkedInHiveUDFTransformer.java index a6060053d..a727ca37c 100644 --- a/coral-spark/src/main/java/com/linkedin/coral/spark/transformers/FallBackToHiveUDFTransformer.java +++ b/coral-spark/src/main/java/com/linkedin/coral/spark/transformers/FallBackToLinkedInHiveUDFTransformer.java @@ -23,14 +23,14 @@ /** - * After failing to transform UDF with {@link TransportableUDFTransformer}, + * After failing to transform UDF with {@link TransportUDFTransformer}, * we use this transformer to fall back to the original Hive UDF defined in * {@link com.linkedin.coral.hive.hive2rel.functions.StaticHiveFunctionRegistry}. * This is reasonable since Spark understands and has ability to run Hive UDF. - * Check `CoralSparkTest#testFallBackToHiveUDFTransformer()` for an example. + * Check `CoralSparkTest#testFallBackToLinkedInHiveUDFTransformer()` for an example. */ -public class FallBackToHiveUDFTransformer extends SqlCallTransformer { - private static final Logger LOG = LoggerFactory.getLogger(FallBackToHiveUDFTransformer.class); +public class FallBackToLinkedInHiveUDFTransformer extends SqlCallTransformer { + private static final Logger LOG = LoggerFactory.getLogger(FallBackToLinkedInHiveUDFTransformer.class); /** * Some LinkedIn UDFs get registered correctly in a SparkSession, and hence a DataFrame is successfully @@ -46,33 +46,35 @@ public class FallBackToHiveUDFTransformer extends SqlCallTransformer { "com.linkedin.coral.hive.hive2rel.CoralTestUnsupportedUDF"); private final Set sparkUDFInfos; - public FallBackToHiveUDFTransformer(Set sparkUDFInfos) { + public FallBackToLinkedInHiveUDFTransformer(Set sparkUDFInfos) { this.sparkUDFInfos = sparkUDFInfos; } @Override protected boolean condition(SqlCall sqlCall) { - final String functionClassName = sqlCall.getOperator().getName(); - if (UNSUPPORTED_HIVE_UDFS.contains(functionClassName)) { - throw new UnsupportedUDFException(functionClassName); - } - return functionClassName.contains(".") && !functionClassName.equals("."); + final SqlOperator operator = sqlCall.getOperator(); + final String operatorName = operator.getName(); + return operator instanceof VersionedSqlUserDefinedFunction && operatorName.contains(".") + && !operatorName.equals("."); } @Override protected SqlCall transform(SqlCall sqlCall) { final VersionedSqlUserDefinedFunction operator = (VersionedSqlUserDefinedFunction) sqlCall.getOperator(); - final String functionClassName = operator.getName(); - final String expandedFunctionName = operator.getViewDependentFunctionName(); + final String operatorName = operator.getName(); + if (UNSUPPORTED_HIVE_UDFS.contains(operatorName)) { + throw new UnsupportedUDFException(operatorName); + } + final String viewDependentFunctionName = operator.getViewDependentFunctionName(); final List dependencies = operator.getIvyDependencies(); List listOfUris = dependencies.stream().map(URI::create).collect(Collectors.toList()); - LOG.info("Function: {} is not a Builtin UDF or Transportable UDF. We fall back to its Hive " - + "function with ivy dependency: {}", functionClassName, String.join(",", dependencies)); + LOG.info("Function: {} is not a Builtin UDF or Transport UDF. We fall back to its Hive " + + "function with ivy dependency: {}", operatorName, String.join(",", dependencies)); final SparkUDFInfo sparkUDFInfo = - new SparkUDFInfo(functionClassName, expandedFunctionName, listOfUris, SparkUDFInfo.UDFTYPE.HIVE_CUSTOM_UDF); + new SparkUDFInfo(operatorName, viewDependentFunctionName, listOfUris, SparkUDFInfo.UDFTYPE.HIVE_CUSTOM_UDF); sparkUDFInfos.add(sparkUDFInfo); final SqlOperator convertedFunction = - createSqlOperatorOfFunction(expandedFunctionName, operator.getReturnTypeInference()); + createSqlOperator(viewDependentFunctionName, operator.getReturnTypeInference()); return convertedFunction.createCall(sqlCall.getParserPosition(), sqlCall.getOperandList()); } } diff --git a/coral-spark/src/main/java/com/linkedin/coral/spark/transformers/TransportableUDFTransformer.java b/coral-spark/src/main/java/com/linkedin/coral/spark/transformers/TransportUDFTransformer.java similarity index 76% rename from coral-spark/src/main/java/com/linkedin/coral/spark/transformers/TransportableUDFTransformer.java rename to coral-spark/src/main/java/com/linkedin/coral/spark/transformers/TransportUDFTransformer.java index 4e1c3fcd4..272fd197d 100644 --- a/coral-spark/src/main/java/com/linkedin/coral/spark/transformers/TransportableUDFTransformer.java +++ b/coral-spark/src/main/java/com/linkedin/coral/spark/transformers/TransportUDFTransformer.java @@ -22,10 +22,10 @@ /** - * This transformer transforms legacy Hive UDFs to an equivalent Transportable UDF. - * Check `CoralSparkTest#testTransportableUDFTransformer()` for example. + * This transformer transforms legacy Hive UDFs to an equivalent registered Transport UDF. + * Check `CoralSparkTest#testTransportUDFTransformer()` for example. */ -public class TransportableUDFTransformer extends SqlCallTransformer { +public class TransportUDFTransformer extends SqlCallTransformer { private final String hiveUDFClassName; private final String sparkUDFClassName; private final String artifactoryUrlSpark211; @@ -33,7 +33,7 @@ public class TransportableUDFTransformer extends SqlCallTransformer { private final Set sparkUDFInfos; private ScalaVersion scalaVersion; - public TransportableUDFTransformer(String hiveUDFClassName, String sparkUDFClassName, String artifactoryUrlSpark211, + public TransportUDFTransformer(String hiveUDFClassName, String sparkUDFClassName, String artifactoryUrlSpark211, String artifactoryUrlSpark212, Set sparkUDFInfos) { this.hiveUDFClassName = hiveUDFClassName; this.sparkUDFClassName = sparkUDFClassName; @@ -42,7 +42,7 @@ public TransportableUDFTransformer(String hiveUDFClassName, String sparkUDFClass this.sparkUDFInfos = sparkUDFInfos; } - private static final Logger LOG = LoggerFactory.getLogger(TransportableUDFTransformer.class); + private static final Logger LOG = LoggerFactory.getLogger(TransportUDFTransformer.class); public static final String DALI_UDFS_IVY_URL_SPARK_2_11 = "ivy://com.linkedin.standard-udfs-dali-udfs:standard-udfs-dali-udfs:2.0.3?classifier=spark_2.11"; public static final String DALI_UDFS_IVY_URL_SPARK_2_12 = @@ -56,7 +56,8 @@ public enum ScalaVersion { @Override protected boolean condition(SqlCall sqlCall) { scalaVersion = getScalaVersionOfSpark(); - if (!hiveUDFClassName.equalsIgnoreCase(sqlCall.getOperator().getName())) { + if (!(sqlCall.getOperator() instanceof VersionedSqlUserDefinedFunction) + || !hiveUDFClassName.equalsIgnoreCase(sqlCall.getOperator().getName())) { return false; } if (scalaVersion == ScalaVersion.SCALA_2_11 && artifactoryUrlSpark211 != null @@ -72,23 +73,26 @@ protected boolean condition(SqlCall sqlCall) { @Override protected SqlCall transform(SqlCall sqlCall) { final VersionedSqlUserDefinedFunction operator = (VersionedSqlUserDefinedFunction) sqlCall.getOperator(); - final String functionName = operator.getViewDependentFunctionName(); - sparkUDFInfos.add(new SparkUDFInfo(sparkUDFClassName, functionName, + final String viewDependentFunctionName = operator.getViewDependentFunctionName(); + sparkUDFInfos.add(new SparkUDFInfo(sparkUDFClassName, viewDependentFunctionName, Collections.singletonList( URI.create(scalaVersion == ScalaVersion.SCALA_2_11 ? artifactoryUrlSpark211 : artifactoryUrlSpark212)), SparkUDFInfo.UDFTYPE.TRANSPORTABLE_UDF)); - final SqlOperator convertedFunction = createSqlOperatorOfFunction(functionName, operator.getReturnTypeInference()); + final SqlOperator convertedFunction = + createSqlOperator(viewDependentFunctionName, operator.getReturnTypeInference()); return convertedFunction.createCall(sqlCall.getParserPosition(), sqlCall.getOperandList()); } public ScalaVersion getScalaVersionOfSpark() { try { String sparkVersion = SparkSession.active().version(); - if (sparkVersion.matches("2\\.[\\d\\.]*")) + if (sparkVersion.matches("2\\.[\\d\\.]*")) { return ScalaVersion.SCALA_2_11; - if (sparkVersion.matches("3\\.[\\d\\.]*")) + } else if (sparkVersion.matches("3\\.[\\d\\.]*")) { return ScalaVersion.SCALA_2_12; - throw new IllegalStateException(String.format("Unsupported Spark Version %s", sparkVersion)); + } else { + throw new IllegalStateException(String.format("Unsupported Spark Version %s", sparkVersion)); + } } catch (IllegalStateException | NoClassDefFoundError ex) { LOG.warn("Couldn't determine Spark version, falling back to scala_2.11: {}", ex.getMessage()); return ScalaVersion.SCALA_2_11; diff --git a/coral-spark/src/spark3test/java/com/linkedin/coral/spark/TransportableUDFTransformerTest.java b/coral-spark/src/spark3test/java/com/linkedin/coral/spark/TransportUDFTransformerTest.java similarity index 56% rename from coral-spark/src/spark3test/java/com/linkedin/coral/spark/TransportableUDFTransformerTest.java rename to coral-spark/src/spark3test/java/com/linkedin/coral/spark/TransportUDFTransformerTest.java index f5c961e9b..dc0ae839c 100644 --- a/coral-spark/src/spark3test/java/com/linkedin/coral/spark/TransportableUDFTransformerTest.java +++ b/coral-spark/src/spark3test/java/com/linkedin/coral/spark/TransportUDFTransformerTest.java @@ -11,27 +11,27 @@ import org.testng.Assert; import org.testng.annotations.Test; -import com.linkedin.coral.spark.transformers.TransportableUDFTransformer; +import com.linkedin.coral.spark.transformers.TransportUDFTransformer; -public class TransportableUDFTransformerTest { - final TransportableUDFTransformer transportableUDFTransformer = new TransportableUDFTransformer( +public class TransportUDFTransformerTest { + final TransportUDFTransformer _transportUDFTransformer = new TransportUDFTransformer( "com.linkedin.coral.hive.hive2rel.CoralTestUDF", "com.linkedin.coral.spark.CoralTestUDF", "ivy://com.linkedin.coral.spark.CoralTestUDF", null, new HashSet<>()); @Test public void testScalaVersionWithSparkSession() { - SparkSession ss = SparkSession.builder().appName(TransportableUDFTransformerTest.class.getSimpleName()) + SparkSession ss = SparkSession.builder().appName(TransportUDFTransformerTest.class.getSimpleName()) .master("local[1]").enableHiveSupport().getOrCreate(); - Assert.assertEquals(transportableUDFTransformer.getScalaVersionOfSpark(), - TransportableUDFTransformer.ScalaVersion.SCALA_2_12); + Assert.assertEquals(_transportUDFTransformer.getScalaVersionOfSpark(), + TransportUDFTransformer.ScalaVersion.SCALA_2_12); ss.close(); } @Test public void testDefaultScalaVersion() { // If SparkSession is not active, getScalaVersion should return Scala2.11 - Assert.assertEquals(transportableUDFTransformer.getScalaVersionOfSpark(), - TransportableUDFTransformer.ScalaVersion.SCALA_2_11); + Assert.assertEquals(_transportUDFTransformer.getScalaVersionOfSpark(), + TransportUDFTransformer.ScalaVersion.SCALA_2_11); } } diff --git a/coral-spark/src/sparktest/java/com/linkedin/coral/spark/TransportableUDFTransformerTest.java b/coral-spark/src/sparktest/java/com/linkedin/coral/spark/TransportUDFTransformerTest.java similarity index 57% rename from coral-spark/src/sparktest/java/com/linkedin/coral/spark/TransportableUDFTransformerTest.java rename to coral-spark/src/sparktest/java/com/linkedin/coral/spark/TransportUDFTransformerTest.java index 9a08f55ea..9d68f3c6d 100644 --- a/coral-spark/src/sparktest/java/com/linkedin/coral/spark/TransportableUDFTransformerTest.java +++ b/coral-spark/src/sparktest/java/com/linkedin/coral/spark/TransportUDFTransformerTest.java @@ -11,28 +11,28 @@ import org.testng.Assert; import org.testng.annotations.Test; -import com.linkedin.coral.spark.transformers.TransportableUDFTransformer; +import com.linkedin.coral.spark.transformers.TransportUDFTransformer; -public class TransportableUDFTransformerTest { +public class TransportUDFTransformerTest { - final TransportableUDFTransformer transportableUDFTransformer = new TransportableUDFTransformer( + final TransportUDFTransformer _transportUDFTransformer = new TransportUDFTransformer( "com.linkedin.coral.hive.hive2rel.CoralTestUDF", "com.linkedin.coral.spark.CoralTestUDF", "ivy://com.linkedin.coral.spark.CoralTestUDF?classifier=spark_2.11", null, new HashSet<>()); @Test public void testScalaVersionWithSparkSession() { - SparkSession ss = SparkSession.builder().appName(TransportableUDFTransformerTest.class.getSimpleName()) + SparkSession ss = SparkSession.builder().appName(TransportUDFTransformerTest.class.getSimpleName()) .master("local[1]").enableHiveSupport().getOrCreate(); - Assert.assertEquals(transportableUDFTransformer.getScalaVersionOfSpark(), - TransportableUDFTransformer.ScalaVersion.SCALA_2_11); + Assert.assertEquals(_transportUDFTransformer.getScalaVersionOfSpark(), + TransportUDFTransformer.ScalaVersion.SCALA_2_11); ss.close(); } @Test public void testDefaultScalaVersion() { // If SparkSession is not active, getScalaVersionOfSpark should return Scala2.11 - Assert.assertEquals(transportableUDFTransformer.getScalaVersionOfSpark(), - TransportableUDFTransformer.ScalaVersion.SCALA_2_11); + Assert.assertEquals(_transportUDFTransformer.getScalaVersionOfSpark(), + TransportUDFTransformer.ScalaVersion.SCALA_2_11); } } diff --git a/coral-spark/src/test/java/com/linkedin/coral/spark/CoralSparkTest.java b/coral-spark/src/test/java/com/linkedin/coral/spark/CoralSparkTest.java index ac567e14e..96fec8591 100644 --- a/coral-spark/src/test/java/com/linkedin/coral/spark/CoralSparkTest.java +++ b/coral-spark/src/test/java/com/linkedin/coral/spark/CoralSparkTest.java @@ -101,9 +101,9 @@ public void testAllowBaseTableInView() { } @Test - public void testTransportableUDFTransformer() { - // Dali view foo_dali_udf contains a UDF defined with TransportableUDFTransformer. - // The actual values are determined by the parameter values of TransportableUDFTransformer. + public void testTransportUDFTransformer() { + // Dali view foo_dali_udf contains a UDF defined with TransportUDFTransformer. + // The actual values are determined by the parameter values of TransportUDFTransformer. RelNode relNode = TestUtils.toRelNode("default", "foo_dali_udf"); CoralSpark coralSpark = CoralSpark.create(relNode); List udfJars = coralSpark.getSparkUDFInfoList(); @@ -115,7 +115,7 @@ public void testTransportableUDFTransformer() { String udfFunctionName = udfJars.get(0).getFunctionName(); String targetFunctionName = "default_foo_dali_udf_LessThanHundred"; assertEquals(udfFunctionName, targetFunctionName); - // check if CoralSpark can fetch artifactory url defined in TransportableUDFTransformer + // check if CoralSpark can fetch artifactory url defined in TransportUDFTransformer List listOfUriStrings = convertToListOfUriStrings(udfJars.get(0).getArtifactoryUrls()); String targetArtifactoryUrl = "ivy://com.linkedin.coral.spark.CoralTestUDF?classifier=spark_2.11"; assertTrue(listOfUriStrings.contains(targetArtifactoryUrl)); @@ -129,8 +129,8 @@ public void testTransportableUDFTransformer() { } @Test - public void testFallBackToHiveUDFTransformer() { - // Dali view foo_dali_udf2 contains a UDF not defined with OperatorBasedSqlCallTransformer or TransportableUDFTransformer. + public void testFallBackToLinkedInHiveUDFTransformer() { + // Dali view foo_dali_udf2 contains a UDF not defined with OperatorBasedSqlCallTransformer or TransportUDFTransformer. // We need to fall back to the udf initially defined in HiveFunctionRegistry. // Then the function Name comes from Hive metastore in the format dbName_viewName_funcBaseName. RelNode relNode = TestUtils.toRelNode("default", "foo_dali_udf2"); @@ -165,7 +165,7 @@ public void testUnsupportedUdf() { @Test public void testTwoFunctionsWithDependencies() { - // Dali view foo_dali_udf3 contains 2 UDFs. One UDF is defined with TransportableUDFTransformer. The other one is not. + // Dali view foo_dali_udf3 contains 2 UDFs. One UDF is defined with TransportUDFTransformer. The other one is not. // We need to fall back the second one to the udf initially defined in HiveFunctionRegistry. RelNode relNode = TestUtils.toRelNode("default", "foo_dali_udf3"); CoralSpark coralSpark = CoralSpark.create(relNode); @@ -837,6 +837,22 @@ public void testRedundantCastRemovedFromCaseCall() { assertEquals(expandedSql, targetSql); } + @Test + public void testConvertFieldAccessOnFunctionCall() { + RelNode relNode = TestUtils.toRelNode("SELECT named_struct('a', named_struct('b', 1)).a.b"); + + String targetSql = "SELECT (named_struct('a', named_struct('b', 1)).a).b\n" + "FROM (VALUES (0)) t (ZERO)"; + assertEquals(CoralSpark.create(relNode).getSparkSql(), targetSql); + } + + @Test + public void testAvoidCastToRow() { + RelNode relNode = TestUtils.toRelNode("SELECT named_struct('a', array(named_struct('b', 1)))"); + + String targetSql = "SELECT named_struct('a', ARRAY (named_struct('b', 1)))\n" + "FROM (VALUES (0)) t (ZERO)"; + assertEquals(CoralSpark.create(relNode).getSparkSql(), targetSql); + } + private static String getCoralSparkTranslatedSqlWithAliasFromCoralSchema(String db, String view) { RelNode relNode = TestUtils.toRelNode(db, view); Schema schema = TestUtils.getAvroSchemaForView(db, view, false);