Skip to content

Commit

Permalink
Spark 3.4: Implement SupportsRuntimeFiltering (#273)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhouyifan279 authored Oct 23, 2023
1 parent b58bb6c commit fed4cd2
Show file tree
Hide file tree
Showing 6 changed files with 100 additions and 4 deletions.
1 change: 1 addition & 0 deletions docs/configurations/02_sql_configurations.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ spark.clickhouse.ignoreUnsupportedTransform|false|ClickHouse supports using comp
spark.clickhouse.read.compression.codec|lz4|The codec used to decompress data for reading. Supported codecs: none, lz4.|0.5.0
spark.clickhouse.read.distributed.convertLocal|true|When reading Distributed table, read local table instead of itself. If `true`, ignore `spark.clickhouse.read.distributed.useClusterNodes`.|0.1.0
spark.clickhouse.read.format|json|Serialize format for reading. Supported formats: json, binary|0.6.0
spark.clickhouse.read.runtimeFilter.enabled|false|Enable runtime filter for reading.|0.8.0
spark.clickhouse.read.splitByPartitionId|true|If `true`, construct input partition filter by virtual column `_partition_id`, instead of partition value. There are known bugs to assemble SQL predication by partition value. This feature requires ClickHouse Server v21.6+|0.4.0
spark.clickhouse.useNullableQuerySchema|false|If `true`, mark all the fields of the query schema as nullable when executing `CREATE/REPLACE TABLE ... AS SELECT ...` on creating the table. Note, this configuration requires SPARK-43390(available in Spark 3.5), w/o this patch, it always acts as `true`.|0.8.0
spark.clickhouse.write.batchSize|10000|The number of records per batch on writing to ClickHouse.|0.1.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ package org.apache.spark.sql.clickhouse.cluster

import org.apache.spark.sql.clickhouse.ClickHouseSQLConf.READ_DISTRIBUTED_CONVERT_LOCAL
import org.apache.spark.sql.{AnalysisException, Row}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec

class ClickHouseClusterReadSuite extends SparkClickHouseClusterTest {

Expand Down Expand Up @@ -83,4 +85,33 @@ class ClickHouseClusterReadSuite extends SparkClickHouseClusterTest {
)
}
}

