From e1550e2a256bbae4546c99549f12a2c39a175a5a Mon Sep 17 00:00:00 2001 From: Manu Zhang Date: Thu, 25 Apr 2019 01:42:49 +0800 Subject: [PATCH] Add createEmptyDF to SparkSessionExt (#71) --- .../spark/daria/sql/SparkSessionExt.scala | 15 +++++++ .../spark/daria/sql/SparkSessionExtTest.scala | 44 +++++++++++++++++++ 2 files changed, 59 insertions(+) diff --git a/src/main/scala/com/github/mrpowers/spark/daria/sql/SparkSessionExt.scala b/src/main/scala/com/github/mrpowers/spark/daria/sql/SparkSessionExt.scala index b3d0de90..9105c31c 100644 --- a/src/main/scala/com/github/mrpowers/spark/daria/sql/SparkSessionExt.scala +++ b/src/main/scala/com/github/mrpowers/spark/daria/sql/SparkSessionExt.scala @@ -62,6 +62,21 @@ object SparkSessionExt { ) } + /** + * Creates an empty DataFrame given schema fields + * + * This is a handy fallback when you fail to read from a data source + * + * val schema = List(StructField("col1", IntegerType)) + * val df = Try { + * spark.read.parquet("non-existent-path") + * }.getOrElse(spark.createEmptyDf(schema)) + */ + def createEmptyDF[T](fields: List[T]): DataFrame = { + spark.createDataFrame(spark.sparkContext.emptyRDD[Row], + StructType(asSchema(fields))) + } + } } diff --git a/src/test/scala/com/github/mrpowers/spark/daria/sql/SparkSessionExtTest.scala b/src/test/scala/com/github/mrpowers/spark/daria/sql/SparkSessionExtTest.scala index b3eaf043..2d5807db 100644 --- a/src/test/scala/com/github/mrpowers/spark/daria/sql/SparkSessionExtTest.scala +++ b/src/test/scala/com/github/mrpowers/spark/daria/sql/SparkSessionExtTest.scala @@ -200,6 +200,50 @@ object SparkSessionExtTest extends TestSuite with DataFrameComparer with SparkSe } + 'createEmptyDF - { + "creates an empty DataFrame with a list of StructFields" - { + val actualDF = + spark.createEmptyDF( + List( + StructField( + "num1", + IntegerType, + true + ), + StructField( + "num2", + IntegerType, + true + ) + ) + ) + + + val expectedSchema = List( + StructField( + "num1", + IntegerType, + true + ), + StructField( + "num2", + IntegerType, + true + ) + ) + + val expectedDF = + spark.createDataFrame( + spark.sparkContext.parallelize(Seq.empty[Row]), + StructType(expectedSchema) + ) + + assertSmallDataFrameEquality( + actualDF, + expectedDF + ) + } + } } }