Skip to content

Commit

Permalink
engine, interp/wasman: handle the 4 basic value types correctly when …
Browse files Browse the repository at this point in the history
…calling wasman

Signed-off-by: deadprogram <[email protected]>
  • Loading branch information
deadprogram committed Feb 28, 2024
1 parent f286aa5 commit 27c6d9d
Show file tree
Hide file tree
Showing 4 changed files with 156 additions and 30 deletions.
2 changes: 1 addition & 1 deletion engine/instance.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
package engine

type Instance interface {
Call(name string, args ...interface{}) (interface{}, error)
Call(name string, args ...any) (any, error)
}
97 changes: 68 additions & 29 deletions interp/wasman/instance.go
Original file line number Diff line number Diff line change
@@ -1,53 +1,92 @@
package wasman

import (
"math"

wasmaneng "github.com/hybridgroup/wasman"
"github.com/hybridgroup/wasman/types"
)

type Instance struct {
instance *wasmaneng.Instance
}

func (i *Instance) Call(name string, args ...interface{}) (interface{}, error) {
func (i *Instance) Call(name string, args ...any) (any, error) {
if len(args) == 0 {
results, _, err := i.instance.CallExportedFunc(name)
results, types, err := i.instance.CallExportedFunc(name)
if err != nil {
return nil, err
}
return results, nil
res := decodeResults(results, types)
if len(res) == 0 {
return nil, nil
}
if len(res) == 1 {
return res[0], nil
}
return res, nil
}

results, types, err := i.instance.CallExportedFunc(name, encodeArgs(args)...)
if err != nil {
return nil, err
}

wargs := make([]uint64, len(args))
switch args[0].(type) {
res := decodeResults(results, types)
if len(res) == 0 {
return nil, nil
}
if len(res) == 1 {
return res[0], nil
}
return res, nil
}

func encodeArgs(args []any) []uint64 {
encoded := make([]uint64, 0, len(args))
for _, arg := range args {
encoded = append(encoded, encodeArg(arg))
}
return encoded
}

func encodeArg(arg any) uint64 {
switch val := arg.(type) {
case int32:
for i, v := range args {
wargs[i] = uint64(v.(int32))
}
case uint32:
for i, v := range args {
wargs[i] = uint64(v.(uint32))
}
return uint64(val)
case int64:
for i, v := range args {
wargs[i] = uint64(v.(int64))
}
case uint64:
for i, v := range args {
wargs[i] = uint64(v.(uint64))
}
return uint64(val)
case float32:
for i, v := range args {
wargs[i] = uint64(v.(float32))
}
return uint64(math.Float32bits(val))
case float64:
for i, v := range args {
wargs[i] = uint64(v.(float64))
}
return uint64(math.Float64bits(val))
case uint32:
return uint64(val)
case uint64:
return uint64(val)
}
results, _, err := i.instance.CallExportedFunc(name, wargs...)
if err != nil {
return nil, err
panic("bad arg type")
}

func decodeResults(results []uint64, vtypes []types.ValueType) []any {
decoded := make([]any, 0, len(results))
for i, result := range results {
vtype := vtypes[i]
decoded = append(decoded, decodeResult(result, vtype))
}
return decoded
}

return results, nil
func decodeResult(result uint64, vtype types.ValueType) any {
switch vtype {
case types.ValueTypeF32:
return math.Float32frombits(uint32(result))
case types.ValueTypeF64:
return math.Float64frombits(uint64(result))
case types.ValueTypeI32:
return int32(result)
case types.ValueTypeI64:
return int64(result)
}
panic("unreachable")
}
87 changes: 87 additions & 0 deletions interp/wasman/instance_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
package wasman

import (
_ "embed"
"testing"
)

//go:embed tester.wasm
var wasmData []byte

func TestInstance(t *testing.T) {
i := Interpreter{
Memory: make([]byte, 65536),
}
if err := i.Init(); err != nil {
t.Errorf("Interpreter.Init() failed: %v", err)
}

if err := i.Load(wasmData); err != nil {
t.Errorf("Interpreter.Load() failed: %v", err)
}

inst, err := i.Run()
if err != nil {
t.Errorf("Interpreter.Run() failed: %v", err)
}

t.Run("Call int32", func(t *testing.T) {
results, err := inst.Call("test_int32", int32(1), int32(2))
if err != nil {
t.Errorf("Instance.Call() failed: %v", err)
}
if results != int32(3) {
t.Errorf("Instance.Call() failed: %v", results)
}
})

t.Run("Call uint32", func(t *testing.T) {
results, err := inst.Call("test_uint32", uint32(1), uint32(2))
if err != nil {
t.Errorf("Instance.Call() failed: %v", err)
}
if uint32(results.(int32)) != uint32(3) {
t.Errorf("Instance.Call() failed: %v", results)
}
})

t.Run("Call int64", func(t *testing.T) {
results, err := inst.Call("test_int64", int64(1), int64(2))
if err != nil {
t.Errorf("Instance.Call() failed: %v", err)
}
if results != int64(3) {
t.Errorf("Instance.Call() failed: %v", results)
}
})

t.Run("Call uint64", func(t *testing.T) {
results, err := inst.Call("test_uint64", uint64(1), uint64(2))
if err != nil {
t.Errorf("Instance.Call() failed: %v", err)
}
if uint64(results.(int64)) != uint64(3) {
t.Errorf("Instance.Call() failed: %v", results)
}
})

t.Run("Call float32", func(t *testing.T) {
results, err := inst.Call("test_float32", float32(100.2), float32(300.8))
if err != nil {
t.Errorf("Instance.Call() failed: %v", err)
}
if results != float32(401.0) {
t.Errorf("Instance.Call() failed: %v", results)
}
})

t.Run("Call float64", func(t *testing.T) {
results, err := inst.Call("test_float64", float64(111.2), float64(333.8))
if err != nil {
t.Errorf("Instance.Call() failed: %v", err)
}
if results != float64(445.0) {
t.Errorf("Instance.Call() failed: %v", results)
}
})
}
Binary file added interp/wasman/tester.wasm
Binary file not shown.

0 comments on commit 27c6d9d

Please sign in to comment.