From 83afee33562f092e07d1b6d4ee5748791f6edcc4 Mon Sep 17 00:00:00 2001
From: tomato <38561029+qidi1@users.noreply.github.com>
Date: Mon, 22 Aug 2022 11:42:53 +0800
Subject: [PATCH] support Spark3.3 (#2492)
---
assembly/src/main/assembly/assembly.xml | 9 +
core-test/pom.xml | 20 +-
core/pom.xml | 16 +-
.../com/pingcap/tispark/TiSparkInfo.scala | 2 +-
.../tispark/utils/ReflectionUtil.scala | 24 +
.../analyzer/TiAuthorizationRule.scala | 12 +-
.../sql/catalyst/catalog/TiCatalog.scala | 8 +-
.../spark/sql/catalyst/parser/TiParser.scala | 3 +
.../plans/logical/BasicLogicalPlan.scala | 29 +
.../delete/DeleteNotSupportSuite.scala | 5 +-
.../overflow/UnsignedOverflowSuite.scala | 2 +-
.../tispark/telemetry/TelemetrySuite.scala | 4 +
docs/authorization_userguide.md | 4 +-
docs/delete_userguide.md | 3 +-
pom.xml | 16 +-
.../plans/logical/TiBasicLogicalPlan.scala | 36 +
.../plans/logical/TiBasicLogicalPlan.scala | 36 +
.../TiAggregationProjectionV2.scala | 2 +-
.../plans/logical/TiBasicLogicalPlan.scala | 37 +
.../connector/write/TiDBWriteBuilder.scala | 1 -
.../TiAggregationProjectionV2.scala | 2 +-
spark-wrapper/spark-3.3/pom.xml | 129 ++++
.../com/pingcap/tispark/SparkWrapper.scala | 52 ++
.../expressions/TiBasicExpression.scala | 133 ++++
.../plans/logical/TiBasicLogicalPlan.scala | 43 ++
.../connector/write/TiDBWriteBuilder.scala | 47 ++
.../TiAggregationProjectionV2.scala | 50 ++
.../spark/sql/extensions/TiStrategy.scala | 658 ++++++++++++++++++
28 files changed, 1356 insertions(+), 27 deletions(-)
create mode 100644 core/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/BasicLogicalPlan.scala
create mode 100644 spark-wrapper/spark-3.0/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TiBasicLogicalPlan.scala
create mode 100644 spark-wrapper/spark-3.1/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TiBasicLogicalPlan.scala
create mode 100644 spark-wrapper/spark-3.2/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TiBasicLogicalPlan.scala
create mode 100644 spark-wrapper/spark-3.3/pom.xml
create mode 100644 spark-wrapper/spark-3.3/src/main/scala/com/pingcap/tispark/SparkWrapper.scala
create mode 100644 spark-wrapper/spark-3.3/src/main/scala/org/apache/spark/sql/catalyst/expressions/TiBasicExpression.scala
create mode 100644 spark-wrapper/spark-3.3/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TiBasicLogicalPlan.scala
create mode 100644 spark-wrapper/spark-3.3/src/main/scala/org/apache/spark/sql/connector/write/TiDBWriteBuilder.scala
create mode 100644 spark-wrapper/spark-3.3/src/main/scala/org/apache/spark/sql/extensions/TiAggregationProjectionV2.scala
create mode 100644 spark-wrapper/spark-3.3/src/main/scala/org/apache/spark/sql/extensions/TiStrategy.scala
diff --git a/assembly/src/main/assembly/assembly.xml b/assembly/src/main/assembly/assembly.xml
index e8ad969fa9..bb997e3532 100644
--- a/assembly/src/main/assembly/assembly.xml
+++ b/assembly/src/main/assembly/assembly.xml
@@ -45,5 +45,14 @@
**/*
+
+
+ ${project.parent.basedir}/spark-wrapper/spark-3.3/target/classes/
+
+ resources/spark-wrapper-spark-3_3
+
+ **/*
+
+
diff --git a/core-test/pom.xml b/core-test/pom.xml
index 89673077aa..126fbb03e9 100644
--- a/core-test/pom.xml
+++ b/core-test/pom.xml
@@ -71,17 +71,27 @@
org.apache.logging.log4j
log4j-api
- 2.17.1
+ 2.17.2
org.apache.logging.log4j
log4j-core
- 2.17.1
+ 2.17.2
org.apache.spark
spark-core_${scala.binary.version}
${spark.version.test}
+
+
+ org.apache.logging.log4j
+ log4j-api
+
+
+ org.apache.logging.log4j
+ log4j-core
+
+
org.apache.spark
@@ -118,6 +128,12 @@
org.apache.hadoop
hadoop-client
2.7.2
+
+
+ org.slf4j
+ slf4j-log4j12
+
+
com.google.guava
diff --git a/core/pom.xml b/core/pom.xml
index 6250680b83..ff14dd3f5d 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -44,12 +44,12 @@
org.apache.logging.log4j
log4j-api
- 2.17.1
+ 2.17.2
org.apache.logging.log4j
log4j-core
- 2.17.1
+ 2.17.2
org.apache.spark
@@ -60,6 +60,14 @@
io.netty
netty-all
+
+ org.apache.logging.log4j
+ log4j-api
+
+
+ org.apache.logging.log4j
+ log4j-core
+
@@ -107,6 +115,10 @@
io.netty
netty-all
+
+ org.slf4j
+ slf4j-log4j12
+
diff --git a/core/src/main/scala/com/pingcap/tispark/TiSparkInfo.scala b/core/src/main/scala/com/pingcap/tispark/TiSparkInfo.scala
index 638f8713f4..9ff081a147 100644
--- a/core/src/main/scala/com/pingcap/tispark/TiSparkInfo.scala
+++ b/core/src/main/scala/com/pingcap/tispark/TiSparkInfo.scala
@@ -22,7 +22,7 @@ import org.slf4j.LoggerFactory
object TiSparkInfo {
private final val logger = LoggerFactory.getLogger(getClass.getName)
- val SUPPORTED_SPARK_VERSION: List[String] = "3.0" :: "3.1" :: "3.2" :: Nil
+ val SUPPORTED_SPARK_VERSION: List[String] = "3.0" :: "3.1" :: "3.2" :: "3.3" :: Nil
val SPARK_VERSION: String = org.apache.spark.SPARK_VERSION
diff --git a/core/src/main/scala/com/pingcap/tispark/utils/ReflectionUtil.scala b/core/src/main/scala/com/pingcap/tispark/utils/ReflectionUtil.scala
index d3d897b50c..4129436e47 100644
--- a/core/src/main/scala/com/pingcap/tispark/utils/ReflectionUtil.scala
+++ b/core/src/main/scala/com/pingcap/tispark/utils/ReflectionUtil.scala
@@ -18,14 +18,18 @@
package com.pingcap.tispark.utils
import com.pingcap.tispark.TiSparkInfo
+import com.pingcap.tispark.auth.TiAuthorization
import com.pingcap.tispark.write.TiDBOptions
import org.apache.spark.sql.{SQLContext, SparkSession, Strategy, TiContext}
import org.apache.spark.sql.catalyst.expressions.BasicExpression.TiExpression
import org.apache.spark.sql.catalyst.expressions.{Alias, ExprId, Expression, SortOrder}
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.{SparkSession, Strategy, TiContext}
import org.apache.spark.sql.connector.write.{LogicalWriteInfo, WriteBuilder}
import org.slf4j.LoggerFactory
import java.io.File
+import java.lang.reflect.InvocationTargetException
import java.net.{URL, URLClassLoader}
/**
@@ -60,6 +64,8 @@ object ReflectionUtil {
private val SPARK_WRAPPER_CLASS = "com.pingcap.tispark.SparkWrapper"
private val TI_BASIC_EXPRESSION_CLASS =
"org.apache.spark.sql.catalyst.expressions.TiBasicExpression"
+ private val TI_BASIC_LOGICAL_PLAN_CLASS =
+ "org.apache.spark.sql.catalyst.plans.logical.TiBasicLogicalPlan"
private val TI_STRATEGY_CLASS =
"org.apache.spark.sql.extensions.TiStrategy"
private val TIDB_WRITE_BUILDER_CLASS =
@@ -105,6 +111,24 @@ object ReflectionUtil {
.asInstanceOf[Option[TiExpression]]
}
+ def callTiBasicLogicalPlanVerifyAuthorizationRule(
+ logicalPlan: LogicalPlan,
+ tiAuthorization: Option[TiAuthorization]): LogicalPlan = {
+ try {
+ classLoader
+ .loadClass(TI_BASIC_LOGICAL_PLAN_CLASS)
+ .getDeclaredMethod(
+ "verifyAuthorizationRule",
+ classOf[LogicalPlan],
+ classOf[Option[TiAuthorization]])
+ .invoke(null, logicalPlan, tiAuthorization)
+ .asInstanceOf[LogicalPlan]
+ } catch {
+ case ex: InvocationTargetException =>
+ throw ex.getTargetException
+ }
+ }
+
def newTiStrategy(
getOrCreateTiContext: SparkSession => TiContext,
sparkSession: SparkSession): Strategy = {
diff --git a/core/src/main/scala/org/apache/spark/sql/catalyst/analyzer/TiAuthorizationRule.scala b/core/src/main/scala/org/apache/spark/sql/catalyst/analyzer/TiAuthorizationRule.scala
index 9593897b3d..1a65ca5ee1 100644
--- a/core/src/main/scala/org/apache/spark/sql/catalyst/analyzer/TiAuthorizationRule.scala
+++ b/core/src/main/scala/org/apache/spark/sql/catalyst/analyzer/TiAuthorizationRule.scala
@@ -17,7 +17,9 @@
package org.apache.spark.sql.catalyst.analyzer
import com.pingcap.tispark.auth.TiAuthorization
+import com.pingcap.tispark.v2.TiDBTable
import org.apache.spark.sql.catalyst.plans.logical.{
+ BasicLogicalPlan,
DeleteFromTable,
LogicalPlan,
SetCatalogAndNamespace,
@@ -26,7 +28,6 @@ import org.apache.spark.sql.catalyst.plans.logical.{
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.{SparkSession, TiContext}
-import com.pingcap.tispark.v2.TiDBTable
/**
* Only work for table v2(catalog plugin)
@@ -34,7 +35,6 @@ import com.pingcap.tispark.v2.TiDBTable
case class TiAuthorizationRule(getOrCreateTiContext: SparkSession => TiContext)(
sparkSession: SparkSession)
extends Rule[LogicalPlan] {
-
protected val tiContext: TiContext = getOrCreateTiContext(sparkSession)
private lazy val tiAuthorization: Option[TiAuthorization] = tiContext.tiAuthorization
@@ -47,12 +47,8 @@ case class TiAuthorizationRule(getOrCreateTiContext: SparkSession => TiContext)(
tiAuthorization)
}
dt
- case sd @ SetCatalogAndNamespace(catalogManager, catalogName, namespace) =>
- if (catalogName.nonEmpty && catalogName.get.equals("tidb_catalog") && namespace.isDefined) {
- namespace.get
- .foreach(TiAuthorization.authorizeForSetDatabase(_, tiAuthorization))
- }
- sd
+ case s: SetCatalogAndNamespace =>
+ BasicLogicalPlan.verifyAuthorizationRule(s, tiAuthorization)
case dr @ DataSourceV2Relation(
TiDBTable(_, tableRef, _, _, _),
output,
diff --git a/core/src/main/scala/org/apache/spark/sql/catalyst/catalog/TiCatalog.scala b/core/src/main/scala/org/apache/spark/sql/catalyst/catalog/TiCatalog.scala
index feda4d7313..629f8e1578 100644
--- a/core/src/main/scala/org/apache/spark/sql/catalyst/catalog/TiCatalog.scala
+++ b/core/src/main/scala/org/apache/spark/sql/catalyst/catalog/TiCatalog.scala
@@ -167,13 +167,17 @@ class TiCatalog extends TableCatalog with SupportsNamespaces {
override def renameTable(oldIdent: Identifier, newIdent: Identifier): Unit = ???
- override def dropNamespace(namespace: Array[String]): Boolean = ???
-
override def createNamespace(
namespace: Array[String],
metadata: util.Map[String, String]): Unit =
???
+ // for spark version smaller than 3.3
+ def dropNamespace(strings: Array[String]): Boolean = ???
+
+ // for spark version bigger equal 3.3
+ def dropNamespace(namespace: Array[String], cascade: Boolean): Boolean = ???
+
override def alterNamespace(namespace: Array[String], changes: NamespaceChange*): Unit = ???
}
diff --git a/core/src/main/scala/org/apache/spark/sql/catalyst/parser/TiParser.scala b/core/src/main/scala/org/apache/spark/sql/catalyst/parser/TiParser.scala
index 3067c110ed..cde71277b8 100644
--- a/core/src/main/scala/org/apache/spark/sql/catalyst/parser/TiParser.scala
+++ b/core/src/main/scala/org/apache/spark/sql/catalyst/parser/TiParser.scala
@@ -54,6 +54,9 @@ case class TiParser(
@scala.throws[ParseException]("Text cannot be parsed to a DataType")
def parseRawDataType(sqlText: String): DataType = ???
+ @scala.throws[ParseException]("Text cannot be parsed to a LogicalPlan")
+ def parseQuery(sqlText: String): LogicalPlan = ???
+
def getOrElseInitTiCatalog: TiCatalog = {
val catalogManager = sparkSession.sessionState.catalogManager
catalogManager.catalog("tidb_catalog").asInstanceOf[TiCatalog]
diff --git a/core/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/BasicLogicalPlan.scala b/core/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/BasicLogicalPlan.scala
new file mode 100644
index 0000000000..6c2431f899
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/BasicLogicalPlan.scala
@@ -0,0 +1,29 @@
+/*
+ *
+ * Copyright 2022 PingCAP, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+package org.apache.spark.sql.catalyst.plans.logical
+
+import com.pingcap.tispark.auth.TiAuthorization
+import com.pingcap.tispark.utils.ReflectionUtil
+
+object BasicLogicalPlan {
+ def verifyAuthorizationRule(
+ logicalPlan: LogicalPlan,
+ tiAuthorization: Option[TiAuthorization]): LogicalPlan =
+ ReflectionUtil.callTiBasicLogicalPlanVerifyAuthorizationRule(logicalPlan, tiAuthorization)
+}
diff --git a/core/src/test/scala/com/pingcap/tispark/delete/DeleteNotSupportSuite.scala b/core/src/test/scala/com/pingcap/tispark/delete/DeleteNotSupportSuite.scala
index 6eb8b3be58..3d56f11f37 100644
--- a/core/src/test/scala/com/pingcap/tispark/delete/DeleteNotSupportSuite.scala
+++ b/core/src/test/scala/com/pingcap/tispark/delete/DeleteNotSupportSuite.scala
@@ -31,10 +31,11 @@ class DeleteNotSupportSuite extends BaseBatchWriteTest("test_delete_not_support"
test("Delete without WHERE clause") {
jdbcUpdate(s"create table $dbtable(i int, s int,PRIMARY KEY (i))")
-
the[IllegalArgumentException] thrownBy {
spark.sql(s"delete from $dbtable")
- } should have message "requirement failed: Delete without WHERE clause is not supported"
+ } should (have message "requirement failed: Delete without WHERE clause is not supported"
+ // when where clause is empty, Spark3.3 will send AlwaysTrue() to TiSpark.
+ or have message "requirement failed: Delete with alwaysTrue WHERE clause is not supported")
}
test("Delete with alwaysTrue WHERE clause") {
diff --git a/core/src/test/scala/com/pingcap/tispark/overflow/UnsignedOverflowSuite.scala b/core/src/test/scala/com/pingcap/tispark/overflow/UnsignedOverflowSuite.scala
index 1926c3d305..4b6be247e4 100644
--- a/core/src/test/scala/com/pingcap/tispark/overflow/UnsignedOverflowSuite.scala
+++ b/core/src/test/scala/com/pingcap/tispark/overflow/UnsignedOverflowSuite.scala
@@ -262,7 +262,7 @@ class UnsignedOverflowSuite extends BaseBatchWriteTest("test_data_type_unsigned_
val jdbcErrorClass = classOf[java.lang.RuntimeException]
val tidbErrorClass = classOf[java.lang.RuntimeException]
val tidbErrorMsgStartWith =
- "Error while encoding: java.lang.RuntimeException: java.lang.Integer is not a valid external type for schema of bigint\nif (assertnotnull(input[0, org.apache.spark.sql.Row, true]).isNullAt) null else validateexternaltype(getexternalrowfield(assertnotnull(input[0, org.apache.spark.sql.Row, true]), 0, c1), LongType) AS c1"
+ "Error while encoding: java.lang.RuntimeException: java.lang.Integer is not a valid external type for schema of bigint"
compareTiDBWriteFailureWithJDBC(
List(row),
diff --git a/core/src/test/scala/com/pingcap/tispark/telemetry/TelemetrySuite.scala b/core/src/test/scala/com/pingcap/tispark/telemetry/TelemetrySuite.scala
index f548913e10..f2831dd954 100644
--- a/core/src/test/scala/com/pingcap/tispark/telemetry/TelemetrySuite.scala
+++ b/core/src/test/scala/com/pingcap/tispark/telemetry/TelemetrySuite.scala
@@ -16,6 +16,9 @@
package com.pingcap.tispark.telemetry
+import com.pingcap.tikv.{TiConfiguration, TiSession}
+import com.pingcap.tispark.auth.TiAuthorization.tiConf
+import com.pingcap.tispark.listener.CacheInvalidateListener
import com.pingcap.tispark.utils.HttpClientUtil
import com.sun.net.httpserver.{
HttpExchange,
@@ -24,6 +27,7 @@ import com.sun.net.httpserver.{
HttpsConfigurator,
HttpsServer
}
+import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.test.SharedSQLContext
import org.scalatest.Matchers.{be, noException}
diff --git a/docs/authorization_userguide.md b/docs/authorization_userguide.md
index ae9ed0c1d0..b5ecead472 100644
--- a/docs/authorization_userguide.md
+++ b/docs/authorization_userguide.md
@@ -9,10 +9,10 @@ perform.
This feature allows you to execute SQL in TiSpark with Authorization and authentication, the same behavior as TiDB
## Prerequisites
-
+
- The database's user account must have the `PROCESS` privilege.
- TiSpark version >= 2.5.0
-- Spark version = 3.0.x or 3.1.x
+- Spark version = 3.0.x or 3.1.x or 3.2.x or 3.3.x
## Setup
diff --git a/docs/delete_userguide.md b/docs/delete_userguide.md
index 070a5a696e..6b11ef6198 100644
--- a/docs/delete_userguide.md
+++ b/docs/delete_userguide.md
@@ -11,7 +11,7 @@ spark.sql.catalog.tidb_catalog.pd.addresses ${your_pd_address}
## Requirement
- TiDB 4.x or 5.x
-- Spark >= 3.0
+- Spark = 3.0.x or 3.1.x or 3.2.x or 3.3.x
## Delete with SQL
```
@@ -38,4 +38,3 @@ spark.sql("delete from tidb_catalog.db.table where xxx")
- Delete from partition table is not supported.
- Delete with Pessimistic Transaction Mode is not supported.
-
diff --git a/pom.xml b/pom.xml
index 0de6335645..7041fc1afa 100644
--- a/pom.xml
+++ b/pom.xml
@@ -170,12 +170,13 @@
spark-wrapper/spark-3.0
spark-wrapper/spark-3.1
spark-wrapper/spark-3.2
+ spark-wrapper/spark-3.3
assembly
- spark-3.1.1
+ spark-3.1
false
@@ -186,7 +187,7 @@
- spark-3.2.1
+ spark-3.2
false
@@ -196,6 +197,17 @@
3.2
+
+ spark-3.3
+
+ false
+
+
+ 3.3.0
+ 3.3.0
+ 3.3
+
+
jenkins
diff --git a/spark-wrapper/spark-3.0/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TiBasicLogicalPlan.scala b/spark-wrapper/spark-3.0/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TiBasicLogicalPlan.scala
new file mode 100644
index 0000000000..801862affd
--- /dev/null
+++ b/spark-wrapper/spark-3.0/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TiBasicLogicalPlan.scala
@@ -0,0 +1,36 @@
+/*
+ * Copyright 2022 PingCAP, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.plans.logical
+
+import com.pingcap.tispark.auth.TiAuthorization
+
+object TiBasicLogicalPlan {
+ def verifyAuthorizationRule(
+ logicalPlan: LogicalPlan,
+ tiAuthorization: Option[TiAuthorization]): LogicalPlan = {
+ logicalPlan match {
+ case st @ SetCatalogAndNamespace(catalogManager, catalogName, namespace) =>
+ if (catalogName.nonEmpty && catalogName.get.equals(
+ "tidb_catalog") && namespace.isDefined) {
+ namespace.get
+ .foreach(TiAuthorization.authorizeForSetDatabase(_, tiAuthorization))
+ }
+ st
+ }
+ }
+
+}
diff --git a/spark-wrapper/spark-3.1/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TiBasicLogicalPlan.scala b/spark-wrapper/spark-3.1/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TiBasicLogicalPlan.scala
new file mode 100644
index 0000000000..801862affd
--- /dev/null
+++ b/spark-wrapper/spark-3.1/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TiBasicLogicalPlan.scala
@@ -0,0 +1,36 @@
+/*
+ * Copyright 2022 PingCAP, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.plans.logical
+
+import com.pingcap.tispark.auth.TiAuthorization
+
+object TiBasicLogicalPlan {
+ def verifyAuthorizationRule(
+ logicalPlan: LogicalPlan,
+ tiAuthorization: Option[TiAuthorization]): LogicalPlan = {
+ logicalPlan match {
+ case st @ SetCatalogAndNamespace(catalogManager, catalogName, namespace) =>
+ if (catalogName.nonEmpty && catalogName.get.equals(
+ "tidb_catalog") && namespace.isDefined) {
+ namespace.get
+ .foreach(TiAuthorization.authorizeForSetDatabase(_, tiAuthorization))
+ }
+ st
+ }
+ }
+
+}
diff --git a/spark-wrapper/spark-3.1/src/main/scala/org/apache/spark/sql/extensions/TiAggregationProjectionV2.scala b/spark-wrapper/spark-3.1/src/main/scala/org/apache/spark/sql/extensions/TiAggregationProjectionV2.scala
index 3603f9c6ae..08193a2675 100644
--- a/spark-wrapper/spark-3.1/src/main/scala/org/apache/spark/sql/extensions/TiAggregationProjectionV2.scala
+++ b/spark-wrapper/spark-3.1/src/main/scala/org/apache/spark/sql/extensions/TiAggregationProjectionV2.scala
@@ -27,7 +27,7 @@ import org.apache.spark.sql.execution.datasources.v2.{
/**
* I'm afraid that the same name with the object under the spark-wrapper/spark3.0 will lead to some problems.
- * Although the same name will pass the itegration test
+ * Although the same name will pass the integration test
*/
object TiAggregationProjectionV2 {
type ReturnType = (Seq[Expression], LogicalPlan, TiDBTable, Seq[NamedExpression])
diff --git a/spark-wrapper/spark-3.2/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TiBasicLogicalPlan.scala b/spark-wrapper/spark-3.2/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TiBasicLogicalPlan.scala
new file mode 100644
index 0000000000..1138f3ce4d
--- /dev/null
+++ b/spark-wrapper/spark-3.2/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TiBasicLogicalPlan.scala
@@ -0,0 +1,37 @@
+/*
+ * Copyright 2022 PingCAP, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.plans.logical
+
+import com.pingcap.tispark.auth.TiAuthorization
+import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SetCatalogAndNamespace}
+
+object TiBasicLogicalPlan {
+ def verifyAuthorizationRule(
+ logicalPlan: LogicalPlan,
+ tiAuthorization: Option[TiAuthorization]): LogicalPlan = {
+ logicalPlan match {
+ case st @ SetCatalogAndNamespace(catalogManager, catalogName, namespace) =>
+ if (catalogName.nonEmpty && catalogName.get.equals(
+ "tidb_catalog") && namespace.isDefined) {
+ namespace.get
+ .foreach(TiAuthorization.authorizeForSetDatabase(_, tiAuthorization))
+ }
+ st
+ }
+ }
+
+}
diff --git a/spark-wrapper/spark-3.2/src/main/scala/org/apache/spark/sql/connector/write/TiDBWriteBuilder.scala b/spark-wrapper/spark-3.2/src/main/scala/org/apache/spark/sql/connector/write/TiDBWriteBuilder.scala
index 98361987db..c8a7fbf7df 100644
--- a/spark-wrapper/spark-3.2/src/main/scala/org/apache/spark/sql/connector/write/TiDBWriteBuilder.scala
+++ b/spark-wrapper/spark-3.2/src/main/scala/org/apache/spark/sql/connector/write/TiDBWriteBuilder.scala
@@ -31,7 +31,6 @@ case class TiDBWriteBuilder(
new InsertableRelation {
override def insert(data: DataFrame, overwrite: Boolean): Unit = {
val schema = info.schema()
- println("Do write")
val df = sqlContext.sparkSession.createDataFrame(data.toJavaRDD, schema)
df.write
.format("tidb")
diff --git a/spark-wrapper/spark-3.2/src/main/scala/org/apache/spark/sql/extensions/TiAggregationProjectionV2.scala b/spark-wrapper/spark-3.2/src/main/scala/org/apache/spark/sql/extensions/TiAggregationProjectionV2.scala
index 3603f9c6ae..08193a2675 100644
--- a/spark-wrapper/spark-3.2/src/main/scala/org/apache/spark/sql/extensions/TiAggregationProjectionV2.scala
+++ b/spark-wrapper/spark-3.2/src/main/scala/org/apache/spark/sql/extensions/TiAggregationProjectionV2.scala
@@ -27,7 +27,7 @@ import org.apache.spark.sql.execution.datasources.v2.{
/**
* I'm afraid that the same name with the object under the spark-wrapper/spark3.0 will lead to some problems.
- * Although the same name will pass the itegration test
+ * Although the same name will pass the integration test
*/
object TiAggregationProjectionV2 {
type ReturnType = (Seq[Expression], LogicalPlan, TiDBTable, Seq[NamedExpression])
diff --git a/spark-wrapper/spark-3.3/pom.xml b/spark-wrapper/spark-3.3/pom.xml
new file mode 100644
index 0000000000..024340a6d0
--- /dev/null
+++ b/spark-wrapper/spark-3.3/pom.xml
@@ -0,0 +1,129 @@
+
+
+ 4.0.0
+
+ com.pingcap.tispark
+ tispark-parent
+ 3.1.0-SNAPSHOT
+ ../../pom.xml
+
+
+ spark-wrapper-spark-3.3_${scala.version.release}
+ jar
+ TiSpark Project Spark Wrapper Spark-3.3
+ http://github.copm/pingcap/tispark
+
+
+ 3.3.0
+
+
+
+
+ com.pingcap.tispark
+ tispark-core-internal
+ ${project.parent.version}
+
+
+ org.apache.spark
+ spark-core_2.12
+ ${spark.version.wrapper}
+
+
+ org.apache.logging.log4j
+ log4j-slf4j-impl
+
+
+
+
+ org.apache.spark
+ spark-catalyst_2.12
+ ${spark.version.wrapper}
+
+
+ org.apache.spark
+ spark-sql_2.12
+ ${spark.version.wrapper}
+
+
+
+
+ src/main/scala
+
+
+ net.alchim31.maven
+ scala-maven-plugin
+ 4.3.0
+
+
+ compile-scala
+ compile
+
+ add-source
+ compile
+
+
+
+ test-compile-scala
+ test-compile
+
+ add-source
+ testCompile
+
+
+
+ attach-javadocs
+
+ doc-jar
+
+
+
+
+ ${scala.version}
+
+
+
+ org.apache.maven.plugins
+ maven-compiler-plugin
+ 2.3.2
+
+ 1.8
+ 1.8
+ UTF-8
+ true
+ true
+
+
+
+
+ org.apache.maven.plugins
+ maven-source-plugin
+ 3.0.1
+
+
+ attach-sources
+
+ jar-no-fork
+
+
+
+
+
+
+ org.antipathy
+ mvn-scalafmt_${scala.binary.version}
+ 1.0.3
+
+ ${scalafmt.skip}
+ ${scalafmt.skip}
+
+ ${project.basedir}/src/main/scala
+
+
+ ${project.basedir}/src/test/scala
+
+
+
+
+
+
diff --git a/spark-wrapper/spark-3.3/src/main/scala/com/pingcap/tispark/SparkWrapper.scala b/spark-wrapper/spark-3.3/src/main/scala/com/pingcap/tispark/SparkWrapper.scala
new file mode 100644
index 0000000000..542f39cb9b
--- /dev/null
+++ b/spark-wrapper/spark-3.3/src/main/scala/com/pingcap/tispark/SparkWrapper.scala
@@ -0,0 +1,52 @@
+package com.pingcap.tispark
+/*
+ * Copyright 2022 PingCAP, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import org.apache.spark.sql.catalyst.expressions.{
+ Alias,
+ AliasHelper,
+ ExprId,
+ Expression,
+ SortOrder
+}
+
+object SparkWrapper {
+ def getVersion: String = {
+ "SparkWrapper-3.3"
+ }
+
+ def newAlias(child: Expression, name: String): Alias = {
+ Alias(child, name)()
+ }
+
+ def newAlias(child: Expression, name: String, exprId: ExprId): Alias = {
+ Alias(child, name)(exprId = exprId)
+ }
+
+ def trimNonTopLevelAliases(e: Expression): Expression = {
+ TiCleanupAliases.trimNonTopLevelAliases2(e)
+ }
+
+ def copySortOrder(sortOrder: SortOrder, child: Expression): SortOrder = {
+ sortOrder.copy(child = child)
+ }
+}
+
+object TiCleanupAliases extends AliasHelper {
+ def trimNonTopLevelAliases2[T <: Expression](e: T): T = {
+ super.trimNonTopLevelAliases(e)
+ }
+}
diff --git a/spark-wrapper/spark-3.3/src/main/scala/org/apache/spark/sql/catalyst/expressions/TiBasicExpression.scala b/spark-wrapper/spark-3.3/src/main/scala/org/apache/spark/sql/catalyst/expressions/TiBasicExpression.scala
new file mode 100644
index 0000000000..227e547617
--- /dev/null
+++ b/spark-wrapper/spark-3.3/src/main/scala/org/apache/spark/sql/catalyst/expressions/TiBasicExpression.scala
@@ -0,0 +1,133 @@
+/*
+ *
+ * Copyright 2022 PingCAP, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import com.pingcap.tikv.expression.{
+ ArithmeticBinaryExpression,
+ ColumnRef,
+ ComparisonBinaryExpression,
+ Constant,
+ LogicalBinaryExpression,
+ StringRegExpression
+}
+import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
+import org.apache.spark.sql.catalyst.expressions.BasicExpression.{
+ TiExpression,
+ TiIsNull,
+ TiNot,
+ convertLiteral
+}
+import org.apache.spark.sql.execution.TiConverter
+import org.apache.spark.sql.types.DecimalType
+
+object TiBasicExpression {
+
+ def convertToTiExpr(expr: Expression): Option[TiExpression] =
+ expr match {
+ case Literal(value, dataType) =>
+ Some(
+ Constant.create(convertLiteral(value, dataType), TiConverter.fromSparkType(dataType)))
+
+ case Add(BasicExpression(lhs), BasicExpression(rhs), _) =>
+ Some(ArithmeticBinaryExpression.plus(lhs, rhs))
+
+ case Subtract(BasicExpression(lhs), BasicExpression(rhs), _) =>
+ Some(ArithmeticBinaryExpression.minus(lhs, rhs))
+
+ case e @ Multiply(BasicExpression(lhs), BasicExpression(rhs), _) =>
+ Some(ArithmeticBinaryExpression.multiply(TiConverter.fromSparkType(e.dataType), lhs, rhs))
+
+ case d @ Divide(BasicExpression(lhs), BasicExpression(rhs), _) =>
+ Some(ArithmeticBinaryExpression.divide(TiConverter.fromSparkType(d.dataType), lhs, rhs))
+
+ case And(BasicExpression(lhs), BasicExpression(rhs)) =>
+ Some(LogicalBinaryExpression.and(lhs, rhs))
+
+ case Or(BasicExpression(lhs), BasicExpression(rhs)) =>
+ Some(LogicalBinaryExpression.or(lhs, rhs))
+
+ case Alias(BasicExpression(child), _) =>
+ Some(child)
+
+ case IsNull(BasicExpression(child)) =>
+ Some(new TiIsNull(child))
+
+ case IsNotNull(BasicExpression(child)) =>
+ Some(new TiNot(new TiIsNull(child)))
+
+ case GreaterThan(BasicExpression(lhs), BasicExpression(rhs)) =>
+ Some(ComparisonBinaryExpression.greaterThan(lhs, rhs))
+
+ case GreaterThanOrEqual(BasicExpression(lhs), BasicExpression(rhs)) =>
+ Some(ComparisonBinaryExpression.greaterEqual(lhs, rhs))
+
+ case LessThan(BasicExpression(lhs), BasicExpression(rhs)) =>
+ Some(ComparisonBinaryExpression.lessThan(lhs, rhs))
+
+ case LessThanOrEqual(BasicExpression(lhs), BasicExpression(rhs)) =>
+ Some(ComparisonBinaryExpression.lessEqual(lhs, rhs))
+
+ case EqualTo(BasicExpression(lhs), BasicExpression(rhs)) =>
+ Some(ComparisonBinaryExpression.equal(lhs, rhs))
+
+ case Not(EqualTo(BasicExpression(lhs), BasicExpression(rhs))) =>
+ Some(ComparisonBinaryExpression.notEqual(lhs, rhs))
+
+ case Not(BasicExpression(child)) =>
+ Some(new TiNot(child))
+
+ case StartsWith(BasicExpression(lhs), BasicExpression(rhs)) =>
+ Some(StringRegExpression.startsWith(lhs, rhs))
+
+ case Contains(BasicExpression(lhs), BasicExpression(rhs)) =>
+ Some(StringRegExpression.contains(lhs, rhs))
+
+ case EndsWith(BasicExpression(lhs), BasicExpression(rhs)) =>
+ Some(StringRegExpression.endsWith(lhs, rhs))
+
+ case Like(BasicExpression(lhs), BasicExpression(rhs), _) =>
+ Some(StringRegExpression.like(lhs, rhs))
+
+ // Coprocessor has its own behavior of type promoting and overflow check
+ // so we simply remove it from expression and let cop handle it
+ case CheckOverflow(BasicExpression(expr), dec: DecimalType, _) =>
+ expr.setDataType(TiConverter.fromSparkType(dec))
+ Some(expr)
+
+ case PromotePrecision(BasicExpression(expr)) =>
+ Some(expr)
+
+ case PromotePrecision(Cast(BasicExpression(expr), dec: DecimalType, _, _)) =>
+ expr.setDataType(TiConverter.fromSparkType(dec))
+ Some(expr)
+
+ case PromotePrecision(BasicExpression(expr)) =>
+ Some(expr)
+
+ // TODO: Are all AttributeReference column reference in such context?
+ case attr: AttributeReference =>
+ Some(ColumnRef.create(attr.name, TiConverter.fromSparkType(attr.dataType)))
+
+ case uAttr: UnresolvedAttribute =>
+ Some(ColumnRef.create(uAttr.name, TiConverter.fromSparkType(uAttr.dataType)))
+
+ // TODO: Remove it and let it fail once done all translation
+ case _ => Option.empty[TiExpression]
+ }
+}
diff --git a/spark-wrapper/spark-3.3/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TiBasicLogicalPlan.scala b/spark-wrapper/spark-3.3/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TiBasicLogicalPlan.scala
new file mode 100644
index 0000000000..d51fc90199
--- /dev/null
+++ b/spark-wrapper/spark-3.3/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TiBasicLogicalPlan.scala
@@ -0,0 +1,43 @@
+/*
+ * Copyright 2022 PingCAP, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.plans.logical
+
+import com.pingcap.tispark.auth.TiAuthorization
+import org.apache.spark.sql.catalyst.analysis.{ResolvedDBObjectName, ResolvedNamespace}
+
+object TiBasicLogicalPlan {
+ def verifyAuthorizationRule(
+ logicalPlan: LogicalPlan,
+ tiAuthorization: Option[TiAuthorization]): LogicalPlan = {
+ logicalPlan match {
+ case st @ SetCatalogAndNamespace(namespace) =>
+ namespace match {
+ case ResolvedNamespace(catalog, ns) =>
+ if (catalog.name().equals("tidb_catalog")) {
+ ns.foreach(TiAuthorization.authorizeForSetDatabase(_, tiAuthorization))
+ }
+ st
+ case ResolvedDBObjectName(catalog, nameParts) =>
+ if (catalog.name().equals("tidb_catalog")) {
+ nameParts.foreach(TiAuthorization.authorizeForSetDatabase(_, tiAuthorization))
+ }
+ st
+ }
+ }
+ }
+
+}
diff --git a/spark-wrapper/spark-3.3/src/main/scala/org/apache/spark/sql/connector/write/TiDBWriteBuilder.scala b/spark-wrapper/spark-3.3/src/main/scala/org/apache/spark/sql/connector/write/TiDBWriteBuilder.scala
new file mode 100644
index 0000000000..c8a7fbf7df
--- /dev/null
+++ b/spark-wrapper/spark-3.3/src/main/scala/org/apache/spark/sql/connector/write/TiDBWriteBuilder.scala
@@ -0,0 +1,47 @@
+/*
+ * Copyright 2022 PingCAP, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector.write
+
+import com.pingcap.tispark.write.TiDBOptions
+import org.apache.spark.sql.sources.InsertableRelation
+import org.apache.spark.sql.{DataFrame, SQLContext, SaveMode}
+
+case class TiDBWriteBuilder(
+ info: LogicalWriteInfo,
+ tiDBOptions: TiDBOptions,
+ sqlContext: SQLContext)
+ extends WriteBuilder {
+ override def build(): V1Write =
+ new V1Write {
+ override def toInsertableRelation: InsertableRelation = {
+ new InsertableRelation {
+ override def insert(data: DataFrame, overwrite: Boolean): Unit = {
+ val schema = info.schema()
+ val df = sqlContext.sparkSession.createDataFrame(data.toJavaRDD, schema)
+ df.write
+ .format("tidb")
+ .options(tiDBOptions.parameters)
+ .option(TiDBOptions.TIDB_DATABASE, tiDBOptions.database)
+ .option(TiDBOptions.TIDB_TABLE, tiDBOptions.table)
+ .option(TiDBOptions.TIDB_DEDUPLICATE, "false")
+ .mode(SaveMode.Append)
+ .save()
+ }
+ }
+ }
+ }
+}
diff --git a/spark-wrapper/spark-3.3/src/main/scala/org/apache/spark/sql/extensions/TiAggregationProjectionV2.scala b/spark-wrapper/spark-3.3/src/main/scala/org/apache/spark/sql/extensions/TiAggregationProjectionV2.scala
new file mode 100644
index 0000000000..a85531333c
--- /dev/null
+++ b/spark-wrapper/spark-3.3/src/main/scala/org/apache/spark/sql/extensions/TiAggregationProjectionV2.scala
@@ -0,0 +1,50 @@
+/*
+ * Copyright 2022 PingCAP, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.extensions
+
+import com.pingcap.tispark.v2.TiDBTable
+import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, NamedExpression}
+import org.apache.spark.sql.catalyst.planning.PhysicalOperation
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.execution.datasources.v2.{
+ DataSourceV2Relation,
+ DataSourceV2ScanRelation
+}
+
+/**
+ * I'm afraid that the same name with the object under the spark-wrapper/spark3.0 will lead to some problems.
+ * Although the same name will pass the integration test
+ */
+object TiAggregationProjectionV2 {
+ type ReturnType = (Seq[Expression], LogicalPlan, TiDBTable, Seq[NamedExpression])
+
+ def unapply(plan: LogicalPlan): Option[ReturnType] =
+ plan match {
+ // Only push down aggregates projection when all filters can be applied and
+ // all projection expressions are column references
+ case PhysicalOperation(
+ projects,
+ filters,
+ rel @ DataSourceV2ScanRelation(
+ DataSourceV2Relation(source: TiDBTable, _, _, _, _),
+ _,
+ _,
+ _)) if projects.forall(_.isInstanceOf[Attribute]) =>
+ Some((filters, rel, source, projects))
+ case _ => Option.empty[ReturnType]
+ }
+}
diff --git a/spark-wrapper/spark-3.3/src/main/scala/org/apache/spark/sql/extensions/TiStrategy.scala b/spark-wrapper/spark-3.3/src/main/scala/org/apache/spark/sql/extensions/TiStrategy.scala
new file mode 100644
index 0000000000..64fc14e2b0
--- /dev/null
+++ b/spark-wrapper/spark-3.3/src/main/scala/org/apache/spark/sql/extensions/TiStrategy.scala
@@ -0,0 +1,658 @@
+/*
+ * Copyright 2022 PingCAP, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.extensions
+
+import com.pingcap.tidb.tipb.EncodeType
+import com.pingcap.tikv.exception.IgnoreUnsupportedTypeException
+import com.pingcap.tikv.expression._
+import com.pingcap.tikv.meta.TiDAGRequest.PushDownType
+import com.pingcap.tikv.meta.{TiDAGRequest, TiTimestamp}
+import com.pingcap.tikv.predicates.{PredicateUtils, TiKVScanAnalyzer}
+import com.pingcap.tikv.region.TiStoreType
+import com.pingcap.tikv.statistics.TableStatistics
+import com.pingcap.tispark.TiConfigConst
+import com.pingcap.tispark.statistics.StatisticsManager
+import com.pingcap.tispark.utils.{ReflectionUtil, TiUtil}
+import com.pingcap.tispark.v2.TiDBTable
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.expressions.aggregate._
+import org.apache.spark.sql.catalyst.expressions.{
+ Alias,
+ Ascending,
+ Attribute,
+ AttributeMap,
+ AttributeSet,
+ Descending,
+ Expression,
+ IntegerLiteral,
+ IsNull,
+ NamedExpression,
+ NullsFirst,
+ NullsLast,
+ SortOrder,
+ SubqueryExpression,
+ TiExprUtils
+}
+import org.apache.spark.sql.catalyst.planner.TiAggregation
+import org.apache.spark.sql.catalyst.planning.PhysicalOperation
+import org.apache.spark.sql.catalyst.plans.logical
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.execution._
+import org.apache.spark.sql.execution.datasources.v2.{
+ DataSourceV2Relation,
+ DataSourceV2ScanRelation
+}
+import org.apache.spark.sql.internal.SQLConf
+import org.joda.time.{DateTime, DateTimeZone}
+
+import java.util.concurrent.TimeUnit
+import scala.collection.JavaConverters._
+import scala.collection.mutable
+
+object TiStrategy {
+ private val assignedTSPlanCache = new mutable.WeakHashMap[LogicalPlan, Boolean]()
+
+ private def hasTSAssigned(plan: LogicalPlan): Boolean = {
+ assignedTSPlanCache.contains(plan)
+ }
+
+ private def markTSAssigned(plan: LogicalPlan): Unit = {
+ plan foreachUp { p =>
+ assignedTSPlanCache.put(p, true)
+ }
+ }
+}
+
+/**
+ * CHECK Spark [[org.apache.spark.sql.Strategy]]
+ *
+ * TODO: Too many hacks here since we hijack the planning
+ * but we don't have full control over planning stage
+ * We cannot pass context around during planning so
+ * a re-extract needed for push-down since
+ * a plan tree might contain Join which causes a single tree
+ * have multiple plans to push-down
+ */
+case class TiStrategy(getOrCreateTiContext: SparkSession => TiContext)(sparkSession: SparkSession)
+ extends Strategy
+ with Logging {
+ type TiExpression = com.pingcap.tikv.expression.Expression
+ type TiColumnRef = com.pingcap.tikv.expression.ColumnRef
+ private lazy val tiContext: TiContext = getOrCreateTiContext(sparkSession)
+ private lazy val sqlContext = tiContext.sqlContext
+ private lazy val sqlConf: SQLConf = sqlContext.conf
+
+ def typeBlockList: TypeBlocklist = {
+ val blocklistString =
+ sqlConf.getConfString(TiConfigConst.UNSUPPORTED_TYPES, "")
+ new TypeBlocklist(blocklistString)
+ }
+
+ override def apply(plan: LogicalPlan): Seq[SparkPlan] = {
+ TiExtensions.validateCatalog(sparkSession)
+ val ts = if (TiUtil.getTiDBSnapshot(sparkSession).isEmpty) {
+ tiContext.tiSession.getTimestamp
+ } else {
+ tiContext.tiSession.getSnapshotTimestamp
+ }
+
+ if (plan.isStreaming) {
+ // We should use a new timestamp for next batch execution.
+ // Otherwise Spark Structure Streaming will not see new data in TiDB.
+ if (!TiStrategy.hasTSAssigned(plan)) {
+ plan foreachUp applyStartTs(ts, forceUpdate = true)
+ TiStrategy.markTSAssigned(plan)
+ }
+ } else {
+ plan foreachUp applyStartTs(ts)
+ }
+
+ plan
+ .collectFirst {
+ case DataSourceV2ScanRelation(
+ DataSourceV2Relation(table: TiDBTable, _, _, _, _),
+ _,
+ _,
+ _) =>
+ doPlan(table, plan)
+ }
+ .toSeq
+ .flatten
+ }
+
+ def referencedTiColumns(expression: TiExpression): Seq[TiColumnRef] =
+ PredicateUtils.extractColumnRefFromExpression(expression).asScala.toSeq
+
+ /**
+ * build a Seq of used TiColumnRef from AttributeSet and bound them to source table
+ *
+ * @param attributeSet AttributeSet containing projects w/ or w/o filters
+ * @param source source TiDBRelation
+ * @return a Seq of TiColumnRef extracted
+ */
+ def buildTiColumnRefFromColumnSeq(
+ attributeSet: AttributeSet,
+ source: TiDBTable): Seq[TiColumnRef] = {
+ val tiColumnSeq: Seq[TiExpression] = attributeSet.toSeq.map { expr =>
+ TiExprUtils.transformAttrToColRef(expr, source.table)
+ }
+ var tiColumns: mutable.HashSet[TiColumnRef] = mutable.HashSet.empty[TiColumnRef]
+ for (expression <- tiColumnSeq) {
+ val colSetPerExpr = PredicateUtils.extractColumnRefFromExpression(expression)
+ colSetPerExpr.asScala.foreach {
+ tiColumns += _
+ }
+ }
+ tiColumns.toSeq
+ }
+
+ // apply StartTs to every logical plan in Spark Planning stage
+ protected def applyStartTs(
+ ts: TiTimestamp,
+ forceUpdate: Boolean = false): PartialFunction[LogicalPlan, Unit] = {
+ case DataSourceV2ScanRelation(
+ DataSourceV2Relation(r @ TiDBTable(_, _, _, timestamp, _), _, _, _, _),
+ _,
+ _,
+ _) =>
+ if (timestamp == null || forceUpdate) {
+ r.ts = ts
+ }
+ case logicalPlan =>
+ logicalPlan transformExpressionsUp {
+ case s: SubqueryExpression =>
+ s.plan.foreachUp(applyStartTs(ts))
+ s
+ }
+ }
+
+ private def blocklist: ExpressionBlocklist = {
+ val blocklistString = sqlConf.getConfString(TiConfigConst.UNSUPPORTED_PUSHDOWN_EXPR, "")
+ new ExpressionBlocklist(blocklistString)
+ }
+
+ private def allowAggregationPushDown(): Boolean =
+ sqlConf.getConfString(TiConfigConst.ALLOW_AGG_PUSHDOWN, "true").toLowerCase.toBoolean
+
+ private def useIndexScanFirst(): Boolean =
+ sqlConf.getConfString(TiConfigConst.USE_INDEX_SCAN_FIRST, "false").toLowerCase.toBoolean
+
+ private def allowIndexRead(): Boolean =
+ sqlConf.getConfString(TiConfigConst.ALLOW_INDEX_READ, "true").toLowerCase.toBoolean
+
+ private def useStreamingProcess: Boolean =
+ sqlConf.getConfString(TiConfigConst.COPROCESS_STREAMING, "false").toLowerCase.toBoolean
+
+ private def getCodecFormat: EncodeType = {
+ // FIXME: Should use default codec format "chblock", change it back after fix.
+ val codecFormatStr =
+ sqlConf
+ .getConfString(TiConfigConst.CODEC_FORMAT, TiConfigConst.DEFAULT_CODEC_FORMAT)
+ .toLowerCase
+
+ codecFormatStr match {
+ case TiConfigConst.CHUNK_CODEC_FORMAT => EncodeType.TypeChunk
+ case TiConfigConst.DEFAULT_CODEC_FORMAT => EncodeType.TypeCHBlock
+ case _ => EncodeType.TypeDefault
+ }
+ }
+
+ private def eligibleStorageEngines(source: TiDBTable): List[TiStoreType] =
+ TiUtil.getIsolationReadEngines(sqlContext).filter {
+ case TiStoreType.TiKV => true
+ case TiStoreType.TiFlash => source.isTiFlashReplicaAvailable
+ case _ => false
+ }
+
+ private def timeZoneOffsetInSeconds(): Int = {
+ val tz = DateTimeZone.getDefault
+ val instant = DateTime.now.getMillis
+ val offsetInMilliseconds = tz.getOffset(instant)
+ val hours = TimeUnit.MILLISECONDS.toHours(offsetInMilliseconds).toInt
+ val seconds = hours * 3600
+ seconds
+ }
+
+ private def newTiDAGRequest(): TiDAGRequest = {
+ val ts = timeZoneOffsetInSeconds()
+ if (useStreamingProcess) {
+ new TiDAGRequest(PushDownType.STREAMING, ts)
+ } else {
+ new TiDAGRequest(PushDownType.NORMAL, getCodecFormat, ts)
+ }
+ }
+
+ private def toCoprocessorRDD(
+ source: TiDBTable,
+ output: Seq[Attribute],
+ dagRequest: TiDAGRequest): SparkPlan = {
+ dagRequest.setTableInfo(source.table)
+ dagRequest.setStartTs(source.ts)
+
+ val notAllowPushDown = dagRequest.getFields.asScala
+ .map {
+ _.getDataType.getType
+ }
+ .exists {
+ typeBlockList.isUnsupportedType
+ }
+
+ if (notAllowPushDown) {
+ throw new IgnoreUnsupportedTypeException(
+ "Unsupported type found in fields: " + typeBlockList)
+ } else {
+ if (dagRequest.isDoubleRead) {
+ source.dagRequestToRegionTaskExec(dagRequest, output)
+ } else {
+ ColumnarCoprocessorRDD(
+ output,
+ source.logicalPlanToRDD(dagRequest, output),
+ fetchHandle = false)
+ }
+ }
+ }
+
+ private def aggregationToDAGRequest(
+ groupByList: Seq[NamedExpression],
+ aggregates: Seq[AggregateExpression],
+ source: TiDBTable,
+ dagRequest: TiDAGRequest): TiDAGRequest = {
+ aggregates
+ .map {
+ _.aggregateFunction
+ }
+ .foreach { expr =>
+ TiExprUtils.transformAggExprToTiAgg(expr, source.table, dagRequest)
+ }
+
+ groupByList.foreach { expr =>
+ TiExprUtils.transformGroupingToTiGrouping(expr, source.table, dagRequest)
+ }
+
+ dagRequest
+ }
+
+ private def filterToDAGRequest(
+ tiColumns: Seq[TiColumnRef],
+ filters: Seq[Expression],
+ source: TiDBTable,
+ dagRequest: TiDAGRequest): TiDAGRequest = {
+ val tiFilters: Seq[TiExpression] = filters.map {
+ TiExprUtils.transformFilter(_, source.table, dagRequest)
+ }
+
+ val scanBuilder: TiKVScanAnalyzer = new TiKVScanAnalyzer
+
+ val tblStatistics: TableStatistics = StatisticsManager.getTableStatistics(source.table.getId)
+
+ // engines that could be chosen.
+ val engines = eligibleStorageEngines(source)
+
+ if (engines.isEmpty) {
+ throw new RuntimeException(
+ s"No eligible storage engines found for $source, " +
+ s"isolation_read_engines = ${TiUtil.getIsolationReadEngines(sqlContext)}")
+ }
+
+ scanBuilder.buildTiDAGReq(
+ allowIndexRead(),
+ useIndexScanFirst(),
+ engines.contains(TiStoreType.TiKV),
+ engines.contains(TiStoreType.TiFlash),
+ tiColumns.map { colRef =>
+ source.table.getColumn(colRef.getName)
+ }.asJava,
+ tiFilters.asJava,
+ source.table,
+ tblStatistics,
+ source.ts,
+ dagRequest)
+ }
+
+ private def pruneTopNFilterProject(
+ limit: Int,
+ projectList: Seq[NamedExpression],
+ filterPredicates: Seq[Expression],
+ source: TiDBTable,
+ sortOrder: Seq[SortOrder]): SparkPlan = {
+ val request = newTiDAGRequest()
+ request.setLimit(limit)
+ TiExprUtils.transformSortOrderToTiOrderBy(request, sortOrder, source.table)
+
+ pruneFilterProject(projectList, filterPredicates, source, request)
+ }
+
+ private def collectLimit(limit: Int, child: LogicalPlan): SparkPlan =
+ child match {
+ case PhysicalOperation(
+ projectList,
+ filters,
+ DataSourceV2ScanRelation(
+ DataSourceV2Relation(source: TiDBTable, _, _, _, _),
+ _,
+ _,
+ _)) if filters.forall(TiExprUtils.isSupportedFilter(_, source, blocklist)) =>
+ pruneTopNFilterProject(limit, projectList, filters, source, Nil)
+ case _ => planLater(child)
+ }
+
+ private def takeOrderedAndProject(
+ limit: Int,
+ sortOrder: Seq[SortOrder],
+ child: LogicalPlan,
+ project: Seq[NamedExpression]): SparkPlan = {
+ // If sortOrder is empty, limit must be greater than 0
+ if (limit < 0 || (sortOrder.isEmpty && limit == 0)) {
+ return execution.TakeOrderedAndProjectExec(limit, sortOrder, project, planLater(child))
+ }
+
+ child match {
+ case PhysicalOperation(
+ projectList,
+ filters,
+ DataSourceV2ScanRelation(
+ DataSourceV2Relation(source: TiDBTable, _, _, _, _),
+ _,
+ _,
+ _)) if filters.forall(TiExprUtils.isSupportedFilter(_, source, blocklist)) =>
+ val refinedOrders = refineSortOrder(projectList, sortOrder, source)
+ if (refinedOrders.isEmpty) {
+ execution.TakeOrderedAndProjectExec(limit, sortOrder, project, planLater(child))
+ } else {
+ execution.TakeOrderedAndProjectExec(
+ limit,
+ sortOrder,
+ project,
+ pruneTopNFilterProject(limit, projectList, filters, source, refinedOrders.get))
+ }
+ case _ => execution.TakeOrderedAndProjectExec(limit, sortOrder, project, planLater(child))
+ }
+ }
+
+ // refine sort order
+ // 1. sort order expressions are all valid to be pushed
+ // 2. if any reference to projections are valid to be pushed
+ private def refineSortOrder(
+ projectList: Seq[NamedExpression],
+ sortOrders: Seq[SortOrder],
+ source: TiDBTable): Option[Seq[SortOrder]] = {
+ val aliases = AttributeMap(projectList.collect {
+ case a: Alias => a.toAttribute -> a
+ })
+ // Order by desc/asc + nulls first/last
+ //
+ // 1. Order by asc + nulls first:
+ // order by col asc nulls first = order by col asc
+ // 2. Order by desc + nulls first:
+ // order by col desc nulls first = order by col is null desc, col desc
+ // 3. Order by asc + nulls last:
+ // order by col asc nulls last = order by col is null asc, col asc
+ // 4. Order by desc + nulls last:
+ // order by col desc nulls last = order by col desc
+ val refinedSortOrder = sortOrders.flatMap { sortOrder: SortOrder =>
+ val newSortExpr = sortOrder.child.transformUp {
+ case a: Attribute => aliases.getOrElse(a, a)
+ }
+ val trimmedExpr = ReflectionUtil.trimNonTopLevelAliases(newSortExpr)
+ val trimmedSortOrder = ReflectionUtil.copySortOrder(sortOrder, trimmedExpr)
+
+ (sortOrder.direction, sortOrder.nullOrdering) match {
+ case (_ @Ascending, _ @NullsLast) | (_ @Descending, _ @NullsFirst) =>
+ ReflectionUtil.copySortOrder(sortOrder, IsNull(trimmedExpr)) :: trimmedSortOrder :: Nil
+ case _ =>
+ trimmedSortOrder :: Nil
+ }
+ }
+ if (refinedSortOrder
+ .exists(order => !TiExprUtils.isSupportedOrderBy(order.child, source, blocklist))) {
+ Option.empty
+ } else {
+ Some(refinedSortOrder)
+ }
+ }
+
+ private def pruneFilterProject(
+ projectList: Seq[NamedExpression],
+ filterPredicates: Seq[Expression],
+ source: TiDBTable,
+ dagRequest: TiDAGRequest): SparkPlan = {
+
+ val projectSet = AttributeSet(projectList.flatMap(_.references))
+ val filterSet = AttributeSet(filterPredicates.flatMap(_.references))
+
+ val (pushdownFilters: Seq[Expression], residualFilters: Seq[Expression]) =
+ filterPredicates.partition((expression: Expression) =>
+ TiExprUtils.isSupportedFilter(expression, source, blocklist))
+
+ val residualFilter: Option[Expression] =
+ residualFilters.reduceLeftOption(catalyst.expressions.And)
+
+ val tiColumns = buildTiColumnRefFromColumnSeq(projectSet ++ filterSet, source)
+
+ filterToDAGRequest(tiColumns, pushdownFilters, source, dagRequest)
+
+ if (tiColumns.isEmpty) {
+ // we cannot send a request with empty columns
+ if (dagRequest.hasIndex) {
+ // add the first index column so that the plan will contain at least one column.
+ val idxColumn = dagRequest.getIndexInfo.getIndexColumns.get(0)
+ dagRequest.addRequiredColumn(ColumnRef.create(idxColumn.getName, source.table))
+ } else {
+ // add a random column so that the plan will contain at least one column.
+ // if the table contains a primary key then use the PK instead.
+ val column = source.table.getColumns.asScala
+ .collectFirst {
+ case e if e.isPrimaryKey => e
+ }
+ .getOrElse(source.table.getColumn(0))
+ dagRequest.addRequiredColumn(ColumnRef.create(column.getName, source.table))
+ }
+ }
+
+ // Right now we still use a projection even if the only evaluation is applying an alias
+ // to a column. Since this is a no-op, it could be avoided. However, using this
+ // optimization with the current implementation would change the output schema.
+ // TODO: Decouple final output schema from expression evaluation so this copy can be
+ // avoided safely.
+ if (AttributeSet(projectList.map(_.toAttribute)) == projectSet &&
+ filterSet.subsetOf(projectSet)) {
+ // When it is possible to just use column pruning to get the right projection and
+ // when the columns of this projection are enough to evaluate all filter conditions,
+ // just do a scan followed by a filter, with no extra project.
+ val projectSeq: Seq[Attribute] = projectList.asInstanceOf[Seq[Attribute]]
+ projectSeq.foreach(
+ attr =>
+ dagRequest.addRequiredColumn(
+ ColumnRef.create(attr.name, source.table.getColumn(attr.name))))
+ val scan = toCoprocessorRDD(source, projectSeq, dagRequest)
+ residualFilter.fold(scan)(FilterExec(_, scan))
+ } else {
+ // for now all column used will be returned for old interface
+ // TODO: once switch to new interface we change this pruning logic
+ val projectSeq: Seq[Attribute] = (projectSet ++ filterSet).toSeq
+ projectSeq.foreach(
+ attr =>
+ dagRequest.addRequiredColumn(
+ ColumnRef.create(attr.name, source.table.getColumn(attr.name))))
+ val scan = toCoprocessorRDD(source, projectSeq, dagRequest)
+ ProjectExec(projectList, residualFilter.fold(scan)(FilterExec(_, scan)))
+ }
+ }
+
+ private def groupAggregateProjection(
+ tiColumns: Seq[TiColumnRef],
+ groupingExpressions: Seq[NamedExpression],
+ aggregateExpressions: Seq[AggregateExpression],
+ resultExpressions: Seq[NamedExpression],
+ source: TiDBTable,
+ dagReq: TiDAGRequest): Seq[SparkPlan] = {
+ val deterministicAggAliases = aggregateExpressions.collect {
+ case e if e.deterministic => e.canonicalized -> ReflectionUtil.newAlias(e, e.toString())
+ }.toMap
+
+ def aliasPushedPartialResult(e: AggregateExpression): Alias =
+ deterministicAggAliases.getOrElse(e.canonicalized, ReflectionUtil.newAlias(e, e.toString()))
+
+ val residualAggregateExpressions = aggregateExpressions.map { aggExpr =>
+ // As `aggExpr` is being pushing down to TiKV, we need to replace the original Catalyst
+ // aggregate expressions with new ones that merges the partial aggregation results returned by
+ // TiKV.
+ //
+ // NOTE: Unlike simple aggregate functions (e.g., `Max`, `Min`, etc.), `Count` must be
+ // replaced with a `Sum` to sum up the partial counts returned by TiKV.
+ //
+ // NOTE: All `Average`s should have already been rewritten into `Sum`s and `Count`s by the
+ // `TiAggregation` pattern extractor.
+
+ // An attribute referring to the partial aggregation results returned by TiKV.
+ val partialResultRef = aliasPushedPartialResult(aggExpr).toAttribute
+
+ aggExpr.aggregateFunction match {
+ case e: Max => aggExpr.copy(aggregateFunction = e.copy(child = partialResultRef))
+ case e: Min => aggExpr.copy(aggregateFunction = e.copy(child = partialResultRef))
+ case e: Sum => aggExpr.copy(aggregateFunction = e.copy(child = partialResultRef))
+ case e: SpecialSum => aggExpr.copy(aggregateFunction = e.copy(child = partialResultRef))
+ case e: First => aggExpr.copy(aggregateFunction = e.copy(child = partialResultRef))
+ case _: Count =>
+ aggExpr.copy(aggregateFunction = CountSum(partialResultRef))
+ case _: Average => throw new IllegalStateException("All AVGs should have been rewritten.")
+ case _ => aggExpr
+ }
+ }
+
+ tiColumns foreach {
+ dagReq.addRequiredColumn
+ }
+
+ aggregationToDAGRequest(groupingExpressions, aggregateExpressions.distinct, source, dagReq)
+
+ val aggregateAttributes =
+ aggregateExpressions.map(expr => aliasPushedPartialResult(expr).toAttribute)
+ val groupAttributes = groupingExpressions.map(_.toAttribute)
+
+ // output of Coprocessor plan should contain all references within
+ // aggregates and group by expressions
+ val output = aggregateAttributes ++ groupAttributes
+
+ val groupExpressionMap =
+ groupingExpressions.map(expr => expr.exprId -> expr.toAttribute).toMap
+
+ // resultExpression might refer to some of the group by expressions
+ // Those expressions originally refer to table columns but now it refers to
+ // results of coprocessor.
+ // For example, select a + 1 from t group by a + 1
+ // expression a + 1 has been pushed down to coprocessor
+ // and in turn a + 1 in projection should be replaced by
+ // reference of coprocessor output entirely
+ val rewrittenResultExpressions = resultExpressions.map {
+ _.transform {
+ case e: NamedExpression => groupExpressionMap.getOrElse(e.exprId, e)
+ }.asInstanceOf[NamedExpression]
+ }
+
+ aggregate.AggUtils.planAggregateWithoutDistinct(
+ groupAttributes,
+ residualAggregateExpressions,
+ rewrittenResultExpressions,
+ toCoprocessorRDD(source, output, dagReq))
+ }
+
+ private def isValidAggregates(
+ groupingExpressions: Seq[NamedExpression],
+ aggregateExpressions: Seq[AggregateExpression],
+ filters: Seq[Expression],
+ source: TiDBTable): Boolean =
+ allowAggregationPushDown &&
+ filters.forall(TiExprUtils.isSupportedFilter(_, source, blocklist)) &&
+ groupingExpressions.forall(TiExprUtils.isSupportedGroupingExpr(_, source, blocklist)) &&
+ aggregateExpressions.forall(TiExprUtils.isSupportedAggregate(_, source, blocklist)) &&
+ !aggregateExpressions.exists(_.isDistinct) &&
+ // TODO: This is a temporary fix for the issue: https://github.com/pingcap/tispark/issues/1039
+ !groupingExpressions.exists(_.isInstanceOf[Alias])
+
+ // We do through similar logic with original Spark as in SparkStrategies.scala
+ // Difference is we need to test if a sub-plan can be consumed all together by TiKV
+ // and then we don't return (don't planLater) and plan the remaining all at once
+ // TODO: This test should be done once for all children
+ private def doPlan(source: TiDBTable, plan: LogicalPlan): Seq[SparkPlan] =
+ plan match {
+ case logical.ReturnAnswer(rootPlan) =>
+ rootPlan match {
+ case logical.Limit(IntegerLiteral(limit), logical.Sort(order, true, child)) =>
+ takeOrderedAndProject(limit, order, child, child.output) :: Nil
+ case logical.Limit(
+ IntegerLiteral(limit),
+ logical.Project(projectList, logical.Sort(order, true, child))) =>
+ takeOrderedAndProject(limit, order, child, projectList) :: Nil
+ case logical.Limit(IntegerLiteral(limit), child) =>
+ execution.CollectLimitExec(limit, collectLimit(limit, child)) :: Nil
+ case other => planLater(other) :: Nil
+ }
+ case logical.Limit(IntegerLiteral(limit), logical.Sort(order, true, child)) =>
+ takeOrderedAndProject(limit, order, child, child.output) :: Nil
+ case logical.Limit(
+ IntegerLiteral(limit),
+ logical.Project(projectList, logical.Sort(order, true, child))) =>
+ takeOrderedAndProject(limit, order, child, projectList) :: Nil
+ case logical.Limit(IntegerLiteral(limit), child) =>
+ execution.CollectLimitExec(limit, collectLimit(limit, child)) :: Nil
+ // Collapse filters and projections and push plan directly
+ case PhysicalOperation(
+ projectList,
+ filters,
+ DataSourceV2ScanRelation(
+ DataSourceV2Relation(source: TiDBTable, _, _, _, _),
+ _,
+ _,
+ _)) =>
+ pruneFilterProject(projectList, filters, source, newTiDAGRequest()) :: Nil
+
+ // Basic logic of original Spark's aggregation plan is:
+ // PhysicalAggregation extractor will rewrite original aggregation
+ // into aggregateExpressions and resultExpressions.
+ // resultExpressions contains only references [[AttributeReference]]
+ // to the result of aggregation. resultExpressions might contain projections
+ // like Add(sumResult, 1).
+ // For a aggregate like agg(expr) + 1, the rewrite process is: rewrite agg(expr) ->
+ // 1. pushdown: agg(expr) as agg1, if avg then sum(expr), count(expr)
+ // 2. residual expr (for Spark itself): agg(agg1) as finalAgg1 the parameter is a
+ // reference to pushed plan's corresponding aggregation
+ // 3. resultExpressions: finalAgg1 + 1, the finalAgg1 is the reference to final result
+ // of the aggregation
+ case TiAggregation(
+ groupingExpressions,
+ aggregateExpressions,
+ resultExpressions,
+ TiAggregationProjectionV2(filters, _, `source`, projects))
+ if isValidAggregates(groupingExpressions, aggregateExpressions, filters, source) =>
+ val projectSet = AttributeSet((projects ++ filters).flatMap {
+ _.references
+ })
+ val tiColumns = buildTiColumnRefFromColumnSeq(projectSet, source)
+ val dagReq: TiDAGRequest =
+ filterToDAGRequest(tiColumns, filters, source, newTiDAGRequest())
+ groupAggregateProjection(
+ tiColumns,
+ groupingExpressions,
+ aggregateExpressions,
+ resultExpressions,
+ `source`,
+ dagReq)
+ case _ => Nil
+ }
+}