Skip to content

Commit

Permalink
mockup
Browse files Browse the repository at this point in the history
  • Loading branch information
GideonPotok committed Jun 4, 2024
1 parent a49ccd4 commit 86fd2f3
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class OpenHashMap[K : ClassTag, @specialized(Long, Int, Double) V: ClassTag](

def this() = this(64)

protected var _keySet = new OpenHashSet[K](initialCapacity)
protected var _keySet = new OpenHashSet[K](initialCapacity, 0.7, Some(true))

// Init in constructor (instead of in declaration) to work around a Scala compiler specialization
// bug that would generate two arrays (one for Object and one for specialized T).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ import scala.reflect._
import com.google.common.hash.Hashing

import org.apache.spark.annotation.Private
import org.apache.spark.sql.catalyst.util.CollationFactory
import org.apache.spark.unsafe.types.UTF8String

/**
* A simple, fast hash set optimized for non-null insertion-only use case, where keys are never
Expand All @@ -43,7 +45,9 @@ import org.apache.spark.annotation.Private
@Private
class OpenHashSet[@specialized(Long, Int, Double, Float) T: ClassTag](
initialCapacity: Int,
loadFactor: Double)
loadFactor: Double,
specialPassedInHasher: Option[Boolean] = None
)
extends Serializable {

require(initialCapacity <= OpenHashSet.MAX_CAPACITY,
Expand All @@ -67,7 +71,9 @@ class OpenHashSet[@specialized(Long, Int, Double, Float) T: ClassTag](
case ClassTag.Int => new IntHasher().asInstanceOf[Hasher[T]]
case ClassTag.Double => new DoubleHasher().asInstanceOf[Hasher[T]]
case ClassTag.Float => new FloatHasher().asInstanceOf[Hasher[T]]
case _ => new Hasher[T]
case _ => specialPassedInHasher.map(_ =>
new CollationStringHasher()).getOrElse(
new Hasher[T]).asInstanceOf[Hasher[T]]
}

protected var _capacity = nextPowerOf2(initialCapacity)
Expand Down Expand Up @@ -314,6 +320,15 @@ object OpenHashSet {
override def hash(o: Float): Int = java.lang.Float.floatToIntBits(o)
}

class CollationStringHasher extends Hasher[String] {
override def hash(o: String): Int = {
// scalastyle:off println
println("Hash function for collation string is called")
// scalastyle:on println
CollationFactory.fetchCollation(2).hashFunction.applyAsLong(UTF8String.fromString(o)).toInt
}
}

private def grow1(newSize: Int): Unit = {}
private def move1(oldPos: Int, newPos: Int): Unit = { }

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.expressions.aggregate

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, TypeCheckResult, UnresolvedWithinGroup}
import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, UnresolvedWithinGroup}
import org.apache.spark.sql.catalyst.expressions.{Ascending, Descending, Expression, ExpressionDescription, ImplicitCastInputTypes, SortOrder}
import org.apache.spark.sql.catalyst.trees.UnaryLike
import org.apache.spark.sql.catalyst.types.PhysicalDataType
Expand Down Expand Up @@ -49,21 +49,6 @@ case class Mode(

override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)

override def checkInputDataTypes(): TypeCheckResult = {
val defaultCheck = super.checkInputDataTypes()
if (defaultCheck.isFailure) {
defaultCheck
} else if (UnsafeRowUtils.isBinaryStable(child.dataType) ||
child.dataType.isInstanceOf[StringType]) {
TypeCheckResult.TypeCheckSuccess
} else {
TypeCheckResult.TypeCheckFailure(
"The input to the function 'mode' was a complex type with non-binary collated fields," +
" which are currently not supported by 'mode'."
)
}
}

override def prettyName: String = "mode"

override def update(
Expand Down Expand Up @@ -98,6 +83,11 @@ case class Mode(
case (k, _) => CollationFactory.getCollationKey(k.asInstanceOf[UTF8String], collationId)
}(x => x)((x, y) => (x._1, x._2 + y._2)).values
modeMap
case complexType if !UnsafeRowUtils.isBinaryStable(child.dataType) =>
val modeMap = buffer.toSeq.groupMapReduce {
case (k, _) => UnsafeRowUtils.getBinaryStableKey(k, child.dataType)
}(x => x)((x, y) => (x._1, x._2 + y._2)).values
modeMap
case _ => buffer
}
reverseOpt.map { reverse =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.util

import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

object UnsafeRowUtils {

Expand Down Expand Up @@ -129,8 +130,34 @@ object UnsafeRowUtils {
None
}

def getBinaryStableKey(k: AnyRef, dataType: DataType): AnyRef = {
dataType match {
case complexType: StructType if !isBinaryStable(complexType) =>
val row = k.asInstanceOf[UnsafeRow]
val key = new Array[AnyRef](complexType.fields.length)
complexType.fields.zipWithIndex.foreach { case (field, index) =>
key(index) = getBinaryStableKey(row.get(index, field.dataType), field.dataType)
}
key
case complexType: ArrayType if !isBinaryStable(complexType.elementType) =>
val array = k.asInstanceOf[ArrayData]
val key = new Array[AnyRef](array.numElements())
array.foreach(complexType.elementType, (i, any) => {
key.update(i,
getBinaryStableKey(array.get(i, complexType.elementType), complexType.elementType))
})
key
case c: StringType if
!CollationFactory.fetchCollation(c.collationId).supportsBinaryEquality =>
val collationId = c.collationId
CollationFactory.getCollationKey(k.asInstanceOf[UTF8String], collationId)
case _ => k
}
}

/**
* Wrapper of validateStructuralIntegrityWithReasonImpl, add more information for debugging
*
* @param row The input UnsafeRow to be validated
* @param expectedSchema The expected schema that should match with the UnsafeRow
* @return None if all the checks pass. An error message if the row is not matched with the schema
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1719,7 +1719,7 @@ class CollationSQLExpressionsSuite
t.collationId + ", f2: INT>) USING parquet")
sql(s"INSERT INTO ${tableName} VALUES " + valuesToAdd)
val query = s"SELECT lower(mode(i).f1) FROM ${tableName}"
if(t.collationId == "utf8_binary_lcase" || t.collationId == "unicode_ci") {
if(t.collationId == "") { // || t.collationId == "unicode_ci") {
// Cannot resolve "mode(i)" due to data type mismatch:
// Input to function mode was a complex type with strings collated on non-binary
// collations, which is not yet supported.. SQLSTATE: 42K09; line 1 pos 13;
Expand Down

0 comments on commit 86fd2f3

Please sign in to comment.