From a186448de5542d0c6d1335590998162cab9bea7a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bj=C3=B8rn=20Erik=20Pedersen?= Date: Mon, 5 Aug 2024 16:56:24 +0200 Subject: [PATCH] Add experimental ImportResolver (#2298) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Closes #2294 Signed-off-by: Bjørn Erik Pedersen --- experimental/importresolver.go | 19 ++++ experimental/importresolver_example_test.go | 101 ++++++++++++++++++ experimental/importresolver_test.go | 63 +++++++++++ experimental/testdata/inoutdispatcher.wasm | Bin 0 -> 217 bytes experimental/testdata/inoutdispatcher.wat | 38 +++++++ .../testdata/inoutdispatcherclient.wasm | Bin 0 -> 97 bytes .../testdata/inoutdispatcherclient.wat | 7 ++ internal/expctxkeys/importresolver.go | 6 ++ internal/wasm/store.go | 20 +++- internal/wasm/store_test.go | 39 ++++--- internal/wasm/table_test.go | 9 +- 11 files changed, 276 insertions(+), 26 deletions(-) create mode 100644 experimental/importresolver.go create mode 100644 experimental/importresolver_example_test.go create mode 100644 experimental/importresolver_test.go create mode 100644 experimental/testdata/inoutdispatcher.wasm create mode 100644 experimental/testdata/inoutdispatcher.wat create mode 100644 experimental/testdata/inoutdispatcherclient.wasm create mode 100644 experimental/testdata/inoutdispatcherclient.wat create mode 100644 internal/expctxkeys/importresolver.go diff --git a/experimental/importresolver.go b/experimental/importresolver.go new file mode 100644 index 0000000000..36c0e22b15 --- /dev/null +++ b/experimental/importresolver.go @@ -0,0 +1,19 @@ +package experimental + +import ( + "context" + + "github.com/tetratelabs/wazero/api" + "github.com/tetratelabs/wazero/internal/expctxkeys" +) + +// ImportResolver is an experimental func type that, if set, +// will be used as the first step in resolving imports. +// See issue 2294. +// If the import name is not found, it should return nil. +type ImportResolver func(name string) api.Module + +// WithImportResolver returns a new context with the given ImportResolver. +func WithImportResolver(ctx context.Context, resolver ImportResolver) context.Context { + return context.WithValue(ctx, expctxkeys.ImportResolverKey{}, resolver) +} diff --git a/experimental/importresolver_example_test.go b/experimental/importresolver_example_test.go new file mode 100644 index 0000000000..3c48212f0d --- /dev/null +++ b/experimental/importresolver_example_test.go @@ -0,0 +1,101 @@ +package experimental_test + +import ( + "bytes" + "context" + _ "embed" + "fmt" + "log" + + "github.com/tetratelabs/wazero" + "github.com/tetratelabs/wazero/api" + "github.com/tetratelabs/wazero/experimental" + "github.com/tetratelabs/wazero/imports/wasi_snapshot_preview1" +) + +var ( + // These wasm files were generated by the following: + // cd testdata + // wat2wasm --debug-names inoutdispatcher.wat + // wat2wasm --debug-names inoutdispatcherclient.wat + + //go:embed testdata/inoutdispatcher.wasm + inoutdispatcherWasm []byte + //go:embed testdata/inoutdispatcherclient.wasm + inoutdispatcherclientWasm []byte +) + +func Example_importResolver() { + ctx := context.Background() + + r := wazero.NewRuntime(ctx) + defer r.Close(ctx) + + // The client imports the inoutdispatcher module that reads from stdin and writes to stdout. + // This means that we need multiple instances of the inoutdispatcher module to have different stdin/stdout. + // This example demonstrates a way to do that. + type mod struct { + in bytes.Buffer + out bytes.Buffer + + client api.Module + } + + wasi_snapshot_preview1.MustInstantiate(ctx, r) + + const numInstances = 3 + mods := make([]*mod, numInstances) + for i := range mods { + mods[i] = &mod{} + m := mods[i] + idm, err := r.CompileModule(ctx, inoutdispatcherWasm) + if err != nil { + log.Panicln(err) + } + idcm, err := r.CompileModule(ctx, inoutdispatcherclientWasm) + if err != nil { + log.Panicln(err) + } + + const inoutDispatcherModuleName = "inoutdispatcher" + + dispatcherInstance, err := r.InstantiateModule(ctx, idm, + wazero.NewModuleConfig(). + WithStdin(&m.in). + WithStdout(&m.out). + WithName("")) // Makes it an anonymous module. + if err != nil { + log.Panicln(err) + } + + ctx = experimental.WithImportResolver(ctx, func(name string) api.Module { + if name == inoutDispatcherModuleName { + return dispatcherInstance + } + return nil + }) + + m.client, err = r.InstantiateModule(ctx, idcm, wazero.NewModuleConfig().WithName(fmt.Sprintf("m%d", i))) + if err != nil { + log.Panicln(err) + } + + } + + for i, m := range mods { + m.in.WriteString(fmt.Sprintf("Module instance #%d", i)) + _, err := m.client.ExportedFunction("dispatch").Call(ctx) + if err != nil { + log.Panicln(err) + } + } + + for i, m := range mods { + fmt.Printf("out%d: %s\n", i, m.out.String()) + } + + // Output: + // out0: Module instance #0 + // out1: Module instance #1 + // out2: Module instance #2 +} diff --git a/experimental/importresolver_test.go b/experimental/importresolver_test.go new file mode 100644 index 0000000000..313bda05bd --- /dev/null +++ b/experimental/importresolver_test.go @@ -0,0 +1,63 @@ +package experimental_test + +import ( + "context" + "fmt" + "testing" + + "github.com/tetratelabs/wazero" + "github.com/tetratelabs/wazero/api" + "github.com/tetratelabs/wazero/experimental" + "github.com/tetratelabs/wazero/internal/testing/binaryencoding" + "github.com/tetratelabs/wazero/internal/testing/require" + "github.com/tetratelabs/wazero/internal/wasm" +) + +func TestImportResolver(t *testing.T) { + ctx := context.Background() + + r := wazero.NewRuntime(ctx) + defer r.Close(ctx) + + for i := 0; i < 5; i++ { + var callCount int + start := func(ctx context.Context) { + callCount++ + } + modImport, err := r.NewHostModuleBuilder(fmt.Sprintf("env%d", i)). + NewFunctionBuilder().WithFunc(start).Export("start"). + Compile(ctx) + require.NoError(t, err) + // Anonymous module, it will be resolved by the import resolver. + instanceImport, err := r.InstantiateModule(ctx, modImport, wazero.NewModuleConfig().WithName("")) + require.NoError(t, err) + + resolveImport := func(name string) api.Module { + if name == "env" { + return instanceImport + } + return nil + } + + // Set the import resolver in the context. + ctx = experimental.WithImportResolver(context.Background(), resolveImport) + + one := uint32(1) + binary := binaryencoding.EncodeModule(&wasm.Module{ + TypeSection: []wasm.FunctionType{{}}, + ImportSection: []wasm.Import{{Module: "env", Name: "start", Type: wasm.ExternTypeFunc, DescFunc: 0}}, + FunctionSection: []wasm.Index{0}, + CodeSection: []wasm.Code{ + {Body: []byte{wasm.OpcodeCall, 0, wasm.OpcodeEnd}}, // Call the imported env.start. + }, + StartSection: &one, + }) + + modMain, err := r.CompileModule(ctx, binary) + require.NoError(t, err) + + _, err = r.InstantiateModule(ctx, modMain, wazero.NewModuleConfig()) + require.NoError(t, err) + require.Equal(t, 1, callCount) + } +} diff --git a/experimental/testdata/inoutdispatcher.wasm b/experimental/testdata/inoutdispatcher.wasm new file mode 100644 index 0000000000000000000000000000000000000000..33403e3db52659729b51c109d73d3b9ef2b5ec73 GIT binary patch literal 217 zcmZ{fu?~Vj5JYE}GZ60z3S(hyVQXV!Ot{4O8*q^`Tmce}NXgIoAufe>PBY1yNoE%= zwgi9-S85i~TmgLL?c`!8W9RzVJjb#h{44nBm_F)q@U8~_6f!AHLl`}fFwshpC^eMT zA@**(H{kV!&aldw6T@cq4RJ1#!Y+_(mJ2;dgqtc1Ye5S}PCEmwHTJ4=t)~s+dky_o JIZfmlFg|t&Fhl?V literal 0 HcmV?d00001 diff --git a/experimental/testdata/inoutdispatcher.wat b/experimental/testdata/inoutdispatcher.wat new file mode 100644 index 0000000000..485c39c9ef --- /dev/null +++ b/experimental/testdata/inoutdispatcher.wat @@ -0,0 +1,38 @@ +(module + (import "wasi_snapshot_preview1" "fd_read" (func $fd_read (param i32 i32 i32 i32) (result i32))) + (import "wasi_snapshot_preview1" "fd_write" (func $fd_write (param i32 i32 i32 i32) (result i32))) + (memory 1 1 ) + (func (export "dispatch") + ;; Buffer of 100 chars to read into. + (i32.store (i32.const 4) (i32.const 12)) + (i32.store (i32.const 8) (i32.const 100)) + + (block $done + (loop $read + ;; Read from stdin. + (call $fd_read + (i32.const 0) ;; fd; 0 is stdin. + (i32.const 4) ;; iovs + (i32.const 1) ;; iovs_len + (i32.const 8) ;; nread + ) + + ;; If nread is 0, we're done. + (if (i32.eq (i32.load (i32.const 8)) (i32.const 0)) + (then br $done) + ) + + ;; Write to stdout. + (drop (call $fd_write + (i32.const 1) ;; fd; 1 is stdout. + (i32.const 4) ;; iovs + (i32.const 1) ;; iovs_len + (i32.const 0) ;; nwritten + )) + (br $read) + + ) + ) + ) + +) diff --git a/experimental/testdata/inoutdispatcherclient.wasm b/experimental/testdata/inoutdispatcherclient.wasm new file mode 100644 index 0000000000000000000000000000000000000000..8da09dc22155e8d102192a11e9ca5447e0541ad8 GIT binary patch literal 97 zcmYL>%ML&=6a~+@>JgQ&mH0R}G|@#%wEmwBc9WT8peYal4QD_m_CD@cUVS<=FPl4? b7lA^Ey5n!yRx}u3F`B5s(Gp*kQl2et9tadw literal 0 HcmV?d00001 diff --git a/experimental/testdata/inoutdispatcherclient.wat b/experimental/testdata/inoutdispatcherclient.wat new file mode 100644 index 0000000000..904284fb6e --- /dev/null +++ b/experimental/testdata/inoutdispatcherclient.wat @@ -0,0 +1,7 @@ + +(module + (import "inoutdispatcher" "dispatch" (func $dispatch)) + (func (export "dispatch") + (call $dispatch) + ) +) diff --git a/internal/expctxkeys/importresolver.go b/internal/expctxkeys/importresolver.go new file mode 100644 index 0000000000..af52cc80eb --- /dev/null +++ b/internal/expctxkeys/importresolver.go @@ -0,0 +1,6 @@ +package expctxkeys + +// ImportResolverKey is a context.Context Value key. +// Its associated value should be an ImportResolver. +// See issue 2294. +type ImportResolverKey struct{} diff --git a/internal/wasm/store.go b/internal/wasm/store.go index 1db661e853..cf87c30ea7 100644 --- a/internal/wasm/store.go +++ b/internal/wasm/store.go @@ -352,7 +352,7 @@ func (s *Store) instantiate( return nil, err } - if err = m.resolveImports(module); err != nil { + if err = m.resolveImports(ctx, module); err != nil { return nil, err } @@ -410,12 +410,22 @@ func (s *Store) instantiate( return } -func (m *ModuleInstance) resolveImports(module *Module) (err error) { +func (m *ModuleInstance) resolveImports(ctx context.Context, module *Module) (err error) { + // Check if ctx contains an ImportResolver. + resolveImport, _ := ctx.Value(expctxkeys.ImportResolverKey{}).(experimental.ImportResolver) + for moduleName, imports := range module.ImportPerModule { var importedModule *ModuleInstance - importedModule, err = m.s.module(moduleName) - if err != nil { - return err + if resolveImport != nil { + if v := resolveImport(moduleName); v != nil { + importedModule = v.(*ModuleInstance) + } + } + if importedModule == nil { + importedModule, err = m.s.module(moduleName) + if err != nil { + return err + } } for _, i := range imports { diff --git a/internal/wasm/store_test.go b/internal/wasm/store_test.go index 8df289306e..08297591b4 100644 --- a/internal/wasm/store_test.go +++ b/internal/wasm/store_test.go @@ -701,13 +701,13 @@ func Test_resolveImports(t *testing.T) { t.Run("module not instantiated", func(t *testing.T) { m := &ModuleInstance{s: newStore()} - err := m.resolveImports(&Module{ImportPerModule: map[string][]*Import{"unknown": {{}}}}) + err := m.resolveImports(context.Background(), &Module{ImportPerModule: map[string][]*Import{"unknown": {{}}}}) require.EqualError(t, err, "module[unknown] not instantiated") }) t.Run("export instance not found", func(t *testing.T) { m := &ModuleInstance{s: newStore()} m.s.nameToModule[moduleName] = &ModuleInstance{Exports: map[string]*Export{}, ModuleName: moduleName} - err := m.resolveImports(&Module{ImportPerModule: map[string][]*Import{moduleName: {{Name: "unknown"}}}}) + err := m.resolveImports(context.Background(), &Module{ImportPerModule: map[string][]*Import{moduleName: {{Name: "unknown"}}}}) require.EqualError(t, err, "\"unknown\" is not exported in module \"test\"") }) t.Run("func", func(t *testing.T) { @@ -743,7 +743,7 @@ func Test_resolveImports(t *testing.T) { } m := &ModuleInstance{Engine: &mockModuleEngine{resolveImportsCalled: map[Index]Index{}}, s: s, Source: module} - err := m.resolveImports(module) + err := m.resolveImports(context.Background(), module) require.NoError(t, err) me := m.Engine.(*mockModuleEngine) @@ -773,7 +773,7 @@ func Test_resolveImports(t *testing.T) { } m := &ModuleInstance{Engine: &mockModuleEngine{resolveImportsCalled: map[Index]Index{}}, s: s, Source: module} - err := m.resolveImports(module) + err := m.resolveImports(context.Background(), module) require.EqualError(t, err, "import func[test.target]: signature mismatch: v_f32 != v_v") }) }) @@ -787,6 +787,7 @@ func Test_resolveImports(t *testing.T) { Exports: map[string]*Export{name: {Type: ExternTypeGlobal, Index: 0}}, ModuleName: moduleName, } err := m.resolveImports( + context.Background(), &Module{ ImportPerModule: map[string][]*Import{moduleName: {{Name: name, Type: ExternTypeGlobal, DescGlobal: g.Type}}}, }, @@ -805,11 +806,13 @@ func Test_resolveImports(t *testing.T) { ModuleName: moduleName, } m := &ModuleInstance{Globals: make([]*GlobalInstance, 1), s: s} - err := m.resolveImports(&Module{ - ImportPerModule: map[string][]*Import{moduleName: { - {Module: moduleName, Name: name, Type: ExternTypeGlobal, DescGlobal: GlobalType{Mutable: true}}, - }}, - }) + err := m.resolveImports( + context.Background(), + &Module{ + ImportPerModule: map[string][]*Import{moduleName: { + {Module: moduleName, Name: name, Type: ExternTypeGlobal, DescGlobal: GlobalType{Mutable: true}}, + }}, + }) require.EqualError(t, err, "import global[test.target]: mutability mismatch: true != false") }) t.Run("type mismatch", func(t *testing.T) { @@ -823,11 +826,13 @@ func Test_resolveImports(t *testing.T) { ModuleName: moduleName, } m := &ModuleInstance{Globals: make([]*GlobalInstance, 1), s: s} - err := m.resolveImports(&Module{ - ImportPerModule: map[string][]*Import{moduleName: { - {Module: moduleName, Name: name, Type: ExternTypeGlobal, DescGlobal: GlobalType{ValType: ValueTypeF64}}, - }}, - }) + err := m.resolveImports( + context.Background(), + &Module{ + ImportPerModule: map[string][]*Import{moduleName: { + {Module: moduleName, Name: name, Type: ExternTypeGlobal, DescGlobal: GlobalType{ValType: ValueTypeF64}}, + }}, + }) require.EqualError(t, err, "import global[test.target]: value type mismatch: f64 != i32") }) }) @@ -846,7 +851,7 @@ func Test_resolveImports(t *testing.T) { Engine: importedME, } m := &ModuleInstance{s: s, Engine: &mockModuleEngine{resolveImportsCalled: map[Index]Index{}}} - err := m.resolveImports(&Module{ + err := m.resolveImports(context.Background(), &Module{ ImportPerModule: map[string][]*Import{ moduleName: {{Module: moduleName, Name: name, Type: ExternTypeMemory, DescMem: &Memory{Max: max}}}, }, @@ -866,7 +871,7 @@ func Test_resolveImports(t *testing.T) { ModuleName: moduleName, } m := &ModuleInstance{s: s} - err := m.resolveImports(&Module{ + err := m.resolveImports(context.Background(), &Module{ ImportPerModule: map[string][]*Import{ moduleName: {{Module: moduleName, Name: name, Type: ExternTypeMemory, DescMem: importMemoryType}}, }, @@ -886,7 +891,7 @@ func Test_resolveImports(t *testing.T) { max := uint32(10) importMemoryType := &Memory{Max: max} m := &ModuleInstance{s: s} - err := m.resolveImports(&Module{ + err := m.resolveImports(context.Background(), &Module{ ImportPerModule: map[string][]*Import{moduleName: {{Module: moduleName, Name: name, Type: ExternTypeMemory, DescMem: importMemoryType}}}, }) require.EqualError(t, err, "import memory[test.target]: maximum size mismatch: 10 < 65536") diff --git a/internal/wasm/table_test.go b/internal/wasm/table_test.go index 371cfa46ec..8a1e68fe39 100644 --- a/internal/wasm/table_test.go +++ b/internal/wasm/table_test.go @@ -1,6 +1,7 @@ package wasm import ( + "context" "math" "testing" @@ -29,7 +30,7 @@ func Test_resolveImports_table(t *testing.T) { ModuleName: moduleName, } m := &ModuleInstance{Tables: make([]*TableInstance, 1), s: s} - err := m.resolveImports(&Module{ + err := m.resolveImports(context.Background(), &Module{ ImportPerModule: map[string][]*Import{ moduleName: {{Module: moduleName, Name: name, Type: ExternTypeTable, DescTable: Table{Max: &max}}}, }, @@ -47,7 +48,7 @@ func Test_resolveImports_table(t *testing.T) { ModuleName: moduleName, } m := &ModuleInstance{Tables: make([]*TableInstance, 1), s: s} - err := m.resolveImports(&Module{ + err := m.resolveImports(context.Background(), &Module{ ImportPerModule: map[string][]*Import{ moduleName: {{Module: moduleName, Name: name, Type: ExternTypeTable, DescTable: importTableType}}, }, @@ -64,7 +65,7 @@ func Test_resolveImports_table(t *testing.T) { ModuleName: moduleName, } m := &ModuleInstance{Tables: make([]*TableInstance, 1), s: s} - err := m.resolveImports(&Module{ + err := m.resolveImports(context.Background(), &Module{ ImportPerModule: map[string][]*Import{ moduleName: {{Module: moduleName, Name: name, Type: ExternTypeTable, DescTable: importTableType}}, }, @@ -79,7 +80,7 @@ func Test_resolveImports_table(t *testing.T) { ModuleName: moduleName, } m := &ModuleInstance{Tables: make([]*TableInstance, 1), s: s} - err := m.resolveImports(&Module{ + err := m.resolveImports(context.Background(), &Module{ ImportPerModule: map[string][]*Import{ moduleName: {{Module: moduleName, Name: name, Type: ExternTypeTable, DescTable: Table{Type: RefTypeExternref}}}, },