Skip to content

Commit

Permalink
[SPARK-48273][SQL] Fix late rewrite of PlanWithUnresolvedIdentifier
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

`PlanWithUnresolvedIdentifier` is rewritten later in analysis which causes rules like
`SubstituteUnresolvedOrdinals` to miss the new plan. This causes following queries to fail:
```
create temporary view identifier('v1') as (select my_col from (values (1), (2), (1) as (my_col)) group by 1);
--
cache table identifier('t1') as (select my_col from (values (1), (2), (1) as (my_col)) group by 1);
--
create table identifier('t2') as (select my_col from (values (1), (2), (1)
as (my_col)) group by 1);
insert into identifier('t2') select my_col from (values (3) as (my_col)) group by 1;
```
Fix this by explicitly applying rules after plan rewrite.

### Why are the changes needed?

To fix the described bug.

### Does this PR introduce _any_ user-facing change?

Yes, it fixes the mentioned problematic queries.

### How was this patch tested?

Updated existing `identifier-clause.sql` golden file.

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes apache#46580 from nikolamand-db/SPARK-48273.

Authored-by: Nikola Mandic <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
nikolamand-db authored and cloud-fan committed May 28, 2024
1 parent 7fe1b93 commit 731a2cf
Show file tree
Hide file tree
Showing 6 changed files with 139 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
TypeCoercion.typeCoercionRules
}

