diff --git a/src/main/scala/com/amazon/deequ/analyzers/CustomSql.scala b/src/main/scala/com/amazon/deequ/analyzers/CustomSql.scala index caf90f42..b8dc2692 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/CustomSql.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/CustomSql.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.types.DoubleType import scala.util.Failure import scala.util.Success +import scala.util.Try case class CustomSql(expression: String) extends Analyzer[CustomSqlState, DoubleMetric] { /** @@ -33,18 +34,24 @@ case class CustomSql(expression: String) extends Analyzer[CustomSqlState, Double * @return */ override def computeStateFrom(data: DataFrame): Option[CustomSqlState] = { - val dfSql = data.sqlContext.sql(expression) - val cols = dfSql.columns.toSeq - cols match { - case Seq(resultCol) => - val dfSqlCast = dfSql.withColumn(resultCol, col(resultCol).cast(DoubleType)) - val results: Seq[Row] = dfSqlCast.collect() - if (results.size != 1) { - Some(CustomSqlState(Right("Custom SQL did not return exactly 1 row"))) - } else { - Some(CustomSqlState(Left(results.head.get(0).asInstanceOf[Double]))) + + Try { + data.sqlContext.sql(expression) + } match { + case Failure(e) => Some(CustomSqlState(Right(e.getMessage))) + case Success(dfSql) => + val cols = dfSql.columns.toSeq + cols match { + case Seq(resultCol) => + val dfSqlCast = dfSql.withColumn(resultCol, col(resultCol).cast(DoubleType)) + val results: Seq[Row] = dfSqlCast.collect() + if (results.size != 1) { + Some(CustomSqlState(Right("Custom SQL did not return exactly 1 row"))) + } else { + Some(CustomSqlState(Left(results.head.get(0).asInstanceOf[Double]))) + } + case _ => Some(CustomSqlState(Right("Custom SQL did not return exactly 1 column"))) } - case _ => Some(CustomSqlState(Right("Custom SQL did not return exactly 1 column"))) } } diff --git a/src/test/scala/com/amazon/deequ/analyzers/CustomSqlTest.scala b/src/test/scala/com/amazon/deequ/analyzers/CustomSqlTest.scala index 8b399054..05337499 100644 --- a/src/test/scala/com/amazon/deequ/analyzers/CustomSqlTest.scala +++ b/src/test/scala/com/amazon/deequ/analyzers/CustomSqlTest.scala @@ -68,7 +68,21 @@ class CustomSqlTest extends AnyWordSpec with Matchers with SparkContextSpec with case Success(_) => fail("Should have failed") case Failure(exception) => exception.getMessage shouldBe "Custom SQL did not return exactly 1 column" } + } + + "returns the error if the SQL statement has a syntax error" in withSparkSession { session => + val data = getDfWithStringColumns(session) + data.createOrReplaceTempView("primary") + val sql = CustomSql("Select `foo` from primary") + val state = sql.computeStateFrom(data) + val metric = sql.computeMetricFrom(state) + + metric.value.isFailure shouldBe true + metric.value match { + case Success(_) => fail("Should have failed") + case Failure(exception) => exception.getMessage should include("`foo`") + } } } }