Skip to content

Commit

Permalink
Do row-level filtering in AnalyzerOptions rather than with RowLevelFi…
Browse files Browse the repository at this point in the history
…lterTreatment trait
  • Loading branch information
eycho-am committed Feb 14, 2024
1 parent e5b7821 commit f856816
Show file tree
Hide file tree
Showing 18 changed files with 167 additions and 171 deletions.
16 changes: 0 additions & 16 deletions src/main/scala/com/amazon/deequ/VerificationRunBuilder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,6 @@ import com.amazon.deequ.analyzers.{State, _}
import com.amazon.deequ.checks.{Check, CheckLevel}
import com.amazon.deequ.metrics.Metric
import com.amazon.deequ.repository._
import com.amazon.deequ.utilities.FilteredRow.FilteredRow
import com.amazon.deequ.utilities.RowLevelFilterTreatment
import com.amazon.deequ.utilities.RowLevelFilterTreatmentImpl
import org.apache.spark.sql.{DataFrame, SparkSession}

/** A class to build a VerificationRun using a fluent API */
Expand All @@ -47,7 +44,6 @@ class VerificationRunBuilder(val data: DataFrame) {

protected var statePersister: Option[StatePersister] = None
protected var stateLoader: Option[StateLoader] = None
protected var rowLevelFilterTreatment: RowLevelFilterTreatment = RowLevelFilterTreatment.sharedInstance

protected def this(verificationRunBuilder: VerificationRunBuilder) {

Expand All @@ -70,7 +66,6 @@ class VerificationRunBuilder(val data: DataFrame) {

stateLoader = verificationRunBuilder.stateLoader
statePersister = verificationRunBuilder.statePersister
rowLevelFilterTreatment = verificationRunBuilder.rowLevelFilterTreatment
}

/**
Expand Down Expand Up @@ -140,17 +135,6 @@ class VerificationRunBuilder(val data: DataFrame) {
this
}

/**
* Sets how row level results will be treated for the Verification run
*
* @param filteredRow enum to determine how filtered row level results are labeled (TRUE | NULL)
*/
def withRowLevelFilterTreatment(filteredRow: FilteredRow): this.type = {
RowLevelFilterTreatment.setSharedInstance(new RowLevelFilterTreatmentImpl(filteredRow))
rowLevelFilterTreatment = RowLevelFilterTreatment.sharedInstance
this
}

/**
* Set a metrics repository associated with the current data to enable features like reusing
* previously computed results and storing the results of the current run.
Expand Down
13 changes: 7 additions & 6 deletions src/main/scala/com/amazon/deequ/analyzers/Analyzer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,14 @@
package com.amazon.deequ.analyzers

import com.amazon.deequ.analyzers.Analyzers._
import com.amazon.deequ.analyzers.FilteredRow.FilteredRow
import com.amazon.deequ.analyzers.NullBehavior.NullBehavior
import com.amazon.deequ.analyzers.runners._
import com.amazon.deequ.metrics.DoubleMetric
import com.amazon.deequ.metrics.Entity
import com.amazon.deequ.metrics.FullColumn
import com.amazon.deequ.metrics.Metric
import com.amazon.deequ.utilities.ColumnUtil.removeEscapeColumn
import com.amazon.deequ.utilities.FilteredRow
import com.amazon.deequ.utilities.FilteredRow.FilteredRow
import com.amazon.deequ.utilities.RowLevelFilterTreatment
import com.amazon.deequ.utilities.RowLevelFilterTreatmentImpl
import com.google.common.annotations.VisibleForTesting
import org.apache.spark.sql.Column
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.Row
Expand Down Expand Up @@ -266,12 +262,17 @@ case class NumMatchesAndCount(numMatches: Long, count: Long, override val fullCo
}
}

case class AnalyzerOptions(nullBehavior: NullBehavior = NullBehavior.Ignore)
case class AnalyzerOptions(nullBehavior: NullBehavior = NullBehavior.Ignore, filteredRow: FilteredRow = FilteredRow.TRUE)
object NullBehavior extends Enumeration {
type NullBehavior = Value
val Ignore, EmptyString, Fail = Value
}

object FilteredRow extends Enumeration {
type FilteredRow = Value
val NULL, TRUE = Value
}

/** Base class for analyzers that compute ratios of matching predicates */
abstract class PredicateMatchingAnalyzer(
name: String,
Expand Down
19 changes: 8 additions & 11 deletions src/main/scala/com/amazon/deequ/analyzers/Completeness.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,16 @@ import com.amazon.deequ.analyzers.Preconditions.{hasColumn, isNotNested}
import org.apache.spark.sql.functions.sum
import org.apache.spark.sql.types.{IntegerType, StructType}
import Analyzers._
import com.amazon.deequ.utilities.FilteredRow.FilteredRow
import com.amazon.deequ.utilities.RowLevelAnalyzer
import com.amazon.deequ.utilities.RowLevelFilterTreatment
import com.amazon.deequ.utilities.RowLevelFilterTreatmentImpl
import com.amazon.deequ.analyzers.FilteredRow.FilteredRow
import com.google.common.annotations.VisibleForTesting
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.functions.expr
import org.apache.spark.sql.{Column, Row}

/** Completeness is the fraction of non-null values in a column of a DataFrame. */
case class Completeness(column: String, where: Option[String] = None) extends
case class Completeness(column: String, where: Option[String] = None, analyzerOptions: Option[AnalyzerOptions] = None) extends
StandardScanShareableAnalyzer[NumMatchesAndCount]("Completeness", column) with
FilterableAnalyzer with RowLevelAnalyzer {
FilterableAnalyzer {

override def fromAggregationResult(result: Row, offset: Int): Option[NumMatchesAndCount] = {
ifNoNullsIn(result, offset, howMany = 2) { _ =>
Expand All @@ -59,12 +56,12 @@ case class Completeness(column: String, where: Option[String] = None) extends
@VisibleForTesting
private[deequ] def rowLevelResults: Column = {
val whereCondition = where.map { expression => expr(expression)}
conditionalSelectionFilteredFromColumns(col(column).isNotNull, whereCondition, rowLevelFilterTreatment.toString)
conditionalSelectionFilteredFromColumns(col(column).isNotNull, whereCondition, getRowLevelFilterTreatment.toString)
}

@VisibleForTesting
private[deequ] def withRowLevelFilterTreatment(filteredRow: FilteredRow): this.type = {
RowLevelFilterTreatment.setSharedInstance(new RowLevelFilterTreatmentImpl(filteredRow))
this
private def getRowLevelFilterTreatment: FilteredRow = {
analyzerOptions
.map { options => options.filteredRow }
.getOrElse(FilteredRow.TRUE)
}
}
30 changes: 20 additions & 10 deletions src/main/scala/com/amazon/deequ/analyzers/UniqueValueRatio.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,19 @@
package com.amazon.deequ.analyzers

import com.amazon.deequ.analyzers.Analyzers.COUNT_COL
import com.amazon.deequ.analyzers.FilteredRow.FilteredRow
import com.amazon.deequ.metrics.DoubleMetric
import com.amazon.deequ.utilities.RowLevelAnalyzer
import org.apache.spark.sql.functions.expr
import org.apache.spark.sql.functions.not
import org.apache.spark.sql.functions.when
import org.apache.spark.sql.{Column, Row}
import org.apache.spark.sql.functions.{col, count, lit, sum}
import org.apache.spark.sql.types.DoubleType

case class UniqueValueRatio(columns: Seq[String], where: Option[String] = None)
case class UniqueValueRatio(columns: Seq[String], where: Option[String] = None,
analyzerOptions: Option[AnalyzerOptions] = None)
extends ScanShareableFrequencyBasedAnalyzer("UniqueValueRatio", columns)
with FilterableAnalyzer with RowLevelAnalyzer {
with FilterableAnalyzer {

override def aggregationFunctions(numRows: Long): Seq[Column] = {
sum(col(COUNT_COL).equalTo(lit(1)).cast(DoubleType)) :: count("*") :: Nil
Expand All @@ -38,17 +39,26 @@ case class UniqueValueRatio(columns: Seq[String], where: Option[String] = None)
val numUniqueValues = result.getDouble(offset)
val numDistinctValues = result.getLong(offset + 1).toDouble
val conditionColumn = where.map { expression => expr(expression) }
val fullColumnUniqueness = conditionColumn.map {
condition => {
when(not(condition), expr(rowLevelFilterTreatment.toString))
.when((fullColumn.getOrElse(null)).equalTo(1), true)
.otherwise(false)
val fullColumnUniqueness = fullColumn.map {
rowLevelColumn => {
conditionColumn.map {
condition => {
when(not(condition), expr(getRowLevelFilterTreatment.toString))
.when(rowLevelColumn.equalTo(1), true).otherwise(false)
}
}.getOrElse(when(rowLevelColumn.equalTo(1), true).otherwise(false))
}
}.getOrElse(when((fullColumn.getOrElse(null)).equalTo(1), true).otherwise(false))
toSuccessMetric(numUniqueValues / numDistinctValues, Option(fullColumnUniqueness))
}
toSuccessMetric(numUniqueValues / numDistinctValues, fullColumnUniqueness)
}

override def filterCondition: Option[String] = where

private def getRowLevelFilterTreatment: FilteredRow = {
analyzerOptions
.map { options => options.filteredRow }
.getOrElse(FilteredRow.TRUE)
}
}

object UniqueValueRatio {
Expand Down
36 changes: 18 additions & 18 deletions src/main/scala/com/amazon/deequ/analyzers/Uniqueness.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,8 @@
package com.amazon.deequ.analyzers

import com.amazon.deequ.analyzers.Analyzers.COUNT_COL
import com.amazon.deequ.analyzers.FilteredRow.FilteredRow
import com.amazon.deequ.metrics.DoubleMetric
import com.amazon.deequ.utilities.FilteredRow
import com.amazon.deequ.utilities.FilteredRow.FilteredRow
import com.amazon.deequ.utilities.RowLevelAnalyzer
import com.amazon.deequ.utilities.RowLevelFilterTreatment
import com.amazon.deequ.utilities.RowLevelFilterTreatmentImpl
import com.google.common.annotations.VisibleForTesting
import org.apache.spark.sql.Column
import org.apache.spark.sql.Row
Expand All @@ -36,32 +32,36 @@ import org.apache.spark.sql.types.DoubleType

/** Uniqueness is the fraction of unique values of a column(s), i.e.,
* values that occur exactly once. */
case class Uniqueness(columns: Seq[String], where: Option[String] = None)
case class Uniqueness(columns: Seq[String], where: Option[String] = None,
analyzerOptions: Option[AnalyzerOptions] = None)
extends ScanShareableFrequencyBasedAnalyzer("Uniqueness", columns)
with FilterableAnalyzer with RowLevelAnalyzer {
with FilterableAnalyzer {

override def aggregationFunctions(numRows: Long): Seq[Column] = {
(sum(col(COUNT_COL).equalTo(lit(1)).cast(DoubleType)) / numRows) :: Nil
}

override def fromAggregationResult(result: Row, offset: Int, fullColumn: Option[Column]): DoubleMetric = {
val conditionColumn = where.map { expression => expr(expression) }
val fullColumnUniqueness = conditionColumn.map {
condition => {
when(not(condition), expr(rowLevelFilterTreatment.toString))
.when(fullColumn.getOrElse(null).equalTo(1), true).otherwise(false)
val fullColumnUniqueness = fullColumn.map {
rowLevelColumn => {
conditionColumn.map {
condition => {
when(not(condition), expr(getRowLevelFilterTreatment.toString))
.when(rowLevelColumn.equalTo(1), true).otherwise(false)
}
}.getOrElse(when(rowLevelColumn.equalTo(1), true).otherwise(false))
}
}.getOrElse(when((fullColumn.getOrElse(null)).equalTo(1), true).otherwise(false))
super.fromAggregationResult(result, offset, Option(fullColumnUniqueness))
}
super.fromAggregationResult(result, offset, fullColumnUniqueness)
}

override def filterCondition: Option[String] = where


@VisibleForTesting
private[deequ] def withRowLevelFilterTreatment(filteredRow: FilteredRow): this.type = {
RowLevelFilterTreatment.setSharedInstance(new RowLevelFilterTreatmentImpl(filteredRow))
this
private def getRowLevelFilterTreatment: FilteredRow = {
analyzerOptions
.map { options => options.filteredRow }
.getOrElse(FilteredRow.TRUE)
}
}

Expand Down
60 changes: 52 additions & 8 deletions src/main/scala/com/amazon/deequ/checks/Check.scala
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,12 @@ case class Check(
*
* @param column Column to run the assertion on
* @param hint A hint to provide additional context why a constraint could have failed
* @param analyzerOptions Options to configure analyzer behavior (NullTreatment, FilteredRow)
* @return
*/
def isComplete(column: String, hint: Option[String] = None): CheckWithLastConstraintFilterable = {
addFilterableConstraint { filter => completenessConstraint(column, Check.IsOne, filter, hint) }
def isComplete(column: String, hint: Option[String] = None,
analyzerOptions: Option[AnalyzerOptions] = None): CheckWithLastConstraintFilterable = {
addFilterableConstraint { filter => completenessConstraint(column, Check.IsOne, filter, hint, analyzerOptions) }
}

/**
Expand All @@ -146,14 +148,16 @@ case class Check(
* @param column Column to run the assertion on
* @param assertion Function that receives a double input parameter and returns a boolean
* @param hint A hint to provide additional context why a constraint could have failed
* @param analyzerOptions Options to configure analyzer behavior (NullTreatment, FilteredRow)
* @return
*/
def hasCompleteness(
column: String,
assertion: Double => Boolean,
hint: Option[String] = None)
hint: Option[String] = None,
analyzerOptions: Option[AnalyzerOptions] = None)
: CheckWithLastConstraintFilterable = {
addFilterableConstraint { filter => completenessConstraint(column, assertion, filter, hint) }
addFilterableConstraint { filter => completenessConstraint(column, assertion, filter, hint, analyzerOptions) }
}

/**
Expand Down Expand Up @@ -221,11 +225,13 @@ case class Check(
*
* @param column Column to run the assertion on
* @param hint A hint to provide additional context why a constraint could have failed
* @param analyzerOptions Options to configure analyzer behavior (NullTreatment, FilteredRow)
* @return
*/
def isUnique(column: String, hint: Option[String] = None): CheckWithLastConstraintFilterable = {
def isUnique(column: String, hint: Option[String] = None,
analyzerOptions: Option[AnalyzerOptions] = None): CheckWithLastConstraintFilterable = {
addFilterableConstraint { filter =>
uniquenessConstraint(Seq(column), Check.IsOne, filter, hint) }
uniquenessConstraint(Seq(column), Check.IsOne, filter, hint, analyzerOptions) }
}

/**
Expand Down Expand Up @@ -269,22 +275,42 @@ case class Check(
addFilterableConstraint { filter => uniquenessConstraint(columns, assertion, filter) }
}

/**
* Creates a constraint that asserts on uniqueness in a single or combined set of key columns.
*
* @param columns Key columns
* @param assertion Function that receives a double input parameter and returns a boolean.
* Refers to the fraction of unique values
* @param hint A hint to provide additional context why a constraint could have failed
* @return
*/
def hasUniqueness(
columns: Seq[String],
assertion: Double => Boolean,
hint: Option[String])
: CheckWithLastConstraintFilterable = {

addFilterableConstraint { filter => uniquenessConstraint(columns, assertion, filter, hint) }
}

/**
* Creates a constraint that asserts on uniqueness in a single or combined set of key columns.
*
* @param columns Key columns
* @param assertion Function that receives a double input parameter and returns a boolean.
* Refers to the fraction of unique values
* @param hint A hint to provide additional context why a constraint could have failed
* @param analyzerOptions Options to configure analyzer behavior (NullTreatment, FilteredRow)
* @return
*/
def hasUniqueness(
columns: Seq[String],
assertion: Double => Boolean,
hint: Option[String])
hint: Option[String],
analyzerOptions: Option[AnalyzerOptions])
: CheckWithLastConstraintFilterable = {

addFilterableConstraint { filter => uniquenessConstraint(columns, assertion, filter, hint) }
addFilterableConstraint { filter => uniquenessConstraint(columns, assertion, filter, hint, analyzerOptions) }
}

/**
Expand Down Expand Up @@ -314,6 +340,22 @@ case class Check(
hasUniqueness(Seq(column), assertion, hint)
}

/**
* Creates a constraint that asserts on the uniqueness of a key column.
*
* @param column Key column
* @param assertion Function that receives a double input parameter and returns a boolean.
* Refers to the fraction of unique values.
* @param hint A hint to provide additional context why a constraint could have failed
* @param analyzerOptions Options to configure analyzer behavior (NullTreatment, FilteredRow)
* @return
*/
def hasUniqueness(column: String, assertion: Double => Boolean, hint: Option[String],
analyzerOptions: Option[AnalyzerOptions])
: CheckWithLastConstraintFilterable = {
hasUniqueness(Seq(column), assertion, hint, analyzerOptions)
}

/**
* Creates a constraint on the distinctness in a single or combined set of key columns.
*
Expand Down Expand Up @@ -601,6 +643,7 @@ case class Check(
* @param column Column to run the assertion on
* @param assertion Function that receives a double input parameter and returns a boolean
* @param hint A hint to provide additional context why a constraint could have failed
* @param analyzerOptions Options to configure analyzer behavior (NullTreatment, FilteredRow)
* @return
*/
def hasMinLength(
Expand All @@ -619,6 +662,7 @@ case class Check(
* @param column Column to run the assertion on
* @param assertion Function that receives a double input parameter and returns a boolean
* @param hint A hint to provide additional context why a constraint could have failed
* @param analyzerOptions Options to configure analyzer behavior (NullTreatment, FilteredRow)
* @return
*/
def hasMaxLength(
Expand Down
Loading

0 comments on commit f856816

Please sign in to comment.