Skip to content

Commit

Permalink
support spark 3.1.1 (pingcap#2024)
Browse files Browse the repository at this point in the history
  • Loading branch information
marsishandsome authored Apr 27, 2021
1 parent b81669a commit 97666d1
Show file tree
Hide file tree
Showing 26 changed files with 1,717 additions and 441 deletions.
1 change: 1 addition & 0 deletions .ci/integration_test.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,7 @@ def call(ghprbActualCommit, ghprbCommentBody, ghprbPullId, ghprbPullTitle, ghprb
"""
sh """
export MAVEN_OPTS="-Xmx6G -XX:MaxPermSize=512M"
mvn clean package -DskipTests
mvn test ${MVN_PROFILE} -Dtest=moo ${mvnStr}
"""
}
Expand Down
21 changes: 21 additions & 0 deletions assembly/src/main/assembly/assembly.xml
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,25 @@
<unpack>true</unpack>
</dependencySet>
</dependencySets>

<fileSets>
<fileSet>
<directory>
${project.parent.basedir}/spark-wrapper/spark-3.0/target/classes/
</directory>
<outputDirectory>resources/spark-wrapper-spark-3_0</outputDirectory>
<includes>
<include>**/*</include>
</includes>
</fileSet>
<fileSet>
<directory>
${project.parent.basedir}/spark-wrapper/spark-3.1/target/classes/
</directory>
<outputDirectory>resources/spark-wrapper-spark-3_1</outputDirectory>
<includes>
<include>**/*</include>
</includes>
</fileSet>
</fileSets>
</assembly>
2 changes: 1 addition & 1 deletion core/src/main/scala/com/pingcap/tispark/TiSparkInfo.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import org.slf4j.LoggerFactory
object TiSparkInfo {
private final val logger = LoggerFactory.getLogger(getClass.getName)

val SUPPORTED_SPARK_VERSION: List[String] = "3.0" :: Nil
val SUPPORTED_SPARK_VERSION: List[String] = "3.0" :: "3.1" :: Nil

val SPARK_VERSION: String = org.apache.spark.SPARK_VERSION

Expand Down
155 changes: 155 additions & 0 deletions core/src/main/scala/com/pingcap/tispark/utils/ReflectionUtil.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
/*
*
* Copyright 2021 PingCAP, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.pingcap.tispark.utils

import com.pingcap.tispark.TiSparkInfo
import org.apache.spark.sql.{SparkSession, TiContext}
import org.apache.spark.sql.catalyst.expressions.BasicExpression.TiExpression
import org.apache.spark.sql.catalyst.expressions.{Alias, ExprId, Expression, SortOrder}
import org.apache.spark.sql.catalyst.parser.ParserInterface
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.slf4j.LoggerFactory

import java.io.File
import java.net.{URL, URLClassLoader}

/**
* ReflectionUtil is designed to reflect methods which differ across
* different Spark versions. Compatibility issues should be solved by
* reflections in future.
*/
object ReflectionUtil {
lazy val classLoader: URLClassLoader = {
val tisparkClassUrl = this.getClass.getProtectionDomain.getCodeSource.getLocation
val tisparkClassPath = new File(tisparkClassUrl.getFile)
logger.info(s"tispark class url: ${tisparkClassUrl.toString}")

val sparkWrapperClassURL: URL = if (tisparkClassPath.isDirectory) {
val classDir = new File(
s"${tisparkClassPath.getAbsolutePath}/../../../spark-wrapper/spark-${TiSparkInfo.SPARK_MAJOR_VERSION}/target/classes/")
if (!classDir.exists()) {
throw new Exception(
"cannot find spark wrapper classes! please compile the spark-wrapper project first!")
}
classDir.toURI.toURL
} else {
new URL(
s"jar:$tisparkClassUrl!/resources/spark-wrapper-spark-${TiSparkInfo.SPARK_MAJOR_VERSION
.replace('.', '_')}/")
}
logger.info(s"spark wrapper class url: ${sparkWrapperClassURL.toString}")

new URLClassLoader(Array(sparkWrapperClassURL), this.getClass.getClassLoader)
}
private val logger = LoggerFactory.getLogger(getClass.getName)
private val SPARK_WRAPPER_CLASS = "com.pingcap.tispark.SparkWrapper"
private val TI_BASIC_EXPRESSION_CLASS =
"org.apache.spark.sql.catalyst.expressions.TiBasicExpression"
private val TI_RESOLUTION_RULE_CLASS =
"org.apache.spark.sql.extensions.TiResolutionRule"
private val TI_RESOLUTION_RULE_V2_CLASS =
"org.apache.spark.sql.extensions.TiResolutionRuleV2"
private val TI_PARSER_CLASS =
"org.apache.spark.sql.extensions.TiParser"
private val TI_DDL_RULE_CLASS =
"org.apache.spark.sql.extensions.TiDDLRule"

def newAlias(child: Expression, name: String): Alias = {
classLoader
.loadClass(SPARK_WRAPPER_CLASS)
.getDeclaredMethod("newAlias", classOf[Expression], classOf[String])
.invoke(null, child, name)
.asInstanceOf[Alias]
}

def newAlias(child: Expression, name: String, exprId: ExprId): Alias = {
classLoader
.loadClass(SPARK_WRAPPER_CLASS)
.getDeclaredMethod("newAlias", classOf[Expression], classOf[String], classOf[ExprId])
.invoke(null, child, name, exprId)
.asInstanceOf[Alias]
}

def copySortOrder(sortOrder: SortOrder, child: Expression): SortOrder = {
classLoader
.loadClass(SPARK_WRAPPER_CLASS)
.getDeclaredMethod("copySortOrder", classOf[SortOrder], classOf[Expression])
.invoke(null, sortOrder, child)
.asInstanceOf[SortOrder]
}

def trimNonTopLevelAliases(e: Expression): Expression = {
classLoader
.loadClass(SPARK_WRAPPER_CLASS)
.getDeclaredMethod("trimNonTopLevelAliases", classOf[Expression])
.invoke(null, e)
.asInstanceOf[Expression]
}

def callTiBasicExpressionConvertToTiExpr(expr: Expression): Option[TiExpression] = {
classLoader
.loadClass(TI_BASIC_EXPRESSION_CLASS)
.getDeclaredMethod("convertToTiExpr", classOf[Expression])
.invoke(null, expr)
.asInstanceOf[Option[TiExpression]]
}

def newTiResolutionRule(
getOrCreateTiContext: SparkSession => TiContext,
sparkSession: SparkSession): Rule[LogicalPlan] = {
classLoader
.loadClass(TI_RESOLUTION_RULE_CLASS)
.getDeclaredConstructor(classOf[SparkSession => TiContext], classOf[SparkSession])
.newInstance(getOrCreateTiContext, sparkSession)
.asInstanceOf[Rule[LogicalPlan]]
}

def newTiResolutionRuleV2(
getOrCreateTiContext: SparkSession => TiContext,
sparkSession: SparkSession): Rule[LogicalPlan] = {
classLoader
.loadClass(TI_RESOLUTION_RULE_V2_CLASS)
.getDeclaredConstructor(classOf[SparkSession => TiContext], classOf[SparkSession])
.newInstance(getOrCreateTiContext, sparkSession)
.asInstanceOf[Rule[LogicalPlan]]
}

def newTiParser(
getOrCreateTiContext: SparkSession => TiContext,
sparkSession: SparkSession,
parserInterface: ParserInterface): ParserInterface = {
classLoader
.loadClass(TI_PARSER_CLASS)
.getDeclaredConstructor(
classOf[SparkSession => TiContext],
classOf[SparkSession],
classOf[ParserInterface])
.newInstance(getOrCreateTiContext, sparkSession, parserInterface)
.asInstanceOf[ParserInterface]
}

def newTiDDLRule(
getOrCreateTiContext: SparkSession => TiContext,
sparkSession: SparkSession): Rule[LogicalPlan] = {
classLoader
.loadClass(TI_DDL_RULE_CLASS)
.getDeclaredConstructor(classOf[SparkSession => TiContext], classOf[SparkSession])
.newInstance(getOrCreateTiContext, sparkSession)
.asInstanceOf[Rule[LogicalPlan]]
}
}
38 changes: 37 additions & 1 deletion core/src/main/scala/com/pingcap/tispark/utils/TiUtil.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
package com.pingcap.tispark.utils

import java.util.concurrent.TimeUnit

import com.pingcap.tikv.TiConfiguration
import com.pingcap.tikv.datatype.TypeMapping
import com.pingcap.tikv.meta.{TiDAGRequest, TiTableInfo}
Expand All @@ -30,7 +29,44 @@ import org.apache.spark.sql.{DataFrame, Row, SQLContext, SparkSession}
import org.apache.spark.{SparkConf, sql}
import org.tikv.kvproto.Kvrpcpb.{CommandPri, IsolationLevel}

import java.time.{Instant, LocalDate, ZoneId}
import java.util.TimeZone
import java.util.concurrent.TimeUnit.NANOSECONDS

object TiUtil {
val MICROS_PER_MILLIS = 1000L
val MICROS_PER_SECOND = 1000000L

def defaultTimeZone(): TimeZone = TimeZone.getDefault

def daysToMillis(days: Int): Long = {
daysToMillis(days, defaultTimeZone().toZoneId)
}

def daysToMillis(days: Int, zoneId: ZoneId): Long = {
val instant = daysToLocalDate(days).atStartOfDay(zoneId).toInstant
toMillis(instantToMicros(instant))
}

/*
* Converts the timestamp to milliseconds since epoch. In spark timestamp values have microseconds
* precision, so this conversion is lossy.
*/
def toMillis(us: Long): Long = {
// When the timestamp is negative i.e before 1970, we need to adjust the millseconds portion.
// Example - 1965-01-01 10:11:12.123456 is represented as (-157700927876544) in micro precision.
// In millis precision the above needs to be represented as (-157700927877).
Math.floorDiv(us, MICROS_PER_MILLIS)
}

def instantToMicros(instant: Instant): Long = {
val us = Math.multiplyExact(instant.getEpochSecond, MICROS_PER_SECOND)
val result = Math.addExact(us, NANOSECONDS.toMicros(instant.getNano))
result
}

def daysToLocalDate(days: Int): LocalDate = LocalDate.ofEpochDay(days)

def getSchemaFromTable(table: TiTableInfo): StructType = {
val fields = new Array[StructField](table.getColumns.size())
for (i <- 0 until table.getColumns.size()) {
Expand Down
13 changes: 4 additions & 9 deletions core/src/main/scala/org/apache/spark/sql/TiAggregationImpl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,10 @@

package org.apache.spark.sql

import com.pingcap.tispark.utils.ReflectionUtil
import org.apache.spark.sql.catalyst.expressions.NamedExpression.newExprId
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.{
Alias,
Cast,
Divide,
Expression,
NamedExpression
}
import org.apache.spark.sql.catalyst.expressions.{Cast, Divide, Expression, NamedExpression}
import org.apache.spark.sql.catalyst.planning.PhysicalAggregation
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.types.{DecimalType, DoubleType, FloatType, LongType}
Expand Down Expand Up @@ -97,15 +92,15 @@ object TiAggregationImpl {
case DoubleType => Cast(sum.resultAttribute, DoubleType)
case d: DecimalType => Cast(sum.resultAttribute, d)
}
(ref: Expression) -> Alias(castedSum, ref.name)(exprId = ref.exprId)
(ref: Expression) -> ReflectionUtil.newAlias(castedSum, ref.name, ref.exprId)
}

val avgRewrite: PartialFunction[Expression, Expression] = avgRewriteMap.map {
case (ref, Seq(sum, count)) =>
val castedSum = Cast(sum.resultAttribute, DoubleType)
val castedCount = Cast(count.resultAttribute, DoubleType)
val division = Cast(Divide(castedSum, castedCount), ref.dataType)
(ref: Expression) -> Alias(division, ref.name)(exprId = ref.exprId)
(ref: Expression) -> ReflectionUtil.newAlias(division, ref.name, ref.exprId)
}

val rewrittenResultExpressions = resultExpressions
Expand Down
19 changes: 9 additions & 10 deletions core/src/main/scala/org/apache/spark/sql/TiStrategy.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
package org.apache.spark.sql

import java.util.concurrent.TimeUnit

import com.pingcap.tidb.tipb.EncodeType
import com.pingcap.tikv.exception.IgnoreUnsupportedTypeException
import com.pingcap.tikv.expression._
Expand All @@ -26,10 +25,9 @@ import com.pingcap.tikv.predicates.{PredicateUtils, TiKVScanAnalyzer}
import com.pingcap.tikv.region.TiStoreType
import com.pingcap.tikv.statistics.TableStatistics
import com.pingcap.tispark.statistics.StatisticsManager
import com.pingcap.tispark.utils.TiUtil
import com.pingcap.tispark.utils.{ReflectionUtil, TiUtil}
import com.pingcap.tispark.{TiConfigConst, TiDBRelation}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.analysis.CleanupAliases
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.{
Alias,
Expand All @@ -38,15 +36,15 @@ import org.apache.spark.sql.catalyst.expressions.{
AttributeMap,
AttributeSet,
Descending,
TiExprUtils,
Expression,
IntegerLiteral,
IsNull,
NamedExpression,
NullsFirst,
NullsLast,
SortOrder,
SubqueryExpression
SubqueryExpression,
TiExprUtils
}
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical
Expand Down Expand Up @@ -378,11 +376,12 @@ case class TiStrategy(getOrCreateTiContext: SparkSession => TiContext)(sparkSess
val newSortExpr = sortOrder.child.transformUp {
case a: Attribute => aliases.getOrElse(a, a)
}
val trimmedExpr = CleanupAliases.trimNonTopLevelAliases(newSortExpr)
val trimmedSortOrder = sortOrder.copy(child = trimmedExpr)
val trimmedExpr = ReflectionUtil.trimNonTopLevelAliases(newSortExpr)
val trimmedSortOrder = ReflectionUtil.copySortOrder(sortOrder, trimmedExpr)

(sortOrder.direction, sortOrder.nullOrdering) match {
case (_ @Ascending, _ @NullsLast) | (_ @Descending, _ @NullsFirst) =>
sortOrder.copy(child = IsNull(trimmedExpr)) :: trimmedSortOrder :: Nil
ReflectionUtil.copySortOrder(sortOrder, IsNull(trimmedExpr)) :: trimmedSortOrder :: Nil
case _ =>
trimmedSortOrder :: Nil
}
Expand Down Expand Up @@ -471,11 +470,11 @@ case class TiStrategy(getOrCreateTiContext: SparkSession => TiContext)(sparkSess
source: TiDBRelation,
dagReq: TiDAGRequest): Seq[SparkPlan] = {
val deterministicAggAliases = aggregateExpressions.collect {
case e if e.deterministic => e.canonicalized -> Alias(e, e.toString())()
case e if e.deterministic => e.canonicalized -> ReflectionUtil.newAlias(e, e.toString())
}.toMap

def aliasPushedPartialResult(e: AggregateExpression): Alias =
deterministicAggAliases.getOrElse(e.canonicalized, Alias(e, e.toString())())
deterministicAggAliases.getOrElse(e.canonicalized, ReflectionUtil.newAlias(e, e.toString()))

val residualAggregateExpressions = aggregateExpressions.map { aggExpr =>
// As `aggExpr` is being pushing down to TiKV, we need to replace the original Catalyst
Expand Down
Loading

0 comments on commit 97666d1

Please sign in to comment.