From 27c6d9d9fde44e36ddf636b860700d9248ebc1b7 Mon Sep 17 00:00:00 2001 From: deadprogram Date: Tue, 27 Feb 2024 23:58:36 +0100 Subject: [PATCH] engine, interp/wasman: handle the 4 basic value types correctly when calling wasman Signed-off-by: deadprogram --- engine/instance.go | 2 +- interp/wasman/instance.go | 97 +++++++++++++++++++++++---------- interp/wasman/instance_test.go | 87 +++++++++++++++++++++++++++++ interp/wasman/tester.wasm | Bin 0 -> 506 bytes 4 files changed, 156 insertions(+), 30 deletions(-) create mode 100644 interp/wasman/instance_test.go create mode 100755 interp/wasman/tester.wasm diff --git a/engine/instance.go b/engine/instance.go index 3642c20..a40be7a 100644 --- a/engine/instance.go +++ b/engine/instance.go @@ -1,5 +1,5 @@ package engine type Instance interface { - Call(name string, args ...interface{}) (interface{}, error) + Call(name string, args ...any) (any, error) } diff --git a/interp/wasman/instance.go b/interp/wasman/instance.go index 2a30944..de2b163 100644 --- a/interp/wasman/instance.go +++ b/interp/wasman/instance.go @@ -1,53 +1,92 @@ package wasman import ( + "math" + wasmaneng "github.com/hybridgroup/wasman" + "github.com/hybridgroup/wasman/types" ) type Instance struct { instance *wasmaneng.Instance } -func (i *Instance) Call(name string, args ...interface{}) (interface{}, error) { +func (i *Instance) Call(name string, args ...any) (any, error) { if len(args) == 0 { - results, _, err := i.instance.CallExportedFunc(name) + results, types, err := i.instance.CallExportedFunc(name) if err != nil { return nil, err } - return results, nil + res := decodeResults(results, types) + if len(res) == 0 { + return nil, nil + } + if len(res) == 1 { + return res[0], nil + } + return res, nil + } + + results, types, err := i.instance.CallExportedFunc(name, encodeArgs(args)...) + if err != nil { + return nil, err } - wargs := make([]uint64, len(args)) - switch args[0].(type) { + res := decodeResults(results, types) + if len(res) == 0 { + return nil, nil + } + if len(res) == 1 { + return res[0], nil + } + return res, nil +} + +func encodeArgs(args []any) []uint64 { + encoded := make([]uint64, 0, len(args)) + for _, arg := range args { + encoded = append(encoded, encodeArg(arg)) + } + return encoded +} + +func encodeArg(arg any) uint64 { + switch val := arg.(type) { case int32: - for i, v := range args { - wargs[i] = uint64(v.(int32)) - } - case uint32: - for i, v := range args { - wargs[i] = uint64(v.(uint32)) - } + return uint64(val) case int64: - for i, v := range args { - wargs[i] = uint64(v.(int64)) - } - case uint64: - for i, v := range args { - wargs[i] = uint64(v.(uint64)) - } + return uint64(val) case float32: - for i, v := range args { - wargs[i] = uint64(v.(float32)) - } + return uint64(math.Float32bits(val)) case float64: - for i, v := range args { - wargs[i] = uint64(v.(float64)) - } + return uint64(math.Float64bits(val)) + case uint32: + return uint64(val) + case uint64: + return uint64(val) } - results, _, err := i.instance.CallExportedFunc(name, wargs...) - if err != nil { - return nil, err + panic("bad arg type") +} + +func decodeResults(results []uint64, vtypes []types.ValueType) []any { + decoded := make([]any, 0, len(results)) + for i, result := range results { + vtype := vtypes[i] + decoded = append(decoded, decodeResult(result, vtype)) } + return decoded +} - return results, nil +func decodeResult(result uint64, vtype types.ValueType) any { + switch vtype { + case types.ValueTypeF32: + return math.Float32frombits(uint32(result)) + case types.ValueTypeF64: + return math.Float64frombits(uint64(result)) + case types.ValueTypeI32: + return int32(result) + case types.ValueTypeI64: + return int64(result) + } + panic("unreachable") } diff --git a/interp/wasman/instance_test.go b/interp/wasman/instance_test.go new file mode 100644 index 0000000..8ded463 --- /dev/null +++ b/interp/wasman/instance_test.go @@ -0,0 +1,87 @@ +package wasman + +import ( + _ "embed" + "testing" +) + +//go:embed tester.wasm +var wasmData []byte + +func TestInstance(t *testing.T) { + i := Interpreter{ + Memory: make([]byte, 65536), + } + if err := i.Init(); err != nil { + t.Errorf("Interpreter.Init() failed: %v", err) + } + + if err := i.Load(wasmData); err != nil { + t.Errorf("Interpreter.Load() failed: %v", err) + } + + inst, err := i.Run() + if err != nil { + t.Errorf("Interpreter.Run() failed: %v", err) + } + + t.Run("Call int32", func(t *testing.T) { + results, err := inst.Call("test_int32", int32(1), int32(2)) + if err != nil { + t.Errorf("Instance.Call() failed: %v", err) + } + if results != int32(3) { + t.Errorf("Instance.Call() failed: %v", results) + } + }) + + t.Run("Call uint32", func(t *testing.T) { + results, err := inst.Call("test_uint32", uint32(1), uint32(2)) + if err != nil { + t.Errorf("Instance.Call() failed: %v", err) + } + if uint32(results.(int32)) != uint32(3) { + t.Errorf("Instance.Call() failed: %v", results) + } + }) + + t.Run("Call int64", func(t *testing.T) { + results, err := inst.Call("test_int64", int64(1), int64(2)) + if err != nil { + t.Errorf("Instance.Call() failed: %v", err) + } + if results != int64(3) { + t.Errorf("Instance.Call() failed: %v", results) + } + }) + + t.Run("Call uint64", func(t *testing.T) { + results, err := inst.Call("test_uint64", uint64(1), uint64(2)) + if err != nil { + t.Errorf("Instance.Call() failed: %v", err) + } + if uint64(results.(int64)) != uint64(3) { + t.Errorf("Instance.Call() failed: %v", results) + } + }) + + t.Run("Call float32", func(t *testing.T) { + results, err := inst.Call("test_float32", float32(100.2), float32(300.8)) + if err != nil { + t.Errorf("Instance.Call() failed: %v", err) + } + if results != float32(401.0) { + t.Errorf("Instance.Call() failed: %v", results) + } + }) + + t.Run("Call float64", func(t *testing.T) { + results, err := inst.Call("test_float64", float64(111.2), float64(333.8)) + if err != nil { + t.Errorf("Instance.Call() failed: %v", err) + } + if results != float64(445.0) { + t.Errorf("Instance.Call() failed: %v", results) + } + }) +} diff --git a/interp/wasman/tester.wasm b/interp/wasman/tester.wasm new file mode 100755 index 0000000000000000000000000000000000000000..da3a58552c464debe5bf03633ffa539ccceaee83 GIT binary patch literal 506 zcmY*WJx}XE5ZpW431B3IXem-qfFe8bImCpP0xD=;M`54OXLE4&S-$go1m)vD&?E6v zaZr>8)$Hud?r60@b%_7~ojM6Xg4b)bPVj1lRta7%(K5k{1zK1=2N1UGKI^-&sDu$z ziZ>2$V7NA2$Akos`%T&{+ey<2;Q!Kw9ex{gspw|T`awI@J~(qQxP2JF06lMUvBmc- ze!B2rSG45l&Vl1N{@0xmGb&}A8P#aC3l9n2)zit(sF2vdgFn!I-KeZ(ywP}1(yC}_ z!LfA{1Yq6FW^mZajBCv^FP&qDZ7h#L-{)!pm!;x@QdCw3RPj$xGGDhkIHS ImB(Mr7wuq_6#xJL literal 0 HcmV?d00001