From 862af0aec8abd7f63d2232477f1d9d0d0f6296bd Mon Sep 17 00:00:00 2001 From: Tuan Pham Date: Sat, 28 Sep 2024 15:25:14 +1000 Subject: [PATCH 01/10] Add Table support for StructField Diff --- .../spark/fast/tests/DatasetComparer.scala | 8 +-- ...{DataframeUtil.scala => ProductUtil.scala} | 63 +++++++++++-------- .../spark/fast/tests/SchemaComparer.scala | 32 ++++------ .../fast/tests/DatasetComparerTest.scala | 51 +++++++++++++-- 4 files changed, 96 insertions(+), 58 deletions(-) rename core/src/main/scala/com/github/mrpowers/spark/fast/tests/{DataframeUtil.scala => ProductUtil.scala} (65%) diff --git a/core/src/main/scala/com/github/mrpowers/spark/fast/tests/DatasetComparer.scala b/core/src/main/scala/com/github/mrpowers/spark/fast/tests/DatasetComparer.scala index 70b30dc..e71a115 100644 --- a/core/src/main/scala/com/github/mrpowers/spark/fast/tests/DatasetComparer.scala +++ b/core/src/main/scala/com/github/mrpowers/spark/fast/tests/DatasetComparer.scala @@ -38,7 +38,7 @@ Expected DataFrame Row Count: '$expectedCount' /** * Raises an error unless `actualDS` and `expectedDS` are equal */ - def assertSmallDatasetEquality[T]( + def assertSmallDatasetEquality[T: ClassTag]( actualDS: Dataset[T], expectedDS: Dataset[T], ignoreNullable: Boolean = false, @@ -53,7 +53,7 @@ Expected DataFrame Row Count: '$expectedCount' assertSmallDatasetContentEquality(actual, expectedDS, orderedComparison, truncate, equals) } - def assertSmallDatasetContentEquality[T]( + def assertSmallDatasetContentEquality[T: ClassTag]( actualDS: Dataset[T], expectedDS: Dataset[T], orderedComparison: Boolean, @@ -66,12 +66,12 @@ Expected DataFrame Row Count: '$expectedCount' assertSmallDatasetContentEquality(defaultSortDataset(actualDS), defaultSortDataset(expectedDS), truncate, equals) } - def assertSmallDatasetContentEquality[T](actualDS: Dataset[T], expectedDS: Dataset[T], truncate: Int, equals: (T, T) => Boolean): Unit = { + def assertSmallDatasetContentEquality[T: ClassTag](actualDS: Dataset[T], expectedDS: Dataset[T], truncate: Int, equals: (T, T) => Boolean): Unit = { val a = actualDS.collect().toSeq val e = expectedDS.collect().toSeq if (!a.approximateSameElements(e, equals)) { val arr = ("Actual Content", "Expected Content") - val msg = "Diffs\n" ++ DataframeUtil.showDataframeDiff(arr, a.asRows, e.asRows, truncate) + val msg = "Diffs\n" ++ ProductUtil.showProductDiff[T](arr, a, e, truncate) throw DatasetContentMismatch(msg) } } diff --git a/core/src/main/scala/com/github/mrpowers/spark/fast/tests/DataframeUtil.scala b/core/src/main/scala/com/github/mrpowers/spark/fast/tests/ProductUtil.scala similarity index 65% rename from core/src/main/scala/com/github/mrpowers/spark/fast/tests/DataframeUtil.scala rename to core/src/main/scala/com/github/mrpowers/spark/fast/tests/ProductUtil.scala index 6cfde87..ffbe668 100644 --- a/core/src/main/scala/com/github/mrpowers/spark/fast/tests/DataframeUtil.scala +++ b/core/src/main/scala/com/github/mrpowers/spark/fast/tests/ProductUtil.scala @@ -1,35 +1,48 @@ package com.github.mrpowers.spark.fast.tests import com.github.mrpowers.spark.fast.tests.ufansi.Color.{DarkGray, Green, Red} +import com.github.mrpowers.spark.fast.tests.ufansi.FansiExtensions.StrOps import org.apache.commons.lang3.StringUtils import org.apache.spark.sql.Row -import com.github.mrpowers.spark.fast.tests.ufansi.FansiExtensions.StrOps -object DataframeUtil { - private[mrpowers] def showDataframeDiff( +import scala.reflect.ClassTag + +object ProductUtil { + private[mrpowers] def productOrRowToSeq(product: Any): Seq[Any] = { + product match { + case null => Seq.empty + case r: Row => r.toSeq + case p: Product => p.productIterator.toSeq + case _ => throw new IllegalArgumentException("Only Row and Product types are supported") + } + } + private[mrpowers] def showProductDiff[T: ClassTag]( header: (String, String), - actual: Seq[Row], - expected: Seq[Row], + actual: Seq[T], + expected: Seq[T], truncate: Int = 20, - minColWidth: Int = 3 + minColWidth: Int = 3, + defaultVal: T = null.asInstanceOf[T], + border: (String, String) = ("[", "]") ): String = { + val className = implicitly[ClassTag[T]].runtimeClass.getSimpleName + val prodToString: Seq[Any] => String = s => s.mkString(s"$className${border._1}", ",", border._2) + val emptyProd = s"$className()" val sb = new StringBuilder - val fullJoin = actual.zipAll(expected, Row(), Row()) + val fullJoin = actual.zipAll(expected, defaultVal, defaultVal) + val diff = fullJoin.map { case (actualRow, expectedRow) => - if (equals(actualRow, expectedRow)) { + if (actualRow == expectedRow) { List(DarkGray(actualRow.toString), DarkGray(expectedRow.toString)) } else { - val actualSeq = actualRow.toSeq - val expectedSeq = expectedRow.toSeq + val actualSeq = productOrRowToSeq(actualRow) + val expectedSeq = productOrRowToSeq(expectedRow) if (actualSeq.isEmpty) - List( - Red("[]"), - Green(expectedSeq.mkString("[", ",", "]")) - ) + List(Red(emptyProd), Green(prodToString(expectedSeq))) else if (expectedSeq.isEmpty) - List(Red(actualSeq.mkString("[", ",", "]")), Green("[]")) + List(Red(prodToString(actualSeq)), Green(emptyProd)) else { val withEquals = actualSeq .zip(expectedSeq) @@ -38,12 +51,8 @@ object DataframeUtil { } val allFieldsAreNotEqual = !withEquals.exists(_._3) if (allFieldsAreNotEqual) { - List( - Red(actualSeq.mkString("[", ",", "]")), - Green(expectedSeq.mkString("[", ",", "]")) - ) + List(Red(prodToString(actualSeq)), Green(prodToString(expectedSeq))) } else { - val coloredDiff = withEquals .map { case (actualRowField, expectedRowField, true) => @@ -51,9 +60,9 @@ object DataframeUtil { case (actualRowField, expectedRowField, false) => (Red(actualRowField.toString), Green(expectedRowField.toString)) } - val start = DarkGray("[") + val start = DarkGray(s"$className${border._1}") val sep = DarkGray(",") - val end = DarkGray("]") + val end = DarkGray(border._2) List( coloredDiff.map(_._1).mkStr(start, sep, end), coloredDiff.map(_._2).mkStr(start, sep, end) @@ -69,11 +78,12 @@ object DataframeUtil { val colWidths = Array.fill(numCols)(minColWidth) // Compute the width of each column - for ((cell, i) <- headerSeq.zipWithIndex) { + headerSeq.zipWithIndex.foreach({ case (cell, i) => colWidths(i) = math.max(colWidths(i), cell.length) - } - for (row <- diff) { - for ((cell, i) <- row.zipWithIndex) { + }) + + diff.foreach { row => + row.zipWithIndex.foreach { case (cell, i) => colWidths(i) = math.max(colWidths(i), cell.length) } } @@ -117,5 +127,4 @@ object DataframeUtil { sb.toString } - } diff --git a/core/src/main/scala/com/github/mrpowers/spark/fast/tests/SchemaComparer.scala b/core/src/main/scala/com/github/mrpowers/spark/fast/tests/SchemaComparer.scala index ce1edfe..9316d76 100644 --- a/core/src/main/scala/com/github/mrpowers/spark/fast/tests/SchemaComparer.scala +++ b/core/src/main/scala/com/github/mrpowers/spark/fast/tests/SchemaComparer.scala @@ -1,29 +1,20 @@ package com.github.mrpowers.spark.fast.tests +import com.github.mrpowers.spark.fast.tests.ProductUtil.showProductDiff import org.apache.spark.sql.Dataset -import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType} +import org.apache.spark.sql.types.{ArrayType, DataType, MapType, NullType, StructField, StructType} object SchemaComparer { - case class DatasetSchemaMismatch(smth: String) extends Exception(smth) private def betterSchemaMismatchMessage[T](actualDS: Dataset[T], expectedDS: Dataset[T]): String = { - "\nActual Schema Field | Expected Schema Field\n" + actualDS.schema - .zipAll( - expectedDS.schema, - "", - "" - ) - .map { - case (sf1, sf2) if sf1 == sf2 => - ufansi.Color.Blue(s"$sf1 | $sf2") - case ("", sf2) => - ufansi.Color.Red(s"MISSING | $sf2") - case (sf1, "") => - ufansi.Color.Red(s"$sf1 | MISSING") - case (sf1, sf2) => - ufansi.Color.Red(s"$sf1 | $sf2") - } - .mkString("\n") + showProductDiff( + ("Actual Schema", "Expected Schema"), + actualDS.schema.fields, + expectedDS.schema.fields, + truncate = 200, + defaultVal = StructField("SPARK_FAST_TEST_MISSING_FIELD", NullType), + border = ("(", ")") + ) } def assertSchemaEqual[T]( @@ -36,7 +27,7 @@ object SchemaComparer { require((ignoreColumnNames, ignoreColumnOrder) != (true, true), "Cannot set both ignoreColumnNames and ignoreColumnOrder to true.") if (!SchemaComparer.equals(actualDS.schema, expectedDS.schema, ignoreNullable, ignoreColumnNames, ignoreColumnOrder)) { throw DatasetSchemaMismatch( - betterSchemaMismatchMessage(actualDS, expectedDS) + "Diffs\n" + betterSchemaMismatchMessage(actualDS, expectedDS) ) } } @@ -76,5 +67,4 @@ object SchemaComparer { case _ => dt1 == dt2 } } - } diff --git a/core/src/test/scala/com/github/mrpowers/spark/fast/tests/DatasetComparerTest.scala b/core/src/test/scala/com/github/mrpowers/spark/fast/tests/DatasetComparerTest.scala index f7a9f57..d3b2f73 100644 --- a/core/src/test/scala/com/github/mrpowers/spark/fast/tests/DatasetComparerTest.scala +++ b/core/src/test/scala/com/github/mrpowers/spark/fast/tests/DatasetComparerTest.scala @@ -154,31 +154,70 @@ class DatasetComparerTest extends AnyFreeSpec with DatasetComparer with SparkSes } "throws an error if the DataFrames have different schemas" in { + val nestedSchema = StructType( + Seq( + StructField( + "attributes", + StructType( + Seq( + StructField("PostCode", IntegerType, nullable = true) + ) + ), + nullable = true + ) + ) + ) + + val nestedSchema2 = StructType( + Seq( + StructField( + "attributes", + StructType( + Seq( + StructField("PostCode", StringType, nullable = true) + ) + ), + nullable = true + ) + ) + ) + val sourceDF = spark.createDF( List( - (1), - (5) + (1, 2.0, null), + (5, 3.0, null) ), - List(("number", IntegerType, true)) + List( + ("number", IntegerType, true), + ("float", DoubleType, true), + ("nestedField", nestedSchema, true) + ) ) val expectedDF = spark.createDF( List( - (1, "word"), - (5, "word") + (1, "word", null, 1L), + (5, "word", null, 2L) ), List( ("number", IntegerType, true), - ("word", StringType, true) + ("word", StringType, true), + ("nestedField", nestedSchema2, true), + ("long", LongType, true) ) ) val e = intercept[DatasetSchemaMismatch] { assertLargeDatasetEquality(sourceDF, expectedDF) } + println(e) val e2 = intercept[DatasetSchemaMismatch] { assertSmallDatasetEquality(sourceDF, expectedDF) } + println(e2) + + sourceDF.schema.printTreeString() + expectedDF.schema.printTreeString() } "throws an error if the DataFrames content is different" in { From 86fe715ce10b8df6f21d4403d39c713304371402 Mon Sep 17 00:00:00 2001 From: Tuan Pham Date: Sat, 28 Sep 2024 17:29:28 +1000 Subject: [PATCH 02/10] Determine bracket based on type --- .../github/mrpowers/spark/fast/tests/ProductUtil.scala | 10 ++++++---- .../mrpowers/spark/fast/tests/SchemaComparer.scala | 3 +-- .../spark/fast/tests/DataFrameComparerTest.scala | 4 ++-- .../spark/fast/tests/DatasetComparerTest.scala | 4 ++-- 4 files changed, 11 insertions(+), 10 deletions(-) diff --git a/core/src/main/scala/com/github/mrpowers/spark/fast/tests/ProductUtil.scala b/core/src/main/scala/com/github/mrpowers/spark/fast/tests/ProductUtil.scala index ffbe668..ab2da0e 100644 --- a/core/src/main/scala/com/github/mrpowers/spark/fast/tests/ProductUtil.scala +++ b/core/src/main/scala/com/github/mrpowers/spark/fast/tests/ProductUtil.scala @@ -22,12 +22,14 @@ object ProductUtil { expected: Seq[T], truncate: Int = 20, minColWidth: Int = 3, - defaultVal: T = null.asInstanceOf[T], - border: (String, String) = ("[", "]") + defaultVal: T = null.asInstanceOf[T] ): String = { - val className = implicitly[ClassTag[T]].runtimeClass.getSimpleName + + val runTimeClass = implicitly[ClassTag[T]].runtimeClass + val className = runTimeClass.getSimpleName + val border = if (runTimeClass == classOf[Row]) ("[", "]") else ("(", ")") val prodToString: Seq[Any] => String = s => s.mkString(s"$className${border._1}", ",", border._2) - val emptyProd = s"$className()" + val emptyProd = s"$className${border._1}${border._2}" val sb = new StringBuilder diff --git a/core/src/main/scala/com/github/mrpowers/spark/fast/tests/SchemaComparer.scala b/core/src/main/scala/com/github/mrpowers/spark/fast/tests/SchemaComparer.scala index 9316d76..f67fa73 100644 --- a/core/src/main/scala/com/github/mrpowers/spark/fast/tests/SchemaComparer.scala +++ b/core/src/main/scala/com/github/mrpowers/spark/fast/tests/SchemaComparer.scala @@ -12,8 +12,7 @@ object SchemaComparer { actualDS.schema.fields, expectedDS.schema.fields, truncate = 200, - defaultVal = StructField("SPARK_FAST_TEST_MISSING_FIELD", NullType), - border = ("(", ")") + defaultVal = StructField("SPARK_FAST_TEST_MISSING_FIELD", NullType) ) } diff --git a/core/src/test/scala/com/github/mrpowers/spark/fast/tests/DataFrameComparerTest.scala b/core/src/test/scala/com/github/mrpowers/spark/fast/tests/DataFrameComparerTest.scala index 3cda19c..2cf2ae9 100644 --- a/core/src/test/scala/com/github/mrpowers/spark/fast/tests/DataFrameComparerTest.scala +++ b/core/src/test/scala/com/github/mrpowers/spark/fast/tests/DataFrameComparerTest.scala @@ -74,8 +74,8 @@ class DataFrameComparerTest extends AnyFreeSpec with DataFrameComparer with Spar val colourGroup = e.getMessage.extractColorGroup val expectedColourGroup = colourGroup.get(Console.GREEN) val actualColourGroup = colourGroup.get(Console.RED) - assert(expectedColourGroup.contains(Seq("uk", "[steve,10,aus]"))) - assert(actualColourGroup.contains(Seq("france", "[mark,11,usa]"))) + assert(expectedColourGroup.contains(Seq("uk", "Row[steve,10,aus]"))) + assert(actualColourGroup.contains(Seq("france", "Row[mark,11,usa]"))) } "works well for wide DataFrames" in { diff --git a/core/src/test/scala/com/github/mrpowers/spark/fast/tests/DatasetComparerTest.scala b/core/src/test/scala/com/github/mrpowers/spark/fast/tests/DatasetComparerTest.scala index d3b2f73..bcc49a9 100644 --- a/core/src/test/scala/com/github/mrpowers/spark/fast/tests/DatasetComparerTest.scala +++ b/core/src/test/scala/com/github/mrpowers/spark/fast/tests/DatasetComparerTest.scala @@ -62,8 +62,8 @@ class DatasetComparerTest extends AnyFreeSpec with DatasetComparer with SparkSes val colourGroup = e.getMessage.extractColorGroup val expectedColourGroup = colourGroup.get(Console.GREEN) val actualColourGroup = colourGroup.get(Console.RED) - assert(expectedColourGroup.contains(Seq("[frank,10]", "lucy"))) - assert(actualColourGroup.contains(Seq("[bob,1]", "alice"))) + assert(expectedColourGroup.contains(Seq("Person(frank,10)", "lucy"))) + assert(actualColourGroup.contains(Seq("Person(bob,1)", "alice"))) } "works with really long columns" in { From 6a2f2c5c44129f371aafbc073e78be2d9dc0ecd1 Mon Sep 17 00:00:00 2001 From: Tuan Pham Date: Sat, 28 Sep 2024 18:01:52 +1000 Subject: [PATCH 03/10] Add Test for schema diff --- .../spark/fast/tests/ProductUtil.scala | 2 +- .../spark/fast/tests/SchemaComparer.scala | 3 +- .../fast/tests/DataFrameComparerTest.scala | 37 ++++++++- .../fast/tests/DatasetComparerTest.scala | 81 +++++++++++++++++-- 4 files changed, 111 insertions(+), 12 deletions(-) diff --git a/core/src/main/scala/com/github/mrpowers/spark/fast/tests/ProductUtil.scala b/core/src/main/scala/com/github/mrpowers/spark/fast/tests/ProductUtil.scala index ab2da0e..4840ee0 100644 --- a/core/src/main/scala/com/github/mrpowers/spark/fast/tests/ProductUtil.scala +++ b/core/src/main/scala/com/github/mrpowers/spark/fast/tests/ProductUtil.scala @@ -29,7 +29,7 @@ object ProductUtil { val className = runTimeClass.getSimpleName val border = if (runTimeClass == classOf[Row]) ("[", "]") else ("(", ")") val prodToString: Seq[Any] => String = s => s.mkString(s"$className${border._1}", ",", border._2) - val emptyProd = s"$className${border._1}${border._2}" + val emptyProd = "MISSING" val sb = new StringBuilder diff --git a/core/src/main/scala/com/github/mrpowers/spark/fast/tests/SchemaComparer.scala b/core/src/main/scala/com/github/mrpowers/spark/fast/tests/SchemaComparer.scala index f67fa73..89f6783 100644 --- a/core/src/main/scala/com/github/mrpowers/spark/fast/tests/SchemaComparer.scala +++ b/core/src/main/scala/com/github/mrpowers/spark/fast/tests/SchemaComparer.scala @@ -11,8 +11,7 @@ object SchemaComparer { ("Actual Schema", "Expected Schema"), actualDS.schema.fields, expectedDS.schema.fields, - truncate = 200, - defaultVal = StructField("SPARK_FAST_TEST_MISSING_FIELD", NullType) + truncate = 200 ) } diff --git a/core/src/test/scala/com/github/mrpowers/spark/fast/tests/DataFrameComparerTest.scala b/core/src/test/scala/com/github/mrpowers/spark/fast/tests/DataFrameComparerTest.scala index 2cf2ae9..757acb5 100644 --- a/core/src/test/scala/com/github/mrpowers/spark/fast/tests/DataFrameComparerTest.scala +++ b/core/src/test/scala/com/github/mrpowers/spark/fast/tests/DataFrameComparerTest.scala @@ -1,6 +1,6 @@ package com.github.mrpowers.spark.fast.tests -import org.apache.spark.sql.types.{DoubleType, IntegerType, StringType} +import org.apache.spark.sql.types.{DoubleType, IntegerType, LongType, StringType} import SparkSessionExt._ import com.github.mrpowers.spark.fast.tests.SchemaComparer.DatasetSchemaMismatch import com.github.mrpowers.spark.fast.tests.StringExt.StringOps @@ -331,6 +331,41 @@ class DataFrameComparerTest extends AnyFreeSpec with DataFrameComparer with Spar assertLargeDataFrameEquality(sourceDF, expectedDF) } } + + "correctly mark unequal schema field" in { + val sourceDF = spark.createDF( + List( + (1, 2.0), + (5, 3.0) + ), + List( + ("number", IntegerType, true), + ("float", DoubleType, true) + ) + ) + + val expectedDF = spark.createDF( + List( + (1, "word", 1L), + (5, "word", 2L) + ), + List( + ("number", IntegerType, true), + ("word", StringType, true), + ("long", LongType, true) + ) + ) + + val e = intercept[DatasetSchemaMismatch] { + assertSmallDataFrameEquality(sourceDF, expectedDF) + } + + val colourGroup = e.getMessage.extractColorGroup + val expectedColourGroup = colourGroup.get(Console.GREEN) + val actualColourGroup = colourGroup.get(Console.RED) + assert(expectedColourGroup.contains(Seq("word", "StringType", "StructField(long,LongType,true,{})"))) + assert(actualColourGroup.contains(Seq("float", "DoubleType", "MISSING"))) + } } "assertApproximateDataFrameEquality" - { diff --git a/core/src/test/scala/com/github/mrpowers/spark/fast/tests/DatasetComparerTest.scala b/core/src/test/scala/com/github/mrpowers/spark/fast/tests/DatasetComparerTest.scala index bcc49a9..e47b59a 100644 --- a/core/src/test/scala/com/github/mrpowers/spark/fast/tests/DatasetComparerTest.scala +++ b/core/src/test/scala/com/github/mrpowers/spark/fast/tests/DatasetComparerTest.scala @@ -207,17 +207,13 @@ class DatasetComparerTest extends AnyFreeSpec with DatasetComparer with SparkSes ) ) - val e = intercept[DatasetSchemaMismatch] { + intercept[DatasetSchemaMismatch] { assertLargeDatasetEquality(sourceDF, expectedDF) } - println(e) - val e2 = intercept[DatasetSchemaMismatch] { + + intercept[DatasetSchemaMismatch] { assertSmallDatasetEquality(sourceDF, expectedDF) } - println(e2) - - sourceDF.schema.printTreeString() - expectedDF.schema.printTreeString() } "throws an error if the DataFrames content is different" in { @@ -469,6 +465,41 @@ class DatasetComparerTest extends AnyFreeSpec with DatasetComparer with SparkSes assertLargeDatasetEquality(ds1, ds2, ignoreColumnOrder = true) assertLargeDatasetEquality(ds2, ds1, ignoreColumnOrder = true) } + + "correctly mark unequal schema field" in { + val sourceDF = spark.createDF( + List( + (1, 2.0), + (5, 3.0) + ), + List( + ("number", IntegerType, true), + ("float", DoubleType, true) + ) + ) + + val expectedDF = spark.createDF( + List( + (1, "word", 1L), + (5, "word", 2L) + ), + List( + ("number", IntegerType, true), + ("word", StringType, true), + ("long", LongType, true) + ) + ) + + val e = intercept[DatasetSchemaMismatch] { + assertLargeDatasetEquality(sourceDF, expectedDF) + } + + val colourGroup = e.getMessage.extractColorGroup + val expectedColourGroup = colourGroup.get(Console.GREEN) + val actualColourGroup = colourGroup.get(Console.RED) + assert(expectedColourGroup.contains(Seq("word", "StringType", "StructField(long,LongType,true,{})"))) + assert(actualColourGroup.contains(Seq("float", "DoubleType", "MISSING"))) + } } "assertSmallDatasetEquality" - { @@ -650,9 +681,43 @@ class DatasetComparerTest extends AnyFreeSpec with DatasetComparer with SparkSes Person("alice", 5) ).toDS.select("age", "name").as(ds1.encoder) - assertSmallDatasetEquality(ds1, ds2, ignoreColumnOrder = true) assertSmallDatasetEquality(ds2, ds1, ignoreColumnOrder = true) } + + "correctly mark unequal schema field" in { + val sourceDF = spark.createDF( + List( + (1, 2.0), + (5, 3.0) + ), + List( + ("number", IntegerType, true), + ("float", DoubleType, true) + ) + ) + + val expectedDF = spark.createDF( + List( + (1, "word", 1L), + (5, "word", 2L) + ), + List( + ("number", IntegerType, true), + ("word", StringType, true), + ("long", LongType, true) + ) + ) + + val e = intercept[DatasetSchemaMismatch] { + assertSmallDatasetEquality(sourceDF, expectedDF) + } + + val colourGroup = e.getMessage.extractColorGroup + val expectedColourGroup = colourGroup.get(Console.GREEN) + val actualColourGroup = colourGroup.get(Console.RED) + assert(expectedColourGroup.contains(Seq("word", "StringType", "StructField(long,LongType,true,{})"))) + assert(actualColourGroup.contains(Seq("float", "DoubleType", "MISSING"))) + } } "defaultSortDataset" - { From eb91137da64f08fbffd2ce621f3c4ee7fffec0f6 Mon Sep 17 00:00:00 2001 From: Tuan Pham Date: Tue, 1 Oct 2024 18:45:09 +1000 Subject: [PATCH 04/10] Handle single valued case --- .../com/github/mrpowers/spark/fast/tests/ProductUtil.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/com/github/mrpowers/spark/fast/tests/ProductUtil.scala b/core/src/main/scala/com/github/mrpowers/spark/fast/tests/ProductUtil.scala index 4840ee0..a776f9b 100644 --- a/core/src/main/scala/com/github/mrpowers/spark/fast/tests/ProductUtil.scala +++ b/core/src/main/scala/com/github/mrpowers/spark/fast/tests/ProductUtil.scala @@ -10,10 +10,10 @@ import scala.reflect.ClassTag object ProductUtil { private[mrpowers] def productOrRowToSeq(product: Any): Seq[Any] = { product match { - case null => Seq.empty case r: Row => r.toSeq case p: Product => p.productIterator.toSeq - case _ => throw new IllegalArgumentException("Only Row and Product types are supported") + case null => Seq.empty + case s => Seq(s) } } private[mrpowers] def showProductDiff[T: ClassTag]( From 9df300484ce9bcb175dcd4dd63d56963a9b95d3e Mon Sep 17 00:00:00 2001 From: Tuan Pham Date: Wed, 2 Oct 2024 08:04:18 +1000 Subject: [PATCH 05/10] Make Row Diff not display class name --- .../github/mrpowers/spark/fast/tests/ProductUtil.scala | 9 ++++----- .../spark/fast/tests/DataFrameComparerTest.scala | 4 ++-- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/com/github/mrpowers/spark/fast/tests/ProductUtil.scala b/core/src/main/scala/com/github/mrpowers/spark/fast/tests/ProductUtil.scala index a776f9b..72c8d7a 100644 --- a/core/src/main/scala/com/github/mrpowers/spark/fast/tests/ProductUtil.scala +++ b/core/src/main/scala/com/github/mrpowers/spark/fast/tests/ProductUtil.scala @@ -26,9 +26,8 @@ object ProductUtil { ): String = { val runTimeClass = implicitly[ClassTag[T]].runtimeClass - val className = runTimeClass.getSimpleName - val border = if (runTimeClass == classOf[Row]) ("[", "]") else ("(", ")") - val prodToString: Seq[Any] => String = s => s.mkString(s"$className${border._1}", ",", border._2) + val (className, lBracket, rBracket) = if (runTimeClass == classOf[Row]) ("", "[", "]") else (runTimeClass.getSimpleName, "(", ")") + val prodToString: Seq[Any] => String = s => s.mkString(s"$className$lBracket", ",", rBracket) val emptyProd = "MISSING" val sb = new StringBuilder @@ -62,9 +61,9 @@ object ProductUtil { case (actualRowField, expectedRowField, false) => (Red(actualRowField.toString), Green(expectedRowField.toString)) } - val start = DarkGray(s"$className${border._1}") + val start = DarkGray(s"$className$lBracket") val sep = DarkGray(",") - val end = DarkGray(border._2) + val end = DarkGray(rBracket) List( coloredDiff.map(_._1).mkStr(start, sep, end), coloredDiff.map(_._2).mkStr(start, sep, end) diff --git a/core/src/test/scala/com/github/mrpowers/spark/fast/tests/DataFrameComparerTest.scala b/core/src/test/scala/com/github/mrpowers/spark/fast/tests/DataFrameComparerTest.scala index 757acb5..3b61324 100644 --- a/core/src/test/scala/com/github/mrpowers/spark/fast/tests/DataFrameComparerTest.scala +++ b/core/src/test/scala/com/github/mrpowers/spark/fast/tests/DataFrameComparerTest.scala @@ -74,8 +74,8 @@ class DataFrameComparerTest extends AnyFreeSpec with DataFrameComparer with Spar val colourGroup = e.getMessage.extractColorGroup val expectedColourGroup = colourGroup.get(Console.GREEN) val actualColourGroup = colourGroup.get(Console.RED) - assert(expectedColourGroup.contains(Seq("uk", "Row[steve,10,aus]"))) - assert(actualColourGroup.contains(Seq("france", "Row[mark,11,usa]"))) + assert(expectedColourGroup.contains(Seq("uk", "[steve,10,aus]"))) + assert(actualColourGroup.contains(Seq("france", "[mark,11,usa]"))) } "works well for wide DataFrames" in { From f7848c3f6ba275631d547fe3180d30428d2c6860 Mon Sep 17 00:00:00 2001 From: Tuan Pham Date: Thu, 3 Oct 2024 08:41:41 +1000 Subject: [PATCH 06/10] Disallow input default val --- .../github/mrpowers/spark/fast/tests/DatasetComparer.scala | 2 +- .../com/github/mrpowers/spark/fast/tests/ProductUtil.scala | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/com/github/mrpowers/spark/fast/tests/DatasetComparer.scala b/core/src/main/scala/com/github/mrpowers/spark/fast/tests/DatasetComparer.scala index e71a115..e77a59d 100644 --- a/core/src/main/scala/com/github/mrpowers/spark/fast/tests/DatasetComparer.scala +++ b/core/src/main/scala/com/github/mrpowers/spark/fast/tests/DatasetComparer.scala @@ -71,7 +71,7 @@ Expected DataFrame Row Count: '$expectedCount' val e = expectedDS.collect().toSeq if (!a.approximateSameElements(e, equals)) { val arr = ("Actual Content", "Expected Content") - val msg = "Diffs\n" ++ ProductUtil.showProductDiff[T](arr, a, e, truncate) + val msg = "Diffs\n" ++ ProductUtil.showProductDiff(arr, a, e, truncate) throw DatasetContentMismatch(msg) } } diff --git a/core/src/main/scala/com/github/mrpowers/spark/fast/tests/ProductUtil.scala b/core/src/main/scala/com/github/mrpowers/spark/fast/tests/ProductUtil.scala index 72c8d7a..000cbf1 100644 --- a/core/src/main/scala/com/github/mrpowers/spark/fast/tests/ProductUtil.scala +++ b/core/src/main/scala/com/github/mrpowers/spark/fast/tests/ProductUtil.scala @@ -10,9 +10,9 @@ import scala.reflect.ClassTag object ProductUtil { private[mrpowers] def productOrRowToSeq(product: Any): Seq[Any] = { product match { + case null => Seq.empty case r: Row => r.toSeq case p: Product => p.productIterator.toSeq - case null => Seq.empty case s => Seq(s) } } @@ -22,7 +22,6 @@ object ProductUtil { expected: Seq[T], truncate: Int = 20, minColWidth: Int = 3, - defaultVal: T = null.asInstanceOf[T] ): String = { val runTimeClass = implicitly[ClassTag[T]].runtimeClass @@ -32,7 +31,7 @@ object ProductUtil { val sb = new StringBuilder - val fullJoin = actual.zipAll(expected, defaultVal, defaultVal) + val fullJoin = actual.zipAll(expected, null, null) val diff = fullJoin.map { case (actualRow, expectedRow) => if (actualRow == expectedRow) { From 398ec2fcce373b445553f4e616dd174c1cb061ef Mon Sep 17 00:00:00 2001 From: Tuan Pham Date: Thu, 3 Oct 2024 08:42:26 +1000 Subject: [PATCH 07/10] formatting --- .../com/github/mrpowers/spark/fast/tests/ProductUtil.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/com/github/mrpowers/spark/fast/tests/ProductUtil.scala b/core/src/main/scala/com/github/mrpowers/spark/fast/tests/ProductUtil.scala index 000cbf1..d2d1f15 100644 --- a/core/src/main/scala/com/github/mrpowers/spark/fast/tests/ProductUtil.scala +++ b/core/src/main/scala/com/github/mrpowers/spark/fast/tests/ProductUtil.scala @@ -21,7 +21,7 @@ object ProductUtil { actual: Seq[T], expected: Seq[T], truncate: Int = 20, - minColWidth: Int = 3, + minColWidth: Int = 3 ): String = { val runTimeClass = implicitly[ClassTag[T]].runtimeClass From 49d143fc0a80406d048053b29403471415828c69 Mon Sep 17 00:00:00 2001 From: Tuan Pham Date: Fri, 4 Oct 2024 17:58:54 +1000 Subject: [PATCH 08/10] handle Iterable cases --- .../spark/fast/tests/ProductUtil.scala | 12 ++- .../fast/tests/DatasetComparerTest.scala | 96 +++++++++++++++++++ 2 files changed, 103 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/com/github/mrpowers/spark/fast/tests/ProductUtil.scala b/core/src/main/scala/com/github/mrpowers/spark/fast/tests/ProductUtil.scala index d2d1f15..b4b464c 100644 --- a/core/src/main/scala/com/github/mrpowers/spark/fast/tests/ProductUtil.scala +++ b/core/src/main/scala/com/github/mrpowers/spark/fast/tests/ProductUtil.scala @@ -10,10 +10,12 @@ import scala.reflect.ClassTag object ProductUtil { private[mrpowers] def productOrRowToSeq(product: Any): Seq[Any] = { product match { - case null => Seq.empty - case r: Row => r.toSeq - case p: Product => p.productIterator.toSeq - case s => Seq(s) + case null => Seq.empty + case a: Array[_] => a + case i: Iterable[_] => i.toSeq + case r: Row => r.toSeq + case p: Product => p.productIterator.toSeq + case s => Seq(s) } } private[mrpowers] def showProductDiff[T: ClassTag]( @@ -45,7 +47,7 @@ object ProductUtil { List(Red(prodToString(actualSeq)), Green(emptyProd)) else { val withEquals = actualSeq - .zip(expectedSeq) + .zipAll(expectedSeq, "MISSING", "MISSING") .map { case (actualRowField, expectedRowField) => (actualRowField, expectedRowField, actualRowField == expectedRowField) } diff --git a/core/src/test/scala/com/github/mrpowers/spark/fast/tests/DatasetComparerTest.scala b/core/src/test/scala/com/github/mrpowers/spark/fast/tests/DatasetComparerTest.scala index e47b59a..6091d4d 100644 --- a/core/src/test/scala/com/github/mrpowers/spark/fast/tests/DatasetComparerTest.scala +++ b/core/src/test/scala/com/github/mrpowers/spark/fast/tests/DatasetComparerTest.scala @@ -2,8 +2,10 @@ package com.github.mrpowers.spark.fast.tests import org.apache.spark.sql.types._ import SparkSessionExt._ +import com.github.mrpowers.spark.fast.tests.ProductUtil.showProductDiff import com.github.mrpowers.spark.fast.tests.SchemaComparer.DatasetSchemaMismatch import com.github.mrpowers.spark.fast.tests.StringExt.StringOps +import org.apache.spark.sql.Row import org.scalatest.freespec.AnyFreeSpec object Person { @@ -66,6 +68,100 @@ class DatasetComparerTest extends AnyFreeSpec with DatasetComparer with SparkSes assert(actualColourGroup.contains(Seq("Person(bob,1)", "alice"))) } + "correctly mark unequal element for Dataset[String]" in { + import spark.implicits._ + val sourceDS = Seq("word", "StringType", "StructField(long,LongType,true,{})").toDS + + val expectedDS = List("word", "StringType", "StructField(long,LongType2,true,{})").toDS + + val e = intercept[DatasetContentMismatch] { + assertSmallDatasetEquality(sourceDS, expectedDS) + } + + val colourGroup = e.getMessage.extractColorGroup + val expectedColourGroup = colourGroup.get(Console.GREEN) + val actualColourGroup = colourGroup.get(Console.RED) + assert(expectedColourGroup.contains(Seq("StructField(long,LongType2,true,{})"))) + assert(actualColourGroup.contains(Seq("StructField(long,LongType,true,{})"))) + } + + "correctly mark unequal element for Dataset[Seq[String]]" in { + import spark.implicits._ + + val sourceDS = Seq( + Seq("apple", "banana", "cherry"), + Seq("dog", "cat"), + Seq("red", "green", "blue") + ).toDS + + val expectedDS = Seq( + Seq("apple", "banana2"), + Seq("dog", "cat"), + Seq("red", "green", "blue") + ).toDS + + val e = intercept[DatasetContentMismatch] { + assertSmallDatasetEquality(sourceDS, expectedDS) + } + + val colourGroup = e.getMessage.extractColorGroup + val expectedColourGroup = colourGroup.get(Console.GREEN) + val actualColourGroup = colourGroup.get(Console.RED) + assert(expectedColourGroup.contains(Seq("banana2", "MISSING"))) + assert(actualColourGroup.contains(Seq("banana", "cherry"))) + } + + "correctly mark unequal element for Dataset[Array[String]]" in { + import spark.implicits._ + + val sourceDS = Seq( + Array("apple", "banana", "cherry"), + Array("dog", "cat"), + Array("red", "green", "blue") + ).toDS + + val expectedDS = Seq( + Array("apple", "banana2"), + Array("dog", "cat"), + Array("red", "green", "blue") + ).toDS + + val e = intercept[DatasetContentMismatch] { + assertSmallDatasetEquality(sourceDS, expectedDS) + } + + val colourGroup = e.getMessage.extractColorGroup + val expectedColourGroup = colourGroup.get(Console.GREEN) + val actualColourGroup = colourGroup.get(Console.RED) + assert(expectedColourGroup.contains(Seq("banana2", "MISSING"))) + assert(actualColourGroup.contains(Seq("banana", "cherry"))) + } + + "correctly mark unequal element for Dataset[Map[String, String]]" in { + import spark.implicits._ + + val sourceDS = Seq( + Map("apple" -> "banana", "apple1" -> "banana1"), + Map("apple" -> "banana", "apple1" -> "banana1") + ).toDS + + val expectedDS = Seq( + Map("apple" -> "banana1", "apple1" -> "banana1"), + Map("apple" -> "banana", "apple1" -> "banana1") + ).toDS + + val e = intercept[DatasetContentMismatch] { + assertSmallDatasetEquality(sourceDS, expectedDS) + } + println(e) + + val colourGroup = e.getMessage.extractColorGroup + val expectedColourGroup = colourGroup.get(Console.GREEN) + val actualColourGroup = colourGroup.get(Console.RED) + assert(expectedColourGroup.contains(Seq("(apple, banana1)"))) + assert(actualColourGroup.contains(Seq("(apple, banana)"))) + } + "works with really long columns" in { val sourceDS = Seq( Person("juanisareallygoodguythatilikealotOK", 5), From f4bb04849cebce190d428d68b15efbdb79c9484b Mon Sep 17 00:00:00 2001 From: Tuan Pham Date: Fri, 4 Oct 2024 18:06:48 +1000 Subject: [PATCH 09/10] remove space --- .../mrpowers/spark/fast/tests/DatasetComparerTest.scala | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/core/src/test/scala/com/github/mrpowers/spark/fast/tests/DatasetComparerTest.scala b/core/src/test/scala/com/github/mrpowers/spark/fast/tests/DatasetComparerTest.scala index 6091d4d..257541a 100644 --- a/core/src/test/scala/com/github/mrpowers/spark/fast/tests/DatasetComparerTest.scala +++ b/core/src/test/scala/com/github/mrpowers/spark/fast/tests/DatasetComparerTest.scala @@ -153,13 +153,12 @@ class DatasetComparerTest extends AnyFreeSpec with DatasetComparer with SparkSes val e = intercept[DatasetContentMismatch] { assertSmallDatasetEquality(sourceDS, expectedDS) } - println(e) val colourGroup = e.getMessage.extractColorGroup val expectedColourGroup = colourGroup.get(Console.GREEN) val actualColourGroup = colourGroup.get(Console.RED) - assert(expectedColourGroup.contains(Seq("(apple, banana1)"))) - assert(actualColourGroup.contains(Seq("(apple, banana)"))) + assert(expectedColourGroup.contains(Seq("(apple,banana1)"))) + assert(actualColourGroup.contains(Seq("(apple,banana)"))) } "works with really long columns" in { From 0697c7886c5be110e4fe1f7ae18becff3b89fa88 Mon Sep 17 00:00:00 2001 From: Tuan Pham Date: Fri, 4 Oct 2024 18:14:15 +1000 Subject: [PATCH 10/10] Fix String test --- .../mrpowers/spark/fast/tests/DatasetComparerTest.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/test/scala/com/github/mrpowers/spark/fast/tests/DatasetComparerTest.scala b/core/src/test/scala/com/github/mrpowers/spark/fast/tests/DatasetComparerTest.scala index 257541a..726965c 100644 --- a/core/src/test/scala/com/github/mrpowers/spark/fast/tests/DatasetComparerTest.scala +++ b/core/src/test/scala/com/github/mrpowers/spark/fast/tests/DatasetComparerTest.scala @@ -81,8 +81,8 @@ class DatasetComparerTest extends AnyFreeSpec with DatasetComparer with SparkSes val colourGroup = e.getMessage.extractColorGroup val expectedColourGroup = colourGroup.get(Console.GREEN) val actualColourGroup = colourGroup.get(Console.RED) - assert(expectedColourGroup.contains(Seq("StructField(long,LongType2,true,{})"))) - assert(actualColourGroup.contains(Seq("StructField(long,LongType,true,{})"))) + assert(expectedColourGroup.contains(Seq("String(StructField(long,LongType2,true,{}))"))) + assert(actualColourGroup.contains(Seq("String(StructField(long,LongType,true,{}))"))) } "correctly mark unequal element for Dataset[Seq[String]]" in {