override def batches: Seq[Batch] = Seq(
private def earlyBatches: Seq[Batch] = Seq(
Batch("Substitution", fixedPoint,
new SubstituteExecuteImmediate(catalogManager),
// This rule optimizes `UpdateFields` expression chains so looks more like optimization rule.
Expand All @@ -274,7 +274,10 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
Batch("Simple Sanity Check", Once,
LookupFunctions),
Batch("Keep Legacy Outputs", Once,
KeepLegacyOutputs),
KeepLegacyOutputs)
)

override def batches: Seq[Batch] = earlyBatches ++ Seq(
Batch("Resolution", fixedPoint,
new ResolveCatalogs(catalogManager) ::
ResolveInsertInto ::
Expand Down Expand Up @@ -319,7 +322,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
ResolveTimeZone ::
ResolveRandomSeed ::
ResolveBinaryArithmetic ::
ResolveIdentifierClause ::
new ResolveIdentifierClause(earlyBatches) ::
ResolveUnion ::
ResolveRowLevelCommandAssignments ::
RewriteDeleteFromTable ::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,24 @@ package org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.catalyst.expressions.{AliasHelper, EvalHelper, Expression}
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor}
import org.apache.spark.sql.catalyst.trees.TreePattern.UNRESOLVED_IDENTIFIER
import org.apache.spark.sql.types.StringType

/**
* Resolves the identifier expressions and builds the original plans/expressions.
*/
object ResolveIdentifierClause extends Rule[LogicalPlan] with AliasHelper with EvalHelper {
class ResolveIdentifierClause(earlyBatches: Seq[RuleExecutor[LogicalPlan]#Batch])
extends Rule[LogicalPlan] with AliasHelper with EvalHelper {

private val executor = new RuleExecutor[LogicalPlan] {
override def batches: Seq[Batch] = earlyBatches.asInstanceOf[Seq[Batch]]
}

override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUpWithPruning(
_.containsAnyPattern(UNRESOLVED_IDENTIFIER)) {
case p: PlanWithUnresolvedIdentifier if p.identifierExpr.resolved =>
p.planBuilder.apply(evalIdentifierExpr(p.identifierExpr))
executor.execute(p.planBuilder.apply(evalIdentifierExpr(p.identifierExpr)))
case other =>
other.transformExpressionsWithPruning(_.containsAnyPattern(UNRESOLVED_IDENTIFIER)) {
case e: ExpressionWithUnresolvedIdentifier if e.identifierExpr.resolved =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging {
override val maxIterationsSetting: String = null) extends Strategy

/** A batch of rules. */
protected case class Batch(name: String, strategy: Strategy, rules: Rule[TreeType]*)
protected[catalyst] case class Batch(name: String, strategy: Strategy, rules: Rule[TreeType]*)

/** Defines a sequence of rule batches, to be overridden by the implementation. */
protected def batches: Seq[Batch]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -926,6 +926,65 @@ org.apache.spark.sql.catalyst.parser.ParseException
}


-- !query
create temporary view identifier('v1') as (select my_col from (values (1), (2), (1) as (my_col)) group by 1)
-- !query analysis
CreateViewCommand `v1`, (select my_col from (values (1), (2), (1) as (my_col)) group by 1), false, false, LocalTempView, UNSUPPORTED, true
+- Aggregate [my_col#x], [my_col#x]
+- SubqueryAlias __auto_generated_subquery_name
+- SubqueryAlias as
+- LocalRelation [my_col#x]


-- !query
cache table identifier('t1') as (select my_col from (values (1), (2), (1) as (my_col)) group by 1)
-- !query analysis
CacheTableAsSelect t1, (select my_col from (values (1), (2), (1) as (my_col)) group by 1), false, true
+- Aggregate [my_col#x], [my_col#x]
+- SubqueryAlias __auto_generated_subquery_name
+- SubqueryAlias as
+- LocalRelation [my_col#x]


-- !query
create table identifier('t2') as (select my_col from (values (1), (2), (1) as (my_col)) group by 1)
-- !query analysis
CreateDataSourceTableAsSelectCommand `spark_catalog`.`default`.`t2`, ErrorIfExists, [my_col]
+- Aggregate [my_col#x], [my_col#x]
+- SubqueryAlias __auto_generated_subquery_name
+- SubqueryAlias as
+- LocalRelation [my_col#x]


-- !query
insert into identifier('t2') select my_col from (values (3) as (my_col)) group by 1
-- !query analysis
InsertIntoHadoopFsRelationCommand file:[not included in comparison]/{warehouse_dir}/t2, false, Parquet, [path=file:[not included in comparison]/{warehouse_dir}/t2], Append, `spark_catalog`.`default`.`t2`, org.apache.spark.sql.execution.datasources.InMemoryFileIndex(file:[not included in comparison]/{warehouse_dir}/t2), [my_col]
+- Aggregate [my_col#x], [my_col#x]
+- SubqueryAlias __auto_generated_subquery_name
+- SubqueryAlias as
+- LocalRelation [my_col#x]


-- !query
drop view v1
-- !query analysis
DropTempViewCommand v1


-- !query
drop table t1
-- !query analysis
DropTempViewCommand t1


-- !query
drop table t2
-- !query analysis
DropTable false, false
+- ResolvedIdentifier V2SessionCatalog(spark_catalog), default.t2


-- !query
SELECT row_number() OVER IDENTIFIER('x.win') FROM VALUES(1) AS T(c1) WINDOW win AS (ORDER BY c1)
-- !query analysis
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,15 @@ CREATE TEMPORARY FUNCTION IDENTIFIER('default.my' || 'DoubleAvg') AS 'test.org.a
DROP TEMPORARY FUNCTION IDENTIFIER('default.my' || 'DoubleAvg');
CREATE TEMPORARY VIEW IDENTIFIER('default.v')(c1) AS VALUES(1);

-- SPARK-48273: Aggregation operation in statements using identifier clause for table name
create temporary view identifier('v1') as (select my_col from (values (1), (2), (1) as (my_col)) group by 1);
cache table identifier('t1') as (select my_col from (values (1), (2), (1) as (my_col)) group by 1);
create table identifier('t2') as (select my_col from (values (1), (2), (1) as (my_col)) group by 1);
insert into identifier('t2') select my_col from (values (3) as (my_col)) group by 1;
drop view v1;
drop table t1;
drop table t2;

-- Not supported
SELECT row_number() OVER IDENTIFIER('x.win') FROM VALUES(1) AS T(c1) WINDOW win AS (ORDER BY c1);
SELECT T1.c1 FROM VALUES(1) AS T1(c1) JOIN VALUES(1) AS T2(c1) USING (IDENTIFIER('c1'));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1059,6 +1059,62 @@ org.apache.spark.sql.catalyst.parser.ParseException
}


-- !query
create temporary view identifier('v1') as (select my_col from (values (1), (2), (1) as (my_col)) group by 1)
-- !query schema
struct<>
-- !query output



-- !query
cache table identifier('t1') as (select my_col from (values (1), (2), (1) as (my_col)) group by 1)
-- !query schema
struct<>
-- !query output



-- !query
create table identifier('t2') as (select my_col from (values (1), (2), (1) as (my_col)) group by 1)
-- !query schema
struct<>
-- !query output



-- !query
insert into identifier('t2') select my_col from (values (3) as (my_col)) group by 1
-- !query schema
struct<>
-- !query output



-- !query
drop view v1
-- !query schema
struct<>
-- !query output



-- !query
drop table t1
-- !query schema
struct<>
-- !query output



-- !query
drop table t2
-- !query schema
struct<>
-- !query output



-- !query
SELECT row_number() OVER IDENTIFIER('x.win') FROM VALUES(1) AS T(c1) WINDOW win AS (ORDER BY c1)
-- !query schema
Expand Down

0 comments on commit 731a2cf

Please sign in to comment.