Skip to content

Commit

Permalink
Merge pull request #166 from zeotuan/schema-color-diff
Browse files Browse the repository at this point in the history
Support Schema color diff and Improve Dataset color diff message
  • Loading branch information
zeotuan authored Oct 16, 2024
2 parents c53aeff + 0697c78 commit 8127ab1
Show file tree
Hide file tree
Showing 5 changed files with 297 additions and 64 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ Expected DataFrame Row Count: '$expectedCount'
/**
* Raises an error unless `actualDS` and `expectedDS` are equal
*/
def assertSmallDatasetEquality[T](
def assertSmallDatasetEquality[T: ClassTag](
actualDS: Dataset[T],
expectedDS: Dataset[T],
ignoreNullable: Boolean = false,
Expand All @@ -53,7 +53,7 @@ Expected DataFrame Row Count: '$expectedCount'
assertSmallDatasetContentEquality(actual, expectedDS, orderedComparison, truncate, equals)
}

def assertSmallDatasetContentEquality[T](
def assertSmallDatasetContentEquality[T: ClassTag](
actualDS: Dataset[T],
expectedDS: Dataset[T],
orderedComparison: Boolean,
Expand All @@ -66,12 +66,12 @@ Expected DataFrame Row Count: '$expectedCount'
assertSmallDatasetContentEquality(defaultSortDataset(actualDS), defaultSortDataset(expectedDS), truncate, equals)
}

def assertSmallDatasetContentEquality[T](actualDS: Dataset[T], expectedDS: Dataset[T], truncate: Int, equals: (T, T) => Boolean): Unit = {
def assertSmallDatasetContentEquality[T: ClassTag](actualDS: Dataset[T], expectedDS: Dataset[T], truncate: Int, equals: (T, T) => Boolean): Unit = {
val a = actualDS.collect().toSeq
val e = expectedDS.collect().toSeq
if (!a.approximateSameElements(e, equals)) {
val arr = ("Actual Content", "Expected Content")
val msg = "Diffs\n" ++ DataframeUtil.showDataframeDiff(arr, a.asRows, e.asRows, truncate)
val msg = "Diffs\n" ++ ProductUtil.showProductDiff(arr, a, e, truncate)
throw DatasetContentMismatch(msg)
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,59 +1,70 @@
package com.github.mrpowers.spark.fast.tests

import com.github.mrpowers.spark.fast.tests.ufansi.Color.{DarkGray, Green, Red}
import com.github.mrpowers.spark.fast.tests.ufansi.FansiExtensions.StrOps
import org.apache.commons.lang3.StringUtils
import org.apache.spark.sql.Row
import com.github.mrpowers.spark.fast.tests.ufansi.FansiExtensions.StrOps
object DataframeUtil {

private[mrpowers] def showDataframeDiff(
import scala.reflect.ClassTag

object ProductUtil {
private[mrpowers] def productOrRowToSeq(product: Any): Seq[Any] = {
product match {
case null => Seq.empty
case a: Array[_] => a
case i: Iterable[_] => i.toSeq
case r: Row => r.toSeq
case p: Product => p.productIterator.toSeq
case s => Seq(s)
}
}
private[mrpowers] def showProductDiff[T: ClassTag](
header: (String, String),
actual: Seq[Row],
expected: Seq[Row],
actual: Seq[T],
expected: Seq[T],
truncate: Int = 20,
minColWidth: Int = 3
): String = {

val runTimeClass = implicitly[ClassTag[T]].runtimeClass
val (className, lBracket, rBracket) = if (runTimeClass == classOf[Row]) ("", "[", "]") else (runTimeClass.getSimpleName, "(", ")")
val prodToString: Seq[Any] => String = s => s.mkString(s"$className$lBracket", ",", rBracket)
val emptyProd = "MISSING"

val sb = new StringBuilder

val fullJoin = actual.zipAll(expected, Row(), Row())
val fullJoin = actual.zipAll(expected, null, null)

val diff = fullJoin.map { case (actualRow, expectedRow) =>
if (equals(actualRow, expectedRow)) {
if (actualRow == expectedRow) {
List(DarkGray(actualRow.toString), DarkGray(expectedRow.toString))
} else {
val actualSeq = actualRow.toSeq
val expectedSeq = expectedRow.toSeq
val actualSeq = productOrRowToSeq(actualRow)
val expectedSeq = productOrRowToSeq(expectedRow)
if (actualSeq.isEmpty)
List(
Red("[]"),
Green(expectedSeq.mkString("[", ",", "]"))
)
List(Red(emptyProd), Green(prodToString(expectedSeq)))
else if (expectedSeq.isEmpty)
List(Red(actualSeq.mkString("[", ",", "]")), Green("[]"))
List(Red(prodToString(actualSeq)), Green(emptyProd))
else {
val withEquals = actualSeq
.zip(expectedSeq)
.zipAll(expectedSeq, "MISSING", "MISSING")
.map { case (actualRowField, expectedRowField) =>
(actualRowField, expectedRowField, actualRowField == expectedRowField)
}
val allFieldsAreNotEqual = !withEquals.exists(_._3)
if (allFieldsAreNotEqual) {
List(
Red(actualSeq.mkString("[", ",", "]")),
Green(expectedSeq.mkString("[", ",", "]"))
)
List(Red(prodToString(actualSeq)), Green(prodToString(expectedSeq)))
} else {

val coloredDiff = withEquals
.map {
case (actualRowField, expectedRowField, true) =>
(DarkGray(actualRowField.toString), DarkGray(expectedRowField.toString))
case (actualRowField, expectedRowField, false) =>
(Red(actualRowField.toString), Green(expectedRowField.toString))
}
val start = DarkGray("[")
val start = DarkGray(s"$className$lBracket")
val sep = DarkGray(",")
val end = DarkGray("]")
val end = DarkGray(rBracket)
List(
coloredDiff.map(_._1).mkStr(start, sep, end),
coloredDiff.map(_._2).mkStr(start, sep, end)
Expand All @@ -69,11 +80,12 @@ object DataframeUtil {
val colWidths = Array.fill(numCols)(minColWidth)

// Compute the width of each column
for ((cell, i) <- headerSeq.zipWithIndex) {
headerSeq.zipWithIndex.foreach({ case (cell, i) =>
colWidths(i) = math.max(colWidths(i), cell.length)
}
for (row <- diff) {
for ((cell, i) <- row.zipWithIndex) {
})

diff.foreach { row =>
row.zipWithIndex.foreach { case (cell, i) =>
colWidths(i) = math.max(colWidths(i), cell.length)
}
}
Expand Down Expand Up @@ -117,5 +129,4 @@ object DataframeUtil {

sb.toString
}

}
Original file line number Diff line number Diff line change
@@ -1,29 +1,18 @@
package com.github.mrpowers.spark.fast.tests

import com.github.mrpowers.spark.fast.tests.ProductUtil.showProductDiff
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType}
import org.apache.spark.sql.types.{ArrayType, DataType, MapType, NullType, StructField, StructType}

object SchemaComparer {

case class DatasetSchemaMismatch(smth: String) extends Exception(smth)
private def betterSchemaMismatchMessage[T](actualDS: Dataset[T], expectedDS: Dataset[T]): String = {
"\nActual Schema Field | Expected Schema Field\n" + actualDS.schema
.zipAll(
expectedDS.schema,
"",
""
)
.map {
case (sf1, sf2) if sf1 == sf2 =>
ufansi.Color.Blue(s"$sf1 | $sf2")
case ("", sf2) =>
ufansi.Color.Red(s"MISSING | $sf2")
case (sf1, "") =>
ufansi.Color.Red(s"$sf1 | MISSING")
case (sf1, sf2) =>
ufansi.Color.Red(s"$sf1 | $sf2")
}
.mkString("\n")
showProductDiff(
("Actual Schema", "Expected Schema"),
actualDS.schema.fields,
expectedDS.schema.fields,
truncate = 200
)
}

def assertSchemaEqual[T](
Expand All @@ -36,7 +25,7 @@ object SchemaComparer {
require((ignoreColumnNames, ignoreColumnOrder) != (true, true), "Cannot set both ignoreColumnNames and ignoreColumnOrder to true.")
if (!SchemaComparer.equals(actualDS.schema, expectedDS.schema, ignoreNullable, ignoreColumnNames, ignoreColumnOrder)) {
throw DatasetSchemaMismatch(
betterSchemaMismatchMessage(actualDS, expectedDS)
"Diffs\n" + betterSchemaMismatchMessage(actualDS, expectedDS)
)
}
}
Expand Down Expand Up @@ -76,5 +65,4 @@ object SchemaComparer {
case _ => dt1 == dt2
}
}

}
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package com.github.mrpowers.spark.fast.tests

import org.apache.spark.sql.types.{DoubleType, IntegerType, StringType}
import org.apache.spark.sql.types.{DoubleType, IntegerType, LongType, StringType}
import SparkSessionExt._
import com.github.mrpowers.spark.fast.tests.SchemaComparer.DatasetSchemaMismatch
import com.github.mrpowers.spark.fast.tests.StringExt.StringOps
Expand Down Expand Up @@ -331,6 +331,41 @@ class DataFrameComparerTest extends AnyFreeSpec with DataFrameComparer with Spar
assertLargeDataFrameEquality(sourceDF, expectedDF)
}
}

"correctly mark unequal schema field" in {
val sourceDF = spark.createDF(
List(
(1, 2.0),
(5, 3.0)
),
List(
("number", IntegerType, true),
("float", DoubleType, true)
)
)

val expectedDF = spark.createDF(
List(
(1, "word", 1L),
(5, "word", 2L)
),
List(
("number", IntegerType, true),
("word", StringType, true),
("long", LongType, true)
)
)

val e = intercept[DatasetSchemaMismatch] {
assertSmallDataFrameEquality(sourceDF, expectedDF)
}

val colourGroup = e.getMessage.extractColorGroup
val expectedColourGroup = colourGroup.get(Console.GREEN)
val actualColourGroup = colourGroup.get(Console.RED)
assert(expectedColourGroup.contains(Seq("word", "StringType", "StructField(long,LongType,true,{})")))
assert(actualColourGroup.contains(Seq("float", "DoubleType", "MISSING")))
}
}

"assertApproximateDataFrameEquality" - {
Expand Down
Loading

0 comments on commit 8127ab1

Please sign in to comment.