Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initialize NextJS Example #149

Merged
merged 20 commits into from
Jan 27, 2025
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Use dynamic imports for onnxruntime-web
mirko314 committed Jan 17, 2025
commit 96d08c274106a08f4304a2dbe02bce5cb6c97b3e
33 changes: 22 additions & 11 deletions packages/web/src/onnx.ts
Original file line number Diff line number Diff line change
@@ -1,21 +1,32 @@
export { createOnnxSession, runOnnxSession };

import ndarray, { NdArray } from 'ndarray';
import type ORT from 'onnxruntime-web';

import * as ort_cpu from 'onnxruntime-web';
import * as ort_gpu from 'onnxruntime-web/webgpu';

import { InferenceSession, Tensor } from 'onnxruntime-web';
import * as caps from './capabilities';
import { loadAsUrl } from './resource';
import { Config } from './schema';
import { loadAsUrl, resolveChunkUrls } from './resource';

type ORT = typeof import('onnxruntime-web');
// use a dynamic import to avoid bundling the entire onnxruntime-web package
let ort: ORT | null = null;
const getOrt = async (useWebGPU: boolean): Promise<ORT> => {
if (ort !== null) {
return ort;
}
if (useWebGPU) {
ort = (await import('onnxruntime-web/webgpu')).default;
} else {
ort = (await import('onnxruntime-web')).default;
}
return ort;
};

async function createOnnxSession(model: any, config: Config) {
const useWebGPU = config.device === 'gpu' && (await caps.webgpu());
// BUG: proxyToWorker is not working for WASM/CPU Backend for now
const proxyToWorker = useWebGPU && config.proxyToWorker;
const executionProviders = [useWebGPU ? 'webgpu' : 'wasm'];
const ort = useWebGPU ? ort_gpu : ort_cpu;
const ort = await getOrt(useWebGPU);

if (config.debug) {
console.debug('\tUsing WebGPU:', useWebGPU);
@@ -45,14 +56,14 @@ async function createOnnxSession(model: any, config: Config) {
console.debug('ort.env.wasm:', ort.env.wasm);
}

const ort_config: ORT.InferenceSession.SessionOptions = {
const ortConfig: InferenceSession.SessionOptions = {
executionProviders: executionProviders,
graphOptimizationLevel: 'all',
executionMode: 'parallel',
enableCpuMemArena: true
};

const session = await ort.InferenceSession.create(model, ort_config).catch(
const session = await ort.InferenceSession.create(model, ortConfig).catch(
(e: any) => {
throw new Error(
`Failed to create session: "${e}". Please check if the publicPath is set correctly.`
@@ -69,7 +80,7 @@ async function runOnnxSession(
config: Config
) {
const useWebGPU = config.device === 'gpu' && (await caps.webgpu());
const ort = useWebGPU ? ort_gpu : ort_cpu;
const ort = await getOrt(useWebGPU);

const feeds: Record<string, any> = {};
for (const [key, tensor] of inputs) {
@@ -82,7 +93,7 @@ async function runOnnxSession(
const outputData = await session.run(feeds, {});
const outputKVPairs: NdArray<Float32Array>[] = [];
for (const key of outputs) {
const output: ORT.Tensor = outputData[key];
const output: Tensor = outputData[key];
const shape: number[] = output.dims as number[];
const data: Float32Array = output.data as Float32Array;
const tensor = ndarray(data, shape);