-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
engine, interp/wasman: handle the 4 basic value types correctly when …
…calling wasman Signed-off-by: deadprogram <[email protected]>
- Loading branch information
1 parent
f286aa5
commit 27c6d9d
Showing
4 changed files
with
156 additions
and
30 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
package engine | ||
|
||
type Instance interface { | ||
Call(name string, args ...interface{}) (interface{}, error) | ||
Call(name string, args ...any) (any, error) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,53 +1,92 @@ | ||
package wasman | ||
|
||
import ( | ||
"math" | ||
|
||
wasmaneng "github.com/hybridgroup/wasman" | ||
"github.com/hybridgroup/wasman/types" | ||
) | ||
|
||
type Instance struct { | ||
instance *wasmaneng.Instance | ||
} | ||
|
||
func (i *Instance) Call(name string, args ...interface{}) (interface{}, error) { | ||
func (i *Instance) Call(name string, args ...any) (any, error) { | ||
if len(args) == 0 { | ||
results, _, err := i.instance.CallExportedFunc(name) | ||
results, types, err := i.instance.CallExportedFunc(name) | ||
if err != nil { | ||
return nil, err | ||
} | ||
return results, nil | ||
res := decodeResults(results, types) | ||
if len(res) == 0 { | ||
return nil, nil | ||
} | ||
if len(res) == 1 { | ||
return res[0], nil | ||
} | ||
return res, nil | ||
} | ||
|
||
results, types, err := i.instance.CallExportedFunc(name, encodeArgs(args)...) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
wargs := make([]uint64, len(args)) | ||
switch args[0].(type) { | ||
res := decodeResults(results, types) | ||
if len(res) == 0 { | ||
return nil, nil | ||
} | ||
if len(res) == 1 { | ||
return res[0], nil | ||
} | ||
return res, nil | ||
} | ||
|
||
func encodeArgs(args []any) []uint64 { | ||
encoded := make([]uint64, 0, len(args)) | ||
for _, arg := range args { | ||
encoded = append(encoded, encodeArg(arg)) | ||
} | ||
return encoded | ||
} | ||
|
||
func encodeArg(arg any) uint64 { | ||
switch val := arg.(type) { | ||
case int32: | ||
for i, v := range args { | ||
wargs[i] = uint64(v.(int32)) | ||
} | ||
case uint32: | ||
for i, v := range args { | ||
wargs[i] = uint64(v.(uint32)) | ||
} | ||
return uint64(val) | ||
case int64: | ||
for i, v := range args { | ||
wargs[i] = uint64(v.(int64)) | ||
} | ||
case uint64: | ||
for i, v := range args { | ||
wargs[i] = uint64(v.(uint64)) | ||
} | ||
return uint64(val) | ||
case float32: | ||
for i, v := range args { | ||
wargs[i] = uint64(v.(float32)) | ||
} | ||
return uint64(math.Float32bits(val)) | ||
case float64: | ||
for i, v := range args { | ||
wargs[i] = uint64(v.(float64)) | ||
} | ||
return uint64(math.Float64bits(val)) | ||
case uint32: | ||
return uint64(val) | ||
case uint64: | ||
return uint64(val) | ||
} | ||
results, _, err := i.instance.CallExportedFunc(name, wargs...) | ||
if err != nil { | ||
return nil, err | ||
panic("bad arg type") | ||
} | ||
|
||
func decodeResults(results []uint64, vtypes []types.ValueType) []any { | ||
decoded := make([]any, 0, len(results)) | ||
for i, result := range results { | ||
vtype := vtypes[i] | ||
decoded = append(decoded, decodeResult(result, vtype)) | ||
} | ||
return decoded | ||
} | ||
|
||
return results, nil | ||
func decodeResult(result uint64, vtype types.ValueType) any { | ||
switch vtype { | ||
case types.ValueTypeF32: | ||
return math.Float32frombits(uint32(result)) | ||
case types.ValueTypeF64: | ||
return math.Float64frombits(uint64(result)) | ||
case types.ValueTypeI32: | ||
return int32(result) | ||
case types.ValueTypeI64: | ||
return int64(result) | ||
} | ||
panic("unreachable") | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
package wasman | ||
|
||
import ( | ||
_ "embed" | ||
"testing" | ||
) | ||
|
||
//go:embed tester.wasm | ||
var wasmData []byte | ||
|
||
func TestInstance(t *testing.T) { | ||
i := Interpreter{ | ||
Memory: make([]byte, 65536), | ||
} | ||
if err := i.Init(); err != nil { | ||
t.Errorf("Interpreter.Init() failed: %v", err) | ||
} | ||
|
||
if err := i.Load(wasmData); err != nil { | ||
t.Errorf("Interpreter.Load() failed: %v", err) | ||
} | ||
|
||
inst, err := i.Run() | ||
if err != nil { | ||
t.Errorf("Interpreter.Run() failed: %v", err) | ||
} | ||
|
||
t.Run("Call int32", func(t *testing.T) { | ||
results, err := inst.Call("test_int32", int32(1), int32(2)) | ||
if err != nil { | ||
t.Errorf("Instance.Call() failed: %v", err) | ||
} | ||
if results != int32(3) { | ||
t.Errorf("Instance.Call() failed: %v", results) | ||
} | ||
}) | ||
|
||
t.Run("Call uint32", func(t *testing.T) { | ||
results, err := inst.Call("test_uint32", uint32(1), uint32(2)) | ||
if err != nil { | ||
t.Errorf("Instance.Call() failed: %v", err) | ||
} | ||
if uint32(results.(int32)) != uint32(3) { | ||
t.Errorf("Instance.Call() failed: %v", results) | ||
} | ||
}) | ||
|
||
t.Run("Call int64", func(t *testing.T) { | ||
results, err := inst.Call("test_int64", int64(1), int64(2)) | ||
if err != nil { | ||
t.Errorf("Instance.Call() failed: %v", err) | ||
} | ||
if results != int64(3) { | ||
t.Errorf("Instance.Call() failed: %v", results) | ||
} | ||
}) | ||
|
||
t.Run("Call uint64", func(t *testing.T) { | ||
results, err := inst.Call("test_uint64", uint64(1), uint64(2)) | ||
if err != nil { | ||
t.Errorf("Instance.Call() failed: %v", err) | ||
} | ||
if uint64(results.(int64)) != uint64(3) { | ||
t.Errorf("Instance.Call() failed: %v", results) | ||
} | ||
}) | ||
|
||
t.Run("Call float32", func(t *testing.T) { | ||
results, err := inst.Call("test_float32", float32(100.2), float32(300.8)) | ||
if err != nil { | ||
t.Errorf("Instance.Call() failed: %v", err) | ||
} | ||
if results != float32(401.0) { | ||
t.Errorf("Instance.Call() failed: %v", results) | ||
} | ||
}) | ||
|
||
t.Run("Call float64", func(t *testing.T) { | ||
results, err := inst.Call("test_float64", float64(111.2), float64(333.8)) | ||
if err != nil { | ||
t.Errorf("Instance.Call() failed: %v", err) | ||
} | ||
if results != float64(445.0) { | ||
t.Errorf("Instance.Call() failed: %v", results) | ||
} | ||
}) | ||
} |
Binary file not shown.