Skip to content

Commit

Permalink
JBAI-5829 [examples] Replaced loading models using bytes with direct …
Browse files Browse the repository at this point in the history
…file paths in ORT and KI example modules. Updated cache directory usage and adjusted build script for conditional plugin application.
  • Loading branch information
dmitriyb committed Oct 1, 2024
1 parent 740da0f commit 6d67891
Show file tree
Hide file tree
Showing 6 changed files with 22 additions and 31 deletions.
29 changes: 13 additions & 16 deletions build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import org.jetbrains.kotlin.gradle.targets.js.yarn.YarnLockMismatchReport
import org.jetbrains.kotlin.gradle.targets.js.yarn.YarnPlugin
import org.jetbrains.kotlin.gradle.targets.js.yarn.YarnRootExtension
import org.jetbrains.kotlin.gradle.tasks.KotlinCompilationTask
import org.jetbrains.kotlin.utils.addToStdlib.applyIf

group = "io.kinference"
version = "0.2.22"
Expand Down Expand Up @@ -35,21 +36,23 @@ subprojects {

apply {
plugin("org.jetbrains.kotlin.multiplatform")

plugin("maven-publish")
plugin("idea")
}


publishing {
repositories {
maven {
name = "SpacePackages"
url = uri("https://packages.jetbrains.team/maven/p/ki/maven")
applyIf(path != ":examples") {
apply(plugin = "maven-publish")

publishing {
repositories {
maven {
name = "SpacePackages"
url = uri("https://packages.jetbrains.team/maven/p/ki/maven")

credentials {
username = System.getenv("JB_SPACE_CLIENT_ID")
password = System.getenv("JB_SPACE_CLIENT_SECRET")
credentials {
username = System.getenv("JB_SPACE_CLIENT_ID")
password = System.getenv("JB_SPACE_CLIENT_SECRET")
}
}
}
}
Expand Down Expand Up @@ -83,9 +86,3 @@ subprojects {
targetCompatibility = jvmTargetVersion.toString()
}
}

project(":examples") {
tasks.withType<PublishToMavenRepository>().configureEach {
onlyIf { false }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import java.io.File
* It is used in various functions to check for existing files or directories,
* create new ones if they do not exist, and manage the caching of downloaded files.
*/
val cacheDirectory = System.getProperty("user.dir") + "/cache/"
val cacheDirectory = System.getProperty("user.dir") + "/.cache/"

/**
* Downloads a file from the given URL and saves it with the specified file name.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,10 @@ suspend fun main() {
println("Downloading synset from: $synsetUrl")
downloadFile(synsetUrl, "synset.txt")

val modelBytes = CommonDataLoader.bytes("$cacheDirectory/$modelName.onnx".toPath())
val classLabels = File("$cacheDirectory/synset.txt").readLines()

println("Loading model...")
val model = KIEngine.loadModel(modelBytes, optimize = true, predictionConfig = PredictionConfigs.DefaultAutoAllocator)
val model = KIEngine.loadModel("$cacheDirectory/$modelName.onnx".toPath(), optimize = true, predictionConfig = PredictionConfigs.DefaultAutoAllocator)
println("Creating inputs...")
val inputTensors = createInputs()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,10 @@ suspend fun main() {
println("Downloading synset from: $synsetUrl")
downloadFile(synsetUrl, "synset.txt")

val modelBytes = CommonDataLoader.bytes("$cacheDirectory/$modelName.onnx".toPath())
val classLabels = File("$cacheDirectory/synset.txt").readLines()

println("Loading model...")
val model = ORTEngine.loadModel(modelBytes)
val model = ORTEngine.loadModel("$cacheDirectory/$modelName.onnx".toPath())
println("Creating inputs...")
val inputTensors = createInputs()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,8 @@ suspend fun main() {
println("Downloading model from: $modelUrl")
downloadFile(modelUrl, "$modelName.onnx") //GPT-2 from model zoo is around 650 Mb, adjust your timeout if needed

val modelBytes = CommonDataLoader.bytes("${cacheDirectory}/$modelName.onnx".toPath())

println("Loading model...")
val model = KIEngine.loadModel(modelBytes, optimize = true, predictionConfig = PredictionConfigs.DefaultAutoAllocator)
val model = KIEngine.loadModel("$cacheDirectory/$modelName.onnx".toPath(), optimize = true, predictionConfig = PredictionConfigs.DefaultAutoAllocator)

val tokenizer = HuggingFaceTokenizer.newInstance("gpt2", mapOf("modelMaxLength" to "1024"))
val testString = "Neurogenesis is most active during embryonic development and is responsible for producing " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,8 @@ suspend fun main() {
println("Downloading model from: $modelUrl")
downloadFile(modelUrl, "$modelName.onnx") //GPT-2 from model zoo is around 650 Mb, adjust your timeout if needed

val modelBytes = CommonDataLoader.bytes("${cacheDirectory}/$modelName.onnx".toPath())

println("Loading model...")
val model = ORTEngine.loadModel(modelBytes)
val model = ORTEngine.loadModel("$cacheDirectory/$modelName.onnx".toPath())

val tokenizer = HuggingFaceTokenizer.newInstance("gpt2", mapOf("modelMaxLength" to "1024"))
val testString = "Neurogenesis is most active during embryonic development and is responsible for producing " +
Expand Down Expand Up @@ -63,12 +61,12 @@ suspend fun main() {
}

private suspend fun convertToKITensorMap(outputs: Map<String, ORTData<*>>): Map<String, KITensor> {
return outputs.map { (key, value) ->
val ortTensor = value as ORTTensor
return outputs.map { (name, ortTensor) ->
val ortTensor = ortTensor as ORTTensor
val data = ortTensor.toFloatArray()
val shape = ortTensor.shape.toIntArray()
val ndArray = FloatNDArray(shape) { idx: InlineInt -> data[idx.value] }
val tensor = ndArray.asTensor(key)
return@map key to tensor
val kiTensor = ndArray.asTensor(name)
return@map name to kiTensor
}.toMap()
}

0 comments on commit 6d67891

Please sign in to comment.