Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disallow re-entering wasm while off main stack #117

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions crates/continuations/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,13 @@ impl StackChain {
pub const ABSENT_DISCRIMINANT: usize = STACK_CHAIN_ABSENT_DISCRIMINANT;
pub const MAIN_STACK_DISCRIMINANT: usize = STACK_CHAIN_MAIN_STACK_DISCRIMINANT;
pub const CONTINUATION_DISCRIMINANT: usize = STACK_CHAIN_CONTINUATION_DISCRIMINANT;

pub fn is_main_stack(&self) -> bool {
match self {
StackChain::MainStack(_) => true,
_ => false,
}
}
}

#[repr(transparent)]
Expand Down
63 changes: 52 additions & 11 deletions crates/runtime/src/traphandlers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@ mod backtrace;
mod coredump;

use crate::sys::traphandlers;
use crate::{Instance, VMContext, VMRuntimeLimits};
use crate::{Instance, VMContext, VMOpaqueContext, VMRuntimeLimits};
use anyhow::Error;
use std::any::Any;
use std::cell::{Cell, UnsafeCell};
use std::mem::MaybeUninit;
use std::ptr;
use std::sync::Once;
use wasmtime_continuations::StackChainCell;

pub use self::backtrace::{Backtrace, Frame};
pub use self::coredump::CoreDumpStack;
Expand Down Expand Up @@ -207,22 +208,31 @@ pub unsafe fn catch_traps<'a, F>(
capture_backtrace: bool,
capture_coredump: bool,
caller: *mut VMContext,
callee: *mut VMOpaqueContext,
mut closure: F,
) -> Result<(), Box<Trap>>
where
F: FnMut(*mut VMContext),
{
let limits = Instance::from_vmctx(caller, |i| i.runtime_limits());

let result = CallThreadState::new(signal_handler, capture_backtrace, capture_coredump, *limits)
.with(|cx| {
traphandlers::wasmtime_setjmp(
cx.jmp_buf.as_ptr(),
call_closure::<F>,
&mut closure as *mut F as *mut u8,
caller,
)
});
let callee_stack_chain = VMContext::try_from_opaque(callee)
.map(|vmctx| Instance::from_vmctx(vmctx, |i| *i.stack_chain() as *const StackChainCell));

let result = CallThreadState::new(
signal_handler,
capture_backtrace,
capture_coredump,
*limits,
callee_stack_chain,
)
.with(|cx| {
traphandlers::wasmtime_setjmp(
cx.jmp_buf.as_ptr(),
call_closure::<F>,
&mut closure as *mut F as *mut u8,
caller,
)
});

return match result {
Ok(x) => Ok(x),
Expand All @@ -242,6 +252,31 @@ where
}
}

/// Returns true if the first `CallThreadState` in this thread's chain that
/// actually executes wasm is doing so inside a continuation. Returns false
/// if there is no `CallThreadState` executing wasm.
pub fn first_wasm_state_on_fiber_stack() -> bool {
tls::with(|head_state| {
// Iterate this threads' CallThreadState chain starting at `head_state`
// (if chain is non-empty), skipping those CTSs whose
// `callee_stack_chain` is None. This means that if `first_wasm_state`
// is Some, it is the first entry in the call thread state chain
// actually executin wasm.
let first_wasm_state = head_state
.iter()
.flat_map(|head| head.iter())
.skip_while(|state| state.callee_stack_chain.is_none())
.next();

first_wasm_state.map_or(false, |state| unsafe {
let stack_chain = &*state
.callee_stack_chain
.expect("must be Some according to filtering above");
!(*stack_chain.0.get()).is_main_stack()
})
})
}

