Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

engine, interp/wasman: correct param handling for all WASM basic types #3

Merged
merged 2 commits into from
Feb 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion engine/instance.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
package engine

type Instance interface {
Call(name string, args ...interface{}) (interface{}, error)
Call(name string, args ...any) (any, error)
}
85 changes: 70 additions & 15 deletions interp/wasman/instance.go
Original file line number Diff line number Diff line change
@@ -1,37 +1,92 @@
package wasman

import (
"math"

wasmaneng "github.com/hybridgroup/wasman"
"github.com/hybridgroup/wasman/types"
)

type Instance struct {
instance *wasmaneng.Instance
}

func (i *Instance) Call(name string, args ...interface{}) (interface{}, error) {
func (i *Instance) Call(name string, args ...any) (any, error) {
if len(args) == 0 {
results, _, err := i.instance.CallExportedFunc(name)
results, types, err := i.instance.CallExportedFunc(name)
if err != nil {
return nil, err
}
return results, nil
}

wargs := make([]uint64, len(args))
switch args[0].(type) {
case int32:
for i, v := range args {
wargs[i] = uint64(v.(int32))
res := decodeResults(results, types)
if len(res) == 0 {
return nil, nil
}
case uint32:
for i, v := range args {
wargs[i] = uint64(v.(uint32))
if len(res) == 1 {
return res[0], nil
}
return res, nil
}
results, _, err := i.instance.CallExportedFunc(name, wargs...)

results, types, err := i.instance.CallExportedFunc(name, encodeArgs(args)...)
if err != nil {
return nil, err
}

return results, nil
res := decodeResults(results, types)
if len(res) == 0 {
return nil, nil
}
if len(res) == 1 {
return res[0], nil
}
return res, nil
}

func encodeArgs(args []any) []uint64 {
deadprogram marked this conversation as resolved.
Show resolved Hide resolved
encoded := make([]uint64, 0, len(args))
for _, arg := range args {
encoded = append(encoded, encodeArg(arg))
}
return encoded
}

func encodeArg(arg any) uint64 {
switch val := arg.(type) {
case int32:
return uint64(val)
case int64:
return uint64(val)
case float32:
return uint64(math.Float32bits(val))
case float64:
return uint64(math.Float64bits(val))
case uint32:
return uint64(val)
case uint64:
return uint64(val)
}
panic("bad arg type")
}

func decodeResults(results []uint64, vtypes []types.ValueType) []any {
decoded := make([]any, 0, len(results))
for i, result := range results {
vtype := vtypes[i]
decoded = append(decoded, decodeResult(result, vtype))
}
return decoded
}

func decodeResult(result uint64, vtype types.ValueType) any {
switch vtype {
case types.ValueTypeF32:
return math.Float32frombits(uint32(result))
case types.ValueTypeF64:
return math.Float64frombits(uint64(result))
case types.ValueTypeI32:
return int32(result)
case types.ValueTypeI64:
return int64(result)
}
panic("unreachable")
}
87 changes: 87 additions & 0 deletions interp/wasman/instance_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
package wasman

import (
_ "embed"
"testing"
)

//go:embed tester.wasm
var wasmData []byte

func TestInstance(t *testing.T) {
i := Interpreter{
Memory: make([]byte, 65536),
}
if err := i.Init(); err != nil {
t.Errorf("Interpreter.Init() failed: %v", err)
}

if err := i.Load(wasmData); err != nil {
t.Errorf("Interpreter.Load() failed: %v", err)
}

inst, err := i.Run()
if err != nil {
t.Errorf("Interpreter.Run() failed: %v", err)
}

t.Run("Call int32", func(t *testing.T) {
results, err := inst.Call("test_int32", int32(1), int32(2))
if err != nil {
t.Errorf("Instance.Call() failed: %v", err)
}
if results != int32(3) {
t.Errorf("Instance.Call() failed: %v", results)
}
})

t.Run("Call uint32", func(t *testing.T) {
results, err := inst.Call("test_uint32", uint32(1), uint32(2))
if err != nil {
t.Errorf("Instance.Call() failed: %v", err)
}
if uint32(results.(int32)) != uint32(3) {
t.Errorf("Instance.Call() failed: %v", results)
}
})

t.Run("Call int64", func(t *testing.T) {
results, err := inst.Call("test_int64", int64(1), int64(2))
if err != nil {
t.Errorf("Instance.Call() failed: %v", err)
}
if results != int64(3) {
t.Errorf("Instance.Call() failed: %v", results)
}
})

t.Run("Call uint64", func(t *testing.T) {
results, err := inst.Call("test_uint64", uint64(1), uint64(2))
if err != nil {
t.Errorf("Instance.Call() failed: %v", err)
}
if uint64(results.(int64)) != uint64(3) {
t.Errorf("Instance.Call() failed: %v", results)
}
})

t.Run("Call float32", func(t *testing.T) {
results, err := inst.Call("test_float32", float32(100.2), float32(300.8))
if err != nil {
t.Errorf("Instance.Call() failed: %v", err)
}
if results != float32(401.0) {
t.Errorf("Instance.Call() failed: %v", results)
}
})

t.Run("Call float64", func(t *testing.T) {
results, err := inst.Call("test_float64", float64(111.2), float64(333.8))
if err != nil {
t.Errorf("Instance.Call() failed: %v", err)
}
if results != float64(445.0) {
t.Errorf("Instance.Call() failed: %v", results)
}
})
}
Binary file added interp/wasman/tester.wasm
Binary file not shown.