diff --git a/src/main/scala/com/amazon/deequ/VerificationRunBuilder.scala b/src/main/scala/com/amazon/deequ/VerificationRunBuilder.scala index 55800080..fb31651c 100644 --- a/src/main/scala/com/amazon/deequ/VerificationRunBuilder.scala +++ b/src/main/scala/com/amazon/deequ/VerificationRunBuilder.scala @@ -22,14 +22,13 @@ import com.amazon.deequ.analyzers.{State, _} import com.amazon.deequ.checks.{Check, CheckLevel} import com.amazon.deequ.metrics.Metric import com.amazon.deequ.repository._ -import com.amazon.deequ.utilities.FilteredRow import com.amazon.deequ.utilities.FilteredRow.FilteredRow import com.amazon.deequ.utilities.RowLevelFilterTreatment import com.amazon.deequ.utilities.RowLevelFilterTreatmentImpl import org.apache.spark.sql.{DataFrame, SparkSession} /** A class to build a VerificationRun using a fluent API */ -class VerificationRunBuilder(val data: DataFrame) extends RowLevelFilterTreatment { +class VerificationRunBuilder(val data: DataFrame) { protected var requiredAnalyzers: Seq[Analyzer[_, Metric[_]]] = Seq.empty @@ -48,6 +47,7 @@ class VerificationRunBuilder(val data: DataFrame) extends RowLevelFilterTreatmen protected var statePersister: Option[StatePersister] = None protected var stateLoader: Option[StateLoader] = None + protected var rowLevelFilterTreatment: RowLevelFilterTreatment = RowLevelFilterTreatment.sharedInstance protected def this(verificationRunBuilder: VerificationRunBuilder) { @@ -70,6 +70,7 @@ class VerificationRunBuilder(val data: DataFrame) extends RowLevelFilterTreatmen stateLoader = verificationRunBuilder.stateLoader statePersister = verificationRunBuilder.statePersister + rowLevelFilterTreatment = verificationRunBuilder.rowLevelFilterTreatment } /** @@ -146,11 +147,10 @@ class VerificationRunBuilder(val data: DataFrame) extends RowLevelFilterTreatmen */ def withRowLevelFilterTreatment(filteredRow: FilteredRow): this.type = { RowLevelFilterTreatment.setSharedInstance(new RowLevelFilterTreatmentImpl(filteredRow)) + rowLevelFilterTreatment = RowLevelFilterTreatment.sharedInstance this } - def rowLevelFilterTreatment: FilteredRow.Value = RowLevelFilterTreatment.sharedInstance.rowLevelFilterTreatment - /** * Set a metrics repository associated with the current data to enable features like reusing * previously computed results and storing the results of the current run. diff --git a/src/main/scala/com/amazon/deequ/analyzers/Analyzer.scala b/src/main/scala/com/amazon/deequ/analyzers/Analyzer.scala index 327a7a14..02857942 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/Analyzer.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/Analyzer.scala @@ -67,7 +67,7 @@ trait DoubleValuedState[S <: DoubleValuedState[S]] extends State[S] { } /** Common trait for all analyzers which generates metrics from states computed on data frames */ -trait Analyzer[S <: State[_], +M <: Metric[_]] extends Serializable with RowLevelFilterTreatment { +trait Analyzer[S <: State[_], +M <: Metric[_]] extends Serializable { /** * Compute the state (sufficient statistics) from the data @@ -178,15 +178,6 @@ trait Analyzer[S <: State[_], +M <: Metric[_]] extends Serializable with RowLeve private[deequ] def copyStateTo(source: StateLoader, target: StatePersister): Unit = { source.load[S](this).foreach { state => target.persist(this, state) } } - - @VisibleForTesting - private[deequ] def withRowLevelFilterTreatment(filteredRow: FilteredRow): this.type = { - RowLevelFilterTreatment.setSharedInstance(new RowLevelFilterTreatmentImpl(filteredRow)) - this - } - - def rowLevelFilterTreatment: FilteredRow.Value = RowLevelFilterTreatment.sharedInstance.rowLevelFilterTreatment - } /** An analyzer that runs a set of aggregation functions over the data, diff --git a/src/main/scala/com/amazon/deequ/analyzers/Completeness.scala b/src/main/scala/com/amazon/deequ/analyzers/Completeness.scala index 7107c834..f385da45 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/Completeness.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/Completeness.scala @@ -20,6 +20,10 @@ import com.amazon.deequ.analyzers.Preconditions.{hasColumn, isNotNested} import org.apache.spark.sql.functions.sum import org.apache.spark.sql.types.{IntegerType, StructType} import Analyzers._ +import com.amazon.deequ.utilities.FilteredRow.FilteredRow +import com.amazon.deequ.utilities.RowLevelAnalyzer +import com.amazon.deequ.utilities.RowLevelFilterTreatment +import com.amazon.deequ.utilities.RowLevelFilterTreatmentImpl import com.google.common.annotations.VisibleForTesting import org.apache.spark.sql.functions.col import org.apache.spark.sql.functions.expr @@ -28,7 +32,7 @@ import org.apache.spark.sql.{Column, Row} /** Completeness is the fraction of non-null values in a column of a DataFrame. */ case class Completeness(column: String, where: Option[String] = None) extends StandardScanShareableAnalyzer[NumMatchesAndCount]("Completeness", column) with - FilterableAnalyzer { + FilterableAnalyzer with RowLevelAnalyzer { override def fromAggregationResult(result: Row, offset: Int): Option[NumMatchesAndCount] = { ifNoNullsIn(result, offset, howMany = 2) { _ => @@ -58,4 +62,9 @@ case class Completeness(column: String, where: Option[String] = None) extends conditionalSelectionFilteredFromColumns(col(column).isNotNull, whereCondition, rowLevelFilterTreatment.toString) } + @VisibleForTesting + private[deequ] def withRowLevelFilterTreatment(filteredRow: FilteredRow): this.type = { + RowLevelFilterTreatment.setSharedInstance(new RowLevelFilterTreatmentImpl(filteredRow)) + this + } } diff --git a/src/main/scala/com/amazon/deequ/analyzers/UniqueValueRatio.scala b/src/main/scala/com/amazon/deequ/analyzers/UniqueValueRatio.scala index 6cfdc638..b3d1d701 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/UniqueValueRatio.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/UniqueValueRatio.scala @@ -18,6 +18,7 @@ package com.amazon.deequ.analyzers import com.amazon.deequ.analyzers.Analyzers.COUNT_COL import com.amazon.deequ.metrics.DoubleMetric +import com.amazon.deequ.utilities.RowLevelAnalyzer import org.apache.spark.sql.functions.expr import org.apache.spark.sql.functions.not import org.apache.spark.sql.functions.when @@ -27,7 +28,7 @@ import org.apache.spark.sql.types.DoubleType case class UniqueValueRatio(columns: Seq[String], where: Option[String] = None) extends ScanShareableFrequencyBasedAnalyzer("UniqueValueRatio", columns) - with FilterableAnalyzer { + with FilterableAnalyzer with RowLevelAnalyzer { override def aggregationFunctions(numRows: Long): Seq[Column] = { sum(col(COUNT_COL).equalTo(lit(1)).cast(DoubleType)) :: count("*") :: Nil diff --git a/src/main/scala/com/amazon/deequ/analyzers/Uniqueness.scala b/src/main/scala/com/amazon/deequ/analyzers/Uniqueness.scala index f62476da..16ec6d7b 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/Uniqueness.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/Uniqueness.scala @@ -18,6 +18,12 @@ package com.amazon.deequ.analyzers import com.amazon.deequ.analyzers.Analyzers.COUNT_COL import com.amazon.deequ.metrics.DoubleMetric +import com.amazon.deequ.utilities.FilteredRow +import com.amazon.deequ.utilities.FilteredRow.FilteredRow +import com.amazon.deequ.utilities.RowLevelAnalyzer +import com.amazon.deequ.utilities.RowLevelFilterTreatment +import com.amazon.deequ.utilities.RowLevelFilterTreatmentImpl +import com.google.common.annotations.VisibleForTesting import org.apache.spark.sql.Column import org.apache.spark.sql.Row import org.apache.spark.sql.functions.when @@ -32,7 +38,7 @@ import org.apache.spark.sql.types.DoubleType * values that occur exactly once. */ case class Uniqueness(columns: Seq[String], where: Option[String] = None) extends ScanShareableFrequencyBasedAnalyzer("Uniqueness", columns) - with FilterableAnalyzer { + with FilterableAnalyzer with RowLevelAnalyzer { override def aggregationFunctions(numRows: Long): Seq[Column] = { (sum(col(COUNT_COL).equalTo(lit(1)).cast(DoubleType)) / numRows) :: Nil @@ -50,6 +56,13 @@ case class Uniqueness(columns: Seq[String], where: Option[String] = None) } override def filterCondition: Option[String] = where + + + @VisibleForTesting + private[deequ] def withRowLevelFilterTreatment(filteredRow: FilteredRow): this.type = { + RowLevelFilterTreatment.setSharedInstance(new RowLevelFilterTreatmentImpl(filteredRow)) + this + } } object Uniqueness { diff --git a/src/main/scala/com/amazon/deequ/utilities/RowLevelFilterTreatement.scala b/src/main/scala/com/amazon/deequ/utilities/RowLevelFilterTreatement.scala index 45ce0ce9..c37e72e6 100644 --- a/src/main/scala/com/amazon/deequ/utilities/RowLevelFilterTreatement.scala +++ b/src/main/scala/com/amazon/deequ/utilities/RowLevelFilterTreatement.scala @@ -46,3 +46,7 @@ object FilteredRow extends Enumeration { type FilteredRow = Value val NULL, TRUE = Value } + +trait RowLevelAnalyzer extends RowLevelFilterTreatment { + def rowLevelFilterTreatment: FilteredRow.Value = RowLevelFilterTreatment.sharedInstance.rowLevelFilterTreatment +}