test("runtime filter - distributed table") {
withSimpleDistTable("single_replica", "runtime_db", "runtime_tbl", true) { (_, db, tbl_dist, _) =>
spark.sql("set spark.clickhouse.read.runtimeFilter.enabled=false")
checkAnswer(
spark.sql(s"SELECT id FROM $db.$tbl_dist " +
s"WHERE id IN (" +
s" SELECT id FROM $db.$tbl_dist " +
s" WHERE DATE_FORMAT(create_time, 'yyyy-MM-dd') between '2021-01-01' and '2022-01-01'" +
s")"),
Row(1)
)

spark.sql("set spark.clickhouse.read.runtimeFilter.enabled=true")
val df = spark.sql(s"SELECT id FROM $db.$tbl_dist " +
s"WHERE id IN (" +
s" SELECT id FROM $db.$tbl_dist " +
s" WHERE DATE_FORMAT(create_time, 'yyyy-MM-dd') between '2021-01-01' and '2022-01-01'" +
s")")
checkAnswer(df, Row(1))
val runtimeFilterExists = df.queryExecution.sparkPlan.exists {
case BatchScanExec(_, _, runtimeFilters, _, _, table, _, _, _)
if table.name() == TableIdentifier(tbl_dist, Some(db)).quotedString
&& runtimeFilters.nonEmpty => true
case _ => false
}
assert(runtimeFilterExists)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
package org.apache.spark.sql.clickhouse.single

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
import org.apache.spark.sql.types._

class ClickHouseSingleSuite extends SparkClickHouseSingleTest {
Expand Down Expand Up @@ -451,4 +453,36 @@ class ClickHouseSingleSuite extends SparkClickHouseSingleTest {
spark.sql(s"UNCACHE TABLE $db.$tbl")
}
}

test("runtime filter") {
val db = "runtime_db"
val tbl = "runtime_tbl"

withSimpleTable(db, tbl, true) {
spark.sql("set spark.clickhouse.read.runtimeFilter.enabled=false")
checkAnswer(
spark.sql(s"SELECT id FROM $db.$tbl " +
s"WHERE id IN (" +
s" SELECT id FROM $db.$tbl " +
s" WHERE DATE_FORMAT(create_time, 'yyyy-MM-dd') between '2021-01-01' and '2022-01-01'" +
s")"),
Row(1)
)

spark.sql("set spark.clickhouse.read.runtimeFilter.enabled=true")
val df = spark.sql(s"SELECT id FROM $db.$tbl " +
s"WHERE id IN (" +
s" SELECT id FROM $db.$tbl " +
s" WHERE DATE_FORMAT(create_time, 'yyyy-MM-dd') between '2021-01-01' and '2022-01-01'" +
s")")
checkAnswer(df, Row(1))
val runtimeFilterExists = df.queryExecution.sparkPlan.exists {
case BatchScanExec(_, _, runtimeFilters, _, _, table, _, _, _)
if table.name() == TableIdentifier(tbl, Some(db)).quotedString
&& runtimeFilters.nonEmpty => true
case _ => false
}
assert(runtimeFilterExists)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,13 @@ object ClickHouseSQLConf {
.transform(_.toLowerCase)
.createWithDefault("json")

val RUNTIME_FILTER_ENABLED: ConfigEntry[Boolean] =
buildConf("spark.clickhouse.read.runtimeFilter.enabled")
.doc("Enable runtime filter for reading.")
.version("0.8.0")
.booleanConf
.createWithDefault(false)

val WRITE_FORMAT: ConfigEntry[String] =
buildConf("spark.clickhouse.write.format")
.doc("Serialize format for writing. Supported formats: json, arrow")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ class ReadOptions(_options: JMap[String, String]) extends SparkOptions {

def format: String =
eval(READ_FORMAT.key, READ_FORMAT)

def runtimeFilterEnabled: Boolean =
eval(RUNTIME_FILTER_ENABLED.key, RUNTIME_FILTER_ENABLED)
}

class WriteOptions(_options: JMap[String, String]) extends SparkOptions {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ package xenon.clickhouse.read

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.clickhouse.ClickHouseSQLConf._
import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.connector.expressions.{Expressions, NamedReference, Transform}
import org.apache.spark.sql.connector.expressions.aggregate.Aggregation
import org.apache.spark.sql.connector.metric.CustomMetric
import org.apache.spark.sql.connector.read._
Expand Down Expand Up @@ -127,8 +127,14 @@ class ClickHouseScanBuilder(

class ClickHouseBatchScan(scanJob: ScanJobDescription) extends Scan with Batch
with SupportsReportPartitioning
with SupportsRuntimeFiltering
with PartitionReaderFactory
with ClickHouseHelper {
with ClickHouseHelper
with SQLHelper {

implicit private val tz: ZoneId = scanJob.tz

private var runtimeFilters: Array[Filter] = Array.empty

val database: String = scanJob.database
val table: String = scanJob.table
Expand Down Expand Up @@ -187,9 +193,13 @@ class ClickHouseBatchScan(scanJob: ScanJobDescription) extends Scan with Batch
override def createReader(_partition: InputPartition): PartitionReader[InternalRow] = {
val format = scanJob.readOptions.format
val partition = _partition.asInstanceOf[ClickHouseInputPartition]
val finalScanJob = scanJob.copy(filtersExpr =
scanJob.filtersExpr + " AND "
+ compileFilters(AlwaysTrue :: runtimeFilters.toList)
)
format match {
case "json" => new ClickHouseJsonReader(scanJob, partition)
case "binary" => new ClickHouseBinaryReader(scanJob, partition)
case "json" => new ClickHouseJsonReader(finalScanJob, partition)
case "binary" => new ClickHouseBinaryReader(finalScanJob, partition)
case unsupported => throw CHClientException(s"Unsupported read format: $unsupported")
}
}
Expand All @@ -198,4 +208,14 @@ class ClickHouseBatchScan(scanJob: ScanJobDescription) extends Scan with Batch
BlocksReadMetric(),
BytesReadMetric()
)

override def filterAttributes(): Array[NamedReference] =
if (scanJob.readOptions.runtimeFilterEnabled) {
scanJob.readSchema.fields.map(field => Expressions.column(field.name))
} else {
Array.empty
}

override def filter(filters: Array[Filter]): Unit =
runtimeFilters = filters
}

0 comments on commit fed4cd2

Please sign in to comment.