From 97666d14544bd1d31809ed0e07caca9120b0dd66 Mon Sep 17 00:00:00 2001 From: Liangliang Gu Date: Tue, 27 Apr 2021 13:11:33 +0800 Subject: [PATCH] support spark 3.1.1 (#2024) --- .ci/integration_test.groovy | 1 + assembly/src/main/assembly/assembly.xml | 21 ++ .../com/pingcap/tispark/TiSparkInfo.scala | 2 +- .../tispark/utils/ReflectionUtil.scala | 155 ++++++++++++++ .../com/pingcap/tispark/utils/TiUtil.scala | 38 +++- .../apache/spark/sql/TiAggregationImpl.scala | 13 +- .../org/apache/spark/sql/TiStrategy.scala | 19 +- .../expressions/BasicExpression.scala | 100 +-------- .../apache/spark/sql/extensions/parser.scala | 134 +----------- .../apache/spark/sql/extensions/rules.scala | 194 +----------------- .../catalyst/catalog/CatalogTestSuite.scala | 1 - pom.xml | 15 +- spark-wrapper/spark-3.0/pom.xml | 123 +++++++++++ .../com/pingcap/tispark/SparkWrapper.scala | 41 ++++ .../expressions/TiBasicExpression.scala | 132 ++++++++++++ .../spark/sql/extensions/TiDDLRule.scala | 102 +++++++++ .../spark/sql/extensions/TiParser.scala | 157 ++++++++++++++ .../sql/extensions/TiResolutionRule.scala | 100 +++++++++ .../sql/extensions/TiResolutionRuleV2.scala | 75 +++++++ spark-wrapper/spark-3.1/pom.xml | 123 +++++++++++ .../com/pingcap/tispark/SparkWrapper.scala | 52 +++++ .../expressions/TiBasicExpression.scala | 132 ++++++++++++ .../spark/sql/extensions/TiDDLRule.scala | 120 +++++++++++ .../spark/sql/extensions/TiParser.scala | 133 ++++++++++++ .../sql/extensions/TiResolutionRule.scala | 100 +++++++++ .../sql/extensions/TiResolutionRuleV2.scala | 75 +++++++ 26 files changed, 1717 insertions(+), 441 deletions(-) create mode 100644 core/src/main/scala/com/pingcap/tispark/utils/ReflectionUtil.scala create mode 100644 spark-wrapper/spark-3.0/pom.xml create mode 100644 spark-wrapper/spark-3.0/src/main/scala/com/pingcap/tispark/SparkWrapper.scala create mode 100644 spark-wrapper/spark-3.0/src/main/scala/org/apache/spark/sql/catalyst/expressions/TiBasicExpression.scala create mode 100644 spark-wrapper/spark-3.0/src/main/scala/org/apache/spark/sql/extensions/TiDDLRule.scala create mode 100644 spark-wrapper/spark-3.0/src/main/scala/org/apache/spark/sql/extensions/TiParser.scala create mode 100644 spark-wrapper/spark-3.0/src/main/scala/org/apache/spark/sql/extensions/TiResolutionRule.scala create mode 100644 spark-wrapper/spark-3.0/src/main/scala/org/apache/spark/sql/extensions/TiResolutionRuleV2.scala create mode 100644 spark-wrapper/spark-3.1/pom.xml create mode 100644 spark-wrapper/spark-3.1/src/main/scala/com/pingcap/tispark/SparkWrapper.scala create mode 100644 spark-wrapper/spark-3.1/src/main/scala/org/apache/spark/sql/catalyst/expressions/TiBasicExpression.scala create mode 100644 spark-wrapper/spark-3.1/src/main/scala/org/apache/spark/sql/extensions/TiDDLRule.scala create mode 100644 spark-wrapper/spark-3.1/src/main/scala/org/apache/spark/sql/extensions/TiParser.scala create mode 100644 spark-wrapper/spark-3.1/src/main/scala/org/apache/spark/sql/extensions/TiResolutionRule.scala create mode 100644 spark-wrapper/spark-3.1/src/main/scala/org/apache/spark/sql/extensions/TiResolutionRuleV2.scala diff --git a/.ci/integration_test.groovy b/.ci/integration_test.groovy index 4241ed1fa9..547077f7e0 100644 --- a/.ci/integration_test.groovy +++ b/.ci/integration_test.groovy @@ -241,6 +241,7 @@ def call(ghprbActualCommit, ghprbCommentBody, ghprbPullId, ghprbPullTitle, ghprb """ sh """ export MAVEN_OPTS="-Xmx6G -XX:MaxPermSize=512M" + mvn clean package -DskipTests mvn test ${MVN_PROFILE} -Dtest=moo ${mvnStr} """ } diff --git a/assembly/src/main/assembly/assembly.xml b/assembly/src/main/assembly/assembly.xml index 2cf3316c8c..4490305238 100644 --- a/assembly/src/main/assembly/assembly.xml +++ b/assembly/src/main/assembly/assembly.xml @@ -17,4 +17,25 @@ true + + + + + ${project.parent.basedir}/spark-wrapper/spark-3.0/target/classes/ + + resources/spark-wrapper-spark-3_0 + + **/* + + + + + ${project.parent.basedir}/spark-wrapper/spark-3.1/target/classes/ + + resources/spark-wrapper-spark-3_1 + + **/* + + + diff --git a/core/src/main/scala/com/pingcap/tispark/TiSparkInfo.scala b/core/src/main/scala/com/pingcap/tispark/TiSparkInfo.scala index de13741e70..b3f5764ea4 100644 --- a/core/src/main/scala/com/pingcap/tispark/TiSparkInfo.scala +++ b/core/src/main/scala/com/pingcap/tispark/TiSparkInfo.scala @@ -21,7 +21,7 @@ import org.slf4j.LoggerFactory object TiSparkInfo { private final val logger = LoggerFactory.getLogger(getClass.getName) - val SUPPORTED_SPARK_VERSION: List[String] = "3.0" :: Nil + val SUPPORTED_SPARK_VERSION: List[String] = "3.0" :: "3.1" :: Nil val SPARK_VERSION: String = org.apache.spark.SPARK_VERSION diff --git a/core/src/main/scala/com/pingcap/tispark/utils/ReflectionUtil.scala b/core/src/main/scala/com/pingcap/tispark/utils/ReflectionUtil.scala new file mode 100644 index 0000000000..dacd8165b4 --- /dev/null +++ b/core/src/main/scala/com/pingcap/tispark/utils/ReflectionUtil.scala @@ -0,0 +1,155 @@ +/* + * + * Copyright 2021 PingCAP, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.pingcap.tispark.utils + +import com.pingcap.tispark.TiSparkInfo +import org.apache.spark.sql.{SparkSession, TiContext} +import org.apache.spark.sql.catalyst.expressions.BasicExpression.TiExpression +import org.apache.spark.sql.catalyst.expressions.{Alias, ExprId, Expression, SortOrder} +import org.apache.spark.sql.catalyst.parser.ParserInterface +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule +import org.slf4j.LoggerFactory + +import java.io.File +import java.net.{URL, URLClassLoader} + +/** + * ReflectionUtil is designed to reflect methods which differ across + * different Spark versions. Compatibility issues should be solved by + * reflections in future. + */ +object ReflectionUtil { + lazy val classLoader: URLClassLoader = { + val tisparkClassUrl = this.getClass.getProtectionDomain.getCodeSource.getLocation + val tisparkClassPath = new File(tisparkClassUrl.getFile) + logger.info(s"tispark class url: ${tisparkClassUrl.toString}") + + val sparkWrapperClassURL: URL = if (tisparkClassPath.isDirectory) { + val classDir = new File( + s"${tisparkClassPath.getAbsolutePath}/../../../spark-wrapper/spark-${TiSparkInfo.SPARK_MAJOR_VERSION}/target/classes/") + if (!classDir.exists()) { + throw new Exception( + "cannot find spark wrapper classes! please compile the spark-wrapper project first!") + } + classDir.toURI.toURL + } else { + new URL( + s"jar:$tisparkClassUrl!/resources/spark-wrapper-spark-${TiSparkInfo.SPARK_MAJOR_VERSION + .replace('.', '_')}/") + } + logger.info(s"spark wrapper class url: ${sparkWrapperClassURL.toString}") + + new URLClassLoader(Array(sparkWrapperClassURL), this.getClass.getClassLoader) + } + private val logger = LoggerFactory.getLogger(getClass.getName) + private val SPARK_WRAPPER_CLASS = "com.pingcap.tispark.SparkWrapper" + private val TI_BASIC_EXPRESSION_CLASS = + "org.apache.spark.sql.catalyst.expressions.TiBasicExpression" + private val TI_RESOLUTION_RULE_CLASS = + "org.apache.spark.sql.extensions.TiResolutionRule" + private val TI_RESOLUTION_RULE_V2_CLASS = + "org.apache.spark.sql.extensions.TiResolutionRuleV2" + private val TI_PARSER_CLASS = + "org.apache.spark.sql.extensions.TiParser" + private val TI_DDL_RULE_CLASS = + "org.apache.spark.sql.extensions.TiDDLRule" + + def newAlias(child: Expression, name: String): Alias = { + classLoader + .loadClass(SPARK_WRAPPER_CLASS) + .getDeclaredMethod("newAlias", classOf[Expression], classOf[String]) + .invoke(null, child, name) + .asInstanceOf[Alias] + } + + def newAlias(child: Expression, name: String, exprId: ExprId): Alias = { + classLoader + .loadClass(SPARK_WRAPPER_CLASS) + .getDeclaredMethod("newAlias", classOf[Expression], classOf[String], classOf[ExprId]) + .invoke(null, child, name, exprId) + .asInstanceOf[Alias] + } + + def copySortOrder(sortOrder: SortOrder, child: Expression): SortOrder = { + classLoader + .loadClass(SPARK_WRAPPER_CLASS) + .getDeclaredMethod("copySortOrder", classOf[SortOrder], classOf[Expression]) + .invoke(null, sortOrder, child) + .asInstanceOf[SortOrder] + } + + def trimNonTopLevelAliases(e: Expression): Expression = { + classLoader + .loadClass(SPARK_WRAPPER_CLASS) + .getDeclaredMethod("trimNonTopLevelAliases", classOf[Expression]) + .invoke(null, e) + .asInstanceOf[Expression] + } + + def callTiBasicExpressionConvertToTiExpr(expr: Expression): Option[TiExpression] = { + classLoader + .loadClass(TI_BASIC_EXPRESSION_CLASS) + .getDeclaredMethod("convertToTiExpr", classOf[Expression]) + .invoke(null, expr) + .asInstanceOf[Option[TiExpression]] + } + + def newTiResolutionRule( + getOrCreateTiContext: SparkSession => TiContext, + sparkSession: SparkSession): Rule[LogicalPlan] = { + classLoader + .loadClass(TI_RESOLUTION_RULE_CLASS) + .getDeclaredConstructor(classOf[SparkSession => TiContext], classOf[SparkSession]) + .newInstance(getOrCreateTiContext, sparkSession) + .asInstanceOf[Rule[LogicalPlan]] + } + + def newTiResolutionRuleV2( + getOrCreateTiContext: SparkSession => TiContext, + sparkSession: SparkSession): Rule[LogicalPlan] = { + classLoader + .loadClass(TI_RESOLUTION_RULE_V2_CLASS) + .getDeclaredConstructor(classOf[SparkSession => TiContext], classOf[SparkSession]) + .newInstance(getOrCreateTiContext, sparkSession) + .asInstanceOf[Rule[LogicalPlan]] + } + + def newTiParser( + getOrCreateTiContext: SparkSession => TiContext, + sparkSession: SparkSession, + parserInterface: ParserInterface): ParserInterface = { + classLoader + .loadClass(TI_PARSER_CLASS) + .getDeclaredConstructor( + classOf[SparkSession => TiContext], + classOf[SparkSession], + classOf[ParserInterface]) + .newInstance(getOrCreateTiContext, sparkSession, parserInterface) + .asInstanceOf[ParserInterface] + } + + def newTiDDLRule( + getOrCreateTiContext: SparkSession => TiContext, + sparkSession: SparkSession): Rule[LogicalPlan] = { + classLoader + .loadClass(TI_DDL_RULE_CLASS) + .getDeclaredConstructor(classOf[SparkSession => TiContext], classOf[SparkSession]) + .newInstance(getOrCreateTiContext, sparkSession) + .asInstanceOf[Rule[LogicalPlan]] + } +} diff --git a/core/src/main/scala/com/pingcap/tispark/utils/TiUtil.scala b/core/src/main/scala/com/pingcap/tispark/utils/TiUtil.scala index 1202a3911d..90784196da 100644 --- a/core/src/main/scala/com/pingcap/tispark/utils/TiUtil.scala +++ b/core/src/main/scala/com/pingcap/tispark/utils/TiUtil.scala @@ -16,7 +16,6 @@ package com.pingcap.tispark.utils import java.util.concurrent.TimeUnit - import com.pingcap.tikv.TiConfiguration import com.pingcap.tikv.datatype.TypeMapping import com.pingcap.tikv.meta.{TiDAGRequest, TiTableInfo} @@ -30,7 +29,44 @@ import org.apache.spark.sql.{DataFrame, Row, SQLContext, SparkSession} import org.apache.spark.{SparkConf, sql} import org.tikv.kvproto.Kvrpcpb.{CommandPri, IsolationLevel} +import java.time.{Instant, LocalDate, ZoneId} +import java.util.TimeZone +import java.util.concurrent.TimeUnit.NANOSECONDS + object TiUtil { + val MICROS_PER_MILLIS = 1000L + val MICROS_PER_SECOND = 1000000L + + def defaultTimeZone(): TimeZone = TimeZone.getDefault + + def daysToMillis(days: Int): Long = { + daysToMillis(days, defaultTimeZone().toZoneId) + } + + def daysToMillis(days: Int, zoneId: ZoneId): Long = { + val instant = daysToLocalDate(days).atStartOfDay(zoneId).toInstant + toMillis(instantToMicros(instant)) + } + + /* + * Converts the timestamp to milliseconds since epoch. In spark timestamp values have microseconds + * precision, so this conversion is lossy. + */ + def toMillis(us: Long): Long = { + // When the timestamp is negative i.e before 1970, we need to adjust the millseconds portion. + // Example - 1965-01-01 10:11:12.123456 is represented as (-157700927876544) in micro precision. + // In millis precision the above needs to be represented as (-157700927877). + Math.floorDiv(us, MICROS_PER_MILLIS) + } + + def instantToMicros(instant: Instant): Long = { + val us = Math.multiplyExact(instant.getEpochSecond, MICROS_PER_SECOND) + val result = Math.addExact(us, NANOSECONDS.toMicros(instant.getNano)) + result + } + + def daysToLocalDate(days: Int): LocalDate = LocalDate.ofEpochDay(days) + def getSchemaFromTable(table: TiTableInfo): StructType = { val fields = new Array[StructField](table.getColumns.size()) for (i <- 0 until table.getColumns.size()) { diff --git a/core/src/main/scala/org/apache/spark/sql/TiAggregationImpl.scala b/core/src/main/scala/org/apache/spark/sql/TiAggregationImpl.scala index 2204fac13e..eaba318ee5 100644 --- a/core/src/main/scala/org/apache/spark/sql/TiAggregationImpl.scala +++ b/core/src/main/scala/org/apache/spark/sql/TiAggregationImpl.scala @@ -15,15 +15,10 @@ package org.apache.spark.sql +import com.pingcap.tispark.utils.ReflectionUtil import org.apache.spark.sql.catalyst.expressions.NamedExpression.newExprId import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.expressions.{ - Alias, - Cast, - Divide, - Expression, - NamedExpression -} +import org.apache.spark.sql.catalyst.expressions.{Cast, Divide, Expression, NamedExpression} import org.apache.spark.sql.catalyst.planning.PhysicalAggregation import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.types.{DecimalType, DoubleType, FloatType, LongType} @@ -97,7 +92,7 @@ object TiAggregationImpl { case DoubleType => Cast(sum.resultAttribute, DoubleType) case d: DecimalType => Cast(sum.resultAttribute, d) } - (ref: Expression) -> Alias(castedSum, ref.name)(exprId = ref.exprId) + (ref: Expression) -> ReflectionUtil.newAlias(castedSum, ref.name, ref.exprId) } val avgRewrite: PartialFunction[Expression, Expression] = avgRewriteMap.map { @@ -105,7 +100,7 @@ object TiAggregationImpl { val castedSum = Cast(sum.resultAttribute, DoubleType) val castedCount = Cast(count.resultAttribute, DoubleType) val division = Cast(Divide(castedSum, castedCount), ref.dataType) - (ref: Expression) -> Alias(division, ref.name)(exprId = ref.exprId) + (ref: Expression) -> ReflectionUtil.newAlias(division, ref.name, ref.exprId) } val rewrittenResultExpressions = resultExpressions diff --git a/core/src/main/scala/org/apache/spark/sql/TiStrategy.scala b/core/src/main/scala/org/apache/spark/sql/TiStrategy.scala index f8f6502571..412f38d6bb 100644 --- a/core/src/main/scala/org/apache/spark/sql/TiStrategy.scala +++ b/core/src/main/scala/org/apache/spark/sql/TiStrategy.scala @@ -16,7 +16,6 @@ package org.apache.spark.sql import java.util.concurrent.TimeUnit - import com.pingcap.tidb.tipb.EncodeType import com.pingcap.tikv.exception.IgnoreUnsupportedTypeException import com.pingcap.tikv.expression._ @@ -26,10 +25,9 @@ import com.pingcap.tikv.predicates.{PredicateUtils, TiKVScanAnalyzer} import com.pingcap.tikv.region.TiStoreType import com.pingcap.tikv.statistics.TableStatistics import com.pingcap.tispark.statistics.StatisticsManager -import com.pingcap.tispark.utils.TiUtil +import com.pingcap.tispark.utils.{ReflectionUtil, TiUtil} import com.pingcap.tispark.{TiConfigConst, TiDBRelation} import org.apache.spark.internal.Logging -import org.apache.spark.sql.catalyst.analysis.CleanupAliases import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.{ Alias, @@ -38,7 +36,6 @@ import org.apache.spark.sql.catalyst.expressions.{ AttributeMap, AttributeSet, Descending, - TiExprUtils, Expression, IntegerLiteral, IsNull, @@ -46,7 +43,8 @@ import org.apache.spark.sql.catalyst.expressions.{ NullsFirst, NullsLast, SortOrder, - SubqueryExpression + SubqueryExpression, + TiExprUtils } import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical @@ -378,11 +376,12 @@ case class TiStrategy(getOrCreateTiContext: SparkSession => TiContext)(sparkSess val newSortExpr = sortOrder.child.transformUp { case a: Attribute => aliases.getOrElse(a, a) } - val trimmedExpr = CleanupAliases.trimNonTopLevelAliases(newSortExpr) - val trimmedSortOrder = sortOrder.copy(child = trimmedExpr) + val trimmedExpr = ReflectionUtil.trimNonTopLevelAliases(newSortExpr) + val trimmedSortOrder = ReflectionUtil.copySortOrder(sortOrder, trimmedExpr) + (sortOrder.direction, sortOrder.nullOrdering) match { case (_ @Ascending, _ @NullsLast) | (_ @Descending, _ @NullsFirst) => - sortOrder.copy(child = IsNull(trimmedExpr)) :: trimmedSortOrder :: Nil + ReflectionUtil.copySortOrder(sortOrder, IsNull(trimmedExpr)) :: trimmedSortOrder :: Nil case _ => trimmedSortOrder :: Nil } @@ -471,11 +470,11 @@ case class TiStrategy(getOrCreateTiContext: SparkSession => TiContext)(sparkSess source: TiDBRelation, dagReq: TiDAGRequest): Seq[SparkPlan] = { val deterministicAggAliases = aggregateExpressions.collect { - case e if e.deterministic => e.canonicalized -> Alias(e, e.toString())() + case e if e.deterministic => e.canonicalized -> ReflectionUtil.newAlias(e, e.toString()) }.toMap def aliasPushedPartialResult(e: AggregateExpression): Alias = - deterministicAggAliases.getOrElse(e.canonicalized, Alias(e, e.toString())()) + deterministicAggAliases.getOrElse(e.canonicalized, ReflectionUtil.newAlias(e, e.toString())) val residualAggregateExpressions = aggregateExpressions.map { aggExpr => // As `aggExpr` is being pushing down to TiKV, we need to replace the original Catalyst diff --git a/core/src/main/scala/org/apache/spark/sql/catalyst/expressions/BasicExpression.scala b/core/src/main/scala/org/apache/spark/sql/catalyst/expressions/BasicExpression.scala index d1f30e02cc..34d0161bfb 100644 --- a/core/src/main/scala/org/apache/spark/sql/catalyst/expressions/BasicExpression.scala +++ b/core/src/main/scala/org/apache/spark/sql/catalyst/expressions/BasicExpression.scala @@ -18,12 +18,8 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.Timestamp - -import com.pingcap.tikv.expression._ import com.pingcap.tikv.region.RegionStoreClient.RequestTypes -import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute -import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.TiConverter +import com.pingcap.tispark.utils.{ReflectionUtil, TiUtil} import org.apache.spark.sql.types._ import org.joda.time.DateTime @@ -43,7 +39,7 @@ object BasicExpression { // and this number of date has compensate of timezone // and must be restored by DateTimeUtils.daysToMillis case DateType => - new DateTime(DateTimeUtils.daysToMillis(value.asInstanceOf[DateTimeUtils.SQLDate])) + new DateTime(TiUtil.daysToMillis(value.asInstanceOf[Int])) case TimestampType => new Timestamp(value.asInstanceOf[Long] / 1000) case StringType => value.toString case _: DecimalType => value.asInstanceOf[Decimal].toBigDecimal.bigDecimal @@ -79,97 +75,7 @@ object BasicExpression { } def convertToTiExpr(expr: Expression): Option[TiExpression] = - expr match { - case Literal(value, dataType) => - Some( - Constant.create(convertLiteral(value, dataType), TiConverter.fromSparkType(dataType))) - - case Add(BasicExpression(lhs), BasicExpression(rhs)) => - Some(ArithmeticBinaryExpression.plus(lhs, rhs)) - - case Subtract(BasicExpression(lhs), BasicExpression(rhs)) => - Some(ArithmeticBinaryExpression.minus(lhs, rhs)) - - case e @ Multiply(BasicExpression(lhs), BasicExpression(rhs)) => - Some(ArithmeticBinaryExpression.multiply(TiConverter.fromSparkType(e.dataType), lhs, rhs)) - - case d @ Divide(BasicExpression(lhs), BasicExpression(rhs)) => - Some(ArithmeticBinaryExpression.divide(TiConverter.fromSparkType(d.dataType), lhs, rhs)) - - case And(BasicExpression(lhs), BasicExpression(rhs)) => - Some(LogicalBinaryExpression.and(lhs, rhs)) - - case Or(BasicExpression(lhs), BasicExpression(rhs)) => - Some(LogicalBinaryExpression.or(lhs, rhs)) - - case Alias(BasicExpression(child), _) => - Some(child) - - case IsNull(BasicExpression(child)) => - Some(new TiIsNull(child)) - - case IsNotNull(BasicExpression(child)) => - Some(new TiNot(new TiIsNull(child))) - - case GreaterThan(BasicExpression(lhs), BasicExpression(rhs)) => - Some(ComparisonBinaryExpression.greaterThan(lhs, rhs)) - - case GreaterThanOrEqual(BasicExpression(lhs), BasicExpression(rhs)) => - Some(ComparisonBinaryExpression.greaterEqual(lhs, rhs)) - - case LessThan(BasicExpression(lhs), BasicExpression(rhs)) => - Some(ComparisonBinaryExpression.lessThan(lhs, rhs)) - - case LessThanOrEqual(BasicExpression(lhs), BasicExpression(rhs)) => - Some(ComparisonBinaryExpression.lessEqual(lhs, rhs)) - - case EqualTo(BasicExpression(lhs), BasicExpression(rhs)) => - Some(ComparisonBinaryExpression.equal(lhs, rhs)) - - case Not(EqualTo(BasicExpression(lhs), BasicExpression(rhs))) => - Some(ComparisonBinaryExpression.notEqual(lhs, rhs)) - - case Not(BasicExpression(child)) => - Some(new TiNot(child)) - - case StartsWith(BasicExpression(lhs), BasicExpression(rhs)) => - Some(StringRegExpression.startsWith(lhs, rhs)) - - case Contains(BasicExpression(lhs), BasicExpression(rhs)) => - Some(StringRegExpression.contains(lhs, rhs)) - - case EndsWith(BasicExpression(lhs), BasicExpression(rhs)) => - Some(StringRegExpression.endsWith(lhs, rhs)) - - case Like(BasicExpression(lhs), BasicExpression(rhs), _) => - Some(StringRegExpression.like(lhs, rhs)) - - // Coprocessor has its own behavior of type promoting and overflow check - // so we simply remove it from expression and let cop handle it - case CheckOverflow(BasicExpression(expr), dec: DecimalType, _) => - expr.setDataType(TiConverter.fromSparkType(dec)) - Some(expr) - - case PromotePrecision(BasicExpression(expr)) => - Some(expr) - - case PromotePrecision(Cast(BasicExpression(expr), dec: DecimalType, _)) => - expr.setDataType(TiConverter.fromSparkType(dec)) - Some(expr) - - case PromotePrecision(BasicExpression(expr)) => - Some(expr) - - // TODO: Are all AttributeReference column reference in such context? - case attr: AttributeReference => - Some(ColumnRef.create(attr.name, TiConverter.fromSparkType(attr.dataType))) - - case uAttr: UnresolvedAttribute => - Some(ColumnRef.create(uAttr.name, TiConverter.fromSparkType(uAttr.dataType))) - - // TODO: Remove it and let it fail once done all translation - case _ => Option.empty[TiExpression] - } + ReflectionUtil.callTiBasicExpressionConvertToTiExpr(expr) def unapply(expr: Expression): Option[TiExpression] = { convertToTiExpr(expr) diff --git a/core/src/main/scala/org/apache/spark/sql/extensions/parser.scala b/core/src/main/scala/org/apache/spark/sql/extensions/parser.scala index e7e9a4ad26..f9461e8d59 100644 --- a/core/src/main/scala/org/apache/spark/sql/extensions/parser.scala +++ b/core/src/main/scala/org/apache/spark/sql/extensions/parser.scala @@ -15,21 +15,9 @@ package org.apache.spark.sql.extensions -import java.util +import com.pingcap.tispark.utils.ReflectionUtil -import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation -import org.apache.spark.sql.catalyst.expressions.{Expression, SubqueryExpression} import org.apache.spark.sql.catalyst.parser._ -import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} -import org.apache.spark.sql.execution.SparkSqlParser -import org.apache.spark.sql.execution.command.{ - CacheTableCommand, - CreateViewCommand, - ExplainCommand, - UncacheTableCommand -} -import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.sql.{SparkSession, TiContext, TiExtensions} class TiParserFactory(getOrCreateTiContext: SparkSession => TiContext) @@ -40,125 +28,7 @@ class TiParserFactory(getOrCreateTiContext: SparkSession => TiContext) if (TiExtensions.catalogPluginMode(sparkSession)) { parserInterface } else { - TiParser(getOrCreateTiContext)(sparkSession, parserInterface) - } - } -} - -case class TiParser(getOrCreateTiContext: SparkSession => TiContext)( - sparkSession: SparkSession, - delegate: ParserInterface) - extends ParserInterface { - private lazy val tiContext = getOrCreateTiContext(sparkSession) - private lazy val internal = new SparkSqlParser(sparkSession.sqlContext.conf) - - private val cteTableNames = new ThreadLocal[java.util.Set[String]] { - override def initialValue(): util.Set[String] = new util.HashSet[String]() - } - - /** - * WAR to lead Spark to consider this relation being on local files. - * Otherwise Spark will lookup this relation in his session catalog. - * CHECK Spark [[org.apache.spark.sql.catalyst.analysis.Analyzer.ResolveRelations]] for details. - */ - private val qualifyTableIdentifier: PartialFunction[LogicalPlan, LogicalPlan] = { - case r @ UnresolvedRelation(tableIdentifier) if needQualify(tableIdentifier) => - r.copy(qualifyTableIdentifierInternal(tableIdentifier)) - case i @ InsertIntoStatement(r @ UnresolvedRelation(tableIdentifier), _, _, _, _) - if needQualify(tableIdentifier) => - // When getting temp view, we leverage legacy catalog. - i.copy(r.copy(qualifyTableIdentifierInternal(tableIdentifier))) - case w @ With(_, cteRelations) => - for (x <- cteRelations) { - cteTableNames.get().add(x._1.toLowerCase()) - } - w.copy(cteRelations = cteRelations - .map(p => (p._1, p._2.transform(qualifyTableIdentifier).asInstanceOf[SubqueryAlias]))) - case cv @ CreateViewCommand(_, _, _, _, _, child, _, _, _) => - cv.copy(child = child transform qualifyTableIdentifier) - case e @ ExplainCommand(plan, _) => - e.copy(logicalPlan = plan transform qualifyTableIdentifier) - case c @ CacheTableCommand(tableIdentifier, plan, _, _) - if plan.isEmpty && needQualify(tableIdentifier) => - // Caching an unqualified catalog table. - c.copy(qualifyTableIdentifierInternal(tableIdentifier)) - case c @ CacheTableCommand(_, plan, _, _) if plan.isDefined => - c.copy(plan = Some(plan.get transform qualifyTableIdentifier)) - case u @ UncacheTableCommand(tableIdentifier, _) if needQualify(tableIdentifier) => - // Uncaching an unqualified catalog table. - u.copy(qualifyTableIdentifierInternal(tableIdentifier)) - case logicalPlan => - logicalPlan transformExpressionsUp { - case s: SubqueryExpression => - val cteNamesBeforeSubQuery = new util.HashSet[String]() - cteNamesBeforeSubQuery.addAll(cteTableNames.get()) - val newPlan = s.withNewPlan(s.plan transform qualifyTableIdentifier) - // cte table names in the subquery should not been seen outside subquey - cteTableNames.get().clear() - cteTableNames.get().addAll(cteNamesBeforeSubQuery) - newPlan - } - } - - override def parsePlan(sqlText: String): LogicalPlan = { - val plan = internal.parsePlan(sqlText) - cteTableNames.get().clear() - plan.transform(qualifyTableIdentifier) - } - - override def parseExpression(sqlText: String): Expression = - internal.parseExpression(sqlText) - - override def parseTableIdentifier(sqlText: String): TableIdentifier = - internal.parseTableIdentifier(sqlText) - - override def parseFunctionIdentifier(sqlText: String): FunctionIdentifier = - internal.parseFunctionIdentifier(sqlText) - - override def parseTableSchema(sqlText: String): StructType = - internal.parseTableSchema(sqlText) - - override def parseDataType(sqlText: String): DataType = - internal.parseDataType(sqlText) - - private def qualifyTableIdentifierInternal(tableIdentifier: Seq[String]): Seq[String] = { - if (tableIdentifier.size == 1) { - tiContext.tiCatalog.getCurrentDatabase :: tableIdentifier.toList - } else { - tableIdentifier + ReflectionUtil.newTiParser(getOrCreateTiContext, sparkSession, parserInterface) } } - - private def qualifyTableIdentifierInternal( - tableIdentifier: TableIdentifier): TableIdentifier = { - TableIdentifier( - tableIdentifier.table, - Some(tableIdentifier.database.getOrElse(tiContext.tiCatalog.getCurrentDatabase))) - } - - private def needQualify(tableIdentifier: Seq[String]): Boolean = { - tableIdentifier.size == 1 && tiContext.sessionCatalog - .getTempView(tableIdentifier.head) - .isEmpty && !cteTableNames.get().contains(tableIdentifier.head.toLowerCase()) - } - - /** - * Determines whether a table specified by tableIdentifier is - * needs to be qualified. This is used for TiSpark to transform - * plans and decides whether a relation should be resolved or parsed. - * - * @param tableIdentifier tableIdentifier - * @return whether it needs qualifying - */ - private def needQualify(tableIdentifier: TableIdentifier): Boolean = { - tableIdentifier.database.isEmpty && tiContext.sessionCatalog - .getTempView(tableIdentifier.table) - .isEmpty - } - - override def parseMultipartIdentifier(sqlText: String): Seq[String] = - internal.parseMultipartIdentifier(sqlText) - - @scala.throws[ParseException]("Text cannot be parsed to a DataType") - override def parseRawDataType(sqlText: String): DataType = ??? } diff --git a/core/src/main/scala/org/apache/spark/sql/extensions/rules.scala b/core/src/main/scala/org/apache/spark/sql/extensions/rules.scala index 890c985e64..275324f8be 100644 --- a/core/src/main/scala/org/apache/spark/sql/extensions/rules.scala +++ b/core/src/main/scala/org/apache/spark/sql/extensions/rules.scala @@ -14,25 +14,10 @@ */ package org.apache.spark.sql.extensions -import com.pingcap.tispark.{MetaManager, TiDBRelation, TiTableReference} -import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, _} -import com.pingcap.tispark.statistics.StatisticsManager -import org.apache.spark.sql.catalyst.analysis.{ - EliminateSubqueryAliases, - UnresolvedRelation, - UnresolvedTableOrView -} -import org.apache.spark.sql.catalyst.catalog.CatalogTypes._ -import org.apache.spark.sql.catalyst.catalog.{TiDBTable, TiSessionCatalog} -import org.apache.spark.sql.catalyst.expressions.AttributeReference -import org.apache.spark.sql.{SparkSession, TiContext} +import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan} +import com.pingcap.tispark.utils.ReflectionUtil +import org.apache.spark.sql.{SparkSession, TiContext, TiExtensions} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.connector.catalog.{CatalogManager, Identifier} -import org.apache.spark.sql.execution.command._ -import org.apache.spark.sql.execution.datasources.LogicalRelation -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation -import org.apache.spark.sql.{AnalysisException, _} import org.slf4j.LoggerFactory class TiResolutionRuleFactory(getOrCreateTiContext: SparkSession => TiContext) @@ -42,75 +27,9 @@ class TiResolutionRuleFactory(getOrCreateTiContext: SparkSession => TiContext) if (TiExtensions.catalogPluginMode(sparkSession)) { // set the class loader to Reflection class loader to avoid class not found exception while loading TiCatalog logger.info("TiSpark running in catalog plugin mode") - TiResolutionRuleV2(getOrCreateTiContext)(sparkSession) + ReflectionUtil.newTiResolutionRuleV2(getOrCreateTiContext, sparkSession) } else { - TiResolutionRule(getOrCreateTiContext)(sparkSession) - } - } -} - -case class TiResolutionRule(getOrCreateTiContext: SparkSession => TiContext)( - sparkSession: SparkSession) - extends Rule[LogicalPlan] { - protected lazy val meta: MetaManager = tiContext.meta - private lazy val autoLoad = tiContext.autoLoad - private lazy val tiCatalog = tiContext.tiCatalog - private lazy val tiSession = tiContext.tiSession - private lazy val sqlContext = tiContext.sqlContext - protected val tiContext: TiContext = getOrCreateTiContext(sparkSession) - - protected def resolveTiDBRelation( - withSubQueryAlias: Boolean = true): Seq[String] => LogicalPlan = - tableIdentifier => { - val dbName = getDatabaseFromIdentifier(tableIdentifier) - val tableName = - if (tableIdentifier.size == 1) tableIdentifier.head else tableIdentifier.tail.head - val table = meta.getTable(dbName, tableName) - if (table.isEmpty) { - throw new AnalysisException(s"Table or view '$tableName' not found in database '$dbName'") - } - if (autoLoad) { - StatisticsManager.loadStatisticsInfo(table.get) - } - val sizeInBytes = StatisticsManager.estimateTableSize(table.get) - val tiDBRelation = - TiDBRelation(tiSession, TiTableReference(dbName, tableName, sizeInBytes), meta)( - sqlContext) - if (withSubQueryAlias) { - // Use SubqueryAlias so that projects and joins can correctly resolve - // UnresolvedAttributes in JoinConditions, Projects, Filters, etc. - SubqueryAlias(tableName, LogicalRelation(tiDBRelation)) - } else { - LogicalRelation(tiDBRelation) - } - } - - override def apply(plan: LogicalPlan): LogicalPlan = - plan transformUp resolveTiDBRelations - - protected def resolveTiDBRelations: PartialFunction[LogicalPlan, LogicalPlan] = { - case i @ InsertIntoStatement(UnresolvedRelation(tableIdentifier), _, _, _, _) - if tiCatalog - .catalogOf(tableIdentifier) - .exists(_.isInstanceOf[TiSessionCatalog]) => - i.copy(table = EliminateSubqueryAliases(resolveTiDBRelation()(tableIdentifier))) - case UnresolvedRelation(tableIdentifier) - if tiCatalog - .catalogOf(tableIdentifier) - .exists(_.isInstanceOf[TiSessionCatalog]) => - resolveTiDBRelation()(tableIdentifier) - case UnresolvedTableOrView(tableIdentifier) - if tiCatalog - .catalogOf(tableIdentifier) - .exists(_.isInstanceOf[TiSessionCatalog]) => - resolveTiDBRelation(false)(tableIdentifier) - } - - private def getDatabaseFromIdentifier(tableIdentifier: Seq[String]): String = { - if (tableIdentifier.size == 1) { - tiCatalog.getCurrentDatabase - } else { - tableIdentifier.head + ReflectionUtil.newTiResolutionRule(getOrCreateTiContext, sparkSession) } } } @@ -121,113 +40,12 @@ class TiDDLRuleFactory(getOrCreateTiContext: SparkSession => TiContext) if (TiExtensions.catalogPluginMode(sparkSession)) { TiDDLRuleV2(getOrCreateTiContext)(sparkSession) } else { - TiDDLRule(getOrCreateTiContext)(sparkSession) + ReflectionUtil.newTiDDLRule(getOrCreateTiContext, sparkSession) } } } case class NopCommand(name: String) extends Command {} -case class TiDDLRule(getOrCreateTiContext: SparkSession => TiContext)(sparkSession: SparkSession) - extends Rule[LogicalPlan] { - protected lazy val tiContext: TiContext = getOrCreateTiContext(sparkSession) - - def getDBAndTableName(ident: Identifier): (String, Option[String]) = { - ident.namespace() match { - case Array(db) => - (ident.name(), Some(db)) - case _ => - (ident.name(), None) - } - } - - def isSupportedCatalog(sd: SetCatalogAndNamespace): Boolean = { - if (sd.catalogName.isEmpty) - false - else { - sd.catalogName.get.equals(CatalogManager.SESSION_CATALOG_NAME) - } - } - - override def apply(plan: LogicalPlan): LogicalPlan = - plan transformUp { - // TODO: support other commands that may concern TiSpark catalog. - case sd: ShowNamespaces => - TiShowDatabasesCommand(tiContext, sd) - case sd: SetCatalogAndNamespace if isSupportedCatalog(sd) => - TiSetDatabaseCommand(tiContext, sd) - case st: ShowTablesCommand => - TiShowTablesCommand(tiContext, st) - case st: ShowColumnsCommand => - TiShowColumnsCommand(tiContext, st) - case dt: DescribeTableCommand => - TiDescribeTablesCommand( - tiContext, - dt, - DescribeTableInfo( - TableIdentifier(dt.table.table, dt.table.database), - dt.partitionSpec, - dt.isExtended)) - case dt @ DescribeRelation( - LogicalRelation(TiDBRelation(_, tableRef, _, _, _), _, _, _), - _, - _) => - TiDescribeTablesCommand( - tiContext, - dt, - DescribeTableInfo( - TableIdentifier(tableRef.tableName, Some(tableRef.databaseName)), - dt.partitionSpec, - dt.isExtended)) - case dc: DescribeColumnCommand => - TiDescribeColumnCommand(tiContext, dc) - case ct: CreateTableLikeCommand => - TiCreateTableLikeCommand(tiContext, ct) - } -} - -case class TiResolutionRuleV2(getOrCreateTiContext: SparkSession => TiContext)( - sparkSession: SparkSession) - extends Rule[LogicalPlan] { - protected lazy val meta: MetaManager = tiContext.meta - private lazy val autoLoad = tiContext.autoLoad - //private lazy val tiCatalog = tiContext.tiCatalog - private lazy val tiSession = tiContext.tiSession - private lazy val sqlContext = tiContext.sqlContext - protected val tiContext: TiContext = getOrCreateTiContext(sparkSession) - - protected val resolveTiDBRelation: (TiDBTable, Seq[AttributeReference]) => LogicalPlan = - (tiTable, output) => { - if (autoLoad) { - StatisticsManager.loadStatisticsInfo(tiTable.tiTableInfo.get) - } - val sizeInBytes = StatisticsManager.estimateTableSize(tiTable.tiTableInfo.get) - val tiDBRelation = TiDBRelation( - tiSession, - TiTableReference(tiTable.databaseName, tiTable.tableName, sizeInBytes), - meta)(sqlContext) - // Use SubqueryAlias so that projects and joins can correctly resolve - // UnresolvedAttributes in JoinConditions, Projects, Filters, etc. - // todo since there is no UnresolvedAttributes, do we still need the subqueryAlias relation??? - SubqueryAlias(tiTable.tableName, LogicalRelation(tiDBRelation, output, None, false)) - } - - protected def resolveTiDBRelations: PartialFunction[LogicalPlan, LogicalPlan] = { - // todo can remove this branch since the target table of insert into statement should never be a tidb table - case i @ InsertIntoStatement(DataSourceV2Relation(table, output, _, _, _), _, _, _, _) - if table.isInstanceOf[TiDBTable] => - val tiTable = table.asInstanceOf[TiDBTable] - i.copy(table = EliminateSubqueryAliases(resolveTiDBRelation(tiTable, output))) - case DataSourceV2Relation(table, output, _, _, _) if table.isInstanceOf[TiDBTable] => - val tiTable = table.asInstanceOf[TiDBTable] - resolveTiDBRelation(tiTable, output) - } - - override def apply(plan: LogicalPlan): LogicalPlan = - plan match { - case _ => - plan transformUp resolveTiDBRelations - } -} case class TiDDLRuleV2(getOrCreateTiContext: SparkSession => TiContext)( sparkSession: SparkSession) diff --git a/core/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestSuite.scala b/core/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestSuite.scala index 7cd607278c..6b2d87c32e 100644 --- a/core/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestSuite.scala +++ b/core/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestSuite.scala @@ -401,7 +401,6 @@ class CatalogTestSuite extends BaseTiSparkTest { "desc extended t id", skipJDBC = true, rTiDB = expectedDescExtendedTableColumn) - spark.sql("drop table if exists t") } test("test schema change") { diff --git a/pom.xml b/pom.xml index 7e1cf55f7d..556f711d8b 100644 --- a/pom.xml +++ b/pom.xml @@ -70,8 +70,8 @@ UTF-8 UTF-8 3.1.0 - 3.0.0 - 3.0.2 + 3.0.2 + 3.1.1 2.12 2.12.10 @@ -161,10 +161,21 @@ tikv-client core + spark-wrapper/spark-3.0 + spark-wrapper/spark-3.1 assembly + + spark-3.1.1 + + false + + + 3.1.1 + + jenkins diff --git a/spark-wrapper/spark-3.0/pom.xml b/spark-wrapper/spark-3.0/pom.xml new file mode 100644 index 0000000000..d395a5ac82 --- /dev/null +++ b/spark-wrapper/spark-3.0/pom.xml @@ -0,0 +1,123 @@ + + + 4.0.0 + + com.pingcap.tispark + tispark-parent + 2.5.0-SNAPSHOT + ../../pom.xml + + + spark-wrapper-spark-3.0 + jar + TiSpark Project Spark Wrapper Spark-3.0 + http://github.copm/pingcap/tispark + + + 3.0.2 + + + + + com.pingcap.tispark + tispark-core-internal + ${project.parent.version} + + + org.apache.spark + spark-core_2.12 + ${spark.version.wrapper} + + + org.apache.spark + spark-catalyst_2.12 + ${spark.version.wrapper} + + + org.apache.spark + spark-sql_2.12 + ${spark.version.wrapper} + + + + + src/main/scala + + + net.alchim31.maven + scala-maven-plugin + 4.3.0 + + + compile-scala + compile + + add-source + compile + + + + test-compile-scala + test-compile + + add-source + testCompile + + + + attach-javadocs + + doc-jar + + + + + ${scala.version} + + + + org.apache.maven.plugins + maven-compiler-plugin + 2.3.2 + + 1.8 + 1.8 + UTF-8 + true + true + + + + + org.apache.maven.plugins + maven-source-plugin + 3.0.1 + + + attach-sources + + jar-no-fork + + + + + + + org.antipathy + mvn-scalafmt_${scala.binary.version} + 1.0.3 + + ${scalafmt.skip} + ${scalafmt.skip} + + ${project.basedir}/src/main/scala + + + ${project.basedir}/src/test/scala + + + + + + diff --git a/spark-wrapper/spark-3.0/src/main/scala/com/pingcap/tispark/SparkWrapper.scala b/spark-wrapper/spark-3.0/src/main/scala/com/pingcap/tispark/SparkWrapper.scala new file mode 100644 index 0000000000..1d8e85fb22 --- /dev/null +++ b/spark-wrapper/spark-3.0/src/main/scala/com/pingcap/tispark/SparkWrapper.scala @@ -0,0 +1,41 @@ +/* + * Copyright 2021 PingCAP, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.pingcap.tispark + +import org.apache.spark.sql.catalyst.analysis.CleanupAliases +import org.apache.spark.sql.catalyst.expressions.{Alias, ExprId, Expression, SortOrder} + +object SparkWrapper { + def getVersion: String = { + "SparkWrapper-3.0" + } + + def newAlias(child: Expression, name: String): Alias = { + Alias(child, name)() + } + + def newAlias(child: Expression, name: String, exprId: ExprId): Alias = { + Alias(child, name)(exprId = exprId) + } + + def trimNonTopLevelAliases(e: Expression): Expression = { + CleanupAliases.trimNonTopLevelAliases(e) + } + + def copySortOrder(sortOrder: SortOrder, child: Expression): SortOrder = { + sortOrder.copy(child = child) + } +} diff --git a/spark-wrapper/spark-3.0/src/main/scala/org/apache/spark/sql/catalyst/expressions/TiBasicExpression.scala b/spark-wrapper/spark-3.0/src/main/scala/org/apache/spark/sql/catalyst/expressions/TiBasicExpression.scala new file mode 100644 index 0000000000..3b60688f62 --- /dev/null +++ b/spark-wrapper/spark-3.0/src/main/scala/org/apache/spark/sql/catalyst/expressions/TiBasicExpression.scala @@ -0,0 +1,132 @@ +/* + * + * Copyright 2021 PingCAP, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package org.apache.spark.sql.catalyst.expressions + +import com.pingcap.tikv.expression.{ + ArithmeticBinaryExpression, + ColumnRef, + ComparisonBinaryExpression, + Constant, + LogicalBinaryExpression, + StringRegExpression +} +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.expressions.BasicExpression.{ + TiExpression, + TiIsNull, + TiNot, + convertLiteral +} +import org.apache.spark.sql.execution.TiConverter +import org.apache.spark.sql.types.DecimalType + +object TiBasicExpression { + + def convertToTiExpr(expr: Expression): Option[TiExpression] = + expr match { + case Literal(value, dataType) => + Some( + Constant.create(convertLiteral(value, dataType), TiConverter.fromSparkType(dataType))) + + case Add(BasicExpression(lhs), BasicExpression(rhs)) => + Some(ArithmeticBinaryExpression.plus(lhs, rhs)) + + case Subtract(BasicExpression(lhs), BasicExpression(rhs)) => + Some(ArithmeticBinaryExpression.minus(lhs, rhs)) + + case e @ Multiply(BasicExpression(lhs), BasicExpression(rhs)) => + Some(ArithmeticBinaryExpression.multiply(TiConverter.fromSparkType(e.dataType), lhs, rhs)) + + case d @ Divide(BasicExpression(lhs), BasicExpression(rhs)) => + Some(ArithmeticBinaryExpression.divide(TiConverter.fromSparkType(d.dataType), lhs, rhs)) + + case And(BasicExpression(lhs), BasicExpression(rhs)) => + Some(LogicalBinaryExpression.and(lhs, rhs)) + + case Or(BasicExpression(lhs), BasicExpression(rhs)) => + Some(LogicalBinaryExpression.or(lhs, rhs)) + + case Alias(BasicExpression(child), _) => + Some(child) + + case IsNull(BasicExpression(child)) => + Some(new TiIsNull(child)) + + case IsNotNull(BasicExpression(child)) => + Some(new TiNot(new TiIsNull(child))) + + case GreaterThan(BasicExpression(lhs), BasicExpression(rhs)) => + Some(ComparisonBinaryExpression.greaterThan(lhs, rhs)) + + case GreaterThanOrEqual(BasicExpression(lhs), BasicExpression(rhs)) => + Some(ComparisonBinaryExpression.greaterEqual(lhs, rhs)) + + case LessThan(BasicExpression(lhs), BasicExpression(rhs)) => + Some(ComparisonBinaryExpression.lessThan(lhs, rhs)) + + case LessThanOrEqual(BasicExpression(lhs), BasicExpression(rhs)) => + Some(ComparisonBinaryExpression.lessEqual(lhs, rhs)) + + case EqualTo(BasicExpression(lhs), BasicExpression(rhs)) => + Some(ComparisonBinaryExpression.equal(lhs, rhs)) + + case Not(EqualTo(BasicExpression(lhs), BasicExpression(rhs))) => + Some(ComparisonBinaryExpression.notEqual(lhs, rhs)) + + case Not(BasicExpression(child)) => + Some(new TiNot(child)) + + case StartsWith(BasicExpression(lhs), BasicExpression(rhs)) => + Some(StringRegExpression.startsWith(lhs, rhs)) + + case Contains(BasicExpression(lhs), BasicExpression(rhs)) => + Some(StringRegExpression.contains(lhs, rhs)) + + case EndsWith(BasicExpression(lhs), BasicExpression(rhs)) => + Some(StringRegExpression.endsWith(lhs, rhs)) + + case Like(BasicExpression(lhs), BasicExpression(rhs), _) => + Some(StringRegExpression.like(lhs, rhs)) + + // Coprocessor has its own behavior of type promoting and overflow check + // so we simply remove it from expression and let cop handle it + case CheckOverflow(BasicExpression(expr), dec: DecimalType, _) => + expr.setDataType(TiConverter.fromSparkType(dec)) + Some(expr) + + case PromotePrecision(BasicExpression(expr)) => + Some(expr) + + case PromotePrecision(Cast(BasicExpression(expr), dec: DecimalType, _)) => + expr.setDataType(TiConverter.fromSparkType(dec)) + Some(expr) + + case PromotePrecision(BasicExpression(expr)) => + Some(expr) + + // TODO: Are all AttributeReference column reference in such context? + case attr: AttributeReference => + Some(ColumnRef.create(attr.name, TiConverter.fromSparkType(attr.dataType))) + + case uAttr: UnresolvedAttribute => + Some(ColumnRef.create(uAttr.name, TiConverter.fromSparkType(uAttr.dataType))) + + // TODO: Remove it and let it fail once done all translation + case _ => Option.empty[TiExpression] + } +} diff --git a/spark-wrapper/spark-3.0/src/main/scala/org/apache/spark/sql/extensions/TiDDLRule.scala b/spark-wrapper/spark-3.0/src/main/scala/org/apache/spark/sql/extensions/TiDDLRule.scala new file mode 100644 index 0000000000..da45a9ed53 --- /dev/null +++ b/spark-wrapper/spark-3.0/src/main/scala/org/apache/spark/sql/extensions/TiDDLRule.scala @@ -0,0 +1,102 @@ +/* + * Copyright 2021 PingCAP, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.extensions + +import com.pingcap.tispark.TiDBRelation +import org.apache.spark.sql.{SparkSession, TiContext} +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.plans.logical.{ + DescribeRelation, + LogicalPlan, + SetCatalogAndNamespace, + ShowNamespaces +} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.connector.catalog.{CatalogManager, Identifier} +import org.apache.spark.sql.execution.command.{ + CreateTableLikeCommand, + DescribeColumnCommand, + DescribeTableCommand, + DescribeTableInfo, + ShowColumnsCommand, + ShowTablesCommand, + TiCreateTableLikeCommand, + TiDescribeColumnCommand, + TiDescribeTablesCommand, + TiSetDatabaseCommand, + TiShowColumnsCommand, + TiShowDatabasesCommand, + TiShowTablesCommand +} +import org.apache.spark.sql.execution.datasources.LogicalRelation + +case class TiDDLRule(getOrCreateTiContext: SparkSession => TiContext, sparkSession: SparkSession) + extends Rule[LogicalPlan] { + protected lazy val tiContext: TiContext = getOrCreateTiContext(sparkSession) + + def getDBAndTableName(ident: Identifier): (String, Option[String]) = { + ident.namespace() match { + case Array(db) => + (ident.name(), Some(db)) + case _ => + (ident.name(), None) + } + } + + def isSupportedCatalog(sd: SetCatalogAndNamespace): Boolean = { + if (sd.catalogName.isEmpty) + false + else { + sd.catalogName.get.equals(CatalogManager.SESSION_CATALOG_NAME) + } + } + + override def apply(plan: LogicalPlan): LogicalPlan = + plan transformUp { + // TODO: support other commands that may concern TiSpark catalog. + case sd: ShowNamespaces => + TiShowDatabasesCommand(tiContext, sd) + case sd: SetCatalogAndNamespace if isSupportedCatalog(sd) => + TiSetDatabaseCommand(tiContext, sd) + case st: ShowTablesCommand => + TiShowTablesCommand(tiContext, st) + case st: ShowColumnsCommand => + TiShowColumnsCommand(tiContext, st) + case dt: DescribeTableCommand => + TiDescribeTablesCommand( + tiContext, + dt, + DescribeTableInfo( + TableIdentifier(dt.table.table, dt.table.database), + dt.partitionSpec, + dt.isExtended)) + case dt @ DescribeRelation( + LogicalRelation(TiDBRelation(_, tableRef, _, _, _), _, _, _), + _, + _) => + TiDescribeTablesCommand( + tiContext, + dt, + DescribeTableInfo( + TableIdentifier(tableRef.tableName, Some(tableRef.databaseName)), + dt.partitionSpec, + dt.isExtended)) + case dc: DescribeColumnCommand => + TiDescribeColumnCommand(tiContext, dc) + case ct: CreateTableLikeCommand => + TiCreateTableLikeCommand(tiContext, ct) + } +} diff --git a/spark-wrapper/spark-3.0/src/main/scala/org/apache/spark/sql/extensions/TiParser.scala b/spark-wrapper/spark-3.0/src/main/scala/org/apache/spark/sql/extensions/TiParser.scala new file mode 100644 index 0000000000..af2f305aaa --- /dev/null +++ b/spark-wrapper/spark-3.0/src/main/scala/org/apache/spark/sql/extensions/TiParser.scala @@ -0,0 +1,157 @@ +/* + * Copyright 2021 PingCAP, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.extensions + +import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation +import org.apache.spark.sql.catalyst.expressions.{Expression, SubqueryExpression} +import org.apache.spark.sql.catalyst.parser.{ParseException, ParserInterface} +import org.apache.spark.sql.catalyst.plans.logical.{ + InsertIntoStatement, + LogicalPlan, + SubqueryAlias, + With +} +import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} +import org.apache.spark.sql.execution.SparkSqlParser +import org.apache.spark.sql.execution.command.{ + CacheTableCommand, + CreateViewCommand, + ExplainCommand, + UncacheTableCommand +} +import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.sql.{SparkSession, TiContext} + +import java.util + +case class TiParser( + getOrCreateTiContext: SparkSession => TiContext, + sparkSession: SparkSession, + delegate: ParserInterface) + extends ParserInterface { + private lazy val tiContext = getOrCreateTiContext(sparkSession) + private lazy val internal = new SparkSqlParser(sparkSession.sqlContext.conf) + + private val cteTableNames = new ThreadLocal[java.util.Set[String]] { + override def initialValue(): util.Set[String] = new util.HashSet[String]() + } + + /** + * WAR to lead Spark to consider this relation being on local files. + * Otherwise Spark will lookup this relation in his session catalog. + * CHECK Spark [[org.apache.spark.sql.catalyst.analysis.Analyzer.ResolveRelations]] for details. + */ + private val qualifyTableIdentifier: PartialFunction[LogicalPlan, LogicalPlan] = { + case r @ UnresolvedRelation(tableIdentifier) if needQualify(tableIdentifier) => + r.copy(qualifyTableIdentifierInternal(tableIdentifier)) + case i @ InsertIntoStatement(r @ UnresolvedRelation(tableIdentifier), _, _, _, _) + if needQualify(tableIdentifier) => + // When getting temp view, we leverage legacy catalog. + i.copy(r.copy(qualifyTableIdentifierInternal(tableIdentifier))) + case w @ With(_, cteRelations) => + for (x <- cteRelations) { + cteTableNames.get().add(x._1.toLowerCase()) + } + w.copy(cteRelations = cteRelations + .map(p => (p._1, p._2.transform(qualifyTableIdentifier).asInstanceOf[SubqueryAlias]))) + case cv @ CreateViewCommand(_, _, _, _, _, child, _, _, _) => + cv.copy(child = child transform qualifyTableIdentifier) + case e @ ExplainCommand(plan, _) => + e.copy(logicalPlan = plan transform qualifyTableIdentifier) + case c @ CacheTableCommand(tableIdentifier, plan, _, _) + if plan.isEmpty && needQualify(tableIdentifier) => + // Caching an unqualified catalog table. + c.copy(qualifyTableIdentifierInternal(tableIdentifier)) + case c @ CacheTableCommand(_, plan, _, _) if plan.isDefined => + c.copy(plan = Some(plan.get transform qualifyTableIdentifier)) + case u @ UncacheTableCommand(tableIdentifier, _) if needQualify(tableIdentifier) => + // Uncaching an unqualified catalog table. + u.copy(qualifyTableIdentifierInternal(tableIdentifier)) + case logicalPlan => + logicalPlan transformExpressionsUp { + case s: SubqueryExpression => + val cteNamesBeforeSubQuery = new util.HashSet[String]() + cteNamesBeforeSubQuery.addAll(cteTableNames.get()) + val newPlan = s.withNewPlan(s.plan transform qualifyTableIdentifier) + // cte table names in the subquery should not been seen outside subquey + cteTableNames.get().clear() + cteTableNames.get().addAll(cteNamesBeforeSubQuery) + newPlan + } + } + + override def parsePlan(sqlText: String): LogicalPlan = { + val plan = internal.parsePlan(sqlText) + cteTableNames.get().clear() + plan.transform(qualifyTableIdentifier) + } + + override def parseExpression(sqlText: String): Expression = + internal.parseExpression(sqlText) + + override def parseTableIdentifier(sqlText: String): TableIdentifier = + internal.parseTableIdentifier(sqlText) + + override def parseFunctionIdentifier(sqlText: String): FunctionIdentifier = + internal.parseFunctionIdentifier(sqlText) + + override def parseTableSchema(sqlText: String): StructType = + internal.parseTableSchema(sqlText) + + override def parseDataType(sqlText: String): DataType = + internal.parseDataType(sqlText) + + private def qualifyTableIdentifierInternal(tableIdentifier: Seq[String]): Seq[String] = { + if (tableIdentifier.size == 1) { + tiContext.tiCatalog.getCurrentDatabase :: tableIdentifier.toList + } else { + tableIdentifier + } + } + + private def qualifyTableIdentifierInternal( + tableIdentifier: TableIdentifier): TableIdentifier = { + TableIdentifier( + tableIdentifier.table, + Some(tableIdentifier.database.getOrElse(tiContext.tiCatalog.getCurrentDatabase))) + } + + private def needQualify(tableIdentifier: Seq[String]): Boolean = { + tableIdentifier.size == 1 && tiContext.sessionCatalog + .getTempView(tableIdentifier.head) + .isEmpty && !cteTableNames.get().contains(tableIdentifier.head.toLowerCase()) + } + + /** + * Determines whether a table specified by tableIdentifier is + * needs to be qualified. This is used for TiSpark to transform + * plans and decides whether a relation should be resolved or parsed. + * + * @param tableIdentifier tableIdentifier + * @return whether it needs qualifying + */ + private def needQualify(tableIdentifier: TableIdentifier): Boolean = { + tableIdentifier.database.isEmpty && tiContext.sessionCatalog + .getTempView(tableIdentifier.table) + .isEmpty + } + + override def parseMultipartIdentifier(sqlText: String): Seq[String] = + internal.parseMultipartIdentifier(sqlText) + + @scala.throws[ParseException]("Text cannot be parsed to a DataType") + override def parseRawDataType(sqlText: String): DataType = ??? +} diff --git a/spark-wrapper/spark-3.0/src/main/scala/org/apache/spark/sql/extensions/TiResolutionRule.scala b/spark-wrapper/spark-3.0/src/main/scala/org/apache/spark/sql/extensions/TiResolutionRule.scala new file mode 100644 index 0000000000..cfa444e118 --- /dev/null +++ b/spark-wrapper/spark-3.0/src/main/scala/org/apache/spark/sql/extensions/TiResolutionRule.scala @@ -0,0 +1,100 @@ +/* + * Copyright 2021 PingCAP, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.extensions + +import com.pingcap.tispark.statistics.StatisticsManager +import com.pingcap.tispark.{MetaManager, TiDBRelation, TiTableReference} +import org.apache.spark.sql.catalyst.analysis.{ + EliminateSubqueryAliases, + UnresolvedRelation, + UnresolvedTableOrView +} +import org.apache.spark.sql.catalyst.catalog.TiSessionCatalog +import org.apache.spark.sql.catalyst.plans.logical.{ + InsertIntoStatement, + LogicalPlan, + SubqueryAlias +} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.datasources.LogicalRelation +import org.apache.spark.sql.{AnalysisException, SparkSession, TiContext} + +case class TiResolutionRule( + getOrCreateTiContext: SparkSession => TiContext, + sparkSession: SparkSession) + extends Rule[LogicalPlan] { + protected lazy val meta: MetaManager = tiContext.meta + private lazy val autoLoad = tiContext.autoLoad + private lazy val tiCatalog = tiContext.tiCatalog + private lazy val tiSession = tiContext.tiSession + private lazy val sqlContext = tiContext.sqlContext + protected val tiContext: TiContext = getOrCreateTiContext(sparkSession) + + protected def resolveTiDBRelation( + withSubQueryAlias: Boolean = true): Seq[String] => LogicalPlan = + tableIdentifier => { + val dbName = getDatabaseFromIdentifier(tableIdentifier) + val tableName = + if (tableIdentifier.size == 1) tableIdentifier.head else tableIdentifier.tail.head + val table = meta.getTable(dbName, tableName) + if (table.isEmpty) { + throw new AnalysisException(s"Table or view '$tableName' not found in database '$dbName'") + } + if (autoLoad) { + StatisticsManager.loadStatisticsInfo(table.get) + } + val sizeInBytes = StatisticsManager.estimateTableSize(table.get) + val tiDBRelation = + TiDBRelation(tiSession, TiTableReference(dbName, tableName, sizeInBytes), meta)( + sqlContext) + if (withSubQueryAlias) { + // Use SubqueryAlias so that projects and joins can correctly resolve + // UnresolvedAttributes in JoinConditions, Projects, Filters, etc. + SubqueryAlias(tableName, LogicalRelation(tiDBRelation)) + } else { + LogicalRelation(tiDBRelation) + } + } + + override def apply(plan: LogicalPlan): LogicalPlan = + plan transformUp resolveTiDBRelations + + protected def resolveTiDBRelations: PartialFunction[LogicalPlan, LogicalPlan] = { + case i @ InsertIntoStatement(UnresolvedRelation(tableIdentifier), _, _, _, _) + if tiCatalog + .catalogOf(tableIdentifier) + .exists(_.isInstanceOf[TiSessionCatalog]) => + i.copy(table = EliminateSubqueryAliases(resolveTiDBRelation()(tableIdentifier))) + case UnresolvedRelation(tableIdentifier) + if tiCatalog + .catalogOf(tableIdentifier) + .exists(_.isInstanceOf[TiSessionCatalog]) => + resolveTiDBRelation()(tableIdentifier) + case UnresolvedTableOrView(tableIdentifier) + if tiCatalog + .catalogOf(tableIdentifier) + .exists(_.isInstanceOf[TiSessionCatalog]) => + resolveTiDBRelation(false)(tableIdentifier) + } + + private def getDatabaseFromIdentifier(tableIdentifier: Seq[String]): String = { + if (tableIdentifier.size == 1) { + tiCatalog.getCurrentDatabase + } else { + tableIdentifier.head + } + } +} diff --git a/spark-wrapper/spark-3.0/src/main/scala/org/apache/spark/sql/extensions/TiResolutionRuleV2.scala b/spark-wrapper/spark-3.0/src/main/scala/org/apache/spark/sql/extensions/TiResolutionRuleV2.scala new file mode 100644 index 0000000000..de851ee86e --- /dev/null +++ b/spark-wrapper/spark-3.0/src/main/scala/org/apache/spark/sql/extensions/TiResolutionRuleV2.scala @@ -0,0 +1,75 @@ +/* + * Copyright 2021 PingCAP, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.extensions + +import com.pingcap.tispark.statistics.StatisticsManager +import com.pingcap.tispark.{MetaManager, TiDBRelation, TiTableReference} +import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases +import org.apache.spark.sql.catalyst.catalog.TiDBTable +import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.catalyst.plans.logical.{ + InsertIntoStatement, + LogicalPlan, + SubqueryAlias +} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.datasources.LogicalRelation +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.{SparkSession, TiContext} + +case class TiResolutionRuleV2(getOrCreateTiContext: SparkSession => TiContext)( + sparkSession: SparkSession) + extends Rule[LogicalPlan] { + protected lazy val meta: MetaManager = tiContext.meta + private lazy val autoLoad = tiContext.autoLoad + //private lazy val tiCatalog = tiContext.tiCatalog + private lazy val tiSession = tiContext.tiSession + private lazy val sqlContext = tiContext.sqlContext + protected val tiContext: TiContext = getOrCreateTiContext(sparkSession) + + protected val resolveTiDBRelation: (TiDBTable, Seq[AttributeReference]) => LogicalPlan = + (tiTable, output) => { + if (autoLoad) { + StatisticsManager.loadStatisticsInfo(tiTable.tiTableInfo.get) + } + val sizeInBytes = StatisticsManager.estimateTableSize(tiTable.tiTableInfo.get) + val tiDBRelation = TiDBRelation( + tiSession, + TiTableReference(tiTable.databaseName, tiTable.tableName, sizeInBytes), + meta)(sqlContext) + // Use SubqueryAlias so that projects and joins can correctly resolve + // UnresolvedAttributes in JoinConditions, Projects, Filters, etc. + // todo since there is no UnresolvedAttributes, do we still need the subqueryAlias relation??? + SubqueryAlias(tiTable.tableName, LogicalRelation(tiDBRelation, output, None, false)) + } + + protected def resolveTiDBRelations: PartialFunction[LogicalPlan, LogicalPlan] = { + // todo can remove this branch since the target table of insert into statement should never be a tidb table + case i @ InsertIntoStatement(DataSourceV2Relation(table, output, _, _, _), _, _, _, _) + if table.isInstanceOf[TiDBTable] => + val tiTable = table.asInstanceOf[TiDBTable] + i.copy(table = EliminateSubqueryAliases(resolveTiDBRelation(tiTable, output))) + case DataSourceV2Relation(table, output, _, _, _) if table.isInstanceOf[TiDBTable] => + val tiTable = table.asInstanceOf[TiDBTable] + resolveTiDBRelation(tiTable, output) + } + + override def apply(plan: LogicalPlan): LogicalPlan = + plan match { + case _ => + plan transformUp resolveTiDBRelations + } +} diff --git a/spark-wrapper/spark-3.1/pom.xml b/spark-wrapper/spark-3.1/pom.xml new file mode 100644 index 0000000000..6d81996cba --- /dev/null +++ b/spark-wrapper/spark-3.1/pom.xml @@ -0,0 +1,123 @@ + + + 4.0.0 + + com.pingcap.tispark + tispark-parent + 2.5.0-SNAPSHOT + ../../pom.xml + + + spark-wrapper-spark-3.1 + jar + TiSpark Project Spark Wrapper Spark-3.1 + http://github.copm/pingcap/tispark + + + 3.1.1 + + + + + com.pingcap.tispark + tispark-core-internal + ${project.parent.version} + + + org.apache.spark + spark-core_2.12 + ${spark.version.wrapper} + + + org.apache.spark + spark-catalyst_2.12 + ${spark.version.wrapper} + + + org.apache.spark + spark-sql_2.12 + ${spark.version.wrapper} + + + + + src/main/scala + + + net.alchim31.maven + scala-maven-plugin + 4.3.0 + + + compile-scala + compile + + add-source + compile + + + + test-compile-scala + test-compile + + add-source + testCompile + + + + attach-javadocs + + doc-jar + + + + + ${scala.version} + + + + org.apache.maven.plugins + maven-compiler-plugin + 2.3.2 + + 1.8 + 1.8 + UTF-8 + true + true + + + + + org.apache.maven.plugins + maven-source-plugin + 3.0.1 + + + attach-sources + + jar-no-fork + + + + + + + org.antipathy + mvn-scalafmt_${scala.binary.version} + 1.0.3 + + ${scalafmt.skip} + ${scalafmt.skip} + + ${project.basedir}/src/main/scala + + + ${project.basedir}/src/test/scala + + + + + + diff --git a/spark-wrapper/spark-3.1/src/main/scala/com/pingcap/tispark/SparkWrapper.scala b/spark-wrapper/spark-3.1/src/main/scala/com/pingcap/tispark/SparkWrapper.scala new file mode 100644 index 0000000000..ee55770fbd --- /dev/null +++ b/spark-wrapper/spark-3.1/src/main/scala/com/pingcap/tispark/SparkWrapper.scala @@ -0,0 +1,52 @@ +/* + * Copyright 2021 PingCAP, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.pingcap.tispark + +import org.apache.spark.sql.catalyst.expressions.{ + Alias, + AliasHelper, + ExprId, + Expression, + SortOrder +} + +object SparkWrapper { + def getVersion: String = { + "SparkWrapper-3.1" + } + + def newAlias(child: Expression, name: String): Alias = { + Alias(child, name)() + } + + def newAlias(child: Expression, name: String, exprId: ExprId): Alias = { + Alias(child, name)(exprId = exprId) + } + + def trimNonTopLevelAliases(e: Expression): Expression = { + TiCleanupAliases.trimNonTopLevelAliases2(e) + } + + def copySortOrder(sortOrder: SortOrder, child: Expression): SortOrder = { + sortOrder.copy(child = child) + } +} + +object TiCleanupAliases extends AliasHelper { + def trimNonTopLevelAliases2[T <: Expression](e: T): T = { + super.trimNonTopLevelAliases(e) + } +} diff --git a/spark-wrapper/spark-3.1/src/main/scala/org/apache/spark/sql/catalyst/expressions/TiBasicExpression.scala b/spark-wrapper/spark-3.1/src/main/scala/org/apache/spark/sql/catalyst/expressions/TiBasicExpression.scala new file mode 100644 index 0000000000..ca6faf8090 --- /dev/null +++ b/spark-wrapper/spark-3.1/src/main/scala/org/apache/spark/sql/catalyst/expressions/TiBasicExpression.scala @@ -0,0 +1,132 @@ +/* + * + * Copyright 2021 PingCAP, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package org.apache.spark.sql.catalyst.expressions + +import com.pingcap.tikv.expression.{ + ArithmeticBinaryExpression, + ColumnRef, + ComparisonBinaryExpression, + Constant, + LogicalBinaryExpression, + StringRegExpression +} +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.expressions.BasicExpression.{ + TiExpression, + TiIsNull, + TiNot, + convertLiteral +} +import org.apache.spark.sql.execution.TiConverter +import org.apache.spark.sql.types.DecimalType + +object TiBasicExpression { + + def convertToTiExpr(expr: Expression): Option[TiExpression] = + expr match { + case Literal(value, dataType) => + Some( + Constant.create(convertLiteral(value, dataType), TiConverter.fromSparkType(dataType))) + + case Add(BasicExpression(lhs), BasicExpression(rhs), _) => + Some(ArithmeticBinaryExpression.plus(lhs, rhs)) + + case Subtract(BasicExpression(lhs), BasicExpression(rhs), _) => + Some(ArithmeticBinaryExpression.minus(lhs, rhs)) + + case e @ Multiply(BasicExpression(lhs), BasicExpression(rhs), _) => + Some(ArithmeticBinaryExpression.multiply(TiConverter.fromSparkType(e.dataType), lhs, rhs)) + + case d @ Divide(BasicExpression(lhs), BasicExpression(rhs), _) => + Some(ArithmeticBinaryExpression.divide(TiConverter.fromSparkType(d.dataType), lhs, rhs)) + + case And(BasicExpression(lhs), BasicExpression(rhs)) => + Some(LogicalBinaryExpression.and(lhs, rhs)) + + case Or(BasicExpression(lhs), BasicExpression(rhs)) => + Some(LogicalBinaryExpression.or(lhs, rhs)) + + case Alias(BasicExpression(child), _) => + Some(child) + + case IsNull(BasicExpression(child)) => + Some(new TiIsNull(child)) + + case IsNotNull(BasicExpression(child)) => + Some(new TiNot(new TiIsNull(child))) + + case GreaterThan(BasicExpression(lhs), BasicExpression(rhs)) => + Some(ComparisonBinaryExpression.greaterThan(lhs, rhs)) + + case GreaterThanOrEqual(BasicExpression(lhs), BasicExpression(rhs)) => + Some(ComparisonBinaryExpression.greaterEqual(lhs, rhs)) + + case LessThan(BasicExpression(lhs), BasicExpression(rhs)) => + Some(ComparisonBinaryExpression.lessThan(lhs, rhs)) + + case LessThanOrEqual(BasicExpression(lhs), BasicExpression(rhs)) => + Some(ComparisonBinaryExpression.lessEqual(lhs, rhs)) + + case EqualTo(BasicExpression(lhs), BasicExpression(rhs)) => + Some(ComparisonBinaryExpression.equal(lhs, rhs)) + + case Not(EqualTo(BasicExpression(lhs), BasicExpression(rhs))) => + Some(ComparisonBinaryExpression.notEqual(lhs, rhs)) + + case Not(BasicExpression(child)) => + Some(new TiNot(child)) + + case StartsWith(BasicExpression(lhs), BasicExpression(rhs)) => + Some(StringRegExpression.startsWith(lhs, rhs)) + + case Contains(BasicExpression(lhs), BasicExpression(rhs)) => + Some(StringRegExpression.contains(lhs, rhs)) + + case EndsWith(BasicExpression(lhs), BasicExpression(rhs)) => + Some(StringRegExpression.endsWith(lhs, rhs)) + + case Like(BasicExpression(lhs), BasicExpression(rhs), _) => + Some(StringRegExpression.like(lhs, rhs)) + + // Coprocessor has its own behavior of type promoting and overflow check + // so we simply remove it from expression and let cop handle it + case CheckOverflow(BasicExpression(expr), dec: DecimalType, _) => + expr.setDataType(TiConverter.fromSparkType(dec)) + Some(expr) + + case PromotePrecision(BasicExpression(expr)) => + Some(expr) + + case PromotePrecision(Cast(BasicExpression(expr), dec: DecimalType, _)) => + expr.setDataType(TiConverter.fromSparkType(dec)) + Some(expr) + + case PromotePrecision(BasicExpression(expr)) => + Some(expr) + + // TODO: Are all AttributeReference column reference in such context? + case attr: AttributeReference => + Some(ColumnRef.create(attr.name, TiConverter.fromSparkType(attr.dataType))) + + case uAttr: UnresolvedAttribute => + Some(ColumnRef.create(uAttr.name, TiConverter.fromSparkType(uAttr.dataType))) + + // TODO: Remove it and let it fail once done all translation + case _ => Option.empty[TiExpression] + } +} diff --git a/spark-wrapper/spark-3.1/src/main/scala/org/apache/spark/sql/extensions/TiDDLRule.scala b/spark-wrapper/spark-3.1/src/main/scala/org/apache/spark/sql/extensions/TiDDLRule.scala new file mode 100644 index 0000000000..38dadc1a0a --- /dev/null +++ b/spark-wrapper/spark-3.1/src/main/scala/org/apache/spark/sql/extensions/TiDDLRule.scala @@ -0,0 +1,120 @@ +/* + * Copyright 2021 PingCAP, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.extensions + +import com.pingcap.tispark.TiDBRelation +import org.apache.spark.sql.{SparkSession, TiContext} +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.plans.logical.{ + DescribeColumn, + DescribeRelation, + LogicalPlan, + SetCatalogAndNamespace, + ShowColumns, + ShowNamespaces +} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.connector.catalog.{CatalogManager, Identifier} +import org.apache.spark.sql.execution.command.{ + CreateTableLikeCommand, + DescribeColumnCommand, + DescribeTableCommand, + DescribeTableInfo, + ShowColumnsCommand, + ShowTablesCommand, + TiCreateTableLikeCommand, + TiDescribeColumnCommand, + TiDescribeTablesCommand, + TiSetDatabaseCommand, + TiShowColumnsCommand, + TiShowDatabasesCommand, + TiShowTablesCommand +} +import org.apache.spark.sql.execution.datasources.LogicalRelation + +case class TiDDLRule(getOrCreateTiContext: SparkSession => TiContext, sparkSession: SparkSession) + extends Rule[LogicalPlan] { + protected lazy val tiContext: TiContext = getOrCreateTiContext(sparkSession) + + def getDBAndTableName(ident: Identifier): (String, Option[String]) = { + ident.namespace() match { + case Array(db) => + (ident.name(), Some(db)) + case _ => + (ident.name(), None) + } + } + + def isSupportedCatalog(sd: SetCatalogAndNamespace): Boolean = { + if (sd.catalogName.isEmpty) + false + else { + sd.catalogName.get.equals(CatalogManager.SESSION_CATALOG_NAME) + } + } + + override def apply(plan: LogicalPlan): LogicalPlan = + plan transformUp { + // TODO: support other commands that may concern TiSpark catalog. + case sd: ShowNamespaces => + TiShowDatabasesCommand(tiContext, sd) + case sd: SetCatalogAndNamespace if isSupportedCatalog(sd) => + TiSetDatabaseCommand(tiContext, sd) + case st: ShowTablesCommand => + TiShowTablesCommand(tiContext, st) + case st: ShowColumnsCommand => + TiShowColumnsCommand(tiContext, st) + case ShowColumns(LogicalRelation(TiDBRelation(_, tableRef, _, _, _), _, _, _), _) => + TiShowColumnsCommand( + tiContext, + ShowColumnsCommand( + None, + new TableIdentifier(tableRef.tableName, Some(tableRef.databaseName)))) + case dt: DescribeTableCommand => + TiDescribeTablesCommand( + tiContext, + dt, + DescribeTableInfo( + TableIdentifier(dt.table.table, dt.table.database), + dt.partitionSpec, + dt.isExtended)) + case dt @ DescribeRelation( + LogicalRelation(TiDBRelation(_, tableRef, _, _, _), _, _, _), + _, + _) => + TiDescribeTablesCommand( + tiContext, + dt, + DescribeTableInfo( + TableIdentifier(tableRef.tableName, Some(tableRef.databaseName)), + dt.partitionSpec, + dt.isExtended)) + case dc: DescribeColumnCommand => + TiDescribeColumnCommand(tiContext, dc) + case DescribeColumn( + LogicalRelation(TiDBRelation(_, tableRef, _, _, _), _, _, _), + colNameParts, + isExtended) => + TiDescribeColumnCommand( + tiContext, + DescribeColumnCommand( + TableIdentifier(tableRef.tableName, Some(tableRef.databaseName)), + colNameParts, + isExtended)) + case ct: CreateTableLikeCommand => + TiCreateTableLikeCommand(tiContext, ct) + } +} diff --git a/spark-wrapper/spark-3.1/src/main/scala/org/apache/spark/sql/extensions/TiParser.scala b/spark-wrapper/spark-3.1/src/main/scala/org/apache/spark/sql/extensions/TiParser.scala new file mode 100644 index 0000000000..f1cb911279 --- /dev/null +++ b/spark-wrapper/spark-3.1/src/main/scala/org/apache/spark/sql/extensions/TiParser.scala @@ -0,0 +1,133 @@ +/* + * Copyright 2021 PingCAP, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.extensions + +import org.apache.spark.sql.{SparkSession, TiContext} +import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} +import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation +import org.apache.spark.sql.catalyst.expressions.{Expression, SubqueryExpression} +import org.apache.spark.sql.catalyst.parser.ParserInterface +import org.apache.spark.sql.catalyst.plans.logical.{ + InsertIntoStatement, + LogicalPlan, + SubqueryAlias, + With +} +import org.apache.spark.sql.execution.SparkSqlParser +import org.apache.spark.sql.execution.command.{ + CacheTableCommand, + CreateViewCommand, + ExplainCommand, + UncacheTableCommand +} +import org.apache.spark.sql.types.{DataType, StructType} + +import java.util + +case class TiParser( + getOrCreateTiContext: SparkSession => TiContext, + sparkSession: SparkSession, + delegate: ParserInterface) + extends ParserInterface { + private lazy val tiContext = getOrCreateTiContext(sparkSession) + private lazy val internal = new SparkSqlParser() + + private val cteTableNames = new ThreadLocal[java.util.Set[String]] { + override def initialValue(): util.Set[String] = new util.HashSet[String]() + } + + /** + * WAR to lead Spark to consider this relation being on local files. + * Otherwise Spark will lookup this relation in his session catalog. + * CHECK Spark [[org.apache.spark.sql.catalyst.analysis.Analyzer.ResolveRelations]] for details. + */ + private val qualifyTableIdentifier: PartialFunction[LogicalPlan, LogicalPlan] = { + case r @ UnresolvedRelation(tableIdentifier, _, _) if needQualify(tableIdentifier) => + r.copy(qualifyTableIdentifierInternal(tableIdentifier)) + case i @ InsertIntoStatement(r @ UnresolvedRelation(tableIdentifier, _, _), _, _, _, _, _) + if needQualify(tableIdentifier) => + // When getting temp view, we leverage legacy catalog. + i.copy(r.copy(qualifyTableIdentifierInternal(tableIdentifier))) + case w @ With(_, cteRelations) => + for (x <- cteRelations) { + cteTableNames.get().add(x._1.toLowerCase()) + } + w.copy(cteRelations = cteRelations + .map(p => (p._1, p._2.transform(qualifyTableIdentifier).asInstanceOf[SubqueryAlias]))) + case cv @ CreateViewCommand(_, _, _, _, _, child, _, _, _) => + cv.copy(child = child transform qualifyTableIdentifier) + case e @ ExplainCommand(plan, _) => + e.copy(logicalPlan = plan transform qualifyTableIdentifier) + case c @ CacheTableCommand(tableIdentifier, plan, _, _, _) + if plan.isEmpty && needQualify(tableIdentifier) => + // Caching an unqualified catalog table. + c.copy(qualifyTableIdentifierInternal(tableIdentifier)) + case c @ CacheTableCommand(_, plan, _, _, _) if plan.isDefined => + c.copy(plan = Some(plan.get transform qualifyTableIdentifier)) + case u @ UncacheTableCommand(tableIdentifier, _) if needQualify(tableIdentifier) => + // Uncaching an unqualified catalog table. + u.copy(qualifyTableIdentifierInternal(tableIdentifier)) + case logicalPlan => + logicalPlan transformExpressionsUp { + case s: SubqueryExpression => + val cteNamesBeforeSubQuery = new util.HashSet[String]() + cteNamesBeforeSubQuery.addAll(cteTableNames.get()) + val newPlan = s.withNewPlan(s.plan transform qualifyTableIdentifier) + // cte table names in the subquery should not been seen outside subquey + cteTableNames.get().clear() + cteTableNames.get().addAll(cteNamesBeforeSubQuery) + newPlan + } + } + + override def parsePlan(sqlText: String): LogicalPlan = { + val plan = internal.parsePlan(sqlText) + cteTableNames.get().clear() + plan.transform(qualifyTableIdentifier) + } + + override def parseExpression(sqlText: String): Expression = + internal.parseExpression(sqlText) + + override def parseTableIdentifier(sqlText: String): TableIdentifier = + internal.parseTableIdentifier(sqlText) + + override def parseFunctionIdentifier(sqlText: String): FunctionIdentifier = + internal.parseFunctionIdentifier(sqlText) + + override def parseTableSchema(sqlText: String): StructType = + internal.parseTableSchema(sqlText) + + override def parseDataType(sqlText: String): DataType = + internal.parseDataType(sqlText) + + private def qualifyTableIdentifierInternal(tableIdentifier: Seq[String]): Seq[String] = { + if (tableIdentifier.size == 1) { + tiContext.tiCatalog.getCurrentDatabase :: tableIdentifier.toList + } else { + tableIdentifier + } + } + + private def needQualify(tableIdentifier: Seq[String]): Boolean = { + tableIdentifier.size == 1 && tiContext.sessionCatalog + .getTempView(tableIdentifier.head) + .isEmpty && !cteTableNames.get().contains(tableIdentifier.head.toLowerCase()) + } + + override def parseMultipartIdentifier(sqlText: String): Seq[String] = + internal.parseMultipartIdentifier(sqlText) +} diff --git a/spark-wrapper/spark-3.1/src/main/scala/org/apache/spark/sql/extensions/TiResolutionRule.scala b/spark-wrapper/spark-3.1/src/main/scala/org/apache/spark/sql/extensions/TiResolutionRule.scala new file mode 100644 index 0000000000..797300c14a --- /dev/null +++ b/spark-wrapper/spark-3.1/src/main/scala/org/apache/spark/sql/extensions/TiResolutionRule.scala @@ -0,0 +1,100 @@ +/* + * Copyright 2021 PingCAP, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.extensions + +import com.pingcap.tispark.{MetaManager, TiDBRelation, TiTableReference} +import com.pingcap.tispark.statistics.StatisticsManager +import org.apache.spark.sql.{AnalysisException, SparkSession, TiContext} +import org.apache.spark.sql.catalyst.analysis.{ + EliminateSubqueryAliases, + UnresolvedRelation, + UnresolvedTableOrView +} +import org.apache.spark.sql.catalyst.catalog.TiSessionCatalog +import org.apache.spark.sql.catalyst.plans.logical.{ + InsertIntoStatement, + LogicalPlan, + SubqueryAlias +} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.datasources.LogicalRelation + +case class TiResolutionRule( + getOrCreateTiContext: SparkSession => TiContext, + sparkSession: SparkSession) + extends Rule[LogicalPlan] { + protected lazy val meta: MetaManager = tiContext.meta + private lazy val autoLoad = tiContext.autoLoad + private lazy val tiCatalog = tiContext.tiCatalog + private lazy val tiSession = tiContext.tiSession + private lazy val sqlContext = tiContext.sqlContext + protected val tiContext: TiContext = getOrCreateTiContext(sparkSession) + + protected def resolveTiDBRelation( + withSubQueryAlias: Boolean = true): Seq[String] => LogicalPlan = + tableIdentifier => { + val dbName = getDatabaseFromIdentifier(tableIdentifier) + val tableName = + if (tableIdentifier.size == 1) tableIdentifier.head else tableIdentifier.tail.head + val table = meta.getTable(dbName, tableName) + if (table.isEmpty) { + throw new AnalysisException(s"Table or view '$tableName' not found in database '$dbName'") + } + if (autoLoad) { + StatisticsManager.loadStatisticsInfo(table.get) + } + val sizeInBytes = StatisticsManager.estimateTableSize(table.get) + val tiDBRelation = + TiDBRelation(tiSession, TiTableReference(dbName, tableName, sizeInBytes), meta)( + sqlContext) + if (withSubQueryAlias) { + // Use SubqueryAlias so that projects and joins can correctly resolve + // UnresolvedAttributes in JoinConditions, Projects, Filters, etc. + SubqueryAlias(tableName, LogicalRelation(tiDBRelation)) + } else { + LogicalRelation(tiDBRelation) + } + } + + override def apply(plan: LogicalPlan): LogicalPlan = + plan transformUp resolveTiDBRelations + + protected def resolveTiDBRelations: PartialFunction[LogicalPlan, LogicalPlan] = { + case i @ InsertIntoStatement(UnresolvedRelation(tableIdentifier, _, _), _, _, _, _, _) + if tiCatalog + .catalogOf(tableIdentifier) + .exists(_.isInstanceOf[TiSessionCatalog]) => + i.copy(table = EliminateSubqueryAliases(resolveTiDBRelation()(tableIdentifier))) + case UnresolvedRelation(tableIdentifier, _, _) + if tiCatalog + .catalogOf(tableIdentifier) + .exists(_.isInstanceOf[TiSessionCatalog]) => + resolveTiDBRelation()(tableIdentifier) + case UnresolvedTableOrView(tableIdentifier, _, _) + if tiCatalog + .catalogOf(tableIdentifier) + .exists(_.isInstanceOf[TiSessionCatalog]) => + resolveTiDBRelation(false)(tableIdentifier) + } + + private def getDatabaseFromIdentifier(tableIdentifier: Seq[String]): String = { + if (tableIdentifier.size == 1) { + tiCatalog.getCurrentDatabase + } else { + tableIdentifier.head + } + } +} diff --git a/spark-wrapper/spark-3.1/src/main/scala/org/apache/spark/sql/extensions/TiResolutionRuleV2.scala b/spark-wrapper/spark-3.1/src/main/scala/org/apache/spark/sql/extensions/TiResolutionRuleV2.scala new file mode 100644 index 0000000000..b91c79cbf0 --- /dev/null +++ b/spark-wrapper/spark-3.1/src/main/scala/org/apache/spark/sql/extensions/TiResolutionRuleV2.scala @@ -0,0 +1,75 @@ +/* + * Copyright 2021 PingCAP, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.extensions + +import com.pingcap.tispark.{MetaManager, TiDBRelation, TiTableReference} +import com.pingcap.tispark.statistics.StatisticsManager +import org.apache.spark.sql.{SparkSession, TiContext} +import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases +import org.apache.spark.sql.catalyst.catalog.TiDBTable +import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.catalyst.plans.logical.{ + InsertIntoStatement, + LogicalPlan, + SubqueryAlias +} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.datasources.LogicalRelation +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation + +case class TiResolutionRuleV2(getOrCreateTiContext: SparkSession => TiContext)( + sparkSession: SparkSession) + extends Rule[LogicalPlan] { + protected lazy val meta: MetaManager = tiContext.meta + private lazy val autoLoad = tiContext.autoLoad + //private lazy val tiCatalog = tiContext.tiCatalog + private lazy val tiSession = tiContext.tiSession + private lazy val sqlContext = tiContext.sqlContext + protected val tiContext: TiContext = getOrCreateTiContext(sparkSession) + + protected val resolveTiDBRelation: (TiDBTable, Seq[AttributeReference]) => LogicalPlan = + (tiTable, output) => { + if (autoLoad) { + StatisticsManager.loadStatisticsInfo(tiTable.tiTableInfo.get) + } + val sizeInBytes = StatisticsManager.estimateTableSize(tiTable.tiTableInfo.get) + val tiDBRelation = TiDBRelation( + tiSession, + TiTableReference(tiTable.databaseName, tiTable.tableName, sizeInBytes), + meta)(sqlContext) + // Use SubqueryAlias so that projects and joins can correctly resolve + // UnresolvedAttributes in JoinConditions, Projects, Filters, etc. + // todo since there is no UnresolvedAttributes, do we still need the subqueryAlias relation??? + SubqueryAlias(tiTable.tableName, LogicalRelation(tiDBRelation, output, None, false)) + } + + protected def resolveTiDBRelations: PartialFunction[LogicalPlan, LogicalPlan] = { + // todo can remove this branch since the target table of insert into statement should never be a tidb table + case i @ InsertIntoStatement(DataSourceV2Relation(table, output, _, _, _), _, _, _, _, _) + if table.isInstanceOf[TiDBTable] => + val tiTable = table.asInstanceOf[TiDBTable] + i.copy(table = EliminateSubqueryAliases(resolveTiDBRelation(tiTable, output))) + case DataSourceV2Relation(table, output, _, _, _) if table.isInstanceOf[TiDBTable] => + val tiTable = table.asInstanceOf[TiDBTable] + resolveTiDBRelation(tiTable, output) + } + + override def apply(plan: LogicalPlan): LogicalPlan = + plan match { + case _ => + plan transformUp resolveTiDBRelations + } +}