Skip to content

Commit

Permalink
review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
VenkataKarthikP committed Nov 12, 2023
1 parent 209eba3 commit 46ded08
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 50 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not
* use this file except in compliance with the License. A copy of the License
Expand All @@ -20,28 +20,22 @@ import com.amazon.deequ.analyzers.runners.AnalyzerContext
import com.amazon.deequ.metrics.Metric
import com.amazon.deequ.repository._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{StringType, StructField, StructType}
import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession}
import org.apache.spark.sql.{DataFrame, SaveMode, SparkSession}

class SparkTableMetricsRepository(session: SparkSession, tableName: String) extends MetricsRepository {

private val SCHEMA = StructType(Array(
StructField("result_key", StringType),
StructField("metric_name", StringType),
StructField("metric_value", StringType),
StructField("result_timestamp", StringType),
StructField("serialized_context", StringType)
))
import session.implicits._

override def save(resultKey: ResultKey, analyzerContext: AnalyzerContext): Unit = {
val serializedContext = AnalysisResultSerde.serialize(Seq(AnalysisResult(resultKey, analyzerContext)))

val rows = analyzerContext.metricMap.map { case (analyzer, metric) =>
Row(resultKey.toString, analyzer.toString, metric.value.toString,
resultKey.dataSetDate.toString, serializedContext)
}.toSeq
val successfulMetrics = analyzerContext.metricMap
.filter { case (_, metric) => metric.value.isSuccess }

val metricDF = session.createDataFrame(session.sparkContext.parallelize(rows), SCHEMA)
val metricDF = successfulMetrics.map { case (analyzer, metric) =>
SparkTableMetric(resultKey.toString, analyzer.toString, metric.value.toString,
resultKey.dataSetDate, serializedContext)
}.toSeq.toDF()

metricDF.write
.mode(SaveMode.Append)
Expand All @@ -50,14 +44,13 @@ class SparkTableMetricsRepository(session: SparkSession, tableName: String) exte

override def loadByKey(resultKey: ResultKey): Option[AnalyzerContext] = {
val df: DataFrame = session.table(tableName)
val matchingRows = df.filter(col("result_key") === resultKey.toString).collect()
val matchingRows = df.filter(col("resultKey") === resultKey.toString).collect()

if (matchingRows.isEmpty) {
None
} else {
val serializedContext = matchingRows(0).getAs[String]("serialized_context")
val analysisResult = AnalysisResultSerde.deserialize(serializedContext).head
Some(analysisResult.analyzerContext)
val serializedContext = matchingRows(0).getAs[String]("serializedContext")
AnalysisResultSerde.deserialize(serializedContext).headOption.map(_.analyzerContext)
}
}

Expand All @@ -67,13 +60,16 @@ class SparkTableMetricsRepository(session: SparkSession, tableName: String) exte

}

case class SparkTableMetric(resultKey: String, metricName: String, metricValue: String, resultTimestamp: Long,
serializedContext: String)

case class SparkTableMetricsRepositoryMultipleResultsLoader(session: SparkSession,
tableName: String,
tagValues: Option[Map[String, String]] = None,
analyzers: Option[Seq[Analyzer[_, Metric[_]]]] = None,
timeAfter: Option[Long] = None,
timeBefore: Option[Long] = None
private val tagValues: Option[Map[String, String]] = None,
private val analyzers: Option[Seq[Analyzer[_, Metric[_]]]]
= None,
private val timeAfter: Option[Long] = None,
private val timeBefore: Option[Long] = None
) extends MetricsRepositoryMultipleResultsLoader {

override def withTagValues(tagValues: Map[String, String]): MetricsRepositoryMultipleResultsLoader =
Expand All @@ -91,12 +87,11 @@ case class SparkTableMetricsRepositoryMultipleResultsLoader(session: SparkSessio
override def get(): Seq[AnalysisResult] = {
val initialDF: DataFrame = session.table(tableName)

initialDF.printSchema()
val tagValuesFilter: DataFrame => DataFrame = df => {
tagValues.map { tags =>
tags.foldLeft(df) { (currentDF, tag) =>
currentDF.filter(row => {
val ser = row.getAs[String]("serialized_context")
val ser = row.getAs[String]("serializedContext")
AnalysisResultSerde.deserialize(ser).exists(ar => {
val tags = ar.resultKey.tags
tags.contains(tag._1) && tags(tag._1) == tag._2
Expand All @@ -107,16 +102,16 @@ case class SparkTableMetricsRepositoryMultipleResultsLoader(session: SparkSessio
}

val specificAnalyzersFilter: DataFrame => DataFrame = df => {
analyzers.map(analyzers => df.filter(col("metric_name").isin(analyzers.map(_.toString): _*)))
analyzers.map(analyzers => df.filter(col("metricName").isin(analyzers.map(_.toString): _*)))
.getOrElse(df)
}

val timeAfterFilter: DataFrame => DataFrame = df => {
timeAfter.map(time => df.filter(col("result_timestamp") > time.toString)).getOrElse(df)
timeAfter.map(time => df.filter(col("resultTimestamp") > time.toString)).getOrElse(df)
}

val timeBeforeFilter: DataFrame => DataFrame = df => {
timeBefore.map(time => df.filter(col("result_timestamp") < time.toString)).getOrElse(df)
timeBefore.map(time => df.filter(col("resultTimestamp") < time.toString)).getOrElse(df)
}

val filteredDF = Seq(tagValuesFilter, specificAnalyzersFilter, timeAfterFilter, timeBeforeFilter)
Expand All @@ -126,7 +121,7 @@ case class SparkTableMetricsRepositoryMultipleResultsLoader(session: SparkSessio

// Convert the final DataFrame to the desired output format
filteredDF.collect().flatMap(row => {
val serializedContext = row.getAs[String]("serialized_context")
val serializedContext = row.getAs[String]("serializedContext")
AnalysisResultSerde.deserialize(serializedContext)
}).toSeq
}
Expand Down
33 changes: 22 additions & 11 deletions src/test/scala/com/amazon/deequ/SparkContextSpec.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not
* use this file except in compliance with the License. A copy of the License
Expand Down Expand Up @@ -27,13 +27,13 @@ import scala.collection.convert.ImplicitConversions.`iterator asScala`
*/
trait SparkContextSpec {

val warehouseDir: Path = Files.createTempDirectory("my_temp_dir_")
val tmpWareHouseDir: Path = Files.createTempDirectory("deequ_tmp")

/**
* @param testFun thunk to run with SparkSession as an argument
*/
def withSparkSession(testFun: SparkSession => Any): Unit = {
val session = setupSparkSession
val session = setupSparkSession()
try {
testFun(session)
} finally {
Expand All @@ -42,11 +42,20 @@ trait SparkContextSpec {
}
}

def withSparkSessionCustomWareHouse(testFun: SparkSession => Any): Unit = {
val session = setupSparkSession(Some(tmpWareHouseDir.toAbsolutePath.toString))
try {
testFun(session)
} finally {
tearDownSparkSession(session)
}
}

def withSparkSessionIcebergCatalog(testFun: SparkSession => Any): Unit = {
val session = setupSparkSession
val session = setupSparkSession(Some(tmpWareHouseDir.toAbsolutePath.toString))
session.conf.set("spark.sql.catalog.local", "org.apache.iceberg.spark.SparkCatalog")
session.conf.set("spark.sql.catalog.local.type", "hadoop")
session.conf.set("spark.sql.catalog.local.warehouse", warehouseDir.toAbsolutePath.toString)
session.conf.set("spark.sql.catalog.local.warehouse", tmpWareHouseDir.toAbsolutePath.toString)

try {
testFun(session)
Expand All @@ -63,7 +72,7 @@ trait SparkContextSpec {
*/
def withMonitorableSparkSession(testFun: (SparkSession, SparkMonitor) => Any): Unit = {
val monitor = new SparkMonitor
val session = setupSparkSession
val session = setupSparkSession()
session.sparkContext.addSparkListener(monitor)
try {
testFun(session, monitor)
Expand Down Expand Up @@ -91,16 +100,18 @@ trait SparkContextSpec {
*
* @return sparkSession to be used
*/
private def setupSparkSession = {
val session = SparkSession.builder()
private def setupSparkSession(wareHouseDir: Option[String] = None) = {
val sessionBuilder = SparkSession.builder()
.master("local")
.appName("test")
.config("spark.ui.enabled", "false")
.config("spark.sql.shuffle.partitions", 2.toString)
.config("spark.sql.adaptive.enabled", value = false)
.config("spark.driver.bindAddress", "127.0.0.1")
.config("spark.sql.warehouse.dir", warehouseDir.toAbsolutePath.toString)
.getOrCreate()

val session = wareHouseDir.fold(sessionBuilder.getOrCreate())(sessionBuilder
.config("spark.sql.warehouse.dir", _).getOrCreate())

session.sparkContext.setCheckpointDir(System.getProperty("java.io.tmpdir"))
session
}
Expand All @@ -124,7 +135,7 @@ trait SparkContextSpec {
private def tearDownSparkSession(session: SparkSession) = {
session.stop()
System.clearProperty("spark.driver.port")
deleteDirectory(warehouseDir)
deleteDirectory(tmpWareHouseDir)

}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not
* use this file except in compliance with the License. A copy of the License
Expand Down Expand Up @@ -35,7 +35,7 @@ class SparkTableMetricsRepositoryTest extends AnyWordSpec
private val analyzer = Size()

"spark table metrics repository " should {
"save and load a single metric" in withSparkSession { spark => {
"save and load a single metric" in withSparkSessionCustomWareHouse { spark =>
val resultKey = ResultKey(System.currentTimeMillis(), Map("tag" -> "value"))
val metric = DoubleMetric(Entity.Column, "m1", "", Try(100))
val context = AnalyzerContext(Map(analyzer -> metric))
Expand All @@ -49,11 +49,10 @@ class SparkTableMetricsRepositoryTest extends AnyWordSpec

assert(loadedContext.isDefined)
assert(loadedContext.get.metric(analyzer).contains(metric))
}

}

"save multiple metrics and load them" in withSparkSession { spark => {
"save multiple metrics and load them" in withSparkSessionCustomWareHouse { spark =>
val repository = new SparkTableMetricsRepository(spark, "metrics_table")

val resultKey1 = ResultKey(System.currentTimeMillis(), Map("tag" -> "tagValue1"))
Expand All @@ -72,10 +71,10 @@ class SparkTableMetricsRepositoryTest extends AnyWordSpec
assert(loadedMetrics.length == 2)

loadedMetrics.flatMap(_.resultKey.tags)
}

}

"save and load metrics with tag" in withSparkSession { spark => {
"save and load metrics with tag" in withSparkSessionCustomWareHouse { spark =>
val repository = new SparkTableMetricsRepository(spark, "metrics_table")

val resultKey1 = ResultKey(System.currentTimeMillis(), Map("tag" -> "A"))
Expand All @@ -90,12 +89,15 @@ class SparkTableMetricsRepositoryTest extends AnyWordSpec
repository.save(resultKey2, context2)
val loadedMetricsForTagA = repository.load().withTagValues(Map("tag" -> "A")).get()
assert(loadedMetricsForTagA.length == 1)
// additional assertions to ensure the loaded metric is the one with tag "A"

val loadedMetricsForMetricM1 = repository.load().forAnalyzers(Seq(analyzer))
assert(loadedMetricsForTagA.length == 1)
val tagsMapA = loadedMetricsForTagA.flatMap(_.resultKey.tags).toMap
assert(tagsMapA.size == 1, "should have 1 result")
assert(tagsMapA.contains("tag"), "should contain tag")
assert(tagsMapA("tag") == "A", "tag should be A")

val loadedMetricsForAllMetrics = repository.load().forAnalyzers(Seq(analyzer)).get()
assert(loadedMetricsForAllMetrics.length == 2, "should have 2 results")

}
}

"save and load to iceberg a single metric" in withSparkSessionIcebergCatalog { spark => {
Expand Down

0 comments on commit 46ded08

Please sign in to comment.