From 717f7fc2803061670431583d4d98922eac007a00 Mon Sep 17 00:00:00 2001 From: penikala <11186040+VenkataKarthikP@users.noreply.github.com> Date: Tue, 28 Nov 2023 11:36:17 -0800 Subject: [PATCH] MetricsRepository using Spark tables as the data source (#518) * spark table repository * review comments --------- Co-authored-by: vpenikalapati --- pom.xml | 6 + .../sparktable/SparkMetricsRepository.scala | 130 ++++++++++++++++++ .../com/amazon/deequ/SparkContextSpec.scala | 55 +++++++- .../SparkTableMetricsRepositoryTest.scala | 121 ++++++++++++++++ 4 files changed, 306 insertions(+), 6 deletions(-) create mode 100644 src/main/scala/com/amazon/deequ/repository/sparktable/SparkMetricsRepository.scala create mode 100644 src/test/scala/com/amazon/deequ/repository/sparktable/SparkTableMetricsRepositoryTest.scala diff --git a/pom.xml b/pom.xml index 47508e571..5847dc922 100644 --- a/pom.xml +++ b/pom.xml @@ -147,6 +147,12 @@ test + + org.apache.iceberg + iceberg-spark-runtime-3.3_2.12 + 0.14.0 + test + diff --git a/src/main/scala/com/amazon/deequ/repository/sparktable/SparkMetricsRepository.scala b/src/main/scala/com/amazon/deequ/repository/sparktable/SparkMetricsRepository.scala new file mode 100644 index 000000000..c5c7290dc --- /dev/null +++ b/src/main/scala/com/amazon/deequ/repository/sparktable/SparkMetricsRepository.scala @@ -0,0 +1,130 @@ +/** + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not + * use this file except in compliance with the License. A copy of the License + * is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + * + */ +package com.amazon.deequ.repository.sparktable + +import com.amazon.deequ.analyzers.Analyzer +import com.amazon.deequ.analyzers.runners.AnalyzerContext +import com.amazon.deequ.metrics.Metric +import com.amazon.deequ.repository._ +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.{DataFrame, SaveMode, SparkSession} + +class SparkTableMetricsRepository(session: SparkSession, tableName: String) extends MetricsRepository { + + import session.implicits._ + + override def save(resultKey: ResultKey, analyzerContext: AnalyzerContext): Unit = { + val serializedContext = AnalysisResultSerde.serialize(Seq(AnalysisResult(resultKey, analyzerContext))) + + val successfulMetrics = analyzerContext.metricMap + .filter { case (_, metric) => metric.value.isSuccess } + + val metricDF = successfulMetrics.map { case (analyzer, metric) => + SparkTableMetric(resultKey.toString, analyzer.toString, metric.value.toString, + resultKey.dataSetDate, serializedContext) + }.toSeq.toDF() + + metricDF.write + .mode(SaveMode.Append) + .saveAsTable(tableName) + } + + override def loadByKey(resultKey: ResultKey): Option[AnalyzerContext] = { + val df: DataFrame = session.table(tableName) + val matchingRows = df.filter(col("resultKey") === resultKey.toString).collect() + + if (matchingRows.isEmpty) { + None + } else { + val serializedContext = matchingRows(0).getAs[String]("serializedContext") + AnalysisResultSerde.deserialize(serializedContext).headOption.map(_.analyzerContext) + } + } + + override def load(): MetricsRepositoryMultipleResultsLoader = { + SparkTableMetricsRepositoryMultipleResultsLoader(session, tableName) + } + +} + +case class SparkTableMetric(resultKey: String, metricName: String, metricValue: String, resultTimestamp: Long, + serializedContext: String) + +case class SparkTableMetricsRepositoryMultipleResultsLoader(session: SparkSession, + tableName: String, + private val tagValues: Option[Map[String, String]] = None, + private val analyzers: Option[Seq[Analyzer[_, Metric[_]]]] + = None, + private val timeAfter: Option[Long] = None, + private val timeBefore: Option[Long] = None + ) extends MetricsRepositoryMultipleResultsLoader { + + override def withTagValues(tagValues: Map[String, String]): MetricsRepositoryMultipleResultsLoader = + this.copy(tagValues = Some(tagValues)) + + override def forAnalyzers(analyzers: Seq[Analyzer[_, Metric[_]]]): MetricsRepositoryMultipleResultsLoader = + this.copy(analyzers = Some(analyzers)) + + override def after(dateTime: Long): MetricsRepositoryMultipleResultsLoader = + this.copy(timeAfter = Some(dateTime)) + + override def before(dateTime: Long): MetricsRepositoryMultipleResultsLoader = + this.copy(timeBefore = Some(dateTime)) + + override def get(): Seq[AnalysisResult] = { + val initialDF: DataFrame = session.table(tableName) + + val tagValuesFilter: DataFrame => DataFrame = df => { + tagValues.map { tags => + tags.foldLeft(df) { (currentDF, tag) => + currentDF.filter(row => { + val ser = row.getAs[String]("serializedContext") + AnalysisResultSerde.deserialize(ser).exists(ar => { + val tags = ar.resultKey.tags + tags.contains(tag._1) && tags(tag._1) == tag._2 + }) + }) + } + }.getOrElse(df) + } + + val specificAnalyzersFilter: DataFrame => DataFrame = df => { + analyzers.map(analyzers => df.filter(col("metricName").isin(analyzers.map(_.toString): _*))) + .getOrElse(df) + } + + val timeAfterFilter: DataFrame => DataFrame = df => { + timeAfter.map(time => df.filter(col("resultTimestamp") > time.toString)).getOrElse(df) + } + + val timeBeforeFilter: DataFrame => DataFrame = df => { + timeBefore.map(time => df.filter(col("resultTimestamp") < time.toString)).getOrElse(df) + } + + val filteredDF = Seq(tagValuesFilter, specificAnalyzersFilter, timeAfterFilter, timeBeforeFilter) + .foldLeft(initialDF) { + (df, filter) => filter(df) + } + + // Convert the final DataFrame to the desired output format + filteredDF.collect().flatMap(row => { + val serializedContext = row.getAs[String]("serializedContext") + AnalysisResultSerde.deserialize(serializedContext) + }).toSeq + } + + +} diff --git a/src/test/scala/com/amazon/deequ/SparkContextSpec.scala b/src/test/scala/com/amazon/deequ/SparkContextSpec.scala index b54e05eee..81b1fd190 100644 --- a/src/test/scala/com/amazon/deequ/SparkContextSpec.scala +++ b/src/test/scala/com/amazon/deequ/SparkContextSpec.scala @@ -1,5 +1,5 @@ /** - * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not * use this file except in compliance with the License. A copy of the License @@ -19,16 +19,44 @@ package com.amazon.deequ import org.apache.spark.SparkContext import org.apache.spark.sql.{SQLContext, SparkSession} +import java.nio.file.{Files, Path} +import scala.collection.convert.ImplicitConversions.`iterator asScala` + /** * To be mixed with Tests so they can use a default spark context suitable for testing */ trait SparkContextSpec { + val tmpWareHouseDir: Path = Files.createTempDirectory("deequ_tmp") + /** * @param testFun thunk to run with SparkSession as an argument */ def withSparkSession(testFun: SparkSession => Any): Unit = { - val session = setupSparkSession + val session = setupSparkSession() + try { + testFun(session) + } finally { + /* empty cache of RDD size, as the referred ids are only valid within a session */ + tearDownSparkSession(session) + } + } + + def withSparkSessionCustomWareHouse(testFun: SparkSession => Any): Unit = { + val session = setupSparkSession(Some(tmpWareHouseDir.toAbsolutePath.toString)) + try { + testFun(session) + } finally { + tearDownSparkSession(session) + } + } + + def withSparkSessionIcebergCatalog(testFun: SparkSession => Any): Unit = { + val session = setupSparkSession(Some(tmpWareHouseDir.toAbsolutePath.toString)) + session.conf.set("spark.sql.catalog.local", "org.apache.iceberg.spark.SparkCatalog") + session.conf.set("spark.sql.catalog.local.type", "hadoop") + session.conf.set("spark.sql.catalog.local.warehouse", tmpWareHouseDir.toAbsolutePath.toString) + try { testFun(session) } finally { @@ -44,7 +72,7 @@ trait SparkContextSpec { */ def withMonitorableSparkSession(testFun: (SparkSession, SparkMonitor) => Any): Unit = { val monitor = new SparkMonitor - val session = setupSparkSession + val session = setupSparkSession() session.sparkContext.addSparkListener(monitor) try { testFun(session, monitor) @@ -72,19 +100,32 @@ trait SparkContextSpec { * * @return sparkSession to be used */ - private def setupSparkSession = { - val session = SparkSession.builder() + private def setupSparkSession(wareHouseDir: Option[String] = None) = { + val sessionBuilder = SparkSession.builder() .master("local") .appName("test") .config("spark.ui.enabled", "false") .config("spark.sql.shuffle.partitions", 2.toString) .config("spark.sql.adaptive.enabled", value = false) .config("spark.driver.bindAddress", "127.0.0.1") - .getOrCreate() + + val session = wareHouseDir.fold(sessionBuilder.getOrCreate())(sessionBuilder + .config("spark.sql.warehouse.dir", _).getOrCreate()) + session.sparkContext.setCheckpointDir(System.getProperty("java.io.tmpdir")) session } + /** + * to cleanup temp directory used in test + * @param path - path to cleanup + */ + private def deleteDirectory(path: Path): Unit = { + if (Files.exists(path)) { + Files.walk(path).iterator().toList.reverse.foreach(Files.delete) + } + } + /** * Tears down the sparkSession * @@ -94,6 +135,8 @@ trait SparkContextSpec { private def tearDownSparkSession(session: SparkSession) = { session.stop() System.clearProperty("spark.driver.port") + deleteDirectory(tmpWareHouseDir) + } } diff --git a/src/test/scala/com/amazon/deequ/repository/sparktable/SparkTableMetricsRepositoryTest.scala b/src/test/scala/com/amazon/deequ/repository/sparktable/SparkTableMetricsRepositoryTest.scala new file mode 100644 index 000000000..667b5b502 --- /dev/null +++ b/src/test/scala/com/amazon/deequ/repository/sparktable/SparkTableMetricsRepositoryTest.scala @@ -0,0 +1,121 @@ +/** + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not + * use this file except in compliance with the License. A copy of the License + * is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + * + */ + +package com.amazon.deequ.repository.sparktable + +import com.amazon.deequ.SparkContextSpec +import com.amazon.deequ.analyzers.Size +import com.amazon.deequ.analyzers.runners.AnalyzerContext +import com.amazon.deequ.metrics.{DoubleMetric, Entity} +import com.amazon.deequ.repository.ResultKey +import com.amazon.deequ.utils.FixtureSupport +import org.scalatest.wordspec.AnyWordSpec + +import scala.util.Try + +class SparkTableMetricsRepositoryTest extends AnyWordSpec + with SparkContextSpec + with FixtureSupport { + + // private var spark: SparkSession = _ + // private var repository: SparkTableMetricsRepository = _ + private val analyzer = Size() + + "spark table metrics repository " should { + "save and load a single metric" in withSparkSessionCustomWareHouse { spark => + val resultKey = ResultKey(System.currentTimeMillis(), Map("tag" -> "value")) + val metric = DoubleMetric(Entity.Column, "m1", "", Try(100)) + val context = AnalyzerContext(Map(analyzer -> metric)) + + val repository = new SparkTableMetricsRepository(spark, "metrics_table") + // Save the metric + repository.save(resultKey, context) + + // Load the metric + val loadedContext = repository.loadByKey(resultKey) + + assert(loadedContext.isDefined) + assert(loadedContext.get.metric(analyzer).contains(metric)) + + } + + "save multiple metrics and load them" in withSparkSessionCustomWareHouse { spark => + val repository = new SparkTableMetricsRepository(spark, "metrics_table") + + val resultKey1 = ResultKey(System.currentTimeMillis(), Map("tag" -> "tagValue1")) + val metric = DoubleMetric(Entity.Column, "m1", "", Try(100)) + val context1 = AnalyzerContext(Map(analyzer -> metric)) + + val resultKey2 = ResultKey(System.currentTimeMillis(), Map("tag" -> "tagValue2")) + val metric2 = DoubleMetric(Entity.Column, "m2", "", Try(101)) + val context2 = AnalyzerContext(Map(analyzer -> metric2)) + + repository.save(resultKey1, context1) + repository.save(resultKey2, context2) + + val loadedMetrics = repository.load().get() + + assert(loadedMetrics.length == 2) + + loadedMetrics.flatMap(_.resultKey.tags) + + } + + "save and load metrics with tag" in withSparkSessionCustomWareHouse { spark => + val repository = new SparkTableMetricsRepository(spark, "metrics_table") + + val resultKey1 = ResultKey(System.currentTimeMillis(), Map("tag" -> "A")) + val metric = DoubleMetric(Entity.Column, "m1", "", Try(100)) + val context1 = AnalyzerContext(Map(analyzer -> metric)) + + val resultKey2 = ResultKey(System.currentTimeMillis(), Map("tag" -> "B")) + val metric2 = DoubleMetric(Entity.Column, "m2", "", Try(101)) + val context2 = AnalyzerContext(Map(analyzer -> metric2)) + + repository.save(resultKey1, context1) + repository.save(resultKey2, context2) + val loadedMetricsForTagA = repository.load().withTagValues(Map("tag" -> "A")).get() + assert(loadedMetricsForTagA.length == 1) + + val tagsMapA = loadedMetricsForTagA.flatMap(_.resultKey.tags).toMap + assert(tagsMapA.size == 1, "should have 1 result") + assert(tagsMapA.contains("tag"), "should contain tag") + assert(tagsMapA("tag") == "A", "tag should be A") + + val loadedMetricsForAllMetrics = repository.load().forAnalyzers(Seq(analyzer)).get() + assert(loadedMetricsForAllMetrics.length == 2, "should have 2 results") + + } + + "save and load to iceberg a single metric" in withSparkSessionIcebergCatalog { spark => { + val resultKey = ResultKey(System.currentTimeMillis(), Map("tag" -> "value")) + val metric = DoubleMetric(Entity.Column, "m1", "", Try(100)) + val context = AnalyzerContext(Map(analyzer -> metric)) + + val repository = new SparkTableMetricsRepository(spark, "local.metrics_table") + // Save the metric + repository.save(resultKey, context) + + // Load the metric + val loadedContext = repository.loadByKey(resultKey) + + assert(loadedContext.isDefined) + assert(loadedContext.get.metric(analyzer).contains(metric)) + } + + } + } +}