From e48f97aab7cc0db957eb7aeabcafed405aa10624 Mon Sep 17 00:00:00 2001 From: Edward Cho <114528615+eycho-am@users.noreply.github.com> Date: Wed, 21 Feb 2024 13:29:16 -0500 Subject: [PATCH] Feature: Add Row Level Result Treatment Options for Miminum and Maximum (#535) * Address comments on PR #532 * Add filtered row-level result support for Minimum, Maximum, Compliance, PatternMatch, MinLength, MaxLength analyzers * Refactored criterion for MinLength and MaxLength analyzers to separate rowLevelResults logic --- .../amazon/deequ/VerificationRunBuilder.scala | 2 +- .../com/amazon/deequ/analyzers/Analyzer.scala | 32 ++- .../amazon/deequ/analyzers/Completeness.scala | 11 +- .../amazon/deequ/analyzers/Compliance.scala | 19 +- .../deequ/analyzers/GroupingAnalyzers.scala | 3 +- .../amazon/deequ/analyzers/MaxLength.scala | 33 ++- .../com/amazon/deequ/analyzers/Maximum.scala | 18 +- .../amazon/deequ/analyzers/MinLength.scala | 33 ++- .../com/amazon/deequ/analyzers/Minimum.scala | 23 +- .../amazon/deequ/analyzers/PatternMatch.scala | 26 +- .../deequ/analyzers/UniqueValueRatio.scala | 10 +- .../amazon/deequ/analyzers/Uniqueness.scala | 10 +- .../scala/com/amazon/deequ/checks/Check.scala | 64 +++-- .../amazon/deequ/constraints/Constraint.scala | 24 +- .../amazon/deequ/VerificationSuiteTest.scala | 243 ++++++++++++++++-- .../deequ/analyzers/CompletenessTest.scala | 2 +- .../deequ/analyzers/ComplianceTest.scala | 159 +++++++++++- .../deequ/analyzers/MaxLengthTest.scala | 78 ++++++ .../amazon/deequ/analyzers/MaximumTest.scala | 31 +++ .../deequ/analyzers/MinLengthTest.scala | 87 ++++++- .../amazon/deequ/analyzers/MinimumTest.scala | 33 +++ .../deequ/analyzers/PatternMatchTest.scala | 55 +++- .../deequ/analyzers/UniquenessTest.scala | 4 +- .../runners/AnalysisRunnerTests.scala | 7 +- .../runners/AnalyzerContextTest.scala | 2 - .../ConstraintSuggestionResultTest.scala | 8 +- .../amazon/deequ/utils/FixtureSupport.scala | 45 ++-- 27 files changed, 924 insertions(+), 138 deletions(-) diff --git a/src/main/scala/com/amazon/deequ/VerificationRunBuilder.scala b/src/main/scala/com/amazon/deequ/VerificationRunBuilder.scala index 929b2319b..f34b7f6ee 100644 --- a/src/main/scala/com/amazon/deequ/VerificationRunBuilder.scala +++ b/src/main/scala/com/amazon/deequ/VerificationRunBuilder.scala @@ -25,7 +25,7 @@ import com.amazon.deequ.repository._ import org.apache.spark.sql.{DataFrame, SparkSession} /** A class to build a VerificationRun using a fluent API */ -class VerificationRunBuilder(val data: DataFrame) { +class VerificationRunBuilder(val data: DataFrame) { protected var requiredAnalyzers: Seq[Analyzer[_, Metric[_]]] = Seq.empty diff --git a/src/main/scala/com/amazon/deequ/analyzers/Analyzer.scala b/src/main/scala/com/amazon/deequ/analyzers/Analyzer.scala index bc241fe72..dd5fb07e9 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/Analyzer.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/Analyzer.scala @@ -17,7 +17,7 @@ package com.amazon.deequ.analyzers import com.amazon.deequ.analyzers.Analyzers._ -import com.amazon.deequ.analyzers.FilteredRow.FilteredRow +import com.amazon.deequ.analyzers.FilteredRowOutcome.FilteredRowOutcome import com.amazon.deequ.analyzers.NullBehavior.NullBehavior import com.amazon.deequ.analyzers.runners._ import com.amazon.deequ.metrics.DoubleMetric @@ -172,6 +172,12 @@ trait Analyzer[S <: State[_], +M <: Metric[_]] extends Serializable { private[deequ] def copyStateTo(source: StateLoader, target: StatePersister): Unit = { source.load[S](this).foreach { state => target.persist(this, state) } } + + private[deequ] def getRowLevelFilterTreatment(analyzerOptions: Option[AnalyzerOptions]): FilteredRowOutcome = { + analyzerOptions + .map { options => options.filteredRow } + .getOrElse(FilteredRowOutcome.TRUE) + } } /** An analyzer that runs a set of aggregation functions over the data, @@ -257,15 +263,19 @@ case class NumMatchesAndCount(numMatches: Long, count: Long, override val fullCo } case class AnalyzerOptions(nullBehavior: NullBehavior = NullBehavior.Ignore, - filteredRow: FilteredRow = FilteredRow.TRUE) + filteredRow: FilteredRowOutcome = FilteredRowOutcome.TRUE) object NullBehavior extends Enumeration { type NullBehavior = Value val Ignore, EmptyString, Fail = Value } -object FilteredRow extends Enumeration { - type FilteredRow = Value +object FilteredRowOutcome extends Enumeration { + type FilteredRowOutcome = Value val NULL, TRUE = Value + + implicit class FilteredRowOutcomeOps(value: FilteredRowOutcome) { + def getExpression: Column = expr(value.toString) + } } /** Base class for analyzers that compute ratios of matching predicates */ @@ -484,6 +494,12 @@ private[deequ] object Analyzers { .getOrElse(selection) } + def conditionSelectionGivenColumn(selection: Column, where: Option[Column], replaceWith: Boolean): Column = { + where + .map { condition => when(condition, replaceWith).otherwise(selection) } + .getOrElse(selection) + } + def conditionalSelection(selection: Column, where: Option[String], replaceWith: Double): Column = { conditionSelectionGivenColumn(selection, where.map(expr), replaceWith) } @@ -500,12 +516,12 @@ private[deequ] object Analyzers { def conditionalSelectionFilteredFromColumns( selection: Column, conditionColumn: Option[Column], - filterTreatment: String) + filterTreatment: FilteredRowOutcome) : Column = { conditionColumn - .map { condition => { - when(not(condition), expr(filterTreatment)).when(condition, selection) - } } + .map { condition => + when(not(condition), filterTreatment.getExpression).when(condition, selection) + } .getOrElse(selection) } diff --git a/src/main/scala/com/amazon/deequ/analyzers/Completeness.scala b/src/main/scala/com/amazon/deequ/analyzers/Completeness.scala index 399cbb06a..3a262d7cc 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/Completeness.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/Completeness.scala @@ -20,7 +20,6 @@ import com.amazon.deequ.analyzers.Preconditions.{hasColumn, isNotNested} import org.apache.spark.sql.functions.sum import org.apache.spark.sql.types.{IntegerType, StructType} import Analyzers._ -import com.amazon.deequ.analyzers.FilteredRow.FilteredRow import com.google.common.annotations.VisibleForTesting import org.apache.spark.sql.functions.col import org.apache.spark.sql.functions.expr @@ -54,15 +53,9 @@ case class Completeness(column: String, where: Option[String] = None, @VisibleForTesting // required by some tests that compare analyzer results to an expected state private[deequ] def criterion: Column = conditionalSelection(column, where).isNotNull - @VisibleForTesting private[deequ] def rowLevelResults: Column = { val whereCondition = where.map { expression => expr(expression)} - conditionalSelectionFilteredFromColumns(col(column).isNotNull, whereCondition, getRowLevelFilterTreatment.toString) - } - - private def getRowLevelFilterTreatment: FilteredRow = { - analyzerOptions - .map { options => options.filteredRow } - .getOrElse(FilteredRow.TRUE) + conditionalSelectionFilteredFromColumns( + col(column).isNotNull, whereCondition, getRowLevelFilterTreatment(analyzerOptions)) } } diff --git a/src/main/scala/com/amazon/deequ/analyzers/Compliance.scala b/src/main/scala/com/amazon/deequ/analyzers/Compliance.scala index ec242fe6c..0edf01970 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/Compliance.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/Compliance.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.functions._ import Analyzers._ import com.amazon.deequ.analyzers.Preconditions.hasColumn import com.google.common.annotations.VisibleForTesting +import org.apache.spark.sql.types.DoubleType /** * Compliance is a measure of the fraction of rows that complies with the given column constraint. @@ -40,14 +41,15 @@ import com.google.common.annotations.VisibleForTesting case class Compliance(instance: String, predicate: String, where: Option[String] = None, - columns: List[String] = List.empty[String]) + columns: List[String] = List.empty[String], + analyzerOptions: Option[AnalyzerOptions] = None) extends StandardScanShareableAnalyzer[NumMatchesAndCount]("Compliance", instance) with FilterableAnalyzer { override def fromAggregationResult(result: Row, offset: Int): Option[NumMatchesAndCount] = { ifNoNullsIn(result, offset, howMany = 2) { _ => - NumMatchesAndCount(result.getLong(offset), result.getLong(offset + 1), Some(criterion)) + NumMatchesAndCount(result.getLong(offset), result.getLong(offset + 1), Some(rowLevelResults)) } } @@ -65,6 +67,19 @@ case class Compliance(instance: String, conditionalSelection(expr(predicate), where).cast(IntegerType) } + private def rowLevelResults: Column = { + val filteredRowOutcome = getRowLevelFilterTreatment(analyzerOptions) + val whereNotCondition = where.map { expression => not(expr(expression)) } + + filteredRowOutcome match { + case FilteredRowOutcome.TRUE => + conditionSelectionGivenColumn(expr(predicate), whereNotCondition, replaceWith = true).cast(IntegerType) + case _ => + // The default behavior when using filtering for rows is to treat them as nulls. No special treatment needed. + criterion + } + } + override protected def additionalPreconditions(): Seq[StructType => Unit] = columns.map(hasColumn) } diff --git a/src/main/scala/com/amazon/deequ/analyzers/GroupingAnalyzers.scala b/src/main/scala/com/amazon/deequ/analyzers/GroupingAnalyzers.scala index 30bd89621..c830d0189 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/GroupingAnalyzers.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/GroupingAnalyzers.scala @@ -93,9 +93,8 @@ object FrequencyBasedAnalyzer { val fullColumn: Column = { val window = Window.partitionBy(columnsToGroupBy: _*) where.map { - condition => { + condition => count(when(expr(condition), UNIQUENESS_ID)).over(window) - } }.getOrElse(count(UNIQUENESS_ID).over(window)) } diff --git a/src/main/scala/com/amazon/deequ/analyzers/MaxLength.scala b/src/main/scala/com/amazon/deequ/analyzers/MaxLength.scala index 47ed71a69..19c9ca9b7 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/MaxLength.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/MaxLength.scala @@ -23,8 +23,10 @@ import com.amazon.deequ.analyzers.Preconditions.isString import org.apache.spark.sql.Column import org.apache.spark.sql.Row import org.apache.spark.sql.functions.col +import org.apache.spark.sql.functions.expr import org.apache.spark.sql.functions.length import org.apache.spark.sql.functions.max +import org.apache.spark.sql.functions.not import org.apache.spark.sql.types.DoubleType import org.apache.spark.sql.types.StructType @@ -33,12 +35,12 @@ case class MaxLength(column: String, where: Option[String] = None, analyzerOptio with FilterableAnalyzer { override def aggregationFunctions(): Seq[Column] = { - max(criterion(getNullBehavior)) :: Nil + max(criterion) :: Nil } override def fromAggregationResult(result: Row, offset: Int): Option[MaxState] = { ifNoNullsIn(result, offset) { _ => - MaxState(result.getDouble(offset), Some(criterion(getNullBehavior))) + MaxState(result.getDouble(offset), Some(rowLevelResults)) } } @@ -48,15 +50,34 @@ case class MaxLength(column: String, where: Option[String] = None, analyzerOptio override def filterCondition: Option[String] = where - private def criterion(nullBehavior: NullBehavior): Column = { + private[deequ] def criterion: Column = { + transformColForNullBehavior + } + + private[deequ] def rowLevelResults: Column = { + transformColForFilteredRow(criterion) + } + + private def transformColForFilteredRow(col: Column): Column = { + val whereNotCondition = where.map { expression => not(expr(expression)) } + getRowLevelFilterTreatment(analyzerOptions) match { + case FilteredRowOutcome.TRUE => + conditionSelectionGivenColumn(col, whereNotCondition, replaceWith = Double.MinValue) + case _ => + conditionSelectionGivenColumn(col, whereNotCondition, replaceWith = null) + } + } + + private def transformColForNullBehavior: Column = { val isNullCheck = col(column).isNull - nullBehavior match { + val colLengths: Column = length(conditionalSelection(column, where)).cast(DoubleType) + getNullBehavior match { case NullBehavior.Fail => - val colLengths: Column = length(conditionalSelection(column, where)).cast(DoubleType) conditionSelectionGivenColumn(colLengths, Option(isNullCheck), replaceWith = Double.MaxValue) case NullBehavior.EmptyString => length(conditionSelectionGivenColumn(col(column), Option(isNullCheck), replaceWith = "")).cast(DoubleType) - case _ => length(conditionalSelection(column, where)).cast(DoubleType) + case _ => + colLengths } } diff --git a/src/main/scala/com/amazon/deequ/analyzers/Maximum.scala b/src/main/scala/com/amazon/deequ/analyzers/Maximum.scala index 24a1ae965..c5cc33f94 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/Maximum.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/Maximum.scala @@ -23,6 +23,8 @@ import org.apache.spark.sql.types.{DoubleType, StructType} import Analyzers._ import com.amazon.deequ.metrics.FullColumn import com.google.common.annotations.VisibleForTesting +import org.apache.spark.sql.functions.expr +import org.apache.spark.sql.functions.not case class MaxState(maxValue: Double, override val fullColumn: Option[Column] = None) extends DoubleValuedState[MaxState] with FullColumn { @@ -36,7 +38,7 @@ case class MaxState(maxValue: Double, override val fullColumn: Option[Column] = } } -case class Maximum(column: String, where: Option[String] = None) +case class Maximum(column: String, where: Option[String] = None, analyzerOptions: Option[AnalyzerOptions] = None) extends StandardScanShareableAnalyzer[MaxState]("Maximum", column) with FilterableAnalyzer { @@ -47,7 +49,7 @@ case class Maximum(column: String, where: Option[String] = None) override def fromAggregationResult(result: Row, offset: Int): Option[MaxState] = { ifNoNullsIn(result, offset) { _ => - MaxState(result.getDouble(offset), Some(criterion)) + MaxState(result.getDouble(offset), Some(rowLevelResults)) } } @@ -60,5 +62,17 @@ case class Maximum(column: String, where: Option[String] = None) @VisibleForTesting private def criterion: Column = conditionalSelection(column, where).cast(DoubleType) + private[deequ] def rowLevelResults: Column = { + val filteredRowOutcome = getRowLevelFilterTreatment(analyzerOptions) + val whereNotCondition = where.map { expression => not(expr(expression)) } + + filteredRowOutcome match { + case FilteredRowOutcome.TRUE => + conditionSelectionGivenColumn(col(column), whereNotCondition, replaceWith = Double.MinValue).cast(DoubleType) + case _ => + criterion + } + } + } diff --git a/src/main/scala/com/amazon/deequ/analyzers/MinLength.scala b/src/main/scala/com/amazon/deequ/analyzers/MinLength.scala index b63c4b4be..c155cca94 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/MinLength.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/MinLength.scala @@ -23,8 +23,10 @@ import com.amazon.deequ.analyzers.Preconditions.isString import org.apache.spark.sql.Column import org.apache.spark.sql.Row import org.apache.spark.sql.functions.col +import org.apache.spark.sql.functions.expr import org.apache.spark.sql.functions.length import org.apache.spark.sql.functions.min +import org.apache.spark.sql.functions.not import org.apache.spark.sql.types.DoubleType import org.apache.spark.sql.types.StructType @@ -33,12 +35,12 @@ case class MinLength(column: String, where: Option[String] = None, analyzerOptio with FilterableAnalyzer { override def aggregationFunctions(): Seq[Column] = { - min(criterion(getNullBehavior)) :: Nil + min(criterion) :: Nil } override def fromAggregationResult(result: Row, offset: Int): Option[MinState] = { ifNoNullsIn(result, offset) { _ => - MinState(result.getDouble(offset), Some(criterion(getNullBehavior))) + MinState(result.getDouble(offset), Some(rowLevelResults)) } } @@ -48,15 +50,34 @@ case class MinLength(column: String, where: Option[String] = None, analyzerOptio override def filterCondition: Option[String] = where - private[deequ] def criterion(nullBehavior: NullBehavior): Column = { + private[deequ] def criterion: Column = { + transformColForNullBehavior + } + + private[deequ] def rowLevelResults: Column = { + transformColForFilteredRow(criterion) + } + + private def transformColForFilteredRow(col: Column): Column = { + val whereNotCondition = where.map { expression => not(expr(expression)) } + getRowLevelFilterTreatment(analyzerOptions) match { + case FilteredRowOutcome.TRUE => + conditionSelectionGivenColumn(col, whereNotCondition, replaceWith = Double.MaxValue) + case _ => + conditionSelectionGivenColumn(col, whereNotCondition, replaceWith = null) + } + } + + private def transformColForNullBehavior: Column = { val isNullCheck = col(column).isNull - nullBehavior match { + val colLengths: Column = length(conditionalSelection(column, where)).cast(DoubleType) + getNullBehavior match { case NullBehavior.Fail => - val colLengths: Column = length(conditionalSelection(column, where)).cast(DoubleType) conditionSelectionGivenColumn(colLengths, Option(isNullCheck), replaceWith = Double.MinValue) case NullBehavior.EmptyString => length(conditionSelectionGivenColumn(col(column), Option(isNullCheck), replaceWith = "")).cast(DoubleType) - case _ => length(conditionalSelection(column, where)).cast(DoubleType) + case _ => + colLengths } } diff --git a/src/main/scala/com/amazon/deequ/analyzers/Minimum.scala b/src/main/scala/com/amazon/deequ/analyzers/Minimum.scala index feac13f88..18640dc12 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/Minimum.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/Minimum.scala @@ -23,6 +23,8 @@ import org.apache.spark.sql.types.{DoubleType, StructType} import Analyzers._ import com.amazon.deequ.metrics.FullColumn import com.google.common.annotations.VisibleForTesting +import org.apache.spark.sql.functions.expr +import org.apache.spark.sql.functions.not case class MinState(minValue: Double, override val fullColumn: Option[Column] = None) extends DoubleValuedState[MinState] with FullColumn { @@ -36,7 +38,7 @@ case class MinState(minValue: Double, override val fullColumn: Option[Column] = } } -case class Minimum(column: String, where: Option[String] = None) +case class Minimum(column: String, where: Option[String] = None, analyzerOptions: Option[AnalyzerOptions] = None) extends StandardScanShareableAnalyzer[MinState]("Minimum", column) with FilterableAnalyzer { @@ -45,9 +47,8 @@ case class Minimum(column: String, where: Option[String] = None) } override def fromAggregationResult(result: Row, offset: Int): Option[MinState] = { - ifNoNullsIn(result, offset) { _ => - MinState(result.getDouble(offset), Some(criterion)) + MinState(result.getDouble(offset), Some(rowLevelResults)) } } @@ -58,5 +59,19 @@ case class Minimum(column: String, where: Option[String] = None) override def filterCondition: Option[String] = where @VisibleForTesting - private def criterion: Column = conditionalSelection(column, where).cast(DoubleType) + private def criterion: Column = { + conditionalSelection(column, where).cast(DoubleType) + } + + private[deequ] def rowLevelResults: Column = { + val filteredRowOutcome = getRowLevelFilterTreatment(analyzerOptions) + val whereNotCondition = where.map { expression => not(expr(expression)) } + + filteredRowOutcome match { + case FilteredRowOutcome.TRUE => + conditionSelectionGivenColumn(col(column), whereNotCondition, replaceWith = Double.MaxValue).cast(DoubleType) + case _ => + criterion + } + } } diff --git a/src/main/scala/com/amazon/deequ/analyzers/PatternMatch.scala b/src/main/scala/com/amazon/deequ/analyzers/PatternMatch.scala index 47fb08737..eb62f9675 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/PatternMatch.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/PatternMatch.scala @@ -19,6 +19,8 @@ package com.amazon.deequ.analyzers import com.amazon.deequ.analyzers.Analyzers._ import com.amazon.deequ.analyzers.Preconditions.{hasColumn, isString} import com.google.common.annotations.VisibleForTesting +import org.apache.spark.sql.functions.expr +import org.apache.spark.sql.functions.not import org.apache.spark.sql.{Column, Row} import org.apache.spark.sql.functions.{col, lit, regexp_extract, sum, when} import org.apache.spark.sql.types.{BooleanType, IntegerType, StructType} @@ -36,13 +38,14 @@ import scala.util.matching.Regex * @param pattern The regular expression to check for * @param where Additional filter to apply before the analyzer is run. */ -case class PatternMatch(column: String, pattern: Regex, where: Option[String] = None) +case class PatternMatch(column: String, pattern: Regex, where: Option[String] = None, + analyzerOptions: Option[AnalyzerOptions] = None) extends StandardScanShareableAnalyzer[NumMatchesAndCount]("PatternMatch", column) with FilterableAnalyzer { override def fromAggregationResult(result: Row, offset: Int): Option[NumMatchesAndCount] = { ifNoNullsIn(result, offset, howMany = 2) { _ => - NumMatchesAndCount(result.getLong(offset), result.getLong(offset + 1), Some(criterion.cast(BooleanType))) + NumMatchesAndCount(result.getLong(offset), result.getLong(offset + 1), Some(rowLevelResults.cast(BooleanType))) } } @@ -77,12 +80,25 @@ case class PatternMatch(column: String, pattern: Regex, where: Option[String] = @VisibleForTesting // required by some tests that compare analyzer results to an expected state private[deequ] def criterion: Column = { - val expression = when(regexp_extract(col(column), pattern.toString(), 0) =!= lit(""), 1) - .otherwise(0) - conditionalSelection(expression, where).cast(IntegerType) + conditionalSelection(getPatternMatchExpression, where).cast(IntegerType) } + private[deequ] def rowLevelResults: Column = { + val filteredRowOutcome = getRowLevelFilterTreatment(analyzerOptions) + val whereNotCondition = where.map { expression => not(expr(expression)) } + filteredRowOutcome match { + case FilteredRowOutcome.TRUE => + conditionSelectionGivenColumn(getPatternMatchExpression, whereNotCondition, replaceWith = 1).cast(IntegerType) + case _ => + // The default behavior when using filtering for rows is to treat them as nulls. No special treatment needed. + criterion + } + } + + private def getPatternMatchExpression: Column = { + when(regexp_extract(col(column), pattern.toString(), 0) =!= lit(""), 1).otherwise(0) + } } object Patterns { diff --git a/src/main/scala/com/amazon/deequ/analyzers/UniqueValueRatio.scala b/src/main/scala/com/amazon/deequ/analyzers/UniqueValueRatio.scala index c2fce1f14..02b682b9d 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/UniqueValueRatio.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/UniqueValueRatio.scala @@ -17,7 +17,7 @@ package com.amazon.deequ.analyzers import com.amazon.deequ.analyzers.Analyzers.COUNT_COL -import com.amazon.deequ.analyzers.FilteredRow.FilteredRow +import com.amazon.deequ.analyzers.FilteredRowOutcome.FilteredRowOutcome import com.amazon.deequ.metrics.DoubleMetric import org.apache.spark.sql.functions.expr import org.apache.spark.sql.functions.not @@ -43,7 +43,7 @@ case class UniqueValueRatio(columns: Seq[String], where: Option[String] = None, rowLevelColumn => { conditionColumn.map { condition => { - when(not(condition), expr(getRowLevelFilterTreatment.toString)) + when(not(condition), getRowLevelFilterTreatment(analyzerOptions).getExpression) .when(rowLevelColumn.equalTo(1), true).otherwise(false) } }.getOrElse(when(rowLevelColumn.equalTo(1), true).otherwise(false)) @@ -53,12 +53,6 @@ case class UniqueValueRatio(columns: Seq[String], where: Option[String] = None, } override def filterCondition: Option[String] = where - - private def getRowLevelFilterTreatment: FilteredRow = { - analyzerOptions - .map { options => options.filteredRow } - .getOrElse(FilteredRow.TRUE) - } } object UniqueValueRatio { diff --git a/src/main/scala/com/amazon/deequ/analyzers/Uniqueness.scala b/src/main/scala/com/amazon/deequ/analyzers/Uniqueness.scala index 78ba4c418..b46b6d324 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/Uniqueness.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/Uniqueness.scala @@ -17,7 +17,7 @@ package com.amazon.deequ.analyzers import com.amazon.deequ.analyzers.Analyzers.COUNT_COL -import com.amazon.deequ.analyzers.FilteredRow.FilteredRow +import com.amazon.deequ.analyzers.FilteredRowOutcome.FilteredRowOutcome import com.amazon.deequ.metrics.DoubleMetric import com.google.common.annotations.VisibleForTesting import org.apache.spark.sql.Column @@ -47,7 +47,7 @@ case class Uniqueness(columns: Seq[String], where: Option[String] = None, rowLevelColumn => { conditionColumn.map { condition => { - when(not(condition), expr(getRowLevelFilterTreatment.toString)) + when(not(condition), getRowLevelFilterTreatment(analyzerOptions).getExpression) .when(rowLevelColumn.equalTo(1), true).otherwise(false) } }.getOrElse(when(rowLevelColumn.equalTo(1), true).otherwise(false)) @@ -57,12 +57,6 @@ case class Uniqueness(columns: Seq[String], where: Option[String] = None, } override def filterCondition: Option[String] = where - - private def getRowLevelFilterTreatment: FilteredRow = { - analyzerOptions - .map { options => options.filteredRow } - .getOrElse(FilteredRow.TRUE) - } } object Uniqueness { diff --git a/src/main/scala/com/amazon/deequ/checks/Check.scala b/src/main/scala/com/amazon/deequ/checks/Check.scala index bdae62ab7..ccfd9badc 100644 --- a/src/main/scala/com/amazon/deequ/checks/Check.scala +++ b/src/main/scala/com/amazon/deequ/checks/Check.scala @@ -716,15 +716,17 @@ case class Check( * @param column Column to run the assertion on * @param assertion Function that receives a double input parameter and returns a boolean * @param hint A hint to provide additional context why a constraint could have failed + * @param analyzerOptions Options to configure analyzer behavior (NullTreatment, FilteredRow) * @return */ def hasMin( column: String, assertion: Double => Boolean, - hint: Option[String] = None) + hint: Option[String] = None, + analyzerOptions: Option[AnalyzerOptions] = None) : CheckWithLastConstraintFilterable = { - addFilterableConstraint { filter => minConstraint(column, assertion, filter, hint) } + addFilterableConstraint { filter => minConstraint(column, assertion, filter, hint, analyzerOptions) } } /** @@ -733,15 +735,17 @@ case class Check( * @param column Column to run the assertion on * @param assertion Function that receives a double input parameter and returns a boolean * @param hint A hint to provide additional context why a constraint could have failed + * @param analyzerOptions Options to configure analyzer behavior (NullTreatment, FilteredRow) * @return */ def hasMax( column: String, assertion: Double => Boolean, - hint: Option[String] = None) + hint: Option[String] = None, + analyzerOptions: Option[AnalyzerOptions] = None) : CheckWithLastConstraintFilterable = { - addFilterableConstraint { filter => maxConstraint(column, assertion, filter, hint) } + addFilterableConstraint { filter => maxConstraint(column, assertion, filter, hint, analyzerOptions) } } /** @@ -845,6 +849,7 @@ case class Check( * name the metrics for the analysis being done. * @param assertion Function that receives a double input parameter and returns a boolean * @param hint A hint to provide additional context why a constraint could have failed + * @param analyzerOptions Options to configure analyzer behavior (NullTreatment, FilteredRow) * @return */ def satisfies( @@ -852,11 +857,12 @@ case class Check( constraintName: String, assertion: Double => Boolean = Check.IsOne, hint: Option[String] = None, - columns: List[String] = List.empty[String]) + columns: List[String] = List.empty[String], + analyzerOptions: Option[AnalyzerOptions] = None) : CheckWithLastConstraintFilterable = { addFilterableConstraint { filter => - complianceConstraint(constraintName, columnCondition, assertion, filter, hint, columns) + complianceConstraint(constraintName, columnCondition, assertion, filter, hint, columns, analyzerOptions) } } @@ -868,6 +874,7 @@ case class Check( * @param pattern The columns values will be checked for a match against this pattern. * @param assertion Function that receives a double input parameter and returns a boolean * @param hint A hint to provide additional context why a constraint could have failed + * @param analyzerOptions Options to configure analyzer behavior (NullTreatment, FilteredRow) * @return */ def hasPattern( @@ -875,11 +882,12 @@ case class Check( pattern: Regex, assertion: Double => Boolean = Check.IsOne, name: Option[String] = None, - hint: Option[String] = None) + hint: Option[String] = None, + analyzerOptions: Option[AnalyzerOptions] = None) : CheckWithLastConstraintFilterable = { addFilterableConstraint { filter => - Constraint.patternMatchConstraint(column, pattern, assertion, filter, name, hint) + Constraint.patternMatchConstraint(column, pattern, assertion, filter, name, hint, analyzerOptions) } } @@ -1118,8 +1126,7 @@ case class Check( allowedValues: Array[String]) : CheckWithLastConstraintFilterable = { - - isContainedIn(column, allowedValues, Check.IsOne, None) + isContainedIn(column, allowedValues, Check.IsOne, None, None) } // We can't use default values here as you can't combine default values and overloading in Scala @@ -1137,7 +1144,7 @@ case class Check( hint: Option[String]) : CheckWithLastConstraintFilterable = { - isContainedIn(column, allowedValues, Check.IsOne, hint) + isContainedIn(column, allowedValues, Check.IsOne, hint, None) } // We can't use default values here as you can't combine default values and overloading in Scala @@ -1155,8 +1162,27 @@ case class Check( assertion: Double => Boolean) : CheckWithLastConstraintFilterable = { + isContainedIn(column, allowedValues, assertion, None, None) + } + + // We can't use default values here as you can't combine default values and overloading in Scala + /** + * Asserts that every non-null value in a column is contained in a set of predefined values + * + * @param column Column to run the assertion on + * @param allowedValues Allowed values for the column + * @param assertion Function that receives a double input parameter and returns a boolean + * @param hint A hint to provide additional context why a constraint could have failed + * @return + */ + def isContainedIn( + column: String, + allowedValues: Array[String], + assertion: Double => Boolean, + hint: Option[String]) + : CheckWithLastConstraintFilterable = { - isContainedIn(column, allowedValues, assertion, None) + isContainedIn(column, allowedValues, assertion, hint, None) } // We can't use default values here as you can't combine default values and overloading in Scala @@ -1167,23 +1193,24 @@ case class Check( * @param allowedValues Allowed values for the column * @param assertion Function that receives a double input parameter and returns a boolean * @param hint A hint to provide additional context why a constraint could have failed + * @param analyzerOptions Options to configure analyzer behavior (NullTreatment, FilteredRow) * @return */ def isContainedIn( column: String, allowedValues: Array[String], assertion: Double => Boolean, - hint: Option[String]) + hint: Option[String], + analyzerOptions: Option[AnalyzerOptions]) : CheckWithLastConstraintFilterable = { - val valueList = allowedValues .map { _.replaceAll("'", "\\\\\'") } .mkString("'", "','", "'") val predicate = s"`$column` IS NULL OR `$column` IN ($valueList)" satisfies(predicate, s"$column contained in ${allowedValues.mkString(",")}", - assertion, hint, List(column)) + assertion, hint, List(column), analyzerOptions) } /** @@ -1195,6 +1222,7 @@ case class Check( * @param includeLowerBound is a value equal to the lower bound allows? * @param includeUpperBound is a value equal to the upper bound allowed? * @param hint A hint to provide additional context why a constraint could have failed + * @param analyzerOptions Options to configure analyzer behavior (NullTreatment, FilteredRow) * @return */ def isContainedIn( @@ -1203,7 +1231,8 @@ case class Check( upperBound: Double, includeLowerBound: Boolean = true, includeUpperBound: Boolean = true, - hint: Option[String] = None) + hint: Option[String] = None, + analyzerOptions: Option[AnalyzerOptions] = None) : CheckWithLastConstraintFilterable = { val leftOperand = if (includeLowerBound) ">=" else ">" @@ -1212,7 +1241,8 @@ case class Check( val predicate = s"`$column` IS NULL OR " + s"(`$column` $leftOperand $lowerBound AND `$column` $rightOperand $upperBound)" - satisfies(predicate, s"$column between $lowerBound and $upperBound", hint = hint, columns = List(column)) + satisfies(predicate, s"$column between $lowerBound and $upperBound", hint = hint, + columns = List(column), analyzerOptions = analyzerOptions) } /** diff --git a/src/main/scala/com/amazon/deequ/constraints/Constraint.scala b/src/main/scala/com/amazon/deequ/constraints/Constraint.scala index b9a15901b..fec0842f7 100644 --- a/src/main/scala/com/amazon/deequ/constraints/Constraint.scala +++ b/src/main/scala/com/amazon/deequ/constraints/Constraint.scala @@ -371,6 +371,7 @@ object Constraint { * metrics for the analysis being done. * @param column Data frame column which is a combination of expression and the column name * @param hint A hint to provide additional context why a constraint could have failed + * @param analyzerOptions Options to configure analyzer behavior (NullTreatment, FilteredRow) */ def complianceConstraint( name: String, @@ -378,10 +379,11 @@ object Constraint { assertion: Double => Boolean, where: Option[String] = None, hint: Option[String] = None, - columns: List[String] = List.empty[String]) + columns: List[String] = List.empty[String], + analyzerOptions: Option[AnalyzerOptions] = None) : Constraint = { - val compliance = Compliance(name, column, where, columns) + val compliance = Compliance(name, column, where, columns, analyzerOptions) fromAnalyzer(compliance, assertion, hint) } @@ -406,6 +408,7 @@ object Constraint { * @param pattern The regex pattern to check compliance for * @param column Data frame column which is a combination of expression and the column name * @param hint A hint to provide additional context why a constraint could have failed + * @param analyzerOptions Options to configure analyzer behavior (NullTreatment, FilteredRow) */ def patternMatchConstraint( column: String, @@ -413,10 +416,11 @@ object Constraint { assertion: Double => Boolean, where: Option[String] = None, name: Option[String] = None, - hint: Option[String] = None) + hint: Option[String] = None, + analyzerOptions: Option[AnalyzerOptions] = None) : Constraint = { - val patternMatch = PatternMatch(column, pattern, where) + val patternMatch = PatternMatch(column, pattern, where, analyzerOptions) fromAnalyzer(patternMatch, pattern, assertion, name, hint) } @@ -637,16 +641,18 @@ object Constraint { * @param column Column to run the assertion on * @param assertion Function that receives a double input parameter and returns a boolean * @param hint A hint to provide additional context why a constraint could have failed + * @param analyzerOptions Options to configure analyzer behavior (NullTreatment, FilteredRow) * */ def minConstraint( column: String, assertion: Double => Boolean, where: Option[String] = None, - hint: Option[String] = None) + hint: Option[String] = None, + analyzerOptions: Option[AnalyzerOptions] = None) : Constraint = { - val minimum = Minimum(column, where) + val minimum = Minimum(column, where, analyzerOptions) fromAnalyzer(minimum, assertion, hint) } @@ -670,15 +676,17 @@ object Constraint { * @param column Column to run the assertion on * @param assertion Function that receives a double input parameter and returns a boolean * @param hint A hint to provide additional context why a constraint could have failed + * @param analyzerOptions Options to configure analyzer behavior (NullTreatment, FilteredRow) */ def maxConstraint( column: String, assertion: Double => Boolean, where: Option[String] = None, - hint: Option[String] = None) + hint: Option[String] = None, + analyzerOptions: Option[AnalyzerOptions] = None) : Constraint = { - val maximum = Maximum(column, where) + val maximum = Maximum(column, where, analyzerOptions) fromAnalyzer(maximum, assertion, hint) } diff --git a/src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala b/src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala index 7588ee914..1fb8ab74d 100644 --- a/src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala +++ b/src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala @@ -305,7 +305,7 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec } "generate a result that contains row-level results with true for filtered rows" in withSparkSession { session => - val data = getDfCompleteAndInCompleteColumns(session) + val data = getDfCompleteAndInCompleteColumnsWithIntId(session) val completeness = new Check(CheckLevel.Error, "rule1") .hasCompleteness("att2", _ > 0.7, None) @@ -315,15 +315,31 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec val uniquenessWhere = new Check(CheckLevel.Error, "rule3") .isUnique("att1") .where("item < 3") + val min = new Check(CheckLevel.Error, "rule4") + .hasMin("item", _ > 3, None) + .where("item > 3") + val max = new Check(CheckLevel.Error, "rule5") + .hasMax("item", _ < 4, None) + .where("item < 4") + val patternMatch = new Check(CheckLevel.Error, "rule6") + .hasPattern("att2", """(^f)""".r) + .where("item < 4") + val expectedColumn1 = completeness.description val expectedColumn2 = uniqueness.description val expectedColumn3 = uniquenessWhere.description + val expectedColumn4 = min.description + val expectedColumn5 = max.description + val expectedColumn6 = patternMatch.description val suite = new VerificationSuite().onData(data) .addCheck(completeness) .addCheck(uniqueness) .addCheck(uniquenessWhere) + .addCheck(min) + .addCheck(max) + .addCheck(patternMatch) val result: VerificationResult = suite.run() @@ -332,24 +348,38 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec val resultData = VerificationResult.rowLevelResultsAsDataFrame(session, result, data).orderBy("item") resultData.show(false) val expectedColumns: Set[String] = - data.columns.toSet + expectedColumn1 + expectedColumn2 + expectedColumn3 + data.columns.toSet + expectedColumn1 + expectedColumn2 + expectedColumn3 + + expectedColumn4 + expectedColumn5 + expectedColumn6 assert(resultData.columns.toSet == expectedColumns) + // filtered rows 2,5 (where att1 = "a") val rowLevel1 = resultData.select(expectedColumn1).collect().map(r => r.getAs[Any](0)) assert(Seq(true, true, false, true, true, true).sameElements(rowLevel1)) val rowLevel2 = resultData.select(expectedColumn2).collect().map(r => r.getAs[Any](0)) assert(Seq(false, false, false, false, false, false).sameElements(rowLevel2)) + // filtered rows 3,4,5,6 (where item < 3) val rowLevel3 = resultData.select(expectedColumn3).collect().map(r => r.getAs[Any](0)) assert(Seq(true, true, true, true, true, true).sameElements(rowLevel3)) + // filtered rows 1, 2, 3 (where item > 3) + val minRowLevel = resultData.select(expectedColumn4).collect().map(r => r.getAs[Any](0)) + assert(Seq(true, true, true, true, true, true).sameElements(minRowLevel)) + + // filtered rows 4, 5, 6 (where item < 4) + val maxRowLevel = resultData.select(expectedColumn5).collect().map(r => r.getAs[Any](0)) + assert(Seq(true, true, true, true, true, true).sameElements(maxRowLevel)) + + // filtered rows 4, 5, 6 (where item < 4) + val rowLevel6 = resultData.select(expectedColumn6).collect().map(r => r.getAs[Any](0)) + assert(Seq(true, false, false, true, true, true).sameElements(rowLevel6)) } "generate a result that contains row-level results with null for filtered rows" in withSparkSession { session => - val data = getDfCompleteAndInCompleteColumns(session) + val data = getDfCompleteAndInCompleteColumnsWithIntId(session) - val analyzerOptions = Option(AnalyzerOptions(filteredRow = FilteredRow.NULL)) + val analyzerOptions = Option(AnalyzerOptions(filteredRow = FilteredRowOutcome.NULL)) val completeness = new Check(CheckLevel.Error, "rule1") .hasCompleteness("att2", _ > 0.7, None, analyzerOptions) @@ -359,14 +389,30 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec val uniquenessWhere = new Check(CheckLevel.Error, "rule3") .isUnique("att1", None, analyzerOptions) .where("item < 3") + val min = new Check(CheckLevel.Error, "rule4") + .hasMin("item", _ > 3, None, analyzerOptions) + .where("item > 3") + val max = new Check(CheckLevel.Error, "rule5") + .hasMax("item", _ < 4, None, analyzerOptions) + .where("item < 4") + val patternMatch = new Check(CheckLevel.Error, "rule6") + .hasPattern("att2", """(^f)""".r, analyzerOptions = analyzerOptions) + .where("item < 4") + val expectedColumn1 = completeness.description val expectedColumn2 = uniqueness.description val expectedColumn3 = uniquenessWhere.description + val expectedColumn4 = min.description + val expectedColumn5 = max.description + val expectedColumn6 = patternMatch.description val suite = new VerificationSuite().onData(data) .addCheck(completeness) .addCheck(uniqueness) .addCheck(uniquenessWhere) + .addCheck(min) + .addCheck(max) + .addCheck(patternMatch) val result: VerificationResult = suite.run() @@ -375,7 +421,8 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec val resultData = VerificationResult.rowLevelResultsAsDataFrame(session, result, data).orderBy("item") resultData.show(false) val expectedColumns: Set[String] = - data.columns.toSet + expectedColumn1 + expectedColumn2 + expectedColumn3 + data.columns.toSet + expectedColumn1 + expectedColumn2 + expectedColumn3 + + expectedColumn4 + expectedColumn5 + expectedColumn6 assert(resultData.columns.toSet == expectedColumns) val rowLevel1 = resultData.select(expectedColumn1).collect().map(r => r.getAs[Any](0)) @@ -384,9 +431,92 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec val rowLevel2 = resultData.select(expectedColumn2).collect().map(r => r.getAs[Any](0)) assert(Seq(false, false, false, false, false, false).sameElements(rowLevel2)) + // filtered rows 3,4,5,6 (where item < 3) val rowLevel3 = resultData.select(expectedColumn3).collect().map(r => r.getAs[Any](0)) assert(Seq(true, true, null, null, null, null).sameElements(rowLevel3)) + // filtered rows 1, 2, 3 (where item > 3) + val rowLevel4 = resultData.select(expectedColumn4).collect().map(r => r.getAs[Any](0)) + assert(Seq(null, null, null, true, true, true).sameElements(rowLevel4)) + + // filtered rows 4, 5, 6 (where item < 4) + val rowLevel5 = resultData.select(expectedColumn5).collect().map(r => r.getAs[Any](0)) + assert(Seq(true, true, true, null, null, null).sameElements(rowLevel5)) + + // filtered rows 4, 5, 6 (where item < 4) + val rowLevel6 = resultData.select(expectedColumn6).collect().map(r => r.getAs[Any](0)) + assert(Seq(true, false, false, null, null, null).sameElements(rowLevel6)) + } + + "generate a result that contains compliance row-level results " in withSparkSession { session => + val data = getDfWithNumericValues(session) + val analyzerOptions = Option(AnalyzerOptions(filteredRow = FilteredRowOutcome.NULL)) + + val complianceRange = new Check(CheckLevel.Error, "rule1") + .isContainedIn("attNull", 0, 6, false, false) + val complianceFilteredRange = new Check(CheckLevel.Error, "rule2") + .isContainedIn("attNull", 0, 6, false, false) + .where("att1 < 4") + val complianceFilteredRangeNull = new Check(CheckLevel.Error, "rule3") + .isContainedIn("attNull", 0, 6, false, false, + analyzerOptions = analyzerOptions) + .where("att1 < 4") + val complianceInArray = new Check(CheckLevel.Error, "rule4") + .isContainedIn("att2", Array("5", "6", "7")) + val complianceInArrayFiltered = new Check(CheckLevel.Error, "rule5") + .isContainedIn("att2", Array("5", "6", "7")) + .where("att1 > 3") + val complianceInArrayFilteredNull = new Check(CheckLevel.Error, "rule6") + .isContainedIn("att2", Array("5", "6", "7"), Check.IsOne, None, analyzerOptions) + .where("att1 > 3") + + val expectedColumn1 = complianceRange.description + val expectedColumn2 = complianceFilteredRange.description + val expectedColumn3 = complianceFilteredRangeNull.description + val expectedColumn4 = complianceInArray.description + val expectedColumn5 = complianceInArrayFiltered.description + val expectedColumn6 = complianceInArrayFilteredNull.description + + val suite = new VerificationSuite().onData(data) + .addCheck(complianceRange) + .addCheck(complianceFilteredRange) + .addCheck(complianceFilteredRangeNull) + .addCheck(complianceInArray) + .addCheck(complianceInArrayFiltered) + .addCheck(complianceInArrayFilteredNull) + + val result: VerificationResult = suite.run() + + assert(result.status == CheckStatus.Error) + + val resultData = VerificationResult.rowLevelResultsAsDataFrame(session, result, data).orderBy("item") + resultData.show(false) + val expectedColumns: Set[String] = + data.columns.toSet + expectedColumn1 + expectedColumn2 + expectedColumn3 + + expectedColumn4 + expectedColumn5 + expectedColumn6 + assert(resultData.columns.toSet == expectedColumns) + + val rowLevel1 = resultData.select(expectedColumn1).collect().map(r => r.getAs[Any](0)) + assert(Seq(true, true, true, true, false, false).sameElements(rowLevel1)) + + // filtered rows 4, 5, 6 (where att1 < 4) as true + val rowLevel2 = resultData.select(expectedColumn2).collect().map(r => r.getAs[Any](0)) + assert(Seq(true, true, true, true, true, true).sameElements(rowLevel2)) + + // filtered rows 4, 5, 6 (where att1 < 4) as null + val rowLevel3 = resultData.select(expectedColumn3).collect().map(r => r.getAs[Any](0)) + assert(Seq(true, true, true, null, null, null).sameElements(rowLevel3)) + + val rowLevel4 = resultData.select(expectedColumn4).collect().map(r => r.getAs[Any](0)) + assert(Seq(false, false, false, true, true, true).sameElements(rowLevel4)) + + // filtered rows 1,2,3 (where att1 > 3) as true + val rowLevel5 = resultData.select(expectedColumn5).collect().map(r => r.getAs[Any](0)) + assert(Seq(true, true, true, true, true, true).sameElements(rowLevel5)) + + // filtered rows 1,2,3 (where att1 > 3) as null + val rowLevel6 = resultData.select(expectedColumn6).collect().map(r => r.getAs[Any](0)) + assert(Seq(null, null, null, true, true, true).sameElements(rowLevel6)) } "generate a result that contains row-level results for null column values" in withSparkSession { session => @@ -422,20 +552,16 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec data.columns.toSet + expectedColumn1 + expectedColumn2 + expectedColumn3 + expectedColumn4 assert(resultData.columns.toSet == expectedColumns) - val rowLevel1 = resultData.select(expectedColumn1).collect().map(r => - if (r == null) null else r.getAs[Boolean](0)) + val rowLevel1 = resultData.select(expectedColumn1).collect().map(r => r.getAs[Any](0)) assert(Seq(false, null, true, true, null, true).sameElements(rowLevel1)) - val rowLevel2 = resultData.select(expectedColumn2).collect().map(r => - if (r == null) null else r.getAs[Boolean](0)) + val rowLevel2 = resultData.select(expectedColumn2).collect().map(r => r.getAs[Any](0)) assert(Seq(true, null, true, false, null, false).sameElements(rowLevel2)) - val rowLevel3 = resultData.select(expectedColumn3).collect().map(r => - if (r == null) null else r.getAs[Boolean](0)) + val rowLevel3 = resultData.select(expectedColumn3).collect().map(r => r.getAs[Any](0)) assert(Seq(true, true, false, true, false, true).sameElements(rowLevel3)) - val rowLevel4 = resultData.select(expectedColumn4).collect().map(r => - if (r == null) null else r.getAs[Boolean](0)) + val rowLevel4 = resultData.select(expectedColumn4).collect().map(r => r.getAs[Any](0)) assert(Seq(false, null, false, true, null, true).sameElements(rowLevel4)) } @@ -446,12 +572,37 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec .hasMinLength("att2", _ >= 1, analyzerOptions = Option(AnalyzerOptions(NullBehavior.Fail))) val maxLength = new Check(CheckLevel.Error, "rule2") .hasMaxLength("att2", _ <= 1, analyzerOptions = Option(AnalyzerOptions(NullBehavior.Fail))) + // filtered rows as null + val minLengthFilterNull = new Check(CheckLevel.Error, "rule3") + .hasMinLength("att2", _ >= 1, + analyzerOptions = Option(AnalyzerOptions(NullBehavior.Fail, FilteredRowOutcome.NULL))) + .where("val1 < 5") + val maxLengthFilterNull = new Check(CheckLevel.Error, "rule4") + .hasMaxLength("att2", _ <= 1, + analyzerOptions = Option(AnalyzerOptions(NullBehavior.Fail, FilteredRowOutcome.NULL))) + .where("val1 < 5") + val minLengthFilterTrue = new Check(CheckLevel.Error, "rule5") + .hasMinLength("att2", _ >= 1, + analyzerOptions = Option(AnalyzerOptions(NullBehavior.Fail, FilteredRowOutcome.TRUE))) + .where("val1 < 5") + val maxLengthFilterTrue = new Check(CheckLevel.Error, "rule6") + .hasMaxLength("att2", _ <= 1, + analyzerOptions = Option(AnalyzerOptions(NullBehavior.Fail, FilteredRowOutcome.TRUE))) + .where("val1 < 5") val expectedColumn1 = minLength.description val expectedColumn2 = maxLength.description + val expectedColumn3 = minLengthFilterNull.description + val expectedColumn4 = maxLengthFilterNull.description + val expectedColumn5 = minLengthFilterTrue.description + val expectedColumn6 = maxLengthFilterTrue.description val suite = new VerificationSuite().onData(data) .addCheck(minLength) .addCheck(maxLength) + .addCheck(minLengthFilterNull) + .addCheck(maxLengthFilterNull) + .addCheck(minLengthFilterTrue) + .addCheck(maxLengthFilterTrue) val result: VerificationResult = suite.run() @@ -461,7 +612,8 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec resultData.show() val expectedColumns: Set[String] = - data.columns.toSet + expectedColumn1 + expectedColumn2 + data.columns.toSet + expectedColumn1 + expectedColumn2 + expectedColumn3 + + expectedColumn4 + expectedColumn5 + expectedColumn6 assert(resultData.columns.toSet == expectedColumns) val rowLevel1 = resultData.select(expectedColumn1).collect().map(r => r.getBoolean(0)) @@ -469,6 +621,22 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec val rowLevel2 = resultData.select(expectedColumn2).collect().map(r => r.getBoolean(0)) assert(Seq(true, true, false, true, false, true).sameElements(rowLevel2)) + + // filtered last two rows where(val1 < 5) + val rowLevel3 = resultData.select(expectedColumn3).collect().map(r => r.getAs[Any](0)) + assert(Seq(true, true, false, true, null, null).sameElements(rowLevel3)) + + // filtered last two rows where(val1 < 5) + val rowLevel4 = resultData.select(expectedColumn4).collect().map(r => r.getAs[Any](0)) + assert(Seq(true, true, false, true, null, null).sameElements(rowLevel4)) + + // filtered last two rows where(val1 < 5) + val rowLevel5 = resultData.select(expectedColumn5).collect().map(r => r.getAs[Any](0)) + assert(Seq(true, true, false, true, true, true).sameElements(rowLevel5)) + + // filtered last two rows where(val1 < 5) + val rowLevel6 = resultData.select(expectedColumn6).collect().map(r => r.getAs[Any](0)) + assert(Seq(true, true, false, true, true, true).sameElements(rowLevel6)) } "generate a result that contains length row-level results with nullBehavior empty" in withSparkSession { session => @@ -480,12 +648,38 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec // nulls should succeed since length 0 is < 2 val maxLength = new Check(CheckLevel.Error, "rule2") .hasMaxLength("att2", _ < 2, analyzerOptions = Option(AnalyzerOptions(NullBehavior.EmptyString))) + // filtered rows as null + val minLengthFilterNull = new Check(CheckLevel.Error, "rule3") + .hasMinLength("att2", _ >= 1, + analyzerOptions = Option(AnalyzerOptions(NullBehavior.EmptyString, FilteredRowOutcome.NULL))) + .where("val1 < 5") + val maxLengthFilterNull = new Check(CheckLevel.Error, "rule4") + .hasMaxLength("att2", _ < 2, + analyzerOptions = Option(AnalyzerOptions(NullBehavior.EmptyString, FilteredRowOutcome.NULL))) + .where("val1 < 5") + val minLengthFilterTrue = new Check(CheckLevel.Error, "rule5") + .hasMinLength("att2", _ >= 1, + analyzerOptions = Option(AnalyzerOptions(NullBehavior.EmptyString, FilteredRowOutcome.TRUE))) + .where("val1 < 5") + val maxLengthFilterTrue = new Check(CheckLevel.Error, "rule6") + .hasMaxLength("att2", _ < 2, + analyzerOptions = Option(AnalyzerOptions(NullBehavior.EmptyString, FilteredRowOutcome.TRUE))) + .where("val1 < 5") + val expectedColumn1 = minLength.description val expectedColumn2 = maxLength.description + val expectedColumn3 = minLengthFilterNull.description + val expectedColumn4 = maxLengthFilterNull.description + val expectedColumn5 = minLengthFilterTrue.description + val expectedColumn6 = maxLengthFilterTrue.description val suite = new VerificationSuite().onData(data) .addCheck(minLength) .addCheck(maxLength) + .addCheck(minLengthFilterNull) + .addCheck(maxLengthFilterNull) + .addCheck(minLengthFilterTrue) + .addCheck(maxLengthFilterTrue) val result: VerificationResult = suite.run() @@ -495,7 +689,8 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec resultData.show() val expectedColumns: Set[String] = - data.columns.toSet + expectedColumn1 + expectedColumn2 + data.columns.toSet + expectedColumn1 + expectedColumn2 + expectedColumn3 + + expectedColumn4 + expectedColumn5 + expectedColumn6 assert(resultData.columns.toSet == expectedColumns) @@ -504,6 +699,22 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec val rowLevel2 = resultData.select(expectedColumn2).collect().map(r => r.getBoolean(0)) assert(Seq(true, true, true, true, true, true).sameElements(rowLevel2)) + + // filtered last two rows where(val1 < 5) + val rowLevel3 = resultData.select(expectedColumn3).collect().map(r => r.getAs[Any](0)) + assert(Seq(true, true, false, true, null, null).sameElements(rowLevel3)) + + // filtered last two rows where(val1 < 5) + val rowLevel4 = resultData.select(expectedColumn4).collect().map(r => r.getAs[Any](0)) + assert(Seq(true, true, true, true, null, null).sameElements(rowLevel4)) + + // filtered last two rows where(val1 < 5) + val rowLevel5 = resultData.select(expectedColumn5).collect().map(r => r.getAs[Any](0)) + assert(Seq(true, true, false, true, true, true).sameElements(rowLevel5)) + + // filtered last two rows where(val1 < 5) + val rowLevel6 = resultData.select(expectedColumn6).collect().map(r => r.getAs[Any](0)) + assert(Seq(true, true, true, true, true, true).sameElements(rowLevel6)) } "accept analysis config for mandatory analysis" in withSparkSession { sparkSession => @@ -1124,7 +1335,7 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec checkFailedResultStringType.constraintResults.map(_.message) shouldBe List(Some("Empty state for analyzer Compliance(name between 1.0 and 3.0,`name`" + " IS NULL OR (`name` >= 1.0 AND `name` <= 3.0)," + - "None,List(name)), all input values were NULL.")) + "None,List(name),None), all input values were NULL.")) assert(checkFailedResultStringType.status == CheckStatus.Error) val checkFailedCompletenessResult = verificationResult.checkResults(complianceCheckThatShouldFailCompleteness) diff --git a/src/test/scala/com/amazon/deequ/analyzers/CompletenessTest.scala b/src/test/scala/com/amazon/deequ/analyzers/CompletenessTest.scala index 54e26f867..b5b0d5094 100644 --- a/src/test/scala/com/amazon/deequ/analyzers/CompletenessTest.scala +++ b/src/test/scala/com/amazon/deequ/analyzers/CompletenessTest.scala @@ -46,7 +46,7 @@ class CompletenessTest extends AnyWordSpec with Matchers with SparkContextSpec w // Explicitly setting RowLevelFilterTreatment for test purposes, this should be set at the VerificationRunBuilder val completenessAtt2 = Completeness("att2", Option("att1 = \"a\""), - Option(AnalyzerOptions(filteredRow = FilteredRow.NULL))) + Option(AnalyzerOptions(filteredRow = FilteredRowOutcome.NULL))) val state = completenessAtt2.computeStateFrom(data) val metric: DoubleMetric with FullColumn = completenessAtt2.computeMetricFrom(state) diff --git a/src/test/scala/com/amazon/deequ/analyzers/ComplianceTest.scala b/src/test/scala/com/amazon/deequ/analyzers/ComplianceTest.scala index c572a4bd8..5aa4033ba 100644 --- a/src/test/scala/com/amazon/deequ/analyzers/ComplianceTest.scala +++ b/src/test/scala/com/amazon/deequ/analyzers/ComplianceTest.scala @@ -35,7 +35,8 @@ class ComplianceTest extends AnyWordSpec with Matchers with SparkContextSpec wit val state = att1Compliance.computeStateFrom(data) val metric: DoubleMetric with FullColumn = att1Compliance.computeMetricFrom(state) - data.withColumn("new", metric.fullColumn.get).collect().map(_.getAs[Int]("new")) shouldBe Seq(0, 0, 0, 1, 1, 1) + data.withColumn("new", metric.fullColumn.get).collect().map(_.getAs[Int]("new") + ) shouldBe Seq(0, 0, 0, 1, 1, 1) } "return row-level results for null columns" in withSparkSession { session => @@ -49,6 +50,162 @@ class ComplianceTest extends AnyWordSpec with Matchers with SparkContextSpec wit data.withColumn("new", metric.fullColumn.get).collect().map(r => if (r == null) null else r.getAs[Int]("new")) shouldBe Seq(null, null, null, 1, 1, 1) } + + "return row-level results filtered with null" in withSparkSession { session => + + val data = getDfWithNumericValues(session) + + val att1Compliance = Compliance("rule1", "att1 > 4", where = Option("att2 != 0"), + analyzerOptions = Option(AnalyzerOptions(filteredRow = FilteredRowOutcome.NULL))) + val state = att1Compliance.computeStateFrom(data) + val metric: DoubleMetric with FullColumn = att1Compliance.computeMetricFrom(state) + + data.withColumn("new", metric.fullColumn.get).collect().map(r => + if (r == null) null else r.getAs[Int]("new")) shouldBe Seq(null, null, null, 0, 1, 1) + } + + "return row-level results filtered with true" in withSparkSession { session => + + val data = getDfWithNumericValues(session) + + val att1Compliance = Compliance("rule1", "att1 > 4", where = Option("att2 != 0"), + analyzerOptions = Option(AnalyzerOptions(filteredRow = FilteredRowOutcome.TRUE))) + val state = att1Compliance.computeStateFrom(data) + val metric: DoubleMetric with FullColumn = att1Compliance.computeMetricFrom(state) + + data.withColumn("new", metric.fullColumn.get).collect().map(r => + if (r == null) null else r.getAs[Int]("new")) shouldBe Seq(1, 1, 1, 0, 1, 1) + } + + "return row-level results for compliance in bounds" in withSparkSession { session => + val column = "att1" + val leftOperand = ">=" + val rightOperand = "<=" + val lowerBound = 2 + val upperBound = 5 + val predicate = s"`$column` IS NULL OR " + + s"(`$column` $leftOperand $lowerBound AND `$column` $rightOperand $upperBound)" + + val data = getDfWithNumericValues(session) + + val att1Compliance = Compliance(predicate, s"$column between $lowerBound and $upperBound", columns = List("att3")) + val state = att1Compliance.computeStateFrom(data) + val metric: DoubleMetric with FullColumn = att1Compliance.computeMetricFrom(state) + + data.withColumn("new", metric.fullColumn.get).collect().map(r => + if (r == null) null else r.getAs[Int]("new")) shouldBe Seq(0, 1, 1, 1, 1, 0) + } + + "return row-level results for compliance in bounds filtered as null" in withSparkSession { session => + val column = "att1" + val leftOperand = ">=" + val rightOperand = "<=" + val lowerBound = 2 + val upperBound = 5 + val predicate = s"`$column` IS NULL OR " + + s"(`$column` $leftOperand $lowerBound AND `$column` $rightOperand $upperBound)" + + val data = getDfWithNumericValues(session) + + val att1Compliance = Compliance(predicate, s"$column between $lowerBound and $upperBound", + where = Option("att1 < 4"), columns = List("att3"), + analyzerOptions = Option(AnalyzerOptions(filteredRow = FilteredRowOutcome.NULL))) + val state = att1Compliance.computeStateFrom(data) + val metric: DoubleMetric with FullColumn = att1Compliance.computeMetricFrom(state) + + data.withColumn("new", metric.fullColumn.get).collect().map(r => + if (r == null) null else r.getAs[Int]("new")) shouldBe Seq(0, 1, 1, null, null, null) + } + + "return row-level results for compliance in bounds filtered as true" in withSparkSession { session => + val column = "att1" + val leftOperand = ">=" + val rightOperand = "<=" + val lowerBound = 2 + val upperBound = 5 + val predicate = s"`$column` IS NULL OR " + + s"(`$column` $leftOperand $lowerBound AND `$column` $rightOperand $upperBound)" + + val data = getDfWithNumericValues(session) + + val att1Compliance = Compliance(s"$column between $lowerBound and $upperBound", predicate, + where = Option("att1 < 4"), columns = List("att3"), + analyzerOptions = Option(AnalyzerOptions(filteredRow = FilteredRowOutcome.TRUE))) + val state = att1Compliance.computeStateFrom(data) + val metric: DoubleMetric with FullColumn = att1Compliance.computeMetricFrom(state) + + data.withColumn("new", metric.fullColumn.get).collect().map(r => + if (r == null) null else r.getAs[Int]("new")) shouldBe Seq(0, 1, 1, 1, 1, 1) + } + + "return row-level results for compliance in array" in withSparkSession { session => + val column = "att1" + val allowedValues = Array("3", "4", "5") + val valueList = allowedValues + .map { + _.replaceAll("'", "\\\\\'") + } + .mkString("'", "','", "'") + + val predicate = s"`$column` IS NULL OR `$column` IN ($valueList)" + + val data = getDfWithNumericValues(session) + + val att1Compliance = Compliance(s"$column contained in ${allowedValues.mkString(",")}", predicate, + columns = List("att3")) + val state = att1Compliance.computeStateFrom(data) + val metric: DoubleMetric with FullColumn = att1Compliance.computeMetricFrom(state) + + data.withColumn("new", metric.fullColumn.get).collect().map(r => + if (r == null) null else r.getAs[Int]("new")) shouldBe Seq(0, 0, 1, 1, 1, 0) + } + + "return row-level results for compliance in array filtered as null" in withSparkSession { session => + val column = "att1" + val allowedValues = Array("3", "4", "5") + val valueList = allowedValues + .map { + _.replaceAll("'", "\\\\\'") + } + .mkString("'", "','", "'") + + val predicate = s"`$column` IS NULL OR `$column` IN ($valueList)" + + val data = getDfWithNumericValues(session) + + val att1Compliance = Compliance(s"$column contained in ${allowedValues.mkString(",")}", predicate, + where = Option("att1 < 5"), columns = List("att3"), + analyzerOptions = Option(AnalyzerOptions(filteredRow = FilteredRowOutcome.NULL))) + val state = att1Compliance.computeStateFrom(data) + val metric: DoubleMetric with FullColumn = att1Compliance.computeMetricFrom(state) + + data.withColumn("new", metric.fullColumn.get).collect().map(r => + if (r == null) null else r.getAs[Int]("new")) shouldBe Seq(0, 0, 1, 1, null, null) + } + + "return row-level results for compliance in array filtered as true" in withSparkSession { session => + val column = "att1" + val allowedValues = Array("3", "4", "5") + val valueList = allowedValues + .map { + _.replaceAll("'", "\\\\\'") + } + .mkString("'", "','", "'") + + val predicate = s"`$column` IS NULL OR `$column` IN ($valueList)" + + val data = getDfWithNumericValues(session) + + + val att1Compliance = Compliance(s"$column contained in ${allowedValues.mkString(",")}", predicate, + where = Option("att1 < 5"), columns = List("att3"), + analyzerOptions = Option(AnalyzerOptions(filteredRow = FilteredRowOutcome.TRUE))) + val state = att1Compliance.computeStateFrom(data) + val metric: DoubleMetric with FullColumn = att1Compliance.computeMetricFrom(state) + + data.withColumn("new", metric.fullColumn.get).collect().map(r => + if (r == null) null else r.getAs[Int]("new")) shouldBe Seq(0, 0, 1, 1, 1, 1) + } } } diff --git a/src/test/scala/com/amazon/deequ/analyzers/MaxLengthTest.scala b/src/test/scala/com/amazon/deequ/analyzers/MaxLengthTest.scala index f1fec85a4..456bd4c67 100644 --- a/src/test/scala/com/amazon/deequ/analyzers/MaxLengthTest.scala +++ b/src/test/scala/com/amazon/deequ/analyzers/MaxLengthTest.scala @@ -74,6 +74,84 @@ class MaxLengthTest extends AnyWordSpec with Matchers with SparkContextSpec with data.withColumn("new", metric.fullColumn.get) .collect().map(_.getAs[Double]("new")) shouldBe Seq(0.0, 0.0, 0.0, 0.0, 0.0, 0.0) } + + "return row-level results with NullBehavior fail and filtered as true" in withSparkSession { session => + + val data = getEmptyColumnDataDf(session) + + val addressLength = MaxLength("att3", Option("id < 4"), Option(AnalyzerOptions(NullBehavior.Fail))) + val state: Option[MaxState] = addressLength.computeStateFrom(data) + val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) + + data.withColumn("new", metric.fullColumn.get) + .collect().map(_.getAs[Any]("new")) shouldBe + Seq(1.0, 1.0, Double.MaxValue, 1.0, Double.MinValue, Double.MinValue) + } + + "return row-level results with NullBehavior fail and filtered as null" in withSparkSession { session => + + val data = getEmptyColumnDataDf(session) + + val addressLength = MaxLength("att3", Option("id < 4"), + Option(AnalyzerOptions(NullBehavior.Fail, FilteredRowOutcome.NULL))) + val state: Option[MaxState] = addressLength.computeStateFrom(data) + val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) + + data.withColumn("new", metric.fullColumn.get) + .collect().map(_.getAs[Any]("new")) shouldBe Seq(1.0, 1.0, Double.MaxValue, 1.0, null, null) + } + + "return row-level results with NullBehavior empty and filtered as true" in withSparkSession { session => + + val data = getEmptyColumnDataDf(session) + + val addressLength = MaxLength("att3", Option("id < 4"), Option(AnalyzerOptions(NullBehavior.EmptyString))) + val state: Option[MaxState] = addressLength.computeStateFrom(data) + val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) + + data.withColumn("new", metric.fullColumn.get) + .collect().map(_.getAs[Any]("new")) shouldBe + Seq(1.0, 1.0, 0.0, 1.0, Double.MinValue, Double.MinValue) + } + + "return row-level results with NullBehavior empty and filtered as null" in withSparkSession { session => + + val data = getEmptyColumnDataDf(session) + + val addressLength = MaxLength("att3", Option("id < 4"), + Option(AnalyzerOptions(NullBehavior.EmptyString, FilteredRowOutcome.NULL))) + val state: Option[MaxState] = addressLength.computeStateFrom(data) + val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) + + data.withColumn("new", metric.fullColumn.get) + .collect().map(_.getAs[Any]("new")) shouldBe Seq(1.0, 1.0, 0.0, 1.0, null, null) + } + + "return row-level results with NullBehavior ignore and filtered as true" in withSparkSession { session => + + val data = getEmptyColumnDataDf(session) + + val addressLength = MaxLength("att3", Option("id < 4"), Option(AnalyzerOptions(NullBehavior.Ignore))) + val state: Option[MaxState] = addressLength.computeStateFrom(data) + val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) + + data.withColumn("new", metric.fullColumn.get) + .collect().map(_.getAs[Any]("new")) shouldBe + Seq(1.0, 1.0, null, 1.0, Double.MinValue, Double.MinValue) + } + + "return row-level results with NullBehavior ignore and filtered as null" in withSparkSession { session => + + val data = getEmptyColumnDataDf(session) + + val addressLength = MaxLength("att3", Option("id < 4"), + Option(AnalyzerOptions(NullBehavior.Ignore, FilteredRowOutcome.NULL))) + val state: Option[MaxState] = addressLength.computeStateFrom(data) + val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) + + data.withColumn("new", metric.fullColumn.get) + .collect().map(_.getAs[Any]("new")) shouldBe Seq(1.0, 1.0, null, 1.0, null, null) + } } } diff --git a/src/test/scala/com/amazon/deequ/analyzers/MaximumTest.scala b/src/test/scala/com/amazon/deequ/analyzers/MaximumTest.scala index d88b4b532..6ac90f735 100644 --- a/src/test/scala/com/amazon/deequ/analyzers/MaximumTest.scala +++ b/src/test/scala/com/amazon/deequ/analyzers/MaximumTest.scala @@ -51,5 +51,36 @@ class MaximumTest extends AnyWordSpec with Matchers with SparkContextSpec with F if (r == null) null else r.getAs[Double]("new")) shouldBe Seq(null, null, null, 5.0, 6.0, 7.0) } + + "return row-level results for columns with where clause filtered as true" in withSparkSession { session => + + val data = getDfWithNumericValues(session) + + val att1Maximum = Maximum("att1", Option("item < 4")) + val state: Option[MaxState] = att1Maximum.computeStateFrom(data, Option("item < 4")) + val metric: DoubleMetric with FullColumn = att1Maximum.computeMetricFrom(state) + + val result = data.withColumn("new", metric.fullColumn.get) + result.show(false) + result.collect().map(r => + if (r == null) null else r.getAs[Double]("new")) shouldBe + Seq(1.0, 2.0, 3.0, Double.MinValue, Double.MinValue, Double.MinValue) + } + + "return row-level results for columns with where clause filtered as null" in withSparkSession { session => + + val data = getDfWithNumericValues(session) + + val att1Maximum = Maximum("att1", Option("item < 4"), + Option(AnalyzerOptions(filteredRow = FilteredRowOutcome.NULL))) + val state: Option[MaxState] = att1Maximum.computeStateFrom(data, Option("item < 4")) + val metric: DoubleMetric with FullColumn = att1Maximum.computeMetricFrom(state) + + val result = data.withColumn("new", metric.fullColumn.get) + result.show(false) + result.collect().map(r => + if (r == null) null else r.getAs[Double]("new")) shouldBe + Seq(1.0, 2.0, 3.0, null, null, null) + } } } diff --git a/src/test/scala/com/amazon/deequ/analyzers/MinLengthTest.scala b/src/test/scala/com/amazon/deequ/analyzers/MinLengthTest.scala index 84228e7e7..0f88e377f 100644 --- a/src/test/scala/com/amazon/deequ/analyzers/MinLengthTest.scala +++ b/src/test/scala/com/amazon/deequ/analyzers/MinLengthTest.scala @@ -56,13 +56,13 @@ class MinLengthTest extends AnyWordSpec with Matchers with SparkContextSpec with val data = getEmptyColumnDataDf(session) // It's null in two rows - val addressLength = MinLength("att3") + val addressLength = MinLength("att3", None, Option(AnalyzerOptions(NullBehavior.Fail))) val state: Option[MinState] = addressLength.computeStateFrom(data) val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) data.withColumn("new", metric.fullColumn.get) - .collect().map( r => if (r == null) null else r.getAs[Double]("new") - ) shouldBe Seq(1.0, 1.0, null, 1.0, null, 1.0) + .collect().map(_.getAs[Double]("new") + ) shouldBe Seq(1.0, 1.0, Double.MinValue, 1.0, Double.MinValue, 1.0) } "return row-level results for null columns with NullBehavior empty option" in withSparkSession { session => @@ -70,7 +70,7 @@ class MinLengthTest extends AnyWordSpec with Matchers with SparkContextSpec with val data = getEmptyColumnDataDf(session) // It's null in two rows - val addressLength = MinLength("att3") + val addressLength = MinLength("att3", None, Option(AnalyzerOptions(NullBehavior.EmptyString))) val state: Option[MinState] = addressLength.computeStateFrom(data) val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) @@ -89,5 +89,84 @@ class MinLengthTest extends AnyWordSpec with Matchers with SparkContextSpec with data.withColumn("new", metric.fullColumn.get) .collect().map(_.getAs[Double]("new")) shouldBe Seq(0.0, 0.0, 0.0, 0.0, 0.0, 0.0) } + + "return row-level results with NullBehavior fail and filtered as true" in withSparkSession { session => + + val data = getEmptyColumnDataDf(session) + + val addressLength = MinLength("att3", Option("id < 4"), Option(AnalyzerOptions(NullBehavior.Fail))) + val state: Option[MinState] = addressLength.computeStateFrom(data) + val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) + + data.withColumn("new", metric.fullColumn.get) + .collect().map(_.getAs[Any]("new")) shouldBe + Seq(1.0, 1.0, Double.MinValue, 1.0, Double.MaxValue, Double.MaxValue) + } + + "return row-level results with NullBehavior fail and filtered as null" in withSparkSession { session => + + val data = getEmptyColumnDataDf(session) + + val addressLength = MinLength("att3", Option("id < 4"), + Option(AnalyzerOptions(NullBehavior.Fail, FilteredRowOutcome.NULL))) + val state: Option[MinState] = addressLength.computeStateFrom(data) + val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) + + data.withColumn("new", metric.fullColumn.get) + .collect().map(_.getAs[Any]("new")) shouldBe + Seq(1.0, 1.0, Double.MinValue, 1.0, null, null) + } + + "return row-level results with NullBehavior empty and filtered as true" in withSparkSession { session => + + val data = getEmptyColumnDataDf(session) + + val addressLength = MinLength("att3", Option("id < 4"), Option(AnalyzerOptions(NullBehavior.EmptyString))) + val state: Option[MinState] = addressLength.computeStateFrom(data) + val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) + + data.withColumn("new", metric.fullColumn.get) + .collect().map(_.getAs[Any]("new")) shouldBe + Seq(1.0, 1.0, 0.0, 1.0, Double.MaxValue, Double.MaxValue) + } + + "return row-level results with NullBehavior empty and filtered as null" in withSparkSession { session => + + val data = getEmptyColumnDataDf(session) + + val addressLength = MinLength("att3", Option("id < 4"), + Option(AnalyzerOptions(NullBehavior.EmptyString, FilteredRowOutcome.NULL))) + val state: Option[MinState] = addressLength.computeStateFrom(data) + val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) + + data.withColumn("new", metric.fullColumn.get) + .collect().map(_.getAs[Any]("new")) shouldBe Seq(1.0, 1.0, 0.0, 1.0, null, null) + } + + "return row-level results NullBehavior ignore and filtered as true" in withSparkSession { session => + + val data = getEmptyColumnDataDf(session) + + val addressLength = MinLength("att3", Option("id < 4"), Option(AnalyzerOptions(NullBehavior.Ignore))) + val state: Option[MinState] = addressLength.computeStateFrom(data) + val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) + + data.withColumn("new", metric.fullColumn.get) + .collect().map(_.getAs[Any]("new")) shouldBe + Seq(1.0, 1.0, null, 1.0, Double.MaxValue, Double.MaxValue) + } + + "return row-level results with NullBehavior ignore and filtered as null" in withSparkSession { session => + + val data = getEmptyColumnDataDf(session) + + val addressLength = MinLength("att3", Option("id < 4"), + Option(AnalyzerOptions(NullBehavior.Ignore, FilteredRowOutcome.NULL))) + val state: Option[MinState] = addressLength.computeStateFrom(data) + val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) + + data.withColumn("new", metric.fullColumn.get) + .collect().map(_.getAs[Any]("new")) shouldBe Seq(1.0, 1.0, null, 1.0, null, null) + } } } diff --git a/src/test/scala/com/amazon/deequ/analyzers/MinimumTest.scala b/src/test/scala/com/amazon/deequ/analyzers/MinimumTest.scala index 6d495aa0f..435542e8c 100644 --- a/src/test/scala/com/amazon/deequ/analyzers/MinimumTest.scala +++ b/src/test/scala/com/amazon/deequ/analyzers/MinimumTest.scala @@ -52,6 +52,39 @@ class MinimumTest extends AnyWordSpec with Matchers with SparkContextSpec with F if (r == null) null else r.getAs[Double]("new")) shouldBe Seq(null, null, null, 5.0, 6.0, 7.0) } + + "return row-level results for columns with where clause filtered as true" in withSparkSession { session => + + val data = getDfWithNumericValues(session) + + val att1Minimum = Minimum("att1", Option("item < 4")) + val state: Option[MinState] = att1Minimum.computeStateFrom(data, Option("item < 4")) + print(state) + val metric: DoubleMetric with FullColumn = att1Minimum.computeMetricFrom(state) + + val result = data.withColumn("new", metric.fullColumn.get) + result.show(false) + result.collect().map(r => + if (r == null) null else r.getAs[Double]("new")) shouldBe + Seq(1.0, 2.0, 3.0, Double.MaxValue, Double.MaxValue, Double.MaxValue) + } + + "return row-level results for columns with where clause filtered as null" in withSparkSession { session => + + val data = getDfWithNumericValues(session) + + val att1Minimum = Minimum("att1", Option("item < 4"), + Option(AnalyzerOptions(filteredRow = FilteredRowOutcome.NULL))) + val state: Option[MinState] = att1Minimum.computeStateFrom(data, Option("item < 4")) + print(state) + val metric: DoubleMetric with FullColumn = att1Minimum.computeMetricFrom(state) + + val result = data.withColumn("new", metric.fullColumn.get) + result.show(false) + result.collect().map(r => + if (r == null) null else r.getAs[Double]("new")) shouldBe + Seq(1.0, 2.0, 3.0, null, null, null) + } } } diff --git a/src/test/scala/com/amazon/deequ/analyzers/PatternMatchTest.scala b/src/test/scala/com/amazon/deequ/analyzers/PatternMatchTest.scala index e01235597..94d439674 100644 --- a/src/test/scala/com/amazon/deequ/analyzers/PatternMatchTest.scala +++ b/src/test/scala/com/amazon/deequ/analyzers/PatternMatchTest.scala @@ -33,10 +33,35 @@ class PatternMatchTest extends AnyWordSpec with Matchers with SparkContextSpec w val state = patternMatchCountry.computeStateFrom(data) val metric: DoubleMetric with FullColumn = patternMatchCountry.computeMetricFrom(state) - data.withColumn("new", metric.fullColumn.get).collect().map(_.getAs[Boolean]("new")) shouldBe + data.withColumn("new", metric.fullColumn.get).collect().map(_.getAs[Any]("new")) shouldBe Seq(true, true, true, true, true, true, true, true) } + "return row-level results for non-null columns starts with digit" in withSparkSession { session => + + val data = getDfWithStringColumns(session) + + val patternMatchCountry = PatternMatch("Address Line 1", """(^[0-4])""".r) + val state = patternMatchCountry.computeStateFrom(data) + val metric: DoubleMetric with FullColumn = patternMatchCountry.computeMetricFrom(state) + + data.withColumn("new", metric.fullColumn.get).collect().map(_.getAs[Any]("new")) shouldBe + Seq(false, false, true, true, false, false, true, true) + } + + "return row-level results for non-null columns starts with digit filtered as true" in withSparkSession { session => + + val data = getDfWithStringColumns(session) + + val patternMatchCountry = PatternMatch("Address Line 1", """(^[0-4])""".r, where = Option("id < 5"), + analyzerOptions = Option(AnalyzerOptions(filteredRow = FilteredRowOutcome.TRUE))) + val state = patternMatchCountry.computeStateFrom(data) + val metric: DoubleMetric with FullColumn = patternMatchCountry.computeMetricFrom(state) + + data.withColumn("new", metric.fullColumn.get).collect().map(_.getAs[Any]("new")) shouldBe + Seq(false, false, true, true, false, true, true, true) + } + "return row-level results for columns with nulls" in withSparkSession { session => val data = getDfWithStringColumns(session) @@ -45,8 +70,34 @@ class PatternMatchTest extends AnyWordSpec with Matchers with SparkContextSpec w val state = patternMatchCountry.computeStateFrom(data) val metric: DoubleMetric with FullColumn = patternMatchCountry.computeMetricFrom(state) - data.withColumn("new", metric.fullColumn.get).collect().map(_.getAs[Boolean]("new")) shouldBe + data.withColumn("new", metric.fullColumn.get).collect().map(_.getAs[Any]("new")) shouldBe Seq(true, true, true, true, false, true, true, false) } + + "return row-level results for columns with nulls filtered as true" in withSparkSession { session => + + val data = getDfWithStringColumns(session) + + val patternMatchCountry = PatternMatch("Address Line 2", """\w""".r, where = Option("id < 5"), + analyzerOptions = Option(AnalyzerOptions(filteredRow = FilteredRowOutcome.TRUE))) + val state = patternMatchCountry.computeStateFrom(data) + val metric: DoubleMetric with FullColumn = patternMatchCountry.computeMetricFrom(state) + + data.withColumn("new", metric.fullColumn.get).collect().map(_.getAs[Any]("new")) shouldBe + Seq(true, true, true, true, false, true, true, true) + } + + "return row-level results for columns with nulls filtered as null" in withSparkSession { session => + + val data = getDfWithStringColumns(session) + + val patternMatchCountry = PatternMatch("Address Line 2", """\w""".r, where = Option("id < 5"), + analyzerOptions = Option(AnalyzerOptions(filteredRow = FilteredRowOutcome.NULL))) + val state = patternMatchCountry.computeStateFrom(data) + val metric: DoubleMetric with FullColumn = patternMatchCountry.computeMetricFrom(state) + + data.withColumn("new", metric.fullColumn.get).collect().map(_.getAs[Any]("new")) shouldBe + Seq(true, true, true, true, false, null, null, null) + } } } diff --git a/src/test/scala/com/amazon/deequ/analyzers/UniquenessTest.scala b/src/test/scala/com/amazon/deequ/analyzers/UniquenessTest.scala index d50995b55..4aea6bb27 100644 --- a/src/test/scala/com/amazon/deequ/analyzers/UniquenessTest.scala +++ b/src/test/scala/com/amazon/deequ/analyzers/UniquenessTest.scala @@ -123,7 +123,7 @@ class UniquenessTest extends AnyWordSpec with Matchers with SparkContextSpec wit val data = getDfWithUniqueColumns(session) val addressLength = Uniqueness(Seq("onlyUniqueWithOtherNonUnique"), Option("unique < 4"), - Option(AnalyzerOptions(filteredRow = FilteredRow.NULL))) + Option(AnalyzerOptions(filteredRow = FilteredRowOutcome.NULL))) val state: Option[FrequenciesAndNumRows] = addressLength.computeStateFrom(data, Option("unique < 4")) val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) @@ -139,7 +139,7 @@ class UniquenessTest extends AnyWordSpec with Matchers with SparkContextSpec wit val data = getDfWithUniqueColumns(session) val addressLength = Uniqueness(Seq("halfUniqueCombinedWithNonUnique", "nonUnique"), Option("unique > 2"), - Option(AnalyzerOptions(filteredRow = FilteredRow.NULL))) + Option(AnalyzerOptions(filteredRow = FilteredRowOutcome.NULL))) val state: Option[FrequenciesAndNumRows] = addressLength.computeStateFrom(data, Option("unique > 2")) val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) diff --git a/src/test/scala/com/amazon/deequ/analyzers/runners/AnalysisRunnerTests.scala b/src/test/scala/com/amazon/deequ/analyzers/runners/AnalysisRunnerTests.scala index ce9bda69b..193dbaebe 100644 --- a/src/test/scala/com/amazon/deequ/analyzers/runners/AnalysisRunnerTests.scala +++ b/src/test/scala/com/amazon/deequ/analyzers/runners/AnalysisRunnerTests.scala @@ -204,10 +204,9 @@ class AnalysisRunnerTests extends AnyWordSpec // Used to be tested with the above line, but adding filters changed the order of the results. assert(separateResults.asInstanceOf[Set[DoubleMetric]].size == runnerResults.asInstanceOf[Set[DoubleMetric]].size) - separateResults.asInstanceOf[Set[DoubleMetric]].foreach( result => { - assert(runnerResults.toString.contains(result.toString)) - } - ) + separateResults.asInstanceOf[Set[DoubleMetric]].foreach( result => + assert(runnerResults.toString.contains(result.toString)) + ) } "reuse existing results" in diff --git a/src/test/scala/com/amazon/deequ/analyzers/runners/AnalyzerContextTest.scala b/src/test/scala/com/amazon/deequ/analyzers/runners/AnalyzerContextTest.scala index 9133d5ae4..3054d141c 100644 --- a/src/test/scala/com/amazon/deequ/analyzers/runners/AnalyzerContextTest.scala +++ b/src/test/scala/com/amazon/deequ/analyzers/runners/AnalyzerContextTest.scala @@ -146,7 +146,5 @@ class AnalyzerContextTest extends AnyWordSpec private[this] def assertSameJson(jsonA: String, jsonB: String): Unit = { assert(SimpleResultSerde.deserialize(jsonA).toSet.sameElements(SimpleResultSerde.deserialize(jsonB).toSet)) - // assert(SimpleResultSerde.deserialize(jsonA) == - // SimpleResultSerde.deserialize(jsonB)) } } diff --git a/src/test/scala/com/amazon/deequ/suggestions/ConstraintSuggestionResultTest.scala b/src/test/scala/com/amazon/deequ/suggestions/ConstraintSuggestionResultTest.scala index 9a82903e8..0cd76c8de 100644 --- a/src/test/scala/com/amazon/deequ/suggestions/ConstraintSuggestionResultTest.scala +++ b/src/test/scala/com/amazon/deequ/suggestions/ConstraintSuggestionResultTest.scala @@ -255,7 +255,7 @@ class ConstraintSuggestionResultTest extends WordSpec with Matchers with SparkCo | }, | { | "constraint_name": "ComplianceConstraint(Compliance(\u0027item\u0027 has no - | negative values,item \u003e\u003d 0,None,List(item)))", + | negative values,item \u003e\u003d 0,None,List(item),None))", | "column_name": "item", | "current_value": "Minimum: 1.0", | "description": "\u0027item\u0027 has no negative values", @@ -341,7 +341,7 @@ class ConstraintSuggestionResultTest extends WordSpec with Matchers with SparkCo | }, | { | "constraint_name": "ComplianceConstraint(Compliance(\u0027item\u0027 has no - | negative values,item \u003e\u003d 0,None,List(item)))", + | negative values,item \u003e\u003d 0,None,List(item),None))", | "column_name": "item", | "current_value": "Minimum: 1.0", | "description": "\u0027item\u0027 has no negative values", @@ -428,7 +428,7 @@ class ConstraintSuggestionResultTest extends WordSpec with Matchers with SparkCo | }, | { | "constraint_name": "ComplianceConstraint(Compliance(\u0027item\u0027 has no - | negative values,item \u003e\u003d 0,None,List(item)))", + | negative values,item \u003e\u003d 0,None,List(item),None))", | "column_name": "item", | "current_value": "Minimum: 1.0", | "description": "\u0027item\u0027 has no negative values", @@ -494,7 +494,7 @@ class ConstraintSuggestionResultTest extends WordSpec with Matchers with SparkCo | }, | { | "constraint_name": "ComplianceConstraint(Compliance(\u0027`item.one`\u0027 has no - | negative values,`item.one` \u003e\u003d 0,None,List(`item.one`)))", + | negative values,`item.one` \u003e\u003d 0,None,List(`item.one`),None))", | "column_name": "`item.one`", | "current_value": "Minimum: 1.0", | "description": "\u0027`item.one`\u0027 has no negative values", diff --git a/src/test/scala/com/amazon/deequ/utils/FixtureSupport.scala b/src/test/scala/com/amazon/deequ/utils/FixtureSupport.scala index 601134a53..5c56ed4b0 100644 --- a/src/test/scala/com/amazon/deequ/utils/FixtureSupport.scala +++ b/src/test/scala/com/amazon/deequ/utils/FixtureSupport.scala @@ -32,13 +32,13 @@ trait FixtureSupport { import sparkSession.implicits._ Seq( - ("", "a", "f"), - ("", "b", "d"), - ("", "a", null), - ("", "a", "f"), - ("", "b", null), - ("", "a", "f") - ).toDF("att1", "att2", "att3") + (0, "", "a", "f"), + (1, "", "b", "d"), + (2, "", "a", null), + (3, "", "a", "f"), + (4, "", "b", null), + (5, "", "a", "f") + ).toDF("id", "att1", "att2", "att3") } def getDfEmpty(sparkSession: SparkSession): DataFrame = { @@ -159,6 +159,19 @@ trait FixtureSupport { ).toDF("item", "att1", "att2") } + def getDfCompleteAndInCompleteColumnsWithIntId(sparkSession: SparkSession): DataFrame = { + import sparkSession.implicits._ + + Seq( + (1, "a", "f"), + (2, "b", "d"), + (3, "a", null), + (4, "a", "f"), + (5, "b", null), + (6, "a", "f") + ).toDF("item", "att1", "att2") + } + def getDfCompleteAndInCompleteColumnsWithSpacesInNames(sparkSession: SparkSession): DataFrame = { import sparkSession.implicits._ @@ -399,16 +412,16 @@ trait FixtureSupport { import sparkSession.implicits._ Seq( - ("India", "Xavier House, 2nd Floor", "St. Peter Colony, Perry Road", "Bandra (West)"), - ("India", "503 Godavari", "Sir Pochkhanwala Road", "Worli"), - ("India", "4/4 Seema Society", "N Dutta Road, Four Bungalows", "Andheri"), - ("India", "1001D Abhishek Apartments", "Juhu Versova Road", "Andheri"), - ("India", "95, Hill Road", null, null), - ("India", "90 Cuffe Parade", "Taj President Hotel", "Cuffe Parade"), - ("India", "4, Seven PM", "Sir Pochkhanwala Rd", "Worli"), - ("India", "1453 Sahar Road", null, null) + (0, "India", "Xavier House, 2nd Floor", "St. Peter Colony, Perry Road", "Bandra (West)"), + (1, "India", "503 Godavari", "Sir Pochkhanwala Road", "Worli"), + (2, "India", "4/4 Seema Society", "N Dutta Road, Four Bungalows", "Andheri"), + (3, "India", "1001D Abhishek Apartments", "Juhu Versova Road", "Andheri"), + (4, "India", "95, Hill Road", null, null), + (5, "India", "90 Cuffe Parade", "Taj President Hotel", "Cuffe Parade"), + (6, "India", "4, Seven PM", "Sir Pochkhanwala Rd", "Worli"), + (7, "India", "1453 Sahar Road", null, null) ) - .toDF("Country", "Address Line 1", "Address Line 2", "Address Line 3") + .toDF("id", "Country", "Address Line 1", "Address Line 2", "Address Line 3") } def getDfWithPeriodInName(sparkSession: SparkSession): DataFrame = {