Skip to content

Commit

Permalink
[SPARK-48197][SQL] Avoid assert error for invalid lambda function
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

`ExpressionBuilder` asserts all its input expressions to be resolved during lookup, which is not true as the analyzer rule `ResolveFunctions` can trigger function lookup even if the input expression contains unresolved lambda functions.

This PR updates that assert to check non-lambda inputs only, and fail earlier if the input contains lambda functions. In the future, if we use `ExpressionBuilder` to register higher-order functions, we can relax it.

### Why are the changes needed?

better error message

### Does this PR introduce _any_ user-facing change?

no, only changes error message

### How was this patch tested?

new test

### Was this patch authored or co-authored using generative AI tooling?

no

Closes apache#46475 from cloud-fan/minor.

Authored-by: Wenchen Fan <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
cloud-fan committed May 9, 2024
1 parent 337f980 commit 7e79e91
Show file tree
Hide file tree
Showing 7 changed files with 96 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -955,7 +955,14 @@ object FunctionRegistry {
since: Option[String] = None): (String, (ExpressionInfo, FunctionBuilder)) = {
val info = FunctionRegistryBase.expressionInfo[T](name, since)
val funcBuilder = (expressions: Seq[Expression]) => {
assert(expressions.forall(_.resolved), "function arguments must be resolved.")
val (lambdas, others) = expressions.partition(_.isInstanceOf[LambdaFunction])
if (lambdas.nonEmpty && !builder.supportsLambda) {
throw new AnalysisException(
errorClass = "INVALID_LAMBDA_FUNCTION_CALL.NON_HIGHER_ORDER_FUNCTION",
messageParameters = Map(
"class" -> builder.getClass.getCanonicalName))
}
assert(others.forall(_.resolved), "function arguments must be resolved.")
val rearrangedExpressions = rearrangeExpressions(name, builder, expressions)
val expr = builder.build(name, rearrangedExpressions)
if (setAlias) expr.setTagValue(FUNC_ALIAS, name)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ trait FunctionBuilderBase[T] {
}

def build(funcName: String, expressions: Seq[Expression]): T

def supportsLambda: Boolean = false
}

object NamedParametersSupport {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,26 @@ org.apache.spark.sql.AnalysisException
}


-- !query
select ceil(x -> x) as v
-- !query analysis
org.apache.spark.sql.AnalysisException
{
"errorClass" : "INVALID_LAMBDA_FUNCTION_CALL.NON_HIGHER_ORDER_FUNCTION",
"sqlState" : "42K0D",
"messageParameters" : {
"class" : "org.apache.spark.sql.catalyst.expressions.CeilExpressionBuilder$"
},
"queryContext" : [ {
"objectType" : "",
"objectName" : "",
"startIndex" : 8,
"stopIndex" : 19,
"fragment" : "ceil(x -> x)"
} ]
}


-- !query
select transform(zs, z -> z) as v from nested
-- !query analysis
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,26 @@ org.apache.spark.sql.AnalysisException
}


-- !query
select ceil(x -> x) as v
-- !query analysis
org.apache.spark.sql.AnalysisException
{
"errorClass" : "INVALID_LAMBDA_FUNCTION_CALL.NON_HIGHER_ORDER_FUNCTION",
"sqlState" : "42K0D",
"messageParameters" : {
"class" : "org.apache.spark.sql.catalyst.expressions.CeilExpressionBuilder$"
},
"queryContext" : [ {
"objectType" : "",
"objectName" : "",
"startIndex" : 8,
"stopIndex" : 19,
"fragment" : "ceil(x -> x)"
} ]
}


-- !query
select transform(zs, z -> z) as v from nested
-- !query analysis
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ create or replace temporary view nested as values

-- Only allow lambda's in higher order functions.
select upper(x -> x) as v;
-- Also test functions registered with `ExpressionBuilder`.
select ceil(x -> x) as v;

-- Identity transform an array
select transform(zs, z -> z) as v from nested;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,28 @@ org.apache.spark.sql.AnalysisException
}


-- !query
select ceil(x -> x) as v
-- !query schema
struct<>
-- !query output
org.apache.spark.sql.AnalysisException
{
"errorClass" : "INVALID_LAMBDA_FUNCTION_CALL.NON_HIGHER_ORDER_FUNCTION",
"sqlState" : "42K0D",
"messageParameters" : {
"class" : "org.apache.spark.sql.catalyst.expressions.CeilExpressionBuilder$"
},
"queryContext" : [ {
"objectType" : "",
"objectName" : "",
"startIndex" : 8,
"stopIndex" : 19,
"fragment" : "ceil(x -> x)"
} ]
}


-- !query
select transform(zs, z -> z) as v from nested
-- !query schema
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,28 @@ org.apache.spark.sql.AnalysisException
}


-- !query
select ceil(x -> x) as v
-- !query schema
struct<>
-- !query output
org.apache.spark.sql.AnalysisException
{
"errorClass" : "INVALID_LAMBDA_FUNCTION_CALL.NON_HIGHER_ORDER_FUNCTION",
"sqlState" : "42K0D",
"messageParameters" : {
"class" : "org.apache.spark.sql.catalyst.expressions.CeilExpressionBuilder$"
},
"queryContext" : [ {
"objectType" : "",
"objectName" : "",
"startIndex" : 8,
"stopIndex" : 19,
"fragment" : "ceil(x -> x)"
} ]
}


-- !query
select transform(zs, z -> z) as v from nested
-- !query schema
Expand Down

0 comments on commit 7e79e91

Please sign in to comment.