diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala index 3e4344f98bce0..0fa11b9c45038 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala @@ -255,12 +255,18 @@ object DeduplicateRelations extends Rule[LogicalPlan] { val newRightGroup = rewriteAttrs(c.rightGroup, rightAttrMap) val newLeftOrder = rewriteAttrs(c.leftOrder, leftAttrMap) val newRightOrder = rewriteAttrs(c.rightOrder, rightAttrMap) - val newKeyDes = c.keyDeserializer.asInstanceOf[UnresolvedDeserializer] - .copy(inputAttributes = newLeftGroup) - val newLeftDes = c.leftDeserializer.asInstanceOf[UnresolvedDeserializer] - .copy(inputAttributes = newLeftAttr) - val newRightDes = c.rightDeserializer.asInstanceOf[UnresolvedDeserializer] - .copy(inputAttributes = newRightAttr) + val newKeyDes = c.keyDeserializer match { + case u: UnresolvedDeserializer => u.copy(inputAttributes = newLeftGroup) + case e: Expression => e.withNewChildren(rewriteAttrs(e.children, leftAttrMap)) + } + val newLeftDes = c.leftDeserializer match { + case u: UnresolvedDeserializer => u.copy(inputAttributes = newLeftAttr) + case e: Expression => e.withNewChildren(rewriteAttrs(e.children, leftAttrMap)) + } + val newRightDes = c.rightDeserializer match { + case u: UnresolvedDeserializer => u.copy(inputAttributes = newRightAttr) + case e: Expression => e.withNewChildren(rewriteAttrs(e.children, rightAttrMap)) + } c.copy(keyDeserializer = newKeyDes, leftDeserializer = newLeftDes, rightDeserializer = newRightDes, leftGroup = newLeftGroup, rightGroup = newRightGroup, leftAttr = newLeftAttr, rightAttr = newRightAttr, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index b939ed40c7db6..fdb2ec30fdd2d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -21,6 +21,7 @@ import java.io.{Externalizable, ObjectInput, ObjectOutput} import java.sql.{Date, Timestamp} import scala.collection.immutable.HashSet +import scala.jdk.CollectionConverters._ import scala.reflect.ClassTag import scala.util.Random @@ -952,6 +953,25 @@ class DatasetSuite extends QueryTest assert(result2.length == 3) } + test("SPARK-48718: cogroup deserializer expr is resolved before dedup relation") { + val lhs = spark.createDataFrame( + List(Row(123L)).asJava, + StructType(Seq(StructField("GROUPING_KEY", LongType))) + ) + val rhs = spark.createDataFrame( + List(Row(0L, 123L)).asJava, + StructType(Seq(StructField("ID", LongType), StructField("GROUPING_KEY", LongType))) + ) + + val lhsKV = lhs.groupByKey((r: Row) => r.getAs[Long]("GROUPING_KEY")) + val rhsKV = rhs.groupByKey((r: Row) => r.getAs[Long]("GROUPING_KEY")) + val cogrouped = lhsKV.cogroup(rhsKV)( + (a: Long, b: Iterator[Row], c: Iterator[Row]) => Iterator(0L) + ) + val joined = rhs.join(cogrouped, col("ID") === col("value"), "left") + checkAnswer(joined, Row(0L, 123L, 0L) :: Nil) + } + test("SPARK-34806: observation on datasets") { val namedObservation = Observation("named") val unnamedObservation = Observation()