From 83afee33562f092e07d1b6d4ee5748791f6edcc4 Mon Sep 17 00:00:00 2001 From: tomato <38561029+qidi1@users.noreply.github.com> Date: Mon, 22 Aug 2022 11:42:53 +0800 Subject: [PATCH] support Spark3.3 (#2492) --- assembly/src/main/assembly/assembly.xml | 9 + core-test/pom.xml | 20 +- core/pom.xml | 16 +- .../com/pingcap/tispark/TiSparkInfo.scala | 2 +- .../tispark/utils/ReflectionUtil.scala | 24 + .../analyzer/TiAuthorizationRule.scala | 12 +- .../sql/catalyst/catalog/TiCatalog.scala | 8 +- .../spark/sql/catalyst/parser/TiParser.scala | 3 + .../plans/logical/BasicLogicalPlan.scala | 29 + .../delete/DeleteNotSupportSuite.scala | 5 +- .../overflow/UnsignedOverflowSuite.scala | 2 +- .../tispark/telemetry/TelemetrySuite.scala | 4 + docs/authorization_userguide.md | 4 +- docs/delete_userguide.md | 3 +- pom.xml | 16 +- .../plans/logical/TiBasicLogicalPlan.scala | 36 + .../plans/logical/TiBasicLogicalPlan.scala | 36 + .../TiAggregationProjectionV2.scala | 2 +- .../plans/logical/TiBasicLogicalPlan.scala | 37 + .../connector/write/TiDBWriteBuilder.scala | 1 - .../TiAggregationProjectionV2.scala | 2 +- spark-wrapper/spark-3.3/pom.xml | 129 ++++ .../com/pingcap/tispark/SparkWrapper.scala | 52 ++ .../expressions/TiBasicExpression.scala | 133 ++++ .../plans/logical/TiBasicLogicalPlan.scala | 43 ++ .../connector/write/TiDBWriteBuilder.scala | 47 ++ .../TiAggregationProjectionV2.scala | 50 ++ .../spark/sql/extensions/TiStrategy.scala | 658 ++++++++++++++++++ 28 files changed, 1356 insertions(+), 27 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/BasicLogicalPlan.scala create mode 100644 spark-wrapper/spark-3.0/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TiBasicLogicalPlan.scala create mode 100644 spark-wrapper/spark-3.1/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TiBasicLogicalPlan.scala create mode 100644 spark-wrapper/spark-3.2/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TiBasicLogicalPlan.scala create mode 100644 spark-wrapper/spark-3.3/pom.xml create mode 100644 spark-wrapper/spark-3.3/src/main/scala/com/pingcap/tispark/SparkWrapper.scala create mode 100644 spark-wrapper/spark-3.3/src/main/scala/org/apache/spark/sql/catalyst/expressions/TiBasicExpression.scala create mode 100644 spark-wrapper/spark-3.3/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TiBasicLogicalPlan.scala create mode 100644 spark-wrapper/spark-3.3/src/main/scala/org/apache/spark/sql/connector/write/TiDBWriteBuilder.scala create mode 100644 spark-wrapper/spark-3.3/src/main/scala/org/apache/spark/sql/extensions/TiAggregationProjectionV2.scala create mode 100644 spark-wrapper/spark-3.3/src/main/scala/org/apache/spark/sql/extensions/TiStrategy.scala diff --git a/assembly/src/main/assembly/assembly.xml b/assembly/src/main/assembly/assembly.xml index e8ad969fa9..bb997e3532 100644 --- a/assembly/src/main/assembly/assembly.xml +++ b/assembly/src/main/assembly/assembly.xml @@ -45,5 +45,14 @@ **/* + + + ${project.parent.basedir}/spark-wrapper/spark-3.3/target/classes/ + + resources/spark-wrapper-spark-3_3 + + **/* + + diff --git a/core-test/pom.xml b/core-test/pom.xml index 89673077aa..126fbb03e9 100644 --- a/core-test/pom.xml +++ b/core-test/pom.xml @@ -71,17 +71,27 @@ org.apache.logging.log4j log4j-api - 2.17.1 + 2.17.2 org.apache.logging.log4j log4j-core - 2.17.1 + 2.17.2 org.apache.spark spark-core_${scala.binary.version} ${spark.version.test} + + + org.apache.logging.log4j + log4j-api + + + org.apache.logging.log4j + log4j-core + + org.apache.spark @@ -118,6 +128,12 @@ org.apache.hadoop hadoop-client 2.7.2 + + + org.slf4j + slf4j-log4j12 + + com.google.guava diff --git a/core/pom.xml b/core/pom.xml index 6250680b83..ff14dd3f5d 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -44,12 +44,12 @@ org.apache.logging.log4j log4j-api - 2.17.1 + 2.17.2 org.apache.logging.log4j log4j-core - 2.17.1 + 2.17.2 org.apache.spark @@ -60,6 +60,14 @@ io.netty netty-all + + org.apache.logging.log4j + log4j-api + + + org.apache.logging.log4j + log4j-core + @@ -107,6 +115,10 @@ io.netty netty-all + + org.slf4j + slf4j-log4j12 + diff --git a/core/src/main/scala/com/pingcap/tispark/TiSparkInfo.scala b/core/src/main/scala/com/pingcap/tispark/TiSparkInfo.scala index 638f8713f4..9ff081a147 100644 --- a/core/src/main/scala/com/pingcap/tispark/TiSparkInfo.scala +++ b/core/src/main/scala/com/pingcap/tispark/TiSparkInfo.scala @@ -22,7 +22,7 @@ import org.slf4j.LoggerFactory object TiSparkInfo { private final val logger = LoggerFactory.getLogger(getClass.getName) - val SUPPORTED_SPARK_VERSION: List[String] = "3.0" :: "3.1" :: "3.2" :: Nil + val SUPPORTED_SPARK_VERSION: List[String] = "3.0" :: "3.1" :: "3.2" :: "3.3" :: Nil val SPARK_VERSION: String = org.apache.spark.SPARK_VERSION diff --git a/core/src/main/scala/com/pingcap/tispark/utils/ReflectionUtil.scala b/core/src/main/scala/com/pingcap/tispark/utils/ReflectionUtil.scala index d3d897b50c..4129436e47 100644 --- a/core/src/main/scala/com/pingcap/tispark/utils/ReflectionUtil.scala +++ b/core/src/main/scala/com/pingcap/tispark/utils/ReflectionUtil.scala @@ -18,14 +18,18 @@ package com.pingcap.tispark.utils import com.pingcap.tispark.TiSparkInfo +import com.pingcap.tispark.auth.TiAuthorization import com.pingcap.tispark.write.TiDBOptions import org.apache.spark.sql.{SQLContext, SparkSession, Strategy, TiContext} import org.apache.spark.sql.catalyst.expressions.BasicExpression.TiExpression import org.apache.spark.sql.catalyst.expressions.{Alias, ExprId, Expression, SortOrder} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.{SparkSession, Strategy, TiContext} import org.apache.spark.sql.connector.write.{LogicalWriteInfo, WriteBuilder} import org.slf4j.LoggerFactory import java.io.File +import java.lang.reflect.InvocationTargetException import java.net.{URL, URLClassLoader} /** @@ -60,6 +64,8 @@ object ReflectionUtil { private val SPARK_WRAPPER_CLASS = "com.pingcap.tispark.SparkWrapper" private val TI_BASIC_EXPRESSION_CLASS = "org.apache.spark.sql.catalyst.expressions.TiBasicExpression" + private val TI_BASIC_LOGICAL_PLAN_CLASS = + "org.apache.spark.sql.catalyst.plans.logical.TiBasicLogicalPlan" private val TI_STRATEGY_CLASS = "org.apache.spark.sql.extensions.TiStrategy" private val TIDB_WRITE_BUILDER_CLASS = @@ -105,6 +111,24 @@ object ReflectionUtil { .asInstanceOf[Option[TiExpression]] } + def callTiBasicLogicalPlanVerifyAuthorizationRule( + logicalPlan: LogicalPlan, + tiAuthorization: Option[TiAuthorization]): LogicalPlan = { + try { + classLoader + .loadClass(TI_BASIC_LOGICAL_PLAN_CLASS) + .getDeclaredMethod( + "verifyAuthorizationRule", + classOf[LogicalPlan], + classOf[Option[TiAuthorization]]) + .invoke(null, logicalPlan, tiAuthorization) + .asInstanceOf[LogicalPlan] + } catch { + case ex: InvocationTargetException => + throw ex.getTargetException + } + } + def newTiStrategy( getOrCreateTiContext: SparkSession => TiContext, sparkSession: SparkSession): Strategy = { diff --git a/core/src/main/scala/org/apache/spark/sql/catalyst/analyzer/TiAuthorizationRule.scala b/core/src/main/scala/org/apache/spark/sql/catalyst/analyzer/TiAuthorizationRule.scala index 9593897b3d..1a65ca5ee1 100644 --- a/core/src/main/scala/org/apache/spark/sql/catalyst/analyzer/TiAuthorizationRule.scala +++ b/core/src/main/scala/org/apache/spark/sql/catalyst/analyzer/TiAuthorizationRule.scala @@ -17,7 +17,9 @@ package org.apache.spark.sql.catalyst.analyzer import com.pingcap.tispark.auth.TiAuthorization +import com.pingcap.tispark.v2.TiDBTable import org.apache.spark.sql.catalyst.plans.logical.{ + BasicLogicalPlan, DeleteFromTable, LogicalPlan, SetCatalogAndNamespace, @@ -26,7 +28,6 @@ import org.apache.spark.sql.catalyst.plans.logical.{ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.{SparkSession, TiContext} -import com.pingcap.tispark.v2.TiDBTable /** * Only work for table v2(catalog plugin) @@ -34,7 +35,6 @@ import com.pingcap.tispark.v2.TiDBTable case class TiAuthorizationRule(getOrCreateTiContext: SparkSession => TiContext)( sparkSession: SparkSession) extends Rule[LogicalPlan] { - protected val tiContext: TiContext = getOrCreateTiContext(sparkSession) private lazy val tiAuthorization: Option[TiAuthorization] = tiContext.tiAuthorization @@ -47,12 +47,8 @@ case class TiAuthorizationRule(getOrCreateTiContext: SparkSession => TiContext)( tiAuthorization) } dt - case sd @ SetCatalogAndNamespace(catalogManager, catalogName, namespace) => - if (catalogName.nonEmpty && catalogName.get.equals("tidb_catalog") && namespace.isDefined) { - namespace.get - .foreach(TiAuthorization.authorizeForSetDatabase(_, tiAuthorization)) - } - sd + case s: SetCatalogAndNamespace => + BasicLogicalPlan.verifyAuthorizationRule(s, tiAuthorization) case dr @ DataSourceV2Relation( TiDBTable(_, tableRef, _, _, _), output, diff --git a/core/src/main/scala/org/apache/spark/sql/catalyst/catalog/TiCatalog.scala b/core/src/main/scala/org/apache/spark/sql/catalyst/catalog/TiCatalog.scala index feda4d7313..629f8e1578 100644 --- a/core/src/main/scala/org/apache/spark/sql/catalyst/catalog/TiCatalog.scala +++ b/core/src/main/scala/org/apache/spark/sql/catalyst/catalog/TiCatalog.scala @@ -167,13 +167,17 @@ class TiCatalog extends TableCatalog with SupportsNamespaces { override def renameTable(oldIdent: Identifier, newIdent: Identifier): Unit = ??? - override def dropNamespace(namespace: Array[String]): Boolean = ??? - override def createNamespace( namespace: Array[String], metadata: util.Map[String, String]): Unit = ??? + // for spark version smaller than 3.3 + def dropNamespace(strings: Array[String]): Boolean = ??? + + // for spark version bigger equal 3.3 + def dropNamespace(namespace: Array[String], cascade: Boolean): Boolean = ??? + override def alterNamespace(namespace: Array[String], changes: NamespaceChange*): Unit = ??? } diff --git a/core/src/main/scala/org/apache/spark/sql/catalyst/parser/TiParser.scala b/core/src/main/scala/org/apache/spark/sql/catalyst/parser/TiParser.scala index 3067c110ed..cde71277b8 100644 --- a/core/src/main/scala/org/apache/spark/sql/catalyst/parser/TiParser.scala +++ b/core/src/main/scala/org/apache/spark/sql/catalyst/parser/TiParser.scala @@ -54,6 +54,9 @@ case class TiParser( @scala.throws[ParseException]("Text cannot be parsed to a DataType") def parseRawDataType(sqlText: String): DataType = ??? + @scala.throws[ParseException]("Text cannot be parsed to a LogicalPlan") + def parseQuery(sqlText: String): LogicalPlan = ??? + def getOrElseInitTiCatalog: TiCatalog = { val catalogManager = sparkSession.sessionState.catalogManager catalogManager.catalog("tidb_catalog").asInstanceOf[TiCatalog] diff --git a/core/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/BasicLogicalPlan.scala b/core/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/BasicLogicalPlan.scala new file mode 100644 index 0000000000..6c2431f899 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/BasicLogicalPlan.scala @@ -0,0 +1,29 @@ +/* + * + * Copyright 2022 PingCAP, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import com.pingcap.tispark.auth.TiAuthorization +import com.pingcap.tispark.utils.ReflectionUtil + +object BasicLogicalPlan { + def verifyAuthorizationRule( + logicalPlan: LogicalPlan, + tiAuthorization: Option[TiAuthorization]): LogicalPlan = + ReflectionUtil.callTiBasicLogicalPlanVerifyAuthorizationRule(logicalPlan, tiAuthorization) +} diff --git a/core/src/test/scala/com/pingcap/tispark/delete/DeleteNotSupportSuite.scala b/core/src/test/scala/com/pingcap/tispark/delete/DeleteNotSupportSuite.scala index 6eb8b3be58..3d56f11f37 100644 --- a/core/src/test/scala/com/pingcap/tispark/delete/DeleteNotSupportSuite.scala +++ b/core/src/test/scala/com/pingcap/tispark/delete/DeleteNotSupportSuite.scala @@ -31,10 +31,11 @@ class DeleteNotSupportSuite extends BaseBatchWriteTest("test_delete_not_support" test("Delete without WHERE clause") { jdbcUpdate(s"create table $dbtable(i int, s int,PRIMARY KEY (i))") - the[IllegalArgumentException] thrownBy { spark.sql(s"delete from $dbtable") - } should have message "requirement failed: Delete without WHERE clause is not supported" + } should (have message "requirement failed: Delete without WHERE clause is not supported" + // when where clause is empty, Spark3.3 will send AlwaysTrue() to TiSpark. + or have message "requirement failed: Delete with alwaysTrue WHERE clause is not supported") } test("Delete with alwaysTrue WHERE clause") { diff --git a/core/src/test/scala/com/pingcap/tispark/overflow/UnsignedOverflowSuite.scala b/core/src/test/scala/com/pingcap/tispark/overflow/UnsignedOverflowSuite.scala index 1926c3d305..4b6be247e4 100644 --- a/core/src/test/scala/com/pingcap/tispark/overflow/UnsignedOverflowSuite.scala +++ b/core/src/test/scala/com/pingcap/tispark/overflow/UnsignedOverflowSuite.scala @@ -262,7 +262,7 @@ class UnsignedOverflowSuite extends BaseBatchWriteTest("test_data_type_unsigned_ val jdbcErrorClass = classOf[java.lang.RuntimeException] val tidbErrorClass = classOf[java.lang.RuntimeException] val tidbErrorMsgStartWith = - "Error while encoding: java.lang.RuntimeException: java.lang.Integer is not a valid external type for schema of bigint\nif (assertnotnull(input[0, org.apache.spark.sql.Row, true]).isNullAt) null else validateexternaltype(getexternalrowfield(assertnotnull(input[0, org.apache.spark.sql.Row, true]), 0, c1), LongType) AS c1" + "Error while encoding: java.lang.RuntimeException: java.lang.Integer is not a valid external type for schema of bigint" compareTiDBWriteFailureWithJDBC( List(row), diff --git a/core/src/test/scala/com/pingcap/tispark/telemetry/TelemetrySuite.scala b/core/src/test/scala/com/pingcap/tispark/telemetry/TelemetrySuite.scala index f548913e10..f2831dd954 100644 --- a/core/src/test/scala/com/pingcap/tispark/telemetry/TelemetrySuite.scala +++ b/core/src/test/scala/com/pingcap/tispark/telemetry/TelemetrySuite.scala @@ -16,6 +16,9 @@ package com.pingcap.tispark.telemetry +import com.pingcap.tikv.{TiConfiguration, TiSession} +import com.pingcap.tispark.auth.TiAuthorization.tiConf +import com.pingcap.tispark.listener.CacheInvalidateListener import com.pingcap.tispark.utils.HttpClientUtil import com.sun.net.httpserver.{ HttpExchange, @@ -24,6 +27,7 @@ import com.sun.net.httpserver.{ HttpsConfigurator, HttpsServer } +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.test.SharedSQLContext import org.scalatest.Matchers.{be, noException} diff --git a/docs/authorization_userguide.md b/docs/authorization_userguide.md index ae9ed0c1d0..b5ecead472 100644 --- a/docs/authorization_userguide.md +++ b/docs/authorization_userguide.md @@ -9,10 +9,10 @@ perform. This feature allows you to execute SQL in TiSpark with Authorization and authentication, the same behavior as TiDB ## Prerequisites - + - The database's user account must have the `PROCESS` privilege. - TiSpark version >= 2.5.0 -- Spark version = 3.0.x or 3.1.x +- Spark version = 3.0.x or 3.1.x or 3.2.x or 3.3.x ## Setup diff --git a/docs/delete_userguide.md b/docs/delete_userguide.md index 070a5a696e..6b11ef6198 100644 --- a/docs/delete_userguide.md +++ b/docs/delete_userguide.md @@ -11,7 +11,7 @@ spark.sql.catalog.tidb_catalog.pd.addresses ${your_pd_address} ## Requirement - TiDB 4.x or 5.x -- Spark >= 3.0 +- Spark = 3.0.x or 3.1.x or 3.2.x or 3.3.x ## Delete with SQL ``` @@ -38,4 +38,3 @@ spark.sql("delete from tidb_catalog.db.table where xxx") - Delete from partition table is not supported. - Delete with Pessimistic Transaction Mode is not supported. - diff --git a/pom.xml b/pom.xml index 0de6335645..7041fc1afa 100644 --- a/pom.xml +++ b/pom.xml @@ -170,12 +170,13 @@ spark-wrapper/spark-3.0 spark-wrapper/spark-3.1 spark-wrapper/spark-3.2 + spark-wrapper/spark-3.3 assembly - spark-3.1.1 + spark-3.1 false @@ -186,7 +187,7 @@ - spark-3.2.1 + spark-3.2 false @@ -196,6 +197,17 @@ 3.2 + + spark-3.3 + + false + + + 3.3.0 + 3.3.0 + 3.3 + + jenkins diff --git a/spark-wrapper/spark-3.0/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TiBasicLogicalPlan.scala b/spark-wrapper/spark-3.0/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TiBasicLogicalPlan.scala new file mode 100644 index 0000000000..801862affd --- /dev/null +++ b/spark-wrapper/spark-3.0/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TiBasicLogicalPlan.scala @@ -0,0 +1,36 @@ +/* + * Copyright 2022 PingCAP, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import com.pingcap.tispark.auth.TiAuthorization + +object TiBasicLogicalPlan { + def verifyAuthorizationRule( + logicalPlan: LogicalPlan, + tiAuthorization: Option[TiAuthorization]): LogicalPlan = { + logicalPlan match { + case st @ SetCatalogAndNamespace(catalogManager, catalogName, namespace) => + if (catalogName.nonEmpty && catalogName.get.equals( + "tidb_catalog") && namespace.isDefined) { + namespace.get + .foreach(TiAuthorization.authorizeForSetDatabase(_, tiAuthorization)) + } + st + } + } + +} diff --git a/spark-wrapper/spark-3.1/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TiBasicLogicalPlan.scala b/spark-wrapper/spark-3.1/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TiBasicLogicalPlan.scala new file mode 100644 index 0000000000..801862affd --- /dev/null +++ b/spark-wrapper/spark-3.1/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TiBasicLogicalPlan.scala @@ -0,0 +1,36 @@ +/* + * Copyright 2022 PingCAP, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import com.pingcap.tispark.auth.TiAuthorization + +object TiBasicLogicalPlan { + def verifyAuthorizationRule( + logicalPlan: LogicalPlan, + tiAuthorization: Option[TiAuthorization]): LogicalPlan = { + logicalPlan match { + case st @ SetCatalogAndNamespace(catalogManager, catalogName, namespace) => + if (catalogName.nonEmpty && catalogName.get.equals( + "tidb_catalog") && namespace.isDefined) { + namespace.get + .foreach(TiAuthorization.authorizeForSetDatabase(_, tiAuthorization)) + } + st + } + } + +} diff --git a/spark-wrapper/spark-3.1/src/main/scala/org/apache/spark/sql/extensions/TiAggregationProjectionV2.scala b/spark-wrapper/spark-3.1/src/main/scala/org/apache/spark/sql/extensions/TiAggregationProjectionV2.scala index 3603f9c6ae..08193a2675 100644 --- a/spark-wrapper/spark-3.1/src/main/scala/org/apache/spark/sql/extensions/TiAggregationProjectionV2.scala +++ b/spark-wrapper/spark-3.1/src/main/scala/org/apache/spark/sql/extensions/TiAggregationProjectionV2.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.execution.datasources.v2.{ /** * I'm afraid that the same name with the object under the spark-wrapper/spark3.0 will lead to some problems. - * Although the same name will pass the itegration test + * Although the same name will pass the integration test */ object TiAggregationProjectionV2 { type ReturnType = (Seq[Expression], LogicalPlan, TiDBTable, Seq[NamedExpression]) diff --git a/spark-wrapper/spark-3.2/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TiBasicLogicalPlan.scala b/spark-wrapper/spark-3.2/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TiBasicLogicalPlan.scala new file mode 100644 index 0000000000..1138f3ce4d --- /dev/null +++ b/spark-wrapper/spark-3.2/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TiBasicLogicalPlan.scala @@ -0,0 +1,37 @@ +/* + * Copyright 2022 PingCAP, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import com.pingcap.tispark.auth.TiAuthorization +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SetCatalogAndNamespace} + +object TiBasicLogicalPlan { + def verifyAuthorizationRule( + logicalPlan: LogicalPlan, + tiAuthorization: Option[TiAuthorization]): LogicalPlan = { + logicalPlan match { + case st @ SetCatalogAndNamespace(catalogManager, catalogName, namespace) => + if (catalogName.nonEmpty && catalogName.get.equals( + "tidb_catalog") && namespace.isDefined) { + namespace.get + .foreach(TiAuthorization.authorizeForSetDatabase(_, tiAuthorization)) + } + st + } + } + +} diff --git a/spark-wrapper/spark-3.2/src/main/scala/org/apache/spark/sql/connector/write/TiDBWriteBuilder.scala b/spark-wrapper/spark-3.2/src/main/scala/org/apache/spark/sql/connector/write/TiDBWriteBuilder.scala index 98361987db..c8a7fbf7df 100644 --- a/spark-wrapper/spark-3.2/src/main/scala/org/apache/spark/sql/connector/write/TiDBWriteBuilder.scala +++ b/spark-wrapper/spark-3.2/src/main/scala/org/apache/spark/sql/connector/write/TiDBWriteBuilder.scala @@ -31,7 +31,6 @@ case class TiDBWriteBuilder( new InsertableRelation { override def insert(data: DataFrame, overwrite: Boolean): Unit = { val schema = info.schema() - println("Do write") val df = sqlContext.sparkSession.createDataFrame(data.toJavaRDD, schema) df.write .format("tidb") diff --git a/spark-wrapper/spark-3.2/src/main/scala/org/apache/spark/sql/extensions/TiAggregationProjectionV2.scala b/spark-wrapper/spark-3.2/src/main/scala/org/apache/spark/sql/extensions/TiAggregationProjectionV2.scala index 3603f9c6ae..08193a2675 100644 --- a/spark-wrapper/spark-3.2/src/main/scala/org/apache/spark/sql/extensions/TiAggregationProjectionV2.scala +++ b/spark-wrapper/spark-3.2/src/main/scala/org/apache/spark/sql/extensions/TiAggregationProjectionV2.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.execution.datasources.v2.{ /** * I'm afraid that the same name with the object under the spark-wrapper/spark3.0 will lead to some problems. - * Although the same name will pass the itegration test + * Although the same name will pass the integration test */ object TiAggregationProjectionV2 { type ReturnType = (Seq[Expression], LogicalPlan, TiDBTable, Seq[NamedExpression]) diff --git a/spark-wrapper/spark-3.3/pom.xml b/spark-wrapper/spark-3.3/pom.xml new file mode 100644 index 0000000000..024340a6d0 --- /dev/null +++ b/spark-wrapper/spark-3.3/pom.xml @@ -0,0 +1,129 @@ + + + 4.0.0 + + com.pingcap.tispark + tispark-parent + 3.1.0-SNAPSHOT + ../../pom.xml + + + spark-wrapper-spark-3.3_${scala.version.release} + jar + TiSpark Project Spark Wrapper Spark-3.3 + http://github.copm/pingcap/tispark + + + 3.3.0 + + + + + com.pingcap.tispark + tispark-core-internal + ${project.parent.version} + + + org.apache.spark + spark-core_2.12 + ${spark.version.wrapper} + + + org.apache.logging.log4j + log4j-slf4j-impl + + + + + org.apache.spark + spark-catalyst_2.12 + ${spark.version.wrapper} + + + org.apache.spark + spark-sql_2.12 + ${spark.version.wrapper} + + + + + src/main/scala + + + net.alchim31.maven + scala-maven-plugin + 4.3.0 + + + compile-scala + compile + + add-source + compile + + + + test-compile-scala + test-compile + + add-source + testCompile + + + + attach-javadocs + + doc-jar + + + + + ${scala.version} + + + + org.apache.maven.plugins + maven-compiler-plugin + 2.3.2 + + 1.8 + 1.8 + UTF-8 + true + true + + + + + org.apache.maven.plugins + maven-source-plugin + 3.0.1 + + + attach-sources + + jar-no-fork + + + + + + + org.antipathy + mvn-scalafmt_${scala.binary.version} + 1.0.3 + + ${scalafmt.skip} + ${scalafmt.skip} + + ${project.basedir}/src/main/scala + + + ${project.basedir}/src/test/scala + + + + + + diff --git a/spark-wrapper/spark-3.3/src/main/scala/com/pingcap/tispark/SparkWrapper.scala b/spark-wrapper/spark-3.3/src/main/scala/com/pingcap/tispark/SparkWrapper.scala new file mode 100644 index 0000000000..542f39cb9b --- /dev/null +++ b/spark-wrapper/spark-3.3/src/main/scala/com/pingcap/tispark/SparkWrapper.scala @@ -0,0 +1,52 @@ +package com.pingcap.tispark +/* + * Copyright 2022 PingCAP, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import org.apache.spark.sql.catalyst.expressions.{ + Alias, + AliasHelper, + ExprId, + Expression, + SortOrder +} + +object SparkWrapper { + def getVersion: String = { + "SparkWrapper-3.3" + } + + def newAlias(child: Expression, name: String): Alias = { + Alias(child, name)() + } + + def newAlias(child: Expression, name: String, exprId: ExprId): Alias = { + Alias(child, name)(exprId = exprId) + } + + def trimNonTopLevelAliases(e: Expression): Expression = { + TiCleanupAliases.trimNonTopLevelAliases2(e) + } + + def copySortOrder(sortOrder: SortOrder, child: Expression): SortOrder = { + sortOrder.copy(child = child) + } +} + +object TiCleanupAliases extends AliasHelper { + def trimNonTopLevelAliases2[T <: Expression](e: T): T = { + super.trimNonTopLevelAliases(e) + } +} diff --git a/spark-wrapper/spark-3.3/src/main/scala/org/apache/spark/sql/catalyst/expressions/TiBasicExpression.scala b/spark-wrapper/spark-3.3/src/main/scala/org/apache/spark/sql/catalyst/expressions/TiBasicExpression.scala new file mode 100644 index 0000000000..227e547617 --- /dev/null +++ b/spark-wrapper/spark-3.3/src/main/scala/org/apache/spark/sql/catalyst/expressions/TiBasicExpression.scala @@ -0,0 +1,133 @@ +/* + * + * Copyright 2022 PingCAP, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package org.apache.spark.sql.catalyst.expressions + +import com.pingcap.tikv.expression.{ + ArithmeticBinaryExpression, + ColumnRef, + ComparisonBinaryExpression, + Constant, + LogicalBinaryExpression, + StringRegExpression +} +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.expressions.BasicExpression.{ + TiExpression, + TiIsNull, + TiNot, + convertLiteral +} +import org.apache.spark.sql.execution.TiConverter +import org.apache.spark.sql.types.DecimalType + +object TiBasicExpression { + + def convertToTiExpr(expr: Expression): Option[TiExpression] = + expr match { + case Literal(value, dataType) => + Some( + Constant.create(convertLiteral(value, dataType), TiConverter.fromSparkType(dataType))) + + case Add(BasicExpression(lhs), BasicExpression(rhs), _) => + Some(ArithmeticBinaryExpression.plus(lhs, rhs)) + + case Subtract(BasicExpression(lhs), BasicExpression(rhs), _) => + Some(ArithmeticBinaryExpression.minus(lhs, rhs)) + + case e @ Multiply(BasicExpression(lhs), BasicExpression(rhs), _) => + Some(ArithmeticBinaryExpression.multiply(TiConverter.fromSparkType(e.dataType), lhs, rhs)) + + case d @ Divide(BasicExpression(lhs), BasicExpression(rhs), _) => + Some(ArithmeticBinaryExpression.divide(TiConverter.fromSparkType(d.dataType), lhs, rhs)) + + case And(BasicExpression(lhs), BasicExpression(rhs)) => + Some(LogicalBinaryExpression.and(lhs, rhs)) + + case Or(BasicExpression(lhs), BasicExpression(rhs)) => + Some(LogicalBinaryExpression.or(lhs, rhs)) + + case Alias(BasicExpression(child), _) => + Some(child) + + case IsNull(BasicExpression(child)) => + Some(new TiIsNull(child)) + + case IsNotNull(BasicExpression(child)) => + Some(new TiNot(new TiIsNull(child))) + + case GreaterThan(BasicExpression(lhs), BasicExpression(rhs)) => + Some(ComparisonBinaryExpression.greaterThan(lhs, rhs)) + + case GreaterThanOrEqual(BasicExpression(lhs), BasicExpression(rhs)) => + Some(ComparisonBinaryExpression.greaterEqual(lhs, rhs)) + + case LessThan(BasicExpression(lhs), BasicExpression(rhs)) => + Some(ComparisonBinaryExpression.lessThan(lhs, rhs)) + + case LessThanOrEqual(BasicExpression(lhs), BasicExpression(rhs)) => + Some(ComparisonBinaryExpression.lessEqual(lhs, rhs)) + + case EqualTo(BasicExpression(lhs), BasicExpression(rhs)) => + Some(ComparisonBinaryExpression.equal(lhs, rhs)) + + case Not(EqualTo(BasicExpression(lhs), BasicExpression(rhs))) => + Some(ComparisonBinaryExpression.notEqual(lhs, rhs)) + + case Not(BasicExpression(child)) => + Some(new TiNot(child)) + + case StartsWith(BasicExpression(lhs), BasicExpression(rhs)) => + Some(StringRegExpression.startsWith(lhs, rhs)) + + case Contains(BasicExpression(lhs), BasicExpression(rhs)) => + Some(StringRegExpression.contains(lhs, rhs)) + + case EndsWith(BasicExpression(lhs), BasicExpression(rhs)) => + Some(StringRegExpression.endsWith(lhs, rhs)) + + case Like(BasicExpression(lhs), BasicExpression(rhs), _) => + Some(StringRegExpression.like(lhs, rhs)) + + // Coprocessor has its own behavior of type promoting and overflow check + // so we simply remove it from expression and let cop handle it + case CheckOverflow(BasicExpression(expr), dec: DecimalType, _) => + expr.setDataType(TiConverter.fromSparkType(dec)) + Some(expr) + + case PromotePrecision(BasicExpression(expr)) => + Some(expr) + + case PromotePrecision(Cast(BasicExpression(expr), dec: DecimalType, _, _)) => + expr.setDataType(TiConverter.fromSparkType(dec)) + Some(expr) + + case PromotePrecision(BasicExpression(expr)) => + Some(expr) + + // TODO: Are all AttributeReference column reference in such context? + case attr: AttributeReference => + Some(ColumnRef.create(attr.name, TiConverter.fromSparkType(attr.dataType))) + + case uAttr: UnresolvedAttribute => + Some(ColumnRef.create(uAttr.name, TiConverter.fromSparkType(uAttr.dataType))) + + // TODO: Remove it and let it fail once done all translation + case _ => Option.empty[TiExpression] + } +} diff --git a/spark-wrapper/spark-3.3/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TiBasicLogicalPlan.scala b/spark-wrapper/spark-3.3/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TiBasicLogicalPlan.scala new file mode 100644 index 0000000000..d51fc90199 --- /dev/null +++ b/spark-wrapper/spark-3.3/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TiBasicLogicalPlan.scala @@ -0,0 +1,43 @@ +/* + * Copyright 2022 PingCAP, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import com.pingcap.tispark.auth.TiAuthorization +import org.apache.spark.sql.catalyst.analysis.{ResolvedDBObjectName, ResolvedNamespace} + +object TiBasicLogicalPlan { + def verifyAuthorizationRule( + logicalPlan: LogicalPlan, + tiAuthorization: Option[TiAuthorization]): LogicalPlan = { + logicalPlan match { + case st @ SetCatalogAndNamespace(namespace) => + namespace match { + case ResolvedNamespace(catalog, ns) => + if (catalog.name().equals("tidb_catalog")) { + ns.foreach(TiAuthorization.authorizeForSetDatabase(_, tiAuthorization)) + } + st + case ResolvedDBObjectName(catalog, nameParts) => + if (catalog.name().equals("tidb_catalog")) { + nameParts.foreach(TiAuthorization.authorizeForSetDatabase(_, tiAuthorization)) + } + st + } + } + } + +} diff --git a/spark-wrapper/spark-3.3/src/main/scala/org/apache/spark/sql/connector/write/TiDBWriteBuilder.scala b/spark-wrapper/spark-3.3/src/main/scala/org/apache/spark/sql/connector/write/TiDBWriteBuilder.scala new file mode 100644 index 0000000000..c8a7fbf7df --- /dev/null +++ b/spark-wrapper/spark-3.3/src/main/scala/org/apache/spark/sql/connector/write/TiDBWriteBuilder.scala @@ -0,0 +1,47 @@ +/* + * Copyright 2022 PingCAP, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.write + +import com.pingcap.tispark.write.TiDBOptions +import org.apache.spark.sql.sources.InsertableRelation +import org.apache.spark.sql.{DataFrame, SQLContext, SaveMode} + +case class TiDBWriteBuilder( + info: LogicalWriteInfo, + tiDBOptions: TiDBOptions, + sqlContext: SQLContext) + extends WriteBuilder { + override def build(): V1Write = + new V1Write { + override def toInsertableRelation: InsertableRelation = { + new InsertableRelation { + override def insert(data: DataFrame, overwrite: Boolean): Unit = { + val schema = info.schema() + val df = sqlContext.sparkSession.createDataFrame(data.toJavaRDD, schema) + df.write + .format("tidb") + .options(tiDBOptions.parameters) + .option(TiDBOptions.TIDB_DATABASE, tiDBOptions.database) + .option(TiDBOptions.TIDB_TABLE, tiDBOptions.table) + .option(TiDBOptions.TIDB_DEDUPLICATE, "false") + .mode(SaveMode.Append) + .save() + } + } + } + } +} diff --git a/spark-wrapper/spark-3.3/src/main/scala/org/apache/spark/sql/extensions/TiAggregationProjectionV2.scala b/spark-wrapper/spark-3.3/src/main/scala/org/apache/spark/sql/extensions/TiAggregationProjectionV2.scala new file mode 100644 index 0000000000..a85531333c --- /dev/null +++ b/spark-wrapper/spark-3.3/src/main/scala/org/apache/spark/sql/extensions/TiAggregationProjectionV2.scala @@ -0,0 +1,50 @@ +/* + * Copyright 2022 PingCAP, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.extensions + +import com.pingcap.tispark.v2.TiDBTable +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, NamedExpression} +import org.apache.spark.sql.catalyst.planning.PhysicalOperation +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.datasources.v2.{ + DataSourceV2Relation, + DataSourceV2ScanRelation +} + +/** + * I'm afraid that the same name with the object under the spark-wrapper/spark3.0 will lead to some problems. + * Although the same name will pass the integration test + */ +object TiAggregationProjectionV2 { + type ReturnType = (Seq[Expression], LogicalPlan, TiDBTable, Seq[NamedExpression]) + + def unapply(plan: LogicalPlan): Option[ReturnType] = + plan match { + // Only push down aggregates projection when all filters can be applied and + // all projection expressions are column references + case PhysicalOperation( + projects, + filters, + rel @ DataSourceV2ScanRelation( + DataSourceV2Relation(source: TiDBTable, _, _, _, _), + _, + _, + _)) if projects.forall(_.isInstanceOf[Attribute]) => + Some((filters, rel, source, projects)) + case _ => Option.empty[ReturnType] + } +} diff --git a/spark-wrapper/spark-3.3/src/main/scala/org/apache/spark/sql/extensions/TiStrategy.scala b/spark-wrapper/spark-3.3/src/main/scala/org/apache/spark/sql/extensions/TiStrategy.scala new file mode 100644 index 0000000000..64fc14e2b0 --- /dev/null +++ b/spark-wrapper/spark-3.3/src/main/scala/org/apache/spark/sql/extensions/TiStrategy.scala @@ -0,0 +1,658 @@ +/* + * Copyright 2022 PingCAP, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.extensions + +import com.pingcap.tidb.tipb.EncodeType +import com.pingcap.tikv.exception.IgnoreUnsupportedTypeException +import com.pingcap.tikv.expression._ +import com.pingcap.tikv.meta.TiDAGRequest.PushDownType +import com.pingcap.tikv.meta.{TiDAGRequest, TiTimestamp} +import com.pingcap.tikv.predicates.{PredicateUtils, TiKVScanAnalyzer} +import com.pingcap.tikv.region.TiStoreType +import com.pingcap.tikv.statistics.TableStatistics +import com.pingcap.tispark.TiConfigConst +import com.pingcap.tispark.statistics.StatisticsManager +import com.pingcap.tispark.utils.{ReflectionUtil, TiUtil} +import com.pingcap.tispark.v2.TiDBTable +import org.apache.spark.internal.Logging +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.expressions.{ + Alias, + Ascending, + Attribute, + AttributeMap, + AttributeSet, + Descending, + Expression, + IntegerLiteral, + IsNull, + NamedExpression, + NullsFirst, + NullsLast, + SortOrder, + SubqueryExpression, + TiExprUtils +} +import org.apache.spark.sql.catalyst.planner.TiAggregation +import org.apache.spark.sql.catalyst.planning.PhysicalOperation +import org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.datasources.v2.{ + DataSourceV2Relation, + DataSourceV2ScanRelation +} +import org.apache.spark.sql.internal.SQLConf +import org.joda.time.{DateTime, DateTimeZone} + +import java.util.concurrent.TimeUnit +import scala.collection.JavaConverters._ +import scala.collection.mutable + +object TiStrategy { + private val assignedTSPlanCache = new mutable.WeakHashMap[LogicalPlan, Boolean]() + + private def hasTSAssigned(plan: LogicalPlan): Boolean = { + assignedTSPlanCache.contains(plan) + } + + private def markTSAssigned(plan: LogicalPlan): Unit = { + plan foreachUp { p => + assignedTSPlanCache.put(p, true) + } + } +} + +/** + * CHECK Spark [[org.apache.spark.sql.Strategy]] + * + * TODO: Too many hacks here since we hijack the planning + * but we don't have full control over planning stage + * We cannot pass context around during planning so + * a re-extract needed for push-down since + * a plan tree might contain Join which causes a single tree + * have multiple plans to push-down + */ +case class TiStrategy(getOrCreateTiContext: SparkSession => TiContext)(sparkSession: SparkSession) + extends Strategy + with Logging { + type TiExpression = com.pingcap.tikv.expression.Expression + type TiColumnRef = com.pingcap.tikv.expression.ColumnRef + private lazy val tiContext: TiContext = getOrCreateTiContext(sparkSession) + private lazy val sqlContext = tiContext.sqlContext + private lazy val sqlConf: SQLConf = sqlContext.conf + + def typeBlockList: TypeBlocklist = { + val blocklistString = + sqlConf.getConfString(TiConfigConst.UNSUPPORTED_TYPES, "") + new TypeBlocklist(blocklistString) + } + + override def apply(plan: LogicalPlan): Seq[SparkPlan] = { + TiExtensions.validateCatalog(sparkSession) + val ts = if (TiUtil.getTiDBSnapshot(sparkSession).isEmpty) { + tiContext.tiSession.getTimestamp + } else { + tiContext.tiSession.getSnapshotTimestamp + } + + if (plan.isStreaming) { + // We should use a new timestamp for next batch execution. + // Otherwise Spark Structure Streaming will not see new data in TiDB. + if (!TiStrategy.hasTSAssigned(plan)) { + plan foreachUp applyStartTs(ts, forceUpdate = true) + TiStrategy.markTSAssigned(plan) + } + } else { + plan foreachUp applyStartTs(ts) + } + + plan + .collectFirst { + case DataSourceV2ScanRelation( + DataSourceV2Relation(table: TiDBTable, _, _, _, _), + _, + _, + _) => + doPlan(table, plan) + } + .toSeq + .flatten + } + + def referencedTiColumns(expression: TiExpression): Seq[TiColumnRef] = + PredicateUtils.extractColumnRefFromExpression(expression).asScala.toSeq + + /** + * build a Seq of used TiColumnRef from AttributeSet and bound them to source table + * + * @param attributeSet AttributeSet containing projects w/ or w/o filters + * @param source source TiDBRelation + * @return a Seq of TiColumnRef extracted + */ + def buildTiColumnRefFromColumnSeq( + attributeSet: AttributeSet, + source: TiDBTable): Seq[TiColumnRef] = { + val tiColumnSeq: Seq[TiExpression] = attributeSet.toSeq.map { expr => + TiExprUtils.transformAttrToColRef(expr, source.table) + } + var tiColumns: mutable.HashSet[TiColumnRef] = mutable.HashSet.empty[TiColumnRef] + for (expression <- tiColumnSeq) { + val colSetPerExpr = PredicateUtils.extractColumnRefFromExpression(expression) + colSetPerExpr.asScala.foreach { + tiColumns += _ + } + } + tiColumns.toSeq + } + + // apply StartTs to every logical plan in Spark Planning stage + protected def applyStartTs( + ts: TiTimestamp, + forceUpdate: Boolean = false): PartialFunction[LogicalPlan, Unit] = { + case DataSourceV2ScanRelation( + DataSourceV2Relation(r @ TiDBTable(_, _, _, timestamp, _), _, _, _, _), + _, + _, + _) => + if (timestamp == null || forceUpdate) { + r.ts = ts + } + case logicalPlan => + logicalPlan transformExpressionsUp { + case s: SubqueryExpression => + s.plan.foreachUp(applyStartTs(ts)) + s + } + } + + private def blocklist: ExpressionBlocklist = { + val blocklistString = sqlConf.getConfString(TiConfigConst.UNSUPPORTED_PUSHDOWN_EXPR, "") + new ExpressionBlocklist(blocklistString) + } + + private def allowAggregationPushDown(): Boolean = + sqlConf.getConfString(TiConfigConst.ALLOW_AGG_PUSHDOWN, "true").toLowerCase.toBoolean + + private def useIndexScanFirst(): Boolean = + sqlConf.getConfString(TiConfigConst.USE_INDEX_SCAN_FIRST, "false").toLowerCase.toBoolean + + private def allowIndexRead(): Boolean = + sqlConf.getConfString(TiConfigConst.ALLOW_INDEX_READ, "true").toLowerCase.toBoolean + + private def useStreamingProcess: Boolean = + sqlConf.getConfString(TiConfigConst.COPROCESS_STREAMING, "false").toLowerCase.toBoolean + + private def getCodecFormat: EncodeType = { + // FIXME: Should use default codec format "chblock", change it back after fix. + val codecFormatStr = + sqlConf + .getConfString(TiConfigConst.CODEC_FORMAT, TiConfigConst.DEFAULT_CODEC_FORMAT) + .toLowerCase + + codecFormatStr match { + case TiConfigConst.CHUNK_CODEC_FORMAT => EncodeType.TypeChunk + case TiConfigConst.DEFAULT_CODEC_FORMAT => EncodeType.TypeCHBlock + case _ => EncodeType.TypeDefault + } + } + + private def eligibleStorageEngines(source: TiDBTable): List[TiStoreType] = + TiUtil.getIsolationReadEngines(sqlContext).filter { + case TiStoreType.TiKV => true + case TiStoreType.TiFlash => source.isTiFlashReplicaAvailable + case _ => false + } + + private def timeZoneOffsetInSeconds(): Int = { + val tz = DateTimeZone.getDefault + val instant = DateTime.now.getMillis + val offsetInMilliseconds = tz.getOffset(instant) + val hours = TimeUnit.MILLISECONDS.toHours(offsetInMilliseconds).toInt + val seconds = hours * 3600 + seconds + } + + private def newTiDAGRequest(): TiDAGRequest = { + val ts = timeZoneOffsetInSeconds() + if (useStreamingProcess) { + new TiDAGRequest(PushDownType.STREAMING, ts) + } else { + new TiDAGRequest(PushDownType.NORMAL, getCodecFormat, ts) + } + } + + private def toCoprocessorRDD( + source: TiDBTable, + output: Seq[Attribute], + dagRequest: TiDAGRequest): SparkPlan = { + dagRequest.setTableInfo(source.table) + dagRequest.setStartTs(source.ts) + + val notAllowPushDown = dagRequest.getFields.asScala + .map { + _.getDataType.getType + } + .exists { + typeBlockList.isUnsupportedType + } + + if (notAllowPushDown) { + throw new IgnoreUnsupportedTypeException( + "Unsupported type found in fields: " + typeBlockList) + } else { + if (dagRequest.isDoubleRead) { + source.dagRequestToRegionTaskExec(dagRequest, output) + } else { + ColumnarCoprocessorRDD( + output, + source.logicalPlanToRDD(dagRequest, output), + fetchHandle = false) + } + } + } + + private def aggregationToDAGRequest( + groupByList: Seq[NamedExpression], + aggregates: Seq[AggregateExpression], + source: TiDBTable, + dagRequest: TiDAGRequest): TiDAGRequest = { + aggregates + .map { + _.aggregateFunction + } + .foreach { expr => + TiExprUtils.transformAggExprToTiAgg(expr, source.table, dagRequest) + } + + groupByList.foreach { expr => + TiExprUtils.transformGroupingToTiGrouping(expr, source.table, dagRequest) + } + + dagRequest + } + + private def filterToDAGRequest( + tiColumns: Seq[TiColumnRef], + filters: Seq[Expression], + source: TiDBTable, + dagRequest: TiDAGRequest): TiDAGRequest = { + val tiFilters: Seq[TiExpression] = filters.map { + TiExprUtils.transformFilter(_, source.table, dagRequest) + } + + val scanBuilder: TiKVScanAnalyzer = new TiKVScanAnalyzer + + val tblStatistics: TableStatistics = StatisticsManager.getTableStatistics(source.table.getId) + + // engines that could be chosen. + val engines = eligibleStorageEngines(source) + + if (engines.isEmpty) { + throw new RuntimeException( + s"No eligible storage engines found for $source, " + + s"isolation_read_engines = ${TiUtil.getIsolationReadEngines(sqlContext)}") + } + + scanBuilder.buildTiDAGReq( + allowIndexRead(), + useIndexScanFirst(), + engines.contains(TiStoreType.TiKV), + engines.contains(TiStoreType.TiFlash), + tiColumns.map { colRef => + source.table.getColumn(colRef.getName) + }.asJava, + tiFilters.asJava, + source.table, + tblStatistics, + source.ts, + dagRequest) + } + + private def pruneTopNFilterProject( + limit: Int, + projectList: Seq[NamedExpression], + filterPredicates: Seq[Expression], + source: TiDBTable, + sortOrder: Seq[SortOrder]): SparkPlan = { + val request = newTiDAGRequest() + request.setLimit(limit) + TiExprUtils.transformSortOrderToTiOrderBy(request, sortOrder, source.table) + + pruneFilterProject(projectList, filterPredicates, source, request) + } + + private def collectLimit(limit: Int, child: LogicalPlan): SparkPlan = + child match { + case PhysicalOperation( + projectList, + filters, + DataSourceV2ScanRelation( + DataSourceV2Relation(source: TiDBTable, _, _, _, _), + _, + _, + _)) if filters.forall(TiExprUtils.isSupportedFilter(_, source, blocklist)) => + pruneTopNFilterProject(limit, projectList, filters, source, Nil) + case _ => planLater(child) + } + + private def takeOrderedAndProject( + limit: Int, + sortOrder: Seq[SortOrder], + child: LogicalPlan, + project: Seq[NamedExpression]): SparkPlan = { + // If sortOrder is empty, limit must be greater than 0 + if (limit < 0 || (sortOrder.isEmpty && limit == 0)) { + return execution.TakeOrderedAndProjectExec(limit, sortOrder, project, planLater(child)) + } + + child match { + case PhysicalOperation( + projectList, + filters, + DataSourceV2ScanRelation( + DataSourceV2Relation(source: TiDBTable, _, _, _, _), + _, + _, + _)) if filters.forall(TiExprUtils.isSupportedFilter(_, source, blocklist)) => + val refinedOrders = refineSortOrder(projectList, sortOrder, source) + if (refinedOrders.isEmpty) { + execution.TakeOrderedAndProjectExec(limit, sortOrder, project, planLater(child)) + } else { + execution.TakeOrderedAndProjectExec( + limit, + sortOrder, + project, + pruneTopNFilterProject(limit, projectList, filters, source, refinedOrders.get)) + } + case _ => execution.TakeOrderedAndProjectExec(limit, sortOrder, project, planLater(child)) + } + } + + // refine sort order + // 1. sort order expressions are all valid to be pushed + // 2. if any reference to projections are valid to be pushed + private def refineSortOrder( + projectList: Seq[NamedExpression], + sortOrders: Seq[SortOrder], + source: TiDBTable): Option[Seq[SortOrder]] = { + val aliases = AttributeMap(projectList.collect { + case a: Alias => a.toAttribute -> a + }) + // Order by desc/asc + nulls first/last + // + // 1. Order by asc + nulls first: + // order by col asc nulls first = order by col asc + // 2. Order by desc + nulls first: + // order by col desc nulls first = order by col is null desc, col desc + // 3. Order by asc + nulls last: + // order by col asc nulls last = order by col is null asc, col asc + // 4. Order by desc + nulls last: + // order by col desc nulls last = order by col desc + val refinedSortOrder = sortOrders.flatMap { sortOrder: SortOrder => + val newSortExpr = sortOrder.child.transformUp { + case a: Attribute => aliases.getOrElse(a, a) + } + val trimmedExpr = ReflectionUtil.trimNonTopLevelAliases(newSortExpr) + val trimmedSortOrder = ReflectionUtil.copySortOrder(sortOrder, trimmedExpr) + + (sortOrder.direction, sortOrder.nullOrdering) match { + case (_ @Ascending, _ @NullsLast) | (_ @Descending, _ @NullsFirst) => + ReflectionUtil.copySortOrder(sortOrder, IsNull(trimmedExpr)) :: trimmedSortOrder :: Nil + case _ => + trimmedSortOrder :: Nil + } + } + if (refinedSortOrder + .exists(order => !TiExprUtils.isSupportedOrderBy(order.child, source, blocklist))) { + Option.empty + } else { + Some(refinedSortOrder) + } + } + + private def pruneFilterProject( + projectList: Seq[NamedExpression], + filterPredicates: Seq[Expression], + source: TiDBTable, + dagRequest: TiDAGRequest): SparkPlan = { + + val projectSet = AttributeSet(projectList.flatMap(_.references)) + val filterSet = AttributeSet(filterPredicates.flatMap(_.references)) + + val (pushdownFilters: Seq[Expression], residualFilters: Seq[Expression]) = + filterPredicates.partition((expression: Expression) => + TiExprUtils.isSupportedFilter(expression, source, blocklist)) + + val residualFilter: Option[Expression] = + residualFilters.reduceLeftOption(catalyst.expressions.And) + + val tiColumns = buildTiColumnRefFromColumnSeq(projectSet ++ filterSet, source) + + filterToDAGRequest(tiColumns, pushdownFilters, source, dagRequest) + + if (tiColumns.isEmpty) { + // we cannot send a request with empty columns + if (dagRequest.hasIndex) { + // add the first index column so that the plan will contain at least one column. + val idxColumn = dagRequest.getIndexInfo.getIndexColumns.get(0) + dagRequest.addRequiredColumn(ColumnRef.create(idxColumn.getName, source.table)) + } else { + // add a random column so that the plan will contain at least one column. + // if the table contains a primary key then use the PK instead. + val column = source.table.getColumns.asScala + .collectFirst { + case e if e.isPrimaryKey => e + } + .getOrElse(source.table.getColumn(0)) + dagRequest.addRequiredColumn(ColumnRef.create(column.getName, source.table)) + } + } + + // Right now we still use a projection even if the only evaluation is applying an alias + // to a column. Since this is a no-op, it could be avoided. However, using this + // optimization with the current implementation would change the output schema. + // TODO: Decouple final output schema from expression evaluation so this copy can be + // avoided safely. + if (AttributeSet(projectList.map(_.toAttribute)) == projectSet && + filterSet.subsetOf(projectSet)) { + // When it is possible to just use column pruning to get the right projection and + // when the columns of this projection are enough to evaluate all filter conditions, + // just do a scan followed by a filter, with no extra project. + val projectSeq: Seq[Attribute] = projectList.asInstanceOf[Seq[Attribute]] + projectSeq.foreach( + attr => + dagRequest.addRequiredColumn( + ColumnRef.create(attr.name, source.table.getColumn(attr.name)))) + val scan = toCoprocessorRDD(source, projectSeq, dagRequest) + residualFilter.fold(scan)(FilterExec(_, scan)) + } else { + // for now all column used will be returned for old interface + // TODO: once switch to new interface we change this pruning logic + val projectSeq: Seq[Attribute] = (projectSet ++ filterSet).toSeq + projectSeq.foreach( + attr => + dagRequest.addRequiredColumn( + ColumnRef.create(attr.name, source.table.getColumn(attr.name)))) + val scan = toCoprocessorRDD(source, projectSeq, dagRequest) + ProjectExec(projectList, residualFilter.fold(scan)(FilterExec(_, scan))) + } + } + + private def groupAggregateProjection( + tiColumns: Seq[TiColumnRef], + groupingExpressions: Seq[NamedExpression], + aggregateExpressions: Seq[AggregateExpression], + resultExpressions: Seq[NamedExpression], + source: TiDBTable, + dagReq: TiDAGRequest): Seq[SparkPlan] = { + val deterministicAggAliases = aggregateExpressions.collect { + case e if e.deterministic => e.canonicalized -> ReflectionUtil.newAlias(e, e.toString()) + }.toMap + + def aliasPushedPartialResult(e: AggregateExpression): Alias = + deterministicAggAliases.getOrElse(e.canonicalized, ReflectionUtil.newAlias(e, e.toString())) + + val residualAggregateExpressions = aggregateExpressions.map { aggExpr => + // As `aggExpr` is being pushing down to TiKV, we need to replace the original Catalyst + // aggregate expressions with new ones that merges the partial aggregation results returned by + // TiKV. + // + // NOTE: Unlike simple aggregate functions (e.g., `Max`, `Min`, etc.), `Count` must be + // replaced with a `Sum` to sum up the partial counts returned by TiKV. + // + // NOTE: All `Average`s should have already been rewritten into `Sum`s and `Count`s by the + // `TiAggregation` pattern extractor. + + // An attribute referring to the partial aggregation results returned by TiKV. + val partialResultRef = aliasPushedPartialResult(aggExpr).toAttribute + + aggExpr.aggregateFunction match { + case e: Max => aggExpr.copy(aggregateFunction = e.copy(child = partialResultRef)) + case e: Min => aggExpr.copy(aggregateFunction = e.copy(child = partialResultRef)) + case e: Sum => aggExpr.copy(aggregateFunction = e.copy(child = partialResultRef)) + case e: SpecialSum => aggExpr.copy(aggregateFunction = e.copy(child = partialResultRef)) + case e: First => aggExpr.copy(aggregateFunction = e.copy(child = partialResultRef)) + case _: Count => + aggExpr.copy(aggregateFunction = CountSum(partialResultRef)) + case _: Average => throw new IllegalStateException("All AVGs should have been rewritten.") + case _ => aggExpr + } + } + + tiColumns foreach { + dagReq.addRequiredColumn + } + + aggregationToDAGRequest(groupingExpressions, aggregateExpressions.distinct, source, dagReq) + + val aggregateAttributes = + aggregateExpressions.map(expr => aliasPushedPartialResult(expr).toAttribute) + val groupAttributes = groupingExpressions.map(_.toAttribute) + + // output of Coprocessor plan should contain all references within + // aggregates and group by expressions + val output = aggregateAttributes ++ groupAttributes + + val groupExpressionMap = + groupingExpressions.map(expr => expr.exprId -> expr.toAttribute).toMap + + // resultExpression might refer to some of the group by expressions + // Those expressions originally refer to table columns but now it refers to + // results of coprocessor. + // For example, select a + 1 from t group by a + 1 + // expression a + 1 has been pushed down to coprocessor + // and in turn a + 1 in projection should be replaced by + // reference of coprocessor output entirely + val rewrittenResultExpressions = resultExpressions.map { + _.transform { + case e: NamedExpression => groupExpressionMap.getOrElse(e.exprId, e) + }.asInstanceOf[NamedExpression] + } + + aggregate.AggUtils.planAggregateWithoutDistinct( + groupAttributes, + residualAggregateExpressions, + rewrittenResultExpressions, + toCoprocessorRDD(source, output, dagReq)) + } + + private def isValidAggregates( + groupingExpressions: Seq[NamedExpression], + aggregateExpressions: Seq[AggregateExpression], + filters: Seq[Expression], + source: TiDBTable): Boolean = + allowAggregationPushDown && + filters.forall(TiExprUtils.isSupportedFilter(_, source, blocklist)) && + groupingExpressions.forall(TiExprUtils.isSupportedGroupingExpr(_, source, blocklist)) && + aggregateExpressions.forall(TiExprUtils.isSupportedAggregate(_, source, blocklist)) && + !aggregateExpressions.exists(_.isDistinct) && + // TODO: This is a temporary fix for the issue: https://github.com/pingcap/tispark/issues/1039 + !groupingExpressions.exists(_.isInstanceOf[Alias]) + + // We do through similar logic with original Spark as in SparkStrategies.scala + // Difference is we need to test if a sub-plan can be consumed all together by TiKV + // and then we don't return (don't planLater) and plan the remaining all at once + // TODO: This test should be done once for all children + private def doPlan(source: TiDBTable, plan: LogicalPlan): Seq[SparkPlan] = + plan match { + case logical.ReturnAnswer(rootPlan) => + rootPlan match { + case logical.Limit(IntegerLiteral(limit), logical.Sort(order, true, child)) => + takeOrderedAndProject(limit, order, child, child.output) :: Nil + case logical.Limit( + IntegerLiteral(limit), + logical.Project(projectList, logical.Sort(order, true, child))) => + takeOrderedAndProject(limit, order, child, projectList) :: Nil + case logical.Limit(IntegerLiteral(limit), child) => + execution.CollectLimitExec(limit, collectLimit(limit, child)) :: Nil + case other => planLater(other) :: Nil + } + case logical.Limit(IntegerLiteral(limit), logical.Sort(order, true, child)) => + takeOrderedAndProject(limit, order, child, child.output) :: Nil + case logical.Limit( + IntegerLiteral(limit), + logical.Project(projectList, logical.Sort(order, true, child))) => + takeOrderedAndProject(limit, order, child, projectList) :: Nil + case logical.Limit(IntegerLiteral(limit), child) => + execution.CollectLimitExec(limit, collectLimit(limit, child)) :: Nil + // Collapse filters and projections and push plan directly + case PhysicalOperation( + projectList, + filters, + DataSourceV2ScanRelation( + DataSourceV2Relation(source: TiDBTable, _, _, _, _), + _, + _, + _)) => + pruneFilterProject(projectList, filters, source, newTiDAGRequest()) :: Nil + + // Basic logic of original Spark's aggregation plan is: + // PhysicalAggregation extractor will rewrite original aggregation + // into aggregateExpressions and resultExpressions. + // resultExpressions contains only references [[AttributeReference]] + // to the result of aggregation. resultExpressions might contain projections + // like Add(sumResult, 1). + // For a aggregate like agg(expr) + 1, the rewrite process is: rewrite agg(expr) -> + // 1. pushdown: agg(expr) as agg1, if avg then sum(expr), count(expr) + // 2. residual expr (for Spark itself): agg(agg1) as finalAgg1 the parameter is a + // reference to pushed plan's corresponding aggregation + // 3. resultExpressions: finalAgg1 + 1, the finalAgg1 is the reference to final result + // of the aggregation + case TiAggregation( + groupingExpressions, + aggregateExpressions, + resultExpressions, + TiAggregationProjectionV2(filters, _, `source`, projects)) + if isValidAggregates(groupingExpressions, aggregateExpressions, filters, source) => + val projectSet = AttributeSet((projects ++ filters).flatMap { + _.references + }) + val tiColumns = buildTiColumnRefFromColumnSeq(projectSet, source) + val dagReq: TiDAGRequest = + filterToDAGRequest(tiColumns, filters, source, newTiDAGRequest()) + groupAggregateProjection( + tiColumns, + groupingExpressions, + aggregateExpressions, + resultExpressions, + `source`, + dagReq) + case _ => Nil + } +}