Skip to content

Commit

Permalink
auxilliary data structure
Browse files Browse the repository at this point in the history
  • Loading branch information
GideonPotok committed May 23, 2024
1 parent 7894588 commit 3f254dd
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ case class Mode(
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0,
reverseOpt: Option[Boolean] = None)
extends TypedAggregateWithHashMapAsBuffer with ImplicitCastInputTypes
extends TypedAggregateWithHashMapAsBufferPlus with ImplicitCastInputTypes
with SupportsOrderingWithinGroup with UnaryLike[Expression] {

def this(child: Expression) = this(child, 0, 0)
Expand All @@ -42,9 +42,6 @@ case class Mode(
this(child, 0, 0, Some(reverse))
}

private lazy val binaryKeys: scala.collection.mutable.Map[UTF8String, UTF8String] =
scala.collection.mutable.Map.empty

// Returns null for empty inputs
override def nullable: Boolean = true

Expand Down Expand Up @@ -82,8 +79,8 @@ case class Mode(
override def prettyName: String = "mode"

override def update(
buffer: OpenHashMap[AnyRef, Long],
input: InternalRow): OpenHashMap[AnyRef, Long] = {
buffers: (OpenHashMap[AnyRef, Long], OpenHashMap[AnyRef, AnyRef]),
input: InternalRow): (OpenHashMap[AnyRef, Long], OpenHashMap[AnyRef, AnyRef]) = {
val key = child.eval(input)

val keyNew = child.dataType match {
Expand All @@ -96,32 +93,35 @@ case class Mode(
case key: UTF8String =>
CollationFactory.getCollationKey(key, collationId)
}
if(!binaryKeys.contains(keyNew)) {
binaryKeys.put(keyNew, UTF8String.fromString(key.toString))
if(!buffers._2.contains(keyNew)) {
buffers._2.update(keyNew, UTF8String.fromString(key.toString))
}
keyNew
case _ => key
}
if (key != null) {
buffer.changeValue(InternalRow.copyValue(keyNew).asInstanceOf[AnyRef], 1L, _ + 1L)
buffers._1.changeValue(InternalRow.copyValue(keyNew).asInstanceOf[AnyRef], 1L, _ + 1L)
}
buffer
buffers
}

override def merge(
buffer: OpenHashMap[AnyRef, Long],
other: OpenHashMap[AnyRef, Long]): OpenHashMap[AnyRef, Long] = {
other.foreach { case (key, count) =>
buffer.changeValue(key, count, _ + count)
buffer: (OpenHashMap[AnyRef, Long], OpenHashMap[AnyRef, AnyRef]),
other: (OpenHashMap[AnyRef, Long], OpenHashMap[AnyRef, AnyRef])): (OpenHashMap[AnyRef, Long], OpenHashMap[AnyRef, AnyRef]) = {
other._1.foreach { case (key, count) =>
buffer._1.changeValue(key, count, _ + count)
}
other._2.foreach { case (key, v) =>
buffer._2.changeValue(key, v, _)
}
buffer
}

override def eval(buffer: OpenHashMap[AnyRef, Long]): Any = {
if (buffer.isEmpty) {
override def eval(buffer: (OpenHashMap[AnyRef, Long], OpenHashMap[AnyRef, AnyRef])): Any = {
if (buffer._1.isEmpty) {
return null
}
val collationAwareBuffer = buffer
val collationAwareBuffer = buffer._1
val v = reverseOpt.map { reverse =>
val defaultKeyOrdering = if (reverse) {
PhysicalDataType.ordering(child.dataType).asInstanceOf[Ordering[AnyRef]].reverse
Expand All @@ -132,7 +132,7 @@ case class Mode(
collationAwareBuffer.maxBy { case (key, count) => (count, key) }(ordering)
}.getOrElse(collationAwareBuffer.maxBy(_._2))._1

binaryKeys.get(v match {
buffer._2.get(v match {
case key: UTF8String => key
case key: String => UTF8String.fromString(key)
}).getOrElse(v)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -639,6 +639,98 @@ abstract class TypedImperativeAggregate[T] extends ImperativeAggregate {
}
}

/**
* A special [[TypedImperativeAggregate]] that uses `OpenHashMap[AnyRef, Long]` as internal
* aggregation buffer.
*/
abstract class TypedAggregateWithHashMapAsBufferGenericPlus[T]
extends TypedImperativeAggregate[(OpenHashMap[AnyRef, T], OpenHashMap[AnyRef, AnyRef])] {
override def createAggregationBuffer(): (OpenHashMap[AnyRef, T], OpenHashMap[AnyRef, AnyRef]) = {
// Initialize new counts map instance here.
(new OpenHashMap[AnyRef, T](), new OpenHashMap[AnyRef, AnyRef]())
}

val t: T

protected def child: Expression

private lazy val projection = UnsafeProjection.create(Array[DataType](child.dataType, tToDataType))

private def tToDataType: DataType = t match {
case _: Long => LongType
case _: String => StringType
}

override def serialize(objs: (OpenHashMap[AnyRef, T], OpenHashMap[AnyRef, AnyRef])): Array[Byte] = {
val buffer = new Array[Byte](4 << 10) // 4K
val bos = new ByteArrayOutputStream()
val out = new DataOutputStream(bos)
try {
// Write pairs in counts map to byte buffer.
objs._1.foreach { case (key, count) =>
val row = InternalRow.apply(key, count)
val unsafeRow = projection.apply(row)
out.writeInt(unsafeRow.getSizeInBytes)
unsafeRow.writeToStream(out, buffer)
}
out.writeInt(-1)
objs._2.foreach { case (key, v) =>
val row = InternalRow.apply(key, v)
val unsafeRow = projection.apply(row)
out.writeInt(unsafeRow.getSizeInBytes)
unsafeRow.writeToStream(out, buffer)
}
out.writeInt(-1)
out.flush()


bos.toByteArray
} finally {
out.close()
bos.close()
}
}

override def deserialize(bytes: Array[Byte]): (OpenHashMap[AnyRef, T], OpenHashMap[AnyRef, AnyRef]) = {
val bis = new ByteArrayInputStream(bytes)
val ins = new DataInputStream(bis)
try {
val counts = new OpenHashMap[AnyRef, T]
// Read unsafeRow size and content in bytes.
var sizeOfNextRow = ins.readInt()
while (sizeOfNextRow >= 0) {
val bs = new Array[Byte](sizeOfNextRow)
ins.readFully(bs)
val row = new UnsafeRow(2)
row.pointTo(bs, sizeOfNextRow)
// Insert the pairs into counts map.
val key = row.get(0, child.dataType)
val count = row.get(1, tToDataType).asInstanceOf[T]
counts.update(key, count)
sizeOfNextRow = ins.readInt()
}
val other = new OpenHashMap[AnyRef, AnyRef]()
sizeOfNextRow = ins.readInt()
while (sizeOfNextRow >= 0) {
val bs = new Array[Byte](sizeOfNextRow)
ins.readFully(bs)
val row = new UnsafeRow(2)
row.pointTo(bs, sizeOfNextRow)
// Insert the pairs into counts map.
val key = row.get(0, child.dataType)
val v = row.get(1, child.dataType)
other.update(key, v)
sizeOfNextRow = ins.readInt()
}

(counts, other)
} finally {
ins.close()
bis.close()
}
}
}

/**
* A special [[TypedImperativeAggregate]] that uses `OpenHashMap[AnyRef, Long]` as internal
* aggregation buffer.
Expand Down Expand Up @@ -712,4 +804,8 @@ abstract class TypedAggregateWithHashMapAsBufferGeneric[T]

abstract class TypedAggregateWithHashMapAsBuffer extends TypedAggregateWithHashMapAsBufferGeneric[Long] {
override val t: Long = 0L
}

abstract class TypedAggregateWithHashMapAsBufferPlus extends TypedAggregateWithHashMapAsBufferGenericPlus[Long] {
override val t: Long = 0L
}

0 comments on commit 3f254dd

Please sign in to comment.