Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

Commit

Permalink
Implement snapshot for interpreter and add tests (#2)
Browse files Browse the repository at this point in the history
  • Loading branch information
anuraaga authored Dec 8, 2023
1 parent b3503f4 commit e6b7256
Show file tree
Hide file tree
Showing 7 changed files with 342 additions and 3 deletions.
2 changes: 0 additions & 2 deletions experimental/checkpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@ type Snapshot interface {
}

// Snapshotter allows host functions to snapshot the WebAssembly execution environment.
// Currently, only the Wasm stack is captured, but in the future, this may be expanded
// to things like globals.
type Snapshotter interface {
// Snapshot captures the current execution state.
Snapshot() Snapshot
Expand Down
100 changes: 100 additions & 0 deletions experimental/checkpoint_example_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
package experimental_test

import (
"context"
_ "embed"
"fmt"
"log"

wazero "github.com/wasilibs/wazerox"
"github.com/wasilibs/wazerox/api"
"github.com/wasilibs/wazerox/experimental"
)

// snapshotWasm was generated by the following:
//
// cd testdata; wat2wasm snapshot.wat
//
//go:embed testdata/snapshot.wasm
var snapshotWasm []byte

type snapshotsKey struct{}

func Example_enableSnapshotterKey() {
ctx := context.Background()

rt := wazero.NewRuntime(ctx)
defer rt.Close(ctx) // This closes everything this Runtime created.

// Enable experimental snapshotting functionality by setting it to context. We use this
// context when invoking functions, indicating to wazero to enable it.
ctx = context.WithValue(ctx, experimental.EnableSnapshotterKey{}, struct{}{})

// Also place a mutable holder of snapshots to be referenced during restore.
var snapshots []experimental.Snapshot
ctx = context.WithValue(ctx, snapshotsKey{}, &snapshots)

// Register host functions using snapshot and restore. Generally snapshot is saved
// into a mutable location in context to be referenced during restore.
_, err := rt.NewHostModuleBuilder("example").
NewFunctionBuilder().
WithFunc(func(ctx context.Context, mod api.Module, snapshotPtr uint32) int32 {
// Because we set EnableSnapshotterKey to context, this is non-nil.
snapshot := ctx.Value(experimental.SnapshotterKey{}).(experimental.Snapshotter).Snapshot()

// Get our mutable snapshots holder to be able to add to it. Our example only calls snapshot
// and restore once but real programs will often call them at multiple layers within a call
// stack with various e.g., try/catch statements.
snapshots := ctx.Value(snapshotsKey{}).(*[]experimental.Snapshot)
idx := len(*snapshots)
*snapshots = append(*snapshots, snapshot)

// Write a value to be passed back to restore. This is meant to be opaque to the guest
// and used to re-reference the snapshot.
ok := mod.Memory().WriteUint32Le(snapshotPtr, uint32(idx))
if !ok {
log.Panicln("failed to write snapshot index")
}

return 0
}).
Export("snapshot").
NewFunctionBuilder().
WithFunc(func(ctx context.Context, mod api.Module, snapshotPtr uint32) {
// Read the value written by snapshot to re-reference the snapshot.
idx, ok := mod.Memory().ReadUint32Le(snapshotPtr)
if !ok {
log.Panicln("failed to read snapshot index")
}

// Get the snapshot
snapshots := ctx.Value(snapshotsKey{}).(*[]experimental.Snapshot)
snapshot := (*snapshots)[idx]

// Restore! The invocation of this function will end as soon as we invoke
// Restore, so we also pass in our return value. The guest function run
// will finish with this return value.
snapshot.Restore([]uint64{5})
}).
Export("restore").
Instantiate(ctx)
if err != nil {
log.Panicln(err)
}

mod, err := rt.Instantiate(ctx, snapshotWasm) // Instantiate the actual code
if err != nil {
log.Panicln(err)
}

// Call the guest entrypoint.
res, err := mod.ExportedFunction("run").Call(ctx)
if err != nil {
log.Panicln(err)
}
// We restored and returned the restore value, so it's our result. If restore
// was instead a no-op, we would have returned 10 from normal code flow.
fmt.Println(res[0])
// Output:
// 5
}
121 changes: 121 additions & 0 deletions experimental/checkpoint_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
package experimental_test

import (
"context"
"testing"

wazero "github.com/wasilibs/wazerox"
"github.com/wasilibs/wazerox/api"
"github.com/wasilibs/wazerox/experimental"
"github.com/wasilibs/wazerox/internal/testing/require"
)

func TestSnapshotNestedWasmInvocation(t *testing.T) {
ctx := context.Background()

rt := wazero.NewRuntime(ctx)
defer rt.Close(ctx)

sidechannel := 0

_, err := rt.NewHostModuleBuilder("example").
NewFunctionBuilder().
WithFunc(func(ctx context.Context, mod api.Module, snapshotPtr uint32) int32 {
defer func() {
sidechannel = 10
}()
snapshot := ctx.Value(experimental.SnapshotterKey{}).(experimental.Snapshotter).Snapshot()
snapshots := ctx.Value(snapshotsKey{}).(*[]experimental.Snapshot)
idx := len(*snapshots)
*snapshots = append(*snapshots, snapshot)
ok := mod.Memory().WriteUint32Le(snapshotPtr, uint32(idx))
require.True(t, ok)

_, err := mod.ExportedFunction("restore").Call(ctx, uint64(snapshotPtr))
require.NoError(t, err)

return 2
}).
Export("snapshot").
NewFunctionBuilder().
WithFunc(func(ctx context.Context, mod api.Module, snapshotPtr uint32) {
idx, ok := mod.Memory().ReadUint32Le(snapshotPtr)
require.True(t, ok)
snapshots := ctx.Value(snapshotsKey{}).(*[]experimental.Snapshot)
snapshot := (*snapshots)[idx]

snapshot.Restore([]uint64{12})
}).
Export("restore").
Instantiate(ctx)
require.NoError(t, err)

mod, err := rt.Instantiate(ctx, snapshotWasm)
require.NoError(t, err)

var snapshots []experimental.Snapshot
ctx = context.WithValue(ctx, snapshotsKey{}, &snapshots)
ctx = context.WithValue(ctx, experimental.EnableSnapshotterKey{}, struct{}{})

snapshotPtr := uint64(0)
res, err := mod.ExportedFunction("snapshot").Call(ctx, snapshotPtr)
require.NoError(t, err)
// return value from restore
require.Equal(t, uint64(12), res[0])
// Host function defers within the call stack work fine
require.Equal(t, 10, sidechannel)
}

func TestSnapshotMultipleWasmInvocations(t *testing.T) {
ctx := context.Background()

rt := wazero.NewRuntime(ctx)
defer rt.Close(ctx)

_, err := rt.NewHostModuleBuilder("example").
NewFunctionBuilder().
WithFunc(func(ctx context.Context, mod api.Module, snapshotPtr uint32) int32 {
snapshot := ctx.Value(experimental.SnapshotterKey{}).(experimental.Snapshotter).Snapshot()
snapshots := ctx.Value(snapshotsKey{}).(*[]experimental.Snapshot)
idx := len(*snapshots)
*snapshots = append(*snapshots, snapshot)
ok := mod.Memory().WriteUint32Le(snapshotPtr, uint32(idx))
require.True(t, ok)

return 0
}).
Export("snapshot").
NewFunctionBuilder().
WithFunc(func(ctx context.Context, mod api.Module, snapshotPtr uint32) {
idx, ok := mod.Memory().ReadUint32Le(snapshotPtr)
require.True(t, ok)
snapshots := ctx.Value(snapshotsKey{}).(*[]experimental.Snapshot)
snapshot := (*snapshots)[idx]

snapshot.Restore([]uint64{12})
}).
Export("restore").
Instantiate(ctx)
require.NoError(t, err)

mod, err := rt.Instantiate(ctx, snapshotWasm)
require.NoError(t, err)

var snapshots []experimental.Snapshot
ctx = context.WithValue(ctx, snapshotsKey{}, &snapshots)
ctx = context.WithValue(ctx, experimental.EnableSnapshotterKey{}, struct{}{})

snapshotPtr := uint64(0)
res, err := mod.ExportedFunction("snapshot").Call(ctx, snapshotPtr)
require.NoError(t, err)
// snapshot returned zero
require.Equal(t, uint64(0), res[0])

// Fails, snapshot and restore are called from different wasm invocations. Currently, this
// results in a panic.
err = require.CapturePanic(func() {
_, _ = mod.ExportedFunction("restore").Call(ctx, snapshotPtr)
})
require.EqualError(t, err, "unhandled snapshot restore, this generally indicates restore was called from a different "+
"exported function invocation than snapshot")
}
Binary file added experimental/testdata/snapshot.wasm
Binary file not shown.
34 changes: 34 additions & 0 deletions experimental/testdata/snapshot.wat
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
(module
(import "example" "snapshot" (func $snapshot (param i32) (result i32)))
(import "example" "restore" (func $restore (param i32)))

(func $helper (result i32)
(call $restore (i32.const 0))
;; Not executed
i32.const 10
)

(func (export "run") (result i32) (local i32)
(call $snapshot (i32.const 0))
local.set 0
local.get 0
(if (result i32)
(then ;; restore return, finish with the value returned by it
local.get 0
)
(else ;; snapshot return, call heloer
(call $helper)
)
)
)

(func (export "snapshot") (param i32) (result i32)
(call $snapshot (local.get 0))
)

(func (export "restore") (param i32)
(call $restore (local.get 0))
)

(memory (export "memory") 1 1)
)
11 changes: 11 additions & 0 deletions internal/engine/compiler/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -848,6 +848,11 @@ func callFrameOffset(funcType *wasm.FunctionType) (ret int) {
//
// This is defined for testability.
func (ce *callEngine) deferredOnCall(ctx context.Context, m *wasm.ModuleInstance, recovered interface{}) (err error) {
if s, ok := recovered.(*snapshot); ok {
// A snapshot that wasn't handled was created by a different call engine possibly from a nested wasm invocation,
// let it propagate up to be handled by the caller.
panic(s)
}
if recovered != nil {
builder := wasmdebug.NewErrorBuilder()

Expand Down Expand Up @@ -1260,6 +1265,12 @@ func (s *snapshot) doRestore() {
copy(ce.stack[s.hostBase:], s.ret)
}

// Error implements the same method on error.
func (s *snapshot) Error() string {
return "unhandled snapshot restore, this generally indicates restore was called from a different " +
"exported function invocation than snapshot"
}

// stackIterator implements experimental.StackIterator.
type stackIterator struct {
stack []uint64
Expand Down
77 changes: 76 additions & 1 deletion internal/engine/interpreter/interpreter.go
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,53 @@ type function struct {
parent *compiledFunction
}

type snapshot struct {
stack []uint64
frames []*callFrame
pc uint64

ret []uint64

ce *callEngine
}

// Snapshot implements the same method as documented on experimental.Snapshotter.
func (ce *callEngine) Snapshot() experimental.Snapshot {
stack := make([]uint64, len(ce.stack))
copy(stack, ce.stack)

frames := make([]*callFrame, len(ce.frames))
copy(frames, ce.frames)

return &snapshot{
stack: stack,
frames: frames,
ce: ce,
}
}

// Restore implements the same method as documented on experimental.Snapshot.
func (s *snapshot) Restore(ret []uint64) {
s.ret = ret
panic(s)
}

func (s *snapshot) doRestore() {
ce := s.ce

ce.stack = s.stack
ce.frames = s.frames
ce.frames[len(ce.frames)-1].pc = s.pc

copy(ce.stack[len(ce.stack)-len(s.ret):], s.ret)
}

// Error implements the same method on error.
func (s *snapshot) Error() string {
return "unhandled snapshot restore, this generally indicates restore was called from a different " +
"exported function invocation than snapshot"
}

// functionFromUintptr resurrects the original *function from the given uintptr
// which comes from either funcref table or OpcodeRefFunc instruction.
func functionFromUintptr(ptr uintptr) *function {
Expand Down Expand Up @@ -512,6 +559,10 @@ func (ce *callEngine) call(ctx context.Context, params, results []uint64) (_ []u
}
}

if ctx.Value(experimental.EnableSnapshotterKey{}) != nil {
ctx = context.WithValue(ctx, experimental.SnapshotterKey{}, ce)
}

defer func() {
// If the module closed during the call, and the call didn't err for another reason, set an ExitError.
if err == nil {
Expand Down Expand Up @@ -555,6 +606,12 @@ type functionListenerInvocation struct {
// with the call frame stack traces. Also, reset the state of callEngine
// so that it can be used for the subsequent calls.
func (ce *callEngine) recoverOnCall(ctx context.Context, m *wasm.ModuleInstance, v interface{}) (err error) {
if s, ok := v.(*snapshot); ok {
// A snapshot that wasn't handled was created by a different call engine possibly from a nested wasm invocation,
// let it propagate up to be handled by the caller.
panic(s)
}

builder := wasmdebug.NewErrorBuilder()
frameCount := len(ce.frames)
functionListeners := make([]functionListenerInvocation, 0, 16)
Expand Down Expand Up @@ -669,7 +726,25 @@ func (ce *callEngine) callNativeFunc(ctx context.Context, m *wasm.ModuleInstance
ce.drop(op.Us[v+1])
frame.pc = op.Us[v]
case wazeroir.OperationKindCall:
ce.callFunction(ctx, f.moduleInstance, &functions[op.U1])
func() {
defer func() {
if r := recover(); r != nil {
if s, ok := r.(*snapshot); ok {
if s.ce == ce {
s.doRestore()
frame = ce.frames[len(ce.frames)-1]
body = frame.f.parent.body
bodyLen = uint64(len(body))
} else {
panic(r)
}
} else {
panic(r)
}
}
}()
ce.callFunction(ctx, f.moduleInstance, &functions[op.U1])
}()
frame.pc++
case wazeroir.OperationKindCallIndirect:
offset := ce.popValue()
Expand Down

0 comments on commit e6b7256

Please sign in to comment.