diff --git a/experimental/checkpoint.go b/experimental/checkpoint.go index 8791abd2..1560e6ba 100644 --- a/experimental/checkpoint.go +++ b/experimental/checkpoint.go @@ -10,8 +10,6 @@ type Snapshot interface { } // Snapshotter allows host functions to snapshot the WebAssembly execution environment. -// Currently, only the Wasm stack is captured, but in the future, this may be expanded -// to things like globals. type Snapshotter interface { // Snapshot captures the current execution state. Snapshot() Snapshot diff --git a/experimental/checkpoint_example_test.go b/experimental/checkpoint_example_test.go new file mode 100644 index 00000000..ab400e7e --- /dev/null +++ b/experimental/checkpoint_example_test.go @@ -0,0 +1,100 @@ +package experimental_test + +import ( + "context" + _ "embed" + "fmt" + "log" + + wazero "github.com/wasilibs/wazerox" + "github.com/wasilibs/wazerox/api" + "github.com/wasilibs/wazerox/experimental" +) + +// snapshotWasm was generated by the following: +// +// cd testdata; wat2wasm snapshot.wat +// +//go:embed testdata/snapshot.wasm +var snapshotWasm []byte + +type snapshotsKey struct{} + +func Example_enableSnapshotterKey() { + ctx := context.Background() + + rt := wazero.NewRuntime(ctx) + defer rt.Close(ctx) // This closes everything this Runtime created. + + // Enable experimental snapshotting functionality by setting it to context. We use this + // context when invoking functions, indicating to wazero to enable it. + ctx = context.WithValue(ctx, experimental.EnableSnapshotterKey{}, struct{}{}) + + // Also place a mutable holder of snapshots to be referenced during restore. + var snapshots []experimental.Snapshot + ctx = context.WithValue(ctx, snapshotsKey{}, &snapshots) + + // Register host functions using snapshot and restore. Generally snapshot is saved + // into a mutable location in context to be referenced during restore. + _, err := rt.NewHostModuleBuilder("example"). + NewFunctionBuilder(). + WithFunc(func(ctx context.Context, mod api.Module, snapshotPtr uint32) int32 { + // Because we set EnableSnapshotterKey to context, this is non-nil. + snapshot := ctx.Value(experimental.SnapshotterKey{}).(experimental.Snapshotter).Snapshot() + + // Get our mutable snapshots holder to be able to add to it. Our example only calls snapshot + // and restore once but real programs will often call them at multiple layers within a call + // stack with various e.g., try/catch statements. + snapshots := ctx.Value(snapshotsKey{}).(*[]experimental.Snapshot) + idx := len(*snapshots) + *snapshots = append(*snapshots, snapshot) + + // Write a value to be passed back to restore. This is meant to be opaque to the guest + // and used to re-reference the snapshot. + ok := mod.Memory().WriteUint32Le(snapshotPtr, uint32(idx)) + if !ok { + log.Panicln("failed to write snapshot index") + } + + return 0 + }). + Export("snapshot"). + NewFunctionBuilder(). + WithFunc(func(ctx context.Context, mod api.Module, snapshotPtr uint32) { + // Read the value written by snapshot to re-reference the snapshot. + idx, ok := mod.Memory().ReadUint32Le(snapshotPtr) + if !ok { + log.Panicln("failed to read snapshot index") + } + + // Get the snapshot + snapshots := ctx.Value(snapshotsKey{}).(*[]experimental.Snapshot) + snapshot := (*snapshots)[idx] + + // Restore! The invocation of this function will end as soon as we invoke + // Restore, so we also pass in our return value. The guest function run + // will finish with this return value. + snapshot.Restore([]uint64{5}) + }). + Export("restore"). + Instantiate(ctx) + if err != nil { + log.Panicln(err) + } + + mod, err := rt.Instantiate(ctx, snapshotWasm) // Instantiate the actual code + if err != nil { + log.Panicln(err) + } + + // Call the guest entrypoint. + res, err := mod.ExportedFunction("run").Call(ctx) + if err != nil { + log.Panicln(err) + } + // We restored and returned the restore value, so it's our result. If restore + // was instead a no-op, we would have returned 10 from normal code flow. + fmt.Println(res[0]) + // Output: + // 5 +} diff --git a/experimental/checkpoint_test.go b/experimental/checkpoint_test.go new file mode 100644 index 00000000..d5a66f0a --- /dev/null +++ b/experimental/checkpoint_test.go @@ -0,0 +1,121 @@ +package experimental_test + +import ( + "context" + "testing" + + wazero "github.com/wasilibs/wazerox" + "github.com/wasilibs/wazerox/api" + "github.com/wasilibs/wazerox/experimental" + "github.com/wasilibs/wazerox/internal/testing/require" +) + +func TestSnapshotNestedWasmInvocation(t *testing.T) { + ctx := context.Background() + + rt := wazero.NewRuntime(ctx) + defer rt.Close(ctx) + + sidechannel := 0 + + _, err := rt.NewHostModuleBuilder("example"). + NewFunctionBuilder(). + WithFunc(func(ctx context.Context, mod api.Module, snapshotPtr uint32) int32 { + defer func() { + sidechannel = 10 + }() + snapshot := ctx.Value(experimental.SnapshotterKey{}).(experimental.Snapshotter).Snapshot() + snapshots := ctx.Value(snapshotsKey{}).(*[]experimental.Snapshot) + idx := len(*snapshots) + *snapshots = append(*snapshots, snapshot) + ok := mod.Memory().WriteUint32Le(snapshotPtr, uint32(idx)) + require.True(t, ok) + + _, err := mod.ExportedFunction("restore").Call(ctx, uint64(snapshotPtr)) + require.NoError(t, err) + + return 2 + }). + Export("snapshot"). + NewFunctionBuilder(). + WithFunc(func(ctx context.Context, mod api.Module, snapshotPtr uint32) { + idx, ok := mod.Memory().ReadUint32Le(snapshotPtr) + require.True(t, ok) + snapshots := ctx.Value(snapshotsKey{}).(*[]experimental.Snapshot) + snapshot := (*snapshots)[idx] + + snapshot.Restore([]uint64{12}) + }). + Export("restore"). + Instantiate(ctx) + require.NoError(t, err) + + mod, err := rt.Instantiate(ctx, snapshotWasm) + require.NoError(t, err) + + var snapshots []experimental.Snapshot + ctx = context.WithValue(ctx, snapshotsKey{}, &snapshots) + ctx = context.WithValue(ctx, experimental.EnableSnapshotterKey{}, struct{}{}) + + snapshotPtr := uint64(0) + res, err := mod.ExportedFunction("snapshot").Call(ctx, snapshotPtr) + require.NoError(t, err) + // return value from restore + require.Equal(t, uint64(12), res[0]) + // Host function defers within the call stack work fine + require.Equal(t, 10, sidechannel) +} + +func TestSnapshotMultipleWasmInvocations(t *testing.T) { + ctx := context.Background() + + rt := wazero.NewRuntime(ctx) + defer rt.Close(ctx) + + _, err := rt.NewHostModuleBuilder("example"). + NewFunctionBuilder(). + WithFunc(func(ctx context.Context, mod api.Module, snapshotPtr uint32) int32 { + snapshot := ctx.Value(experimental.SnapshotterKey{}).(experimental.Snapshotter).Snapshot() + snapshots := ctx.Value(snapshotsKey{}).(*[]experimental.Snapshot) + idx := len(*snapshots) + *snapshots = append(*snapshots, snapshot) + ok := mod.Memory().WriteUint32Le(snapshotPtr, uint32(idx)) + require.True(t, ok) + + return 0 + }). + Export("snapshot"). + NewFunctionBuilder(). + WithFunc(func(ctx context.Context, mod api.Module, snapshotPtr uint32) { + idx, ok := mod.Memory().ReadUint32Le(snapshotPtr) + require.True(t, ok) + snapshots := ctx.Value(snapshotsKey{}).(*[]experimental.Snapshot) + snapshot := (*snapshots)[idx] + + snapshot.Restore([]uint64{12}) + }). + Export("restore"). + Instantiate(ctx) + require.NoError(t, err) + + mod, err := rt.Instantiate(ctx, snapshotWasm) + require.NoError(t, err) + + var snapshots []experimental.Snapshot + ctx = context.WithValue(ctx, snapshotsKey{}, &snapshots) + ctx = context.WithValue(ctx, experimental.EnableSnapshotterKey{}, struct{}{}) + + snapshotPtr := uint64(0) + res, err := mod.ExportedFunction("snapshot").Call(ctx, snapshotPtr) + require.NoError(t, err) + // snapshot returned zero + require.Equal(t, uint64(0), res[0]) + + // Fails, snapshot and restore are called from different wasm invocations. Currently, this + // results in a panic. + err = require.CapturePanic(func() { + _, _ = mod.ExportedFunction("restore").Call(ctx, snapshotPtr) + }) + require.EqualError(t, err, "unhandled snapshot restore, this generally indicates restore was called from a different "+ + "exported function invocation than snapshot") +} diff --git a/experimental/testdata/snapshot.wasm b/experimental/testdata/snapshot.wasm new file mode 100644 index 00000000..b07f5f2a Binary files /dev/null and b/experimental/testdata/snapshot.wasm differ diff --git a/experimental/testdata/snapshot.wat b/experimental/testdata/snapshot.wat new file mode 100644 index 00000000..68714237 --- /dev/null +++ b/experimental/testdata/snapshot.wat @@ -0,0 +1,34 @@ +(module + (import "example" "snapshot" (func $snapshot (param i32) (result i32))) + (import "example" "restore" (func $restore (param i32))) + + (func $helper (result i32) + (call $restore (i32.const 0)) + ;; Not executed + i32.const 10 + ) + + (func (export "run") (result i32) (local i32) + (call $snapshot (i32.const 0)) + local.set 0 + local.get 0 + (if (result i32) + (then ;; restore return, finish with the value returned by it + local.get 0 + ) + (else ;; snapshot return, call heloer + (call $helper) + ) + ) + ) + + (func (export "snapshot") (param i32) (result i32) + (call $snapshot (local.get 0)) + ) + + (func (export "restore") (param i32) + (call $restore (local.get 0)) + ) + + (memory (export "memory") 1 1) +) diff --git a/internal/engine/compiler/engine.go b/internal/engine/compiler/engine.go index 492a0468..88d9cee9 100644 --- a/internal/engine/compiler/engine.go +++ b/internal/engine/compiler/engine.go @@ -848,6 +848,11 @@ func callFrameOffset(funcType *wasm.FunctionType) (ret int) { // // This is defined for testability. func (ce *callEngine) deferredOnCall(ctx context.Context, m *wasm.ModuleInstance, recovered interface{}) (err error) { + if s, ok := recovered.(*snapshot); ok { + // A snapshot that wasn't handled was created by a different call engine possibly from a nested wasm invocation, + // let it propagate up to be handled by the caller. + panic(s) + } if recovered != nil { builder := wasmdebug.NewErrorBuilder() @@ -1260,6 +1265,12 @@ func (s *snapshot) doRestore() { copy(ce.stack[s.hostBase:], s.ret) } +// Error implements the same method on error. +func (s *snapshot) Error() string { + return "unhandled snapshot restore, this generally indicates restore was called from a different " + + "exported function invocation than snapshot" +} + // stackIterator implements experimental.StackIterator. type stackIterator struct { stack []uint64 diff --git a/internal/engine/interpreter/interpreter.go b/internal/engine/interpreter/interpreter.go index 2feee100..dfea6808 100644 --- a/internal/engine/interpreter/interpreter.go +++ b/internal/engine/interpreter/interpreter.go @@ -204,6 +204,53 @@ type function struct { parent *compiledFunction } +type snapshot struct { + stack []uint64 + frames []*callFrame + pc uint64 + + ret []uint64 + + ce *callEngine +} + +// Snapshot implements the same method as documented on experimental.Snapshotter. +func (ce *callEngine) Snapshot() experimental.Snapshot { + stack := make([]uint64, len(ce.stack)) + copy(stack, ce.stack) + + frames := make([]*callFrame, len(ce.frames)) + copy(frames, ce.frames) + + return &snapshot{ + stack: stack, + frames: frames, + ce: ce, + } +} + +// Restore implements the same method as documented on experimental.Snapshot. +func (s *snapshot) Restore(ret []uint64) { + s.ret = ret + panic(s) +} + +func (s *snapshot) doRestore() { + ce := s.ce + + ce.stack = s.stack + ce.frames = s.frames + ce.frames[len(ce.frames)-1].pc = s.pc + + copy(ce.stack[len(ce.stack)-len(s.ret):], s.ret) +} + +// Error implements the same method on error. +func (s *snapshot) Error() string { + return "unhandled snapshot restore, this generally indicates restore was called from a different " + + "exported function invocation than snapshot" +} + // functionFromUintptr resurrects the original *function from the given uintptr // which comes from either funcref table or OpcodeRefFunc instruction. func functionFromUintptr(ptr uintptr) *function { @@ -512,6 +559,10 @@ func (ce *callEngine) call(ctx context.Context, params, results []uint64) (_ []u } } + if ctx.Value(experimental.EnableSnapshotterKey{}) != nil { + ctx = context.WithValue(ctx, experimental.SnapshotterKey{}, ce) + } + defer func() { // If the module closed during the call, and the call didn't err for another reason, set an ExitError. if err == nil { @@ -555,6 +606,12 @@ type functionListenerInvocation struct { // with the call frame stack traces. Also, reset the state of callEngine // so that it can be used for the subsequent calls. func (ce *callEngine) recoverOnCall(ctx context.Context, m *wasm.ModuleInstance, v interface{}) (err error) { + if s, ok := v.(*snapshot); ok { + // A snapshot that wasn't handled was created by a different call engine possibly from a nested wasm invocation, + // let it propagate up to be handled by the caller. + panic(s) + } + builder := wasmdebug.NewErrorBuilder() frameCount := len(ce.frames) functionListeners := make([]functionListenerInvocation, 0, 16) @@ -669,7 +726,25 @@ func (ce *callEngine) callNativeFunc(ctx context.Context, m *wasm.ModuleInstance ce.drop(op.Us[v+1]) frame.pc = op.Us[v] case wazeroir.OperationKindCall: - ce.callFunction(ctx, f.moduleInstance, &functions[op.U1]) + func() { + defer func() { + if r := recover(); r != nil { + if s, ok := r.(*snapshot); ok { + if s.ce == ce { + s.doRestore() + frame = ce.frames[len(ce.frames)-1] + body = frame.f.parent.body + bodyLen = uint64(len(body)) + } else { + panic(r) + } + } else { + panic(r) + } + } + }() + ce.callFunction(ctx, f.moduleInstance, &functions[op.U1]) + }() frame.pc++ case wazeroir.OperationKindCallIndirect: offset := ce.popValue()