From a1b0846498a008911e5edc5d697d58367e1f2a1b Mon Sep 17 00:00:00 2001 From: Tuan Pham Date: Sun, 29 Sep 2024 18:10:25 +1000 Subject: [PATCH 01/14] add select sorted expr --- .../spark/daria/sql/DataFrameExt.scala | 10 +- .../daria/sql/types/StructTypeHelpers.scala | 41 ++-- .../spark/daria/sql/DataFrameExtTest.scala | 194 ++++++++++++++++++ 3 files changed, 227 insertions(+), 18 deletions(-) diff --git a/core/src/main/scala/com/github/mrpowers/spark/daria/sql/DataFrameExt.scala b/core/src/main/scala/com/github/mrpowers/spark/daria/sql/DataFrameExt.scala index aa95e269..aa3e6cda 100644 --- a/core/src/main/scala/com/github/mrpowers/spark/daria/sql/DataFrameExt.scala +++ b/core/src/main/scala/com/github/mrpowers/spark/daria/sql/DataFrameExt.scala @@ -6,6 +6,8 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{ArrayType, DataType, StructField, StructType} import org.apache.spark.sql.{Column, DataFrame} +import scala.collection.mutable + case class DataFrameColumnsException(smth: String) extends Exception(smth) object DataFrameExt { @@ -266,9 +268,9 @@ object DataFrameExt { * This StackOverflow answer provides a detailed description how to use flattenSchema: https://stackoverflow.com/a/50402697/1125159 */ def flattenSchema(delimiter: String = "."): DataFrame = { - val renamedCols: Array[Column] = StructTypeHelpers + val renamedCols = StructTypeHelpers .flattenSchema(df.schema) - .map(name => col(name.toString).as(name.toString.replace(".", delimiter))) + .map(c => c.as(c.toString.replace(".", delimiter))) df.select(renamedCols: _*) } @@ -407,6 +409,8 @@ object DataFrameExt { StructType(loop(df.schema)) ) } - } + def selectSortedCols: DataFrame = df + .select(StructTypeHelpers.schemaToSortedSelectExpr(df.schema): _*) + } } diff --git a/core/src/main/scala/com/github/mrpowers/spark/daria/sql/types/StructTypeHelpers.scala b/core/src/main/scala/com/github/mrpowers/spark/daria/sql/types/StructTypeHelpers.scala index 72b3e934..2b460760 100644 --- a/core/src/main/scala/com/github/mrpowers/spark/daria/sql/types/StructTypeHelpers.scala +++ b/core/src/main/scala/com/github/mrpowers/spark/daria/sql/types/StructTypeHelpers.scala @@ -2,7 +2,7 @@ package com.github.mrpowers.spark.daria.sql.types import org.apache.spark.sql.Column import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.types.{DataType, StructField, StructType} +import org.apache.spark.sql.types.{ArrayType, DataType, StructField, StructType} import org.apache.spark.sql.functions._ import scala.reflect.runtime.universe._ @@ -21,21 +21,32 @@ object StructTypeHelpers { } } - def flattenSchema(schema: StructType, prefix: String = ""): Array[Column] = { - schema.fields.flatMap(structField => { - val codeColName = - if (prefix.isEmpty) structField.name - else prefix + "." + structField.name - - structField.dataType match { - case st: StructType => - flattenSchema( - schema = st, - prefix = codeColName - ) - case _ => Array(col(codeColName)) + def flattenSchema(schema: StructType, baseField: String = "", sortFields: Boolean = false): Seq[Column] = { + val fields = if (sortFields) schema.fields.sortBy(_.name) else schema.fields + fields.foldLeft(Seq.empty[Column]) { case(acc, field) => + val colName = if (baseField.isEmpty) field.name else s"$baseField.${field.name}" + field.dataType match { + case t: StructType => + acc ++ flattenSchema(t, baseField = colName, sortFields = sortFields) + case _ => + acc :+ col(colName) } - }) + } + } + + def schemaToSortedSelectExpr(schema: StructType, baseField: String = ""): Seq[Column] = { + val result = schema.fields.sortBy(_.name).sortBy(_.name).foldLeft(Seq.empty[Column]) { case(acc, field) => + val colName = if (baseField.isEmpty) field.name else s"$baseField.${field.name}" + field.dataType match { + case t: StructType => + acc :+ struct(schemaToSortedSelectExpr(t, baseField = colName): _*).as(field.name) + case ArrayType(t: StructType, _) => + acc :+ arrays_zip(schemaToSortedSelectExpr(t, baseField = colName): _*).as(field.name) + case _ => + acc :+ col(colName) + } + } + result } /** diff --git a/core/src/test/scala/com/github/mrpowers/spark/daria/sql/DataFrameExtTest.scala b/core/src/test/scala/com/github/mrpowers/spark/daria/sql/DataFrameExtTest.scala index 7ad238ce..38e6be3d 100644 --- a/core/src/test/scala/com/github/mrpowers/spark/daria/sql/DataFrameExtTest.scala +++ b/core/src/test/scala/com/github/mrpowers/spark/daria/sql/DataFrameExtTest.scala @@ -961,6 +961,199 @@ object DataFrameExtTest extends TestSuite with DataFrameComparer with SparkSessi } + 'selectSortedCols - { + + "select col with all field sorted" - { + + val data = Seq( + Row( + Row( + "bayVal", + "baxVal", + Row("this", "yoVal"), + "is" + ), + Seq( + Row("yVal", "xVal"), + Row("yVal1", "xVal1") + ), + "something", + "cool", + ";)" + ) + ) + + val schema = StructType( + Seq( + StructField( + "foo", + StructType( + Seq( + StructField( + "bay", + StringType, + true + ), + StructField( + "bax", + StringType, + true + ), + StructField( + "bar", + StructType( + Seq( + StructField( + "zoo", + StringType, + true + ), + StructField( + "yoo", + StringType, + true + ) + ) + ) + ), + StructField( + "baz", + StringType, + true + ), + ) + ), + true + ), + StructField( + "w", + ArrayType(StructType(Seq(StructField("y", StringType, true), StructField("x", StringType, true)))), + true + ), + StructField( + "x", + StringType, + true + ), + StructField( + "y", + StringType, + true + ), + StructField( + "z", + StringType, + true + ) + ) + ) + + val df = spark + .createDataFrame( + spark.sparkContext.parallelize(data), + StructType(schema) + ) + .selectSortedCols + + val expectedData = Seq( + Row( + Row( + Row("yoVal", "this"), + "baxVal", + "bayVal", + "is" + ), + Seq( + Row("xVal", "yVal"), + Row("xVal1", "yVal1") + ), + "something", + "cool", + ";)" + ) + ) + + val expectedSchema = StructType( + Seq( + StructField( + "foo", + StructType( + Seq( + StructField( + "bar", + StructType( + Seq( + StructField( + "yoo", + StringType, + true + ), + StructField( + "zoo", + StringType, + true + ) + ) + ), + false + ), + StructField( + "bax", + StringType, + true + ), + StructField( + "bay", + StringType, + true + ), + StructField( + "baz", + StringType, + true + ) + ) + ), + false + ), + StructField( + "w", + ArrayType(StructType(Seq(StructField("x", StringType, true), StructField("y", StringType, true))), false), + true + ), + StructField( + "x", + StringType, + true + ), + StructField( + "y", + StringType, + true + ), + StructField( + "z", + StringType, + true + ) + ) + ) + + val expectedDF = spark + .createDataFrame( + spark.sparkContext.parallelize(expectedData), + StructType(expectedSchema) + ) + + assertSmallDataFrameEquality( + df, + expectedDF, + ignoreNullable = true + ) + } + + } + 'structureSchema - { "structure schema with default delimiter" - { val data = Seq( @@ -1125,6 +1318,7 @@ object DataFrameExtTest extends TestSuite with DataFrameComparer with SparkSessi } } + 'composeTrans - { def withCountry()(df: DataFrame): DataFrame = { From 4a9e8278d4151079f7e70756e7d3b393c4fa2bd1 Mon Sep 17 00:00:00 2001 From: Tuan Pham Date: Mon, 30 Sep 2024 09:42:24 +1000 Subject: [PATCH 02/14] support flattenSchema on array of StructType --- .../spark/daria/sql/types/StructTypeHelpers.scala | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/com/github/mrpowers/spark/daria/sql/types/StructTypeHelpers.scala b/core/src/main/scala/com/github/mrpowers/spark/daria/sql/types/StructTypeHelpers.scala index 2b460760..2b7e6097 100644 --- a/core/src/main/scala/com/github/mrpowers/spark/daria/sql/types/StructTypeHelpers.scala +++ b/core/src/main/scala/com/github/mrpowers/spark/daria/sql/types/StructTypeHelpers.scala @@ -21,13 +21,14 @@ object StructTypeHelpers { } } - def flattenSchema(schema: StructType, baseField: String = "", sortFields: Boolean = false): Seq[Column] = { - val fields = if (sortFields) schema.fields.sortBy(_.name) else schema.fields - fields.foldLeft(Seq.empty[Column]) { case(acc, field) => + def flattenSchema(schema: StructType, baseField: String = "", flattenArrayType: Boolean = false): Seq[Column] = { + schema.fields.foldLeft(Seq.empty[Column]) { case(acc, field) => val colName = if (baseField.isEmpty) field.name else s"$baseField.${field.name}" field.dataType match { case t: StructType => - acc ++ flattenSchema(t, baseField = colName, sortFields = sortFields) + acc ++ flattenSchema(t, colName) + case ArrayType(t: StructType, _) if flattenArrayType => + acc ++ flattenSchema(t, colName) case _ => acc :+ col(colName) } From 58154d174f5f53c068a982ec25a62be7ffc12402 Mon Sep 17 00:00:00 2001 From: Tuan Pham Date: Mon, 30 Sep 2024 13:32:14 +1000 Subject: [PATCH 03/14] Add more scala like API to sort columns --- .../spark/daria/sql/DataFrameExt.scala | 22 +++++++++++++++++-- .../daria/sql/types/StructTypeHelpers.scala | 14 +++++++----- .../spark/daria/sql/DataFrameExtTest.scala | 2 +- 3 files changed, 30 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/com/github/mrpowers/spark/daria/sql/DataFrameExt.scala b/core/src/main/scala/com/github/mrpowers/spark/daria/sql/DataFrameExt.scala index aa3e6cda..a4805a53 100644 --- a/core/src/main/scala/com/github/mrpowers/spark/daria/sql/DataFrameExt.scala +++ b/core/src/main/scala/com/github/mrpowers/spark/daria/sql/DataFrameExt.scala @@ -1,6 +1,7 @@ package com.github.mrpowers.spark.daria.sql import com.github.mrpowers.spark.daria.sql.types.StructTypeHelpers +import com.github.mrpowers.spark.daria.sql.types.StructTypeHelpers.StructTypeOps import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{ArrayType, DataType, StructField, StructType} @@ -410,7 +411,24 @@ object DataFrameExt { ) } - def selectSortedCols: DataFrame = df - .select(StructTypeHelpers.schemaToSortedSelectExpr(df.schema): _*) + /** Sorts this DataFrame columns order according to the Ordering which results from transforming + * an implicitly given Ordering with a transformation function. + * @see [[scala.math.Ordering]] + * @param f the transformation function mapping elements of type [[StructField]] + * to some other domain `A`. + * @param ord the ordering assumed on domain `A`. + * @tparam A the target type of the transformation `f`, and the type where + * the ordering `ord` is defined. + * @return a DataFrame consisting of the fields of this DataFrame + * sorted according to the ordering where `x < y` if + * `ord.lt(f(x), f(y))`. + * + * @example {{{ + * // this works because scala.Ordering will implicitly provide an Ordering[String] + * df.sortColumnsBy(_.name) + * }}} + */ + def sortColumnsBy[A](f: StructField => A)(implicit ord: Ordering[A]): DataFrame = df + .select(df.schema.toSortedSelectExpr(f): _*) } } diff --git a/core/src/main/scala/com/github/mrpowers/spark/daria/sql/types/StructTypeHelpers.scala b/core/src/main/scala/com/github/mrpowers/spark/daria/sql/types/StructTypeHelpers.scala index 2b7e6097..65d74423 100644 --- a/core/src/main/scala/com/github/mrpowers/spark/daria/sql/types/StructTypeHelpers.scala +++ b/core/src/main/scala/com/github/mrpowers/spark/daria/sql/types/StructTypeHelpers.scala @@ -35,19 +35,18 @@ object StructTypeHelpers { } } - def schemaToSortedSelectExpr(schema: StructType, baseField: String = ""): Seq[Column] = { - val result = schema.fields.sortBy(_.name).sortBy(_.name).foldLeft(Seq.empty[Column]) { case(acc, field) => + private def schemaToSortedSelectExpr[A](schema: StructType, f: StructField => A, baseField: String = "")(implicit ord: Ordering[A]): Seq[Column] = { + schema.fields.sortBy(f).foldLeft(Seq.empty[Column]) { case(acc, field) => val colName = if (baseField.isEmpty) field.name else s"$baseField.${field.name}" field.dataType match { case t: StructType => - acc :+ struct(schemaToSortedSelectExpr(t, baseField = colName): _*).as(field.name) + acc :+ struct(schemaToSortedSelectExpr(t, f, colName): _*).as(field.name) case ArrayType(t: StructType, _) => - acc :+ arrays_zip(schemaToSortedSelectExpr(t, baseField = colName): _*).as(field.name) + acc :+ arrays_zip(schemaToSortedSelectExpr(t, f, colName): _*).as(field.name) case _ => acc :+ col(colName) } } - result } /** @@ -62,4 +61,9 @@ object StructTypeHelpers { }) } + implicit class StructTypeOps(schema: StructType) { + def toSortedSelectExpr[A](f: StructField => A)(implicit ord: Ordering[A]): Seq[Column] = { + schemaToSortedSelectExpr(schema, f) + } + } } diff --git a/core/src/test/scala/com/github/mrpowers/spark/daria/sql/DataFrameExtTest.scala b/core/src/test/scala/com/github/mrpowers/spark/daria/sql/DataFrameExtTest.scala index 38e6be3d..ca27c120 100644 --- a/core/src/test/scala/com/github/mrpowers/spark/daria/sql/DataFrameExtTest.scala +++ b/core/src/test/scala/com/github/mrpowers/spark/daria/sql/DataFrameExtTest.scala @@ -1053,7 +1053,7 @@ object DataFrameExtTest extends TestSuite with DataFrameComparer with SparkSessi spark.sparkContext.parallelize(data), StructType(schema) ) - .selectSortedCols + .sortColumnsBy(_.name) val expectedData = Seq( Row( From f7d1babed4826aaedf386b039146c2f8f95ef740 Mon Sep 17 00:00:00 2001 From: Tuan Pham Date: Mon, 30 Sep 2024 13:44:33 +1000 Subject: [PATCH 04/14] more descriptive doc string example --- .../spark/daria/sql/DataFrameExt.scala | 57 ++++++++++++------- .../daria/sql/types/StructTypeHelpers.scala | 42 +++++++------- 2 files changed, 60 insertions(+), 39 deletions(-) diff --git a/core/src/main/scala/com/github/mrpowers/spark/daria/sql/DataFrameExt.scala b/core/src/main/scala/com/github/mrpowers/spark/daria/sql/DataFrameExt.scala index a4805a53..60b1cb90 100644 --- a/core/src/main/scala/com/github/mrpowers/spark/daria/sql/DataFrameExt.scala +++ b/core/src/main/scala/com/github/mrpowers/spark/daria/sql/DataFrameExt.scala @@ -411,24 +411,43 @@ object DataFrameExt { ) } - /** Sorts this DataFrame columns order according to the Ordering which results from transforming - * an implicitly given Ordering with a transformation function. - * @see [[scala.math.Ordering]] - * @param f the transformation function mapping elements of type [[StructField]] - * to some other domain `A`. - * @param ord the ordering assumed on domain `A`. - * @tparam A the target type of the transformation `f`, and the type where - * the ordering `ord` is defined. - * @return a DataFrame consisting of the fields of this DataFrame - * sorted according to the ordering where `x < y` if - * `ord.lt(f(x), f(y))`. - * - * @example {{{ - * // this works because scala.Ordering will implicitly provide an Ordering[String] - * df.sortColumnsBy(_.name) - * }}} - */ - def sortColumnsBy[A](f: StructField => A)(implicit ord: Ordering[A]): DataFrame = df - .select(df.schema.toSortedSelectExpr(f): _*) + /** + * Sorts this DataFrame columns order according to the Ordering which results from transforming + * an implicitly given Ordering with a transformation function. + * This function will also sort [[StructType]] columns and [[ArrayType]]([[StructType]]) columns. + * @see [[scala.math.Ordering]] + * @param f the transformation function mapping elements of type [[StructField]] + * to some other domain `A`. + * @param ord the ordering assumed on domain `A`. + * @tparam A the target type of the transformation `f`, and the type where + * the ordering `ord` is defined. + * @return a DataFrame consisting of the fields of this DataFrame + * sorted according to the ordering where `x < y` if + * `ord.lt(f(x), f(y))`. + * + * @example {{{ + * // Example DataFrame + * val df = spark.createDataFrame( + * Seq( + * ("John", 30, 2000.0), + * ("Jane", 25, 3000.0) + * ) + * ).toDF("name", "age", "salary") + * + * // Sort columns by name + * val sortedByNameDF = df.sortColumnsBy(_.name) + * sortedByNameDF.show() + * // Output: + * // +---+----+------+ + * // |age|name|salary| + * // +---+----+------+ + * // | 30|John|2000.0| + * // | 25|Jane|3000.0| + * // +---+----+------+ + * }}} + */ + def sortColumnsBy[A](f: StructField => A)(implicit ord: Ordering[A]): DataFrame = + df + .select(df.schema.toSortedSelectExpr(f): _*) } } diff --git a/core/src/main/scala/com/github/mrpowers/spark/daria/sql/types/StructTypeHelpers.scala b/core/src/main/scala/com/github/mrpowers/spark/daria/sql/types/StructTypeHelpers.scala index 65d74423..c77dc841 100644 --- a/core/src/main/scala/com/github/mrpowers/spark/daria/sql/types/StructTypeHelpers.scala +++ b/core/src/main/scala/com/github/mrpowers/spark/daria/sql/types/StructTypeHelpers.scala @@ -22,30 +22,32 @@ object StructTypeHelpers { } def flattenSchema(schema: StructType, baseField: String = "", flattenArrayType: Boolean = false): Seq[Column] = { - schema.fields.foldLeft(Seq.empty[Column]) { case(acc, field) => - val colName = if (baseField.isEmpty) field.name else s"$baseField.${field.name}" - field.dataType match { - case t: StructType => - acc ++ flattenSchema(t, colName) - case ArrayType(t: StructType, _) if flattenArrayType => - acc ++ flattenSchema(t, colName) - case _ => - acc :+ col(colName) - } + schema.fields.foldLeft(Seq.empty[Column]) { + case (acc, field) => + val colName = if (baseField.isEmpty) field.name else s"$baseField.${field.name}" + field.dataType match { + case t: StructType => + acc ++ flattenSchema(t, colName) + case ArrayType(t: StructType, _) if flattenArrayType => + acc ++ flattenSchema(t, colName) + case _ => + acc :+ col(colName) + } } } private def schemaToSortedSelectExpr[A](schema: StructType, f: StructField => A, baseField: String = "")(implicit ord: Ordering[A]): Seq[Column] = { - schema.fields.sortBy(f).foldLeft(Seq.empty[Column]) { case(acc, field) => - val colName = if (baseField.isEmpty) field.name else s"$baseField.${field.name}" - field.dataType match { - case t: StructType => - acc :+ struct(schemaToSortedSelectExpr(t, f, colName): _*).as(field.name) - case ArrayType(t: StructType, _) => - acc :+ arrays_zip(schemaToSortedSelectExpr(t, f, colName): _*).as(field.name) - case _ => - acc :+ col(colName) - } + schema.fields.sortBy(f).foldLeft(Seq.empty[Column]) { + case (acc, field) => + val colName = if (baseField.isEmpty) field.name else s"$baseField.${field.name}" + field.dataType match { + case t: StructType => + acc :+ struct(schemaToSortedSelectExpr(t, f, colName): _*).as(field.name) + case ArrayType(t: StructType, _) => + acc :+ arrays_zip(schemaToSortedSelectExpr(t, f, colName): _*).as(field.name) + case _ => + acc :+ col(colName) + } } } From 20741bba538f18f982d1f07f46d2f205e06b1706 Mon Sep 17 00:00:00 2001 From: Tuan Pham Date: Mon, 30 Sep 2024 20:29:44 +1000 Subject: [PATCH 05/14] Revert unrelated changes --- .../spark/daria/sql/DataFrameExt.scala | 4 +-- .../daria/sql/types/StructTypeHelpers.scala | 28 ++++++++++--------- 2 files changed, 17 insertions(+), 15 deletions(-) diff --git a/core/src/main/scala/com/github/mrpowers/spark/daria/sql/DataFrameExt.scala b/core/src/main/scala/com/github/mrpowers/spark/daria/sql/DataFrameExt.scala index 60b1cb90..4f0b12fc 100644 --- a/core/src/main/scala/com/github/mrpowers/spark/daria/sql/DataFrameExt.scala +++ b/core/src/main/scala/com/github/mrpowers/spark/daria/sql/DataFrameExt.scala @@ -269,9 +269,9 @@ object DataFrameExt { * This StackOverflow answer provides a detailed description how to use flattenSchema: https://stackoverflow.com/a/50402697/1125159 */ def flattenSchema(delimiter: String = "."): DataFrame = { - val renamedCols = StructTypeHelpers + val renamedCols: Array[Column] = StructTypeHelpers .flattenSchema(df.schema) - .map(c => c.as(c.toString.replace(".", delimiter))) + .map(name => col(name.toString).as(name.toString.replace(".", delimiter))) df.select(renamedCols: _*) } diff --git a/core/src/main/scala/com/github/mrpowers/spark/daria/sql/types/StructTypeHelpers.scala b/core/src/main/scala/com/github/mrpowers/spark/daria/sql/types/StructTypeHelpers.scala index c77dc841..9aaa08eb 100644 --- a/core/src/main/scala/com/github/mrpowers/spark/daria/sql/types/StructTypeHelpers.scala +++ b/core/src/main/scala/com/github/mrpowers/spark/daria/sql/types/StructTypeHelpers.scala @@ -21,19 +21,21 @@ object StructTypeHelpers { } } - def flattenSchema(schema: StructType, baseField: String = "", flattenArrayType: Boolean = false): Seq[Column] = { - schema.fields.foldLeft(Seq.empty[Column]) { - case (acc, field) => - val colName = if (baseField.isEmpty) field.name else s"$baseField.${field.name}" - field.dataType match { - case t: StructType => - acc ++ flattenSchema(t, colName) - case ArrayType(t: StructType, _) if flattenArrayType => - acc ++ flattenSchema(t, colName) - case _ => - acc :+ col(colName) - } - } + def flattenSchema(schema: StructType, prefix: String = ""): Array[Column] = { + schema.fields.flatMap(structField => { + val codeColName = + if (prefix.isEmpty) structField.name + else prefix + "." + structField.name + + structField.dataType match { + case st: StructType => + flattenSchema( + schema = st, + prefix = codeColName + ) + case _ => Array(col(codeColName)) + } + }) } private def schemaToSortedSelectExpr[A](schema: StructType, f: StructField => A, baseField: String = "")(implicit ord: Ordering[A]): Seq[Column] = { From a8024f1e1e613292ace590490b19c1fe2777ec2f Mon Sep 17 00:00:00 2001 From: Tuan Pham Date: Tue, 1 Oct 2024 08:43:28 +1000 Subject: [PATCH 06/14] Support arbitrarily nested array --- .../daria/sql/types/StructTypeHelpers.scala | 41 ++++++++++---- .../spark/daria/sql/DataFrameExtTest.scala | 54 ++++++++++++++++++- 2 files changed, 83 insertions(+), 12 deletions(-) diff --git a/core/src/main/scala/com/github/mrpowers/spark/daria/sql/types/StructTypeHelpers.scala b/core/src/main/scala/com/github/mrpowers/spark/daria/sql/types/StructTypeHelpers.scala index 9aaa08eb..8c6a2c98 100644 --- a/core/src/main/scala/com/github/mrpowers/spark/daria/sql/types/StructTypeHelpers.scala +++ b/core/src/main/scala/com/github/mrpowers/spark/daria/sql/types/StructTypeHelpers.scala @@ -5,6 +5,7 @@ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.types.{ArrayType, DataType, StructField, StructType} import org.apache.spark.sql.functions._ +import scala.annotation.tailrec import scala.reflect.runtime.universe._ object StructTypeHelpers { @@ -39,20 +40,40 @@ object StructTypeHelpers { } private def schemaToSortedSelectExpr[A](schema: StructType, f: StructField => A, baseField: String = "")(implicit ord: Ordering[A]): Seq[Column] = { + def handleDataType(t: DataType, colName: String, simpleName: String): Column = + t match { + case st: StructType => + struct(schemaToSortedSelectExpr(st, f, colName): _*).as(simpleName) + case ArrayType(_, _) => + handleArrayType(t, col(colName), simpleName).as(simpleName) + case _ => + col(colName) + } + + // For handling reordering of nested arrays + def handleArrayType(t: DataType, innerCol: Column, simpleName: String): Column = + t match { + case ArrayType(innerType: ArrayType, _) => transform(innerCol, outer => handleArrayType(innerType, outer, simpleName)) + case ArrayType(innerType: StructType, _) => + val cols = schemaToSortedSelectExpr(innerType, f) + transform(innerCol, innerCol1 => struct(cols.map(c => innerCol1.getField(c.toString).as(c.toString)): _*)).as(simpleName) + case _ => innerCol + } + schema.fields.sortBy(f).foldLeft(Seq.empty[Column]) { case (acc, field) => val colName = if (baseField.isEmpty) field.name else s"$baseField.${field.name}" - field.dataType match { - case t: StructType => - acc :+ struct(schemaToSortedSelectExpr(t, f, colName): _*).as(field.name) - case ArrayType(t: StructType, _) => - acc :+ arrays_zip(schemaToSortedSelectExpr(t, f, colName): _*).as(field.name) - case _ => - acc :+ col(colName) - } + acc :+ handleDataType(field.dataType, colName, field.name) } } + @tailrec + private def getInnerMostType(dType: DataType, nDim: Int = 0): (DataType, Int) = + dType match { + case at: ArrayType => getInnerMostType(at.elementType, nDim + 1) + case t => (t, nDim) + } + /** * gets a StructType from a Scala type and * transforms field names from camel case to snake case @@ -66,8 +87,6 @@ object StructTypeHelpers { } implicit class StructTypeOps(schema: StructType) { - def toSortedSelectExpr[A](f: StructField => A)(implicit ord: Ordering[A]): Seq[Column] = { - schemaToSortedSelectExpr(schema, f) - } + def toSortedSelectExpr[A](f: StructField => A)(implicit ord: Ordering[A]): Seq[Column] = schemaToSortedSelectExpr(schema, f) } } diff --git a/core/src/test/scala/com/github/mrpowers/spark/daria/sql/DataFrameExtTest.scala b/core/src/test/scala/com/github/mrpowers/spark/daria/sql/DataFrameExtTest.scala index ca27c120..6cbd05d3 100644 --- a/core/src/test/scala/com/github/mrpowers/spark/daria/sql/DataFrameExtTest.scala +++ b/core/src/test/scala/com/github/mrpowers/spark/daria/sql/DataFrameExtTest.scala @@ -963,7 +963,7 @@ object DataFrameExtTest extends TestSuite with DataFrameComparer with SparkSessi 'selectSortedCols - { - "select col with all field sorted" - { + "select col with all field sorted by name" - { val data = Seq( Row( @@ -973,6 +973,11 @@ object DataFrameExtTest extends TestSuite with DataFrameComparer with SparkSessi Row("this", "yoVal"), "is" ), + Seq( + Seq( + Row("yVal", "xVal"), + Row("yVal1", "xVal1")), + ), Seq( Row("yVal", "xVal"), Row("yVal1", "xVal1") @@ -1025,6 +1030,11 @@ object DataFrameExtTest extends TestSuite with DataFrameComparer with SparkSessi ), true ), + StructField( + "v", + ArrayType(ArrayType(StructType(Seq(StructField("v2", StringType, true), StructField("v1", StringType, true))))), + true + ), StructField( "w", ArrayType(StructType(Seq(StructField("y", StringType, true), StructField("x", StringType, true)))), @@ -1063,6 +1073,12 @@ object DataFrameExtTest extends TestSuite with DataFrameComparer with SparkSessi "bayVal", "is" ), + Seq( + Seq( + Row("xVal", "yVal"), + Row("xVal1", "yVal1") + ) + ), Seq( Row("xVal", "yVal"), Row("xVal1", "yVal1") @@ -1116,6 +1132,11 @@ object DataFrameExtTest extends TestSuite with DataFrameComparer with SparkSessi ), false ), + StructField( + "v", + ArrayType(ArrayType(StructType(Seq(StructField("v1", StringType, true), StructField("v2", StringType, true))), false)), + true + ), StructField( "w", ArrayType(StructType(Seq(StructField("x", StringType, true), StructField("y", StringType, true))), false), @@ -1152,6 +1173,37 @@ object DataFrameExtTest extends TestSuite with DataFrameComparer with SparkSessi ) } + "select col with all field sorted by custom dataType ordering" - { + + val actualDF = + spark.createDF( + List(("this", 1, 1L, 1.0)), + List( + ("a", StringType, true), + ("b", IntegerType, true), + ("c", LongType, true), + ("d", DoubleType, true), + ) + ).sortColumnsBy(_.dataType)(Ordering.by(_.json)) + + val expectedDF = + spark.createDF( + List((1.0, 1, 1L, "this")), + List( + ("d", DoubleType, true), + ("b", IntegerType, true), + ("c", LongType, true), + ("a", StringType, true), + ) + ) + + assertSmallDataFrameEquality( + actualDF, + expectedDF, + ignoreNullable = true + ) + } + } 'structureSchema - { From ff841518dc34da63bd481197cba6433125e74b29 Mon Sep 17 00:00:00 2001 From: Tuan Pham Date: Tue, 1 Oct 2024 08:46:51 +1000 Subject: [PATCH 07/14] Remove unused helper func --- .../mrpowers/spark/daria/sql/types/StructTypeHelpers.scala | 7 ------- 1 file changed, 7 deletions(-) diff --git a/core/src/main/scala/com/github/mrpowers/spark/daria/sql/types/StructTypeHelpers.scala b/core/src/main/scala/com/github/mrpowers/spark/daria/sql/types/StructTypeHelpers.scala index 8c6a2c98..89117bad 100644 --- a/core/src/main/scala/com/github/mrpowers/spark/daria/sql/types/StructTypeHelpers.scala +++ b/core/src/main/scala/com/github/mrpowers/spark/daria/sql/types/StructTypeHelpers.scala @@ -67,13 +67,6 @@ object StructTypeHelpers { } } - @tailrec - private def getInnerMostType(dType: DataType, nDim: Int = 0): (DataType, Int) = - dType match { - case at: ArrayType => getInnerMostType(at.elementType, nDim + 1) - case t => (t, nDim) - } - /** * gets a StructType from a Scala type and * transforms field names from camel case to snake case From cb658b553c8f608e93fbba4033b42f5011e45ffb Mon Sep 17 00:00:00 2001 From: Tuan Pham Date: Thu, 3 Oct 2024 20:41:03 +1000 Subject: [PATCH 08/14] Test for nested array --- .../daria/sql/types/StructTypeHelpers.scala | 2 +- .../spark/daria/sql/DataFrameExtTest.scala | 164 ++++++++++++++++-- 2 files changed, 152 insertions(+), 14 deletions(-) diff --git a/core/src/main/scala/com/github/mrpowers/spark/daria/sql/types/StructTypeHelpers.scala b/core/src/main/scala/com/github/mrpowers/spark/daria/sql/types/StructTypeHelpers.scala index 89117bad..e64077f7 100644 --- a/core/src/main/scala/com/github/mrpowers/spark/daria/sql/types/StructTypeHelpers.scala +++ b/core/src/main/scala/com/github/mrpowers/spark/daria/sql/types/StructTypeHelpers.scala @@ -56,7 +56,7 @@ object StructTypeHelpers { case ArrayType(innerType: ArrayType, _) => transform(innerCol, outer => handleArrayType(innerType, outer, simpleName)) case ArrayType(innerType: StructType, _) => val cols = schemaToSortedSelectExpr(innerType, f) - transform(innerCol, innerCol1 => struct(cols.map(c => innerCol1.getField(c.toString).as(c.toString)): _*)).as(simpleName) + transform(innerCol, innerCol1 => struct(cols.map(c => innerCol1.getField(c.toString).as(c.toString)): _*)) case _ => innerCol } diff --git a/core/src/test/scala/com/github/mrpowers/spark/daria/sql/DataFrameExtTest.scala b/core/src/test/scala/com/github/mrpowers/spark/daria/sql/DataFrameExtTest.scala index 6cbd05d3..3623d636 100644 --- a/core/src/test/scala/com/github/mrpowers/spark/daria/sql/DataFrameExtTest.scala +++ b/core/src/test/scala/com/github/mrpowers/spark/daria/sql/DataFrameExtTest.scala @@ -974,9 +974,7 @@ object DataFrameExtTest extends TestSuite with DataFrameComparer with SparkSessi "is" ), Seq( - Seq( - Row("yVal", "xVal"), - Row("yVal1", "xVal1")), + Seq(Row("yVal", "xVal"), Row("yVal1", "xVal1")) ), Seq( Row("yVal", "xVal"), @@ -1025,7 +1023,7 @@ object DataFrameExtTest extends TestSuite with DataFrameComparer with SparkSessi "baz", StringType, true - ), + ) ) ), true @@ -1173,18 +1171,158 @@ object DataFrameExtTest extends TestSuite with DataFrameComparer with SparkSessi ) } + "select col with all field sorted by name for arbitrarily nested array" - { + val data = Seq( + Row( + Row( + Seq( + Row("a4Val", "a2Val") + ) + ), + Seq( + Seq(Row("yVal", "xVal"), Row("yVal1", "xVal1")) + ), + Seq( + Row("yVal", "xVal"), + Row("yVal1", "xVal1") + ) +// Seq( +// Row( +// Seq( +// Row("x4Val", "x3Val") +// ) +// ) +// ) + ) + ) + + val schema = StructType( + Seq( + StructField( + "a1", + StructType( + Seq( + StructField( + "a2", + ArrayType(StructType(Seq(StructField("a4", StringType, true), StructField("a2", StringType, true))), containsNull = false), + true + ) + ) + ), + false + ), + StructField( + "v", + ArrayType(ArrayType(StructType(Seq(StructField("v2", StringType, true), StructField("v1", StringType, true))))), + true + ), + StructField( + "w", + ArrayType(StructType(Seq(StructField("y", StringType, true), StructField("x", StringType, true)))), + true + ) +// StructField( +// "x", +// ArrayType(StructType(Seq(StructField("x1", ArrayType(StructType(Seq(StructField("x4", StringType, true), StructField("x3", StringType, true)))), true)))), +// true +// ), + ) + ) + + val df = spark + .createDataFrame( + spark.sparkContext.parallelize(data), + StructType(schema) + ) + .sortColumnsBy(_.name) + + val expectedData = Seq( + Row( + Row( + Seq( + Row("a2Val", "a4Val") + ) + ), + Seq( + Seq( + Row("xVal", "yVal"), + Row("xVal1", "yVal1") + ) + ), + Seq( + Row("xVal", "yVal"), + Row("xVal1", "yVal1") + ) +// Seq( +// Row( +// Seq( +// Row("x3Val", "x4Val") +// ) +// ) +// ) + ) + ) + + val expectedSchema = StructType( + Seq( + StructField( + "a1", + StructType( + Seq( + StructField( + "a2", + ArrayType(StructType(Seq(StructField("a2", StringType, true), StructField("a4", StringType, true))), containsNull = false), + true + ) + ) + ), + false + ), + StructField( + "v", + ArrayType(ArrayType(StructType(Seq(StructField("v1", StringType, true), StructField("v2", StringType, true))), false)), + true + ), + StructField( + "w", + ArrayType(StructType(Seq(StructField("x", StringType, true), StructField("y", StringType, true))), false), + true + ) +// StructField( +// "x", +// ArrayType(StructType(Seq(StructField("x1", ArrayType(StructType(Seq(StructField("x3", StringType, true), StructField("x4", StringType, true)))), true)))), +// true +// ) + ) + ) + + val expectedDF = spark + .createDataFrame( + spark.sparkContext.parallelize(expectedData), + StructType(expectedSchema) + ) + + assertSmallDataFrameEquality( + df, + expectedDF, + ignoreNullable = true + ) + } + "select col with all field sorted by custom dataType ordering" - { val actualDF = - spark.createDF( - List(("this", 1, 1L, 1.0)), - List( - ("a", StringType, true), - ("b", IntegerType, true), - ("c", LongType, true), - ("d", DoubleType, true), + spark + .createDF( + List(("this", 1, 1L, 1.0)), + List( + ("a", StringType, true), + ("b", IntegerType, true), + ("c", LongType, true), + ("d", DoubleType, true) + ) ) - ).sortColumnsBy(_.dataType)(Ordering.by(_.json)) + .sortColumnsBy(_.dataType)(Ordering.by(_.json)) val expectedDF = spark.createDF( @@ -1193,7 +1331,7 @@ object DataFrameExtTest extends TestSuite with DataFrameComparer with SparkSessi ("d", DoubleType, true), ("b", IntegerType, true), ("c", LongType, true), - ("a", StringType, true), + ("a", StringType, true) ) ) From b9287521c5aabd6dbf64f5b2f3428728eea929bc Mon Sep 17 00:00:00 2001 From: Tuan Pham Date: Sat, 5 Oct 2024 15:48:53 +1000 Subject: [PATCH 09/14] Support Array(Struct(Array)) cases --- .../spark/daria/sql/DataFrameExt.scala | 3 +- .../spark/daria/sql/DataFrameExtTest.scala | 59 +++++++++++-------- 2 files changed, 36 insertions(+), 26 deletions(-) diff --git a/core/src/main/scala/com/github/mrpowers/spark/daria/sql/DataFrameExt.scala b/core/src/main/scala/com/github/mrpowers/spark/daria/sql/DataFrameExt.scala index 4f0b12fc..e61c1775 100644 --- a/core/src/main/scala/com/github/mrpowers/spark/daria/sql/DataFrameExt.scala +++ b/core/src/main/scala/com/github/mrpowers/spark/daria/sql/DataFrameExt.scala @@ -447,7 +447,6 @@ object DataFrameExt { * }}} */ def sortColumnsBy[A](f: StructField => A)(implicit ord: Ordering[A]): DataFrame = - df - .select(df.schema.toSortedSelectExpr(f): _*) + df.select(df.schema.toSortedSelectExpr(f): _*) } } diff --git a/core/src/test/scala/com/github/mrpowers/spark/daria/sql/DataFrameExtTest.scala b/core/src/test/scala/com/github/mrpowers/spark/daria/sql/DataFrameExtTest.scala index 3623d636..abcfcf42 100644 --- a/core/src/test/scala/com/github/mrpowers/spark/daria/sql/DataFrameExtTest.scala +++ b/core/src/test/scala/com/github/mrpowers/spark/daria/sql/DataFrameExtTest.scala @@ -1185,14 +1185,14 @@ object DataFrameExtTest extends TestSuite with DataFrameComparer with SparkSessi Seq( Row("yVal", "xVal"), Row("yVal1", "xVal1") + ), + Seq( + Row( + Seq( + Row("x4Val", "x3Val") + ) + ) ) -// Seq( -// Row( -// Seq( -// Row("x4Val", "x3Val") -// ) -// ) -// ) ) ) @@ -1220,12 +1220,16 @@ object DataFrameExtTest extends TestSuite with DataFrameComparer with SparkSessi "w", ArrayType(StructType(Seq(StructField("y", StringType, true), StructField("x", StringType, true)))), true + ), + StructField( + "x", + ArrayType( + StructType( + Seq(StructField("x1", ArrayType(StructType(Seq(StructField("b", StringType, true), StructField("a", StringType, true)))), true)) + ) + ), + true ) -// StructField( -// "x", -// ArrayType(StructType(Seq(StructField("x1", ArrayType(StructType(Seq(StructField("x4", StringType, true), StructField("x3", StringType, true)))), true)))), -// true -// ), ) ) @@ -1252,14 +1256,14 @@ object DataFrameExtTest extends TestSuite with DataFrameComparer with SparkSessi Seq( Row("xVal", "yVal"), Row("xVal1", "yVal1") + ), + Seq( + Row( + Seq( + Row("x3Val", "x4Val") + ) + ) ) -// Seq( -// Row( -// Seq( -// Row("x3Val", "x4Val") -// ) -// ) -// ) ) ) @@ -1287,12 +1291,19 @@ object DataFrameExtTest extends TestSuite with DataFrameComparer with SparkSessi "w", ArrayType(StructType(Seq(StructField("x", StringType, true), StructField("y", StringType, true))), false), true + ), + StructField( + "x", + ArrayType( + StructType( + Seq( + StructField("x1", ArrayType(StructType(Seq(StructField("a", StringType, true), StructField("b", StringType, true))), false), true) + ) + ), + false + ), + true ) -// StructField( -// "x", -// ArrayType(StructType(Seq(StructField("x1", ArrayType(StructType(Seq(StructField("x3", StringType, true), StructField("x4", StringType, true)))), true)))), -// true -// ) ) ) From 3d833488ae66f8f1dd9aa35c077351fd3fe10244 Mon Sep 17 00:00:00 2001 From: Tuan Pham Date: Sat, 5 Oct 2024 15:51:26 +1000 Subject: [PATCH 10/14] Support Array(Struct(Array(Struct))) cases --- .../daria/sql/types/StructTypeHelpers.scala | 45 ++++++++++++++----- 1 file changed, 33 insertions(+), 12 deletions(-) diff --git a/core/src/main/scala/com/github/mrpowers/spark/daria/sql/types/StructTypeHelpers.scala b/core/src/main/scala/com/github/mrpowers/spark/daria/sql/types/StructTypeHelpers.scala index e64077f7..f2a9a8ba 100644 --- a/core/src/main/scala/com/github/mrpowers/spark/daria/sql/types/StructTypeHelpers.scala +++ b/core/src/main/scala/com/github/mrpowers/spark/daria/sql/types/StructTypeHelpers.scala @@ -40,31 +40,52 @@ object StructTypeHelpers { } private def schemaToSortedSelectExpr[A](schema: StructType, f: StructField => A, baseField: String = "")(implicit ord: Ordering[A]): Seq[Column] = { - def handleDataType(t: DataType, colName: String, simpleName: String): Column = + def handleNestedType(t: DataType, name: String, outerCol: Column): Column = t match { case st: StructType => - struct(schemaToSortedSelectExpr(st, f, colName): _*).as(simpleName) + val sortedFields = st.fields.sortBy(f) + struct( + sortedFields.map(f => + f.dataType match { + case st: StructType => + handleNestedType(st, f.name, outerCol(f.name)).as(f.name) + case ArrayType(innerType: StructType, _) => + handleArrayType(f.dataType, name, outerCol(f.name)).as(f.name) + case _ => + handleNestedType(f.dataType, f.name, outerCol).as(f.name) + } + ): _* + ).as(name) case ArrayType(_, _) => - handleArrayType(t, col(colName), simpleName).as(simpleName) + handleArrayType(t, name, outerCol).as(name) case _ => - col(colName) + outerCol(name) } // For handling reordering of nested arrays - def handleArrayType(t: DataType, innerCol: Column, simpleName: String): Column = + def handleArrayType(t: DataType, name: String, outer: Column): Column = t match { - case ArrayType(innerType: ArrayType, _) => transform(innerCol, outer => handleArrayType(innerType, outer, simpleName)) + case ArrayType(innerType: ArrayType, _) => + transform(outer, inner => handleArrayType(innerType, name, inner)).as(name) case ArrayType(innerType: StructType, _) => - val cols = schemaToSortedSelectExpr(innerType, f) - transform(innerCol, innerCol1 => struct(cols.map(c => innerCol1.getField(c.toString).as(c.toString)): _*)) - case _ => innerCol + transform(outer, inner => handleNestedType(innerType, name, inner).as(name)).as(name) + case _ => outer.as(name) } - schema.fields.sortBy(f).foldLeft(Seq.empty[Column]) { + val result = schema.fields.sortBy(f).foldLeft(Seq.empty[Column]) { case (acc, field) => - val colName = if (baseField.isEmpty) field.name else s"$baseField.${field.name}" - acc :+ handleDataType(field.dataType, colName, field.name) + val name = field.name + val sortedCol = field.dataType match { + case st: StructType => + handleNestedType(st, name, col(name)) + case arr: ArrayType => + handleArrayType(arr, name, col(name)) + case _ => col(name) + } + + acc :+ sortedCol } + result } /** From ffcc14add15c73dee2ff8e5be8dadea1dbf9ca21 Mon Sep 17 00:00:00 2001 From: Tuan Pham Date: Sat, 5 Oct 2024 16:35:45 +1000 Subject: [PATCH 11/14] Simplify handleNestedType --- .../daria/sql/types/StructTypeHelpers.scala | 45 ++++++++----------- 1 file changed, 18 insertions(+), 27 deletions(-) diff --git a/core/src/main/scala/com/github/mrpowers/spark/daria/sql/types/StructTypeHelpers.scala b/core/src/main/scala/com/github/mrpowers/spark/daria/sql/types/StructTypeHelpers.scala index f2a9a8ba..8b01c041 100644 --- a/core/src/main/scala/com/github/mrpowers/spark/daria/sql/types/StructTypeHelpers.scala +++ b/core/src/main/scala/com/github/mrpowers/spark/daria/sql/types/StructTypeHelpers.scala @@ -39,27 +39,27 @@ object StructTypeHelpers { }) } - private def schemaToSortedSelectExpr[A](schema: StructType, f: StructField => A, baseField: String = "")(implicit ord: Ordering[A]): Seq[Column] = { - def handleNestedType(t: DataType, name: String, outerCol: Column): Column = + private def schemaToSortedSelectExpr[A](schema: StructType, f: StructField => A)(implicit ord: Ordering[A]): Seq[Column] = { + def handleNestedType(t: DataType, name: String, outerCol: Column, firstLevel: Boolean = false): Column = t match { case st: StructType => - val sortedFields = st.fields.sortBy(f) struct( - sortedFields.map(f => - f.dataType match { - case st: StructType => - handleNestedType(st, f.name, outerCol(f.name)).as(f.name) - case ArrayType(innerType: StructType, _) => - handleArrayType(f.dataType, name, outerCol(f.name)).as(f.name) - case _ => - handleNestedType(f.dataType, f.name, outerCol).as(f.name) - } - ): _* + st.fields + .sortBy(f) + .map(field => + handleNestedType( + field.dataType, + field.name, + field.dataType match { + case StructType(_) | ArrayType(_: StructType, _) => outerCol(field.name) + case _ => outerCol + } + ).as(field.name) + ): _* ).as(name) - case ArrayType(_, _) => - handleArrayType(t, name, outerCol).as(name) - case _ => - outerCol(name) + case ArrayType(_, _) => handleArrayType(t, name, outerCol).as(name) + case _ if firstLevel => outerCol + case _ if !firstLevel => outerCol(name) } // For handling reordering of nested arrays @@ -74,16 +74,7 @@ object StructTypeHelpers { val result = schema.fields.sortBy(f).foldLeft(Seq.empty[Column]) { case (acc, field) => - val name = field.name - val sortedCol = field.dataType match { - case st: StructType => - handleNestedType(st, name, col(name)) - case arr: ArrayType => - handleArrayType(arr, name, col(name)) - case _ => col(name) - } - - acc :+ sortedCol + acc :+ handleNestedType(field.dataType, field.name, col(field.name), firstLevel = true) } result } From ea71fcbb0bf092b4ff48f9604ca84119e305561d Mon Sep 17 00:00:00 2001 From: Tuan Pham Date: Mon, 7 Oct 2024 21:09:30 +1100 Subject: [PATCH 12/14] simplify nested field logic, Improve variable naming --- .../daria/sql/types/StructTypeHelpers.scala | 34 +++++++------------ 1 file changed, 12 insertions(+), 22 deletions(-) diff --git a/core/src/main/scala/com/github/mrpowers/spark/daria/sql/types/StructTypeHelpers.scala b/core/src/main/scala/com/github/mrpowers/spark/daria/sql/types/StructTypeHelpers.scala index 8b01c041..dee701ee 100644 --- a/core/src/main/scala/com/github/mrpowers/spark/daria/sql/types/StructTypeHelpers.scala +++ b/core/src/main/scala/com/github/mrpowers/spark/daria/sql/types/StructTypeHelpers.scala @@ -40,43 +40,33 @@ object StructTypeHelpers { } private def schemaToSortedSelectExpr[A](schema: StructType, f: StructField => A)(implicit ord: Ordering[A]): Seq[Column] = { - def handleNestedType(t: DataType, name: String, outerCol: Column, firstLevel: Boolean = false): Column = - t match { + def childFieldToCol(childFieldType: DataType, childFieldName: String, parentCol: Column, firstLevel: Boolean = false): Column = + childFieldType match { case st: StructType => struct( st.fields .sortBy(f) .map(field => - handleNestedType( + childFieldToCol( field.dataType, field.name, field.dataType match { - case StructType(_) | ArrayType(_: StructType, _) => outerCol(field.name) - case _ => outerCol + case StructType(_) | ArrayType(_: StructType, _) => parentCol(field.name) + case _ => parentCol } ).as(field.name) ): _* - ).as(name) - case ArrayType(_, _) => handleArrayType(t, name, outerCol).as(name) - case _ if firstLevel => outerCol - case _ if !firstLevel => outerCol(name) + ).as(childFieldName) + case ArrayType(innerType, _) => + transform(parentCol, childCol => childFieldToCol(innerType, childFieldName, childCol)).as(childFieldName) + case _ if firstLevel => parentCol + case _ if !firstLevel => parentCol(childFieldName) } - // For handling reordering of nested arrays - def handleArrayType(t: DataType, name: String, outer: Column): Column = - t match { - case ArrayType(innerType: ArrayType, _) => - transform(outer, inner => handleArrayType(innerType, name, inner)).as(name) - case ArrayType(innerType: StructType, _) => - transform(outer, inner => handleNestedType(innerType, name, inner).as(name)).as(name) - case _ => outer.as(name) - } - - val result = schema.fields.sortBy(f).foldLeft(Seq.empty[Column]) { + schema.fields.sortBy(f).foldLeft(Seq.empty[Column]) { case (acc, field) => - acc :+ handleNestedType(field.dataType, field.name, col(field.name), firstLevel = true) + acc :+ childFieldToCol(field.dataType, field.name, col(field.name), firstLevel = true) } - result } /** From eb0d8e7651d40d4fc8c000c831d6d96ee16e21e3 Mon Sep 17 00:00:00 2001 From: Tuan Pham Date: Mon, 7 Oct 2024 21:12:04 +1100 Subject: [PATCH 13/14] Fix Formatting --- .../mrpowers/spark/daria/sql/types/StructTypeHelpers.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/com/github/mrpowers/spark/daria/sql/types/StructTypeHelpers.scala b/core/src/main/scala/com/github/mrpowers/spark/daria/sql/types/StructTypeHelpers.scala index dee701ee..d51d0496 100644 --- a/core/src/main/scala/com/github/mrpowers/spark/daria/sql/types/StructTypeHelpers.scala +++ b/core/src/main/scala/com/github/mrpowers/spark/daria/sql/types/StructTypeHelpers.scala @@ -57,7 +57,7 @@ object StructTypeHelpers { ).as(field.name) ): _* ).as(childFieldName) - case ArrayType(innerType, _) => + case ArrayType(innerType, _) => transform(parentCol, childCol => childFieldToCol(innerType, childFieldName, childCol)).as(childFieldName) case _ if firstLevel => parentCol case _ if !firstLevel => parentCol(childFieldName) From 9c3af23d5915ff32839787227907036543a34f96 Mon Sep 17 00:00:00 2001 From: Tuan Pham Date: Tue, 8 Oct 2024 09:20:54 +1100 Subject: [PATCH 14/14] Simplify with map --- .../mrpowers/spark/daria/sql/types/StructTypeHelpers.scala | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/core/src/main/scala/com/github/mrpowers/spark/daria/sql/types/StructTypeHelpers.scala b/core/src/main/scala/com/github/mrpowers/spark/daria/sql/types/StructTypeHelpers.scala index d51d0496..f7801b01 100644 --- a/core/src/main/scala/com/github/mrpowers/spark/daria/sql/types/StructTypeHelpers.scala +++ b/core/src/main/scala/com/github/mrpowers/spark/daria/sql/types/StructTypeHelpers.scala @@ -63,10 +63,7 @@ object StructTypeHelpers { case _ if !firstLevel => parentCol(childFieldName) } - schema.fields.sortBy(f).foldLeft(Seq.empty[Column]) { - case (acc, field) => - acc :+ childFieldToCol(field.dataType, field.name, col(field.name), firstLevel = true) - } + schema.fields.sortBy(f).map(field => childFieldToCol(field.dataType, field.name, col(field.name), firstLevel = true)) } /**