From 63288699635bfc8e98f6cee34570c1479acf386c Mon Sep 17 00:00:00 2001 From: Edward Cho Date: Mon, 12 Feb 2024 17:26:13 -0500 Subject: [PATCH 1/8] Modified Completeness analyzer to label filtered rows as null for row-level results --- .../com/amazon/deequ/analyzers/Analyzer.scala | 21 ++++++++++++- .../amazon/deequ/analyzers/Completeness.scala | 14 +++++++-- .../deequ/analyzers/CompletenessTest.scala | 30 +++++++++++++++++++ .../amazon/deequ/utils/FixtureSupport.scala | 24 +++++++++++++++ 4 files changed, 86 insertions(+), 3 deletions(-) diff --git a/src/main/scala/com/amazon/deequ/analyzers/Analyzer.scala b/src/main/scala/com/amazon/deequ/analyzers/Analyzer.scala index a80405825..c8e49c6b5 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/Analyzer.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/Analyzer.scala @@ -17,6 +17,7 @@ package com.amazon.deequ.analyzers import com.amazon.deequ.analyzers.Analyzers._ +import com.amazon.deequ.analyzers.FilteredRow.FilteredRow import com.amazon.deequ.analyzers.NullBehavior.NullBehavior import com.amazon.deequ.analyzers.runners._ import com.amazon.deequ.metrics.DoubleMetric @@ -255,12 +256,18 @@ case class NumMatchesAndCount(numMatches: Long, count: Long, override val fullCo } } -case class AnalyzerOptions(nullBehavior: NullBehavior = NullBehavior.Ignore) +case class AnalyzerOptions(nullBehavior: NullBehavior = NullBehavior.Ignore, + filteredRow: FilteredRow = FilteredRow.NULL) object NullBehavior extends Enumeration { type NullBehavior = Value val Ignore, EmptyString, Fail = Value } +object FilteredRow extends Enumeration { + type FilteredRow = Value + val NULL, TRUE = Value +} + /** Base class for analyzers that compute ratios of matching predicates */ abstract class PredicateMatchingAnalyzer( name: String, @@ -490,6 +497,18 @@ private[deequ] object Analyzers { conditionalSelectionFromColumns(selection, conditionColumn) } + def conditionalSelectionFilteredFromColumns( + selection: Column, + conditionColumn: Option[Column], + filterTreatment: String) + : Column = { + conditionColumn + .map { condition => { + when(not(condition), expr(filterTreatment)).when(condition, selection) + } } + .getOrElse(selection) + } + private[this] def conditionalSelectionFromColumns( selection: Column, conditionColumn: Option[Column]) diff --git a/src/main/scala/com/amazon/deequ/analyzers/Completeness.scala b/src/main/scala/com/amazon/deequ/analyzers/Completeness.scala index 5e80e2f6e..342f92a67 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/Completeness.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/Completeness.scala @@ -20,8 +20,11 @@ import com.amazon.deequ.analyzers.Preconditions.{hasColumn, isNotNested} import org.apache.spark.sql.functions.sum import org.apache.spark.sql.types.{IntegerType, StructType} import Analyzers._ +import com.amazon.deequ.analyzers.FilteredRow.FilteredRow import com.google.common.annotations.VisibleForTesting import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions.col +import org.apache.spark.sql.functions.expr import org.apache.spark.sql.{Column, Row} /** Completeness is the fraction of non-null values in a column of a DataFrame. */ @@ -30,9 +33,8 @@ case class Completeness(column: String, where: Option[String] = None) extends FilterableAnalyzer { override def fromAggregationResult(result: Row, offset: Int): Option[NumMatchesAndCount] = { - ifNoNullsIn(result, offset, howMany = 2) { _ => - NumMatchesAndCount(result.getLong(offset), result.getLong(offset + 1), Some(criterion)) + NumMatchesAndCount(result.getLong(offset), result.getLong(offset + 1), Some(rowLevelResults)) } } @@ -51,4 +53,12 @@ case class Completeness(column: String, where: Option[String] = None) extends @VisibleForTesting // required by some tests that compare analyzer results to an expected state private[deequ] def criterion: Column = conditionalSelection(column, where).isNotNull + + @VisibleForTesting + private[deequ] def rowLevelResults: Column = { + val filteredRow = FilteredRow.NULL + val whereCondition = where.map { expression => expr(expression)} + conditionalSelectionFilteredFromColumns(col(column).isNotNull, whereCondition, filteredRow.toString) + } + } diff --git a/src/test/scala/com/amazon/deequ/analyzers/CompletenessTest.scala b/src/test/scala/com/amazon/deequ/analyzers/CompletenessTest.scala index b1cdf3014..de33f2327 100644 --- a/src/test/scala/com/amazon/deequ/analyzers/CompletenessTest.scala +++ b/src/test/scala/com/amazon/deequ/analyzers/CompletenessTest.scala @@ -23,6 +23,8 @@ import com.amazon.deequ.utils.FixtureSupport import org.scalatest.matchers.should.Matchers import org.scalatest.wordspec.AnyWordSpec +import scala.util.Success + class CompletenessTest extends AnyWordSpec with Matchers with SparkContextSpec with FixtureSupport { "Completeness" should { @@ -37,5 +39,33 @@ class CompletenessTest extends AnyWordSpec with Matchers with SparkContextSpec w data.withColumn("new", metric.fullColumn.get).collect().map(_.getAs[Boolean]("new")) shouldBe Seq(true, true, true, true, false, true, true, false) } + + "return row-level results for columns filtered using where" in withSparkSession { session => + + val data = getDfForWhereClause(session) + + val completenessCountry = Completeness("ZipCode", Option("State = \"CA\"")) + val state = completenessCountry.computeStateFrom(data) + val metric: DoubleMetric with FullColumn = completenessCountry.computeMetricFrom(state) + + // Address Line 3 is null only where Address Line 2 is null. With the where clause, completeness should be 1.0 + data.withColumn("new", metric.fullColumn.get).collect().map(_.getAs[Boolean]("new")) shouldBe + Seq(true, true, false, false) + } + + "return row-level results for columns filtered as null" in withSparkSession { session => + + val data = getDfCompleteAndInCompleteColumns(session) + + val completenessAtt2 = Completeness("att2", Option("att1 = \"a\"")) + val state = completenessAtt2.computeStateFrom(data) + val metric: DoubleMetric with FullColumn = completenessAtt2.computeMetricFrom(state) + + val df = data.withColumn("new", metric.fullColumn.get) + println("Filtered as null") + df.show(false) + df.collect().map(_.getAs[Any]("new")).toSeq shouldBe + Seq(true, null, false, true, null, true) + } } } diff --git a/src/test/scala/com/amazon/deequ/utils/FixtureSupport.scala b/src/test/scala/com/amazon/deequ/utils/FixtureSupport.scala index 9b6ad9d4e..601134a53 100644 --- a/src/test/scala/com/amazon/deequ/utils/FixtureSupport.scala +++ b/src/test/scala/com/amazon/deequ/utils/FixtureSupport.scala @@ -338,6 +338,19 @@ trait FixtureSupport { .toDF("att1", "att2") } + def getDfWithDistinctValuesQuotes(sparkSession: SparkSession): DataFrame = { + import sparkSession.implicits._ + + Seq( + ("a", null, "Already Has "), + ("a", null, " Can't Proceed"), + (null, "can't", "Already Has "), + ("b", "help", " Can't Proceed"), + ("b", "but", "Already Has "), + ("c", "wouldn't", " Can't Proceed")) + .toDF("att1", "att2", "reason") + } + def getDfWithConditionallyUninformativeColumns(sparkSession: SparkSession): DataFrame = { import sparkSession.implicits._ Seq( @@ -409,6 +422,17 @@ trait FixtureSupport { ).toDF("item.one", "att1", "att2") } + def getDfForWhereClause(sparkSession: SparkSession): DataFrame = { + import sparkSession.implicits._ + + Seq( + ("Acme", "90210", "CA", "Los Angeles"), + ("Acme", "90211", "CA", "Los Angeles"), + ("Robocorp", null, "NJ", null), + ("Robocorp", null, "NY", "New York") + ).toDF("Company", "ZipCode", "State", "City") + } + def getDfCompleteAndInCompleteColumnsWithPeriod(sparkSession: SparkSession): DataFrame = { import sparkSession.implicits._ From 62da8ddf8605043f7c28452186d21819c5c5f894 Mon Sep 17 00:00:00 2001 From: Edward Cho Date: Mon, 12 Feb 2024 18:17:05 -0500 Subject: [PATCH 2/8] Modified GroupingAnalyzers and Uniqueness analyzer to label filtered rows as null for row-level results --- .../com/amazon/deequ/analyzers/Analyzer.scala | 11 ++- .../deequ/analyzers/GroupingAnalyzers.scala | 15 ++- .../deequ/analyzers/UniqueValueRatio.scala | 9 +- .../amazon/deequ/analyzers/Uniqueness.scala | 12 ++- .../amazon/deequ/VerificationSuiteTest.scala | 96 +++++++++++++++++++ .../deequ/analyzers/AnalyzerTests.scala | 4 +- .../deequ/analyzers/UniquenessTest.scala | 30 ++++++ .../com/amazon/deequ/checks/CheckTest.scala | 32 ++++++- 8 files changed, 200 insertions(+), 9 deletions(-) diff --git a/src/main/scala/com/amazon/deequ/analyzers/Analyzer.scala b/src/main/scala/com/amazon/deequ/analyzers/Analyzer.scala index c8e49c6b5..3e88849c4 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/Analyzer.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/Analyzer.scala @@ -72,6 +72,8 @@ trait Analyzer[S <: State[_], +M <: Metric[_]] extends Serializable { */ def computeStateFrom(data: DataFrame): Option[S] + def computeStateFrom(data: DataFrame, filterCondition: Option[String]): Option[S] + /** * Compute the metric from the state (sufficient statistics) * @param state wrapper holding a state of type S (required due to typing issues...) @@ -98,13 +100,14 @@ trait Analyzer[S <: State[_], +M <: Metric[_]] extends Serializable { def calculate( data: DataFrame, aggregateWith: Option[StateLoader] = None, - saveStatesWith: Option[StatePersister] = None) + saveStatesWith: Option[StatePersister] = None, + filterCondition: Option[String] = None) : M = { try { preconditions.foreach { condition => condition(data.schema) } - val state = computeStateFrom(data) + val state = computeStateFrom(data, filterCondition) calculateMetric(state, aggregateWith, saveStatesWith) } catch { @@ -191,6 +194,10 @@ trait ScanShareableAnalyzer[S <: State[_], +M <: Metric[_]] extends Analyzer[S, fromAggregationResult(result, 0) } + override def computeStateFrom(data: DataFrame, where: Option[String]): Option[S] = { + computeStateFrom(data) + } + /** Produces a metric from the aggregation result */ private[deequ] def metricFromAggregationResult( result: Row, diff --git a/src/main/scala/com/amazon/deequ/analyzers/GroupingAnalyzers.scala b/src/main/scala/com/amazon/deequ/analyzers/GroupingAnalyzers.scala index 2090d8231..a7197e89c 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/GroupingAnalyzers.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/GroupingAnalyzers.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.functions.count import org.apache.spark.sql.functions.expr import org.apache.spark.sql.functions.lit import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.functions.when /** Base class for all analyzers that operate the frequencies of groups in the data */ abstract class FrequencyBasedAnalyzer(columnsToGroupOn: Seq[String]) @@ -43,6 +44,10 @@ abstract class FrequencyBasedAnalyzer(columnsToGroupOn: Seq[String]) Some(FrequencyBasedAnalyzer.computeFrequencies(data, groupingColumns())) } + override def computeStateFrom(data: DataFrame, where: Option[String]): Option[FrequenciesAndNumRows] = { + Some(FrequencyBasedAnalyzer.computeFrequencies(data, groupingColumns, where)) + } + /** We need at least one grouping column, and all specified columns must exist */ override def preconditions: Seq[StructType => Unit] = { Seq(atLeastOne(columnsToGroupOn)) ++ columnsToGroupOn.map { hasColumn } ++ @@ -88,7 +93,15 @@ object FrequencyBasedAnalyzer { .count() // Set rows with value count 1 to true, and otherwise false - val fullColumn: Column = count(UNIQUENESS_ID).over(Window.partitionBy(columnsToGroupBy: _*)) + val fullColumn: Column = { + val window = Window.partitionBy(columnsToGroupBy: _*) + where.map { + condition => { + count(when(expr(condition), UNIQUENESS_ID)).over(window) + } + }.getOrElse(count(UNIQUENESS_ID).over(window)) + } + FrequenciesAndNumRows(frequencies, numRows, Option(fullColumn)) } diff --git a/src/main/scala/com/amazon/deequ/analyzers/UniqueValueRatio.scala b/src/main/scala/com/amazon/deequ/analyzers/UniqueValueRatio.scala index d3c8aeb68..d31cdd2c7 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/UniqueValueRatio.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/UniqueValueRatio.scala @@ -18,6 +18,8 @@ package com.amazon.deequ.analyzers import com.amazon.deequ.analyzers.Analyzers.COUNT_COL import com.amazon.deequ.metrics.DoubleMetric +import org.apache.spark.sql.functions.expr +import org.apache.spark.sql.functions.not import org.apache.spark.sql.functions.when import org.apache.spark.sql.{Column, Row} import org.apache.spark.sql.functions.{col, count, lit, sum} @@ -34,7 +36,12 @@ case class UniqueValueRatio(columns: Seq[String], where: Option[String] = None) override def fromAggregationResult(result: Row, offset: Int, fullColumn: Option[Column] = None): DoubleMetric = { val numUniqueValues = result.getDouble(offset) val numDistinctValues = result.getLong(offset + 1).toDouble - val fullColumnUniqueness = when((fullColumn.getOrElse(null)).equalTo(1), true).otherwise(false) + val conditionColumn = where.map { expression => expr(expression) } + val fullColumnUniqueness = conditionColumn.map { + condition => { + when(not(condition), expr(FilteredRow.NULL.toString)).when((fullColumn.getOrElse(null)).equalTo(1), true).otherwise(false) + } + }.getOrElse(when((fullColumn.getOrElse(null)).equalTo(1), true).otherwise(false)) toSuccessMetric(numUniqueValues / numDistinctValues, Option(fullColumnUniqueness)) } diff --git a/src/main/scala/com/amazon/deequ/analyzers/Uniqueness.scala b/src/main/scala/com/amazon/deequ/analyzers/Uniqueness.scala index 959f4734c..88772a413 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/Uniqueness.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/Uniqueness.scala @@ -17,11 +17,14 @@ package com.amazon.deequ.analyzers import com.amazon.deequ.analyzers.Analyzers.COUNT_COL +import com.amazon.deequ.analyzers.Analyzers.conditionalCount import com.amazon.deequ.metrics.DoubleMetric import org.apache.spark.sql.Column import org.apache.spark.sql.Row import org.apache.spark.sql.functions.when import org.apache.spark.sql.functions.col +import org.apache.spark.sql.functions.not +import org.apache.spark.sql.functions.expr import org.apache.spark.sql.functions.lit import org.apache.spark.sql.functions.sum import org.apache.spark.sql.types.DoubleType @@ -33,11 +36,16 @@ case class Uniqueness(columns: Seq[String], where: Option[String] = None) with FilterableAnalyzer { override def aggregationFunctions(numRows: Long): Seq[Column] = { - (sum(col(COUNT_COL).equalTo(lit(1)).cast(DoubleType)) / numRows) :: Nil + (sum(col(COUNT_COL).equalTo(lit(1)).cast(DoubleType)) / numRows) :: Nil } override def fromAggregationResult(result: Row, offset: Int, fullColumn: Option[Column]): DoubleMetric = { - val fullColumnUniqueness = when((fullColumn.getOrElse(null)).equalTo(1), true).otherwise(false) + val conditionColumn = where.map { expression => expr(expression) } + val fullColumnUniqueness = conditionColumn.map { + condition => { + when(not(condition), expr(FilteredRow.NULL.toString)).when((fullColumn.getOrElse(null)).equalTo(1), true).otherwise(false) + } + }.getOrElse(when((fullColumn.getOrElse(null)).equalTo(1), true).otherwise(false)) super.fromAggregationResult(result, offset, Option(fullColumnUniqueness)) } diff --git a/src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala b/src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala index a468b8a34..3f2f14b66 100644 --- a/src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala +++ b/src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala @@ -304,6 +304,43 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec assert(Seq(true, true, true, false, false, false).sameElements(rowLevel8)) } + "generate a result that contains row-level results with filter with null for filtered rows" in withSparkSession { session => + val data = getDfCompleteAndInCompleteColumns(session) + + val completeness = new Check(CheckLevel.Error, "rule1").hasCompleteness("att2", _ > 0.7, None).where("att1 = \"a\"") + val uniqueness = new Check(CheckLevel.Error, "rule2").hasUniqueness("att1", _ > 0.5, None) + val uniquenessWhere = new Check(CheckLevel.Error, "rule3").isUnique("att1").where("item < 3") + val expectedColumn1 = completeness.description + val expectedColumn2 = uniqueness.description + val expectedColumn3 = uniquenessWhere.description + + + val suite = new VerificationSuite().onData(data) + .addCheck(completeness) + .addCheck(uniqueness) + .addCheck(uniquenessWhere) + + val result: VerificationResult = suite.run() + + assert(result.status == CheckStatus.Error) + + val resultData = VerificationResult.rowLevelResultsAsDataFrame(session, result, data).orderBy("item") + + val expectedColumns: Set[String] = + data.columns.toSet + expectedColumn1 + expectedColumn2 + expectedColumn3 + assert(resultData.columns.toSet == expectedColumns) + + val rowLevel1 = resultData.select(expectedColumn1).collect().map(r => r.getAs[Any](0)) + assert(Seq(true, null, false, true, null, true).sameElements(rowLevel1)) + + val rowLevel2 = resultData.select(expectedColumn2).collect().map(r => r.getAs[Any](0)) + assert(Seq(false, false, false, false, false, false).sameElements(rowLevel2)) + + val rowLevel3 = resultData.select(expectedColumn3).collect().map(r => r.getAs[Any](0)) + assert(Seq(true, true, null, null, null, null).sameElements(rowLevel3)) + + } + "generate a result that contains row-level results for null column values" in withSparkSession { session => val data = getDfCompleteAndInCompleteColumnsAndVarLengthStrings(session) @@ -459,6 +496,38 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec } + "accept analysis config for mandatory analysis for checks with filters" in withSparkSession { sparkSession => + + import sparkSession.implicits._ + val df = getDfFull(sparkSession) + + val result = { + val checkToSucceed = Check(CheckLevel.Error, "group-1") + .hasCompleteness("att1", _ > 0.5, null) // 1.0 + .where("att2 = \"c\"") + val uniquenessCheck = Check(CheckLevel.Error, "group-2") + .isUnique("att1") + .where("item > 3") + + + VerificationSuite().onData(df).addCheck(checkToSucceed).addCheck(uniquenessCheck).run() + } + + assert(result.status == CheckStatus.Success) + + val analysisDf = AnalyzerContext.successMetricsAsDataFrame(sparkSession, + AnalyzerContext(result.metrics)) + + val expected = Seq( + ("Column", "att1", "Completeness (where: att2 = \"c\")", 1.0), + ("Column", "att1", "Uniqueness (where: item > 3)", 1.0)) + .toDF("entity", "instance", "name", "value") + + + assertSameRows(analysisDf, expected) + + } + "run the analysis even there are no constraints" in withSparkSession { sparkSession => import sparkSession.implicits._ @@ -786,6 +855,33 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec } } + "A well-defined check should pass even if an ill-defined check is also configured quotes" in withSparkSession { + sparkSession => + val df = getDfWithDistinctValuesQuotes(sparkSession) + + val rangeCheck = Check(CheckLevel.Error, "a") + .isContainedIn("att2", Array("can't", "help", "but", "wouldn't")) + + val reasonCheck = Check(CheckLevel.Error, "a") + .isContainedIn("reason", Array("Already Has ", " Can't Proceed")) + + val verificationResult = VerificationSuite() + .onData(df) + .addCheck(rangeCheck) + .addCheck(reasonCheck) + .run() + + val checkSuccessResult = verificationResult.checkResults(rangeCheck) +// checkSuccessResult.constraintResults.map(_.message) shouldBe List(None) + println(checkSuccessResult.constraintResults.map(_.message)) + assert(checkSuccessResult.status == CheckStatus.Success) + + val reasonResult = verificationResult.checkResults(reasonCheck) + // checkSuccessResult.constraintResults.map(_.message) shouldBe List(None) + println(checkSuccessResult.constraintResults.map(_.message)) + assert(checkSuccessResult.status == CheckStatus.Success) + } + "A well-defined check should pass even if an ill-defined check is also configured" in withSparkSession { sparkSession => val df = getDfWithNameAndAge(sparkSession) diff --git a/src/test/scala/com/amazon/deequ/analyzers/AnalyzerTests.scala b/src/test/scala/com/amazon/deequ/analyzers/AnalyzerTests.scala index 03787b886..1c0b28d1a 100644 --- a/src/test/scala/com/amazon/deequ/analyzers/AnalyzerTests.scala +++ b/src/test/scala/com/amazon/deequ/analyzers/AnalyzerTests.scala @@ -63,7 +63,9 @@ class AnalyzerTests extends AnyWordSpec with Matchers with SparkContextSpec with val result2 = Completeness("att2").calculate(dfMissing) assert(result2 == DoubleMetric(Entity.Column, "Completeness", "att2", Success(0.75), result2.fullColumn)) - + val result3 = Completeness("att2", Option("att1 is NOT NULL")).calculate(dfMissing) + assert(result3 == DoubleMetric(Entity.Column, + "Completeness", "att2", Success(4.0/6.0), result3.fullColumn)) } "fail on wrong column input" in withSparkSession { sparkSession => diff --git a/src/test/scala/com/amazon/deequ/analyzers/UniquenessTest.scala b/src/test/scala/com/amazon/deequ/analyzers/UniquenessTest.scala index 5d6d6808f..e3492464b 100644 --- a/src/test/scala/com/amazon/deequ/analyzers/UniquenessTest.scala +++ b/src/test/scala/com/amazon/deequ/analyzers/UniquenessTest.scala @@ -117,4 +117,34 @@ class UniquenessTest extends AnyWordSpec with Matchers with SparkContextSpec wit .withColumn("new", metric.fullColumn.get).orderBy("unique") .collect().map(_.getAs[Boolean]("new")) shouldBe Seq(true, true, true, true, true, true) } + + "return filtered row-level results for uniqueness with null" in withSparkSession { session => + + val data = getDfWithUniqueColumns(session) + + val addressLength = Uniqueness(Seq("onlyUniqueWithOtherNonUnique"), Option("unique < 4")) + val state: Option[FrequenciesAndNumRows] = addressLength.computeStateFrom(data, Option("unique < 4")) + val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) + + // Adding column with UNIQUENESS_ID, since it's only added in VerificationResult.getRowLevelResults + val resultDf = data.withColumn(UNIQUENESS_ID, monotonically_increasing_id()) + .withColumn("new", metric.fullColumn.get).orderBy("unique") + resultDf + .collect().map(_.getAs[Any]("new")) shouldBe Seq(true, true, true, null, null, null) + } + + "return filtered row-level results for uniqueness with null on multiple columns" in withSparkSession { session => + + val data = getDfWithUniqueColumns(session) + + val addressLength = Uniqueness(Seq("halfUniqueCombinedWithNonUnique", "nonUnique"), Option("unique > 2")) + val state: Option[FrequenciesAndNumRows] = addressLength.computeStateFrom(data, Option("unique > 2")) + val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) + + // Adding column with UNIQUENESS_ID, since it's only added in VerificationResult.getRowLevelResults + val resultDf = data.withColumn(UNIQUENESS_ID, monotonically_increasing_id()) + .withColumn("new", metric.fullColumn.get).orderBy("unique") + resultDf + .collect().map(_.getAs[Any]("new")) shouldBe Seq(null, null, true, true, true, true) + } } diff --git a/src/test/scala/com/amazon/deequ/checks/CheckTest.scala b/src/test/scala/com/amazon/deequ/checks/CheckTest.scala index 5a21079ae..749ec7a33 100644 --- a/src/test/scala/com/amazon/deequ/checks/CheckTest.scala +++ b/src/test/scala/com/amazon/deequ/checks/CheckTest.scala @@ -62,18 +62,39 @@ class CheckTest extends AnyWordSpec with Matchers with SparkContextSpec with Fix val check3 = Check(CheckLevel.Warning, "group-2-W") .hasCompleteness("att2", _ > 0.8) // 0.75 + val check4 = Check(CheckLevel.Error, "group-3") + .isComplete("att2", None) // 1.0 with filter + .where("att2 is NOT NULL") + .hasCompleteness("att2", _ == 1.0, None) // 1.0 with filter + .where("att2 is NOT NULL") + val context = runChecks(getDfCompleteAndInCompleteColumns(sparkSession), - check1, check2, check3) + check1, check2, check3, check4) context.metricMap.foreach { println } assertEvaluatesTo(check1, context, CheckStatus.Success) assertEvaluatesTo(check2, context, CheckStatus.Error) assertEvaluatesTo(check3, context, CheckStatus.Warning) + assertEvaluatesTo(check4, context, CheckStatus.Success) assert(check1.getRowLevelConstraintColumnNames() == Seq("Completeness-att1", "Completeness-att1")) assert(check2.getRowLevelConstraintColumnNames() == Seq("Completeness-att2")) assert(check3.getRowLevelConstraintColumnNames() == Seq("Completeness-att2")) + assert(check4.getRowLevelConstraintColumnNames() == Seq("Completeness-att2", "Completeness-att2")) + } + + "return the correct check status for completeness with where" in withSparkSession { sparkSession => + + val check = Check(CheckLevel.Error, "group-3") + .hasCompleteness("ZipCode", _ > 0.6, None) // 1.0 with filter + .where("City is NOT NULL") + + val context = runChecks(getDfForWhereClause(sparkSession), check) + + assertEvaluatesTo(check, context, CheckStatus.Success) + + assert(check.getRowLevelConstraintColumnNames() == Seq("Completeness-ZipCode")) } "return the correct check status for combined completeness" in @@ -164,7 +185,6 @@ class CheckTest extends AnyWordSpec with Matchers with SparkContextSpec with Fix assert(constraintStatuses.head == ConstraintStatus.Success) assert(constraintStatuses(1) == ConstraintStatus.Success) assert(constraintStatuses(2) == ConstraintStatus.Success) - assert(constraintStatuses(3) == ConstraintStatus.Failure) assert(constraintStatuses(4) == ConstraintStatus.Failure) } @@ -515,6 +535,14 @@ class CheckTest extends AnyWordSpec with Matchers with SparkContextSpec with Fix assertEvaluatesTo(numericRangeCheck9, numericRangeResults, CheckStatus.Success) } + "correctly evaluate range constraints when values have single quote(') in string" in withSparkSession { sparkSession => + val rangeCheck = Check(CheckLevel.Error, "a") + .isContainedIn("att2", Array("can't", "help", "but", "wouldn't")) + + val rangeResults = runChecks(getDfWithDistinctValuesQuotes(sparkSession), rangeCheck) + assertEvaluatesTo(rangeCheck, rangeResults, CheckStatus.Success) + } + "return the correct check status for histogram constraints" in withSparkSession { sparkSession => From 2cdd9a5b45eee4711720a1274887551ee7167e1f Mon Sep 17 00:00:00 2001 From: Edward Cho Date: Mon, 12 Feb 2024 18:17:35 -0500 Subject: [PATCH 3/8] Adjustments for modifying the calculate method to take in a filterCondition --- .../com/amazon/deequ/analyzers/CustomSql.scala | 4 ++++ .../deequ/analyzers/DatasetMatchAnalyzer.scala | 4 ++++ .../com/amazon/deequ/analyzers/Histogram.scala | 4 ++++ .../analyzers/runners/AnalysisRunnerTests.scala | 14 ++++++++++---- .../constraints/AnalysisBasedConstraintTest.scala | 6 +++++- 5 files changed, 27 insertions(+), 5 deletions(-) diff --git a/src/main/scala/com/amazon/deequ/analyzers/CustomSql.scala b/src/main/scala/com/amazon/deequ/analyzers/CustomSql.scala index b8dc2692a..bc14bd184 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/CustomSql.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/CustomSql.scala @@ -55,6 +55,10 @@ case class CustomSql(expression: String) extends Analyzer[CustomSqlState, Double } } + override def computeStateFrom(data: DataFrame, filterCondition: Option[String]): Option[CustomSqlState] = { + computeStateFrom(data) + } + /** * Compute the metric from the state (sufficient statistics) * diff --git a/src/main/scala/com/amazon/deequ/analyzers/DatasetMatchAnalyzer.scala b/src/main/scala/com/amazon/deequ/analyzers/DatasetMatchAnalyzer.scala index cdf0e5061..bfcc5c06d 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/DatasetMatchAnalyzer.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/DatasetMatchAnalyzer.scala @@ -86,6 +86,10 @@ case class DatasetMatchAnalyzer(dfToCompare: DataFrame, } } + override def computeStateFrom(data: DataFrame, filterCondition: Option[String]): Option[DatasetMatchState] = { + computeStateFrom(data) + } + override def computeMetricFrom(state: Option[DatasetMatchState]): DoubleMetric = { val metric = state match { diff --git a/src/main/scala/com/amazon/deequ/analyzers/Histogram.scala b/src/main/scala/com/amazon/deequ/analyzers/Histogram.scala index 42a7e72e5..277b52aea 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/Histogram.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/Histogram.scala @@ -76,6 +76,10 @@ case class Histogram( Some(FrequenciesAndNumRows(frequencies, totalCount)) } + override def computeStateFrom(data: DataFrame, where: Option[String]): Option[FrequenciesAndNumRows] = { + computeStateFrom(data) + } + override def computeMetricFrom(state: Option[FrequenciesAndNumRows]): HistogramMetric = { state match { diff --git a/src/test/scala/com/amazon/deequ/analyzers/runners/AnalysisRunnerTests.scala b/src/test/scala/com/amazon/deequ/analyzers/runners/AnalysisRunnerTests.scala index 4ffc9eeb9..f19e053bc 100644 --- a/src/test/scala/com/amazon/deequ/analyzers/runners/AnalysisRunnerTests.scala +++ b/src/test/scala/com/amazon/deequ/analyzers/runners/AnalysisRunnerTests.scala @@ -137,7 +137,8 @@ class AnalysisRunnerTests extends AnyWordSpec UniqueValueRatio(Seq("att1"), Some("att3 > 0")) :: Nil val (separateResults, numSeparateJobs) = sparkMonitor.withMonitoringSession { stat => - val results = analyzers.map { _.calculate(df) }.toSet + val results = analyzers.map { analyzer => + analyzer.calculate(df, filterCondition = analyzer.filterCondition) }.toSet (results, stat.jobCount) } @@ -160,7 +161,9 @@ class AnalysisRunnerTests extends AnyWordSpec UniqueValueRatio(Seq("att1", "att2"), Some("att3 > 0")) :: Nil val (separateResults, numSeparateJobs) = sparkMonitor.withMonitoringSession { stat => - val results = analyzers.map { _.calculate(df) }.toSet + val results = analyzers.map { analyzer => + analyzer.calculate(df, filterCondition = analyzer.filterCondition) + }.toSet (results, stat.jobCount) } @@ -184,7 +187,9 @@ class AnalysisRunnerTests extends AnyWordSpec Uniqueness("att1", Some("att3 = 0")) :: Nil val (separateResults, numSeparateJobs) = sparkMonitor.withMonitoringSession { stat => - val results = analyzers.map { _.calculate(df) }.toSet + val results = analyzers.map { analyzer => + analyzer.calculate(df, filterCondition = analyzer.filterCondition) + }.toSet (results, stat.jobCount) } @@ -193,9 +198,10 @@ class AnalysisRunnerTests extends AnyWordSpec (results, stat.jobCount) } + assert(separateResults.asInstanceOf[Set[DoubleMetric]].size == runnerResults.asInstanceOf[Set[DoubleMetric]].size) assert(numSeparateJobs == analyzers.length * 2) assert(numCombinedJobs == analyzers.length * 2) - assert(separateResults.toString == runnerResults.toString) +// assert(separateResults == runnerResults.toString) } "reuse existing results" in diff --git a/src/test/scala/com/amazon/deequ/constraints/AnalysisBasedConstraintTest.scala b/src/test/scala/com/amazon/deequ/constraints/AnalysisBasedConstraintTest.scala index f8188165c..c9164ba6a 100644 --- a/src/test/scala/com/amazon/deequ/constraints/AnalysisBasedConstraintTest.scala +++ b/src/test/scala/com/amazon/deequ/constraints/AnalysisBasedConstraintTest.scala @@ -58,7 +58,8 @@ class AnalysisBasedConstraintTest extends WordSpec with Matchers with SparkConte override def calculate( data: DataFrame, stateLoader: Option[StateLoader], - statePersister: Option[StatePersister]) + statePersister: Option[StatePersister], + filterCondition: Option[String]) : DoubleMetric = { val value: Try[Double] = Try { require(data.columns.contains(column), s"Missing column $column") @@ -71,6 +72,9 @@ class AnalysisBasedConstraintTest extends WordSpec with Matchers with SparkConte throw new NotImplementedError() } + override def computeStateFrom(data: DataFrame, filterCondition: Option[String]): Option[NumMatches] = { + computeStateFrom(data) + } override def computeMetricFrom(state: Option[NumMatches]): DoubleMetric = { throw new NotImplementedError() From e13f6ba694c867a44c960a09ddd064affe837575 Mon Sep 17 00:00:00 2001 From: Edward Cho Date: Tue, 13 Feb 2024 09:48:31 -0500 Subject: [PATCH 4/8] Add RowLevelFilterTreatement trait and object to determine how filtered rows will be labeled (default True) --- .../amazon/deequ/VerificationRunBuilder.scala | 21 ++++++- .../com/amazon/deequ/analyzers/Analyzer.scala | 24 +++++--- .../amazon/deequ/analyzers/Completeness.scala | 5 +- .../deequ/analyzers/UniqueValueRatio.scala | 4 +- .../amazon/deequ/analyzers/Uniqueness.scala | 3 +- .../utilities/RowLevelFilterTreatement.scala | 32 +++++++++++ .../amazon/deequ/VerificationSuiteTest.scala | 56 ++++++++++++++++--- .../deequ/analyzers/CompletenessTest.scala | 29 +++++----- .../deequ/analyzers/UniquenessTest.scala | 37 ++++++++++++ .../runners/AnalysisRunnerTests.scala | 9 ++- .../com/amazon/deequ/checks/CheckTest.scala | 2 +- 11 files changed, 178 insertions(+), 44 deletions(-) create mode 100644 src/main/scala/com/amazon/deequ/utilities/RowLevelFilterTreatement.scala diff --git a/src/main/scala/com/amazon/deequ/VerificationRunBuilder.scala b/src/main/scala/com/amazon/deequ/VerificationRunBuilder.scala index a4ee45f6b..a4724dcf8 100644 --- a/src/main/scala/com/amazon/deequ/VerificationRunBuilder.scala +++ b/src/main/scala/com/amazon/deequ/VerificationRunBuilder.scala @@ -22,10 +22,14 @@ import com.amazon.deequ.analyzers.{State, _} import com.amazon.deequ.checks.{Check, CheckLevel} import com.amazon.deequ.metrics.Metric import com.amazon.deequ.repository._ +import com.amazon.deequ.utilities.FilteredRow +import com.amazon.deequ.utilities.FilteredRow.FilteredRow +import com.amazon.deequ.utilities.RowLevelFilterTreatment +import com.amazon.deequ.utilities.RowLevelFilterTreatmentImpl import org.apache.spark.sql.{DataFrame, SparkSession} /** A class to build a VerificationRun using a fluent API */ -class VerificationRunBuilder(val data: DataFrame) { +class VerificationRunBuilder(val data: DataFrame) extends RowLevelFilterTreatment { protected var requiredAnalyzers: Seq[Analyzer[_, Metric[_]]] = Seq.empty @@ -135,6 +139,18 @@ class VerificationRunBuilder(val data: DataFrame) { this } + /** + * Sets how row level results will be treated for the Verification run + * + * @param filteredRow enum to determine how filtered row level results are labeled (TRUE | NULL) + */ + def withRowLevelFilterTreatment(filteredRow: FilteredRow): this.type = { + RowLevelFilterTreatment.setSharedInstance(new RowLevelFilterTreatmentImpl(filteredRow)) + this + } + + def rowLevelFilterTreatment: FilteredRow.Value = RowLevelFilterTreatment.sharedInstance.rowLevelFilterTreatment + /** * Set a metrics repository associated with the current data to enable features like reusing * previously computed results and storing the results of the current run. @@ -159,7 +175,6 @@ class VerificationRunBuilder(val data: DataFrame) { new VerificationRunBuilderWithSparkSession(this, Option(sparkSession)) } - def run(): VerificationResult = { VerificationSuite().doVerificationRun( data, @@ -338,4 +353,4 @@ case class AnomalyCheckConfig( description: String, withTagValues: Map[String, String] = Map.empty, afterDate: Option[Long] = None, - beforeDate: Option[Long] = None) + beforeDate: Option[Long] = None) \ No newline at end of file diff --git a/src/main/scala/com/amazon/deequ/analyzers/Analyzer.scala b/src/main/scala/com/amazon/deequ/analyzers/Analyzer.scala index 3e88849c4..327a7a14b 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/Analyzer.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/Analyzer.scala @@ -17,7 +17,6 @@ package com.amazon.deequ.analyzers import com.amazon.deequ.analyzers.Analyzers._ -import com.amazon.deequ.analyzers.FilteredRow.FilteredRow import com.amazon.deequ.analyzers.NullBehavior.NullBehavior import com.amazon.deequ.analyzers.runners._ import com.amazon.deequ.metrics.DoubleMetric @@ -25,6 +24,11 @@ import com.amazon.deequ.metrics.Entity import com.amazon.deequ.metrics.FullColumn import com.amazon.deequ.metrics.Metric import com.amazon.deequ.utilities.ColumnUtil.removeEscapeColumn +import com.amazon.deequ.utilities.FilteredRow +import com.amazon.deequ.utilities.FilteredRow.FilteredRow +import com.amazon.deequ.utilities.RowLevelFilterTreatment +import com.amazon.deequ.utilities.RowLevelFilterTreatmentImpl +import com.google.common.annotations.VisibleForTesting import org.apache.spark.sql.Column import org.apache.spark.sql.DataFrame import org.apache.spark.sql.Row @@ -63,7 +67,7 @@ trait DoubleValuedState[S <: DoubleValuedState[S]] extends State[S] { } /** Common trait for all analyzers which generates metrics from states computed on data frames */ -trait Analyzer[S <: State[_], +M <: Metric[_]] extends Serializable { +trait Analyzer[S <: State[_], +M <: Metric[_]] extends Serializable with RowLevelFilterTreatment { /** * Compute the state (sufficient statistics) from the data @@ -175,6 +179,14 @@ trait Analyzer[S <: State[_], +M <: Metric[_]] extends Serializable { source.load[S](this).foreach { state => target.persist(this, state) } } + @VisibleForTesting + private[deequ] def withRowLevelFilterTreatment(filteredRow: FilteredRow): this.type = { + RowLevelFilterTreatment.setSharedInstance(new RowLevelFilterTreatmentImpl(filteredRow)) + this + } + + def rowLevelFilterTreatment: FilteredRow.Value = RowLevelFilterTreatment.sharedInstance.rowLevelFilterTreatment + } /** An analyzer that runs a set of aggregation functions over the data, @@ -263,18 +275,12 @@ case class NumMatchesAndCount(numMatches: Long, count: Long, override val fullCo } } -case class AnalyzerOptions(nullBehavior: NullBehavior = NullBehavior.Ignore, - filteredRow: FilteredRow = FilteredRow.NULL) +case class AnalyzerOptions(nullBehavior: NullBehavior = NullBehavior.Ignore) object NullBehavior extends Enumeration { type NullBehavior = Value val Ignore, EmptyString, Fail = Value } -object FilteredRow extends Enumeration { - type FilteredRow = Value - val NULL, TRUE = Value -} - /** Base class for analyzers that compute ratios of matching predicates */ abstract class PredicateMatchingAnalyzer( name: String, diff --git a/src/main/scala/com/amazon/deequ/analyzers/Completeness.scala b/src/main/scala/com/amazon/deequ/analyzers/Completeness.scala index 342f92a67..7107c834b 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/Completeness.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/Completeness.scala @@ -20,9 +20,7 @@ import com.amazon.deequ.analyzers.Preconditions.{hasColumn, isNotNested} import org.apache.spark.sql.functions.sum import org.apache.spark.sql.types.{IntegerType, StructType} import Analyzers._ -import com.amazon.deequ.analyzers.FilteredRow.FilteredRow import com.google.common.annotations.VisibleForTesting -import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions.col import org.apache.spark.sql.functions.expr import org.apache.spark.sql.{Column, Row} @@ -56,9 +54,8 @@ case class Completeness(column: String, where: Option[String] = None) extends @VisibleForTesting private[deequ] def rowLevelResults: Column = { - val filteredRow = FilteredRow.NULL val whereCondition = where.map { expression => expr(expression)} - conditionalSelectionFilteredFromColumns(col(column).isNotNull, whereCondition, filteredRow.toString) + conditionalSelectionFilteredFromColumns(col(column).isNotNull, whereCondition, rowLevelFilterTreatment.toString) } } diff --git a/src/main/scala/com/amazon/deequ/analyzers/UniqueValueRatio.scala b/src/main/scala/com/amazon/deequ/analyzers/UniqueValueRatio.scala index d31cdd2c7..6cfdc6383 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/UniqueValueRatio.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/UniqueValueRatio.scala @@ -39,7 +39,9 @@ case class UniqueValueRatio(columns: Seq[String], where: Option[String] = None) val conditionColumn = where.map { expression => expr(expression) } val fullColumnUniqueness = conditionColumn.map { condition => { - when(not(condition), expr(FilteredRow.NULL.toString)).when((fullColumn.getOrElse(null)).equalTo(1), true).otherwise(false) + when(not(condition), expr(rowLevelFilterTreatment.toString)) + .when((fullColumn.getOrElse(null)).equalTo(1), true) + .otherwise(false) } }.getOrElse(when((fullColumn.getOrElse(null)).equalTo(1), true).otherwise(false)) toSuccessMetric(numUniqueValues / numDistinctValues, Option(fullColumnUniqueness)) diff --git a/src/main/scala/com/amazon/deequ/analyzers/Uniqueness.scala b/src/main/scala/com/amazon/deequ/analyzers/Uniqueness.scala index 88772a413..f87069d2e 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/Uniqueness.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/Uniqueness.scala @@ -17,7 +17,6 @@ package com.amazon.deequ.analyzers import com.amazon.deequ.analyzers.Analyzers.COUNT_COL -import com.amazon.deequ.analyzers.Analyzers.conditionalCount import com.amazon.deequ.metrics.DoubleMetric import org.apache.spark.sql.Column import org.apache.spark.sql.Row @@ -43,7 +42,7 @@ case class Uniqueness(columns: Seq[String], where: Option[String] = None) val conditionColumn = where.map { expression => expr(expression) } val fullColumnUniqueness = conditionColumn.map { condition => { - when(not(condition), expr(FilteredRow.NULL.toString)).when((fullColumn.getOrElse(null)).equalTo(1), true).otherwise(false) + when(not(condition), expr(rowLevelFilterTreatment.toString)).when((fullColumn.getOrElse(null)).equalTo(1), true).otherwise(false) } }.getOrElse(when((fullColumn.getOrElse(null)).equalTo(1), true).otherwise(false)) super.fromAggregationResult(result, offset, Option(fullColumnUniqueness)) diff --git a/src/main/scala/com/amazon/deequ/utilities/RowLevelFilterTreatement.scala b/src/main/scala/com/amazon/deequ/utilities/RowLevelFilterTreatement.scala new file mode 100644 index 000000000..8a2459bf3 --- /dev/null +++ b/src/main/scala/com/amazon/deequ/utilities/RowLevelFilterTreatement.scala @@ -0,0 +1,32 @@ +package com.amazon.deequ.utilities +import com.amazon.deequ.utilities.FilteredRow.FilteredRow + +/** + * Trait that defines how row level results will be treated when a filter is applied to an analyzer + */ +trait RowLevelFilterTreatment { + def rowLevelFilterTreatment: FilteredRow +} + +/** + * Companion object for RowLevelFilterTreatment + * Defines a sharedInstance that can be used throughout the VerificationRunBuilder + */ +object RowLevelFilterTreatment { + private var _sharedInstance: RowLevelFilterTreatment = new RowLevelFilterTreatmentImpl(FilteredRow.TRUE) + + def sharedInstance: RowLevelFilterTreatment = _sharedInstance + + def setSharedInstance(instance: RowLevelFilterTreatment): Unit = { + _sharedInstance = instance + } +} + +class RowLevelFilterTreatmentImpl(initialFilterTreatment: FilteredRow) extends RowLevelFilterTreatment { + override val rowLevelFilterTreatment: FilteredRow = initialFilterTreatment +} + +object FilteredRow extends Enumeration { + type FilteredRow = Value + val NULL, TRUE = Value +} \ No newline at end of file diff --git a/src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala b/src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala index 3f2f14b66..a4a521353 100644 --- a/src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala +++ b/src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala @@ -29,6 +29,7 @@ import com.amazon.deequ.metrics.Entity import com.amazon.deequ.repository.MetricsRepository import com.amazon.deequ.repository.ResultKey import com.amazon.deequ.repository.memory.InMemoryMetricsRepository +import com.amazon.deequ.utilities.FilteredRow import com.amazon.deequ.utils.CollectionUtils.SeqExtensions import com.amazon.deequ.utils.FixtureSupport import com.amazon.deequ.utils.TempFileUtils @@ -304,7 +305,7 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec assert(Seq(true, true, true, false, false, false).sameElements(rowLevel8)) } - "generate a result that contains row-level results with filter with null for filtered rows" in withSparkSession { session => + "generate a result that contains row-level results with filter with true for filtered rows" in withSparkSession { session => val data = getDfCompleteAndInCompleteColumns(session) val completeness = new Check(CheckLevel.Error, "rule1").hasCompleteness("att2", _ > 0.7, None).where("att1 = \"a\"") @@ -325,7 +326,44 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec assert(result.status == CheckStatus.Error) val resultData = VerificationResult.rowLevelResultsAsDataFrame(session, result, data).orderBy("item") + resultData.show(false) + val expectedColumns: Set[String] = + data.columns.toSet + expectedColumn1 + expectedColumn2 + expectedColumn3 + assert(resultData.columns.toSet == expectedColumns) + + val rowLevel1 = resultData.select(expectedColumn1).collect().map(r => r.getAs[Any](0)) + assert(Seq(true, true, false, true, true, true).sameElements(rowLevel1)) + val rowLevel2 = resultData.select(expectedColumn2).collect().map(r => r.getAs[Any](0)) + assert(Seq(false, false, false, false, false, false).sameElements(rowLevel2)) + + val rowLevel3 = resultData.select(expectedColumn3).collect().map(r => r.getAs[Any](0)) + assert(Seq(true, true, true, true, true, true).sameElements(rowLevel3)) + + } + + "generate a result that contains row-level results with filter with null for filtered rows" in withSparkSession { session => + val data = getDfCompleteAndInCompleteColumns(session) + + val completeness = new Check(CheckLevel.Error, "rule1").hasCompleteness("att2", _ > 0.7, None).where("att1 = \"a\"") + val uniqueness = new Check(CheckLevel.Error, "rule2").hasUniqueness("att1", _ > 0.5, None) + val uniquenessWhere = new Check(CheckLevel.Error, "rule3").isUnique("att1").where("item < 3") + val expectedColumn1 = completeness.description + val expectedColumn2 = uniqueness.description + val expectedColumn3 = uniquenessWhere.description + + val suite = new VerificationSuite().onData(data) + .withRowLevelFilterTreatment(FilteredRow.NULL) + .addCheck(completeness) + .addCheck(uniqueness) + .addCheck(uniquenessWhere) + + val result: VerificationResult = suite.run() + + assert(result.status == CheckStatus.Error) + + val resultData = VerificationResult.rowLevelResultsAsDataFrame(session, result, data).orderBy("item") + resultData.show(false) val expectedColumns: Set[String] = data.columns.toSet + expectedColumn1 + expectedColumn2 + expectedColumn3 assert(resultData.columns.toSet == expectedColumns) @@ -499,15 +537,15 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec "accept analysis config for mandatory analysis for checks with filters" in withSparkSession { sparkSession => import sparkSession.implicits._ - val df = getDfFull(sparkSession) + val df = getDfCompleteAndInCompleteColumns(sparkSession) val result = { val checkToSucceed = Check(CheckLevel.Error, "group-1") - .hasCompleteness("att1", _ > 0.5, null) // 1.0 - .where("att2 = \"c\"") + .hasCompleteness("att2", _ > 0.7, null) // 0.75 + .where("att1 = \"a\"") val uniquenessCheck = Check(CheckLevel.Error, "group-2") .isUnique("att1") - .where("item > 3") + .where("item < 3") VerificationSuite().onData(df).addCheck(checkToSucceed).addCheck(uniquenessCheck).run() @@ -519,8 +557,8 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec AnalyzerContext(result.metrics)) val expected = Seq( - ("Column", "att1", "Completeness (where: att2 = \"c\")", 1.0), - ("Column", "att1", "Uniqueness (where: item > 3)", 1.0)) + ("Column", "att2", "Completeness (where: att1 = \"a\")", 0.75), + ("Column", "att1", "Uniqueness (where: item < 3)", 1.0)) .toDF("entity", "instance", "name", "value") @@ -872,12 +910,12 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec .run() val checkSuccessResult = verificationResult.checkResults(rangeCheck) -// checkSuccessResult.constraintResults.map(_.message) shouldBe List(None) + checkSuccessResult.constraintResults.map(_.message) shouldBe List(None) println(checkSuccessResult.constraintResults.map(_.message)) assert(checkSuccessResult.status == CheckStatus.Success) val reasonResult = verificationResult.checkResults(reasonCheck) - // checkSuccessResult.constraintResults.map(_.message) shouldBe List(None) + checkSuccessResult.constraintResults.map(_.message) shouldBe List(None) println(checkSuccessResult.constraintResults.map(_.message)) assert(checkSuccessResult.status == CheckStatus.Success) } diff --git a/src/test/scala/com/amazon/deequ/analyzers/CompletenessTest.scala b/src/test/scala/com/amazon/deequ/analyzers/CompletenessTest.scala index de33f2327..f7084ccb0 100644 --- a/src/test/scala/com/amazon/deequ/analyzers/CompletenessTest.scala +++ b/src/test/scala/com/amazon/deequ/analyzers/CompletenessTest.scala @@ -19,6 +19,7 @@ package com.amazon.deequ.analyzers import com.amazon.deequ.SparkContextSpec import com.amazon.deequ.metrics.DoubleMetric import com.amazon.deequ.metrics.FullColumn +import com.amazon.deequ.utilities.FilteredRow import com.amazon.deequ.utils.FixtureSupport import org.scalatest.matchers.should.Matchers import org.scalatest.wordspec.AnyWordSpec @@ -40,32 +41,34 @@ class CompletenessTest extends AnyWordSpec with Matchers with SparkContextSpec w Seq(true, true, true, true, false, true, true, false) } - "return row-level results for columns filtered using where" in withSparkSession { session => + "return row-level results for columns filtered as null" in withSparkSession { session => - val data = getDfForWhereClause(session) + val data = getDfCompleteAndInCompleteColumns(session) - val completenessCountry = Completeness("ZipCode", Option("State = \"CA\"")) - val state = completenessCountry.computeStateFrom(data) - val metric: DoubleMetric with FullColumn = completenessCountry.computeMetricFrom(state) + // Explicitly setting RowLevelFilterTreatment for test purposes, but this should be set at the VerificationRunBuilder + val completenessAtt2 = Completeness("att2", Option("att1 = \"a\"")).withRowLevelFilterTreatment(FilteredRow.NULL) + val state = completenessAtt2.computeStateFrom(data) + val metric: DoubleMetric with FullColumn = completenessAtt2.computeMetricFrom(state) - // Address Line 3 is null only where Address Line 2 is null. With the where clause, completeness should be 1.0 - data.withColumn("new", metric.fullColumn.get).collect().map(_.getAs[Boolean]("new")) shouldBe - Seq(true, true, false, false) + val df = data.withColumn("new", metric.fullColumn.get) + df.show(false) + df.collect().map(_.getAs[Any]("new")).toSeq shouldBe + Seq(true, null, false, true, null, true) } - "return row-level results for columns filtered as null" in withSparkSession { session => + "return row-level results for columns filtered as true" in withSparkSession { session => val data = getDfCompleteAndInCompleteColumns(session) - val completenessAtt2 = Completeness("att2", Option("att1 = \"a\"")) + // Explicitly setting RowLevelFilterTreatment for test purposes, but this should be set at the VerificationRunBuilder + val completenessAtt2 = Completeness("att2", Option("att1 = \"a\"")).withRowLevelFilterTreatment(FilteredRow.TRUE) val state = completenessAtt2.computeStateFrom(data) val metric: DoubleMetric with FullColumn = completenessAtt2.computeMetricFrom(state) val df = data.withColumn("new", metric.fullColumn.get) - println("Filtered as null") df.show(false) - df.collect().map(_.getAs[Any]("new")).toSeq shouldBe - Seq(true, null, false, true, null, true) + df.collect().map(_.getAs[Any]("new")).toSeq shouldBe + Seq(true, true, false, true, true, true) } } } diff --git a/src/test/scala/com/amazon/deequ/analyzers/UniquenessTest.scala b/src/test/scala/com/amazon/deequ/analyzers/UniquenessTest.scala index e3492464b..bd4a39af4 100644 --- a/src/test/scala/com/amazon/deequ/analyzers/UniquenessTest.scala +++ b/src/test/scala/com/amazon/deequ/analyzers/UniquenessTest.scala @@ -21,6 +21,7 @@ import com.amazon.deequ.VerificationResult.UNIQUENESS_ID import com.amazon.deequ.analyzers.runners.AnalysisRunner import com.amazon.deequ.metrics.DoubleMetric import com.amazon.deequ.metrics.FullColumn +import com.amazon.deequ.utilities.FilteredRow import com.amazon.deequ.utils.FixtureSupport import org.apache.spark.sql.DataFrame import org.apache.spark.sql.SparkSession @@ -123,6 +124,7 @@ class UniquenessTest extends AnyWordSpec with Matchers with SparkContextSpec wit val data = getDfWithUniqueColumns(session) val addressLength = Uniqueness(Seq("onlyUniqueWithOtherNonUnique"), Option("unique < 4")) + .withRowLevelFilterTreatment(FilteredRow.NULL) val state: Option[FrequenciesAndNumRows] = addressLength.computeStateFrom(data, Option("unique < 4")) val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) @@ -138,6 +140,7 @@ class UniquenessTest extends AnyWordSpec with Matchers with SparkContextSpec wit val data = getDfWithUniqueColumns(session) val addressLength = Uniqueness(Seq("halfUniqueCombinedWithNonUnique", "nonUnique"), Option("unique > 2")) + .withRowLevelFilterTreatment(FilteredRow.NULL) val state: Option[FrequenciesAndNumRows] = addressLength.computeStateFrom(data, Option("unique > 2")) val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) @@ -147,4 +150,38 @@ class UniquenessTest extends AnyWordSpec with Matchers with SparkContextSpec wit resultDf .collect().map(_.getAs[Any]("new")) shouldBe Seq(null, null, true, true, true, true) } + + "return filtered row-level results for uniqueness true null" in withSparkSession { session => + + val data = getDfWithUniqueColumns(session) + + // Explicitly setting RowLevelFilterTreatment for test purposes, but this should be set at the VerificationRunBuilder + val addressLength = Uniqueness(Seq("onlyUniqueWithOtherNonUnique"), Option("unique < 4")) + .withRowLevelFilterTreatment(FilteredRow.TRUE) + val state: Option[FrequenciesAndNumRows] = addressLength.computeStateFrom(data, Option("unique < 4")) + val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) + + // Adding column with UNIQUENESS_ID, since it's only added in VerificationResult.getRowLevelResults + val resultDf = data.withColumn(UNIQUENESS_ID, monotonically_increasing_id()) + .withColumn("new", metric.fullColumn.get).orderBy("unique") + resultDf + .collect().map(_.getAs[Any]("new")) shouldBe Seq(true, true, true, true, true, true) + } + + "return filtered row-level results for uniqueness with true on multiple columns" in withSparkSession { session => + + val data = getDfWithUniqueColumns(session) + + // Explicitly setting RowLevelFilterTreatment for test purposes, but this should be set at the VerificationRunBuilder + val addressLength = Uniqueness(Seq("halfUniqueCombinedWithNonUnique", "nonUnique"), Option("unique > 2")) + .withRowLevelFilterTreatment(FilteredRow.TRUE) + val state: Option[FrequenciesAndNumRows] = addressLength.computeStateFrom(data, Option("unique > 2")) + val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) + + // Adding column with UNIQUENESS_ID, since it's only added in VerificationResult.getRowLevelResults + val resultDf = data.withColumn(UNIQUENESS_ID, monotonically_increasing_id()) + .withColumn("new", metric.fullColumn.get).orderBy("unique") + resultDf + .collect().map(_.getAs[Any]("new")) shouldBe Seq(true, true, true, true, true, true) + } } diff --git a/src/test/scala/com/amazon/deequ/analyzers/runners/AnalysisRunnerTests.scala b/src/test/scala/com/amazon/deequ/analyzers/runners/AnalysisRunnerTests.scala index f19e053bc..66bfd9693 100644 --- a/src/test/scala/com/amazon/deequ/analyzers/runners/AnalysisRunnerTests.scala +++ b/src/test/scala/com/amazon/deequ/analyzers/runners/AnalysisRunnerTests.scala @@ -198,10 +198,15 @@ class AnalysisRunnerTests extends AnyWordSpec (results, stat.jobCount) } - assert(separateResults.asInstanceOf[Set[DoubleMetric]].size == runnerResults.asInstanceOf[Set[DoubleMetric]].size) assert(numSeparateJobs == analyzers.length * 2) assert(numCombinedJobs == analyzers.length * 2) -// assert(separateResults == runnerResults.toString) + // assert(separateResults == runnerResults.toString) + // Used to be tested with the above line, but adding filters changed the order of the results. + assert(separateResults.asInstanceOf[Set[DoubleMetric]].size == runnerResults.asInstanceOf[Set[DoubleMetric]].size) + separateResults.asInstanceOf[Set[DoubleMetric]].foreach( result => { + assert(runnerResults.toString.contains(result.toString)) + } + ) } "reuse existing results" in diff --git a/src/test/scala/com/amazon/deequ/checks/CheckTest.scala b/src/test/scala/com/amazon/deequ/checks/CheckTest.scala index 749ec7a33..67cef4889 100644 --- a/src/test/scala/com/amazon/deequ/checks/CheckTest.scala +++ b/src/test/scala/com/amazon/deequ/checks/CheckTest.scala @@ -84,7 +84,7 @@ class CheckTest extends AnyWordSpec with Matchers with SparkContextSpec with Fix assert(check4.getRowLevelConstraintColumnNames() == Seq("Completeness-att2", "Completeness-att2")) } - "return the correct check status for completeness with where" in withSparkSession { sparkSession => + "return the correct check status for completeness with where filter" in withSparkSession { sparkSession => val check = Check(CheckLevel.Error, "group-3") .hasCompleteness("ZipCode", _ > 0.6, None) // 1.0 with filter From 5f715aca571c90c5f0d19fe02d18f0968fbec8bd Mon Sep 17 00:00:00 2001 From: Edward Cho Date: Tue, 13 Feb 2024 10:51:42 -0500 Subject: [PATCH 5/8] Style fixes --- .../amazon/deequ/VerificationRunBuilder.scala | 2 +- .../amazon/deequ/analyzers/Uniqueness.scala | 3 ++- .../utilities/RowLevelFilterTreatement.scala | 18 ++++++++++++- .../amazon/deequ/VerificationSuiteTest.scala | 26 +++++++++++++------ .../deequ/analyzers/CompletenessTest.scala | 4 +-- .../deequ/analyzers/UniquenessTest.scala | 4 +-- .../runners/AnalysisRunnerTests.scala | 3 ++- .../com/amazon/deequ/checks/CheckTest.scala | 2 +- 8 files changed, 45 insertions(+), 17 deletions(-) diff --git a/src/main/scala/com/amazon/deequ/VerificationRunBuilder.scala b/src/main/scala/com/amazon/deequ/VerificationRunBuilder.scala index a4724dcf8..558000806 100644 --- a/src/main/scala/com/amazon/deequ/VerificationRunBuilder.scala +++ b/src/main/scala/com/amazon/deequ/VerificationRunBuilder.scala @@ -353,4 +353,4 @@ case class AnomalyCheckConfig( description: String, withTagValues: Map[String, String] = Map.empty, afterDate: Option[Long] = None, - beforeDate: Option[Long] = None) \ No newline at end of file + beforeDate: Option[Long] = None) diff --git a/src/main/scala/com/amazon/deequ/analyzers/Uniqueness.scala b/src/main/scala/com/amazon/deequ/analyzers/Uniqueness.scala index f87069d2e..f62476dac 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/Uniqueness.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/Uniqueness.scala @@ -42,7 +42,8 @@ case class Uniqueness(columns: Seq[String], where: Option[String] = None) val conditionColumn = where.map { expression => expr(expression) } val fullColumnUniqueness = conditionColumn.map { condition => { - when(not(condition), expr(rowLevelFilterTreatment.toString)).when((fullColumn.getOrElse(null)).equalTo(1), true).otherwise(false) + when(not(condition), expr(rowLevelFilterTreatment.toString)) + .when(fullColumn.getOrElse(null).equalTo(1), true).otherwise(false) } }.getOrElse(when((fullColumn.getOrElse(null)).equalTo(1), true).otherwise(false)) super.fromAggregationResult(result, offset, Option(fullColumnUniqueness)) diff --git a/src/main/scala/com/amazon/deequ/utilities/RowLevelFilterTreatement.scala b/src/main/scala/com/amazon/deequ/utilities/RowLevelFilterTreatement.scala index 8a2459bf3..45ce0ce90 100644 --- a/src/main/scala/com/amazon/deequ/utilities/RowLevelFilterTreatement.scala +++ b/src/main/scala/com/amazon/deequ/utilities/RowLevelFilterTreatement.scala @@ -1,3 +1,19 @@ +/** + * Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not + * use this file except in compliance with the License. A copy of the License + * is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + * + */ + package com.amazon.deequ.utilities import com.amazon.deequ.utilities.FilteredRow.FilteredRow @@ -29,4 +45,4 @@ class RowLevelFilterTreatmentImpl(initialFilterTreatment: FilteredRow) extends R object FilteredRow extends Enumeration { type FilteredRow = Value val NULL, TRUE = Value -} \ No newline at end of file +} diff --git a/src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala b/src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala index a4a521353..5701b456b 100644 --- a/src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala +++ b/src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala @@ -305,12 +305,17 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec assert(Seq(true, true, true, false, false, false).sameElements(rowLevel8)) } - "generate a result that contains row-level results with filter with true for filtered rows" in withSparkSession { session => + "generate a result that contains row-level results with true for filtered rows" in withSparkSession { session => val data = getDfCompleteAndInCompleteColumns(session) - val completeness = new Check(CheckLevel.Error, "rule1").hasCompleteness("att2", _ > 0.7, None).where("att1 = \"a\"") - val uniqueness = new Check(CheckLevel.Error, "rule2").hasUniqueness("att1", _ > 0.5, None) - val uniquenessWhere = new Check(CheckLevel.Error, "rule3").isUnique("att1").where("item < 3") + val completeness = new Check(CheckLevel.Error, "rule1") + .hasCompleteness("att2", _ > 0.7, None) + .where("att1 = \"a\"") + val uniqueness = new Check(CheckLevel.Error, "rule2") + .hasUniqueness("att1", _ > 0.5, None) + val uniquenessWhere = new Check(CheckLevel.Error, "rule3") + .isUnique("att1") + .where("item < 3") val expectedColumn1 = completeness.description val expectedColumn2 = uniqueness.description val expectedColumn3 = uniquenessWhere.description @@ -342,12 +347,17 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec } - "generate a result that contains row-level results with filter with null for filtered rows" in withSparkSession { session => + "generate a result that contains row-level results with null for filtered rows" in withSparkSession { session => val data = getDfCompleteAndInCompleteColumns(session) - val completeness = new Check(CheckLevel.Error, "rule1").hasCompleteness("att2", _ > 0.7, None).where("att1 = \"a\"") - val uniqueness = new Check(CheckLevel.Error, "rule2").hasUniqueness("att1", _ > 0.5, None) - val uniquenessWhere = new Check(CheckLevel.Error, "rule3").isUnique("att1").where("item < 3") + val completeness = new Check(CheckLevel.Error, "rule1") + .hasCompleteness("att2", _ > 0.7, None) + .where("att1 = \"a\"") + val uniqueness = new Check(CheckLevel.Error, "rule2") + .hasUniqueness("att1", _ > 0.5, None) + val uniquenessWhere = new Check(CheckLevel.Error, "rule3") + .isUnique("att1") + .where("item < 3") val expectedColumn1 = completeness.description val expectedColumn2 = uniqueness.description val expectedColumn3 = uniquenessWhere.description diff --git a/src/test/scala/com/amazon/deequ/analyzers/CompletenessTest.scala b/src/test/scala/com/amazon/deequ/analyzers/CompletenessTest.scala index f7084ccb0..cb2778a1b 100644 --- a/src/test/scala/com/amazon/deequ/analyzers/CompletenessTest.scala +++ b/src/test/scala/com/amazon/deequ/analyzers/CompletenessTest.scala @@ -45,7 +45,7 @@ class CompletenessTest extends AnyWordSpec with Matchers with SparkContextSpec w val data = getDfCompleteAndInCompleteColumns(session) - // Explicitly setting RowLevelFilterTreatment for test purposes, but this should be set at the VerificationRunBuilder + // Explicitly setting RowLevelFilterTreatment for test purposes, this should be set at the VerificationRunBuilder val completenessAtt2 = Completeness("att2", Option("att1 = \"a\"")).withRowLevelFilterTreatment(FilteredRow.NULL) val state = completenessAtt2.computeStateFrom(data) val metric: DoubleMetric with FullColumn = completenessAtt2.computeMetricFrom(state) @@ -60,7 +60,7 @@ class CompletenessTest extends AnyWordSpec with Matchers with SparkContextSpec w val data = getDfCompleteAndInCompleteColumns(session) - // Explicitly setting RowLevelFilterTreatment for test purposes, but this should be set at the VerificationRunBuilder + // Explicitly setting RowLevelFilterTreatment for test purposes, this should be set at the VerificationRunBuilder val completenessAtt2 = Completeness("att2", Option("att1 = \"a\"")).withRowLevelFilterTreatment(FilteredRow.TRUE) val state = completenessAtt2.computeStateFrom(data) val metric: DoubleMetric with FullColumn = completenessAtt2.computeMetricFrom(state) diff --git a/src/test/scala/com/amazon/deequ/analyzers/UniquenessTest.scala b/src/test/scala/com/amazon/deequ/analyzers/UniquenessTest.scala index bd4a39af4..7be9b4b35 100644 --- a/src/test/scala/com/amazon/deequ/analyzers/UniquenessTest.scala +++ b/src/test/scala/com/amazon/deequ/analyzers/UniquenessTest.scala @@ -155,7 +155,7 @@ class UniquenessTest extends AnyWordSpec with Matchers with SparkContextSpec wit val data = getDfWithUniqueColumns(session) - // Explicitly setting RowLevelFilterTreatment for test purposes, but this should be set at the VerificationRunBuilder + // Explicitly setting RowLevelFilterTreatment for test purposes, this should be set at the VerificationRunBuilder val addressLength = Uniqueness(Seq("onlyUniqueWithOtherNonUnique"), Option("unique < 4")) .withRowLevelFilterTreatment(FilteredRow.TRUE) val state: Option[FrequenciesAndNumRows] = addressLength.computeStateFrom(data, Option("unique < 4")) @@ -172,7 +172,7 @@ class UniquenessTest extends AnyWordSpec with Matchers with SparkContextSpec wit val data = getDfWithUniqueColumns(session) - // Explicitly setting RowLevelFilterTreatment for test purposes, but this should be set at the VerificationRunBuilder + // Explicitly setting RowLevelFilterTreatment for test purposes, this should be set at the VerificationRunBuilder val addressLength = Uniqueness(Seq("halfUniqueCombinedWithNonUnique", "nonUnique"), Option("unique > 2")) .withRowLevelFilterTreatment(FilteredRow.TRUE) val state: Option[FrequenciesAndNumRows] = addressLength.computeStateFrom(data, Option("unique > 2")) diff --git a/src/test/scala/com/amazon/deequ/analyzers/runners/AnalysisRunnerTests.scala b/src/test/scala/com/amazon/deequ/analyzers/runners/AnalysisRunnerTests.scala index 66bfd9693..31b7365ad 100644 --- a/src/test/scala/com/amazon/deequ/analyzers/runners/AnalysisRunnerTests.scala +++ b/src/test/scala/com/amazon/deequ/analyzers/runners/AnalysisRunnerTests.scala @@ -202,7 +202,8 @@ class AnalysisRunnerTests extends AnyWordSpec assert(numCombinedJobs == analyzers.length * 2) // assert(separateResults == runnerResults.toString) // Used to be tested with the above line, but adding filters changed the order of the results. - assert(separateResults.asInstanceOf[Set[DoubleMetric]].size == runnerResults.asInstanceOf[Set[DoubleMetric]].size) + assert(separateResults.asInstanceOf[Set[DoubleMetric]].size == + runnerResults.asInstanceOf[Set[DoubleMetric]].size) separateResults.asInstanceOf[Set[DoubleMetric]].foreach( result => { assert(runnerResults.toString.contains(result.toString)) } diff --git a/src/test/scala/com/amazon/deequ/checks/CheckTest.scala b/src/test/scala/com/amazon/deequ/checks/CheckTest.scala index 67cef4889..bc20b0954 100644 --- a/src/test/scala/com/amazon/deequ/checks/CheckTest.scala +++ b/src/test/scala/com/amazon/deequ/checks/CheckTest.scala @@ -535,7 +535,7 @@ class CheckTest extends AnyWordSpec with Matchers with SparkContextSpec with Fix assertEvaluatesTo(numericRangeCheck9, numericRangeResults, CheckStatus.Success) } - "correctly evaluate range constraints when values have single quote(') in string" in withSparkSession { sparkSession => + "correctly evaluate range constraints when values have single quote in string" in withSparkSession { sparkSession => val rangeCheck = Check(CheckLevel.Error, "a") .isContainedIn("att2", Array("can't", "help", "but", "wouldn't")) From e5b7821e457a0a123dce15445b9617e9614b80f4 Mon Sep 17 00:00:00 2001 From: Edward Cho Date: Tue, 13 Feb 2024 15:20:50 -0500 Subject: [PATCH 6/8] Modify VerificationRunBuilder to have RowLevelFilterTreatment as variable instead of extending, create RowLevelAnalyzer trait --- .../com/amazon/deequ/VerificationRunBuilder.scala | 8 ++++---- .../com/amazon/deequ/analyzers/Analyzer.scala | 11 +---------- .../com/amazon/deequ/analyzers/Completeness.scala | 11 ++++++++++- .../amazon/deequ/analyzers/UniqueValueRatio.scala | 3 ++- .../com/amazon/deequ/analyzers/Uniqueness.scala | 15 ++++++++++++++- .../utilities/RowLevelFilterTreatement.scala | 4 ++++ 6 files changed, 35 insertions(+), 17 deletions(-) diff --git a/src/main/scala/com/amazon/deequ/VerificationRunBuilder.scala b/src/main/scala/com/amazon/deequ/VerificationRunBuilder.scala index 558000806..fb31651c8 100644 --- a/src/main/scala/com/amazon/deequ/VerificationRunBuilder.scala +++ b/src/main/scala/com/amazon/deequ/VerificationRunBuilder.scala @@ -22,14 +22,13 @@ import com.amazon.deequ.analyzers.{State, _} import com.amazon.deequ.checks.{Check, CheckLevel} import com.amazon.deequ.metrics.Metric import com.amazon.deequ.repository._ -import com.amazon.deequ.utilities.FilteredRow import com.amazon.deequ.utilities.FilteredRow.FilteredRow import com.amazon.deequ.utilities.RowLevelFilterTreatment import com.amazon.deequ.utilities.RowLevelFilterTreatmentImpl import org.apache.spark.sql.{DataFrame, SparkSession} /** A class to build a VerificationRun using a fluent API */ -class VerificationRunBuilder(val data: DataFrame) extends RowLevelFilterTreatment { +class VerificationRunBuilder(val data: DataFrame) { protected var requiredAnalyzers: Seq[Analyzer[_, Metric[_]]] = Seq.empty @@ -48,6 +47,7 @@ class VerificationRunBuilder(val data: DataFrame) extends RowLevelFilterTreatmen protected var statePersister: Option[StatePersister] = None protected var stateLoader: Option[StateLoader] = None + protected var rowLevelFilterTreatment: RowLevelFilterTreatment = RowLevelFilterTreatment.sharedInstance protected def this(verificationRunBuilder: VerificationRunBuilder) { @@ -70,6 +70,7 @@ class VerificationRunBuilder(val data: DataFrame) extends RowLevelFilterTreatmen stateLoader = verificationRunBuilder.stateLoader statePersister = verificationRunBuilder.statePersister + rowLevelFilterTreatment = verificationRunBuilder.rowLevelFilterTreatment } /** @@ -146,11 +147,10 @@ class VerificationRunBuilder(val data: DataFrame) extends RowLevelFilterTreatmen */ def withRowLevelFilterTreatment(filteredRow: FilteredRow): this.type = { RowLevelFilterTreatment.setSharedInstance(new RowLevelFilterTreatmentImpl(filteredRow)) + rowLevelFilterTreatment = RowLevelFilterTreatment.sharedInstance this } - def rowLevelFilterTreatment: FilteredRow.Value = RowLevelFilterTreatment.sharedInstance.rowLevelFilterTreatment - /** * Set a metrics repository associated with the current data to enable features like reusing * previously computed results and storing the results of the current run. diff --git a/src/main/scala/com/amazon/deequ/analyzers/Analyzer.scala b/src/main/scala/com/amazon/deequ/analyzers/Analyzer.scala index 327a7a14b..028579426 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/Analyzer.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/Analyzer.scala @@ -67,7 +67,7 @@ trait DoubleValuedState[S <: DoubleValuedState[S]] extends State[S] { } /** Common trait for all analyzers which generates metrics from states computed on data frames */ -trait Analyzer[S <: State[_], +M <: Metric[_]] extends Serializable with RowLevelFilterTreatment { +trait Analyzer[S <: State[_], +M <: Metric[_]] extends Serializable { /** * Compute the state (sufficient statistics) from the data @@ -178,15 +178,6 @@ trait Analyzer[S <: State[_], +M <: Metric[_]] extends Serializable with RowLeve private[deequ] def copyStateTo(source: StateLoader, target: StatePersister): Unit = { source.load[S](this).foreach { state => target.persist(this, state) } } - - @VisibleForTesting - private[deequ] def withRowLevelFilterTreatment(filteredRow: FilteredRow): this.type = { - RowLevelFilterTreatment.setSharedInstance(new RowLevelFilterTreatmentImpl(filteredRow)) - this - } - - def rowLevelFilterTreatment: FilteredRow.Value = RowLevelFilterTreatment.sharedInstance.rowLevelFilterTreatment - } /** An analyzer that runs a set of aggregation functions over the data, diff --git a/src/main/scala/com/amazon/deequ/analyzers/Completeness.scala b/src/main/scala/com/amazon/deequ/analyzers/Completeness.scala index 7107c834b..f385da45d 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/Completeness.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/Completeness.scala @@ -20,6 +20,10 @@ import com.amazon.deequ.analyzers.Preconditions.{hasColumn, isNotNested} import org.apache.spark.sql.functions.sum import org.apache.spark.sql.types.{IntegerType, StructType} import Analyzers._ +import com.amazon.deequ.utilities.FilteredRow.FilteredRow +import com.amazon.deequ.utilities.RowLevelAnalyzer +import com.amazon.deequ.utilities.RowLevelFilterTreatment +import com.amazon.deequ.utilities.RowLevelFilterTreatmentImpl import com.google.common.annotations.VisibleForTesting import org.apache.spark.sql.functions.col import org.apache.spark.sql.functions.expr @@ -28,7 +32,7 @@ import org.apache.spark.sql.{Column, Row} /** Completeness is the fraction of non-null values in a column of a DataFrame. */ case class Completeness(column: String, where: Option[String] = None) extends StandardScanShareableAnalyzer[NumMatchesAndCount]("Completeness", column) with - FilterableAnalyzer { + FilterableAnalyzer with RowLevelAnalyzer { override def fromAggregationResult(result: Row, offset: Int): Option[NumMatchesAndCount] = { ifNoNullsIn(result, offset, howMany = 2) { _ => @@ -58,4 +62,9 @@ case class Completeness(column: String, where: Option[String] = None) extends conditionalSelectionFilteredFromColumns(col(column).isNotNull, whereCondition, rowLevelFilterTreatment.toString) } + @VisibleForTesting + private[deequ] def withRowLevelFilterTreatment(filteredRow: FilteredRow): this.type = { + RowLevelFilterTreatment.setSharedInstance(new RowLevelFilterTreatmentImpl(filteredRow)) + this + } } diff --git a/src/main/scala/com/amazon/deequ/analyzers/UniqueValueRatio.scala b/src/main/scala/com/amazon/deequ/analyzers/UniqueValueRatio.scala index 6cfdc6383..b3d1d7011 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/UniqueValueRatio.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/UniqueValueRatio.scala @@ -18,6 +18,7 @@ package com.amazon.deequ.analyzers import com.amazon.deequ.analyzers.Analyzers.COUNT_COL import com.amazon.deequ.metrics.DoubleMetric +import com.amazon.deequ.utilities.RowLevelAnalyzer import org.apache.spark.sql.functions.expr import org.apache.spark.sql.functions.not import org.apache.spark.sql.functions.when @@ -27,7 +28,7 @@ import org.apache.spark.sql.types.DoubleType case class UniqueValueRatio(columns: Seq[String], where: Option[String] = None) extends ScanShareableFrequencyBasedAnalyzer("UniqueValueRatio", columns) - with FilterableAnalyzer { + with FilterableAnalyzer with RowLevelAnalyzer { override def aggregationFunctions(numRows: Long): Seq[Column] = { sum(col(COUNT_COL).equalTo(lit(1)).cast(DoubleType)) :: count("*") :: Nil diff --git a/src/main/scala/com/amazon/deequ/analyzers/Uniqueness.scala b/src/main/scala/com/amazon/deequ/analyzers/Uniqueness.scala index f62476dac..16ec6d7b1 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/Uniqueness.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/Uniqueness.scala @@ -18,6 +18,12 @@ package com.amazon.deequ.analyzers import com.amazon.deequ.analyzers.Analyzers.COUNT_COL import com.amazon.deequ.metrics.DoubleMetric +import com.amazon.deequ.utilities.FilteredRow +import com.amazon.deequ.utilities.FilteredRow.FilteredRow +import com.amazon.deequ.utilities.RowLevelAnalyzer +import com.amazon.deequ.utilities.RowLevelFilterTreatment +import com.amazon.deequ.utilities.RowLevelFilterTreatmentImpl +import com.google.common.annotations.VisibleForTesting import org.apache.spark.sql.Column import org.apache.spark.sql.Row import org.apache.spark.sql.functions.when @@ -32,7 +38,7 @@ import org.apache.spark.sql.types.DoubleType * values that occur exactly once. */ case class Uniqueness(columns: Seq[String], where: Option[String] = None) extends ScanShareableFrequencyBasedAnalyzer("Uniqueness", columns) - with FilterableAnalyzer { + with FilterableAnalyzer with RowLevelAnalyzer { override def aggregationFunctions(numRows: Long): Seq[Column] = { (sum(col(COUNT_COL).equalTo(lit(1)).cast(DoubleType)) / numRows) :: Nil @@ -50,6 +56,13 @@ case class Uniqueness(columns: Seq[String], where: Option[String] = None) } override def filterCondition: Option[String] = where + + + @VisibleForTesting + private[deequ] def withRowLevelFilterTreatment(filteredRow: FilteredRow): this.type = { + RowLevelFilterTreatment.setSharedInstance(new RowLevelFilterTreatmentImpl(filteredRow)) + this + } } object Uniqueness { diff --git a/src/main/scala/com/amazon/deequ/utilities/RowLevelFilterTreatement.scala b/src/main/scala/com/amazon/deequ/utilities/RowLevelFilterTreatement.scala index 45ce0ce90..c37e72e61 100644 --- a/src/main/scala/com/amazon/deequ/utilities/RowLevelFilterTreatement.scala +++ b/src/main/scala/com/amazon/deequ/utilities/RowLevelFilterTreatement.scala @@ -46,3 +46,7 @@ object FilteredRow extends Enumeration { type FilteredRow = Value val NULL, TRUE = Value } + +trait RowLevelAnalyzer extends RowLevelFilterTreatment { + def rowLevelFilterTreatment: FilteredRow.Value = RowLevelFilterTreatment.sharedInstance.rowLevelFilterTreatment +} From 8e37b922057780f2cfdeda90563523a8a40f918a Mon Sep 17 00:00:00 2001 From: Edward Cho Date: Wed, 14 Feb 2024 15:02:53 -0500 Subject: [PATCH 7/8] Do row-level filtering in AnalyzerOptions rather than with RowLevelFilterTreatment trait --- .../amazon/deequ/VerificationRunBuilder.scala | 16 ----- .../com/amazon/deequ/analyzers/Analyzer.scala | 14 +++-- .../amazon/deequ/analyzers/Completeness.scala | 20 +++---- .../deequ/analyzers/UniqueValueRatio.scala | 30 ++++++---- .../amazon/deequ/analyzers/Uniqueness.scala | 36 +++++------ .../scala/com/amazon/deequ/checks/Check.scala | 60 ++++++++++++++++--- .../amazon/deequ/constraints/Constraint.scala | 14 +++-- .../utilities/RowLevelFilterTreatement.scala | 52 ---------------- .../amazon/deequ/VerificationResultTest.scala | 18 ++++-- .../amazon/deequ/VerificationSuiteTest.scala | 10 ++-- .../deequ/analyzers/CompletenessTest.scala | 6 +- .../deequ/analyzers/UniquenessTest.scala | 11 ++-- .../runners/AnalysisRunnerTests.scala | 2 +- .../runners/AnalyzerContextTest.scala | 5 +- .../repository/AnalysisResultSerdeTest.scala | 4 +- .../deequ/repository/AnalysisResultTest.scala | 5 +- ...sRepositoryMultipleResultsLoaderTest.scala | 5 +- .../ConstraintSuggestionResultTest.scala | 32 +++++----- 18 files changed, 169 insertions(+), 171 deletions(-) delete mode 100644 src/main/scala/com/amazon/deequ/utilities/RowLevelFilterTreatement.scala diff --git a/src/main/scala/com/amazon/deequ/VerificationRunBuilder.scala b/src/main/scala/com/amazon/deequ/VerificationRunBuilder.scala index fb31651c8..40caa4092 100644 --- a/src/main/scala/com/amazon/deequ/VerificationRunBuilder.scala +++ b/src/main/scala/com/amazon/deequ/VerificationRunBuilder.scala @@ -22,9 +22,6 @@ import com.amazon.deequ.analyzers.{State, _} import com.amazon.deequ.checks.{Check, CheckLevel} import com.amazon.deequ.metrics.Metric import com.amazon.deequ.repository._ -import com.amazon.deequ.utilities.FilteredRow.FilteredRow -import com.amazon.deequ.utilities.RowLevelFilterTreatment -import com.amazon.deequ.utilities.RowLevelFilterTreatmentImpl import org.apache.spark.sql.{DataFrame, SparkSession} /** A class to build a VerificationRun using a fluent API */ @@ -47,7 +44,6 @@ class VerificationRunBuilder(val data: DataFrame) { protected var statePersister: Option[StatePersister] = None protected var stateLoader: Option[StateLoader] = None - protected var rowLevelFilterTreatment: RowLevelFilterTreatment = RowLevelFilterTreatment.sharedInstance protected def this(verificationRunBuilder: VerificationRunBuilder) { @@ -70,7 +66,6 @@ class VerificationRunBuilder(val data: DataFrame) { stateLoader = verificationRunBuilder.stateLoader statePersister = verificationRunBuilder.statePersister - rowLevelFilterTreatment = verificationRunBuilder.rowLevelFilterTreatment } /** @@ -140,17 +135,6 @@ class VerificationRunBuilder(val data: DataFrame) { this } - /** - * Sets how row level results will be treated for the Verification run - * - * @param filteredRow enum to determine how filtered row level results are labeled (TRUE | NULL) - */ - def withRowLevelFilterTreatment(filteredRow: FilteredRow): this.type = { - RowLevelFilterTreatment.setSharedInstance(new RowLevelFilterTreatmentImpl(filteredRow)) - rowLevelFilterTreatment = RowLevelFilterTreatment.sharedInstance - this - } - /** * Set a metrics repository associated with the current data to enable features like reusing * previously computed results and storing the results of the current run. diff --git a/src/main/scala/com/amazon/deequ/analyzers/Analyzer.scala b/src/main/scala/com/amazon/deequ/analyzers/Analyzer.scala index 028579426..f1d7c35c8 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/Analyzer.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/Analyzer.scala @@ -17,6 +17,7 @@ package com.amazon.deequ.analyzers import com.amazon.deequ.analyzers.Analyzers._ +import com.amazon.deequ.analyzers.FilteredRow.FilteredRow import com.amazon.deequ.analyzers.NullBehavior.NullBehavior import com.amazon.deequ.analyzers.runners._ import com.amazon.deequ.metrics.DoubleMetric @@ -24,11 +25,6 @@ import com.amazon.deequ.metrics.Entity import com.amazon.deequ.metrics.FullColumn import com.amazon.deequ.metrics.Metric import com.amazon.deequ.utilities.ColumnUtil.removeEscapeColumn -import com.amazon.deequ.utilities.FilteredRow -import com.amazon.deequ.utilities.FilteredRow.FilteredRow -import com.amazon.deequ.utilities.RowLevelFilterTreatment -import com.amazon.deequ.utilities.RowLevelFilterTreatmentImpl -import com.google.common.annotations.VisibleForTesting import org.apache.spark.sql.Column import org.apache.spark.sql.DataFrame import org.apache.spark.sql.Row @@ -266,12 +262,18 @@ case class NumMatchesAndCount(numMatches: Long, count: Long, override val fullCo } } -case class AnalyzerOptions(nullBehavior: NullBehavior = NullBehavior.Ignore) +case class AnalyzerOptions(nullBehavior: NullBehavior = NullBehavior.Ignore, + filteredRow: FilteredRow = FilteredRow.TRUE) object NullBehavior extends Enumeration { type NullBehavior = Value val Ignore, EmptyString, Fail = Value } +object FilteredRow extends Enumeration { + type FilteredRow = Value + val NULL, TRUE = Value +} + /** Base class for analyzers that compute ratios of matching predicates */ abstract class PredicateMatchingAnalyzer( name: String, diff --git a/src/main/scala/com/amazon/deequ/analyzers/Completeness.scala b/src/main/scala/com/amazon/deequ/analyzers/Completeness.scala index f385da45d..399cbb06a 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/Completeness.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/Completeness.scala @@ -20,19 +20,17 @@ import com.amazon.deequ.analyzers.Preconditions.{hasColumn, isNotNested} import org.apache.spark.sql.functions.sum import org.apache.spark.sql.types.{IntegerType, StructType} import Analyzers._ -import com.amazon.deequ.utilities.FilteredRow.FilteredRow -import com.amazon.deequ.utilities.RowLevelAnalyzer -import com.amazon.deequ.utilities.RowLevelFilterTreatment -import com.amazon.deequ.utilities.RowLevelFilterTreatmentImpl +import com.amazon.deequ.analyzers.FilteredRow.FilteredRow import com.google.common.annotations.VisibleForTesting import org.apache.spark.sql.functions.col import org.apache.spark.sql.functions.expr import org.apache.spark.sql.{Column, Row} /** Completeness is the fraction of non-null values in a column of a DataFrame. */ -case class Completeness(column: String, where: Option[String] = None) extends +case class Completeness(column: String, where: Option[String] = None, + analyzerOptions: Option[AnalyzerOptions] = None) extends StandardScanShareableAnalyzer[NumMatchesAndCount]("Completeness", column) with - FilterableAnalyzer with RowLevelAnalyzer { + FilterableAnalyzer { override def fromAggregationResult(result: Row, offset: Int): Option[NumMatchesAndCount] = { ifNoNullsIn(result, offset, howMany = 2) { _ => @@ -59,12 +57,12 @@ case class Completeness(column: String, where: Option[String] = None) extends @VisibleForTesting private[deequ] def rowLevelResults: Column = { val whereCondition = where.map { expression => expr(expression)} - conditionalSelectionFilteredFromColumns(col(column).isNotNull, whereCondition, rowLevelFilterTreatment.toString) + conditionalSelectionFilteredFromColumns(col(column).isNotNull, whereCondition, getRowLevelFilterTreatment.toString) } - @VisibleForTesting - private[deequ] def withRowLevelFilterTreatment(filteredRow: FilteredRow): this.type = { - RowLevelFilterTreatment.setSharedInstance(new RowLevelFilterTreatmentImpl(filteredRow)) - this + private def getRowLevelFilterTreatment: FilteredRow = { + analyzerOptions + .map { options => options.filteredRow } + .getOrElse(FilteredRow.TRUE) } } diff --git a/src/main/scala/com/amazon/deequ/analyzers/UniqueValueRatio.scala b/src/main/scala/com/amazon/deequ/analyzers/UniqueValueRatio.scala index b3d1d7011..c2fce1f14 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/UniqueValueRatio.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/UniqueValueRatio.scala @@ -17,8 +17,8 @@ package com.amazon.deequ.analyzers import com.amazon.deequ.analyzers.Analyzers.COUNT_COL +import com.amazon.deequ.analyzers.FilteredRow.FilteredRow import com.amazon.deequ.metrics.DoubleMetric -import com.amazon.deequ.utilities.RowLevelAnalyzer import org.apache.spark.sql.functions.expr import org.apache.spark.sql.functions.not import org.apache.spark.sql.functions.when @@ -26,9 +26,10 @@ import org.apache.spark.sql.{Column, Row} import org.apache.spark.sql.functions.{col, count, lit, sum} import org.apache.spark.sql.types.DoubleType -case class UniqueValueRatio(columns: Seq[String], where: Option[String] = None) +case class UniqueValueRatio(columns: Seq[String], where: Option[String] = None, + analyzerOptions: Option[AnalyzerOptions] = None) extends ScanShareableFrequencyBasedAnalyzer("UniqueValueRatio", columns) - with FilterableAnalyzer with RowLevelAnalyzer { + with FilterableAnalyzer { override def aggregationFunctions(numRows: Long): Seq[Column] = { sum(col(COUNT_COL).equalTo(lit(1)).cast(DoubleType)) :: count("*") :: Nil @@ -38,17 +39,26 @@ case class UniqueValueRatio(columns: Seq[String], where: Option[String] = None) val numUniqueValues = result.getDouble(offset) val numDistinctValues = result.getLong(offset + 1).toDouble val conditionColumn = where.map { expression => expr(expression) } - val fullColumnUniqueness = conditionColumn.map { - condition => { - when(not(condition), expr(rowLevelFilterTreatment.toString)) - .when((fullColumn.getOrElse(null)).equalTo(1), true) - .otherwise(false) + val fullColumnUniqueness = fullColumn.map { + rowLevelColumn => { + conditionColumn.map { + condition => { + when(not(condition), expr(getRowLevelFilterTreatment.toString)) + .when(rowLevelColumn.equalTo(1), true).otherwise(false) + } + }.getOrElse(when(rowLevelColumn.equalTo(1), true).otherwise(false)) } - }.getOrElse(when((fullColumn.getOrElse(null)).equalTo(1), true).otherwise(false)) - toSuccessMetric(numUniqueValues / numDistinctValues, Option(fullColumnUniqueness)) + } + toSuccessMetric(numUniqueValues / numDistinctValues, fullColumnUniqueness) } override def filterCondition: Option[String] = where + + private def getRowLevelFilterTreatment: FilteredRow = { + analyzerOptions + .map { options => options.filteredRow } + .getOrElse(FilteredRow.TRUE) + } } object UniqueValueRatio { diff --git a/src/main/scala/com/amazon/deequ/analyzers/Uniqueness.scala b/src/main/scala/com/amazon/deequ/analyzers/Uniqueness.scala index 16ec6d7b1..78ba4c418 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/Uniqueness.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/Uniqueness.scala @@ -17,12 +17,8 @@ package com.amazon.deequ.analyzers import com.amazon.deequ.analyzers.Analyzers.COUNT_COL +import com.amazon.deequ.analyzers.FilteredRow.FilteredRow import com.amazon.deequ.metrics.DoubleMetric -import com.amazon.deequ.utilities.FilteredRow -import com.amazon.deequ.utilities.FilteredRow.FilteredRow -import com.amazon.deequ.utilities.RowLevelAnalyzer -import com.amazon.deequ.utilities.RowLevelFilterTreatment -import com.amazon.deequ.utilities.RowLevelFilterTreatmentImpl import com.google.common.annotations.VisibleForTesting import org.apache.spark.sql.Column import org.apache.spark.sql.Row @@ -36,9 +32,10 @@ import org.apache.spark.sql.types.DoubleType /** Uniqueness is the fraction of unique values of a column(s), i.e., * values that occur exactly once. */ -case class Uniqueness(columns: Seq[String], where: Option[String] = None) +case class Uniqueness(columns: Seq[String], where: Option[String] = None, + analyzerOptions: Option[AnalyzerOptions] = None) extends ScanShareableFrequencyBasedAnalyzer("Uniqueness", columns) - with FilterableAnalyzer with RowLevelAnalyzer { + with FilterableAnalyzer { override def aggregationFunctions(numRows: Long): Seq[Column] = { (sum(col(COUNT_COL).equalTo(lit(1)).cast(DoubleType)) / numRows) :: Nil @@ -46,22 +43,25 @@ case class Uniqueness(columns: Seq[String], where: Option[String] = None) override def fromAggregationResult(result: Row, offset: Int, fullColumn: Option[Column]): DoubleMetric = { val conditionColumn = where.map { expression => expr(expression) } - val fullColumnUniqueness = conditionColumn.map { - condition => { - when(not(condition), expr(rowLevelFilterTreatment.toString)) - .when(fullColumn.getOrElse(null).equalTo(1), true).otherwise(false) + val fullColumnUniqueness = fullColumn.map { + rowLevelColumn => { + conditionColumn.map { + condition => { + when(not(condition), expr(getRowLevelFilterTreatment.toString)) + .when(rowLevelColumn.equalTo(1), true).otherwise(false) + } + }.getOrElse(when(rowLevelColumn.equalTo(1), true).otherwise(false)) } - }.getOrElse(when((fullColumn.getOrElse(null)).equalTo(1), true).otherwise(false)) - super.fromAggregationResult(result, offset, Option(fullColumnUniqueness)) + } + super.fromAggregationResult(result, offset, fullColumnUniqueness) } override def filterCondition: Option[String] = where - - @VisibleForTesting - private[deequ] def withRowLevelFilterTreatment(filteredRow: FilteredRow): this.type = { - RowLevelFilterTreatment.setSharedInstance(new RowLevelFilterTreatmentImpl(filteredRow)) - this + private def getRowLevelFilterTreatment: FilteredRow = { + analyzerOptions + .map { options => options.filteredRow } + .getOrElse(FilteredRow.TRUE) } } diff --git a/src/main/scala/com/amazon/deequ/checks/Check.scala b/src/main/scala/com/amazon/deequ/checks/Check.scala index 884041469..9921e7f8e 100644 --- a/src/main/scala/com/amazon/deequ/checks/Check.scala +++ b/src/main/scala/com/amazon/deequ/checks/Check.scala @@ -132,10 +132,12 @@ case class Check( * * @param column Column to run the assertion on * @param hint A hint to provide additional context why a constraint could have failed + * @param analyzerOptions Options to configure analyzer behavior (NullTreatment, FilteredRow) * @return */ - def isComplete(column: String, hint: Option[String] = None): CheckWithLastConstraintFilterable = { - addFilterableConstraint { filter => completenessConstraint(column, Check.IsOne, filter, hint) } + def isComplete(column: String, hint: Option[String] = None, + analyzerOptions: Option[AnalyzerOptions] = None): CheckWithLastConstraintFilterable = { + addFilterableConstraint { filter => completenessConstraint(column, Check.IsOne, filter, hint, analyzerOptions) } } /** @@ -146,14 +148,16 @@ case class Check( * @param column Column to run the assertion on * @param assertion Function that receives a double input parameter and returns a boolean * @param hint A hint to provide additional context why a constraint could have failed + * @param analyzerOptions Options to configure analyzer behavior (NullTreatment, FilteredRow) * @return */ def hasCompleteness( column: String, assertion: Double => Boolean, - hint: Option[String] = None) + hint: Option[String] = None, + analyzerOptions: Option[AnalyzerOptions] = None) : CheckWithLastConstraintFilterable = { - addFilterableConstraint { filter => completenessConstraint(column, assertion, filter, hint) } + addFilterableConstraint { filter => completenessConstraint(column, assertion, filter, hint, analyzerOptions) } } /** @@ -221,11 +225,13 @@ case class Check( * * @param column Column to run the assertion on * @param hint A hint to provide additional context why a constraint could have failed + * @param analyzerOptions Options to configure analyzer behavior (NullTreatment, FilteredRow) * @return */ - def isUnique(column: String, hint: Option[String] = None): CheckWithLastConstraintFilterable = { + def isUnique(column: String, hint: Option[String] = None, + analyzerOptions: Option[AnalyzerOptions] = None): CheckWithLastConstraintFilterable = { addFilterableConstraint { filter => - uniquenessConstraint(Seq(column), Check.IsOne, filter, hint) } + uniquenessConstraint(Seq(column), Check.IsOne, filter, hint, analyzerOptions) } } /** @@ -269,6 +275,24 @@ case class Check( addFilterableConstraint { filter => uniquenessConstraint(columns, assertion, filter) } } + /** + * Creates a constraint that asserts on uniqueness in a single or combined set of key columns. + * + * @param columns Key columns + * @param assertion Function that receives a double input parameter and returns a boolean. + * Refers to the fraction of unique values + * @param hint A hint to provide additional context why a constraint could have failed + * @return + */ + def hasUniqueness( + columns: Seq[String], + assertion: Double => Boolean, + hint: Option[String]) + : CheckWithLastConstraintFilterable = { + + addFilterableConstraint { filter => uniquenessConstraint(columns, assertion, filter, hint) } + } + /** * Creates a constraint that asserts on uniqueness in a single or combined set of key columns. * @@ -276,15 +300,17 @@ case class Check( * @param assertion Function that receives a double input parameter and returns a boolean. * Refers to the fraction of unique values * @param hint A hint to provide additional context why a constraint could have failed + * @param analyzerOptions Options to configure analyzer behavior (NullTreatment, FilteredRow) * @return */ def hasUniqueness( columns: Seq[String], assertion: Double => Boolean, - hint: Option[String]) + hint: Option[String], + analyzerOptions: Option[AnalyzerOptions]) : CheckWithLastConstraintFilterable = { - addFilterableConstraint { filter => uniquenessConstraint(columns, assertion, filter, hint) } + addFilterableConstraint { filter => uniquenessConstraint(columns, assertion, filter, hint, analyzerOptions) } } /** @@ -314,6 +340,22 @@ case class Check( hasUniqueness(Seq(column), assertion, hint) } + /** + * Creates a constraint that asserts on the uniqueness of a key column. + * + * @param column Key column + * @param assertion Function that receives a double input parameter and returns a boolean. + * Refers to the fraction of unique values. + * @param hint A hint to provide additional context why a constraint could have failed + * @param analyzerOptions Options to configure analyzer behavior (NullTreatment, FilteredRow) + * @return + */ + def hasUniqueness(column: String, assertion: Double => Boolean, hint: Option[String], + analyzerOptions: Option[AnalyzerOptions]) + : CheckWithLastConstraintFilterable = { + hasUniqueness(Seq(column), assertion, hint, analyzerOptions) + } + /** * Creates a constraint on the distinctness in a single or combined set of key columns. * @@ -601,6 +643,7 @@ case class Check( * @param column Column to run the assertion on * @param assertion Function that receives a double input parameter and returns a boolean * @param hint A hint to provide additional context why a constraint could have failed + * @param analyzerOptions Options to configure analyzer behavior (NullTreatment, FilteredRow) * @return */ def hasMinLength( @@ -619,6 +662,7 @@ case class Check( * @param column Column to run the assertion on * @param assertion Function that receives a double input parameter and returns a boolean * @param hint A hint to provide additional context why a constraint could have failed + * @param analyzerOptions Options to configure analyzer behavior (NullTreatment, FilteredRow) * @return */ def hasMaxLength( diff --git a/src/main/scala/com/amazon/deequ/constraints/Constraint.scala b/src/main/scala/com/amazon/deequ/constraints/Constraint.scala index 74070687e..d17ee9abe 100644 --- a/src/main/scala/com/amazon/deequ/constraints/Constraint.scala +++ b/src/main/scala/com/amazon/deequ/constraints/Constraint.scala @@ -192,15 +192,17 @@ object Constraint { * @param assertion Function that receives a double input parameter (since the metric is * double metric) and returns a boolean * @param hint A hint to provide additional context why a constraint could have failed + * @param analyzerOptions Options to configure analyzer behavior (NullTreatment, FilteredRow) */ def completenessConstraint( column: String, assertion: Double => Boolean, where: Option[String] = None, - hint: Option[String] = None) + hint: Option[String] = None, + analyzerOptions: Option[AnalyzerOptions] = None) : Constraint = { - val completeness = Completeness(column, where) + val completeness = Completeness(column, where, analyzerOptions) this.fromAnalyzer(completeness, assertion, hint) } @@ -242,15 +244,17 @@ object Constraint { * (since the metric is double metric) and returns a boolean * @param where Additional filter to apply before the analyzer is run. * @param hint A hint to provide additional context why a constraint could have failed + * @param analyzerOptions Options to configure analyzer behavior (NullTreatment, FilteredRow) */ def uniquenessConstraint( columns: Seq[String], assertion: Double => Boolean, where: Option[String] = None, - hint: Option[String] = None) + hint: Option[String] = None, + analyzerOptions: Option[AnalyzerOptions] = None) : Constraint = { - val uniqueness = Uniqueness(columns, where) + val uniqueness = Uniqueness(columns, where, analyzerOptions) fromAnalyzer(uniqueness, assertion, hint) } @@ -528,6 +532,7 @@ object Constraint { * @param column Column to run the assertion on * @param assertion Function that receives a double input parameter and returns a boolean * @param hint A hint to provide additional context why a constraint could have failed + * @param analyzerOptions Options to configure analyzer behavior (NullTreatment, FilteredRow) */ def maxLengthConstraint( column: String, @@ -562,6 +567,7 @@ object Constraint { * @param column Column to run the assertion on * @param assertion Function that receives a double input parameter and returns a boolean * @param hint A hint to provide additional context why a constraint could have failed + * @param analyzerOptions Options to configure analyzer behavior (NullTreatment, FilteredRow) */ def minLengthConstraint( column: String, diff --git a/src/main/scala/com/amazon/deequ/utilities/RowLevelFilterTreatement.scala b/src/main/scala/com/amazon/deequ/utilities/RowLevelFilterTreatement.scala deleted file mode 100644 index c37e72e61..000000000 --- a/src/main/scala/com/amazon/deequ/utilities/RowLevelFilterTreatement.scala +++ /dev/null @@ -1,52 +0,0 @@ -/** - * Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). You may not - * use this file except in compliance with the License. A copy of the License - * is located at - * - * http://aws.amazon.com/apache2.0/ - * - * or in the "license" file accompanying this file. This file is distributed on - * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either - * express or implied. See the License for the specific language governing - * permissions and limitations under the License. - * - */ - -package com.amazon.deequ.utilities -import com.amazon.deequ.utilities.FilteredRow.FilteredRow - -/** - * Trait that defines how row level results will be treated when a filter is applied to an analyzer - */ -trait RowLevelFilterTreatment { - def rowLevelFilterTreatment: FilteredRow -} - -/** - * Companion object for RowLevelFilterTreatment - * Defines a sharedInstance that can be used throughout the VerificationRunBuilder - */ -object RowLevelFilterTreatment { - private var _sharedInstance: RowLevelFilterTreatment = new RowLevelFilterTreatmentImpl(FilteredRow.TRUE) - - def sharedInstance: RowLevelFilterTreatment = _sharedInstance - - def setSharedInstance(instance: RowLevelFilterTreatment): Unit = { - _sharedInstance = instance - } -} - -class RowLevelFilterTreatmentImpl(initialFilterTreatment: FilteredRow) extends RowLevelFilterTreatment { - override val rowLevelFilterTreatment: FilteredRow = initialFilterTreatment -} - -object FilteredRow extends Enumeration { - type FilteredRow = Value - val NULL, TRUE = Value -} - -trait RowLevelAnalyzer extends RowLevelFilterTreatment { - def rowLevelFilterTreatment: FilteredRow.Value = RowLevelFilterTreatment.sharedInstance.rowLevelFilterTreatment -} diff --git a/src/test/scala/com/amazon/deequ/VerificationResultTest.scala b/src/test/scala/com/amazon/deequ/VerificationResultTest.scala index 93aa73201..0a90c8f77 100644 --- a/src/test/scala/com/amazon/deequ/VerificationResultTest.scala +++ b/src/test/scala/com/amazon/deequ/VerificationResultTest.scala @@ -78,6 +78,13 @@ class VerificationResultTest extends WordSpec with Matchers with SparkContextSpe val successMetricsResultsJson = VerificationResult.successMetricsAsJson(results) + val expectedJsonSet = Set("""{"entity":"Column","instance":"item","name":"Distinctness","value":1.0}""", + """{"entity": "Column", "instance":"att2","name":"Completeness","value":1.0}""", + """{"entity":"Column","instance":"att1","name":"Completeness","value":1.0}""", + """{"entity":"Multicolumn","instance":"att1,att2", + "name":"Uniqueness","value":0.25}""", + """{"entity":"Dataset","instance":"*","name":"Size","value":4.0}""") + val expectedJson = """[{"entity":"Column","instance":"item","name":"Distinctness","value":1.0}, |{"entity": "Column", "instance":"att2","name":"Completeness","value":1.0}, @@ -123,11 +130,11 @@ class VerificationResultTest extends WordSpec with Matchers with SparkContextSpe import session.implicits._ val expected = Seq( - ("group-1", "Error", "Success", "CompletenessConstraint(Completeness(att1,None))", + ("group-1", "Error", "Success", "CompletenessConstraint(Completeness(att1,None,None))", "Success", ""), ("group-2-E", "Error", "Error", "SizeConstraint(Size(None))", "Failure", "Value: 4 does not meet the constraint requirement! Should be greater than 5!"), - ("group-2-E", "Error", "Error", "CompletenessConstraint(Completeness(att2,None))", + ("group-2-E", "Error", "Error", "CompletenessConstraint(Completeness(att2,None,None))", "Success", ""), ("group-2-W", "Warning", "Warning", "DistinctnessConstraint(Distinctness(List(item),None))", @@ -150,7 +157,7 @@ class VerificationResultTest extends WordSpec with Matchers with SparkContextSpe val expectedJson = """[{"check":"group-1","check_level":"Error","check_status":"Success", - |"constraint":"CompletenessConstraint(Completeness(att1,None))", + |"constraint":"CompletenessConstraint(Completeness(att1,None,None))", |"constraint_status":"Success","constraint_message":""}, | |{"check":"group-2-E","check_level":"Error","check_status":"Error", @@ -159,7 +166,7 @@ class VerificationResultTest extends WordSpec with Matchers with SparkContextSpe | Should be greater than 5!"}, | |{"check":"group-2-E","check_level":"Error","check_status":"Error", - |"constraint":"CompletenessConstraint(Completeness(att2,None))", + |"constraint":"CompletenessConstraint(Completeness(att2,None,None))", |"constraint_status":"Success","constraint_message":""}, | |{"check":"group-2-W","check_level":"Warning","check_status":"Warning", @@ -214,7 +221,6 @@ class VerificationResultTest extends WordSpec with Matchers with SparkContextSpe } private[this] def assertSameResultsJson(jsonA: String, jsonB: String): Unit = { - assert(SimpleResultSerde.deserialize(jsonA) == - SimpleResultSerde.deserialize(jsonB)) + assert(SimpleResultSerde.deserialize(jsonA).toSet.sameElements(SimpleResultSerde.deserialize(jsonB).toSet)) } } diff --git a/src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala b/src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala index 5701b456b..932c82988 100644 --- a/src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala +++ b/src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala @@ -29,7 +29,6 @@ import com.amazon.deequ.metrics.Entity import com.amazon.deequ.repository.MetricsRepository import com.amazon.deequ.repository.ResultKey import com.amazon.deequ.repository.memory.InMemoryMetricsRepository -import com.amazon.deequ.utilities.FilteredRow import com.amazon.deequ.utils.CollectionUtils.SeqExtensions import com.amazon.deequ.utils.FixtureSupport import com.amazon.deequ.utils.TempFileUtils @@ -350,20 +349,21 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec "generate a result that contains row-level results with null for filtered rows" in withSparkSession { session => val data = getDfCompleteAndInCompleteColumns(session) + val analyzerOptions = Option(AnalyzerOptions(filteredRow = FilteredRow.NULL)) + val completeness = new Check(CheckLevel.Error, "rule1") - .hasCompleteness("att2", _ > 0.7, None) + .hasCompleteness("att2", _ > 0.7, None, analyzerOptions) .where("att1 = \"a\"") val uniqueness = new Check(CheckLevel.Error, "rule2") - .hasUniqueness("att1", _ > 0.5, None) + .hasUniqueness("att1", _ > 0.5, None, analyzerOptions) val uniquenessWhere = new Check(CheckLevel.Error, "rule3") - .isUnique("att1") + .isUnique("att1", None, analyzerOptions) .where("item < 3") val expectedColumn1 = completeness.description val expectedColumn2 = uniqueness.description val expectedColumn3 = uniquenessWhere.description val suite = new VerificationSuite().onData(data) - .withRowLevelFilterTreatment(FilteredRow.NULL) .addCheck(completeness) .addCheck(uniqueness) .addCheck(uniquenessWhere) diff --git a/src/test/scala/com/amazon/deequ/analyzers/CompletenessTest.scala b/src/test/scala/com/amazon/deequ/analyzers/CompletenessTest.scala index cb2778a1b..54e26f867 100644 --- a/src/test/scala/com/amazon/deequ/analyzers/CompletenessTest.scala +++ b/src/test/scala/com/amazon/deequ/analyzers/CompletenessTest.scala @@ -19,7 +19,6 @@ package com.amazon.deequ.analyzers import com.amazon.deequ.SparkContextSpec import com.amazon.deequ.metrics.DoubleMetric import com.amazon.deequ.metrics.FullColumn -import com.amazon.deequ.utilities.FilteredRow import com.amazon.deequ.utils.FixtureSupport import org.scalatest.matchers.should.Matchers import org.scalatest.wordspec.AnyWordSpec @@ -46,7 +45,8 @@ class CompletenessTest extends AnyWordSpec with Matchers with SparkContextSpec w val data = getDfCompleteAndInCompleteColumns(session) // Explicitly setting RowLevelFilterTreatment for test purposes, this should be set at the VerificationRunBuilder - val completenessAtt2 = Completeness("att2", Option("att1 = \"a\"")).withRowLevelFilterTreatment(FilteredRow.NULL) + val completenessAtt2 = Completeness("att2", Option("att1 = \"a\""), + Option(AnalyzerOptions(filteredRow = FilteredRow.NULL))) val state = completenessAtt2.computeStateFrom(data) val metric: DoubleMetric with FullColumn = completenessAtt2.computeMetricFrom(state) @@ -61,7 +61,7 @@ class CompletenessTest extends AnyWordSpec with Matchers with SparkContextSpec w val data = getDfCompleteAndInCompleteColumns(session) // Explicitly setting RowLevelFilterTreatment for test purposes, this should be set at the VerificationRunBuilder - val completenessAtt2 = Completeness("att2", Option("att1 = \"a\"")).withRowLevelFilterTreatment(FilteredRow.TRUE) + val completenessAtt2 = Completeness("att2", Option("att1 = \"a\"")) val state = completenessAtt2.computeStateFrom(data) val metric: DoubleMetric with FullColumn = completenessAtt2.computeMetricFrom(state) diff --git a/src/test/scala/com/amazon/deequ/analyzers/UniquenessTest.scala b/src/test/scala/com/amazon/deequ/analyzers/UniquenessTest.scala index 7be9b4b35..d50995b55 100644 --- a/src/test/scala/com/amazon/deequ/analyzers/UniquenessTest.scala +++ b/src/test/scala/com/amazon/deequ/analyzers/UniquenessTest.scala @@ -21,7 +21,6 @@ import com.amazon.deequ.VerificationResult.UNIQUENESS_ID import com.amazon.deequ.analyzers.runners.AnalysisRunner import com.amazon.deequ.metrics.DoubleMetric import com.amazon.deequ.metrics.FullColumn -import com.amazon.deequ.utilities.FilteredRow import com.amazon.deequ.utils.FixtureSupport import org.apache.spark.sql.DataFrame import org.apache.spark.sql.SparkSession @@ -123,8 +122,8 @@ class UniquenessTest extends AnyWordSpec with Matchers with SparkContextSpec wit val data = getDfWithUniqueColumns(session) - val addressLength = Uniqueness(Seq("onlyUniqueWithOtherNonUnique"), Option("unique < 4")) - .withRowLevelFilterTreatment(FilteredRow.NULL) + val addressLength = Uniqueness(Seq("onlyUniqueWithOtherNonUnique"), Option("unique < 4"), + Option(AnalyzerOptions(filteredRow = FilteredRow.NULL))) val state: Option[FrequenciesAndNumRows] = addressLength.computeStateFrom(data, Option("unique < 4")) val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) @@ -139,8 +138,8 @@ class UniquenessTest extends AnyWordSpec with Matchers with SparkContextSpec wit val data = getDfWithUniqueColumns(session) - val addressLength = Uniqueness(Seq("halfUniqueCombinedWithNonUnique", "nonUnique"), Option("unique > 2")) - .withRowLevelFilterTreatment(FilteredRow.NULL) + val addressLength = Uniqueness(Seq("halfUniqueCombinedWithNonUnique", "nonUnique"), Option("unique > 2"), + Option(AnalyzerOptions(filteredRow = FilteredRow.NULL))) val state: Option[FrequenciesAndNumRows] = addressLength.computeStateFrom(data, Option("unique > 2")) val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) @@ -157,7 +156,6 @@ class UniquenessTest extends AnyWordSpec with Matchers with SparkContextSpec wit // Explicitly setting RowLevelFilterTreatment for test purposes, this should be set at the VerificationRunBuilder val addressLength = Uniqueness(Seq("onlyUniqueWithOtherNonUnique"), Option("unique < 4")) - .withRowLevelFilterTreatment(FilteredRow.TRUE) val state: Option[FrequenciesAndNumRows] = addressLength.computeStateFrom(data, Option("unique < 4")) val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) @@ -174,7 +172,6 @@ class UniquenessTest extends AnyWordSpec with Matchers with SparkContextSpec wit // Explicitly setting RowLevelFilterTreatment for test purposes, this should be set at the VerificationRunBuilder val addressLength = Uniqueness(Seq("halfUniqueCombinedWithNonUnique", "nonUnique"), Option("unique > 2")) - .withRowLevelFilterTreatment(FilteredRow.TRUE) val state: Option[FrequenciesAndNumRows] = addressLength.computeStateFrom(data, Option("unique > 2")) val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) diff --git a/src/test/scala/com/amazon/deequ/analyzers/runners/AnalysisRunnerTests.scala b/src/test/scala/com/amazon/deequ/analyzers/runners/AnalysisRunnerTests.scala index 31b7365ad..ce9bda69b 100644 --- a/src/test/scala/com/amazon/deequ/analyzers/runners/AnalysisRunnerTests.scala +++ b/src/test/scala/com/amazon/deequ/analyzers/runners/AnalysisRunnerTests.scala @@ -284,7 +284,7 @@ class AnalysisRunnerTests extends AnyWordSpec assert(exception.getMessage == "Could not find all necessary results in the " + "MetricsRepository, the calculation of the metrics for these analyzers " + - "would be needed: Uniqueness(List(item, att2),None), Size(None)") + "would be needed: Uniqueness(List(item, att2),None,None), Size(None)") } "save results if specified" in diff --git a/src/test/scala/com/amazon/deequ/analyzers/runners/AnalyzerContextTest.scala b/src/test/scala/com/amazon/deequ/analyzers/runners/AnalyzerContextTest.scala index 254fac9b4..9133d5ae4 100644 --- a/src/test/scala/com/amazon/deequ/analyzers/runners/AnalyzerContextTest.scala +++ b/src/test/scala/com/amazon/deequ/analyzers/runners/AnalyzerContextTest.scala @@ -145,7 +145,8 @@ class AnalyzerContextTest extends AnyWordSpec } private[this] def assertSameJson(jsonA: String, jsonB: String): Unit = { - assert(SimpleResultSerde.deserialize(jsonA) == - SimpleResultSerde.deserialize(jsonB)) + assert(SimpleResultSerde.deserialize(jsonA).toSet.sameElements(SimpleResultSerde.deserialize(jsonB).toSet)) + // assert(SimpleResultSerde.deserialize(jsonA) == + // SimpleResultSerde.deserialize(jsonB)) } } diff --git a/src/test/scala/com/amazon/deequ/repository/AnalysisResultSerdeTest.scala b/src/test/scala/com/amazon/deequ/repository/AnalysisResultSerdeTest.scala index 6f1fa1874..05f4d47bd 100644 --- a/src/test/scala/com/amazon/deequ/repository/AnalysisResultSerdeTest.scala +++ b/src/test/scala/com/amazon/deequ/repository/AnalysisResultSerdeTest.scala @@ -363,7 +363,7 @@ class SimpleResultSerdeTest extends WordSpec with Matchers with SparkContextSpec .stripMargin.replaceAll("\n", "") // ordering of map entries is not guaranteed, so comparing strings is not an option - assert(SimpleResultSerde.deserialize(sucessMetricsResultJson) == - SimpleResultSerde.deserialize(expected)) + assert(SimpleResultSerde.deserialize(sucessMetricsResultJson).toSet.sameElements( + SimpleResultSerde.deserialize(expected).toSet)) } } diff --git a/src/test/scala/com/amazon/deequ/repository/AnalysisResultTest.scala b/src/test/scala/com/amazon/deequ/repository/AnalysisResultTest.scala index 97d7a3c49..d4ce97fcb 100644 --- a/src/test/scala/com/amazon/deequ/repository/AnalysisResultTest.scala +++ b/src/test/scala/com/amazon/deequ/repository/AnalysisResultTest.scala @@ -344,7 +344,8 @@ class AnalysisResultTest extends AnyWordSpec } private[this] def assertSameJson(jsonA: String, jsonB: String): Unit = { - assert(SimpleResultSerde.deserialize(jsonA) == - SimpleResultSerde.deserialize(jsonB)) + assert(SimpleResultSerde.deserialize(jsonA).toSet.sameElements(SimpleResultSerde.deserialize(jsonB).toSet)) +// assert(SimpleResultSerde.deserialize(jsonA) == +// SimpleResultSerde.deserialize(jsonB)) } } diff --git a/src/test/scala/com/amazon/deequ/repository/MetricsRepositoryMultipleResultsLoaderTest.scala b/src/test/scala/com/amazon/deequ/repository/MetricsRepositoryMultipleResultsLoaderTest.scala index 6e61b9385..592f27b0e 100644 --- a/src/test/scala/com/amazon/deequ/repository/MetricsRepositoryMultipleResultsLoaderTest.scala +++ b/src/test/scala/com/amazon/deequ/repository/MetricsRepositoryMultipleResultsLoaderTest.scala @@ -264,7 +264,8 @@ class MetricsRepositoryMultipleResultsLoaderTest extends AnyWordSpec with Matche } private[this] def assertSameJson(jsonA: String, jsonB: String): Unit = { - assert(SimpleResultSerde.deserialize(jsonA) == - SimpleResultSerde.deserialize(jsonB)) + assert(SimpleResultSerde.deserialize(jsonA).toSet.sameElements(SimpleResultSerde.deserialize(jsonB).toSet)) + // assert(SimpleResultSerde.deserialize(jsonA) == + // SimpleResultSerde.deserialize(jsonB)) } } diff --git a/src/test/scala/com/amazon/deequ/suggestions/ConstraintSuggestionResultTest.scala b/src/test/scala/com/amazon/deequ/suggestions/ConstraintSuggestionResultTest.scala index 6a98bf3c6..9a82903e8 100644 --- a/src/test/scala/com/amazon/deequ/suggestions/ConstraintSuggestionResultTest.scala +++ b/src/test/scala/com/amazon/deequ/suggestions/ConstraintSuggestionResultTest.scala @@ -212,7 +212,7 @@ class ConstraintSuggestionResultTest extends WordSpec with Matchers with SparkCo """{ | "constraint_suggestions": [ | { - | "constraint_name": "CompletenessConstraint(Completeness(att2,None))", + | "constraint_name": "CompletenessConstraint(Completeness(att2,None,None))", | "column_name": "att2", | "current_value": "Completeness: 1.0", | "description": "'att2' is not null", @@ -222,7 +222,7 @@ class ConstraintSuggestionResultTest extends WordSpec with Matchers with SparkCo | "code_for_constraint": ".isComplete(\"att2\")" | }, | { - | "constraint_name": "CompletenessConstraint(Completeness(att1,None))", + | "constraint_name": "CompletenessConstraint(Completeness(att1,None,None))", | "column_name": "att1", | "current_value": "Completeness: 1.0", | "description": "'att1' is not null", @@ -232,7 +232,7 @@ class ConstraintSuggestionResultTest extends WordSpec with Matchers with SparkCo | "code_for_constraint": ".isComplete(\"att1\")" | }, | { - | "constraint_name": "CompletenessConstraint(Completeness(item,None))", + | "constraint_name": "CompletenessConstraint(Completeness(item,None,None))", | "column_name": "item", | "current_value": "Completeness: 1.0", | "description": "'item' is not null", @@ -265,7 +265,7 @@ class ConstraintSuggestionResultTest extends WordSpec with Matchers with SparkCo | "code_for_constraint": ".isNonNegative(\"item\")" | }, | { - | "constraint_name": "UniquenessConstraint(Uniqueness(List(item),None))", + | "constraint_name": "UniquenessConstraint(Uniqueness(List(item),None,None))", | "column_name": "item", | "current_value": "ApproxDistinctness: 1.0", | "description": "'item' is unique", @@ -294,7 +294,7 @@ class ConstraintSuggestionResultTest extends WordSpec with Matchers with SparkCo """{ | "constraint_suggestions": [ | { - | "constraint_name": "CompletenessConstraint(Completeness(att2,None))", + | "constraint_name": "CompletenessConstraint(Completeness(att2,None,None))", | "column_name": "att2", | "current_value": "Completeness: 1.0", | "description": "\u0027att2\u0027 is not null", @@ -305,7 +305,7 @@ class ConstraintSuggestionResultTest extends WordSpec with Matchers with SparkCo | "constraint_result_on_test_set": "Failure" | }, | { - | "constraint_name": "CompletenessConstraint(Completeness(att1,None))", + | "constraint_name": "CompletenessConstraint(Completeness(att1,None,None))", | "column_name": "att1", | "current_value": "Completeness: 1.0", | "description": "\u0027att1\u0027 is not null", @@ -316,7 +316,7 @@ class ConstraintSuggestionResultTest extends WordSpec with Matchers with SparkCo | "constraint_result_on_test_set": "Failure" | }, | { - | "constraint_name": "CompletenessConstraint(Completeness(item,None))", + | "constraint_name": "CompletenessConstraint(Completeness(item,None,None))", | "column_name": "item", | "current_value": "Completeness: 1.0", | "description": "\u0027item\u0027 is not null", @@ -352,7 +352,7 @@ class ConstraintSuggestionResultTest extends WordSpec with Matchers with SparkCo | "constraint_result_on_test_set": "Failure" | }, | { - | "constraint_name": "UniquenessConstraint(Uniqueness(List(item),None))", + | "constraint_name": "UniquenessConstraint(Uniqueness(List(item),None,None))", | "column_name": "item", | "current_value": "ApproxDistinctness: 1.0", | "description": "\u0027item\u0027 is unique", @@ -381,7 +381,7 @@ class ConstraintSuggestionResultTest extends WordSpec with Matchers with SparkCo """{ | "constraint_suggestions": [ | { - | "constraint_name": "CompletenessConstraint(Completeness(att2,None))", + | "constraint_name": "CompletenessConstraint(Completeness(att2,None,None))", | "column_name": "att2", | "current_value": "Completeness: 1.0", | "description": "\u0027att2\u0027 is not null", @@ -392,7 +392,7 @@ class ConstraintSuggestionResultTest extends WordSpec with Matchers with SparkCo | "constraint_result_on_test_set": "Unknown" | }, | { - | "constraint_name": "CompletenessConstraint(Completeness(att1,None))", + | "constraint_name": "CompletenessConstraint(Completeness(att1,None,None))", | "column_name": "att1", | "current_value": "Completeness: 1.0", | "description": "\u0027att1\u0027 is not null", @@ -403,7 +403,7 @@ class ConstraintSuggestionResultTest extends WordSpec with Matchers with SparkCo | "constraint_result_on_test_set": "Unknown" | }, | { - | "constraint_name": "CompletenessConstraint(Completeness(item,None))", + | "constraint_name": "CompletenessConstraint(Completeness(item,None,None))", | "column_name": "item", | "current_value": "Completeness: 1.0", | "description": "\u0027item\u0027 is not null", @@ -439,7 +439,7 @@ class ConstraintSuggestionResultTest extends WordSpec with Matchers with SparkCo | "constraint_result_on_test_set": "Unknown" | }, | { - | "constraint_name": "UniquenessConstraint(Uniqueness(List(item),None))", + | "constraint_name": "UniquenessConstraint(Uniqueness(List(item),None,None))", | "column_name": "item", | "current_value": "ApproxDistinctness: 1.0", | "description": "\u0027item\u0027 is unique", @@ -471,7 +471,7 @@ class ConstraintSuggestionResultTest extends WordSpec with Matchers with SparkCo """{ | "constraint_suggestions": [ | { - | "constraint_name": "CompletenessConstraint(Completeness(`item.one`,None))", + | "constraint_name": "CompletenessConstraint(Completeness(`item.one`,None,None))", | "column_name": "`item.one`", | "current_value": "Completeness: 1.0", | "description": "'`item.one`' is not null", @@ -504,7 +504,7 @@ class ConstraintSuggestionResultTest extends WordSpec with Matchers with SparkCo | "code_for_constraint": ".isNonNegative(\"`item.one`\")" | }, | { - | "constraint_name": "UniquenessConstraint(Uniqueness(List(`item.one`),None))", + | "constraint_name": "UniquenessConstraint(Uniqueness(List(`item.one`),None,None))", | "column_name": "`item.one`", | "current_value": "ApproxDistinctness: 1.0", | "description": "'`item.one`' is unique", @@ -515,7 +515,7 @@ class ConstraintSuggestionResultTest extends WordSpec with Matchers with SparkCo | "code_for_constraint": ".isUnique(\"`item.one`\")" | }, | { - | "constraint_name": "CompletenessConstraint(Completeness(att2,None))", + | "constraint_name": "CompletenessConstraint(Completeness(att2,None,None))", | "column_name": "att2", | "current_value": "Completeness: 1.0", | "description": "'att2' is not null", @@ -525,7 +525,7 @@ class ConstraintSuggestionResultTest extends WordSpec with Matchers with SparkCo | "code_for_constraint": ".isComplete(\"att2\")" | }, | { - | "constraint_name": "CompletenessConstraint(Completeness(att1,None))", + | "constraint_name": "CompletenessConstraint(Completeness(att1,None,None))", | "column_name": "att1", | "current_value": "Completeness: 1.0", | "description": "'att1' is not null", From 6375a32b35e8fd92e14be866b63da96886855a5b Mon Sep 17 00:00:00 2001 From: Edward Cho Date: Thu, 15 Feb 2024 12:43:39 -0500 Subject: [PATCH 8/8] Modify computeStateFrom to take in optional filterCondition --- .../scala/com/amazon/deequ/analyzers/Analyzer.scala | 10 ++-------- .../scala/com/amazon/deequ/analyzers/CustomSql.scala | 6 +----- .../amazon/deequ/analyzers/DatasetMatchAnalyzer.scala | 6 +----- .../com/amazon/deequ/analyzers/GroupingAnalyzers.scala | 9 +++------ .../scala/com/amazon/deequ/analyzers/Histogram.scala | 7 ++----- .../constraints/AnalysisBasedConstraintTest.scala | 6 +----- 6 files changed, 10 insertions(+), 34 deletions(-) diff --git a/src/main/scala/com/amazon/deequ/analyzers/Analyzer.scala b/src/main/scala/com/amazon/deequ/analyzers/Analyzer.scala index f1d7c35c8..bc241fe72 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/Analyzer.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/Analyzer.scala @@ -70,9 +70,7 @@ trait Analyzer[S <: State[_], +M <: Metric[_]] extends Serializable { * @param data data frame * @return */ - def computeStateFrom(data: DataFrame): Option[S] - - def computeStateFrom(data: DataFrame, filterCondition: Option[String]): Option[S] + def computeStateFrom(data: DataFrame, filterCondition: Option[String] = None): Option[S] /** * Compute the metric from the state (sufficient statistics) @@ -187,16 +185,12 @@ trait ScanShareableAnalyzer[S <: State[_], +M <: Metric[_]] extends Analyzer[S, private[deequ] def fromAggregationResult(result: Row, offset: Int): Option[S] /** Runs aggregation functions directly, without scan sharing */ - override def computeStateFrom(data: DataFrame): Option[S] = { + override def computeStateFrom(data: DataFrame, where: Option[String] = None): Option[S] = { val aggregations = aggregationFunctions() val result = data.agg(aggregations.head, aggregations.tail: _*).collect().head fromAggregationResult(result, 0) } - override def computeStateFrom(data: DataFrame, where: Option[String]): Option[S] = { - computeStateFrom(data) - } - /** Produces a metric from the aggregation result */ private[deequ] def metricFromAggregationResult( result: Row, diff --git a/src/main/scala/com/amazon/deequ/analyzers/CustomSql.scala b/src/main/scala/com/amazon/deequ/analyzers/CustomSql.scala index bc14bd184..e07e2d11f 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/CustomSql.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/CustomSql.scala @@ -33,7 +33,7 @@ case class CustomSql(expression: String) extends Analyzer[CustomSqlState, Double * @param data data frame * @return */ - override def computeStateFrom(data: DataFrame): Option[CustomSqlState] = { + override def computeStateFrom(data: DataFrame, filterCondition: Option[String] = None): Option[CustomSqlState] = { Try { data.sqlContext.sql(expression) @@ -55,10 +55,6 @@ case class CustomSql(expression: String) extends Analyzer[CustomSqlState, Double } } - override def computeStateFrom(data: DataFrame, filterCondition: Option[String]): Option[CustomSqlState] = { - computeStateFrom(data) - } - /** * Compute the metric from the state (sufficient statistics) * diff --git a/src/main/scala/com/amazon/deequ/analyzers/DatasetMatchAnalyzer.scala b/src/main/scala/com/amazon/deequ/analyzers/DatasetMatchAnalyzer.scala index bfcc5c06d..f2aefb57f 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/DatasetMatchAnalyzer.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/DatasetMatchAnalyzer.scala @@ -69,7 +69,7 @@ case class DatasetMatchAnalyzer(dfToCompare: DataFrame, matchColumnMappings: Option[Map[String, String]] = None) extends Analyzer[DatasetMatchState, DoubleMetric] { - override def computeStateFrom(data: DataFrame): Option[DatasetMatchState] = { + override def computeStateFrom(data: DataFrame, filterCondition: Option[String] = None): Option[DatasetMatchState] = { val result = if (matchColumnMappings.isDefined) { DataSynchronization.columnMatch(data, dfToCompare, columnMappings, matchColumnMappings.get, assertion) @@ -86,10 +86,6 @@ case class DatasetMatchAnalyzer(dfToCompare: DataFrame, } } - override def computeStateFrom(data: DataFrame, filterCondition: Option[String]): Option[DatasetMatchState] = { - computeStateFrom(data) - } - override def computeMetricFrom(state: Option[DatasetMatchState]): DoubleMetric = { val metric = state match { diff --git a/src/main/scala/com/amazon/deequ/analyzers/GroupingAnalyzers.scala b/src/main/scala/com/amazon/deequ/analyzers/GroupingAnalyzers.scala index a7197e89c..30bd89621 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/GroupingAnalyzers.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/GroupingAnalyzers.scala @@ -40,12 +40,9 @@ abstract class FrequencyBasedAnalyzer(columnsToGroupOn: Seq[String]) override def groupingColumns(): Seq[String] = { columnsToGroupOn } - override def computeStateFrom(data: DataFrame): Option[FrequenciesAndNumRows] = { - Some(FrequencyBasedAnalyzer.computeFrequencies(data, groupingColumns())) - } - - override def computeStateFrom(data: DataFrame, where: Option[String]): Option[FrequenciesAndNumRows] = { - Some(FrequencyBasedAnalyzer.computeFrequencies(data, groupingColumns, where)) + override def computeStateFrom(data: DataFrame, + filterCondition: Option[String] = None): Option[FrequenciesAndNumRows] = { + Some(FrequencyBasedAnalyzer.computeFrequencies(data, groupingColumns(), filterCondition)) } /** We need at least one grouping column, and all specified columns must exist */ diff --git a/src/main/scala/com/amazon/deequ/analyzers/Histogram.scala b/src/main/scala/com/amazon/deequ/analyzers/Histogram.scala index 277b52aea..742b2ba68 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/Histogram.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/Histogram.scala @@ -59,7 +59,8 @@ case class Histogram( } } - override def computeStateFrom(data: DataFrame): Option[FrequenciesAndNumRows] = { + override def computeStateFrom(data: DataFrame, + filterCondition: Option[String] = None): Option[FrequenciesAndNumRows] = { // TODO figure out a way to pass this in if its known before hand val totalCount = if (computeFrequenciesAsRatio) { @@ -76,10 +77,6 @@ case class Histogram( Some(FrequenciesAndNumRows(frequencies, totalCount)) } - override def computeStateFrom(data: DataFrame, where: Option[String]): Option[FrequenciesAndNumRows] = { - computeStateFrom(data) - } - override def computeMetricFrom(state: Option[FrequenciesAndNumRows]): HistogramMetric = { state match { diff --git a/src/test/scala/com/amazon/deequ/constraints/AnalysisBasedConstraintTest.scala b/src/test/scala/com/amazon/deequ/constraints/AnalysisBasedConstraintTest.scala index c9164ba6a..a7efbe180 100644 --- a/src/test/scala/com/amazon/deequ/constraints/AnalysisBasedConstraintTest.scala +++ b/src/test/scala/com/amazon/deequ/constraints/AnalysisBasedConstraintTest.scala @@ -68,14 +68,10 @@ class AnalysisBasedConstraintTest extends WordSpec with Matchers with SparkConte DoubleMetric(Entity.Column, "sample", column, value) } - override def computeStateFrom(data: DataFrame): Option[NumMatches] = { + override def computeStateFrom(data: DataFrame, filterCondition: Option[String] = None): Option[NumMatches] = { throw new NotImplementedError() } - override def computeStateFrom(data: DataFrame, filterCondition: Option[String]): Option[NumMatches] = { - computeStateFrom(data) - } - override def computeMetricFrom(state: Option[NumMatches]): DoubleMetric = { throw new NotImplementedError() }