diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 9c08ec71c1fde..554b73181116c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -507,32 +507,38 @@ case class SortMergeJoinExec( } /** - * Creates variables for left part of result row. + * Creates variables and declarations for left part of result row. * * In order to defer the access after condition and also only access once in the loop, * the variables should be declared separately from accessing the columns, we can't use the * codegen of BoundReference here. */ - private def createLeftVars(ctx: CodegenContext, leftRow: String): Seq[ExprCode] = { + private def createLeftVars(ctx: CodegenContext, leftRow: String): (Seq[ExprCode], Seq[String]) = { ctx.INPUT_ROW = leftRow left.output.zipWithIndex.map { case (a, i) => val value = ctx.freshName("value") val valueCode = ctx.getValue(leftRow, a.dataType, i.toString) - // declare it as class member, so we can access the column before or in the loop. - ctx.addMutableState(ctx.javaType(a.dataType), value) + val javaType = ctx.javaType(a.dataType) + val defaultValue = ctx.defaultValue(a.dataType) if (a.nullable) { val isNull = ctx.freshName("isNull") - ctx.addMutableState(ctx.JAVA_BOOLEAN, isNull) val code = s""" |$isNull = $leftRow.isNullAt($i); - |$value = $isNull ? ${ctx.defaultValue(a.dataType)} : ($valueCode); + |$value = $isNull ? $defaultValue : ($valueCode); """.stripMargin - ExprCode(code, isNull, value) + val leftVarsDecl = + s""" + |boolean $isNull = false; + |$javaType $value = $defaultValue; + """.stripMargin + (ExprCode(code, isNull, value), leftVarsDecl) } else { - ExprCode(s"$value = $valueCode;", "false", value) + val code = s"$value = $valueCode;" + val leftVarsDecl = s"""$javaType $value = $defaultValue;""" + (ExprCode(code, "false", value), leftVarsDecl) } - } + }.unzip } /** @@ -580,7 +586,7 @@ case class SortMergeJoinExec( val (leftRow, matches) = genScanner(ctx) // Create variables for row from both sides. - val leftVars = createLeftVars(ctx, leftRow) + val (leftVars, leftVarDecl) = createLeftVars(ctx, leftRow) val rightRow = ctx.freshName("rightRow") val rightVars = createRightVar(ctx, rightRow) @@ -617,6 +623,7 @@ case class SortMergeJoinExec( s""" |while (findNextInnerJoinRows($leftInput, $rightInput)) { + | ${leftVarDecl.mkString("\n")} | ${beforeLoop.trim} | scala.collection.Iterator $iterator = $matches.generateIterator(); | while ($iterator.hasNext()) {