From 0e8c6b1f1048e078a8770ef9652d6fb2942e481d Mon Sep 17 00:00:00 2001 From: Rahul Sharma Date: Mon, 16 Dec 2024 16:58:18 -0500 Subject: [PATCH] Fix row level bug when composing outcome - When a check fails due to a precondition failure, the row level results are not evaluated correctly. - For example, let's say a check has a completeness constraint which passes, and a minimum constraint which fails due to a precondition failure. - The row level results will be the results for just the completeness constraint. There will be no results generated for the minimum constraint, and therefore the row level results will be incorrect. - We fix this by adding a default outcome for when the row level result column is not provided by the analyzer. --- .../com/amazon/deequ/VerificationResult.scala | 4 ++-- .../amazon/deequ/VerificationSuiteTest.scala | 23 +++++++++++++++++++ 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/src/main/scala/com/amazon/deequ/VerificationResult.scala b/src/main/scala/com/amazon/deequ/VerificationResult.scala index b9b450f2..24b68726 100644 --- a/src/main/scala/com/amazon/deequ/VerificationResult.scala +++ b/src/main/scala/com/amazon/deequ/VerificationResult.scala @@ -31,6 +31,7 @@ import com.amazon.deequ.repository.SimpleResultSerde import org.apache.spark.sql.Column import org.apache.spark.sql.DataFrame import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.functions.lit import org.apache.spark.sql.functions.{col, monotonically_increasing_id} import java.util.UUID @@ -144,7 +145,7 @@ object VerificationResult { val constraint = constraintResult.constraint constraint match { case asserted: RowLevelAssertedConstraint => - constraintResult.metric.flatMap(metricToColumn).map(asserted.assertion(_)) + constraintResult.metric.flatMap(metricToColumn).map(asserted.assertion(_)).orElse(Some(lit(false))) case _: RowLevelConstraint => constraintResult.metric.flatMap(metricToColumn) case _: RowLevelGroupedConstraint => @@ -160,7 +161,6 @@ object VerificationResult { } } - private[this] def getSimplifiedCheckResultOutput( verificationResult: VerificationResult) : Seq[SimpleCheckResultOutput] = { diff --git a/src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala b/src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala index 146579e8..15c5ca44 100644 --- a/src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala +++ b/src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala @@ -1996,6 +1996,29 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec } "Verification Suite's Row Level Results" should { + "yield correct results for invalid column type" in withSparkSession { sparkSession => + import sparkSession.implicits._ + val df = Seq( + ("1", 1, "blue"), + ("2", 2, "green"), + ("3", 3, "blue"), + ("4", 4, "red"), + ("5", 5, "purple") + ).toDF("id", "id2", "color") + val invalidColumn = "id" + val validColumn = "id2" + val checkOnInvalidColumnDescription = s"check on $invalidColumn" + val checkOnValidColumnDescription = s"check on $validColumn" + val checkOnInvalidColumn = Check(CheckLevel.Error, checkOnInvalidColumnDescription).hasMin(invalidColumn, _ >= 3).isComplete(invalidColumn) + val checkOnValidColumn = Check(CheckLevel.Error, checkOnValidColumnDescription).hasMin(validColumn, _ >= 3).isComplete(validColumn) + val verificationResult = VerificationSuite().onData(df).addChecks(Seq(checkOnInvalidColumn, checkOnValidColumn)).run() + val rowLevelResults = VerificationResult.rowLevelResultsAsDataFrame(sparkSession, verificationResult, df).collect() + val invalidColumnCheckRowLevelResults = rowLevelResults.map(_.getAs[Boolean](checkOnInvalidColumnDescription)) + val validColumnCheckRowLevelResults = rowLevelResults.map(_.getAs[Boolean](checkOnValidColumnDescription)) + invalidColumnCheckRowLevelResults shouldBe Seq(false, false, false, false, false) + validColumnCheckRowLevelResults shouldBe Seq(false, false, true, true, true) + } + "yield correct results for satisfies check" in withSparkSession { sparkSession => import sparkSession.implicits._ val df = Seq(