diff --git a/src/main/scala/com/github/mrpowers/spark/daria/sql/DataFrameSchemaChecker.scala b/src/main/scala/com/github/mrpowers/spark/daria/sql/DataFrameSchemaChecker.scala index 7d30f37..8cdeb4a 100644 --- a/src/main/scala/com/github/mrpowers/spark/daria/sql/DataFrameSchemaChecker.scala +++ b/src/main/scala/com/github/mrpowers/spark/daria/sql/DataFrameSchemaChecker.scala @@ -18,9 +18,8 @@ private[sql] class DataFrameSchemaChecker(df: DataFrame, requiredSchema: StructT case Success(namedField) => val basicMatch = namedField.name == reqField.name && - namedField.nullable == reqField.nullable && + (!namedField.nullable || reqField.nullable) && namedField.metadata == reqField.metadata - val contentMatch = reqField.dataType match { case reqSchema: StructType => namedField.dataType match { @@ -28,7 +27,8 @@ private[sql] class DataFrameSchemaChecker(df: DataFrame, requiredSchema: StructT diff(reqSchema, fieldSchema).isEmpty case _ => false } - case _ => reqField == namedField + case namedField.dataType => true + case _ => false } basicMatch && contentMatch diff --git a/src/test/scala/com/github/mrpowers/spark/daria/sql/DataFrameSchemaCheckerTest.scala b/src/test/scala/com/github/mrpowers/spark/daria/sql/DataFrameSchemaCheckerTest.scala index b7c6823..a136aee 100644 --- a/src/test/scala/com/github/mrpowers/spark/daria/sql/DataFrameSchemaCheckerTest.scala +++ b/src/test/scala/com/github/mrpowers/spark/daria/sql/DataFrameSchemaCheckerTest.scala @@ -363,6 +363,40 @@ object DataFrameSchemaCheckerTest extends TestSuite with SparkSessionTestWrapper } + "validates non-null column against a null schema" - { + val sourceSchema = List( + StructField( + "num", + IntegerType, + false + ) + ) + + val sourceDF = + spark.createDataFrame( + spark.sparkContext.parallelize(Seq[Row]()), + StructType(sourceSchema) + ) + + val requiredSchema = + StructType( + List( + StructField( + "num", + IntegerType, + true + ) + ) + ) + + val c = new DataFrameSchemaChecker( + sourceDF, + requiredSchema + ) + + c.validateSchema() + } + } }