diff --git a/core/src/main/scala/com/github/mrpowers/spark/fast/tests/DatasetComparer.scala b/core/src/main/scala/com/github/mrpowers/spark/fast/tests/DatasetComparer.scala index de98d10..df405c8 100644 --- a/core/src/main/scala/com/github/mrpowers/spark/fast/tests/DatasetComparer.scala +++ b/core/src/main/scala/com/github/mrpowers/spark/fast/tests/DatasetComparer.scala @@ -4,10 +4,12 @@ import com.github.mrpowers.spark.fast.tests.DatasetUtils.DatasetOps import com.github.mrpowers.spark.fast.tests.DatasetComparer.maxUnequalRowsToShow import com.github.mrpowers.spark.fast.tests.SeqLikesExtensions.SeqExtensions import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.functions._ -import org.apache.spark.sql.{DataFrame, Dataset, Encoder, Row} +import org.apache.spark.sql.{Column, DataFrame, Dataset, Row} import scala.reflect.ClassTag +import scala.reflect.runtime.universe.TypeTag case class DatasetContentMismatch(smth: String) extends Exception(smth) case class DatasetCountMismatch(smth: String) extends Exception(smth) @@ -73,7 +75,7 @@ Expected DataFrame Row Count: '$expectedCount' val e = expectedDS.collect().toSeq if (!a.approximateSameElements(e, equals)) { val arr = ("Actual Content", "Expected Content") - val msg = "Diffs\n" ++ ProductUtil.showProductDiff(arr, a, e, truncate) + val msg = "Diffs\n" ++ ProductUtil.showProductDiff(arr, Left(a -> e), truncate) throw DatasetContentMismatch(msg) } } @@ -159,29 +161,33 @@ Expected DataFrame Row Count: '$expectedCount' * partitioned in the same way and become unreliable when shuffling is involved. * @param primaryKeys * unique identifier for each row to ensure accurate comparison of rows + * @param checkKeyUniqueness + * if true, will check if the primary key is actually unique */ - def assertLargeDatasetEqualityV2[T: ClassTag]( + def assertLargeDatasetEqualityV2[T: ClassTag: TypeTag]( actualDS: Dataset[T], expectedDS: Dataset[T], - equals: (T, T) => Boolean = (o1: T, o2: T) => o1.equals(o2), + equals: Either[(T, T) => Boolean, Option[Column]] = Right(None), ignoreNullable: Boolean = false, ignoreColumnNames: Boolean = false, ignoreColumnOrder: Boolean = false, ignoreMetadata: Boolean = true, + checkKeyUniqueness: Boolean = false, primaryKeys: Seq[String] = Seq.empty, truncate: Int = 500 ): Unit = { // first check if the schemas are equal SchemaComparer.assertDatasetSchemaEqual(actualDS, expectedDS, ignoreNullable, ignoreColumnNames, ignoreColumnOrder, ignoreMetadata) val actual = if (ignoreColumnOrder) orderColumns(actualDS, expectedDS) else actualDS - assertLargeDatasetContentEqualityV2(actual, expectedDS, equals, primaryKeys, truncate) + assertLargeDatasetContentEqualityV2(actual, expectedDS, equals, primaryKeys, checkKeyUniqueness, truncate) } - def assertLargeDatasetContentEqualityV2[T: ClassTag]( + def assertLargeDatasetContentEqualityV2[T: ClassTag: TypeTag]( ds1: Dataset[T], ds2: Dataset[T], - equals: (T, T) => Boolean, + equals: Either[(T, T) => Boolean, Option[Column]], primaryKeys: Seq[String], + checkKeyUniqueness: Boolean, truncate: Int ): Unit = { try { @@ -195,14 +201,32 @@ Expected DataFrame Row Count: '$expectedCount' throw DatasetCountMismatch(countMismatchMessage(actualCount, expectedCount)) } - val unequalDf = ds1 - .joinPair(ds2, primaryKeys) - .filter((p: (T, T)) => !equals(p._1, p._2)) + if (primaryKeys.nonEmpty && checkKeyUniqueness) { + assert(ds1.isKeyUnique(primaryKeys), "Primary key is not unique in actual dataset") + assert(ds2.isKeyUnique(primaryKeys), "Primary key is not unique in expected dataset") + } + + val joinedDf = ds1 + .outerJoinWith(ds2, primaryKeys) + + val unequalDS = equals match { + case Left(customEquals) => + joinedDf.filter((p: (Option[T], Option[T])) => + p match { + case (Some(l), Some(r)) => !customEquals(l, r) + case (None, None) => false + case _ => true + } + ) + + case Right(equalExprOption) => + joinedDf.filter(equalExprOption.getOrElse(col("l") =!= col("r"))) + } - if (!unequalDf.isEmpty) { - val (a, e) = unequalDf.take(maxUnequalRowsToShow).toSeq.unzip + if (!unequalDS.isEmpty) { + val joined = Right(unequalDS.take(truncate).toSeq) val arr = ("Actual Content", "Expected Content") - val msg = "Diffs\n" ++ ProductUtil.showProductDiff(arr, a, e, truncate) + val msg = "Diffs\n" ++ ProductUtil.showProductDiff(arr, joined, truncate) throw DatasetContentMismatch(msg) } diff --git a/core/src/main/scala/com/github/mrpowers/spark/fast/tests/DatasetUtils.scala b/core/src/main/scala/com/github/mrpowers/spark/fast/tests/DatasetUtils.scala index 37da01d..6e7dd4a 100644 --- a/core/src/main/scala/com/github/mrpowers/spark/fast/tests/DatasetUtils.scala +++ b/core/src/main/scala/com/github/mrpowers/spark/fast/tests/DatasetUtils.scala @@ -1,49 +1,74 @@ package com.github.mrpowers.spark.fast.tests +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.OptionEncoder +import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, ExpressionEncoder, RowEncoder} import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions._ -import org.apache.spark.sql.{DataFrame, Dataset, Encoder, Row} +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.{DataFrame, Dataset, Encoder, Encoders, Row, TypedColumn} import scala.reflect.ClassTag +import scala.reflect.runtime.universe.TypeTag private object DatasetUtils { - implicit class DatasetOps[T: ClassTag](ds: Dataset[T]) { + implicit class DatasetOps[T: ClassTag: TypeTag](ds: Dataset[T]) { def zipWithIndex(indexName: String): DataFrame = ds .orderBy() .withColumn(indexName, row_number().over(Window.orderBy(monotonically_increasing_id()))) .select(ds.columns.map(col) :+ col(indexName): _*) - def joinPair( - other: Dataset[T], - primaryKeys: Seq[String] - ): Dataset[(T, T)] = { - if (primaryKeys.nonEmpty) { - ds - .as("l") - .joinWith(other.as("r"), primaryKeys.map(k => col(s"l.$k") === col(s"r.$k")).reduce(_ && _)) + /** + * Check if the primary key is actually unique + */ + def isKeyUnique(primaryKey: Seq[String]): Boolean = + ds.select(primaryKey.map(col): _*).distinct.count == ds.count + + def outerJoinWith[P: ClassTag: TypeTag]( + other: Dataset[P], + primaryKeys: Seq[String], + outerJoinType: String = "full" + ): Dataset[(Option[T], Option[P])] = { + val (ds1, ds2, key) = if (primaryKeys.nonEmpty) { + (ds, other, primaryKeys) } else { val indexName = s"index_${java.util.UUID.randomUUID}" - val columns = ds.columns - val joined = ds - .zipWithIndex(indexName) - .alias("l") - .join(other.zipWithIndex(indexName).alias("r"), indexName) - - val encoder: Encoder[T] = ds.encoder - val leftCols = columns.map(n => col(s"l.$n")) - val rightCols = columns.map(n => col(s"r.$n")) - val (pair1, pair2) = - if (columns.length == 1 && !(implicitly[ClassTag[T]].runtimeClass == classOf[Row])) - (leftCols.head, rightCols.head) - else - (struct(leftCols: _*), struct(rightCols: _*)) - - joined - .select( - pair1.as("l").as[T](encoder), - pair2.as("r").as[T](encoder) - ) + (ds.zipWithIndex(indexName), other.zipWithIndex(indexName), Seq(indexName)) } + + val joined = ds1 + .as("l") + .join(ds2.as("r"), key, s"${outerJoinType}_outer") + + joined.select(colOptionTypedCol[T]("l", ds.schema, key), colOptionTypedCol[P]("r", other.schema, key)) } } + + private def colOptionTypedCol[P: ClassTag: TypeTag]( + colName: String, + schema: StructType, + key: Seq[String] + ): TypedColumn[Any, Option[P]] = { + val columns = schema.names.map(n => col(s"$colName.$n")) + val isRowType = implicitly[ClassTag[P]].runtimeClass == classOf[Row] + val unTypedColumn = + if (columns.length == 1 && !isRowType) + columns.head + else + when(key.map(k => col(s"$colName.$k").isNull).reduce(_ && _), lit(null)).otherwise(struct(columns: _*)) + + val enc: Encoder[Option[P]] = if (isRowType) { + ExpressionEncoder(OptionEncoder(RowEncoder.encoderFor(schema).asInstanceOf[AgnosticEncoder[P]])) + } else { + ExpressionEncoder() + } + unTypedColumn.as(colName).as[Option[P]](enc) + } + + def colToRowCol( + colName: String, + schema: StructType + ): TypedColumn[Any, Row] = { + val columns = schema.names.map(n => col(s"$colName.$n")) + struct(columns: _*).as(colName).as[Row](ExpressionEncoder()) // Encoders.row(schema) + } } diff --git a/core/src/main/scala/com/github/mrpowers/spark/fast/tests/ProductUtil.scala b/core/src/main/scala/com/github/mrpowers/spark/fast/tests/ProductUtil.scala index b4b464c..5257052 100644 --- a/core/src/main/scala/com/github/mrpowers/spark/fast/tests/ProductUtil.scala +++ b/core/src/main/scala/com/github/mrpowers/spark/fast/tests/ProductUtil.scala @@ -5,23 +5,25 @@ import com.github.mrpowers.spark.fast.tests.ufansi.FansiExtensions.StrOps import org.apache.commons.lang3.StringUtils import org.apache.spark.sql.Row +import scala.annotation.tailrec import scala.reflect.ClassTag object ProductUtil { + @tailrec private[mrpowers] def productOrRowToSeq(product: Any): Seq[Any] = { product match { - case null => Seq.empty - case a: Array[_] => a - case i: Iterable[_] => i.toSeq - case r: Row => r.toSeq - case p: Product => p.productIterator.toSeq - case s => Seq(s) + case null | None => Seq.empty + case a: Array[_] => a + case i: Iterable[_] => i.toSeq + case r: Row => r.toSeq + case p: Product => p.productIterator.toSeq + case Some(ov) if !ov.isInstanceOf[Option[_]] => productOrRowToSeq(ov) + case s => Seq(s) } } private[mrpowers] def showProductDiff[T: ClassTag]( header: (String, String), - actual: Seq[T], - expected: Seq[T], + data: Either[(Seq[T], Seq[T]), Seq[(Option[T], Option[T])]], truncate: Int = 20, minColWidth: Int = 3 ): String = { @@ -33,7 +35,10 @@ object ProductUtil { val sb = new StringBuilder - val fullJoin = actual.zipAll(expected, null, null) + val fullJoin = data match { + case Left((actual, expected)) => actual.zipAll(expected, null, null) + case Right(joined) => joined + } val diff = fullJoin.map { case (actualRow, expectedRow) => if (actualRow == expectedRow) { diff --git a/core/src/main/scala/com/github/mrpowers/spark/fast/tests/SchemaComparer.scala b/core/src/main/scala/com/github/mrpowers/spark/fast/tests/SchemaComparer.scala index 32ceaf7..92b66dd 100644 --- a/core/src/main/scala/com/github/mrpowers/spark/fast/tests/SchemaComparer.scala +++ b/core/src/main/scala/com/github/mrpowers/spark/fast/tests/SchemaComparer.scala @@ -14,8 +14,7 @@ object SchemaComparer { private def betterSchemaMismatchMessage(actualSchema: StructType, expectedSchema: StructType): String = { showProductDiff( ("Actual Schema", "Expected Schema"), - actualSchema.fields, - expectedSchema.fields, + Left(actualSchema.fields.toSeq -> expectedSchema.fields.toSeq), truncate = 200 ) } diff --git a/core/src/test/scala/com/github/mrpowers/spark/fast/tests/DatasetComparerV2Test.scala b/core/src/test/scala/com/github/mrpowers/spark/fast/tests/DatasetComparerV2Test.scala index 5c31f7e..922af81 100644 --- a/core/src/test/scala/com/github/mrpowers/spark/fast/tests/DatasetComparerV2Test.scala +++ b/core/src/test/scala/com/github/mrpowers/spark/fast/tests/DatasetComparerV2Test.scala @@ -2,10 +2,11 @@ package com.github.mrpowers.spark.fast.tests import org.apache.spark.sql.types._ import SparkSessionExt._ +import com.github.mrpowers.spark.fast.tests.DatasetUtils.DatasetOps import com.github.mrpowers.spark.fast.tests.SchemaComparer.DatasetSchemaMismatch import com.github.mrpowers.spark.fast.tests.TestUtilsExt.ExceptionOps -import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.functions.{col, lower} +import org.apache.spark.sql.{Row, SparkSession} +import org.apache.spark.sql.functions.col import org.scalatest.freespec.AnyFreeSpec class DatasetComparerV2Test extends AnyFreeSpec with DatasetComparer { @@ -56,7 +57,7 @@ class DatasetComparerV2Test extends AnyFreeSpec with DatasetComparer { Array("red", "green", "blue") ).toDS - assertLargeDatasetEqualityV2(sourceDS, expectedDS, equals = (a1: Array[String], a2: Array[String]) => a1.mkString == a2.mkString) + assertLargeDatasetEqualityV2(sourceDS, expectedDS, equals = Left((a1: Array[String], a2: Array[String]) => a1.mkString == a2.mkString)) } "can compare Dataset[Map[_]]" in { @@ -70,7 +71,11 @@ class DatasetComparerV2Test extends AnyFreeSpec with DatasetComparer { Map("apple" -> "banana", "apple1" -> "banana1") ).toDS - assertLargeDatasetEqualityV2(sourceDS, expectedDS) + assertLargeDatasetEqualityV2( + sourceDS, + expectedDS, + equals = Left((a1: Map[String, String], a2: Map[String, String]) => a1.mkString == a2.mkString) + ) } "does nothing if the Datasets have the same schemas and content" in { @@ -223,7 +228,7 @@ class DatasetComparerV2Test extends AnyFreeSpec with DatasetComparer { Person("Alice", 5) ) ) - assertLargeDatasetEqualityV2(sourceDS, expectedDS, equals = Person.caseInsensitivePersonEquals) + assertLargeDatasetEqualityV2(sourceDS, expectedDS, equals = Left((p1: Person, p2: Person) => Person.caseInsensitivePersonEquals(p1, p2))) } "fails if custom comparator for returns false" in { @@ -241,7 +246,7 @@ class DatasetComparerV2Test extends AnyFreeSpec with DatasetComparer { ) intercept[DatasetContentMismatch] { - assertLargeDatasetEqualityV2(sourceDS, expectedDS, equals = Person.caseInsensitivePersonEquals) + assertLargeDatasetEqualityV2(sourceDS, expectedDS, equals = Left((p1: Person, p2: Person) => Person.caseInsensitivePersonEquals(p1, p2))) } } @@ -433,5 +438,26 @@ class DatasetComparerV2Test extends AnyFreeSpec with DatasetComparer { assertLargeDatasetEqualityV2(ds2, ds1, ignoreMetadata = false) } } + + "join test2" in { + val joinedOuterDS = + spark.range(0, 1000000, 1, 8).outerJoinWith(spark.range(0, 1000000, 2, 8), Seq("id")) + println("joinedOuterDS") + joinedOuterDS.explain() + joinedOuterDS.show(false) + val filteredDs = joinedOuterDS.filter(p => p._1 == p._2) + filteredDs.explain() + filteredDs.show(false) + + val joinedOuterDF = + spark.range(0, 1000000, 1, 8).toDF().outerJoinWith(spark.range(0, 1000000, 2, 8).toDF(), Seq("id")) + println("joinedOuterDF") + joinedOuterDF.explain() + joinedOuterDF.show(false) + + val filtedDf = joinedOuterDF.filter((p: (Option[Row], Option[Row])) => { p._1 == p._2 }) + filtedDf.explain() + filtedDf.show(false) + } } }