Skip to content

Commit

Permalink
JBAI-4393 [ndarray] Added Fastutil support for more efficient primiti…
Browse files Browse the repository at this point in the history
…ve handling in primitive array storage classes.
  • Loading branch information
dmitriyb committed Aug 29, 2024
1 parent b83f7f8 commit 450a39e
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 14 deletions.
2 changes: 2 additions & 0 deletions gradle/libs.versions.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ okio = "3.6.0"
onnxruntime = "1.17.0.patched-1"
slf4j = "2.0.9"
wire = "4.9.3"
fastutil = "8.5.14"

# JS Dependencies
loglevel = "1.8.1"
Expand All @@ -36,3 +37,4 @@ onnxruntime-gpu = { module = "com.microsoft.onnxruntime:onnxruntime_gpu", versio
slf4j-api = { module = "org.slf4j:slf4j-api", version.ref = "slf4j" }
slf4j-simple = { module = "org.slf4j:slf4j-simple", version.ref = "slf4j" }
wire-runtime = { module = "com.squareup.wire:wire-runtime", version.ref = "wire" }
fastutil-core = { module = "it.unimi.dsi:fastutil-core", version.ref = "fastutil" }
1 change: 1 addition & 0 deletions ndarray/ndarray-core/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ kotlin {
implementation(libs.kotlinx.coroutines.core)
implementation(libs.kotlinx.atomicfu)
api(libs.apache.commons.math4.core)
api(libs.fastutil.core)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ const val ERF_COEF_3 = 1.421413741
const val ERF_COEF_4 = -1.453152027
const val ERF_COEF_5 = 1.061405429

const val INIT_STORAGE_SIZE = 64

internal fun IntArray.swap(leftIdx: Int, rightIdx: Int) {
val temp = get(leftIdx)
this[leftIdx] = this[rightIdx]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,36 +1,37 @@
@file:GeneratePrimitives(DataType.ALL)
package io.kinference.ndarray.arrays.memory.storage

import io.kinference.ndarray.INIT_STORAGE_SIZE
import io.kinference.ndarray.arrays.memory.MemoryManager
import io.kinference.ndarray.extensions.constants.PrimitiveConstants
import io.kinference.ndarray.extensions.utils.getOrPut
import io.kinference.primitives.annotations.GenerateNameFromPrimitives
import io.kinference.primitives.annotations.GeneratePrimitives
import io.kinference.primitives.types.DataType
import io.kinference.primitives.types.PrimitiveArray
import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap

@GenerateNameFromPrimitives
internal class PrimitiveAutoHandlingArrayStorage : TypedAutoHandlingStorage {
private val used = HashMap<Int, ArrayDeque<PrimitiveArray>>(8)
private val unused = HashMap<Int, ArrayDeque<PrimitiveArray>>(8)
private val used = Int2ObjectOpenHashMap<ArrayDeque<PrimitiveArray>>(INIT_STORAGE_SIZE)
private val unused = Int2ObjectOpenHashMap<ArrayDeque<PrimitiveArray>>(INIT_STORAGE_SIZE)

companion object {
private val type = DataType.CurrentPrimitive
}

fun getBlock(blocksNum: Int, blockSize: Int, limiter: MemoryManager): Array<PrimitiveArray> {
internal fun getBlock(blocksNum: Int, blockSize: Int, limiter: MemoryManager): Array<PrimitiveArray> {
val unusedQueue = unused.getOrPut(blockSize) { ArrayDeque(blocksNum) }
val usedQueue = used.getOrPut(blockSize) { ArrayDeque(blocksNum) }

val blocks = if (limiter.checkMemoryLimitAndAdd(type, blockSize * blocksNum)) {
val blocks = if (limiter.checkMemoryLimitAndAdd(type, size = blockSize * blocksNum)) {
Array(blocksNum) {
unusedQueue.removeFirstOrNull()?.apply {
fill(PrimitiveConstants.ZERO)
} ?: PrimitiveArray(blockSize)
}
} else {
Array(blocksNum) {
PrimitiveArray(blockSize)
}
Array(blocksNum) { PrimitiveArray(blockSize) }
}

usedQueue.addAll(blocks)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
@file:GeneratePrimitives(DataType.ALL)
package io.kinference.ndarray.arrays.memory.storage

import io.kinference.ndarray.INIT_STORAGE_SIZE
import io.kinference.ndarray.arrays.*
import io.kinference.ndarray.arrays.PrimitiveNDArray
import io.kinference.ndarray.arrays.memory.MemoryManager
import io.kinference.ndarray.arrays.tiled.PrimitiveTiledArray
import io.kinference.ndarray.blockSizeByStrides
import io.kinference.ndarray.extensions.constants.PrimitiveConstants
import io.kinference.ndarray.extensions.utils.getOrPut
import io.kinference.primitives.annotations.*
import io.kinference.primitives.types.DataType
import io.kinference.primitives.types.PrimitiveArray
import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap

@GenerateNameFromPrimitives
internal class PrimitiveManualHandlingArrayStorage : TypedManualHandlingStorage {
private val storage = HashMap<Int, ArrayDeque<PrimitiveArray>>(8)
private val storage = Int2ObjectOpenHashMap<ArrayDeque<PrimitiveArray>>(INIT_STORAGE_SIZE)

companion object {
private val type = DataType.CurrentPrimitive
Expand All @@ -22,14 +25,12 @@ internal class PrimitiveManualHandlingArrayStorage : TypedManualHandlingStorage
override fun getNDArray(strides: Strides, fillZeros: Boolean, limiter: MemoryManager): MutableNDArrayCore {
val blockSize = blockSizeByStrides(strides)
val blocksNum = strides.linearSize / blockSize
val blocks = if (limiter.checkMemoryLimitAndAdd(type, blockSize * blocksNum)) {
val blocks = if (limiter.checkMemoryLimitAndAdd(type, size = blockSize * blocksNum)) {
val queue = storage.getOrPut(blockSize) { ArrayDeque(blocksNum) }
Array(blocksNum) {
val block = queue.removeFirstOrNull()
if (fillZeros) {
block?.fill(PrimitiveConstants.ZERO)
}
block ?: PrimitiveArray(blockSize)
queue.removeFirstOrNull()?.apply {
fill(PrimitiveConstants.ZERO)
} ?: PrimitiveArray(blockSize)
}
} else {
Array(blocksNum) { PrimitiveArray(blockSize) }
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package io.kinference.ndarray.extensions.utils

import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap

/***
* Calculates the total size of the tensor with such shape.
*/
Expand Down Expand Up @@ -50,3 +52,14 @@ internal fun computeColumnMajorIndex(
internal fun isInPadding(actual: Int, bound: Int) : Boolean {
return actual < 0 || actual >= bound
}

inline fun <V> Int2ObjectOpenHashMap<V>.getOrPut(key: Int, defaultValue: () -> V): V {
val existingValue = this[key]
return if (existingValue != null) {
existingValue
} else {
val value = defaultValue()
put(key, value)
value
}
}

0 comments on commit 450a39e

Please sign in to comment.