From f0743641ebcff7f0071f4c8b4543797cd9704e9e Mon Sep 17 00:00:00 2001 From: skestle Date: Sat, 12 Jun 2021 06:35:58 +1200 Subject: [PATCH 1/2] Ensured that schema validation matches nested structures that are in different order (#144) * Highlighted issues with schema validation's order sensitivity of nested structures. * Implemented recursive schema checker Enables nested structures to be compared regardless of order. --- .../daria/sql/DataFrameSchemaChecker.scala | 33 ++++- .../spark/daria/sql/DariaValidatorTest.scala | 122 ++++++++++++++++++ 2 files changed, 153 insertions(+), 2 deletions(-) diff --git a/src/main/scala/com/github/mrpowers/spark/daria/sql/DataFrameSchemaChecker.scala b/src/main/scala/com/github/mrpowers/spark/daria/sql/DataFrameSchemaChecker.scala index 40610988..7d30f371 100644 --- a/src/main/scala/com/github/mrpowers/spark/daria/sql/DataFrameSchemaChecker.scala +++ b/src/main/scala/com/github/mrpowers/spark/daria/sql/DataFrameSchemaChecker.scala @@ -1,13 +1,42 @@ package com.github.mrpowers.spark.daria.sql import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{StructField, StructType} + +import scala.annotation.tailrec +import scala.util.{Failure, Success, Try} case class InvalidDataFrameSchemaException(smth: String) extends Exception(smth) private[sql] class DataFrameSchemaChecker(df: DataFrame, requiredSchema: StructType) { + private def diff(required: Seq[StructField], schema: StructType): Seq[StructField] = { + required.filterNot(isPresentIn(schema)) + } + + private def isPresentIn(schema: StructType)(reqField: StructField): Boolean = { + Try(schema(reqField.name)) match { + case Success(namedField) => + val basicMatch = + namedField.name == reqField.name && + namedField.nullable == reqField.nullable && + namedField.metadata == reqField.metadata + + val contentMatch = reqField.dataType match { + case reqSchema: StructType => + namedField.dataType match { + case fieldSchema: StructType => + diff(reqSchema, fieldSchema).isEmpty + case _ => false + } + case _ => reqField == namedField + } + + basicMatch && contentMatch + case Failure(_) => false + } + } - val missingStructFields = requiredSchema.diff(df.schema) + val missingStructFields: Seq[StructField] = diff(requiredSchema, df.schema) def missingStructFieldsMessage(): String = { s"The [${missingStructFields.mkString(", ")}] StructFields are not included in the DataFrame with the following StructFields [${df.schema.toString()}]" diff --git a/src/test/scala/com/github/mrpowers/spark/daria/sql/DariaValidatorTest.scala b/src/test/scala/com/github/mrpowers/spark/daria/sql/DariaValidatorTest.scala index b863424d..07e8500a 100644 --- a/src/test/scala/com/github/mrpowers/spark/daria/sql/DariaValidatorTest.scala +++ b/src/test/scala/com/github/mrpowers/spark/daria/sql/DariaValidatorTest.scala @@ -170,6 +170,128 @@ object DariaValidatorTest extends TestSuite with SparkSessionTestWrapper { } + "matches schema when fields are out of order" - { + val zoo = StructField( + "zoo", + StringType, + true + ) + val zaa = StructField( + "zaa", + StringType, + true + ) + val bar = StructField( + "bar", + StructType( + Seq( + zaa, + zoo + ) + ) + ) + val baz = StructField( + "baz", + StringType, + true + ) + val foo = StructField( + "foo", + StructType( + Seq( + baz, + bar + ) + ), + true + ) + val z = StructField( + "z", + StringType, + true + ) + + def validateSchemaEquality(s1: StructType, s2: StructType) = { + val df = spark + .createDataFrame( + spark.sparkContext.parallelize(Seq[Row]()), + s1 + ) + + df.printSchema() + spark + .createDataFrame( + spark.sparkContext.parallelize(Seq[Row]()), + s2 + ) + .printSchema() + + DariaValidator.validateSchema( + df, + s2 + ) + } + + // Shallow equality + validateSchemaEquality( + StructType( + Seq(z, foo) + ), + StructType( + Seq(foo, z) + ) + ) + + // Second level equality + val foo2 = StructField( + "foo", + StructType( + Seq( + bar, + baz + ) + ), + true + ) + validateSchemaEquality( + StructType( + Seq(z, foo) + ), + StructType( + Seq(z, foo2) + ) + ) + + // Third level equality - just to make sure + val bar2 = StructField( + "bar", + StructType( + Seq( + zoo, + zaa + ) + ) + ) + val foo3 = StructField( + "foo", + StructType( + Seq( + baz, + bar2 + ) + ), + true + ) + validateSchemaEquality( + StructType( + Seq(z, foo) + ), + StructType( + Seq(z, foo3) + ) + ) + } + } 'validateAbsenceOfColumns - { From c84ded821643faff9f6da3c4ae260a5914f0bdf3 Mon Sep 17 00:00:00 2001 From: skestle Date: Thu, 10 Jun 2021 14:13:40 +1200 Subject: [PATCH 2/2] Validated data frames with not null columns against null schema columns. Previously, a dataframe would be considered invalid if it had not-null data, but the schema allowed the data to be nullable --- .../daria/sql/DataFrameSchemaChecker.scala | 6 ++-- .../sql/DataFrameSchemaCheckerTest.scala | 34 +++++++++++++++++++ 2 files changed, 37 insertions(+), 3 deletions(-) diff --git a/src/main/scala/com/github/mrpowers/spark/daria/sql/DataFrameSchemaChecker.scala b/src/main/scala/com/github/mrpowers/spark/daria/sql/DataFrameSchemaChecker.scala index 7d30f371..8cdeb4ad 100644 --- a/src/main/scala/com/github/mrpowers/spark/daria/sql/DataFrameSchemaChecker.scala +++ b/src/main/scala/com/github/mrpowers/spark/daria/sql/DataFrameSchemaChecker.scala @@ -18,9 +18,8 @@ private[sql] class DataFrameSchemaChecker(df: DataFrame, requiredSchema: StructT case Success(namedField) => val basicMatch = namedField.name == reqField.name && - namedField.nullable == reqField.nullable && + (!namedField.nullable || reqField.nullable) && namedField.metadata == reqField.metadata - val contentMatch = reqField.dataType match { case reqSchema: StructType => namedField.dataType match { @@ -28,7 +27,8 @@ private[sql] class DataFrameSchemaChecker(df: DataFrame, requiredSchema: StructT diff(reqSchema, fieldSchema).isEmpty case _ => false } - case _ => reqField == namedField + case namedField.dataType => true + case _ => false } basicMatch && contentMatch diff --git a/src/test/scala/com/github/mrpowers/spark/daria/sql/DataFrameSchemaCheckerTest.scala b/src/test/scala/com/github/mrpowers/spark/daria/sql/DataFrameSchemaCheckerTest.scala index b7c6823c..a136aee9 100644 --- a/src/test/scala/com/github/mrpowers/spark/daria/sql/DataFrameSchemaCheckerTest.scala +++ b/src/test/scala/com/github/mrpowers/spark/daria/sql/DataFrameSchemaCheckerTest.scala @@ -363,6 +363,40 @@ object DataFrameSchemaCheckerTest extends TestSuite with SparkSessionTestWrapper } + "validates non-null column against a null schema" - { + val sourceSchema = List( + StructField( + "num", + IntegerType, + false + ) + ) + + val sourceDF = + spark.createDataFrame( + spark.sparkContext.parallelize(Seq[Row]()), + StructType(sourceSchema) + ) + + val requiredSchema = + StructType( + List( + StructField( + "num", + IntegerType, + true + ) + ) + ) + + val c = new DataFrameSchemaChecker( + sourceDF, + requiredSchema + ) + + c.validateSchema() + } + } }