Skip to content

Commit

Permalink
Merge pull request #162 from zeotuan/selectSortExpr
Browse files Browse the repository at this point in the history
feat: add support for ordering columns and StructFields, and Arrays of StructFields in DataFrame

- support custom column ordering based on implicit `Ordering`
- support column ordering within nested Struct column
- support column ordering within Arrays of Struct column
- support column ordering for arbitrarily nested Array
  • Loading branch information
zeotuan authored Oct 8, 2024
2 parents f06904f + 9c3af23 commit 617e062
Show file tree
Hide file tree
Showing 3 changed files with 468 additions and 2 deletions.
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
package com.github.mrpowers.spark.daria.sql

import com.github.mrpowers.spark.daria.sql.types.StructTypeHelpers
import com.github.mrpowers.spark.daria.sql.types.StructTypeHelpers.StructTypeOps
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{ArrayType, DataType, StructField, StructType}
import org.apache.spark.sql.{Column, DataFrame}

import scala.collection.mutable

case class DataFrameColumnsException(smth: String) extends Exception(smth)

object DataFrameExt {
Expand Down Expand Up @@ -407,6 +410,43 @@ object DataFrameExt {
StructType(loop(df.schema))
)
}
}

/**
* Sorts this DataFrame columns order according to the Ordering which results from transforming
* an implicitly given Ordering with a transformation function.
* This function will also sort [[StructType]] columns and [[ArrayType]]([[StructType]]) columns.
* @see [[scala.math.Ordering]]
* @param f the transformation function mapping elements of type [[StructField]]
* to some other domain `A`.
* @param ord the ordering assumed on domain `A`.
* @tparam A the target type of the transformation `f`, and the type where
* the ordering `ord` is defined.
* @return a DataFrame consisting of the fields of this DataFrame
* sorted according to the ordering where `x < y` if
* `ord.lt(f(x), f(y))`.
*
* @example {{{
* // Example DataFrame
* val df = spark.createDataFrame(
* Seq(
* ("John", 30, 2000.0),
* ("Jane", 25, 3000.0)
* )
* ).toDF("name", "age", "salary")
*
* // Sort columns by name
* val sortedByNameDF = df.sortColumnsBy(_.name)
* sortedByNameDF.show()
* // Output:
* // +---+----+------+
* // |age|name|salary|
* // +---+----+------+
* // | 30|John|2000.0|
* // | 25|Jane|3000.0|
* // +---+----+------+
* }}}
*/
def sortColumnsBy[A](f: StructField => A)(implicit ord: Ordering[A]): DataFrame =
df.select(df.schema.toSortedSelectExpr(f): _*)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@ package com.github.mrpowers.spark.daria.sql.types

import org.apache.spark.sql.Column
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.types.{DataType, StructField, StructType}
import org.apache.spark.sql.types.{ArrayType, DataType, StructField, StructType}
import org.apache.spark.sql.functions._

import scala.annotation.tailrec
import scala.reflect.runtime.universe._

object StructTypeHelpers {
Expand Down Expand Up @@ -38,6 +39,33 @@ object StructTypeHelpers {
})
}

private def schemaToSortedSelectExpr[A](schema: StructType, f: StructField => A)(implicit ord: Ordering[A]): Seq[Column] = {
def childFieldToCol(childFieldType: DataType, childFieldName: String, parentCol: Column, firstLevel: Boolean = false): Column =
childFieldType match {
case st: StructType =>
struct(
st.fields
.sortBy(f)
.map(field =>
childFieldToCol(
field.dataType,
field.name,
field.dataType match {
case StructType(_) | ArrayType(_: StructType, _) => parentCol(field.name)
case _ => parentCol
}
).as(field.name)
): _*
).as(childFieldName)
case ArrayType(innerType, _) =>
transform(parentCol, childCol => childFieldToCol(innerType, childFieldName, childCol)).as(childFieldName)
case _ if firstLevel => parentCol
case _ if !firstLevel => parentCol(childFieldName)
}

schema.fields.sortBy(f).map(field => childFieldToCol(field.dataType, field.name, col(field.name), firstLevel = true))
}

/**
* gets a StructType from a Scala type and
* transforms field names from camel case to snake case
Expand All @@ -50,4 +78,7 @@ object StructTypeHelpers {
})
}

implicit class StructTypeOps(schema: StructType) {
def toSortedSelectExpr[A](f: StructField => A)(implicit ord: Ordering[A]): Seq[Column] = schemaToSortedSelectExpr(schema, f)
}
}
Loading

0 comments on commit 617e062

Please sign in to comment.