Skip to content

Commit

Permalink
Modify VerificationRunBuilder to have RowLevelFilterTreatment as vari…
Browse files Browse the repository at this point in the history
…able instead of extending, create RowLevelAnalyzer trait
  • Loading branch information
eycho-am committed Feb 14, 2024
1 parent 5f715ac commit e5b7821
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 17 deletions.
8 changes: 4 additions & 4 deletions src/main/scala/com/amazon/deequ/VerificationRunBuilder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,13 @@ import com.amazon.deequ.analyzers.{State, _}
import com.amazon.deequ.checks.{Check, CheckLevel}
import com.amazon.deequ.metrics.Metric
import com.amazon.deequ.repository._
import com.amazon.deequ.utilities.FilteredRow
import com.amazon.deequ.utilities.FilteredRow.FilteredRow
import com.amazon.deequ.utilities.RowLevelFilterTreatment
import com.amazon.deequ.utilities.RowLevelFilterTreatmentImpl
import org.apache.spark.sql.{DataFrame, SparkSession}

/** A class to build a VerificationRun using a fluent API */
class VerificationRunBuilder(val data: DataFrame) extends RowLevelFilterTreatment {
class VerificationRunBuilder(val data: DataFrame) {

protected var requiredAnalyzers: Seq[Analyzer[_, Metric[_]]] = Seq.empty

Expand All @@ -48,6 +47,7 @@ class VerificationRunBuilder(val data: DataFrame) extends RowLevelFilterTreatmen

protected var statePersister: Option[StatePersister] = None
protected var stateLoader: Option[StateLoader] = None
protected var rowLevelFilterTreatment: RowLevelFilterTreatment = RowLevelFilterTreatment.sharedInstance

protected def this(verificationRunBuilder: VerificationRunBuilder) {

Expand All @@ -70,6 +70,7 @@ class VerificationRunBuilder(val data: DataFrame) extends RowLevelFilterTreatmen

stateLoader = verificationRunBuilder.stateLoader
statePersister = verificationRunBuilder.statePersister
rowLevelFilterTreatment = verificationRunBuilder.rowLevelFilterTreatment
}

/**
Expand Down Expand Up @@ -146,11 +147,10 @@ class VerificationRunBuilder(val data: DataFrame) extends RowLevelFilterTreatmen
*/
def withRowLevelFilterTreatment(filteredRow: FilteredRow): this.type = {
RowLevelFilterTreatment.setSharedInstance(new RowLevelFilterTreatmentImpl(filteredRow))
rowLevelFilterTreatment = RowLevelFilterTreatment.sharedInstance
this
}

def rowLevelFilterTreatment: FilteredRow.Value = RowLevelFilterTreatment.sharedInstance.rowLevelFilterTreatment

/**
* Set a metrics repository associated with the current data to enable features like reusing
* previously computed results and storing the results of the current run.
Expand Down
11 changes: 1 addition & 10 deletions src/main/scala/com/amazon/deequ/analyzers/Analyzer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ trait DoubleValuedState[S <: DoubleValuedState[S]] extends State[S] {
}

/** Common trait for all analyzers which generates metrics from states computed on data frames */
trait Analyzer[S <: State[_], +M <: Metric[_]] extends Serializable with RowLevelFilterTreatment {
trait Analyzer[S <: State[_], +M <: Metric[_]] extends Serializable {

/**
* Compute the state (sufficient statistics) from the data
Expand Down Expand Up @@ -178,15 +178,6 @@ trait Analyzer[S <: State[_], +M <: Metric[_]] extends Serializable with RowLeve
private[deequ] def copyStateTo(source: StateLoader, target: StatePersister): Unit = {
source.load[S](this).foreach { state => target.persist(this, state) }
}

@VisibleForTesting
private[deequ] def withRowLevelFilterTreatment(filteredRow: FilteredRow): this.type = {
RowLevelFilterTreatment.setSharedInstance(new RowLevelFilterTreatmentImpl(filteredRow))
this
}

def rowLevelFilterTreatment: FilteredRow.Value = RowLevelFilterTreatment.sharedInstance.rowLevelFilterTreatment

}

/** An analyzer that runs a set of aggregation functions over the data,
Expand Down
11 changes: 10 additions & 1 deletion src/main/scala/com/amazon/deequ/analyzers/Completeness.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ import com.amazon.deequ.analyzers.Preconditions.{hasColumn, isNotNested}
import org.apache.spark.sql.functions.sum
import org.apache.spark.sql.types.{IntegerType, StructType}
import Analyzers._
import com.amazon.deequ.utilities.FilteredRow.FilteredRow
import com.amazon.deequ.utilities.RowLevelAnalyzer
import com.amazon.deequ.utilities.RowLevelFilterTreatment
import com.amazon.deequ.utilities.RowLevelFilterTreatmentImpl
import com.google.common.annotations.VisibleForTesting
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.functions.expr
Expand All @@ -28,7 +32,7 @@ import org.apache.spark.sql.{Column, Row}
/** Completeness is the fraction of non-null values in a column of a DataFrame. */
case class Completeness(column: String, where: Option[String] = None) extends
StandardScanShareableAnalyzer[NumMatchesAndCount]("Completeness", column) with
FilterableAnalyzer {
FilterableAnalyzer with RowLevelAnalyzer {

override def fromAggregationResult(result: Row, offset: Int): Option[NumMatchesAndCount] = {
ifNoNullsIn(result, offset, howMany = 2) { _ =>
Expand Down Expand Up @@ -58,4 +62,9 @@ case class Completeness(column: String, where: Option[String] = None) extends
conditionalSelectionFilteredFromColumns(col(column).isNotNull, whereCondition, rowLevelFilterTreatment.toString)
}

@VisibleForTesting
private[deequ] def withRowLevelFilterTreatment(filteredRow: FilteredRow): this.type = {
RowLevelFilterTreatment.setSharedInstance(new RowLevelFilterTreatmentImpl(filteredRow))
this
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package com.amazon.deequ.analyzers

import com.amazon.deequ.analyzers.Analyzers.COUNT_COL
import com.amazon.deequ.metrics.DoubleMetric
import com.amazon.deequ.utilities.RowLevelAnalyzer
import org.apache.spark.sql.functions.expr
import org.apache.spark.sql.functions.not
import org.apache.spark.sql.functions.when
Expand All @@ -27,7 +28,7 @@ import org.apache.spark.sql.types.DoubleType

case class UniqueValueRatio(columns: Seq[String], where: Option[String] = None)
extends ScanShareableFrequencyBasedAnalyzer("UniqueValueRatio", columns)
with FilterableAnalyzer {
with FilterableAnalyzer with RowLevelAnalyzer {

override def aggregationFunctions(numRows: Long): Seq[Column] = {
sum(col(COUNT_COL).equalTo(lit(1)).cast(DoubleType)) :: count("*") :: Nil
Expand Down
15 changes: 14 additions & 1 deletion src/main/scala/com/amazon/deequ/analyzers/Uniqueness.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@ package com.amazon.deequ.analyzers

import com.amazon.deequ.analyzers.Analyzers.COUNT_COL
import com.amazon.deequ.metrics.DoubleMetric
import com.amazon.deequ.utilities.FilteredRow
import com.amazon.deequ.utilities.FilteredRow.FilteredRow
import com.amazon.deequ.utilities.RowLevelAnalyzer
import com.amazon.deequ.utilities.RowLevelFilterTreatment
import com.amazon.deequ.utilities.RowLevelFilterTreatmentImpl
import com.google.common.annotations.VisibleForTesting
import org.apache.spark.sql.Column
import org.apache.spark.sql.Row
import org.apache.spark.sql.functions.when
Expand All @@ -32,7 +38,7 @@ import org.apache.spark.sql.types.DoubleType
* values that occur exactly once. */
case class Uniqueness(columns: Seq[String], where: Option[String] = None)
extends ScanShareableFrequencyBasedAnalyzer("Uniqueness", columns)
with FilterableAnalyzer {
with FilterableAnalyzer with RowLevelAnalyzer {

override def aggregationFunctions(numRows: Long): Seq[Column] = {
(sum(col(COUNT_COL).equalTo(lit(1)).cast(DoubleType)) / numRows) :: Nil
Expand All @@ -50,6 +56,13 @@ case class Uniqueness(columns: Seq[String], where: Option[String] = None)
}

override def filterCondition: Option[String] = where


@VisibleForTesting
private[deequ] def withRowLevelFilterTreatment(filteredRow: FilteredRow): this.type = {
RowLevelFilterTreatment.setSharedInstance(new RowLevelFilterTreatmentImpl(filteredRow))
this
}
}

object Uniqueness {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,7 @@ object FilteredRow extends Enumeration {
type FilteredRow = Value
val NULL, TRUE = Value
}

trait RowLevelAnalyzer extends RowLevelFilterTreatment {
def rowLevelFilterTreatment: FilteredRow.Value = RowLevelFilterTreatment.sharedInstance.rowLevelFilterTreatment
}

0 comments on commit e5b7821

Please sign in to comment.