Skip to content

Commit

Permalink
wazevo: support for experimental snapshotting (#2144)
Browse files Browse the repository at this point in the history
Signed-off-by: Takeshi Yoneda <[email protected]>
  • Loading branch information
mathetake authored Mar 11, 2024
1 parent 472719d commit ac1961d
Show file tree
Hide file tree
Showing 6 changed files with 120 additions and 25 deletions.
3 changes: 1 addition & 2 deletions experimental/checkpoint_example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@ type snapshotsKey struct{}
func Example_enableSnapshotterKey() {
ctx := context.Background()

// TODO: currently, only the interpreter is supported for snapshotting.
rt := wazero.NewRuntimeWithConfig(ctx, wazero.NewRuntimeConfigInterpreter())
rt := wazero.NewRuntime(ctx)
defer rt.Close(ctx) // This closes everything this Runtime created.

// Enable experimental snapshotting functionality by setting it to context. We use this
Expand Down
6 changes: 2 additions & 4 deletions experimental/checkpoint_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@ import (
func TestSnapshotNestedWasmInvocation(t *testing.T) {
ctx := context.Background()

// TODO: currently, only the interpreter is supported for snapshotting.
rt := wazero.NewRuntimeWithConfig(ctx, wazero.NewRuntimeConfigInterpreter())
rt := wazero.NewRuntime(ctx)
defer rt.Close(ctx)

sidechannel := 0
Expand Down Expand Up @@ -70,8 +69,7 @@ func TestSnapshotNestedWasmInvocation(t *testing.T) {
func TestSnapshotMultipleWasmInvocations(t *testing.T) {
ctx := context.Background()

// TODO: currently, only the interpreter is supported for snapshotting.
rt := wazero.NewRuntimeWithConfig(ctx, wazero.NewRuntimeConfigInterpreter())
rt := wazero.NewRuntime(ctx)
defer rt.Close(ctx)

_, err := rt.NewHostModuleBuilder("example").
Expand Down
2 changes: 1 addition & 1 deletion internal/engine/wazevo/backend/isa/amd64/stack.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ func GoCallStackView(stackPointerBeforeGoCall *uint64) []uint64 {
return view
}

func AdjustStackAfterGrown(oldRsp, oldTop, rsp, rbp, top uintptr) {
func AdjustClonedStack(oldRsp, oldTop, rsp, rbp, top uintptr) {
diff := uint64(rsp - oldRsp)

newBuf := stackView(rbp, top)
Expand Down
4 changes: 2 additions & 2 deletions internal/engine/wazevo/backend/isa/amd64/stack_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ func addressOf(v *byte) uint64 {
return uint64(uintptr(unsafe.Pointer(v)))
}

func TestAdjustStackAfterGrown(t *testing.T) {
func TestAdjustClonedStack(t *testing.T) {
// In order to allocate slices on Go heap, we need to allocSlice function.
allocSlice := func(size int) []byte {
return make([]byte, size)
Expand All @@ -63,7 +63,7 @@ func TestAdjustStackAfterGrown(t *testing.T) {
// Coy old stack to new stack which contains the old pointers to the old stack elements.
copy(newStack, oldStack)

AdjustStackAfterGrown(oldRsp, oldTop, rsp, rbp, uintptr(addressOf(&newStack[len(newStack)-1])))
AdjustClonedStack(oldRsp, oldTop, rsp, rbp, uintptr(addressOf(&newStack[len(newStack)-1])))
require.Equal(t, addressOf(&newStack[rbpIndex+16]), binary.LittleEndian.Uint64(newStack[rbpIndex:]))
require.Equal(t, addressOf(&newStack[rbpIndex+32]), binary.LittleEndian.Uint64(newStack[rbpIndex+16:]))
require.Equal(t, addressOf(&newStack[rbpIndex+160]), binary.LittleEndian.Uint64(newStack[rbpIndex+32:]))
Expand Down
124 changes: 111 additions & 13 deletions internal/engine/wazevo/call_engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,11 @@ func (c *callEngine) CallWithStack(ctx context.Context, paramResultStack []uint6

// CallWithStack implements api.Function.
func (c *callEngine) callWithStack(ctx context.Context, paramResultStack []uint64) (err error) {
snapshotEnabled := ctx.Value(experimental.EnableSnapshotterKey{}) != nil
if snapshotEnabled {
ctx = context.WithValue(ctx, experimental.SnapshotterKey{}, c)
}

if wazevoapi.StackGuardCheckEnabled {
defer func() {
wazevoapi.CheckStackGuardPage(c.stack)
Expand All @@ -217,7 +222,13 @@ func (c *callEngine) callWithStack(ctx context.Context, paramResultStack []uint6
paramResultPtr = &paramResultStack[0]
}
defer func() {
if r := recover(); r != nil {
r := recover()
if s, ok := r.(*snapshot); ok {
// A snapshot that wasn't handled was created by a different call engine possibly from a nested wasm invocation,
// let it propagate up to be handled by the caller.
panic(s)
}
if r != nil {
type listenerForAbort struct {
def api.FunctionDefinition
lsn experimental.FunctionListener
Expand Down Expand Up @@ -284,7 +295,7 @@ func (c *callEngine) callWithStack(ctx context.Context, paramResultStack []uint6
if err != nil {
return err
}
adjustStackAfterGrown(oldsp, oldTop, newsp, newfp, c.stackTop)
adjustClonedStack(oldsp, oldTop, newsp, newfp, c.stackTop)
// Old stack must be alive until the new stack is adjusted.
runtime.KeepAlive(oldStack)
c.execCtx.exitCode = wazevoapi.ExitCodeOK
Expand Down Expand Up @@ -322,7 +333,12 @@ func (c *callEngine) callWithStack(ctx context.Context, paramResultStack []uint6
case wazevoapi.ExitCodeCallGoFunction:
index := wazevoapi.GoFunctionIndexFromExitCode(ec)
f := hostModuleGoFuncFromOpaque[api.GoFunction](index, c.execCtx.goFunctionCallCalleeModuleContextOpaque)
f.Call(ctx, goCallStackView(c.execCtx.stackPointerBeforeGoCall))
func() {
if snapshotEnabled {
defer snapshotRecoverFn(c)
}
f.Call(ctx, goCallStackView(c.execCtx.stackPointerBeforeGoCall))
}()
// Back to the native code.
c.execCtx.exitCode = wazevoapi.ExitCodeOK
afterGoFunctionCallEntrypoint(c.execCtx.goCallReturnAddress, c.execCtxPtr,
Expand All @@ -339,7 +355,12 @@ func (c *callEngine) callWithStack(ctx context.Context, paramResultStack []uint6
def := hostModule.FunctionDefinition(wasm.Index(index))
listener.Before(ctx, callerModule, def, s, c.stackIterator(true))
// Call into the Go function.
f.Call(ctx, s)
func() {
if snapshotEnabled {
defer snapshotRecoverFn(c)
}
f.Call(ctx, s)
}()
// Call Listener.After.
listener.After(ctx, callerModule, def, s)
// Back to the native code.
Expand All @@ -350,7 +371,12 @@ func (c *callEngine) callWithStack(ctx context.Context, paramResultStack []uint6
index := wazevoapi.GoFunctionIndexFromExitCode(ec)
f := hostModuleGoFuncFromOpaque[api.GoModuleFunction](index, c.execCtx.goFunctionCallCalleeModuleContextOpaque)
mod := c.callerModuleInstance()
f.Call(ctx, mod, goCallStackView(c.execCtx.stackPointerBeforeGoCall))
func() {
if snapshotEnabled {
defer snapshotRecoverFn(c)
}
f.Call(ctx, mod, goCallStackView(c.execCtx.stackPointerBeforeGoCall))
}()
// Back to the native code.
c.execCtx.exitCode = wazevoapi.ExitCodeOK
afterGoFunctionCallEntrypoint(c.execCtx.goCallReturnAddress, c.execCtxPtr,
Expand All @@ -367,7 +393,12 @@ func (c *callEngine) callWithStack(ctx context.Context, paramResultStack []uint6
def := hostModule.FunctionDefinition(wasm.Index(index))
listener.Before(ctx, callerModule, def, s, c.stackIterator(true))
// Call into the Go function.
f.Call(ctx, callerModule, s)
func() {
if snapshotEnabled {
defer snapshotRecoverFn(c)
}
f.Call(ctx, callerModule, s)
}()
// Call Listener.After.
listener.After(ctx, callerModule, def, s)
// Back to the native code.
Expand Down Expand Up @@ -429,7 +460,7 @@ func (c *callEngine) callWithStack(ctx context.Context, paramResultStack []uint6
addr := unsafe.Add(unsafe.Pointer(&mem.Buffer[0]), offset)
return atomic.LoadUint32((*uint32)(addr))
})
s[0] = uint64(res)
s[0] = res
c.execCtx.exitCode = wazevoapi.ExitCodeOK
afterGoFunctionCallEntrypoint(c.execCtx.goCallReturnAddress, c.execCtxPtr,
uintptr(unsafe.Pointer(c.execCtx.stackPointerBeforeGoCall)), c.execCtx.framePointerBeforeGoCall)
Expand Down Expand Up @@ -527,7 +558,13 @@ func (c *callEngine) growStack() (newSP, newFP uintptr, err error) {
}

newLen := 2*currentLen + c.execCtx.stackGrowRequiredSize + 16 // Stack might be aligned to 16 bytes, so add 16 bytes just in case.
newStack := make([]byte, newLen)
newSP, newFP, c.stackTop, c.stack = c.cloneStack(newLen)
c.execCtx.stackBottomPtr = &c.stack[0]
return
}

func (c *callEngine) cloneStack(l uintptr) (newSP, newFP, newTop uintptr, newStack []byte) {
newStack = make([]byte, l)

relSp := c.stackTop - uintptr(unsafe.Pointer(c.execCtx.stackPointerBeforeGoCall))
relFp := c.stackTop - c.execCtx.framePointerBeforeGoCall
Expand All @@ -540,7 +577,7 @@ func (c *callEngine) growStack() (newSP, newFP uintptr, err error) {
sh.Len = int(relSp)
sh.Cap = int(relSp)
}
newTop := alignedStackTop(newStack)
newTop = alignedStackTop(newStack)
{
newSP = newTop - relSp
newFP = newTop - relFp
Expand All @@ -550,10 +587,6 @@ func (c *callEngine) growStack() (newSP, newFP uintptr, err error) {
sh.Cap = int(relSp)
}
copy(newStackAligned, prevStackAligned)

c.stack = newStack
c.stackTop = newTop
c.execCtx.stackBottomPtr = &newStack[0]
return
}

Expand Down Expand Up @@ -624,3 +657,68 @@ func (si *stackIterator) SourceOffsetForPC(pc experimental.ProgramCounter) uint6
cm := si.eng.compiledModuleOfAddr(upc)
return cm.getSourceOffset(upc)
}

// snapshot implements experimental.Snapshot
type snapshot struct {
sp, fp, top uintptr
returnAddress *byte
stack []byte
savedRegisters [64][2]uint64
ret []uint64
c *callEngine
}

// Snapshot implements the same method as documented on experimental.Snapshotter.
func (c *callEngine) Snapshot() experimental.Snapshot {
returnAddress := c.execCtx.goCallReturnAddress
oldTop, oldSp := c.stackTop, uintptr(unsafe.Pointer(c.execCtx.stackPointerBeforeGoCall))
newSP, newFP, newTop, newStack := c.cloneStack(uintptr(len(c.stack)) + 16)
adjustClonedStack(oldSp, oldTop, newSP, newFP, newTop)
return &snapshot{
sp: newSP,
fp: newFP,
top: newTop,
savedRegisters: c.execCtx.savedRegisters,
returnAddress: returnAddress,
stack: newStack,
c: c,
}
}

// Restore implements the same method as documented on experimental.Snapshot.
func (s *snapshot) Restore(ret []uint64) {
s.ret = ret
panic(s)
}

func (s *snapshot) doRestore() {
spp := *(**uint64)(unsafe.Pointer(&s.sp))
view := goCallStackView(spp)
copy(view, s.ret)

c := s.c
c.stack = s.stack
c.stackTop = s.top
ec := &c.execCtx
ec.stackBottomPtr = &c.stack[0]
ec.stackPointerBeforeGoCall = spp
ec.framePointerBeforeGoCall = s.fp
ec.goCallReturnAddress = s.returnAddress
ec.savedRegisters = s.savedRegisters
}

// Error implements the same method on error.
func (s *snapshot) Error() string {
return "unhandled snapshot restore, this generally indicates restore was called from a different " +
"exported function invocation than snapshot"
}

func snapshotRecoverFn(c *callEngine) {
if r := recover(); r != nil {
if s, ok := r.(*snapshot); ok && s.c == c {
s.doRestore()
} else {
panic(r)
}
}
}
6 changes: 3 additions & 3 deletions internal/engine/wazevo/isa.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,16 +45,16 @@ func goCallStackView(stackPointerBeforeGoCall *uint64) []uint64 {
}
}

// adjustStackAfterGrown is a function to adjust the stack after it is grown.
// adjustClonedStack is a function to adjust the stack after it is grown.
// More precisely, absolute addresses (frame pointers) in the stack must be adjusted.
func adjustStackAfterGrown(oldsp, oldTop, sp, fp, top uintptr) {
func adjustClonedStack(oldsp, oldTop, sp, fp, top uintptr) {
switch runtime.GOARCH {
case "arm64":
// TODO: currently, the frame pointers are not used, and saved old sps are relative to the current stack pointer,
// so no need to adjustment on arm64. However, when we make it absolute, which in my opinion is better perf-wise
// at the expense of slightly costly stack growth, we need to adjust the pushed frame pointers.
case "amd64":
amd64.AdjustStackAfterGrown(oldsp, oldTop, sp, fp, top)
amd64.AdjustClonedStack(oldsp, oldTop, sp, fp, top)
default:
panic("unsupported architecture")
}
Expand Down

0 comments on commit ac1961d

Please sign in to comment.