From 46ded08b37145c11c2e4a7ebf51ece811ac04a24 Mon Sep 17 00:00:00 2001 From: Karthik Penikalapati Date: Sun, 12 Nov 2023 14:07:41 -0800 Subject: [PATCH] review comments --- .../sparktable/SparkMetricsRepository.scala | 53 +++++++++---------- .../com/amazon/deequ/SparkContextSpec.scala | 33 ++++++++---- .../SparkTableMetricsRepositoryTest.scala | 22 ++++---- 3 files changed, 58 insertions(+), 50 deletions(-) diff --git a/src/main/scala/com/amazon/deequ/repository/sparktable/SparkMetricsRepository.scala b/src/main/scala/com/amazon/deequ/repository/sparktable/SparkMetricsRepository.scala index 5fb195b0..c5c7290d 100644 --- a/src/main/scala/com/amazon/deequ/repository/sparktable/SparkMetricsRepository.scala +++ b/src/main/scala/com/amazon/deequ/repository/sparktable/SparkMetricsRepository.scala @@ -1,5 +1,5 @@ /** - * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not * use this file except in compliance with the License. A copy of the License @@ -20,28 +20,22 @@ import com.amazon.deequ.analyzers.runners.AnalyzerContext import com.amazon.deequ.metrics.Metric import com.amazon.deequ.repository._ import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types.{StringType, StructField, StructType} -import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession} +import org.apache.spark.sql.{DataFrame, SaveMode, SparkSession} class SparkTableMetricsRepository(session: SparkSession, tableName: String) extends MetricsRepository { - private val SCHEMA = StructType(Array( - StructField("result_key", StringType), - StructField("metric_name", StringType), - StructField("metric_value", StringType), - StructField("result_timestamp", StringType), - StructField("serialized_context", StringType) - )) + import session.implicits._ override def save(resultKey: ResultKey, analyzerContext: AnalyzerContext): Unit = { val serializedContext = AnalysisResultSerde.serialize(Seq(AnalysisResult(resultKey, analyzerContext))) - val rows = analyzerContext.metricMap.map { case (analyzer, metric) => - Row(resultKey.toString, analyzer.toString, metric.value.toString, - resultKey.dataSetDate.toString, serializedContext) - }.toSeq + val successfulMetrics = analyzerContext.metricMap + .filter { case (_, metric) => metric.value.isSuccess } - val metricDF = session.createDataFrame(session.sparkContext.parallelize(rows), SCHEMA) + val metricDF = successfulMetrics.map { case (analyzer, metric) => + SparkTableMetric(resultKey.toString, analyzer.toString, metric.value.toString, + resultKey.dataSetDate, serializedContext) + }.toSeq.toDF() metricDF.write .mode(SaveMode.Append) @@ -50,14 +44,13 @@ class SparkTableMetricsRepository(session: SparkSession, tableName: String) exte override def loadByKey(resultKey: ResultKey): Option[AnalyzerContext] = { val df: DataFrame = session.table(tableName) - val matchingRows = df.filter(col("result_key") === resultKey.toString).collect() + val matchingRows = df.filter(col("resultKey") === resultKey.toString).collect() if (matchingRows.isEmpty) { None } else { - val serializedContext = matchingRows(0).getAs[String]("serialized_context") - val analysisResult = AnalysisResultSerde.deserialize(serializedContext).head - Some(analysisResult.analyzerContext) + val serializedContext = matchingRows(0).getAs[String]("serializedContext") + AnalysisResultSerde.deserialize(serializedContext).headOption.map(_.analyzerContext) } } @@ -67,13 +60,16 @@ class SparkTableMetricsRepository(session: SparkSession, tableName: String) exte } +case class SparkTableMetric(resultKey: String, metricName: String, metricValue: String, resultTimestamp: Long, + serializedContext: String) case class SparkTableMetricsRepositoryMultipleResultsLoader(session: SparkSession, tableName: String, - tagValues: Option[Map[String, String]] = None, - analyzers: Option[Seq[Analyzer[_, Metric[_]]]] = None, - timeAfter: Option[Long] = None, - timeBefore: Option[Long] = None + private val tagValues: Option[Map[String, String]] = None, + private val analyzers: Option[Seq[Analyzer[_, Metric[_]]]] + = None, + private val timeAfter: Option[Long] = None, + private val timeBefore: Option[Long] = None ) extends MetricsRepositoryMultipleResultsLoader { override def withTagValues(tagValues: Map[String, String]): MetricsRepositoryMultipleResultsLoader = @@ -91,12 +87,11 @@ case class SparkTableMetricsRepositoryMultipleResultsLoader(session: SparkSessio override def get(): Seq[AnalysisResult] = { val initialDF: DataFrame = session.table(tableName) - initialDF.printSchema() val tagValuesFilter: DataFrame => DataFrame = df => { tagValues.map { tags => tags.foldLeft(df) { (currentDF, tag) => currentDF.filter(row => { - val ser = row.getAs[String]("serialized_context") + val ser = row.getAs[String]("serializedContext") AnalysisResultSerde.deserialize(ser).exists(ar => { val tags = ar.resultKey.tags tags.contains(tag._1) && tags(tag._1) == tag._2 @@ -107,16 +102,16 @@ case class SparkTableMetricsRepositoryMultipleResultsLoader(session: SparkSessio } val specificAnalyzersFilter: DataFrame => DataFrame = df => { - analyzers.map(analyzers => df.filter(col("metric_name").isin(analyzers.map(_.toString): _*))) + analyzers.map(analyzers => df.filter(col("metricName").isin(analyzers.map(_.toString): _*))) .getOrElse(df) } val timeAfterFilter: DataFrame => DataFrame = df => { - timeAfter.map(time => df.filter(col("result_timestamp") > time.toString)).getOrElse(df) + timeAfter.map(time => df.filter(col("resultTimestamp") > time.toString)).getOrElse(df) } val timeBeforeFilter: DataFrame => DataFrame = df => { - timeBefore.map(time => df.filter(col("result_timestamp") < time.toString)).getOrElse(df) + timeBefore.map(time => df.filter(col("resultTimestamp") < time.toString)).getOrElse(df) } val filteredDF = Seq(tagValuesFilter, specificAnalyzersFilter, timeAfterFilter, timeBeforeFilter) @@ -126,7 +121,7 @@ case class SparkTableMetricsRepositoryMultipleResultsLoader(session: SparkSessio // Convert the final DataFrame to the desired output format filteredDF.collect().flatMap(row => { - val serializedContext = row.getAs[String]("serialized_context") + val serializedContext = row.getAs[String]("serializedContext") AnalysisResultSerde.deserialize(serializedContext) }).toSeq } diff --git a/src/test/scala/com/amazon/deequ/SparkContextSpec.scala b/src/test/scala/com/amazon/deequ/SparkContextSpec.scala index cff2c744..81b1fd19 100644 --- a/src/test/scala/com/amazon/deequ/SparkContextSpec.scala +++ b/src/test/scala/com/amazon/deequ/SparkContextSpec.scala @@ -1,5 +1,5 @@ /** - * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not * use this file except in compliance with the License. A copy of the License @@ -27,13 +27,13 @@ import scala.collection.convert.ImplicitConversions.`iterator asScala` */ trait SparkContextSpec { - val warehouseDir: Path = Files.createTempDirectory("my_temp_dir_") + val tmpWareHouseDir: Path = Files.createTempDirectory("deequ_tmp") /** * @param testFun thunk to run with SparkSession as an argument */ def withSparkSession(testFun: SparkSession => Any): Unit = { - val session = setupSparkSession + val session = setupSparkSession() try { testFun(session) } finally { @@ -42,11 +42,20 @@ trait SparkContextSpec { } } + def withSparkSessionCustomWareHouse(testFun: SparkSession => Any): Unit = { + val session = setupSparkSession(Some(tmpWareHouseDir.toAbsolutePath.toString)) + try { + testFun(session) + } finally { + tearDownSparkSession(session) + } + } + def withSparkSessionIcebergCatalog(testFun: SparkSession => Any): Unit = { - val session = setupSparkSession + val session = setupSparkSession(Some(tmpWareHouseDir.toAbsolutePath.toString)) session.conf.set("spark.sql.catalog.local", "org.apache.iceberg.spark.SparkCatalog") session.conf.set("spark.sql.catalog.local.type", "hadoop") - session.conf.set("spark.sql.catalog.local.warehouse", warehouseDir.toAbsolutePath.toString) + session.conf.set("spark.sql.catalog.local.warehouse", tmpWareHouseDir.toAbsolutePath.toString) try { testFun(session) @@ -63,7 +72,7 @@ trait SparkContextSpec { */ def withMonitorableSparkSession(testFun: (SparkSession, SparkMonitor) => Any): Unit = { val monitor = new SparkMonitor - val session = setupSparkSession + val session = setupSparkSession() session.sparkContext.addSparkListener(monitor) try { testFun(session, monitor) @@ -91,16 +100,18 @@ trait SparkContextSpec { * * @return sparkSession to be used */ - private def setupSparkSession = { - val session = SparkSession.builder() + private def setupSparkSession(wareHouseDir: Option[String] = None) = { + val sessionBuilder = SparkSession.builder() .master("local") .appName("test") .config("spark.ui.enabled", "false") .config("spark.sql.shuffle.partitions", 2.toString) .config("spark.sql.adaptive.enabled", value = false) .config("spark.driver.bindAddress", "127.0.0.1") - .config("spark.sql.warehouse.dir", warehouseDir.toAbsolutePath.toString) - .getOrCreate() + + val session = wareHouseDir.fold(sessionBuilder.getOrCreate())(sessionBuilder + .config("spark.sql.warehouse.dir", _).getOrCreate()) + session.sparkContext.setCheckpointDir(System.getProperty("java.io.tmpdir")) session } @@ -124,7 +135,7 @@ trait SparkContextSpec { private def tearDownSparkSession(session: SparkSession) = { session.stop() System.clearProperty("spark.driver.port") - deleteDirectory(warehouseDir) + deleteDirectory(tmpWareHouseDir) } diff --git a/src/test/scala/com/amazon/deequ/repository/sparktable/SparkTableMetricsRepositoryTest.scala b/src/test/scala/com/amazon/deequ/repository/sparktable/SparkTableMetricsRepositoryTest.scala index 5b66ce30..667b5b50 100644 --- a/src/test/scala/com/amazon/deequ/repository/sparktable/SparkTableMetricsRepositoryTest.scala +++ b/src/test/scala/com/amazon/deequ/repository/sparktable/SparkTableMetricsRepositoryTest.scala @@ -1,5 +1,5 @@ /** - * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not * use this file except in compliance with the License. A copy of the License @@ -35,7 +35,7 @@ class SparkTableMetricsRepositoryTest extends AnyWordSpec private val analyzer = Size() "spark table metrics repository " should { - "save and load a single metric" in withSparkSession { spark => { + "save and load a single metric" in withSparkSessionCustomWareHouse { spark => val resultKey = ResultKey(System.currentTimeMillis(), Map("tag" -> "value")) val metric = DoubleMetric(Entity.Column, "m1", "", Try(100)) val context = AnalyzerContext(Map(analyzer -> metric)) @@ -49,11 +49,10 @@ class SparkTableMetricsRepositoryTest extends AnyWordSpec assert(loadedContext.isDefined) assert(loadedContext.get.metric(analyzer).contains(metric)) - } } - "save multiple metrics and load them" in withSparkSession { spark => { + "save multiple metrics and load them" in withSparkSessionCustomWareHouse { spark => val repository = new SparkTableMetricsRepository(spark, "metrics_table") val resultKey1 = ResultKey(System.currentTimeMillis(), Map("tag" -> "tagValue1")) @@ -72,10 +71,10 @@ class SparkTableMetricsRepositoryTest extends AnyWordSpec assert(loadedMetrics.length == 2) loadedMetrics.flatMap(_.resultKey.tags) - } + } - "save and load metrics with tag" in withSparkSession { spark => { + "save and load metrics with tag" in withSparkSessionCustomWareHouse { spark => val repository = new SparkTableMetricsRepository(spark, "metrics_table") val resultKey1 = ResultKey(System.currentTimeMillis(), Map("tag" -> "A")) @@ -90,12 +89,15 @@ class SparkTableMetricsRepositoryTest extends AnyWordSpec repository.save(resultKey2, context2) val loadedMetricsForTagA = repository.load().withTagValues(Map("tag" -> "A")).get() assert(loadedMetricsForTagA.length == 1) - // additional assertions to ensure the loaded metric is the one with tag "A" - val loadedMetricsForMetricM1 = repository.load().forAnalyzers(Seq(analyzer)) - assert(loadedMetricsForTagA.length == 1) + val tagsMapA = loadedMetricsForTagA.flatMap(_.resultKey.tags).toMap + assert(tagsMapA.size == 1, "should have 1 result") + assert(tagsMapA.contains("tag"), "should contain tag") + assert(tagsMapA("tag") == "A", "tag should be A") + + val loadedMetricsForAllMetrics = repository.load().forAnalyzers(Seq(analyzer)).get() + assert(loadedMetricsForAllMetrics.length == 2, "should have 2 results") - } } "save and load to iceberg a single metric" in withSparkSessionIcebergCatalog { spark => {