Skip to content

Commit

Permalink
JBAI-6945 [ndarray] Introduced functional interface ScalarBroadcastFu…
Browse files Browse the repository at this point in the history
…n instead of lambda, so InlineInt inside changed to regular Int without additional boxing operations.
  • Loading branch information
dmitriyb committed Sep 24, 2024
1 parent 460f929 commit 4bdb061
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import io.kinference.primitives.annotations.GenerateNameFromPrimitives
import io.kinference.primitives.annotations.GeneratePrimitives
import io.kinference.primitives.types.DataType
import io.kinference.primitives.types.PrimitiveType
import io.kinference.utils.inlines.InlineInt

@GenerateNameFromPrimitives
internal fun broadcastTwoTensorsPrimitive(
Expand Down Expand Up @@ -45,14 +44,14 @@ internal fun broadcastTwoTensorsPrimitive(
val rightBlocks = right.array.blocks
val destBlocks = dest.array.blocks

val leftIsScalarFun = { leftOffset: InlineInt, rightOffset: InlineInt, destOffset: InlineInt, axisToBroadcastIdx: InlineInt ->
val shapeIdx = axisToBroadcastIdx.value * 2
val leftIsScalarFun = ScalarBroadcastFun { leftOffset, rightOffset, destOffset, axisToBroadcastIdx ->
val shapeIdx = axisToBroadcastIdx * 2
val batchSize = destBroadcastingShape[shapeIdx]

for (batchIdx in 0 until batchSize) {
val leftBatchOffset = leftOffset.value + leftOffsets[shapeIdx] * batchIdx
val rightBatchOffset = rightOffset.value + rightOffsets[shapeIdx] * batchIdx
val destBatchOffset = destOffset.value + destOffsets[shapeIdx] * batchIdx
val leftBatchOffset = leftOffset + leftOffsets[shapeIdx] * batchIdx
val rightBatchOffset = rightOffset + rightOffsets[shapeIdx] * batchIdx
val destBatchOffset = destOffset + destOffsets[shapeIdx] * batchIdx

val leftScalar = leftBlocks[leftBatchOffset][0]

Expand All @@ -67,14 +66,14 @@ internal fun broadcastTwoTensorsPrimitive(
}
}

val rightIsScalarFun = { leftOffset: InlineInt, rightOffset: InlineInt, destOffset: InlineInt, axisToBroadcastIdx: InlineInt ->
val shapeIdx = axisToBroadcastIdx.value * 2
val rightIsScalarFun = ScalarBroadcastFun { leftOffset, rightOffset, destOffset, axisToBroadcastIdx ->
val shapeIdx = axisToBroadcastIdx * 2
val batchSize = destBroadcastingShape[shapeIdx]

for (batchIdx in 0 until batchSize) {
val leftBatchOffset = leftOffset.value + leftOffsets[shapeIdx] * batchIdx
val rightBatchOffset = rightOffset.value + rightOffsets[shapeIdx] * batchIdx
val destBatchOffset = destOffset.value + destOffsets[shapeIdx] * batchIdx
val leftBatchOffset = leftOffset + leftOffsets[shapeIdx] * batchIdx
val rightBatchOffset = rightOffset + rightOffsets[shapeIdx] * batchIdx
val destBatchOffset = destOffset + destOffsets[shapeIdx] * batchIdx

val rightScalar = rightBlocks[rightBatchOffset][0]

Expand All @@ -89,27 +88,27 @@ internal fun broadcastTwoTensorsPrimitive(
}
}

val defaultFun = { leftOffset: InlineInt, rightOffset: InlineInt, destOffset: InlineInt, axisToBroadcastIdx: InlineInt ->
val defaultFun = ScalarBroadcastFun { leftOffset, rightOffset, destOffset, _ ->
for (blockIdx in 0 until destBlocksInRow) {
val leftBlock = leftBlocks[leftOffset.value + blockIdx]
val rightBlock = rightBlocks[rightOffset.value + blockIdx]
val destBlock = destBlocks[destOffset.value + blockIdx]
val leftBlock = leftBlocks[leftOffset + blockIdx]
val rightBlock = rightBlocks[rightOffset + blockIdx]
val destBlock = destBlocks[destOffset + blockIdx]

for (idx in destBlock.indices) {
destBlock[idx] = op(leftBlock[idx], rightBlock[idx])
}
}
}

val broadcastingFun = when {
val broadcastingFun: ScalarBroadcastFun = when {
leftIsScalar -> leftIsScalarFun
rightIsScalar -> rightIsScalarFun
else -> defaultFun
}

fun broadcast(leftOffset: Int, rightOffset: Int, destOffset: Int, axisToBroadcastIdx: Int) {
if (axisToBroadcastIdx == totalAxesToBroadcast) {
broadcastingFun(InlineInt(leftOffset), InlineInt(rightOffset), InlineInt(destOffset), InlineInt(axisToBroadcastIdx))
broadcastingFun(leftOffset, rightOffset, destOffset, axisToBroadcastIdx)
} else {
val shapeIdx = axisToBroadcastIdx * 2

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@ package io.kinference.ndarray.extensions.broadcasting
import io.kinference.ndarray.arrays.NDArrayCore
import io.kinference.ndarray.extensions.utils.calculateBlock

internal fun interface ScalarBroadcastFun {
operator fun invoke(leftOffset: Int, rightOffset: Int, destOffset: Int, axisToBroadcastIdx: Int)
}

internal data class BroadcastingInfo(
val broadcastingShapes: Array<IntArray>,
val broadcastingDestShape: IntArray,
Expand Down Expand Up @@ -89,8 +93,6 @@ internal data class BroadcastingInfo(
}
}



internal fun makeOffsets(shape: IntArray, blocksInRow: Int): IntArray {
val offsets = IntArray(shape.size)
offsets[offsets.lastIndex - 1] = blocksInRow
Expand Down

0 comments on commit 4bdb061

Please sign in to comment.