Skip to content

Commit

Permalink
Add RowLevelFilterTreatement trait and object to determine how filter…
Browse files Browse the repository at this point in the history
…ed rows will be labeled (default True)
  • Loading branch information
eycho-am committed Feb 13, 2024
1 parent 5268cb9 commit 274a446
Show file tree
Hide file tree
Showing 11 changed files with 178 additions and 44 deletions.
21 changes: 18 additions & 3 deletions src/main/scala/com/amazon/deequ/VerificationRunBuilder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,14 @@ import com.amazon.deequ.analyzers.{State, _}
import com.amazon.deequ.checks.{Check, CheckLevel}
import com.amazon.deequ.metrics.Metric
import com.amazon.deequ.repository._
import com.amazon.deequ.utilities.FilteredRow
import com.amazon.deequ.utilities.FilteredRow.FilteredRow
import com.amazon.deequ.utilities.RowLevelFilterTreatment
import com.amazon.deequ.utilities.RowLevelFilterTreatmentImpl
import org.apache.spark.sql.{DataFrame, SparkSession}

/** A class to build a VerificationRun using a fluent API */
class VerificationRunBuilder(val data: DataFrame) {
class VerificationRunBuilder(val data: DataFrame) extends RowLevelFilterTreatment {

protected var requiredAnalyzers: Seq[Analyzer[_, Metric[_]]] = Seq.empty

Expand Down Expand Up @@ -135,6 +139,18 @@ class VerificationRunBuilder(val data: DataFrame) {
this
}

/**
* Sets how row level results will be treated for the Verification run
*
* @param filteredRow enum to determine how filtered row level results are labeled (TRUE | NULL)
*/
def withRowLevelFilterTreatment(filteredRow: FilteredRow): this.type = {
RowLevelFilterTreatment.setSharedInstance(new RowLevelFilterTreatmentImpl(filteredRow))
this
}

def rowLevelFilterTreatment: FilteredRow.Value = RowLevelFilterTreatment.sharedInstance.rowLevelFilterTreatment

/**
* Set a metrics repository associated with the current data to enable features like reusing
* previously computed results and storing the results of the current run.
Expand All @@ -159,7 +175,6 @@ class VerificationRunBuilder(val data: DataFrame) {
new VerificationRunBuilderWithSparkSession(this, Option(sparkSession))
}


def run(): VerificationResult = {
VerificationSuite().doVerificationRun(
data,
Expand Down Expand Up @@ -338,4 +353,4 @@ case class AnomalyCheckConfig(
description: String,
withTagValues: Map[String, String] = Map.empty,
afterDate: Option[Long] = None,
beforeDate: Option[Long] = None)
beforeDate: Option[Long] = None)
24 changes: 15 additions & 9 deletions src/main/scala/com/amazon/deequ/analyzers/Analyzer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,18 @@
package com.amazon.deequ.analyzers

import com.amazon.deequ.analyzers.Analyzers._
import com.amazon.deequ.analyzers.FilteredRow.FilteredRow
import com.amazon.deequ.analyzers.NullBehavior.NullBehavior
import com.amazon.deequ.analyzers.runners._
import com.amazon.deequ.metrics.DoubleMetric
import com.amazon.deequ.metrics.Entity
import com.amazon.deequ.metrics.FullColumn
import com.amazon.deequ.metrics.Metric
import com.amazon.deequ.utilities.ColumnUtil.removeEscapeColumn
import com.amazon.deequ.utilities.FilteredRow
import com.amazon.deequ.utilities.FilteredRow.FilteredRow
import com.amazon.deequ.utilities.RowLevelFilterTreatment
import com.amazon.deequ.utilities.RowLevelFilterTreatmentImpl
import com.google.common.annotations.VisibleForTesting
import org.apache.spark.sql.Column
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.Row
Expand Down Expand Up @@ -63,7 +67,7 @@ trait DoubleValuedState[S <: DoubleValuedState[S]] extends State[S] {
}

/** Common trait for all analyzers which generates metrics from states computed on data frames */
trait Analyzer[S <: State[_], +M <: Metric[_]] extends Serializable {
trait Analyzer[S <: State[_], +M <: Metric[_]] extends Serializable with RowLevelFilterTreatment {

/**
* Compute the state (sufficient statistics) from the data
Expand Down Expand Up @@ -175,6 +179,14 @@ trait Analyzer[S <: State[_], +M <: Metric[_]] extends Serializable {
source.load[S](this).foreach { state => target.persist(this, state) }
}

@VisibleForTesting
private[deequ] def withRowLevelFilterTreatment(filteredRow: FilteredRow): this.type = {
RowLevelFilterTreatment.setSharedInstance(new RowLevelFilterTreatmentImpl(filteredRow))
this
}

def rowLevelFilterTreatment: FilteredRow.Value = RowLevelFilterTreatment.sharedInstance.rowLevelFilterTreatment

}

/** An analyzer that runs a set of aggregation functions over the data,
Expand Down Expand Up @@ -263,18 +275,12 @@ case class NumMatchesAndCount(numMatches: Long, count: Long, override val fullCo
}
}

case class AnalyzerOptions(nullBehavior: NullBehavior = NullBehavior.Ignore,
filteredRow: FilteredRow = FilteredRow.NULL)
case class AnalyzerOptions(nullBehavior: NullBehavior = NullBehavior.Ignore)
object NullBehavior extends Enumeration {
type NullBehavior = Value
val Ignore, EmptyString, Fail = Value
}

object FilteredRow extends Enumeration {
type FilteredRow = Value
val NULL, TRUE = Value
}

/** Base class for analyzers that compute ratios of matching predicates */
abstract class PredicateMatchingAnalyzer(
name: String,
Expand Down
5 changes: 1 addition & 4 deletions src/main/scala/com/amazon/deequ/analyzers/Completeness.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@ import com.amazon.deequ.analyzers.Preconditions.{hasColumn, isNotNested}
import org.apache.spark.sql.functions.sum
import org.apache.spark.sql.types.{IntegerType, StructType}
import Analyzers._
import com.amazon.deequ.analyzers.FilteredRow.FilteredRow
import com.google.common.annotations.VisibleForTesting
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.functions.expr
import org.apache.spark.sql.{Column, Row}
Expand Down Expand Up @@ -56,9 +54,8 @@ case class Completeness(column: String, where: Option[String] = None) extends

@VisibleForTesting
private[deequ] def rowLevelResults: Column = {
val filteredRow = FilteredRow.NULL
val whereCondition = where.map { expression => expr(expression)}
conditionalSelectionFilteredFromColumns(col(column).isNotNull, whereCondition, filteredRow.toString)
conditionalSelectionFilteredFromColumns(col(column).isNotNull, whereCondition, rowLevelFilterTreatment.toString)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ case class UniqueValueRatio(columns: Seq[String], where: Option[String] = None)
val conditionColumn = where.map { expression => expr(expression) }
val fullColumnUniqueness = conditionColumn.map {
condition => {
when(not(condition), expr(FilteredRow.NULL.toString)).when((fullColumn.getOrElse(null)).equalTo(1), true).otherwise(false)
when(not(condition), expr(rowLevelFilterTreatment.toString))
.when((fullColumn.getOrElse(null)).equalTo(1), true)
.otherwise(false)
}
}.getOrElse(when((fullColumn.getOrElse(null)).equalTo(1), true).otherwise(false))
toSuccessMetric(numUniqueValues / numDistinctValues, Option(fullColumnUniqueness))
Expand Down
3 changes: 1 addition & 2 deletions src/main/scala/com/amazon/deequ/analyzers/Uniqueness.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
package com.amazon.deequ.analyzers

import com.amazon.deequ.analyzers.Analyzers.COUNT_COL
import com.amazon.deequ.analyzers.Analyzers.conditionalCount
import com.amazon.deequ.metrics.DoubleMetric
import org.apache.spark.sql.Column
import org.apache.spark.sql.Row
Expand All @@ -43,7 +42,7 @@ case class Uniqueness(columns: Seq[String], where: Option[String] = None)
val conditionColumn = where.map { expression => expr(expression) }
val fullColumnUniqueness = conditionColumn.map {
condition => {
when(not(condition), expr(FilteredRow.NULL.toString)).when((fullColumn.getOrElse(null)).equalTo(1), true).otherwise(false)
when(not(condition), expr(rowLevelFilterTreatment.toString)).when((fullColumn.getOrElse(null)).equalTo(1), true).otherwise(false)
}
}.getOrElse(when((fullColumn.getOrElse(null)).equalTo(1), true).otherwise(false))
super.fromAggregationResult(result, offset, Option(fullColumnUniqueness))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package com.amazon.deequ.utilities
import com.amazon.deequ.utilities.FilteredRow.FilteredRow

/**
* Trait that defines how row level results will be treated when a filter is applied to an analyzer
*/
trait RowLevelFilterTreatment {
def rowLevelFilterTreatment: FilteredRow
}

/**
* Companion object for RowLevelFilterTreatment
* Defines a sharedInstance that can be used throughout the VerificationRunBuilder
*/
object RowLevelFilterTreatment {
private var _sharedInstance: RowLevelFilterTreatment = new RowLevelFilterTreatmentImpl(FilteredRow.TRUE)

def sharedInstance: RowLevelFilterTreatment = _sharedInstance

def setSharedInstance(instance: RowLevelFilterTreatment): Unit = {
_sharedInstance = instance
}
}

class RowLevelFilterTreatmentImpl(initialFilterTreatment: FilteredRow) extends RowLevelFilterTreatment {
override val rowLevelFilterTreatment: FilteredRow = initialFilterTreatment
}

object FilteredRow extends Enumeration {
type FilteredRow = Value
val NULL, TRUE = Value
}
56 changes: 47 additions & 9 deletions src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import com.amazon.deequ.metrics.Entity
import com.amazon.deequ.repository.MetricsRepository
import com.amazon.deequ.repository.ResultKey
import com.amazon.deequ.repository.memory.InMemoryMetricsRepository
import com.amazon.deequ.utilities.FilteredRow
import com.amazon.deequ.utils.CollectionUtils.SeqExtensions
import com.amazon.deequ.utils.FixtureSupport
import com.amazon.deequ.utils.TempFileUtils
Expand Down Expand Up @@ -303,7 +304,7 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec
assert(Seq(true, true, true, false, false, false).sameElements(rowLevel8))
}

"generate a result that contains row-level results with filter with null for filtered rows" in withSparkSession { session =>
"generate a result that contains row-level results with filter with true for filtered rows" in withSparkSession { session =>
val data = getDfCompleteAndInCompleteColumns(session)

val completeness = new Check(CheckLevel.Error, "rule1").hasCompleteness("att2", _ > 0.7, None).where("att1 = \"a\"")
Expand All @@ -324,7 +325,44 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec
assert(result.status == CheckStatus.Error)

val resultData = VerificationResult.rowLevelResultsAsDataFrame(session, result, data).orderBy("item")
resultData.show(false)
val expectedColumns: Set[String] =
data.columns.toSet + expectedColumn1 + expectedColumn2 + expectedColumn3
assert(resultData.columns.toSet == expectedColumns)

val rowLevel1 = resultData.select(expectedColumn1).collect().map(r => r.getAs[Any](0))
assert(Seq(true, true, false, true, true, true).sameElements(rowLevel1))

val rowLevel2 = resultData.select(expectedColumn2).collect().map(r => r.getAs[Any](0))
assert(Seq(false, false, false, false, false, false).sameElements(rowLevel2))

val rowLevel3 = resultData.select(expectedColumn3).collect().map(r => r.getAs[Any](0))
assert(Seq(true, true, true, true, true, true).sameElements(rowLevel3))

}

"generate a result that contains row-level results with filter with null for filtered rows" in withSparkSession { session =>
val data = getDfCompleteAndInCompleteColumns(session)

val completeness = new Check(CheckLevel.Error, "rule1").hasCompleteness("att2", _ > 0.7, None).where("att1 = \"a\"")
val uniqueness = new Check(CheckLevel.Error, "rule2").hasUniqueness("att1", _ > 0.5, None)
val uniquenessWhere = new Check(CheckLevel.Error, "rule3").isUnique("att1").where("item < 3")
val expectedColumn1 = completeness.description
val expectedColumn2 = uniqueness.description
val expectedColumn3 = uniquenessWhere.description

val suite = new VerificationSuite().onData(data)
.withRowLevelFilterTreatment(FilteredRow.NULL)
.addCheck(completeness)
.addCheck(uniqueness)
.addCheck(uniquenessWhere)

val result: VerificationResult = suite.run()

assert(result.status == CheckStatus.Error)

val resultData = VerificationResult.rowLevelResultsAsDataFrame(session, result, data).orderBy("item")
resultData.show(false)
val expectedColumns: Set[String] =
data.columns.toSet + expectedColumn1 + expectedColumn2 + expectedColumn3
assert(resultData.columns.toSet == expectedColumns)
Expand Down Expand Up @@ -498,15 +536,15 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec
"accept analysis config for mandatory analysis for checks with filters" in withSparkSession { sparkSession =>

import sparkSession.implicits._
val df = getDfFull(sparkSession)
val df = getDfCompleteAndInCompleteColumns(sparkSession)

val result = {
val checkToSucceed = Check(CheckLevel.Error, "group-1")
.hasCompleteness("att1", _ > 0.5, null) // 1.0
.where("att2 = \"c\"")
.hasCompleteness("att2", _ > 0.7, null) // 0.75
.where("att1 = \"a\"")
val uniquenessCheck = Check(CheckLevel.Error, "group-2")
.isUnique("att1")
.where("item > 3")
.where("item < 3")


VerificationSuite().onData(df).addCheck(checkToSucceed).addCheck(uniquenessCheck).run()
Expand All @@ -518,8 +556,8 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec
AnalyzerContext(result.metrics))

val expected = Seq(
("Column", "att1", "Completeness (where: att2 = \"c\")", 1.0),
("Column", "att1", "Uniqueness (where: item > 3)", 1.0))
("Column", "att2", "Completeness (where: att1 = \"a\")", 0.75),
("Column", "att1", "Uniqueness (where: item < 3)", 1.0))
.toDF("entity", "instance", "name", "value")


Expand Down Expand Up @@ -871,12 +909,12 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec
.run()

val checkSuccessResult = verificationResult.checkResults(rangeCheck)
// checkSuccessResult.constraintResults.map(_.message) shouldBe List(None)
checkSuccessResult.constraintResults.map(_.message) shouldBe List(None)
println(checkSuccessResult.constraintResults.map(_.message))
assert(checkSuccessResult.status == CheckStatus.Success)

val reasonResult = verificationResult.checkResults(reasonCheck)
// checkSuccessResult.constraintResults.map(_.message) shouldBe List(None)
checkSuccessResult.constraintResults.map(_.message) shouldBe List(None)
println(checkSuccessResult.constraintResults.map(_.message))
assert(checkSuccessResult.status == CheckStatus.Success)
}
Expand Down
29 changes: 16 additions & 13 deletions src/test/scala/com/amazon/deequ/analyzers/CompletenessTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package com.amazon.deequ.analyzers
import com.amazon.deequ.SparkContextSpec
import com.amazon.deequ.metrics.DoubleMetric
import com.amazon.deequ.metrics.FullColumn
import com.amazon.deequ.utilities.FilteredRow
import com.amazon.deequ.utils.FixtureSupport
import org.scalatest.matchers.should.Matchers
import org.scalatest.wordspec.AnyWordSpec
Expand All @@ -40,32 +41,34 @@ class CompletenessTest extends AnyWordSpec with Matchers with SparkContextSpec w
Seq(true, true, true, true, false, true, true, false)
}

"return row-level results for columns filtered using where" in withSparkSession { session =>
"return row-level results for columns filtered as null" in withSparkSession { session =>

val data = getDfForWhereClause(session)
val data = getDfCompleteAndInCompleteColumns(session)

val completenessCountry = Completeness("ZipCode", Option("State = \"CA\""))
val state = completenessCountry.computeStateFrom(data)
val metric: DoubleMetric with FullColumn = completenessCountry.computeMetricFrom(state)
// Explicitly setting RowLevelFilterTreatment for test purposes, but this should be set at the VerificationRunBuilder
val completenessAtt2 = Completeness("att2", Option("att1 = \"a\"")).withRowLevelFilterTreatment(FilteredRow.NULL)
val state = completenessAtt2.computeStateFrom(data)
val metric: DoubleMetric with FullColumn = completenessAtt2.computeMetricFrom(state)

// Address Line 3 is null only where Address Line 2 is null. With the where clause, completeness should be 1.0
data.withColumn("new", metric.fullColumn.get).collect().map(_.getAs[Boolean]("new")) shouldBe
Seq(true, true, false, false)
val df = data.withColumn("new", metric.fullColumn.get)
df.show(false)
df.collect().map(_.getAs[Any]("new")).toSeq shouldBe
Seq(true, null, false, true, null, true)
}

"return row-level results for columns filtered as null" in withSparkSession { session =>
"return row-level results for columns filtered as true" in withSparkSession { session =>

val data = getDfCompleteAndInCompleteColumns(session)

val completenessAtt2 = Completeness("att2", Option("att1 = \"a\""))
// Explicitly setting RowLevelFilterTreatment for test purposes, but this should be set at the VerificationRunBuilder
val completenessAtt2 = Completeness("att2", Option("att1 = \"a\"")).withRowLevelFilterTreatment(FilteredRow.TRUE)
val state = completenessAtt2.computeStateFrom(data)
val metric: DoubleMetric with FullColumn = completenessAtt2.computeMetricFrom(state)

val df = data.withColumn("new", metric.fullColumn.get)
println("Filtered as null")
df.show(false)
df.collect().map(_.getAs[Any]("new")).toSeq shouldBe
Seq(true, null, false, true, null, true)
df.collect().map(_.getAs[Any]("new")).toSeq shouldBe
Seq(true, true, false, true, true, true)
}
}
}
Loading

0 comments on commit 274a446

Please sign in to comment.