diff --git a/src/main/scala/com/github/mrpowers/spark/daria/sql/DataFrameSchemaChecker.scala b/src/main/scala/com/github/mrpowers/spark/daria/sql/DataFrameSchemaChecker.scala index 4061098..8cdeb4a 100644 --- a/src/main/scala/com/github/mrpowers/spark/daria/sql/DataFrameSchemaChecker.scala +++ b/src/main/scala/com/github/mrpowers/spark/daria/sql/DataFrameSchemaChecker.scala @@ -1,13 +1,42 @@ package com.github.mrpowers.spark.daria.sql import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{StructField, StructType} + +import scala.annotation.tailrec +import scala.util.{Failure, Success, Try} case class InvalidDataFrameSchemaException(smth: String) extends Exception(smth) private[sql] class DataFrameSchemaChecker(df: DataFrame, requiredSchema: StructType) { + private def diff(required: Seq[StructField], schema: StructType): Seq[StructField] = { + required.filterNot(isPresentIn(schema)) + } + + private def isPresentIn(schema: StructType)(reqField: StructField): Boolean = { + Try(schema(reqField.name)) match { + case Success(namedField) => + val basicMatch = + namedField.name == reqField.name && + (!namedField.nullable || reqField.nullable) && + namedField.metadata == reqField.metadata + val contentMatch = reqField.dataType match { + case reqSchema: StructType => + namedField.dataType match { + case fieldSchema: StructType => + diff(reqSchema, fieldSchema).isEmpty + case _ => false + } + case namedField.dataType => true + case _ => false + } + + basicMatch && contentMatch + case Failure(_) => false + } + } - val missingStructFields = requiredSchema.diff(df.schema) + val missingStructFields: Seq[StructField] = diff(requiredSchema, df.schema) def missingStructFieldsMessage(): String = { s"The [${missingStructFields.mkString(", ")}] StructFields are not included in the DataFrame with the following StructFields [${df.schema.toString()}]" diff --git a/src/test/scala/com/github/mrpowers/spark/daria/sql/DariaValidatorTest.scala b/src/test/scala/com/github/mrpowers/spark/daria/sql/DariaValidatorTest.scala index b863424..07e8500 100644 --- a/src/test/scala/com/github/mrpowers/spark/daria/sql/DariaValidatorTest.scala +++ b/src/test/scala/com/github/mrpowers/spark/daria/sql/DariaValidatorTest.scala @@ -170,6 +170,128 @@ object DariaValidatorTest extends TestSuite with SparkSessionTestWrapper { } + "matches schema when fields are out of order" - { + val zoo = StructField( + "zoo", + StringType, + true + ) + val zaa = StructField( + "zaa", + StringType, + true + ) + val bar = StructField( + "bar", + StructType( + Seq( + zaa, + zoo + ) + ) + ) + val baz = StructField( + "baz", + StringType, + true + ) + val foo = StructField( + "foo", + StructType( + Seq( + baz, + bar + ) + ), + true + ) + val z = StructField( + "z", + StringType, + true + ) + + def validateSchemaEquality(s1: StructType, s2: StructType) = { + val df = spark + .createDataFrame( + spark.sparkContext.parallelize(Seq[Row]()), + s1 + ) + + df.printSchema() + spark + .createDataFrame( + spark.sparkContext.parallelize(Seq[Row]()), + s2 + ) + .printSchema() + + DariaValidator.validateSchema( + df, + s2 + ) + } + + // Shallow equality + validateSchemaEquality( + StructType( + Seq(z, foo) + ), + StructType( + Seq(foo, z) + ) + ) + + // Second level equality + val foo2 = StructField( + "foo", + StructType( + Seq( + bar, + baz + ) + ), + true + ) + validateSchemaEquality( + StructType( + Seq(z, foo) + ), + StructType( + Seq(z, foo2) + ) + ) + + // Third level equality - just to make sure + val bar2 = StructField( + "bar", + StructType( + Seq( + zoo, + zaa + ) + ) + ) + val foo3 = StructField( + "foo", + StructType( + Seq( + baz, + bar2 + ) + ), + true + ) + validateSchemaEquality( + StructType( + Seq(z, foo) + ), + StructType( + Seq(z, foo3) + ) + ) + } + } 'validateAbsenceOfColumns - { diff --git a/src/test/scala/com/github/mrpowers/spark/daria/sql/DataFrameSchemaCheckerTest.scala b/src/test/scala/com/github/mrpowers/spark/daria/sql/DataFrameSchemaCheckerTest.scala index b7c6823..a136aee 100644 --- a/src/test/scala/com/github/mrpowers/spark/daria/sql/DataFrameSchemaCheckerTest.scala +++ b/src/test/scala/com/github/mrpowers/spark/daria/sql/DataFrameSchemaCheckerTest.scala @@ -363,6 +363,40 @@ object DataFrameSchemaCheckerTest extends TestSuite with SparkSessionTestWrapper } + "validates non-null column against a null schema" - { + val sourceSchema = List( + StructField( + "num", + IntegerType, + false + ) + ) + + val sourceDF = + spark.createDataFrame( + spark.sparkContext.parallelize(Seq[Row]()), + StructType(sourceSchema) + ) + + val requiredSchema = + StructType( + List( + StructField( + "num", + IntegerType, + true + ) + ) + ) + + val c = new DataFrameSchemaChecker( + sourceDF, + requiredSchema + ) + + c.validateSchema() + } + } }