Skip to content

Commit

Permalink
Typed outer join
Browse files Browse the repository at this point in the history
  • Loading branch information
zeotuan committed Dec 31, 2024
1 parent a528229 commit 0a93bc2
Show file tree
Hide file tree
Showing 5 changed files with 139 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@ import com.github.mrpowers.spark.fast.tests.DatasetUtils.DatasetOps
import com.github.mrpowers.spark.fast.tests.DatasetComparer.maxUnequalRowsToShow
import com.github.mrpowers.spark.fast.tests.SeqLikesExtensions.SeqExtensions
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.functions._
import org.apache.spark.sql.{DataFrame, Dataset, Encoder, Row}
import org.apache.spark.sql.{Column, DataFrame, Dataset, Row}

import scala.reflect.ClassTag
import scala.reflect.runtime.universe.TypeTag

case class DatasetContentMismatch(smth: String) extends Exception(smth)
case class DatasetCountMismatch(smth: String) extends Exception(smth)
Expand Down Expand Up @@ -73,7 +75,7 @@ Expected DataFrame Row Count: '$expectedCount'
val e = expectedDS.collect().toSeq
if (!a.approximateSameElements(e, equals)) {
val arr = ("Actual Content", "Expected Content")
val msg = "Diffs\n" ++ ProductUtil.showProductDiff(arr, a, e, truncate)
val msg = "Diffs\n" ++ ProductUtil.showProductDiff(arr, Left(a -> e), truncate)
throw DatasetContentMismatch(msg)
}
}
Expand Down Expand Up @@ -159,29 +161,33 @@ Expected DataFrame Row Count: '$expectedCount'
* partitioned in the same way and become unreliable when shuffling is involved.
* @param primaryKeys
* unique identifier for each row to ensure accurate comparison of rows
* @param checkKeyUniqueness
* if true, will check if the primary key is actually unique
*/
def assertLargeDatasetEqualityV2[T: ClassTag](
def assertLargeDatasetEqualityV2[T: ClassTag: TypeTag](
actualDS: Dataset[T],
expectedDS: Dataset[T],
equals: (T, T) => Boolean = (o1: T, o2: T) => o1.equals(o2),
equals: Either[(T, T) => Boolean, Option[Column]] = Right(None),
ignoreNullable: Boolean = false,
ignoreColumnNames: Boolean = false,
ignoreColumnOrder: Boolean = false,
ignoreMetadata: Boolean = true,
checkKeyUniqueness: Boolean = false,
primaryKeys: Seq[String] = Seq.empty,
truncate: Int = 500
): Unit = {
// first check if the schemas are equal
SchemaComparer.assertDatasetSchemaEqual(actualDS, expectedDS, ignoreNullable, ignoreColumnNames, ignoreColumnOrder, ignoreMetadata)
val actual = if (ignoreColumnOrder) orderColumns(actualDS, expectedDS) else actualDS
assertLargeDatasetContentEqualityV2(actual, expectedDS, equals, primaryKeys, truncate)
assertLargeDatasetContentEqualityV2(actual, expectedDS, equals, primaryKeys, checkKeyUniqueness, truncate)
}

def assertLargeDatasetContentEqualityV2[T: ClassTag](
def assertLargeDatasetContentEqualityV2[T: ClassTag: TypeTag](
ds1: Dataset[T],
ds2: Dataset[T],
equals: (T, T) => Boolean,
equals: Either[(T, T) => Boolean, Option[Column]],
primaryKeys: Seq[String],
checkKeyUniqueness: Boolean,
truncate: Int
): Unit = {
try {
Expand All @@ -195,14 +201,32 @@ Expected DataFrame Row Count: '$expectedCount'
throw DatasetCountMismatch(countMismatchMessage(actualCount, expectedCount))
}

val unequalDf = ds1
.joinPair(ds2, primaryKeys)
.filter((p: (T, T)) => !equals(p._1, p._2))
if (primaryKeys.nonEmpty && checkKeyUniqueness) {
assert(ds1.isKeyUnique(primaryKeys), "Primary key is not unique in actual dataset")
assert(ds2.isKeyUnique(primaryKeys), "Primary key is not unique in expected dataset")
}

val joinedDf = ds1
.outerJoinWith(ds2, primaryKeys)

val unequalDS = equals match {
case Left(customEquals) =>
joinedDf.filter((p: (Option[T], Option[T])) =>
p match {
case (Some(l), Some(r)) => !customEquals(l, r)
case (None, None) => false
case _ => true
}
)

case Right(equalExprOption) =>
joinedDf.filter(equalExprOption.getOrElse(col("l") =!= col("r")))
}

if (!unequalDf.isEmpty) {
val (a, e) = unequalDf.take(maxUnequalRowsToShow).toSeq.unzip
if (!unequalDS.isEmpty) {
val joined = Right(unequalDS.take(truncate).toSeq)
val arr = ("Actual Content", "Expected Content")
val msg = "Diffs\n" ++ ProductUtil.showProductDiff(arr, a, e, truncate)
val msg = "Diffs\n" ++ ProductUtil.showProductDiff(arr, joined, truncate)
throw DatasetContentMismatch(msg)
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,49 +1,74 @@
package com.github.mrpowers.spark.fast.tests

import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.OptionEncoder
import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, ExpressionEncoder, RowEncoder}
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._
import org.apache.spark.sql.{DataFrame, Dataset, Encoder, Row}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{DataFrame, Dataset, Encoder, Encoders, Row, TypedColumn}

import scala.reflect.ClassTag
import scala.reflect.runtime.universe.TypeTag

private object DatasetUtils {
implicit class DatasetOps[T: ClassTag](ds: Dataset[T]) {
implicit class DatasetOps[T: ClassTag: TypeTag](ds: Dataset[T]) {
def zipWithIndex(indexName: String): DataFrame = ds
.orderBy()
.withColumn(indexName, row_number().over(Window.orderBy(monotonically_increasing_id())))
.select(ds.columns.map(col) :+ col(indexName): _*)

def joinPair(
other: Dataset[T],
primaryKeys: Seq[String]
): Dataset[(T, T)] = {
if (primaryKeys.nonEmpty) {
ds
.as("l")
.joinWith(other.as("r"), primaryKeys.map(k => col(s"l.$k") === col(s"r.$k")).reduce(_ && _))
/**
* Check if the primary key is actually unique
*/
def isKeyUnique(primaryKey: Seq[String]): Boolean =
ds.select(primaryKey.map(col): _*).distinct.count == ds.count

def outerJoinWith[P: ClassTag: TypeTag](
other: Dataset[P],
primaryKeys: Seq[String],
outerJoinType: String = "full"
): Dataset[(Option[T], Option[P])] = {
val (ds1, ds2, key) = if (primaryKeys.nonEmpty) {
(ds, other, primaryKeys)
} else {
val indexName = s"index_${java.util.UUID.randomUUID}"
val columns = ds.columns
val joined = ds
.zipWithIndex(indexName)
.alias("l")
.join(other.zipWithIndex(indexName).alias("r"), indexName)

val encoder: Encoder[T] = ds.encoder
val leftCols = columns.map(n => col(s"l.$n"))
val rightCols = columns.map(n => col(s"r.$n"))
val (pair1, pair2) =
if (columns.length == 1 && !(implicitly[ClassTag[T]].runtimeClass == classOf[Row]))
(leftCols.head, rightCols.head)
else
(struct(leftCols: _*), struct(rightCols: _*))

joined
.select(
pair1.as("l").as[T](encoder),
pair2.as("r").as[T](encoder)
)
(ds.zipWithIndex(indexName), other.zipWithIndex(indexName), Seq(indexName))
}

val joined = ds1
.as("l")
.join(ds2.as("r"), key, s"${outerJoinType}_outer")

joined.select(encoderToOptionTypedCol[T]("l", ds.schema, key), encoderToOptionTypedCol[P]("r", other.schema, key))
}
}

private def encoderToOptionTypedCol[P: ClassTag: TypeTag](
colName: String,
schema: StructType,
key: Seq[String]
): TypedColumn[Any, Option[P]] = {
val columns = schema.names.map(n => col(s"$colName.$n"))
val isRowType = implicitly[ClassTag[P]].runtimeClass == classOf[Row]
val unTypedColumn =
if (columns.length == 1 && !isRowType)
columns.head
else
when(key.map(k => col(s"$colName.$k").isNull).reduce(_ && _), lit(null)).otherwise(struct(columns: _*))

val enc: Encoder[Option[P]] = if (isRowType) {
ExpressionEncoder(OptionEncoder(RowEncoder.encoderFor(schema).asInstanceOf[AgnosticEncoder[P]]))
} else {
ExpressionEncoder()
}
unTypedColumn.as(colName).as[Option[P]](enc)
}

def encoderToRowCol(
colName: String,
schema: StructType
): TypedColumn[Any, Row] = {
val columns = schema.names.map(n => col(s"$colName.$n"))
struct(columns: _*).as(colName).as[Row](Encoders.row(schema))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,25 @@ import com.github.mrpowers.spark.fast.tests.ufansi.FansiExtensions.StrOps
import org.apache.commons.lang3.StringUtils
import org.apache.spark.sql.Row

import scala.annotation.tailrec
import scala.reflect.ClassTag

object ProductUtil {
@tailrec
private[mrpowers] def productOrRowToSeq(product: Any): Seq[Any] = {
product match {
case null => Seq.empty
case a: Array[_] => a
case i: Iterable[_] => i.toSeq
case r: Row => r.toSeq
case p: Product => p.productIterator.toSeq
case s => Seq(s)
case null | None => Seq.empty
case a: Array[_] => a
case i: Iterable[_] => i.toSeq
case r: Row => r.toSeq
case p: Product => p.productIterator.toSeq
case Some(ov) if !ov.isInstanceOf[Option[_]] => productOrRowToSeq(ov)
case s => Seq(s)
}
}
private[mrpowers] def showProductDiff[T: ClassTag](
header: (String, String),
actual: Seq[T],
expected: Seq[T],
data: Either[(Seq[T], Seq[T]), Seq[(Option[T], Option[T])]],
truncate: Int = 20,
minColWidth: Int = 3
): String = {
Expand All @@ -33,7 +35,10 @@ object ProductUtil {

val sb = new StringBuilder

val fullJoin = actual.zipAll(expected, null, null)
val fullJoin = data match {
case Left((actual, expected)) => actual.zipAll(expected, null, null)
case Right(joined) => joined
}

val diff = fullJoin.map { case (actualRow, expectedRow) =>
if (actualRow == expectedRow) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@ object SchemaComparer {
private def betterSchemaMismatchMessage(actualSchema: StructType, expectedSchema: StructType): String = {
showProductDiff(
("Actual Schema", "Expected Schema"),
actualSchema.fields,
expectedSchema.fields,
Left(actualSchema.fields.toSeq -> expectedSchema.fields.toSeq),
truncate = 200
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@ package com.github.mrpowers.spark.fast.tests

import org.apache.spark.sql.types._
import SparkSessionExt._
import com.github.mrpowers.spark.fast.tests.DatasetUtils.DatasetOps
import com.github.mrpowers.spark.fast.tests.SchemaComparer.DatasetSchemaMismatch
import com.github.mrpowers.spark.fast.tests.TestUtilsExt.ExceptionOps
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions.{col, lower}
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.functions.col
import org.scalatest.freespec.AnyFreeSpec

class DatasetComparerV2Test extends AnyFreeSpec with DatasetComparer {
Expand Down Expand Up @@ -56,7 +57,7 @@ class DatasetComparerV2Test extends AnyFreeSpec with DatasetComparer {
Array("red", "green", "blue")
).toDS

assertLargeDatasetEqualityV2(sourceDS, expectedDS, equals = (a1: Array[String], a2: Array[String]) => a1.mkString == a2.mkString)
assertLargeDatasetEqualityV2(sourceDS, expectedDS, equals = Left((a1: Array[String], a2: Array[String]) => a1.mkString == a2.mkString))
}

"can compare Dataset[Map[_]]" in {
Expand All @@ -70,7 +71,11 @@ class DatasetComparerV2Test extends AnyFreeSpec with DatasetComparer {
Map("apple" -> "banana", "apple1" -> "banana1")
).toDS

assertLargeDatasetEqualityV2(sourceDS, expectedDS)
assertLargeDatasetEqualityV2(
sourceDS,
expectedDS,
equals = Left((a1: Map[String, String], a2: Map[String, String]) => a1.mkString == a2.mkString)
)
}

"does nothing if the Datasets have the same schemas and content" in {
Expand Down Expand Up @@ -223,7 +228,7 @@ class DatasetComparerV2Test extends AnyFreeSpec with DatasetComparer {
Person("Alice", 5)
)
)
assertLargeDatasetEqualityV2(sourceDS, expectedDS, equals = Person.caseInsensitivePersonEquals)
assertLargeDatasetEqualityV2(sourceDS, expectedDS, equals = Left((p1: Person, p2: Person) => Person.caseInsensitivePersonEquals(p1, p2)))
}

"fails if custom comparator for returns false" in {
Expand All @@ -241,7 +246,7 @@ class DatasetComparerV2Test extends AnyFreeSpec with DatasetComparer {
)

intercept[DatasetContentMismatch] {
assertLargeDatasetEqualityV2(sourceDS, expectedDS, equals = Person.caseInsensitivePersonEquals)
assertLargeDatasetEqualityV2(sourceDS, expectedDS, equals = Left((p1: Person, p2: Person) => Person.caseInsensitivePersonEquals(p1, p2)))
}
}

Expand Down Expand Up @@ -433,5 +438,26 @@ class DatasetComparerV2Test extends AnyFreeSpec with DatasetComparer {
assertLargeDatasetEqualityV2(ds2, ds1, ignoreMetadata = false)
}
}

"join test2" in {
val joinedOuterDS =
spark.range(0, 1000000, 1, 8).outerJoinWith(spark.range(0, 1000000, 2, 8), Seq("id"))
println("joinedOuterDS")
joinedOuterDS.explain()
joinedOuterDS.show(false)
val filteredDs = joinedOuterDS.filter(p => p._1 == p._2)
filteredDs.explain()
filteredDs.show(false)

val joinedOuterDF =
spark.range(0, 1000000, 1, 8).toDF().outerJoinWith(spark.range(0, 1000000, 2, 8).toDF(), Seq("id"))
println("joinedOuterDF")
joinedOuterDF.explain()
joinedOuterDF.show(false)

val filtedDf = joinedOuterDF.filter((p: (Option[Row], Option[Row])) => { p._1 == p._2 })
filtedDf.explain()
filtedDf.show(false)
}
}
}

0 comments on commit 0a93bc2

Please sign in to comment.