diff --git a/build.gradle.kts b/build.gradle.kts index a6bde9f5..3caf1253 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -6,6 +6,7 @@ import org.jetbrains.kotlin.gradle.targets.js.yarn.YarnLockMismatchReport import org.jetbrains.kotlin.gradle.targets.js.yarn.YarnPlugin import org.jetbrains.kotlin.gradle.targets.js.yarn.YarnRootExtension import org.jetbrains.kotlin.gradle.tasks.KotlinCompilationTask +import org.jetbrains.kotlin.utils.addToStdlib.applyIf group = "io.kinference" version = "0.2.22" @@ -35,21 +36,23 @@ subprojects { apply { plugin("org.jetbrains.kotlin.multiplatform") - - plugin("maven-publish") plugin("idea") } - publishing { - repositories { - maven { - name = "SpacePackages" - url = uri("https://packages.jetbrains.team/maven/p/ki/maven") + applyIf(path != ":examples") { + apply(plugin = "maven-publish") + + publishing { + repositories { + maven { + name = "SpacePackages" + url = uri("https://packages.jetbrains.team/maven/p/ki/maven") - credentials { - username = System.getenv("JB_SPACE_CLIENT_ID") - password = System.getenv("JB_SPACE_CLIENT_SECRET") + credentials { + username = System.getenv("JB_SPACE_CLIENT_ID") + password = System.getenv("JB_SPACE_CLIENT_SECRET") + } } } } @@ -83,9 +86,3 @@ subprojects { targetCompatibility = jvmTargetVersion.toString() } } - -project(":examples") { - tasks.withType().configureEach { - onlyIf { false } - } -} diff --git a/examples/src/jvmMain/kotlin/io/kinference/examples/Utils.kt b/examples/src/jvmMain/kotlin/io/kinference/examples/Utils.kt index 30fb5643..eb9cbcd9 100644 --- a/examples/src/jvmMain/kotlin/io/kinference/examples/Utils.kt +++ b/examples/src/jvmMain/kotlin/io/kinference/examples/Utils.kt @@ -20,7 +20,7 @@ import java.io.File * It is used in various functions to check for existing files or directories, * create new ones if they do not exist, and manage the caching of downloaded files. */ -val cacheDirectory = System.getProperty("user.dir") + "/cache/" +val cacheDirectory = System.getProperty("user.dir") + "/.cache/" /** * Downloads a file from the given URL and saves it with the specified file name. diff --git a/examples/src/jvmMain/kotlin/io/kinference/examples/classification/KIClassificationMain.kt b/examples/src/jvmMain/kotlin/io/kinference/examples/classification/KIClassificationMain.kt index b7203297..00ace87b 100644 --- a/examples/src/jvmMain/kotlin/io/kinference/examples/classification/KIClassificationMain.kt +++ b/examples/src/jvmMain/kotlin/io/kinference/examples/classification/KIClassificationMain.kt @@ -104,11 +104,10 @@ suspend fun main() { println("Downloading synset from: $synsetUrl") downloadFile(synsetUrl, "synset.txt") - val modelBytes = CommonDataLoader.bytes("$cacheDirectory/$modelName.onnx".toPath()) val classLabels = File("$cacheDirectory/synset.txt").readLines() println("Loading model...") - val model = KIEngine.loadModel(modelBytes, optimize = true, predictionConfig = PredictionConfigs.DefaultAutoAllocator) + val model = KIEngine.loadModel("$cacheDirectory/$modelName.onnx".toPath(), optimize = true, predictionConfig = PredictionConfigs.DefaultAutoAllocator) println("Creating inputs...") val inputTensors = createInputs() diff --git a/examples/src/jvmMain/kotlin/io/kinference/examples/classification/ORTClassificationMain.kt b/examples/src/jvmMain/kotlin/io/kinference/examples/classification/ORTClassificationMain.kt index 2ccf757f..c8f3b0d5 100644 --- a/examples/src/jvmMain/kotlin/io/kinference/examples/classification/ORTClassificationMain.kt +++ b/examples/src/jvmMain/kotlin/io/kinference/examples/classification/ORTClassificationMain.kt @@ -103,11 +103,10 @@ suspend fun main() { println("Downloading synset from: $synsetUrl") downloadFile(synsetUrl, "synset.txt") - val modelBytes = CommonDataLoader.bytes("$cacheDirectory/$modelName.onnx".toPath()) val classLabels = File("$cacheDirectory/synset.txt").readLines() println("Loading model...") - val model = ORTEngine.loadModel(modelBytes) + val model = ORTEngine.loadModel("$cacheDirectory/$modelName.onnx".toPath()) println("Creating inputs...") val inputTensors = createInputs() diff --git a/examples/src/jvmMain/kotlin/io/kinference/examples/lm/KIGPT2Main.kt b/examples/src/jvmMain/kotlin/io/kinference/examples/lm/KIGPT2Main.kt index e142fc57..81e106ee 100644 --- a/examples/src/jvmMain/kotlin/io/kinference/examples/lm/KIGPT2Main.kt +++ b/examples/src/jvmMain/kotlin/io/kinference/examples/lm/KIGPT2Main.kt @@ -24,10 +24,8 @@ suspend fun main() { println("Downloading model from: $modelUrl") downloadFile(modelUrl, "$modelName.onnx") //GPT-2 from model zoo is around 650 Mb, adjust your timeout if needed - val modelBytes = CommonDataLoader.bytes("${cacheDirectory}/$modelName.onnx".toPath()) - println("Loading model...") - val model = KIEngine.loadModel(modelBytes, optimize = true, predictionConfig = PredictionConfigs.DefaultAutoAllocator) + val model = KIEngine.loadModel("$cacheDirectory/$modelName.onnx".toPath(), optimize = true, predictionConfig = PredictionConfigs.DefaultAutoAllocator) val tokenizer = HuggingFaceTokenizer.newInstance("gpt2", mapOf("modelMaxLength" to "1024")) val testString = "Neurogenesis is most active during embryonic development and is responsible for producing " + diff --git a/examples/src/jvmMain/kotlin/io/kinference/examples/lm/ORTGPT2Main.kt b/examples/src/jvmMain/kotlin/io/kinference/examples/lm/ORTGPT2Main.kt index c1b9c7d3..dd063413 100644 --- a/examples/src/jvmMain/kotlin/io/kinference/examples/lm/ORTGPT2Main.kt +++ b/examples/src/jvmMain/kotlin/io/kinference/examples/lm/ORTGPT2Main.kt @@ -27,10 +27,8 @@ suspend fun main() { println("Downloading model from: $modelUrl") downloadFile(modelUrl, "$modelName.onnx") //GPT-2 from model zoo is around 650 Mb, adjust your timeout if needed - val modelBytes = CommonDataLoader.bytes("${cacheDirectory}/$modelName.onnx".toPath()) - println("Loading model...") - val model = ORTEngine.loadModel(modelBytes) + val model = ORTEngine.loadModel("$cacheDirectory/$modelName.onnx".toPath()) val tokenizer = HuggingFaceTokenizer.newInstance("gpt2", mapOf("modelMaxLength" to "1024")) val testString = "Neurogenesis is most active during embryonic development and is responsible for producing " + @@ -63,12 +61,12 @@ suspend fun main() { } private suspend fun convertToKITensorMap(outputs: Map>): Map { - return outputs.map { (key, value) -> - val ortTensor = value as ORTTensor + return outputs.map { (name, ortTensor) -> + val ortTensor = ortTensor as ORTTensor val data = ortTensor.toFloatArray() val shape = ortTensor.shape.toIntArray() val ndArray = FloatNDArray(shape) { idx: InlineInt -> data[idx.value] } - val tensor = ndArray.asTensor(key) - return@map key to tensor + val kiTensor = ndArray.asTensor(name) + return@map name to kiTensor }.toMap() }