diff --git a/pom.xml b/pom.xml
index 47508e57..5847dc92 100644
--- a/pom.xml
+++ b/pom.xml
@@ -147,6 +147,12 @@
test
+
+ org.apache.iceberg
+ iceberg-spark-runtime-3.3_2.12
+ 0.14.0
+ test
+
diff --git a/src/main/scala/com/amazon/deequ/repository/sparktable/SparkMetricsRepository.scala b/src/main/scala/com/amazon/deequ/repository/sparktable/SparkMetricsRepository.scala
new file mode 100644
index 00000000..c5c7290d
--- /dev/null
+++ b/src/main/scala/com/amazon/deequ/repository/sparktable/SparkMetricsRepository.scala
@@ -0,0 +1,130 @@
+/**
+ * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not
+ * use this file except in compliance with the License. A copy of the License
+ * is located at
+ *
+ * http://aws.amazon.com/apache2.0/
+ *
+ * or in the "license" file accompanying this file. This file is distributed on
+ * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
+ * express or implied. See the License for the specific language governing
+ * permissions and limitations under the License.
+ *
+ */
+package com.amazon.deequ.repository.sparktable
+
+import com.amazon.deequ.analyzers.Analyzer
+import com.amazon.deequ.analyzers.runners.AnalyzerContext
+import com.amazon.deequ.metrics.Metric
+import com.amazon.deequ.repository._
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.{DataFrame, SaveMode, SparkSession}
+
+class SparkTableMetricsRepository(session: SparkSession, tableName: String) extends MetricsRepository {
+
+ import session.implicits._
+
+ override def save(resultKey: ResultKey, analyzerContext: AnalyzerContext): Unit = {
+ val serializedContext = AnalysisResultSerde.serialize(Seq(AnalysisResult(resultKey, analyzerContext)))
+
+ val successfulMetrics = analyzerContext.metricMap
+ .filter { case (_, metric) => metric.value.isSuccess }
+
+ val metricDF = successfulMetrics.map { case (analyzer, metric) =>
+ SparkTableMetric(resultKey.toString, analyzer.toString, metric.value.toString,
+ resultKey.dataSetDate, serializedContext)
+ }.toSeq.toDF()
+
+ metricDF.write
+ .mode(SaveMode.Append)
+ .saveAsTable(tableName)
+ }
+
+ override def loadByKey(resultKey: ResultKey): Option[AnalyzerContext] = {
+ val df: DataFrame = session.table(tableName)
+ val matchingRows = df.filter(col("resultKey") === resultKey.toString).collect()
+
+ if (matchingRows.isEmpty) {
+ None
+ } else {
+ val serializedContext = matchingRows(0).getAs[String]("serializedContext")
+ AnalysisResultSerde.deserialize(serializedContext).headOption.map(_.analyzerContext)
+ }
+ }
+
+ override def load(): MetricsRepositoryMultipleResultsLoader = {
+ SparkTableMetricsRepositoryMultipleResultsLoader(session, tableName)
+ }
+
+}
+
+case class SparkTableMetric(resultKey: String, metricName: String, metricValue: String, resultTimestamp: Long,
+ serializedContext: String)
+
+case class SparkTableMetricsRepositoryMultipleResultsLoader(session: SparkSession,
+ tableName: String,
+ private val tagValues: Option[Map[String, String]] = None,
+ private val analyzers: Option[Seq[Analyzer[_, Metric[_]]]]
+ = None,
+ private val timeAfter: Option[Long] = None,
+ private val timeBefore: Option[Long] = None
+ ) extends MetricsRepositoryMultipleResultsLoader {
+
+ override def withTagValues(tagValues: Map[String, String]): MetricsRepositoryMultipleResultsLoader =
+ this.copy(tagValues = Some(tagValues))
+
+ override def forAnalyzers(analyzers: Seq[Analyzer[_, Metric[_]]]): MetricsRepositoryMultipleResultsLoader =
+ this.copy(analyzers = Some(analyzers))
+
+ override def after(dateTime: Long): MetricsRepositoryMultipleResultsLoader =
+ this.copy(timeAfter = Some(dateTime))
+
+ override def before(dateTime: Long): MetricsRepositoryMultipleResultsLoader =
+ this.copy(timeBefore = Some(dateTime))
+
+ override def get(): Seq[AnalysisResult] = {
+ val initialDF: DataFrame = session.table(tableName)
+
+ val tagValuesFilter: DataFrame => DataFrame = df => {
+ tagValues.map { tags =>
+ tags.foldLeft(df) { (currentDF, tag) =>
+ currentDF.filter(row => {
+ val ser = row.getAs[String]("serializedContext")
+ AnalysisResultSerde.deserialize(ser).exists(ar => {
+ val tags = ar.resultKey.tags
+ tags.contains(tag._1) && tags(tag._1) == tag._2
+ })
+ })
+ }
+ }.getOrElse(df)
+ }
+
+ val specificAnalyzersFilter: DataFrame => DataFrame = df => {
+ analyzers.map(analyzers => df.filter(col("metricName").isin(analyzers.map(_.toString): _*)))
+ .getOrElse(df)
+ }
+
+ val timeAfterFilter: DataFrame => DataFrame = df => {
+ timeAfter.map(time => df.filter(col("resultTimestamp") > time.toString)).getOrElse(df)
+ }
+
+ val timeBeforeFilter: DataFrame => DataFrame = df => {
+ timeBefore.map(time => df.filter(col("resultTimestamp") < time.toString)).getOrElse(df)
+ }
+
+ val filteredDF = Seq(tagValuesFilter, specificAnalyzersFilter, timeAfterFilter, timeBeforeFilter)
+ .foldLeft(initialDF) {
+ (df, filter) => filter(df)
+ }
+
+ // Convert the final DataFrame to the desired output format
+ filteredDF.collect().flatMap(row => {
+ val serializedContext = row.getAs[String]("serializedContext")
+ AnalysisResultSerde.deserialize(serializedContext)
+ }).toSeq
+ }
+
+
+}
diff --git a/src/test/scala/com/amazon/deequ/SparkContextSpec.scala b/src/test/scala/com/amazon/deequ/SparkContextSpec.scala
index b54e05ee..81b1fd19 100644
--- a/src/test/scala/com/amazon/deequ/SparkContextSpec.scala
+++ b/src/test/scala/com/amazon/deequ/SparkContextSpec.scala
@@ -1,5 +1,5 @@
/**
- * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not
* use this file except in compliance with the License. A copy of the License
@@ -19,16 +19,44 @@ package com.amazon.deequ
import org.apache.spark.SparkContext
import org.apache.spark.sql.{SQLContext, SparkSession}
+import java.nio.file.{Files, Path}
+import scala.collection.convert.ImplicitConversions.`iterator asScala`
+
/**
* To be mixed with Tests so they can use a default spark context suitable for testing
*/
trait SparkContextSpec {
+ val tmpWareHouseDir: Path = Files.createTempDirectory("deequ_tmp")
+
/**
* @param testFun thunk to run with SparkSession as an argument
*/
def withSparkSession(testFun: SparkSession => Any): Unit = {
- val session = setupSparkSession
+ val session = setupSparkSession()
+ try {
+ testFun(session)
+ } finally {
+ /* empty cache of RDD size, as the referred ids are only valid within a session */
+ tearDownSparkSession(session)
+ }
+ }
+
+ def withSparkSessionCustomWareHouse(testFun: SparkSession => Any): Unit = {
+ val session = setupSparkSession(Some(tmpWareHouseDir.toAbsolutePath.toString))
+ try {
+ testFun(session)
+ } finally {
+ tearDownSparkSession(session)
+ }
+ }
+
+ def withSparkSessionIcebergCatalog(testFun: SparkSession => Any): Unit = {
+ val session = setupSparkSession(Some(tmpWareHouseDir.toAbsolutePath.toString))
+ session.conf.set("spark.sql.catalog.local", "org.apache.iceberg.spark.SparkCatalog")
+ session.conf.set("spark.sql.catalog.local.type", "hadoop")
+ session.conf.set("spark.sql.catalog.local.warehouse", tmpWareHouseDir.toAbsolutePath.toString)
+
try {
testFun(session)
} finally {
@@ -44,7 +72,7 @@ trait SparkContextSpec {
*/
def withMonitorableSparkSession(testFun: (SparkSession, SparkMonitor) => Any): Unit = {
val monitor = new SparkMonitor
- val session = setupSparkSession
+ val session = setupSparkSession()
session.sparkContext.addSparkListener(monitor)
try {
testFun(session, monitor)
@@ -72,19 +100,32 @@ trait SparkContextSpec {
*
* @return sparkSession to be used
*/
- private def setupSparkSession = {
- val session = SparkSession.builder()
+ private def setupSparkSession(wareHouseDir: Option[String] = None) = {
+ val sessionBuilder = SparkSession.builder()
.master("local")
.appName("test")
.config("spark.ui.enabled", "false")
.config("spark.sql.shuffle.partitions", 2.toString)
.config("spark.sql.adaptive.enabled", value = false)
.config("spark.driver.bindAddress", "127.0.0.1")
- .getOrCreate()
+
+ val session = wareHouseDir.fold(sessionBuilder.getOrCreate())(sessionBuilder
+ .config("spark.sql.warehouse.dir", _).getOrCreate())
+
session.sparkContext.setCheckpointDir(System.getProperty("java.io.tmpdir"))
session
}
+ /**
+ * to cleanup temp directory used in test
+ * @param path - path to cleanup
+ */
+ private def deleteDirectory(path: Path): Unit = {
+ if (Files.exists(path)) {
+ Files.walk(path).iterator().toList.reverse.foreach(Files.delete)
+ }
+ }
+
/**
* Tears down the sparkSession
*
@@ -94,6 +135,8 @@ trait SparkContextSpec {
private def tearDownSparkSession(session: SparkSession) = {
session.stop()
System.clearProperty("spark.driver.port")
+ deleteDirectory(tmpWareHouseDir)
+
}
}
diff --git a/src/test/scala/com/amazon/deequ/repository/sparktable/SparkTableMetricsRepositoryTest.scala b/src/test/scala/com/amazon/deequ/repository/sparktable/SparkTableMetricsRepositoryTest.scala
new file mode 100644
index 00000000..667b5b50
--- /dev/null
+++ b/src/test/scala/com/amazon/deequ/repository/sparktable/SparkTableMetricsRepositoryTest.scala
@@ -0,0 +1,121 @@
+/**
+ * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not
+ * use this file except in compliance with the License. A copy of the License
+ * is located at
+ *
+ * http://aws.amazon.com/apache2.0/
+ *
+ * or in the "license" file accompanying this file. This file is distributed on
+ * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
+ * express or implied. See the License for the specific language governing
+ * permissions and limitations under the License.
+ *
+ */
+
+package com.amazon.deequ.repository.sparktable
+
+import com.amazon.deequ.SparkContextSpec
+import com.amazon.deequ.analyzers.Size
+import com.amazon.deequ.analyzers.runners.AnalyzerContext
+import com.amazon.deequ.metrics.{DoubleMetric, Entity}
+import com.amazon.deequ.repository.ResultKey
+import com.amazon.deequ.utils.FixtureSupport
+import org.scalatest.wordspec.AnyWordSpec
+
+import scala.util.Try
+
+class SparkTableMetricsRepositoryTest extends AnyWordSpec
+ with SparkContextSpec
+ with FixtureSupport {
+
+ // private var spark: SparkSession = _
+ // private var repository: SparkTableMetricsRepository = _
+ private val analyzer = Size()
+
+ "spark table metrics repository " should {
+ "save and load a single metric" in withSparkSessionCustomWareHouse { spark =>
+ val resultKey = ResultKey(System.currentTimeMillis(), Map("tag" -> "value"))
+ val metric = DoubleMetric(Entity.Column, "m1", "", Try(100))
+ val context = AnalyzerContext(Map(analyzer -> metric))
+
+ val repository = new SparkTableMetricsRepository(spark, "metrics_table")
+ // Save the metric
+ repository.save(resultKey, context)
+
+ // Load the metric
+ val loadedContext = repository.loadByKey(resultKey)
+
+ assert(loadedContext.isDefined)
+ assert(loadedContext.get.metric(analyzer).contains(metric))
+
+ }
+
+ "save multiple metrics and load them" in withSparkSessionCustomWareHouse { spark =>
+ val repository = new SparkTableMetricsRepository(spark, "metrics_table")
+
+ val resultKey1 = ResultKey(System.currentTimeMillis(), Map("tag" -> "tagValue1"))
+ val metric = DoubleMetric(Entity.Column, "m1", "", Try(100))
+ val context1 = AnalyzerContext(Map(analyzer -> metric))
+
+ val resultKey2 = ResultKey(System.currentTimeMillis(), Map("tag" -> "tagValue2"))
+ val metric2 = DoubleMetric(Entity.Column, "m2", "", Try(101))
+ val context2 = AnalyzerContext(Map(analyzer -> metric2))
+
+ repository.save(resultKey1, context1)
+ repository.save(resultKey2, context2)
+
+ val loadedMetrics = repository.load().get()
+
+ assert(loadedMetrics.length == 2)
+
+ loadedMetrics.flatMap(_.resultKey.tags)
+
+ }
+
+ "save and load metrics with tag" in withSparkSessionCustomWareHouse { spark =>
+ val repository = new SparkTableMetricsRepository(spark, "metrics_table")
+
+ val resultKey1 = ResultKey(System.currentTimeMillis(), Map("tag" -> "A"))
+ val metric = DoubleMetric(Entity.Column, "m1", "", Try(100))
+ val context1 = AnalyzerContext(Map(analyzer -> metric))
+
+ val resultKey2 = ResultKey(System.currentTimeMillis(), Map("tag" -> "B"))
+ val metric2 = DoubleMetric(Entity.Column, "m2", "", Try(101))
+ val context2 = AnalyzerContext(Map(analyzer -> metric2))
+
+ repository.save(resultKey1, context1)
+ repository.save(resultKey2, context2)
+ val loadedMetricsForTagA = repository.load().withTagValues(Map("tag" -> "A")).get()
+ assert(loadedMetricsForTagA.length == 1)
+
+ val tagsMapA = loadedMetricsForTagA.flatMap(_.resultKey.tags).toMap
+ assert(tagsMapA.size == 1, "should have 1 result")
+ assert(tagsMapA.contains("tag"), "should contain tag")
+ assert(tagsMapA("tag") == "A", "tag should be A")
+
+ val loadedMetricsForAllMetrics = repository.load().forAnalyzers(Seq(analyzer)).get()
+ assert(loadedMetricsForAllMetrics.length == 2, "should have 2 results")
+
+ }
+
+ "save and load to iceberg a single metric" in withSparkSessionIcebergCatalog { spark => {
+ val resultKey = ResultKey(System.currentTimeMillis(), Map("tag" -> "value"))
+ val metric = DoubleMetric(Entity.Column, "m1", "", Try(100))
+ val context = AnalyzerContext(Map(analyzer -> metric))
+
+ val repository = new SparkTableMetricsRepository(spark, "local.metrics_table")
+ // Save the metric
+ repository.save(resultKey, context)
+
+ // Load the metric
+ val loadedContext = repository.loadByKey(resultKey)
+
+ assert(loadedContext.isDefined)
+ assert(loadedContext.get.metric(analyzer).contains(metric))
+ }
+
+ }
+ }
+}