From 791606fb8130ee1c2e8645966624d2035aeaaf4a Mon Sep 17 00:00:00 2001 From: Stephen Kestle Date: Fri, 14 May 2021 21:16:30 +1200 Subject: [PATCH] Implemented recursive schema checker Enables nested structures to be compared regardless of order. --- .../daria/sql/DataFrameSchemaChecker.scala | 33 +++++++++++++++++-- 1 file changed, 31 insertions(+), 2 deletions(-) diff --git a/src/main/scala/com/github/mrpowers/spark/daria/sql/DataFrameSchemaChecker.scala b/src/main/scala/com/github/mrpowers/spark/daria/sql/DataFrameSchemaChecker.scala index 4061098..7d30f37 100644 --- a/src/main/scala/com/github/mrpowers/spark/daria/sql/DataFrameSchemaChecker.scala +++ b/src/main/scala/com/github/mrpowers/spark/daria/sql/DataFrameSchemaChecker.scala @@ -1,13 +1,42 @@ package com.github.mrpowers.spark.daria.sql import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{StructField, StructType} + +import scala.annotation.tailrec +import scala.util.{Failure, Success, Try} case class InvalidDataFrameSchemaException(smth: String) extends Exception(smth) private[sql] class DataFrameSchemaChecker(df: DataFrame, requiredSchema: StructType) { + private def diff(required: Seq[StructField], schema: StructType): Seq[StructField] = { + required.filterNot(isPresentIn(schema)) + } + + private def isPresentIn(schema: StructType)(reqField: StructField): Boolean = { + Try(schema(reqField.name)) match { + case Success(namedField) => + val basicMatch = + namedField.name == reqField.name && + namedField.nullable == reqField.nullable && + namedField.metadata == reqField.metadata + + val contentMatch = reqField.dataType match { + case reqSchema: StructType => + namedField.dataType match { + case fieldSchema: StructType => + diff(reqSchema, fieldSchema).isEmpty + case _ => false + } + case _ => reqField == namedField + } + + basicMatch && contentMatch + case Failure(_) => false + } + } - val missingStructFields = requiredSchema.diff(df.schema) + val missingStructFields: Seq[StructField] = diff(requiredSchema, df.schema) def missingStructFieldsMessage(): String = { s"The [${missingStructFields.mkString(", ")}] StructFields are not included in the DataFrame with the following StructFields [${df.schema.toString()}]"