diff --git a/core/src/main/scala/com/github/mrpowers/spark/fast/tests/DatasetComparer.scala b/core/src/main/scala/com/github/mrpowers/spark/fast/tests/DatasetComparer.scala index 70b30dc..e77a59d 100644 --- a/core/src/main/scala/com/github/mrpowers/spark/fast/tests/DatasetComparer.scala +++ b/core/src/main/scala/com/github/mrpowers/spark/fast/tests/DatasetComparer.scala @@ -38,7 +38,7 @@ Expected DataFrame Row Count: '$expectedCount' /** * Raises an error unless `actualDS` and `expectedDS` are equal */ - def assertSmallDatasetEquality[T]( + def assertSmallDatasetEquality[T: ClassTag]( actualDS: Dataset[T], expectedDS: Dataset[T], ignoreNullable: Boolean = false, @@ -53,7 +53,7 @@ Expected DataFrame Row Count: '$expectedCount' assertSmallDatasetContentEquality(actual, expectedDS, orderedComparison, truncate, equals) } - def assertSmallDatasetContentEquality[T]( + def assertSmallDatasetContentEquality[T: ClassTag]( actualDS: Dataset[T], expectedDS: Dataset[T], orderedComparison: Boolean, @@ -66,12 +66,12 @@ Expected DataFrame Row Count: '$expectedCount' assertSmallDatasetContentEquality(defaultSortDataset(actualDS), defaultSortDataset(expectedDS), truncate, equals) } - def assertSmallDatasetContentEquality[T](actualDS: Dataset[T], expectedDS: Dataset[T], truncate: Int, equals: (T, T) => Boolean): Unit = { + def assertSmallDatasetContentEquality[T: ClassTag](actualDS: Dataset[T], expectedDS: Dataset[T], truncate: Int, equals: (T, T) => Boolean): Unit = { val a = actualDS.collect().toSeq val e = expectedDS.collect().toSeq if (!a.approximateSameElements(e, equals)) { val arr = ("Actual Content", "Expected Content") - val msg = "Diffs\n" ++ DataframeUtil.showDataframeDiff(arr, a.asRows, e.asRows, truncate) + val msg = "Diffs\n" ++ ProductUtil.showProductDiff(arr, a, e, truncate) throw DatasetContentMismatch(msg) } } diff --git a/core/src/main/scala/com/github/mrpowers/spark/fast/tests/DataframeUtil.scala b/core/src/main/scala/com/github/mrpowers/spark/fast/tests/ProductUtil.scala similarity index 64% rename from core/src/main/scala/com/github/mrpowers/spark/fast/tests/DataframeUtil.scala rename to core/src/main/scala/com/github/mrpowers/spark/fast/tests/ProductUtil.scala index 6cfde87..b4b464c 100644 --- a/core/src/main/scala/com/github/mrpowers/spark/fast/tests/DataframeUtil.scala +++ b/core/src/main/scala/com/github/mrpowers/spark/fast/tests/ProductUtil.scala @@ -1,49 +1,60 @@ package com.github.mrpowers.spark.fast.tests import com.github.mrpowers.spark.fast.tests.ufansi.Color.{DarkGray, Green, Red} +import com.github.mrpowers.spark.fast.tests.ufansi.FansiExtensions.StrOps import org.apache.commons.lang3.StringUtils import org.apache.spark.sql.Row -import com.github.mrpowers.spark.fast.tests.ufansi.FansiExtensions.StrOps -object DataframeUtil { - private[mrpowers] def showDataframeDiff( +import scala.reflect.ClassTag + +object ProductUtil { + private[mrpowers] def productOrRowToSeq(product: Any): Seq[Any] = { + product match { + case null => Seq.empty + case a: Array[_] => a + case i: Iterable[_] => i.toSeq + case r: Row => r.toSeq + case p: Product => p.productIterator.toSeq + case s => Seq(s) + } + } + private[mrpowers] def showProductDiff[T: ClassTag]( header: (String, String), - actual: Seq[Row], - expected: Seq[Row], + actual: Seq[T], + expected: Seq[T], truncate: Int = 20, minColWidth: Int = 3 ): String = { + val runTimeClass = implicitly[ClassTag[T]].runtimeClass + val (className, lBracket, rBracket) = if (runTimeClass == classOf[Row]) ("", "[", "]") else (runTimeClass.getSimpleName, "(", ")") + val prodToString: Seq[Any] => String = s => s.mkString(s"$className$lBracket", ",", rBracket) + val emptyProd = "MISSING" + val sb = new StringBuilder - val fullJoin = actual.zipAll(expected, Row(), Row()) + val fullJoin = actual.zipAll(expected, null, null) + val diff = fullJoin.map { case (actualRow, expectedRow) => - if (equals(actualRow, expectedRow)) { + if (actualRow == expectedRow) { List(DarkGray(actualRow.toString), DarkGray(expectedRow.toString)) } else { - val actualSeq = actualRow.toSeq - val expectedSeq = expectedRow.toSeq + val actualSeq = productOrRowToSeq(actualRow) + val expectedSeq = productOrRowToSeq(expectedRow) if (actualSeq.isEmpty) - List( - Red("[]"), - Green(expectedSeq.mkString("[", ",", "]")) - ) + List(Red(emptyProd), Green(prodToString(expectedSeq))) else if (expectedSeq.isEmpty) - List(Red(actualSeq.mkString("[", ",", "]")), Green("[]")) + List(Red(prodToString(actualSeq)), Green(emptyProd)) else { val withEquals = actualSeq - .zip(expectedSeq) + .zipAll(expectedSeq, "MISSING", "MISSING") .map { case (actualRowField, expectedRowField) => (actualRowField, expectedRowField, actualRowField == expectedRowField) } val allFieldsAreNotEqual = !withEquals.exists(_._3) if (allFieldsAreNotEqual) { - List( - Red(actualSeq.mkString("[", ",", "]")), - Green(expectedSeq.mkString("[", ",", "]")) - ) + List(Red(prodToString(actualSeq)), Green(prodToString(expectedSeq))) } else { - val coloredDiff = withEquals .map { case (actualRowField, expectedRowField, true) => @@ -51,9 +62,9 @@ object DataframeUtil { case (actualRowField, expectedRowField, false) => (Red(actualRowField.toString), Green(expectedRowField.toString)) } - val start = DarkGray("[") + val start = DarkGray(s"$className$lBracket") val sep = DarkGray(",") - val end = DarkGray("]") + val end = DarkGray(rBracket) List( coloredDiff.map(_._1).mkStr(start, sep, end), coloredDiff.map(_._2).mkStr(start, sep, end) @@ -69,11 +80,12 @@ object DataframeUtil { val colWidths = Array.fill(numCols)(minColWidth) // Compute the width of each column - for ((cell, i) <- headerSeq.zipWithIndex) { + headerSeq.zipWithIndex.foreach({ case (cell, i) => colWidths(i) = math.max(colWidths(i), cell.length) - } - for (row <- diff) { - for ((cell, i) <- row.zipWithIndex) { + }) + + diff.foreach { row => + row.zipWithIndex.foreach { case (cell, i) => colWidths(i) = math.max(colWidths(i), cell.length) } } @@ -117,5 +129,4 @@ object DataframeUtil { sb.toString } - } diff --git a/core/src/main/scala/com/github/mrpowers/spark/fast/tests/SchemaComparer.scala b/core/src/main/scala/com/github/mrpowers/spark/fast/tests/SchemaComparer.scala index ce1edfe..89f6783 100644 --- a/core/src/main/scala/com/github/mrpowers/spark/fast/tests/SchemaComparer.scala +++ b/core/src/main/scala/com/github/mrpowers/spark/fast/tests/SchemaComparer.scala @@ -1,29 +1,18 @@ package com.github.mrpowers.spark.fast.tests +import com.github.mrpowers.spark.fast.tests.ProductUtil.showProductDiff import org.apache.spark.sql.Dataset -import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType} +import org.apache.spark.sql.types.{ArrayType, DataType, MapType, NullType, StructField, StructType} object SchemaComparer { - case class DatasetSchemaMismatch(smth: String) extends Exception(smth) private def betterSchemaMismatchMessage[T](actualDS: Dataset[T], expectedDS: Dataset[T]): String = { - "\nActual Schema Field | Expected Schema Field\n" + actualDS.schema - .zipAll( - expectedDS.schema, - "", - "" - ) - .map { - case (sf1, sf2) if sf1 == sf2 => - ufansi.Color.Blue(s"$sf1 | $sf2") - case ("", sf2) => - ufansi.Color.Red(s"MISSING | $sf2") - case (sf1, "") => - ufansi.Color.Red(s"$sf1 | MISSING") - case (sf1, sf2) => - ufansi.Color.Red(s"$sf1 | $sf2") - } - .mkString("\n") + showProductDiff( + ("Actual Schema", "Expected Schema"), + actualDS.schema.fields, + expectedDS.schema.fields, + truncate = 200 + ) } def assertSchemaEqual[T]( @@ -36,7 +25,7 @@ object SchemaComparer { require((ignoreColumnNames, ignoreColumnOrder) != (true, true), "Cannot set both ignoreColumnNames and ignoreColumnOrder to true.") if (!SchemaComparer.equals(actualDS.schema, expectedDS.schema, ignoreNullable, ignoreColumnNames, ignoreColumnOrder)) { throw DatasetSchemaMismatch( - betterSchemaMismatchMessage(actualDS, expectedDS) + "Diffs\n" + betterSchemaMismatchMessage(actualDS, expectedDS) ) } } @@ -76,5 +65,4 @@ object SchemaComparer { case _ => dt1 == dt2 } } - } diff --git a/core/src/test/scala/com/github/mrpowers/spark/fast/tests/DataFrameComparerTest.scala b/core/src/test/scala/com/github/mrpowers/spark/fast/tests/DataFrameComparerTest.scala index 3cda19c..3b61324 100644 --- a/core/src/test/scala/com/github/mrpowers/spark/fast/tests/DataFrameComparerTest.scala +++ b/core/src/test/scala/com/github/mrpowers/spark/fast/tests/DataFrameComparerTest.scala @@ -1,6 +1,6 @@ package com.github.mrpowers.spark.fast.tests -import org.apache.spark.sql.types.{DoubleType, IntegerType, StringType} +import org.apache.spark.sql.types.{DoubleType, IntegerType, LongType, StringType} import SparkSessionExt._ import com.github.mrpowers.spark.fast.tests.SchemaComparer.DatasetSchemaMismatch import com.github.mrpowers.spark.fast.tests.StringExt.StringOps @@ -331,6 +331,41 @@ class DataFrameComparerTest extends AnyFreeSpec with DataFrameComparer with Spar assertLargeDataFrameEquality(sourceDF, expectedDF) } } + + "correctly mark unequal schema field" in { + val sourceDF = spark.createDF( + List( + (1, 2.0), + (5, 3.0) + ), + List( + ("number", IntegerType, true), + ("float", DoubleType, true) + ) + ) + + val expectedDF = spark.createDF( + List( + (1, "word", 1L), + (5, "word", 2L) + ), + List( + ("number", IntegerType, true), + ("word", StringType, true), + ("long", LongType, true) + ) + ) + + val e = intercept[DatasetSchemaMismatch] { + assertSmallDataFrameEquality(sourceDF, expectedDF) + } + + val colourGroup = e.getMessage.extractColorGroup + val expectedColourGroup = colourGroup.get(Console.GREEN) + val actualColourGroup = colourGroup.get(Console.RED) + assert(expectedColourGroup.contains(Seq("word", "StringType", "StructField(long,LongType,true,{})"))) + assert(actualColourGroup.contains(Seq("float", "DoubleType", "MISSING"))) + } } "assertApproximateDataFrameEquality" - { diff --git a/core/src/test/scala/com/github/mrpowers/spark/fast/tests/DatasetComparerTest.scala b/core/src/test/scala/com/github/mrpowers/spark/fast/tests/DatasetComparerTest.scala index f7a9f57..726965c 100644 --- a/core/src/test/scala/com/github/mrpowers/spark/fast/tests/DatasetComparerTest.scala +++ b/core/src/test/scala/com/github/mrpowers/spark/fast/tests/DatasetComparerTest.scala @@ -2,8 +2,10 @@ package com.github.mrpowers.spark.fast.tests import org.apache.spark.sql.types._ import SparkSessionExt._ +import com.github.mrpowers.spark.fast.tests.ProductUtil.showProductDiff import com.github.mrpowers.spark.fast.tests.SchemaComparer.DatasetSchemaMismatch import com.github.mrpowers.spark.fast.tests.StringExt.StringOps +import org.apache.spark.sql.Row import org.scalatest.freespec.AnyFreeSpec object Person { @@ -62,8 +64,101 @@ class DatasetComparerTest extends AnyFreeSpec with DatasetComparer with SparkSes val colourGroup = e.getMessage.extractColorGroup val expectedColourGroup = colourGroup.get(Console.GREEN) val actualColourGroup = colourGroup.get(Console.RED) - assert(expectedColourGroup.contains(Seq("[frank,10]", "lucy"))) - assert(actualColourGroup.contains(Seq("[bob,1]", "alice"))) + assert(expectedColourGroup.contains(Seq("Person(frank,10)", "lucy"))) + assert(actualColourGroup.contains(Seq("Person(bob,1)", "alice"))) + } + + "correctly mark unequal element for Dataset[String]" in { + import spark.implicits._ + val sourceDS = Seq("word", "StringType", "StructField(long,LongType,true,{})").toDS + + val expectedDS = List("word", "StringType", "StructField(long,LongType2,true,{})").toDS + + val e = intercept[DatasetContentMismatch] { + assertSmallDatasetEquality(sourceDS, expectedDS) + } + + val colourGroup = e.getMessage.extractColorGroup + val expectedColourGroup = colourGroup.get(Console.GREEN) + val actualColourGroup = colourGroup.get(Console.RED) + assert(expectedColourGroup.contains(Seq("String(StructField(long,LongType2,true,{}))"))) + assert(actualColourGroup.contains(Seq("String(StructField(long,LongType,true,{}))"))) + } + + "correctly mark unequal element for Dataset[Seq[String]]" in { + import spark.implicits._ + + val sourceDS = Seq( + Seq("apple", "banana", "cherry"), + Seq("dog", "cat"), + Seq("red", "green", "blue") + ).toDS + + val expectedDS = Seq( + Seq("apple", "banana2"), + Seq("dog", "cat"), + Seq("red", "green", "blue") + ).toDS + + val e = intercept[DatasetContentMismatch] { + assertSmallDatasetEquality(sourceDS, expectedDS) + } + + val colourGroup = e.getMessage.extractColorGroup + val expectedColourGroup = colourGroup.get(Console.GREEN) + val actualColourGroup = colourGroup.get(Console.RED) + assert(expectedColourGroup.contains(Seq("banana2", "MISSING"))) + assert(actualColourGroup.contains(Seq("banana", "cherry"))) + } + + "correctly mark unequal element for Dataset[Array[String]]" in { + import spark.implicits._ + + val sourceDS = Seq( + Array("apple", "banana", "cherry"), + Array("dog", "cat"), + Array("red", "green", "blue") + ).toDS + + val expectedDS = Seq( + Array("apple", "banana2"), + Array("dog", "cat"), + Array("red", "green", "blue") + ).toDS + + val e = intercept[DatasetContentMismatch] { + assertSmallDatasetEquality(sourceDS, expectedDS) + } + + val colourGroup = e.getMessage.extractColorGroup + val expectedColourGroup = colourGroup.get(Console.GREEN) + val actualColourGroup = colourGroup.get(Console.RED) + assert(expectedColourGroup.contains(Seq("banana2", "MISSING"))) + assert(actualColourGroup.contains(Seq("banana", "cherry"))) + } + + "correctly mark unequal element for Dataset[Map[String, String]]" in { + import spark.implicits._ + + val sourceDS = Seq( + Map("apple" -> "banana", "apple1" -> "banana1"), + Map("apple" -> "banana", "apple1" -> "banana1") + ).toDS + + val expectedDS = Seq( + Map("apple" -> "banana1", "apple1" -> "banana1"), + Map("apple" -> "banana", "apple1" -> "banana1") + ).toDS + + val e = intercept[DatasetContentMismatch] { + assertSmallDatasetEquality(sourceDS, expectedDS) + } + + val colourGroup = e.getMessage.extractColorGroup + val expectedColourGroup = colourGroup.get(Console.GREEN) + val actualColourGroup = colourGroup.get(Console.RED) + assert(expectedColourGroup.contains(Seq("(apple,banana1)"))) + assert(actualColourGroup.contains(Seq("(apple,banana)"))) } "works with really long columns" in { @@ -154,29 +249,64 @@ class DatasetComparerTest extends AnyFreeSpec with DatasetComparer with SparkSes } "throws an error if the DataFrames have different schemas" in { + val nestedSchema = StructType( + Seq( + StructField( + "attributes", + StructType( + Seq( + StructField("PostCode", IntegerType, nullable = true) + ) + ), + nullable = true + ) + ) + ) + + val nestedSchema2 = StructType( + Seq( + StructField( + "attributes", + StructType( + Seq( + StructField("PostCode", StringType, nullable = true) + ) + ), + nullable = true + ) + ) + ) + val sourceDF = spark.createDF( List( - (1), - (5) + (1, 2.0, null), + (5, 3.0, null) ), - List(("number", IntegerType, true)) + List( + ("number", IntegerType, true), + ("float", DoubleType, true), + ("nestedField", nestedSchema, true) + ) ) val expectedDF = spark.createDF( List( - (1, "word"), - (5, "word") + (1, "word", null, 1L), + (5, "word", null, 2L) ), List( ("number", IntegerType, true), - ("word", StringType, true) + ("word", StringType, true), + ("nestedField", nestedSchema2, true), + ("long", LongType, true) ) ) - val e = intercept[DatasetSchemaMismatch] { + intercept[DatasetSchemaMismatch] { assertLargeDatasetEquality(sourceDF, expectedDF) } - val e2 = intercept[DatasetSchemaMismatch] { + + intercept[DatasetSchemaMismatch] { assertSmallDatasetEquality(sourceDF, expectedDF) } } @@ -430,6 +560,41 @@ class DatasetComparerTest extends AnyFreeSpec with DatasetComparer with SparkSes assertLargeDatasetEquality(ds1, ds2, ignoreColumnOrder = true) assertLargeDatasetEquality(ds2, ds1, ignoreColumnOrder = true) } + + "correctly mark unequal schema field" in { + val sourceDF = spark.createDF( + List( + (1, 2.0), + (5, 3.0) + ), + List( + ("number", IntegerType, true), + ("float", DoubleType, true) + ) + ) + + val expectedDF = spark.createDF( + List( + (1, "word", 1L), + (5, "word", 2L) + ), + List( + ("number", IntegerType, true), + ("word", StringType, true), + ("long", LongType, true) + ) + ) + + val e = intercept[DatasetSchemaMismatch] { + assertLargeDatasetEquality(sourceDF, expectedDF) + } + + val colourGroup = e.getMessage.extractColorGroup + val expectedColourGroup = colourGroup.get(Console.GREEN) + val actualColourGroup = colourGroup.get(Console.RED) + assert(expectedColourGroup.contains(Seq("word", "StringType", "StructField(long,LongType,true,{})"))) + assert(actualColourGroup.contains(Seq("float", "DoubleType", "MISSING"))) + } } "assertSmallDatasetEquality" - { @@ -611,9 +776,43 @@ class DatasetComparerTest extends AnyFreeSpec with DatasetComparer with SparkSes Person("alice", 5) ).toDS.select("age", "name").as(ds1.encoder) - assertSmallDatasetEquality(ds1, ds2, ignoreColumnOrder = true) assertSmallDatasetEquality(ds2, ds1, ignoreColumnOrder = true) } + + "correctly mark unequal schema field" in { + val sourceDF = spark.createDF( + List( + (1, 2.0), + (5, 3.0) + ), + List( + ("number", IntegerType, true), + ("float", DoubleType, true) + ) + ) + + val expectedDF = spark.createDF( + List( + (1, "word", 1L), + (5, "word", 2L) + ), + List( + ("number", IntegerType, true), + ("word", StringType, true), + ("long", LongType, true) + ) + ) + + val e = intercept[DatasetSchemaMismatch] { + assertSmallDatasetEquality(sourceDF, expectedDF) + } + + val colourGroup = e.getMessage.extractColorGroup + val expectedColourGroup = colourGroup.get(Console.GREEN) + val actualColourGroup = colourGroup.get(Console.RED) + assert(expectedColourGroup.contains(Seq("word", "StringType", "StructField(long,LongType,true,{})"))) + assert(actualColourGroup.contains(Seq("float", "DoubleType", "MISSING"))) + } } "defaultSortDataset" - {