Skip to content

Commit

Permalink
support spark 3.2 (pingcap#2287)
Browse files Browse the repository at this point in the history
  • Loading branch information
shiyuhang0 authored May 10, 2022
1 parent a16affa commit 1a48157
Show file tree
Hide file tree
Showing 19 changed files with 1,079 additions and 51 deletions.
9 changes: 9 additions & 0 deletions assembly/src/main/assembly/assembly.xml
Original file line number Diff line number Diff line change
Expand Up @@ -37,5 +37,14 @@
<include>**/*</include>
</includes>
</fileSet>
<fileSet>
<directory>
${project.parent.basedir}/spark-wrapper/spark-3.2/target/classes/
</directory>
<outputDirectory>resources/spark-wrapper-spark-3_2</outputDirectory>
<includes>
<include>**/*</include>
</includes>
</fileSet>
</fileSets>
</assembly>
17 changes: 17 additions & 0 deletions core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,17 @@
<groupId>org.apache.spark</groupId>
<artifactId>spark-core_${scala.binary.version}</artifactId>
<version>${spark.version.compile}</version>
<exclusions>
<exclusion>
<groupId>io.netty</groupId>
<artifactId>netty-all</artifactId>
</exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>io.netty</groupId>
<artifactId>netty-all</artifactId>
<version>4.1.75.Final</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
Expand Down Expand Up @@ -87,6 +98,12 @@
<groupId>org.apache.hadoop</groupId>
<artifactId>hadoop-client</artifactId>
<version>2.7.2</version>
<exclusions>
<exclusion>
<groupId>io.netty</groupId>
<artifactId>netty-all</artifactId>
</exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>com.google.guava</groupId>
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/scala/com/pingcap/tispark/TiSparkInfo.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import org.slf4j.LoggerFactory
object TiSparkInfo {
private final val logger = LoggerFactory.getLogger(getClass.getName)

val SUPPORTED_SPARK_VERSION: List[String] = "3.0" :: "3.1" :: Nil
val SUPPORTED_SPARK_VERSION: List[String] = "3.0" :: "3.1" :: "3.2" :: Nil

val SPARK_VERSION: String = org.apache.spark.SPARK_VERSION

Expand Down
15 changes: 5 additions & 10 deletions core/src/main/scala/com/pingcap/tispark/v2/TiDBTable.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@ import com.pingcap.tikv.key.Handle
import com.pingcap.tikv.meta.{TiDAGRequest, TiTableInfo, TiTimestamp}
import com.pingcap.tispark.utils.TiUtil
import com.pingcap.tispark.v2.TiDBTable.{getDagRequestToRegionTaskExec, getLogicalPlanToRDD}
import com.pingcap.tispark.v2.sink.TiDBWriterBuilder
import com.pingcap.tispark.v2.sink.TiDBWriteBuilder
import com.pingcap.tispark.write.{TiDBDelete, TiDBOptions}
import com.pingcap.tispark.TiTableReference
import org.apache.commons.lang3.StringUtils
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils, TimestampFormatter}
import org.apache.spark.sql.catalyst.util.{DateTimeUtils, TimestampFormatter}
import org.apache.spark.sql.connector.catalog.{
SupportsDelete,
SupportsRead,
Expand Down Expand Up @@ -61,7 +61,7 @@ import org.apache.spark.sql.sources.{
import org.apache.spark.sql.tispark.{TiHandleRDD, TiRowRDD}
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.sql.{SQLContext, SparkSession, execution}
import org.apache.spark.sql.{SQLContext, execution}
import org.slf4j.LoggerFactory

import java.sql.{Date, SQLException, Timestamp}
Expand Down Expand Up @@ -129,7 +129,6 @@ case class TiDBTable(
override def capabilities(): util.Set[TableCapability] = {
val capabilities = new util.HashSet[TableCapability]
capabilities.add(TableCapability.BATCH_READ)
capabilities.add(TableCapability.V1_BATCH_WRITE)
capabilities
}

Expand Down Expand Up @@ -159,7 +158,7 @@ case class TiDBTable(
}
// Get TiDBOptions
val tiDBOptions = new TiDBOptions(scalaMap)
TiDBWriterBuilder(info, tiDBOptions, sqlContext)
TiDBWriteBuilder(info, tiDBOptions, sqlContext)
}

override def deleteWhere(filters: Array[Filter]): Unit = {
Expand Down Expand Up @@ -247,7 +246,6 @@ object TiDBTable {
dagRequest,
session.getConf,
session.getTimestamp,
session,
sqlContext.sparkSession)
}

Expand Down Expand Up @@ -298,10 +296,7 @@ object TiDBTable {
DateTimeUtils.getZoneId(SQLConf.get.sessionLocalTimeZone))
s"'${timestampFormatter.format(timestampValue)}'"
case dateValue: Date => "'" + dateValue + "'"
case dateValue: LocalDate =>
val dateFormatter = DateFormatter(
DateTimeUtils.getZoneId(SQLConf.get.sessionLocalTimeZone))
s"'${dateFormatter.format(dateValue)}'"
case dateValue: LocalDate => "'" + dateValue + "'"
case arrayValue: Array[Any] => arrayValue.map(compileValue).mkString(", ")
case _ => value
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,27 +16,12 @@

package com.pingcap.tispark.v2.sink

import com.pingcap.tispark.write.{TiDBOptions, TiDBWriter}
import org.apache.spark.sql.{DataFrame, SQLContext, SaveMode}
import org.apache.spark.sql.connector.write.{LogicalWriteInfo, V1WriteBuilder}
import org.apache.spark.sql.sources.InsertableRelation
import com.pingcap.tispark.write.TiDBOptions
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.connector.write.{LogicalWriteInfo, WriteBuilder}

case class TiDBWriterBuilder(
case class TiDBWriteBuilder(
info: LogicalWriteInfo,
tiDBOptions: TiDBOptions,
sqlContext: SQLContext)
extends V1WriteBuilder {

// Use V1WriteBuilder before turn to v2
override def buildForV1Write(): InsertableRelation = { (data: DataFrame, overwrite: Boolean) =>
{
val saveMode = if (overwrite) {
SaveMode.Overwrite
} else {
SaveMode.Append
}
TiDBWriter.write(data, sqlContext, saveMode, tiDBOptions)
}
}

}
extends WriteBuilder {}
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ object TiExprUtils {
case _: Average =>
throw new IllegalArgumentException("Should never be here")

case f @ Sum(BasicExpression(arg)) =>
case f: Sum =>
val arg = BasicExpression.unapply(f.child).get
addingSumAggToDAgReq(meta, dagRequest, f, arg)

case f @ PromotedSum(BasicExpression(arg)) =>
Expand Down Expand Up @@ -175,7 +176,15 @@ object TiExprUtils {
tiDBRelation: TiDBTable,
blocklist: ExpressionBlocklist): Boolean =
aggExpr.aggregateFunction match {
case Average(_) | Sum(_) | SumNotNullable(_) | PromotedSum(_) | Min(_) | Max(_) =>
case _: Average =>
!aggExpr.isDistinct &&
aggExpr.aggregateFunction.children
.forall(isSupportedBasicExpression(_, tiDBRelation, blocklist))
case _: Sum =>
!aggExpr.isDistinct &&
aggExpr.aggregateFunction.children
.forall(isSupportedBasicExpression(_, tiDBRelation, blocklist))
case SumNotNullable(_) | PromotedSum(_) | Min(_) | Max(_) =>
!aggExpr.isDistinct &&
aggExpr.aggregateFunction.children
.forall(isSupportedBasicExpression(_, tiDBRelation, blocklist))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,4 +111,13 @@ case class SpecialSum(child: Expression, retType: DataType, initVal: Any)

override def checkInputDataTypes(): TypeCheckResult =
TypeUtils.checkForNumericExpr(child.dataType, "function sum")

/**
* The implement is same as the [[org.apache.spark.sql.catalyst.expressions.aggregate.Sum]]
* @param newChildren
*/
protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = {
assert(newChildren.size == 1, "Incorrect number of children")
copy(child = newChildren.head)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -64,24 +64,24 @@ object TiAggregationImpl {
._1

val sumsRewriteMap = sums.map {
case s @ AggregateExpression(Sum(ref), _, _, _, _) =>
case s @ AggregateExpression(r: Sum, _, _, _, _) =>
// need cast long type to decimal type
val sum =
if (ref.dataType.eq(LongType)) PromotedSum(ref) else Sum(ref)
if (r.child.dataType.eq(LongType)) PromotedSum(r.child) else Sum(r.child)
s.resultAttribute -> s.copy(aggregateFunction = sum, resultId = newExprId)
}.toMap

// An auxiliary map that maps result attribute IDs of all detected `Average`s to corresponding
// converted `Sum`s and `Count`s.
val avgRewriteMap = averages.map {
case a @ AggregateExpression(Average(ref), _, _, _, _) =>
case a @ AggregateExpression(r: Average, _, _, _, _) =>
// We need to do a type promotion on Sum(Long) to avoid LongType overflow in Average rewrite
// scenarios to stay consistent with original spark's Average behaviour
val sum =
if (ref.dataType.eq(LongType)) PromotedSum(ref) else Sum(ref)
if (r.child.dataType.eq(LongType)) PromotedSum(r.child) else Sum(r.child)
a.resultAttribute -> Seq(
a.copy(aggregateFunction = sum, resultId = newExprId),
a.copy(aggregateFunction = Count(ref), resultId = newExprId))
a.copy(aggregateFunction = Count(r.child), resultId = newExprId))
}.toMap

val sumRewrite = sumsRewriteMap.map {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,6 @@ case class ColumnarRegionTaskExec(
dagRequest: TiDAGRequest,
tiConf: TiConfiguration,
ts: TiTimestamp,
@transient private val session: TiSession,
@transient private val sparkSession: SparkSession)
extends UnaryExecNode {

Expand Down Expand Up @@ -392,4 +391,8 @@ case class ColumnarRegionTaskExec(
override protected def doExecute(): RDD[InternalRow] = {
WholeStageCodegenExec(this)(codegenStageId = 0).execute()
}

protected def withNewChildInternal(newChild: SparkPlan): ColumnarRegionTaskExec = {
copy(child = newChild)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,10 @@ class TiAuthIntegrationSuite extends SharedSQLContext {
test(f"Show tables should not contain invisible table") {
noException should be thrownBy spark.sql(s"use $databaseWithPrefix")

// spark 3.2 add isTemporary col when `show tables`, we need to exclude it.
val tables = spark
.sql(s"show tables")
.drop("isTemporary")
.collect()
.map(row => row.toString())
.toList
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,15 +75,4 @@ class BasicBatchWriteSuite extends BaseBatchWriteWithoutDropTableTest("test_data
caught.getMessage
.equals("SaveMode: Overwrite is not supported. TiSpark only support SaveMode.Append."))
}

// Experimental
test("Test Datasource api v2 write") {
jdbcUpdate(s"drop table if exists $dbtable")
jdbcUpdate(s"create table $dbtable(i int, s varchar(128))")
jdbcUpdate(s"insert into $dbtable values(null, 'Hello'), (2, 'TiDB')")
val data: RDD[Row] = sc.makeRDD(List(row3, row4))
val df = sqlContext.createDataFrame(data, schema)
df.writeTo(s"tidb_catalog.$database.$table").options(tidbOptions).append()
testTiDBSelect(Seq(row1, row2, row3, row4))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class DeleteWhereClauseSuite extends BaseBatchWriteTest("test_delete_where_claus
"s<'0'",
"s>='0'",
"s<='0'",
"s ='0\\'0'",
"s>'0' and s<'2'",
"s<'0' or s>'2'",
"s like '%1%'",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ class BasePlanTest extends BaseTiSparkTest {
def getEstimatedRowCount[T](df: Dataset[T], tableName: String): Double =
extractTiSparkPlans(df).collect { extractDAGRequest }.head.getEstimatedCount

def toPlan[T](df: Dataset[T]): SparkPlan = df.queryExecution.executedPlan
def toPlan[T](df: Dataset[T]): SparkPlan = df.queryExecution.sparkPlan

private def fail[T](df: Dataset[T], message: String, throwable: Throwable): Unit = {
df.explain
Expand Down
11 changes: 11 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@
<module>core</module>
<module>spark-wrapper/spark-3.0</module>
<module>spark-wrapper/spark-3.1</module>
<module>spark-wrapper/spark-3.2</module>
<module>assembly</module>
</modules>

Expand All @@ -181,6 +182,16 @@
<spark.version.test>3.1.1</spark.version.test>
</properties>
</profile>
<profile>
<id>spark-3.2.1</id>
<activation>
<activeByDefault>false</activeByDefault>
</activation>
<properties>
<spark.version.compile>3.2.1</spark.version.compile>
<spark.version.test>3.2.1</spark.version.test>
</properties>
</profile>
<profile>
<id>jenkins</id>
<modules>
Expand Down
Loading

0 comments on commit 1a48157

Please sign in to comment.