// Module to hide visibility of the `CallThreadState::prev` field and force
// usage of its accessor methods.
mod call_thread_state {
Expand All @@ -259,6 +294,10 @@ mod call_thread_state {

pub(crate) limits: *const VMRuntimeLimits,

/// `Some(ptr)` iff this CallThreadState is for the execution of wasm.
/// In that case, `ptr` is the executing `Store`'s stack chain.
pub(crate) callee_stack_chain: Option<*const StackChainCell>,

pub(super) prev: Cell<tls::Ptr>,

// The values of `VMRuntimeLimits::last_wasm_{exit_{pc,fp},entry_sp}`
Expand Down Expand Up @@ -291,6 +330,7 @@ mod call_thread_state {
capture_backtrace: bool,
capture_coredump: bool,
limits: *const VMRuntimeLimits,
callee_stack_chain: Option<*const StackChainCell>,
) -> CallThreadState {
CallThreadState {
unwind: UnsafeCell::new(MaybeUninit::uninit()),
Expand All @@ -299,6 +339,7 @@ mod call_thread_state {
capture_backtrace,
capture_coredump,
limits,
callee_stack_chain,
prev: Cell::new(ptr::null()),
old_last_wasm_exit_fp: Cell::new(unsafe { *(*limits).last_wasm_exit_fp.get() }),
old_last_wasm_exit_pc: Cell::new(unsafe { *(*limits).last_wasm_exit_pc.get() }),
Expand Down
10 changes: 10 additions & 0 deletions crates/runtime/src/vmcontext.rs
Original file line number Diff line number Diff line change
Expand Up @@ -953,6 +953,16 @@ impl VMContext {
debug_assert_eq!((*opaque).magic, VMCONTEXT_MAGIC);
opaque.cast()
}

/// Alternative to `from_opaque` that returns `None` if the given opaque
/// context is not actually a `VMContext`.
pub unsafe fn try_from_opaque(opaque: *mut VMOpaqueContext) -> Option<*mut VMContext> {
if (*opaque).magic == VMCONTEXT_MAGIC {
Some(Self::from_opaque(opaque))
} else {
None
}
}
}

/// A "raw" and unsafe representation of a WebAssembly value.
Expand Down
58 changes: 49 additions & 9 deletions crates/wasmtime/src/runtime/func.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1033,15 +1033,19 @@ impl Func {
params_and_returns: *mut ValRaw,
params_and_returns_capacity: usize,
) -> Result<()> {
invoke_wasm_and_catch_traps(store, |caller| {
let func_ref = func_ref.as_ref();
(func_ref.array_call)(
func_ref.vmctx,
caller.cast::<VMOpaqueContext>(),
params_and_returns,
params_and_returns_capacity,
)
})
invoke_wasm_and_catch_traps(
store,
|caller| {
let func_ref = func_ref.as_ref();
(func_ref.array_call)(
func_ref.vmctx,
caller.cast::<VMOpaqueContext>(),
params_and_returns,
params_and_returns_capacity,
)
},
func_ref.as_ref().vmctx,
)
}

/// Converts the raw representation of a `funcref` into an `Option<Func>`
Expand Down Expand Up @@ -1533,8 +1537,43 @@ impl Func {
pub(crate) fn invoke_wasm_and_catch_traps<T>(
store: &mut StoreContextMut<'_, T>,
closure: impl FnMut(*mut VMContext),
callee: *mut VMOpaqueContext,
) -> Result<()> {
unsafe {
if VMContext::try_from_opaque(callee).is_some() {
// If we get here, the callee is a "proper" `VMContext`, and we are
// indeed calling into wasm.
//
// We now ensure that the following invariant holds (see
// wasmfx/wasmfxtime#109): Since we know that we are (re)-entering
// wasm, it must not be the case that we weren't still running
// inside a continuation when reaching this point. In other words,
// we must currently be on the main stack.
//
// We check this by inspecting this thread's chain of
// `CallThreadState`s, which is a linked list of all (nested)
// invocations of wasm (and certain host calls). If any of them are
// executing wasm, we raise an error.
// Since we are doing this check every time we enter wasm, it is
// sufficient to only look at the most recent previous invocation of
// wasm (i.e., we do not need to walk the entire `CallTheadState`
// chain, but only walk to the first such state corresponding to an
// execution of wasm).
//
// As a result, the call below is O(n), where n is the number of
// `CallThreadState`s at the beginning in this thread's CTS chain before
// the first such state that corresponds to wasm execution.
// In other words, n is the nesting level of calls to wrapped host
// functions from within a host function (e.g., calling `f.call()`
// while within a host call, where `f` is the result from wrapping a
// Rust function inside a `Func`).
if wasmtime_runtime::first_wasm_state_on_fiber_stack() {
return Err(anyhow::anyhow!(
"Re-entering wasm while already executing on a continuation stack"
));
}
}

let exit = enter_wasm(store);

if let Err(trap) = store.0.call_hook(CallHook::CallingWasm) {
Expand All @@ -1546,6 +1585,7 @@ pub(crate) fn invoke_wasm_and_catch_traps<T>(
store.0.engine().config().wasm_backtrace,
store.0.engine().config().coredump_on_trap,
store.0.default_caller(),
callee,
closure,
);
exit_wasm(store, exit);
Expand Down
25 changes: 17 additions & 8 deletions crates/wasmtime/src/runtime/func/typed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -187,16 +187,25 @@ where
// efficient to move in memory. This closure is actually invoked on the
// other side of a C++ shim, so it can never be inlined enough to make
// the memory go away, so the size matters here for performance.
let vmctx = unsafe { func.as_ref().vmctx };
let mut captures = (func, MaybeUninit::uninit(), params, false);

let result = invoke_wasm_and_catch_traps(store, |caller| {
let (func_ref, ret, params, returned) = &mut captures;
let func_ref = func_ref.as_ref();
let result =
Params::invoke::<Results>(func_ref.native_call, func_ref.vmctx, caller, *params);
ptr::write(ret.as_mut_ptr(), result);
*returned = true
});
let result = invoke_wasm_and_catch_traps(
store,
|caller| {
let (func_ref, ret, params, returned) = &mut captures;
let func_ref = func_ref.as_ref();
let result = Params::invoke::<Results>(
func_ref.native_call,
func_ref.vmctx,
caller,
*params,
);
ptr::write(ret.as_mut_ptr(), result);
*returned = true
},
vmctx,
);
let (_, ret, _, returned) = captures;
debug_assert_eq!(result.is_ok(), returned);
result?;
Expand Down
19 changes: 12 additions & 7 deletions crates/wasmtime/src/runtime/instance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -354,14 +354,19 @@ impl Instance {
let instance = store.0.instance_mut(id);
let f = instance.get_exported_func(start);
let caller_vmctx = instance.vmctx();
let callee_vmctx = unsafe { f.func_ref.as_ref().vmctx };
unsafe {
super::func::invoke_wasm_and_catch_traps(store, |_default_caller| {
let func = mem::transmute::<
NonNull<VMNativeCallFunction>,
extern "C" fn(*mut VMOpaqueContext, *mut VMContext),
>(f.func_ref.as_ref().native_call);
func(f.func_ref.as_ref().vmctx, caller_vmctx)
})?;
super::func::invoke_wasm_and_catch_traps(
store,
|_default_caller| {
let func = mem::transmute::<
NonNull<VMNativeCallFunction>,
extern "C" fn(*mut VMOpaqueContext, *mut VMContext),
>(f.func_ref.as_ref().native_call);
func(f.func_ref.as_ref().vmctx, caller_vmctx)
},
callee_vmctx,
)?;
}
Ok(())
}
Expand Down
Loading