From 209eba308d954fbbd55892737e3673a194ab2074 Mon Sep 17 00:00:00 2001 From: vpenikalapati Date: Sun, 29 Oct 2023 17:11:32 -0700 Subject: [PATCH 1/2] spark table repository --- pom.xml | 6 + .../sparktable/SparkMetricsRepository.scala | 135 ++++++++++++++++++ .../com/amazon/deequ/SparkContextSpec.scala | 32 +++++ .../SparkTableMetricsRepositoryTest.scala | 119 +++++++++++++++ 4 files changed, 292 insertions(+) create mode 100644 src/main/scala/com/amazon/deequ/repository/sparktable/SparkMetricsRepository.scala create mode 100644 src/test/scala/com/amazon/deequ/repository/sparktable/SparkTableMetricsRepositoryTest.scala diff --git a/pom.xml b/pom.xml index 262a670f6..d57cc1b52 100644 --- a/pom.xml +++ b/pom.xml @@ -155,6 +155,12 @@ test + + org.apache.iceberg + iceberg-spark-runtime-3.3_2.12 + 0.14.0 + test + diff --git a/src/main/scala/com/amazon/deequ/repository/sparktable/SparkMetricsRepository.scala b/src/main/scala/com/amazon/deequ/repository/sparktable/SparkMetricsRepository.scala new file mode 100644 index 000000000..5fb195b07 --- /dev/null +++ b/src/main/scala/com/amazon/deequ/repository/sparktable/SparkMetricsRepository.scala @@ -0,0 +1,135 @@ +/** + * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not + * use this file except in compliance with the License. A copy of the License + * is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + * + */ +package com.amazon.deequ.repository.sparktable + +import com.amazon.deequ.analyzers.Analyzer +import com.amazon.deequ.analyzers.runners.AnalyzerContext +import com.amazon.deequ.metrics.Metric +import com.amazon.deequ.repository._ +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.{StringType, StructField, StructType} +import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession} + +class SparkTableMetricsRepository(session: SparkSession, tableName: String) extends MetricsRepository { + + private val SCHEMA = StructType(Array( + StructField("result_key", StringType), + StructField("metric_name", StringType), + StructField("metric_value", StringType), + StructField("result_timestamp", StringType), + StructField("serialized_context", StringType) + )) + + override def save(resultKey: ResultKey, analyzerContext: AnalyzerContext): Unit = { + val serializedContext = AnalysisResultSerde.serialize(Seq(AnalysisResult(resultKey, analyzerContext))) + + val rows = analyzerContext.metricMap.map { case (analyzer, metric) => + Row(resultKey.toString, analyzer.toString, metric.value.toString, + resultKey.dataSetDate.toString, serializedContext) + }.toSeq + + val metricDF = session.createDataFrame(session.sparkContext.parallelize(rows), SCHEMA) + + metricDF.write + .mode(SaveMode.Append) + .saveAsTable(tableName) + } + + override def loadByKey(resultKey: ResultKey): Option[AnalyzerContext] = { + val df: DataFrame = session.table(tableName) + val matchingRows = df.filter(col("result_key") === resultKey.toString).collect() + + if (matchingRows.isEmpty) { + None + } else { + val serializedContext = matchingRows(0).getAs[String]("serialized_context") + val analysisResult = AnalysisResultSerde.deserialize(serializedContext).head + Some(analysisResult.analyzerContext) + } + } + + override def load(): MetricsRepositoryMultipleResultsLoader = { + SparkTableMetricsRepositoryMultipleResultsLoader(session, tableName) + } + +} + + +case class SparkTableMetricsRepositoryMultipleResultsLoader(session: SparkSession, + tableName: String, + tagValues: Option[Map[String, String]] = None, + analyzers: Option[Seq[Analyzer[_, Metric[_]]]] = None, + timeAfter: Option[Long] = None, + timeBefore: Option[Long] = None + ) extends MetricsRepositoryMultipleResultsLoader { + + override def withTagValues(tagValues: Map[String, String]): MetricsRepositoryMultipleResultsLoader = + this.copy(tagValues = Some(tagValues)) + + override def forAnalyzers(analyzers: Seq[Analyzer[_, Metric[_]]]): MetricsRepositoryMultipleResultsLoader = + this.copy(analyzers = Some(analyzers)) + + override def after(dateTime: Long): MetricsRepositoryMultipleResultsLoader = + this.copy(timeAfter = Some(dateTime)) + + override def before(dateTime: Long): MetricsRepositoryMultipleResultsLoader = + this.copy(timeBefore = Some(dateTime)) + + override def get(): Seq[AnalysisResult] = { + val initialDF: DataFrame = session.table(tableName) + + initialDF.printSchema() + val tagValuesFilter: DataFrame => DataFrame = df => { + tagValues.map { tags => + tags.foldLeft(df) { (currentDF, tag) => + currentDF.filter(row => { + val ser = row.getAs[String]("serialized_context") + AnalysisResultSerde.deserialize(ser).exists(ar => { + val tags = ar.resultKey.tags + tags.contains(tag._1) && tags(tag._1) == tag._2 + }) + }) + } + }.getOrElse(df) + } + + val specificAnalyzersFilter: DataFrame => DataFrame = df => { + analyzers.map(analyzers => df.filter(col("metric_name").isin(analyzers.map(_.toString): _*))) + .getOrElse(df) + } + + val timeAfterFilter: DataFrame => DataFrame = df => { + timeAfter.map(time => df.filter(col("result_timestamp") > time.toString)).getOrElse(df) + } + + val timeBeforeFilter: DataFrame => DataFrame = df => { + timeBefore.map(time => df.filter(col("result_timestamp") < time.toString)).getOrElse(df) + } + + val filteredDF = Seq(tagValuesFilter, specificAnalyzersFilter, timeAfterFilter, timeBeforeFilter) + .foldLeft(initialDF) { + (df, filter) => filter(df) + } + + // Convert the final DataFrame to the desired output format + filteredDF.collect().flatMap(row => { + val serializedContext = row.getAs[String]("serialized_context") + AnalysisResultSerde.deserialize(serializedContext) + }).toSeq + } + + +} diff --git a/src/test/scala/com/amazon/deequ/SparkContextSpec.scala b/src/test/scala/com/amazon/deequ/SparkContextSpec.scala index b54e05eee..cff2c7448 100644 --- a/src/test/scala/com/amazon/deequ/SparkContextSpec.scala +++ b/src/test/scala/com/amazon/deequ/SparkContextSpec.scala @@ -19,11 +19,16 @@ package com.amazon.deequ import org.apache.spark.SparkContext import org.apache.spark.sql.{SQLContext, SparkSession} +import java.nio.file.{Files, Path} +import scala.collection.convert.ImplicitConversions.`iterator asScala` + /** * To be mixed with Tests so they can use a default spark context suitable for testing */ trait SparkContextSpec { + val warehouseDir: Path = Files.createTempDirectory("my_temp_dir_") + /** * @param testFun thunk to run with SparkSession as an argument */ @@ -37,6 +42,20 @@ trait SparkContextSpec { } } + def withSparkSessionIcebergCatalog(testFun: SparkSession => Any): Unit = { + val session = setupSparkSession + session.conf.set("spark.sql.catalog.local", "org.apache.iceberg.spark.SparkCatalog") + session.conf.set("spark.sql.catalog.local.type", "hadoop") + session.conf.set("spark.sql.catalog.local.warehouse", warehouseDir.toAbsolutePath.toString) + + try { + testFun(session) + } finally { + /* empty cache of RDD size, as the referred ids are only valid within a session */ + tearDownSparkSession(session) + } + } + /** * @param testFun thunk to run with SparkSession and SparkMonitor as an argument for the tests * that would like to get details on spark jobs submitted @@ -80,11 +99,22 @@ trait SparkContextSpec { .config("spark.sql.shuffle.partitions", 2.toString) .config("spark.sql.adaptive.enabled", value = false) .config("spark.driver.bindAddress", "127.0.0.1") + .config("spark.sql.warehouse.dir", warehouseDir.toAbsolutePath.toString) .getOrCreate() session.sparkContext.setCheckpointDir(System.getProperty("java.io.tmpdir")) session } + /** + * to cleanup temp directory used in test + * @param path - path to cleanup + */ + private def deleteDirectory(path: Path): Unit = { + if (Files.exists(path)) { + Files.walk(path).iterator().toList.reverse.foreach(Files.delete) + } + } + /** * Tears down the sparkSession * @@ -94,6 +124,8 @@ trait SparkContextSpec { private def tearDownSparkSession(session: SparkSession) = { session.stop() System.clearProperty("spark.driver.port") + deleteDirectory(warehouseDir) + } } diff --git a/src/test/scala/com/amazon/deequ/repository/sparktable/SparkTableMetricsRepositoryTest.scala b/src/test/scala/com/amazon/deequ/repository/sparktable/SparkTableMetricsRepositoryTest.scala new file mode 100644 index 000000000..5b66ce305 --- /dev/null +++ b/src/test/scala/com/amazon/deequ/repository/sparktable/SparkTableMetricsRepositoryTest.scala @@ -0,0 +1,119 @@ +/** + * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not + * use this file except in compliance with the License. A copy of the License + * is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + * + */ + +package com.amazon.deequ.repository.sparktable + +import com.amazon.deequ.SparkContextSpec +import com.amazon.deequ.analyzers.Size +import com.amazon.deequ.analyzers.runners.AnalyzerContext +import com.amazon.deequ.metrics.{DoubleMetric, Entity} +import com.amazon.deequ.repository.ResultKey +import com.amazon.deequ.utils.FixtureSupport +import org.scalatest.wordspec.AnyWordSpec + +import scala.util.Try + +class SparkTableMetricsRepositoryTest extends AnyWordSpec + with SparkContextSpec + with FixtureSupport { + + // private var spark: SparkSession = _ + // private var repository: SparkTableMetricsRepository = _ + private val analyzer = Size() + + "spark table metrics repository " should { + "save and load a single metric" in withSparkSession { spark => { + val resultKey = ResultKey(System.currentTimeMillis(), Map("tag" -> "value")) + val metric = DoubleMetric(Entity.Column, "m1", "", Try(100)) + val context = AnalyzerContext(Map(analyzer -> metric)) + + val repository = new SparkTableMetricsRepository(spark, "metrics_table") + // Save the metric + repository.save(resultKey, context) + + // Load the metric + val loadedContext = repository.loadByKey(resultKey) + + assert(loadedContext.isDefined) + assert(loadedContext.get.metric(analyzer).contains(metric)) + } + + } + + "save multiple metrics and load them" in withSparkSession { spark => { + val repository = new SparkTableMetricsRepository(spark, "metrics_table") + + val resultKey1 = ResultKey(System.currentTimeMillis(), Map("tag" -> "tagValue1")) + val metric = DoubleMetric(Entity.Column, "m1", "", Try(100)) + val context1 = AnalyzerContext(Map(analyzer -> metric)) + + val resultKey2 = ResultKey(System.currentTimeMillis(), Map("tag" -> "tagValue2")) + val metric2 = DoubleMetric(Entity.Column, "m2", "", Try(101)) + val context2 = AnalyzerContext(Map(analyzer -> metric2)) + + repository.save(resultKey1, context1) + repository.save(resultKey2, context2) + + val loadedMetrics = repository.load().get() + + assert(loadedMetrics.length == 2) + + loadedMetrics.flatMap(_.resultKey.tags) + } + } + + "save and load metrics with tag" in withSparkSession { spark => { + val repository = new SparkTableMetricsRepository(spark, "metrics_table") + + val resultKey1 = ResultKey(System.currentTimeMillis(), Map("tag" -> "A")) + val metric = DoubleMetric(Entity.Column, "m1", "", Try(100)) + val context1 = AnalyzerContext(Map(analyzer -> metric)) + + val resultKey2 = ResultKey(System.currentTimeMillis(), Map("tag" -> "B")) + val metric2 = DoubleMetric(Entity.Column, "m2", "", Try(101)) + val context2 = AnalyzerContext(Map(analyzer -> metric2)) + + repository.save(resultKey1, context1) + repository.save(resultKey2, context2) + val loadedMetricsForTagA = repository.load().withTagValues(Map("tag" -> "A")).get() + assert(loadedMetricsForTagA.length == 1) + // additional assertions to ensure the loaded metric is the one with tag "A" + + val loadedMetricsForMetricM1 = repository.load().forAnalyzers(Seq(analyzer)) + assert(loadedMetricsForTagA.length == 1) + + } + } + + "save and load to iceberg a single metric" in withSparkSessionIcebergCatalog { spark => { + val resultKey = ResultKey(System.currentTimeMillis(), Map("tag" -> "value")) + val metric = DoubleMetric(Entity.Column, "m1", "", Try(100)) + val context = AnalyzerContext(Map(analyzer -> metric)) + + val repository = new SparkTableMetricsRepository(spark, "local.metrics_table") + // Save the metric + repository.save(resultKey, context) + + // Load the metric + val loadedContext = repository.loadByKey(resultKey) + + assert(loadedContext.isDefined) + assert(loadedContext.get.metric(analyzer).contains(metric)) + } + + } + } +} From 46ded08b37145c11c2e4a7ebf51ece811ac04a24 Mon Sep 17 00:00:00 2001 From: Karthik Penikalapati Date: Sun, 12 Nov 2023 14:07:41 -0800 Subject: [PATCH 2/2] review comments --- .../sparktable/SparkMetricsRepository.scala | 53 +++++++++---------- .../com/amazon/deequ/SparkContextSpec.scala | 33 ++++++++---- .../SparkTableMetricsRepositoryTest.scala | 22 ++++---- 3 files changed, 58 insertions(+), 50 deletions(-) diff --git a/src/main/scala/com/amazon/deequ/repository/sparktable/SparkMetricsRepository.scala b/src/main/scala/com/amazon/deequ/repository/sparktable/SparkMetricsRepository.scala index 5fb195b07..c5c7290dc 100644 --- a/src/main/scala/com/amazon/deequ/repository/sparktable/SparkMetricsRepository.scala +++ b/src/main/scala/com/amazon/deequ/repository/sparktable/SparkMetricsRepository.scala @@ -1,5 +1,5 @@ /** - * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not * use this file except in compliance with the License. A copy of the License @@ -20,28 +20,22 @@ import com.amazon.deequ.analyzers.runners.AnalyzerContext import com.amazon.deequ.metrics.Metric import com.amazon.deequ.repository._ import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types.{StringType, StructField, StructType} -import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession} +import org.apache.spark.sql.{DataFrame, SaveMode, SparkSession} class SparkTableMetricsRepository(session: SparkSession, tableName: String) extends MetricsRepository { - private val SCHEMA = StructType(Array( - StructField("result_key", StringType), - StructField("metric_name", StringType), - StructField("metric_value", StringType), - StructField("result_timestamp", StringType), - StructField("serialized_context", StringType) - )) + import session.implicits._ override def save(resultKey: ResultKey, analyzerContext: AnalyzerContext): Unit = { val serializedContext = AnalysisResultSerde.serialize(Seq(AnalysisResult(resultKey, analyzerContext))) - val rows = analyzerContext.metricMap.map { case (analyzer, metric) => - Row(resultKey.toString, analyzer.toString, metric.value.toString, - resultKey.dataSetDate.toString, serializedContext) - }.toSeq + val successfulMetrics = analyzerContext.metricMap + .filter { case (_, metric) => metric.value.isSuccess } - val metricDF = session.createDataFrame(session.sparkContext.parallelize(rows), SCHEMA) + val metricDF = successfulMetrics.map { case (analyzer, metric) => + SparkTableMetric(resultKey.toString, analyzer.toString, metric.value.toString, + resultKey.dataSetDate, serializedContext) + }.toSeq.toDF() metricDF.write .mode(SaveMode.Append) @@ -50,14 +44,13 @@ class SparkTableMetricsRepository(session: SparkSession, tableName: String) exte override def loadByKey(resultKey: ResultKey): Option[AnalyzerContext] = { val df: DataFrame = session.table(tableName) - val matchingRows = df.filter(col("result_key") === resultKey.toString).collect() + val matchingRows = df.filter(col("resultKey") === resultKey.toString).collect() if (matchingRows.isEmpty) { None } else { - val serializedContext = matchingRows(0).getAs[String]("serialized_context") - val analysisResult = AnalysisResultSerde.deserialize(serializedContext).head - Some(analysisResult.analyzerContext) + val serializedContext = matchingRows(0).getAs[String]("serializedContext") + AnalysisResultSerde.deserialize(serializedContext).headOption.map(_.analyzerContext) } } @@ -67,13 +60,16 @@ class SparkTableMetricsRepository(session: SparkSession, tableName: String) exte } +case class SparkTableMetric(resultKey: String, metricName: String, metricValue: String, resultTimestamp: Long, + serializedContext: String) case class SparkTableMetricsRepositoryMultipleResultsLoader(session: SparkSession, tableName: String, - tagValues: Option[Map[String, String]] = None, - analyzers: Option[Seq[Analyzer[_, Metric[_]]]] = None, - timeAfter: Option[Long] = None, - timeBefore: Option[Long] = None + private val tagValues: Option[Map[String, String]] = None, + private val analyzers: Option[Seq[Analyzer[_, Metric[_]]]] + = None, + private val timeAfter: Option[Long] = None, + private val timeBefore: Option[Long] = None ) extends MetricsRepositoryMultipleResultsLoader { override def withTagValues(tagValues: Map[String, String]): MetricsRepositoryMultipleResultsLoader = @@ -91,12 +87,11 @@ case class SparkTableMetricsRepositoryMultipleResultsLoader(session: SparkSessio override def get(): Seq[AnalysisResult] = { val initialDF: DataFrame = session.table(tableName) - initialDF.printSchema() val tagValuesFilter: DataFrame => DataFrame = df => { tagValues.map { tags => tags.foldLeft(df) { (currentDF, tag) => currentDF.filter(row => { - val ser = row.getAs[String]("serialized_context") + val ser = row.getAs[String]("serializedContext") AnalysisResultSerde.deserialize(ser).exists(ar => { val tags = ar.resultKey.tags tags.contains(tag._1) && tags(tag._1) == tag._2 @@ -107,16 +102,16 @@ case class SparkTableMetricsRepositoryMultipleResultsLoader(session: SparkSessio } val specificAnalyzersFilter: DataFrame => DataFrame = df => { - analyzers.map(analyzers => df.filter(col("metric_name").isin(analyzers.map(_.toString): _*))) + analyzers.map(analyzers => df.filter(col("metricName").isin(analyzers.map(_.toString): _*))) .getOrElse(df) } val timeAfterFilter: DataFrame => DataFrame = df => { - timeAfter.map(time => df.filter(col("result_timestamp") > time.toString)).getOrElse(df) + timeAfter.map(time => df.filter(col("resultTimestamp") > time.toString)).getOrElse(df) } val timeBeforeFilter: DataFrame => DataFrame = df => { - timeBefore.map(time => df.filter(col("result_timestamp") < time.toString)).getOrElse(df) + timeBefore.map(time => df.filter(col("resultTimestamp") < time.toString)).getOrElse(df) } val filteredDF = Seq(tagValuesFilter, specificAnalyzersFilter, timeAfterFilter, timeBeforeFilter) @@ -126,7 +121,7 @@ case class SparkTableMetricsRepositoryMultipleResultsLoader(session: SparkSessio // Convert the final DataFrame to the desired output format filteredDF.collect().flatMap(row => { - val serializedContext = row.getAs[String]("serialized_context") + val serializedContext = row.getAs[String]("serializedContext") AnalysisResultSerde.deserialize(serializedContext) }).toSeq } diff --git a/src/test/scala/com/amazon/deequ/SparkContextSpec.scala b/src/test/scala/com/amazon/deequ/SparkContextSpec.scala index cff2c7448..81b1fd190 100644 --- a/src/test/scala/com/amazon/deequ/SparkContextSpec.scala +++ b/src/test/scala/com/amazon/deequ/SparkContextSpec.scala @@ -1,5 +1,5 @@ /** - * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not * use this file except in compliance with the License. A copy of the License @@ -27,13 +27,13 @@ import scala.collection.convert.ImplicitConversions.`iterator asScala` */ trait SparkContextSpec { - val warehouseDir: Path = Files.createTempDirectory("my_temp_dir_") + val tmpWareHouseDir: Path = Files.createTempDirectory("deequ_tmp") /** * @param testFun thunk to run with SparkSession as an argument */ def withSparkSession(testFun: SparkSession => Any): Unit = { - val session = setupSparkSession + val session = setupSparkSession() try { testFun(session) } finally { @@ -42,11 +42,20 @@ trait SparkContextSpec { } } + def withSparkSessionCustomWareHouse(testFun: SparkSession => Any): Unit = { + val session = setupSparkSession(Some(tmpWareHouseDir.toAbsolutePath.toString)) + try { + testFun(session) + } finally { + tearDownSparkSession(session) + } + } + def withSparkSessionIcebergCatalog(testFun: SparkSession => Any): Unit = { - val session = setupSparkSession + val session = setupSparkSession(Some(tmpWareHouseDir.toAbsolutePath.toString)) session.conf.set("spark.sql.catalog.local", "org.apache.iceberg.spark.SparkCatalog") session.conf.set("spark.sql.catalog.local.type", "hadoop") - session.conf.set("spark.sql.catalog.local.warehouse", warehouseDir.toAbsolutePath.toString) + session.conf.set("spark.sql.catalog.local.warehouse", tmpWareHouseDir.toAbsolutePath.toString) try { testFun(session) @@ -63,7 +72,7 @@ trait SparkContextSpec { */ def withMonitorableSparkSession(testFun: (SparkSession, SparkMonitor) => Any): Unit = { val monitor = new SparkMonitor - val session = setupSparkSession + val session = setupSparkSession() session.sparkContext.addSparkListener(monitor) try { testFun(session, monitor) @@ -91,16 +100,18 @@ trait SparkContextSpec { * * @return sparkSession to be used */ - private def setupSparkSession = { - val session = SparkSession.builder() + private def setupSparkSession(wareHouseDir: Option[String] = None) = { + val sessionBuilder = SparkSession.builder() .master("local") .appName("test") .config("spark.ui.enabled", "false") .config("spark.sql.shuffle.partitions", 2.toString) .config("spark.sql.adaptive.enabled", value = false) .config("spark.driver.bindAddress", "127.0.0.1") - .config("spark.sql.warehouse.dir", warehouseDir.toAbsolutePath.toString) - .getOrCreate() + + val session = wareHouseDir.fold(sessionBuilder.getOrCreate())(sessionBuilder + .config("spark.sql.warehouse.dir", _).getOrCreate()) + session.sparkContext.setCheckpointDir(System.getProperty("java.io.tmpdir")) session } @@ -124,7 +135,7 @@ trait SparkContextSpec { private def tearDownSparkSession(session: SparkSession) = { session.stop() System.clearProperty("spark.driver.port") - deleteDirectory(warehouseDir) + deleteDirectory(tmpWareHouseDir) } diff --git a/src/test/scala/com/amazon/deequ/repository/sparktable/SparkTableMetricsRepositoryTest.scala b/src/test/scala/com/amazon/deequ/repository/sparktable/SparkTableMetricsRepositoryTest.scala index 5b66ce305..667b5b502 100644 --- a/src/test/scala/com/amazon/deequ/repository/sparktable/SparkTableMetricsRepositoryTest.scala +++ b/src/test/scala/com/amazon/deequ/repository/sparktable/SparkTableMetricsRepositoryTest.scala @@ -1,5 +1,5 @@ /** - * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not * use this file except in compliance with the License. A copy of the License @@ -35,7 +35,7 @@ class SparkTableMetricsRepositoryTest extends AnyWordSpec private val analyzer = Size() "spark table metrics repository " should { - "save and load a single metric" in withSparkSession { spark => { + "save and load a single metric" in withSparkSessionCustomWareHouse { spark => val resultKey = ResultKey(System.currentTimeMillis(), Map("tag" -> "value")) val metric = DoubleMetric(Entity.Column, "m1", "", Try(100)) val context = AnalyzerContext(Map(analyzer -> metric)) @@ -49,11 +49,10 @@ class SparkTableMetricsRepositoryTest extends AnyWordSpec assert(loadedContext.isDefined) assert(loadedContext.get.metric(analyzer).contains(metric)) - } } - "save multiple metrics and load them" in withSparkSession { spark => { + "save multiple metrics and load them" in withSparkSessionCustomWareHouse { spark => val repository = new SparkTableMetricsRepository(spark, "metrics_table") val resultKey1 = ResultKey(System.currentTimeMillis(), Map("tag" -> "tagValue1")) @@ -72,10 +71,10 @@ class SparkTableMetricsRepositoryTest extends AnyWordSpec assert(loadedMetrics.length == 2) loadedMetrics.flatMap(_.resultKey.tags) - } + } - "save and load metrics with tag" in withSparkSession { spark => { + "save and load metrics with tag" in withSparkSessionCustomWareHouse { spark => val repository = new SparkTableMetricsRepository(spark, "metrics_table") val resultKey1 = ResultKey(System.currentTimeMillis(), Map("tag" -> "A")) @@ -90,12 +89,15 @@ class SparkTableMetricsRepositoryTest extends AnyWordSpec repository.save(resultKey2, context2) val loadedMetricsForTagA = repository.load().withTagValues(Map("tag" -> "A")).get() assert(loadedMetricsForTagA.length == 1) - // additional assertions to ensure the loaded metric is the one with tag "A" - val loadedMetricsForMetricM1 = repository.load().forAnalyzers(Seq(analyzer)) - assert(loadedMetricsForTagA.length == 1) + val tagsMapA = loadedMetricsForTagA.flatMap(_.resultKey.tags).toMap + assert(tagsMapA.size == 1, "should have 1 result") + assert(tagsMapA.contains("tag"), "should contain tag") + assert(tagsMapA("tag") == "A", "tag should be A") + + val loadedMetricsForAllMetrics = repository.load().forAnalyzers(Seq(analyzer)).get() + assert(loadedMetricsForAllMetrics.length == 2, "should have 2 results") - } } "save and load to iceberg a single metric" in withSparkSessionIcebergCatalog { spark => {