Skip to content

Commit

Permalink
support flatten array
Browse files Browse the repository at this point in the history
  • Loading branch information
zeotuan committed Jan 5, 2025
1 parent ffa6005 commit d69be53
Show file tree
Hide file tree
Showing 3 changed files with 145 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{ArrayType, DataType, StructField, StructType}
import org.apache.spark.sql.{Column, DataFrame}

import scala.annotation.tailrec
import scala.collection.mutable
import scala.util.matching.Regex

case class DataFrameColumnsException(smth: String) extends Exception(smth)

Expand Down Expand Up @@ -268,11 +270,30 @@ object DataFrameExt {
* Converts all the StructType columns to regular columns
* This StackOverflow answer provides a detailed description how to use flattenSchema: https://stackoverflow.com/a/50402697/1125159
*/
def flattenSchema(delimiter: String = "."): DataFrame = {
val renamedCols: Array[Column] = StructTypeHelpers
.flattenSchema(df.schema)
.map(name => col(name.toString).as(name.toString.replace(".", delimiter)))
df.select(renamedCols: _*)
def flattenSchema(
delimiter: String = ".",
flattenArrayType: Boolean = false
): DataFrame = {
if (delimiter == "." && flattenArrayType) {
throw new IllegalArgumentException("Cannot use '.' as delimiter when flattening ArrayType columns")
}

@tailrec
def flatten(df: DataFrame): DataFrame = {
if (StructTypeHelpers.containComplexFields(df.schema)) {
val renamedCols = StructTypeHelpers
.flattenSchema(df.schema, "", flattenArrayType)
.map { case (c, n) => c.as(sanitizeColName(n, delimiter)) }
val flattened = df.select(renamedCols: _*)
flatten(flattened)
} else df
}
flatten(df)
}

private def sanitizeColName(name: String, rc: String = "_"): String = {
val pattern: Regex = "[^a-zA-Z0-9_]".r
pattern.replaceAllIn(name, rc)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@ package com.github.mrpowers.spark.daria.sql.types

import org.apache.spark.sql.Column
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.types.{ArrayType, DataType, StructField, StructType}
import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructField, StructType}
import org.apache.spark.sql.functions._

import scala.annotation.tailrec
import scala.reflect.runtime.universe._

object StructTypeHelpers {
Expand All @@ -22,21 +21,32 @@ object StructTypeHelpers {
}
}

def flattenSchema(schema: StructType, prefix: String = ""): Array[Column] = {
schema.fields.flatMap(structField => {
val codeColName =
if (prefix.isEmpty) structField.name
else prefix + "." + structField.name

structField.dataType match {
case st: StructType =>
flattenSchema(
schema = st,
prefix = codeColName
)
case _ => Array(col(codeColName))
def flattenSchema(
schema: StructType,
baseField: String = "",
flattenArrayType: Boolean = false
): Seq[(Column, String)] = {
schema.fields.flatMap { field =>
val colName = if (baseField.isEmpty) field.name else s"$baseField.${field.name}"
field.dataType match {
case t: StructType =>
flattenSchema(t, colName)
case ArrayType(t: StructType, _) if flattenArrayType =>
flattenSchema(t, colName).map { case (c, n) => explode(c) -> n }
case ArrayType(_: ArrayType, _) if flattenArrayType =>
Seq(explode(col(colName)) -> colName)
case _ =>
Seq(col(colName) -> colName)
}
})
}
}

def containComplexFields(schema: StructType): Boolean = {
schema.fields.exists {
case StructField(_, _: StructType, _, _) => true
case StructField(_, _: ArrayType, _, _) => true
case _ => false
}
}

private def schemaToSortedSelectExpr[A](schema: StructType, f: StructField => A)(implicit ord: Ordering[A]): Seq[Column] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import SparkSessionExt._
import org.apache.spark.sql.types.{StructType, _}
import org.apache.spark.sql.{DataFrame, Row}
import DataFrameExt._
import com.github.mrpowers.spark.daria.sql.types.StructTypeHelpers
import com.github.mrpowers.spark.fast.tests.DataFrameComparer
import org.apache.spark.sql.functions._

Expand Down Expand Up @@ -959,6 +960,98 @@ object DataFrameExtTest extends TestSuite with DataFrameComparer with SparkSessi
)
}

"flatten dataframe with nested array" - {

val data = Seq(
Row(
Seq(
Seq(Row("vVal", "vVal1"), Row("vVal2", "vVal3")),
Seq(Row("vVal4", "vVal5"), Row("vVal6", "vVal7"))
),
Seq(
Row("wVal1", "wVal2"),
Row("wVal3", "wVal4")
)
)
)

val schema = StructType(
Seq(
StructField(
"v",
ArrayType(ArrayType(StructType(Seq(StructField("v2", StringType, true), StructField("v1", StringType, true))))),
true
),
StructField(
"w",
ArrayType(StructType(Seq(StructField("y", StringType, true), StructField("x", StringType, true)))),
true
)
)
)

val df = spark
.createDataFrame(
spark.sparkContext.parallelize(data),
StructType(schema)
)

val actualDf = df.flattenSchema("_", flattenArrayType = true)

val expectedData = Seq(
Row("vVal", "vVal1", "wVal1", "wVal2"),
Row("vVal", "vVal3", "wVal1", "wVal2"),
Row("vVal2", "vVal1", "wVal1", "wVal2"),
Row("vVal2", "vVal3", "wVal1", "wVal2"),
Row("vVal", "vVal1", "wVal1", "wVal4"),
Row("vVal", "vVal3", "wVal1", "wVal4"),
Row("vVal2", "vVal1", "wVal1", "wVal4"),
Row("vVal2", "vVal3", "wVal1", "wVal4"),
Row("vVal", "vVal1", "wVal3", "wVal2"),
Row("vVal", "vVal3", "wVal3", "wVal2"),
Row("vVal2", "vVal1", "wVal3", "wVal2"),
Row("vVal2", "vVal3", "wVal3", "wVal2"),
Row("vVal", "vVal1", "wVal3", "wVal4"),
Row("vVal", "vVal3", "wVal3", "wVal4"),
Row("vVal2", "vVal1", "wVal3", "wVal4"),
Row("vVal2", "vVal3", "wVal3", "wVal4"),
Row("vVal4", "vVal5", "wVal1", "wVal2"),
Row("vVal4", "vVal7", "wVal1", "wVal2"),
Row("vVal6", "vVal5", "wVal1", "wVal2"),
Row("vVal6", "vVal7", "wVal1", "wVal2"),
Row("vVal4", "vVal5", "wVal1", "wVal4"),
Row("vVal4", "vVal7", "wVal1", "wVal4"),
Row("vVal6", "vVal5", "wVal1", "wVal4"),
Row("vVal6", "vVal7", "wVal1", "wVal4"),
Row("vVal4", "vVal5", "wVal3", "wVal2"),
Row("vVal4", "vVal7", "wVal3", "wVal2"),
Row("vVal6", "vVal5", "wVal3", "wVal2"),
Row("vVal6", "vVal7", "wVal3", "wVal2"),
Row("vVal4", "vVal5", "wVal3", "wVal4"),
Row("vVal4", "vVal7", "wVal3", "wVal4"),
Row("vVal6", "vVal5", "wVal3", "wVal4"),
Row("vVal6", "vVal7", "wVal3", "wVal4")
)

val expectedSchema = StructType(
Seq(
StructField("v_v2", StringType, true),
StructField("v_v1", StringType, true),
StructField("w_y", StringType, true),
StructField("w_x", StringType, true)
)
)

val expectedDF = spark.createDataFrame(spark.sparkContext.parallelize(expectedData), expectedSchema)

assertSmallDataFrameEquality(
actualDf,
expectedDF,
ignoreNullable = true
)

}

}

'selectSortedCols - {
Expand Down

0 comments on commit d69be53

Please sign in to comment.