Skip to content

Commit

Permalink
pushdown aggregation's pre-projection ahead expand node
Browse files Browse the repository at this point in the history
  • Loading branch information
lgbo-ustc committed Sep 6, 2024
1 parent 37d09c1 commit 05351c9
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,15 @@ object CHBackendSettings extends BackendSettingsApi with Logging {
)
}

// Move the pre-prejection for a aggregation ahead of the expand node
// for example, select a, b, sum(c+d) from t group by a, b with cube
def enablePushdownPreProjectionAheadExpand(): Boolean = {
SparkEnv.get.conf.getBoolean(
"spark.gluten.sql.columnar.backend.ch.enable_pushdown_preprojection_ahead_expand",
true
)
}

override def enableNativeWriteFiles(): Boolean = {
GlutenConfig.getConf.enableNativeWriter.getOrElse(false)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ private object CHRuleApi {
injector.injectTransform(_ => EliminateLocalSort)
injector.injectTransform(_ => CollapseProjectExecTransformer)
injector.injectTransform(c => RewriteSortMergeJoinToHashJoinRule.apply(c.session))
injector.injectTransform(c => PushdownExtraProjectionBeforeExpand.apply(c.session))
injector.injectTransform(
c => SparkPlanRules.extendedColumnarRule(c.conf.extendedColumnarTransformRules)(c.session))
injector.injectTransform(c => InsertTransitions(c.outputsColumnar))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -547,5 +547,20 @@ class GlutenClickHouseTPCHSuite extends GlutenClickHouseTPCHAbstractSuite {
compareResultsAgainstVanillaSpark(sql, true, { _ => })
spark.sql("drop table cross_join_t")
}

test("Pushdown aggregation pre-projection ahead expand") {
spark.sql("create table t1(a bigint, b bigint, c bigint, d bigint) using parquet")
spark.sql("insert into t1 values(1,2,3,4), (1,2,4,5), (1,3,4,5), (2,3,4,5)")
var sql = """
| select a, b , sum(d+c) from t1 group by a,b with cube
| order by a,b
|""".stripMargin
compareResultsAgainstVanillaSpark(sql, true, { _ => })
sql = """
| select a, b , sum(a+c), sum(b+d) from t1 group by a,b with cube
| order by a,b
|""".stripMargin
spark.sql("drop table t1")
}
}
// scalastyle:off line.size.limit

0 comments on commit 05351c9

Please sign in to comment.