From 5ed04cbfa308a4478de293c069f3b5a98a6c11c5 Mon Sep 17 00:00:00 2001 From: zeotuan <48720253+zeotuan@users.noreply.github.com> Date: Thu, 26 Sep 2024 20:02:18 +1000 Subject: [PATCH] add arrayConcatAggregator (#159) --- .../mrpowers/spark/daria/sql/functions.scala | 5 ++ .../spark/daria/sql/udafs/ArrayConcat.scala | 32 +++++++- .../spark/daria/sql/FunctionsTest.scala | 78 ++++++++++++++++++- .../daria/sql/udafs/ArrayConcatTest.scala | 41 ++++++++++ 4 files changed, 152 insertions(+), 4 deletions(-) diff --git a/src/main/scala/com/github/mrpowers/spark/daria/sql/functions.scala b/src/main/scala/com/github/mrpowers/spark/daria/sql/functions.scala index b1b6ff9d..f314bcec 100644 --- a/src/main/scala/com/github/mrpowers/spark/daria/sql/functions.scala +++ b/src/main/scala/com/github/mrpowers/spark/daria/sql/functions.scala @@ -1,5 +1,6 @@ package com.github.mrpowers.spark.daria.sql +import com.github.mrpowers.spark.daria.sql.udafs.ArrayConcatAggregator import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql.Column import org.apache.spark.sql.functions._ @@ -750,4 +751,8 @@ object functions { def excelEpochToDate(colName: String): Column = { excelEpochToDate(col(colName)) } + + def arrayConcat(col: Column): Column = { + flatten(collect_list(col)) + } } diff --git a/src/main/scala/com/github/mrpowers/spark/daria/sql/udafs/ArrayConcat.scala b/src/main/scala/com/github/mrpowers/spark/daria/sql/udafs/ArrayConcat.scala index f67a04ec..27b7db68 100644 --- a/src/main/scala/com/github/mrpowers/spark/daria/sql/udafs/ArrayConcat.scala +++ b/src/main/scala/com/github/mrpowers/spark/daria/sql/udafs/ArrayConcat.scala @@ -1,8 +1,10 @@ package com.github.mrpowers.spark.daria.sql.udafs -import org.apache.spark.sql.Row -import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} -import org.apache.spark.sql.types.{ArrayType, DataType, StructField, StructType} +import scala.reflect.runtime.universe.TypeTag +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.{Encoder, Row} +import org.apache.spark.sql.expressions.{Aggregator, MutableAggregationBuffer, UserDefinedAggregateFunction} +import org.apache.spark.sql.types._ class ArrayConcat(elementSchema: DataType, nullable: Boolean = true) extends UserDefinedAggregateFunction { @@ -43,3 +45,27 @@ class ArrayConcat(elementSchema: DataType, nullable: Boolean = true) extends Use buffer.getAs[Seq[Any]](0) } } + +case class ArrayConcatAggregator[T: TypeTag]() extends Aggregator[Seq[T], Seq[T], Seq[T]] { + override def zero: Seq[T] = Seq.empty[T] + + override def reduce(b: Seq[T], a: Seq[T]): Seq[T] = { + if (a == null) { + return b + } + b ++ a + } + + override def merge(b1: Seq[T], b2: Seq[T]): Seq[T] = { + if (b2 == null) { + return b1 + } + b1 ++ b2 + } + + override def finish(reduction: Seq[T]): Seq[T] = reduction + + override def bufferEncoder: Encoder[Seq[T]] = ExpressionEncoder[Seq[T]]() + + override def outputEncoder: Encoder[Seq[T]] = ExpressionEncoder[Seq[T]]() +} diff --git a/src/test/scala/com/github/mrpowers/spark/daria/sql/FunctionsTest.scala b/src/test/scala/com/github/mrpowers/spark/daria/sql/FunctionsTest.scala index abfe4f68..05d05306 100644 --- a/src/test/scala/com/github/mrpowers/spark/daria/sql/FunctionsTest.scala +++ b/src/test/scala/com/github/mrpowers/spark/daria/sql/FunctionsTest.scala @@ -1,12 +1,12 @@ package com.github.mrpowers.spark.daria.sql import java.sql.{Date, Timestamp} - import utest._ import com.github.mrpowers.spark.fast.tests.{ColumnComparer, DataFrameComparer} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ import SparkSessionExt._ +import com.github.mrpowers.spark.daria.sql.functions.arrayConcat object FunctionsTest extends TestSuite with DataFrameComparer with ColumnComparer with SparkSessionTestWrapper { @@ -1429,5 +1429,81 @@ object FunctionsTest extends TestSuite with DataFrameComparer with ColumnCompare } } + "arrayConcat" - { + "arrayConcat array of string type" - { + + val actualDF = spark + .createDF( + List( + Array( + "snake", + "rat" + ), + null, + Array( + "cat", + "crazy" + ) + ), + List(("array", ArrayType(StringType), true)) + ) + .agg(arrayConcat(col("array")).as("array")) + + val expectedDF = spark + .createDF( + List( + Array( + "snake", + "rat", + "cat", + "crazy" + ) + ), + List(("array", ArrayType(StringType), true)) + ) + + assertSmallDataFrameEquality( + actualDF, + expectedDF, + ignoreNullable = true + ) + + } + + "arrayConcat array of Int type" - { + + val actualDF = spark + .createDF( + List( + Array( + 1, + 2 + ), + null, + Array( + 3, + 4 + ) + ), + List(("array", ArrayType(IntegerType), true)) + ) + .agg(arrayConcat(col("array")).as("array")) + + val expectedDF = spark + .createDF( + List(Array(1, 2, 3, 4)), + List(("array", ArrayType(IntegerType), true)) + ) + + assertSmallDataFrameEquality( + actualDF, + expectedDF, + ignoreNullable = true + ) + + } + + } + } } diff --git a/src/test/scala/com/github/mrpowers/spark/daria/sql/udafs/ArrayConcatTest.scala b/src/test/scala/com/github/mrpowers/spark/daria/sql/udafs/ArrayConcatTest.scala index 0fbcaf18..c09fac75 100644 --- a/src/test/scala/com/github/mrpowers/spark/daria/sql/udafs/ArrayConcatTest.scala +++ b/src/test/scala/com/github/mrpowers/spark/daria/sql/udafs/ArrayConcatTest.scala @@ -54,6 +54,47 @@ object ArrayConcatTest extends TestSuite with DataFrameComparer with SparkSessio } + "concatenates rows of arrays using aggregator" - { + + val arrayConcat = udaf(new ArrayConcatAggregator[String]()) + + val actualDF = spark + .createDF( + List( + Array( + "snake", + "rat" + ), + null, + Array( + "cat", + "crazy" + ) + ), + List(("array", ArrayType(StringType), true)) + ) + .agg(arrayConcat(col("array")).as("array")) + + val expectedDF = spark + .createDF( + List( + Array( + "snake", + "rat", + "cat", + "crazy" + ) + ), + List(("array", ArrayType(StringType), true)) + ) + + assertSmallDataFrameEquality( + actualDF, + expectedDF + ) + + } + } }