diff --git a/crates/continuations/src/lib.rs b/crates/continuations/src/lib.rs index a8e75a1abd26..25dc7bf2ef85 100644 --- a/crates/continuations/src/lib.rs +++ b/crates/continuations/src/lib.rs @@ -150,6 +150,13 @@ impl StackChain { pub const ABSENT_DISCRIMINANT: usize = STACK_CHAIN_ABSENT_DISCRIMINANT; pub const MAIN_STACK_DISCRIMINANT: usize = STACK_CHAIN_MAIN_STACK_DISCRIMINANT; pub const CONTINUATION_DISCRIMINANT: usize = STACK_CHAIN_CONTINUATION_DISCRIMINANT; + + pub fn is_main_stack(&self) -> bool { + match self { + StackChain::MainStack(_) => true, + _ => false, + } + } } #[repr(transparent)] diff --git a/crates/runtime/src/traphandlers.rs b/crates/runtime/src/traphandlers.rs index 88cbf56090d1..3f0c9149e56e 100644 --- a/crates/runtime/src/traphandlers.rs +++ b/crates/runtime/src/traphandlers.rs @@ -5,13 +5,14 @@ mod backtrace; mod coredump; use crate::sys::traphandlers; -use crate::{Instance, VMContext, VMRuntimeLimits}; +use crate::{Instance, VMContext, VMOpaqueContext, VMRuntimeLimits}; use anyhow::Error; use std::any::Any; use std::cell::{Cell, UnsafeCell}; use std::mem::MaybeUninit; use std::ptr; use std::sync::Once; +use wasmtime_continuations::StackChainCell; pub use self::backtrace::{Backtrace, Frame}; pub use self::coredump::CoreDumpStack; @@ -207,22 +208,31 @@ pub unsafe fn catch_traps<'a, F>( capture_backtrace: bool, capture_coredump: bool, caller: *mut VMContext, + callee: *mut VMOpaqueContext, mut closure: F, ) -> Result<(), Box> where F: FnMut(*mut VMContext), { let limits = Instance::from_vmctx(caller, |i| i.runtime_limits()); - - let result = CallThreadState::new(signal_handler, capture_backtrace, capture_coredump, *limits) - .with(|cx| { - traphandlers::wasmtime_setjmp( - cx.jmp_buf.as_ptr(), - call_closure::, - &mut closure as *mut F as *mut u8, - caller, - ) - }); + let callee_stack_chain = VMContext::try_from_opaque(callee) + .map(|vmctx| Instance::from_vmctx(vmctx, |i| *i.stack_chain() as *const StackChainCell)); + + let result = CallThreadState::new( + signal_handler, + capture_backtrace, + capture_coredump, + *limits, + callee_stack_chain, + ) + .with(|cx| { + traphandlers::wasmtime_setjmp( + cx.jmp_buf.as_ptr(), + call_closure::, + &mut closure as *mut F as *mut u8, + caller, + ) + }); return match result { Ok(x) => Ok(x), @@ -242,6 +252,31 @@ where } } +/// Returns true if the first `CallThreadState` in this thread's chain that +/// actually executes wasm is doing so inside a continuation. Returns false +/// if there is no `CallThreadState` executing wasm. +pub fn first_wasm_state_on_fiber_stack() -> bool { + tls::with(|head_state| { + // Iterate this threads' CallThreadState chain starting at `head_state` + // (if chain is non-empty), skipping those CTSs whose + // `callee_stack_chain` is None. This means that if `first_wasm_state` + // is Some, it is the first entry in the call thread state chain + // actually executin wasm. + let first_wasm_state = head_state + .iter() + .flat_map(|head| head.iter()) + .skip_while(|state| state.callee_stack_chain.is_none()) + .next(); + + first_wasm_state.map_or(false, |state| unsafe { + let stack_chain = &*state + .callee_stack_chain + .expect("must be Some according to filtering above"); + !(*stack_chain.0.get()).is_main_stack() + }) + }) +} + // Module to hide visibility of the `CallThreadState::prev` field and force // usage of its accessor methods. mod call_thread_state { @@ -259,6 +294,10 @@ mod call_thread_state { pub(crate) limits: *const VMRuntimeLimits, + /// `Some(ptr)` iff this CallThreadState is for the execution of wasm. + /// In that case, `ptr` is the executing `Store`'s stack chain. + pub(crate) callee_stack_chain: Option<*const StackChainCell>, + pub(super) prev: Cell, // The values of `VMRuntimeLimits::last_wasm_{exit_{pc,fp},entry_sp}` @@ -291,6 +330,7 @@ mod call_thread_state { capture_backtrace: bool, capture_coredump: bool, limits: *const VMRuntimeLimits, + callee_stack_chain: Option<*const StackChainCell>, ) -> CallThreadState { CallThreadState { unwind: UnsafeCell::new(MaybeUninit::uninit()), @@ -299,6 +339,7 @@ mod call_thread_state { capture_backtrace, capture_coredump, limits, + callee_stack_chain, prev: Cell::new(ptr::null()), old_last_wasm_exit_fp: Cell::new(unsafe { *(*limits).last_wasm_exit_fp.get() }), old_last_wasm_exit_pc: Cell::new(unsafe { *(*limits).last_wasm_exit_pc.get() }), diff --git a/crates/runtime/src/vmcontext.rs b/crates/runtime/src/vmcontext.rs index 5b03966c11f0..2e8d9d378919 100644 --- a/crates/runtime/src/vmcontext.rs +++ b/crates/runtime/src/vmcontext.rs @@ -953,6 +953,16 @@ impl VMContext { debug_assert_eq!((*opaque).magic, VMCONTEXT_MAGIC); opaque.cast() } + + /// Alternative to `from_opaque` that returns `None` if the given opaque + /// context is not actually a `VMContext`. + pub unsafe fn try_from_opaque(opaque: *mut VMOpaqueContext) -> Option<*mut VMContext> { + if (*opaque).magic == VMCONTEXT_MAGIC { + Some(Self::from_opaque(opaque)) + } else { + None + } + } } /// A "raw" and unsafe representation of a WebAssembly value. diff --git a/crates/wasmtime/src/runtime/func.rs b/crates/wasmtime/src/runtime/func.rs index 2403534f1d73..d7b6992d85db 100644 --- a/crates/wasmtime/src/runtime/func.rs +++ b/crates/wasmtime/src/runtime/func.rs @@ -1033,15 +1033,19 @@ impl Func { params_and_returns: *mut ValRaw, params_and_returns_capacity: usize, ) -> Result<()> { - invoke_wasm_and_catch_traps(store, |caller| { - let func_ref = func_ref.as_ref(); - (func_ref.array_call)( - func_ref.vmctx, - caller.cast::(), - params_and_returns, - params_and_returns_capacity, - ) - }) + invoke_wasm_and_catch_traps( + store, + |caller| { + let func_ref = func_ref.as_ref(); + (func_ref.array_call)( + func_ref.vmctx, + caller.cast::(), + params_and_returns, + params_and_returns_capacity, + ) + }, + func_ref.as_ref().vmctx, + ) } /// Converts the raw representation of a `funcref` into an `Option` @@ -1533,8 +1537,43 @@ impl Func { pub(crate) fn invoke_wasm_and_catch_traps( store: &mut StoreContextMut<'_, T>, closure: impl FnMut(*mut VMContext), + callee: *mut VMOpaqueContext, ) -> Result<()> { unsafe { + if VMContext::try_from_opaque(callee).is_some() { + // If we get here, the callee is a "proper" `VMContext`, and we are + // indeed calling into wasm. + // + // We now ensure that the following invariant holds (see + // wasmfx/wasmfxtime#109): Since we know that we are (re)-entering + // wasm, it must not be the case that we weren't still running + // inside a continuation when reaching this point. In other words, + // we must currently be on the main stack. + // + // We check this by inspecting this thread's chain of + // `CallThreadState`s, which is a linked list of all (nested) + // invocations of wasm (and certain host calls). If any of them are + // executing wasm, we raise an error. + // Since we are doing this check every time we enter wasm, it is + // sufficient to only look at the most recent previous invocation of + // wasm (i.e., we do not need to walk the entire `CallTheadState` + // chain, but only walk to the first such state corresponding to an + // execution of wasm). + // + // As a result, the call below is O(n), where n is the number of + // `CallThreadState`s at the beginning in this thread's CTS chain before + // the first such state that corresponds to wasm execution. + // In other words, n is the nesting level of calls to wrapped host + // functions from within a host function (e.g., calling `f.call()` + // while within a host call, where `f` is the result from wrapping a + // Rust function inside a `Func`). + if wasmtime_runtime::first_wasm_state_on_fiber_stack() { + return Err(anyhow::anyhow!( + "Re-entering wasm while already executing on a continuation stack" + )); + } + } + let exit = enter_wasm(store); if let Err(trap) = store.0.call_hook(CallHook::CallingWasm) { @@ -1546,6 +1585,7 @@ pub(crate) fn invoke_wasm_and_catch_traps( store.0.engine().config().wasm_backtrace, store.0.engine().config().coredump_on_trap, store.0.default_caller(), + callee, closure, ); exit_wasm(store, exit); diff --git a/crates/wasmtime/src/runtime/func/typed.rs b/crates/wasmtime/src/runtime/func/typed.rs index 4cf926061292..5cbc298f1add 100644 --- a/crates/wasmtime/src/runtime/func/typed.rs +++ b/crates/wasmtime/src/runtime/func/typed.rs @@ -187,16 +187,25 @@ where // efficient to move in memory. This closure is actually invoked on the // other side of a C++ shim, so it can never be inlined enough to make // the memory go away, so the size matters here for performance. + let vmctx = unsafe { func.as_ref().vmctx }; let mut captures = (func, MaybeUninit::uninit(), params, false); - let result = invoke_wasm_and_catch_traps(store, |caller| { - let (func_ref, ret, params, returned) = &mut captures; - let func_ref = func_ref.as_ref(); - let result = - Params::invoke::(func_ref.native_call, func_ref.vmctx, caller, *params); - ptr::write(ret.as_mut_ptr(), result); - *returned = true - }); + let result = invoke_wasm_and_catch_traps( + store, + |caller| { + let (func_ref, ret, params, returned) = &mut captures; + let func_ref = func_ref.as_ref(); + let result = Params::invoke::( + func_ref.native_call, + func_ref.vmctx, + caller, + *params, + ); + ptr::write(ret.as_mut_ptr(), result); + *returned = true + }, + vmctx, + ); let (_, ret, _, returned) = captures; debug_assert_eq!(result.is_ok(), returned); result?; diff --git a/crates/wasmtime/src/runtime/instance.rs b/crates/wasmtime/src/runtime/instance.rs index c15a6020aaa0..300fc7bbf199 100644 --- a/crates/wasmtime/src/runtime/instance.rs +++ b/crates/wasmtime/src/runtime/instance.rs @@ -354,14 +354,19 @@ impl Instance { let instance = store.0.instance_mut(id); let f = instance.get_exported_func(start); let caller_vmctx = instance.vmctx(); + let callee_vmctx = unsafe { f.func_ref.as_ref().vmctx }; unsafe { - super::func::invoke_wasm_and_catch_traps(store, |_default_caller| { - let func = mem::transmute::< - NonNull, - extern "C" fn(*mut VMOpaqueContext, *mut VMContext), - >(f.func_ref.as_ref().native_call); - func(f.func_ref.as_ref().vmctx, caller_vmctx) - })?; + super::func::invoke_wasm_and_catch_traps( + store, + |_default_caller| { + let func = mem::transmute::< + NonNull, + extern "C" fn(*mut VMOpaqueContext, *mut VMContext), + >(f.func_ref.as_ref().native_call); + func(f.func_ref.as_ref().vmctx, caller_vmctx) + }, + callee_vmctx, + )?; } Ok(()) } diff --git a/tests/all/typed_continuations.rs b/tests/all/typed_continuations.rs index ca62b58f89f9..2c11845dc320 100644 --- a/tests/all/typed_continuations.rs +++ b/tests/all/typed_continuations.rs @@ -1,107 +1,207 @@ use anyhow::Result; use wasmtime::*; -use wasmtime_wasi::*; -struct WasiHostCtx { - preview2_ctx: WasiCtx, - preview2_table: wasmtime::component::ResourceTable, - preview1_adapter: preview1::WasiPreview1Adapter, -} +mod test_utils { + use anyhow::{bail, Result}; + use wasmtime::*; -impl WasiView for WasiHostCtx { - fn table(&mut self) -> &mut wasmtime::component::ResourceTable { - &mut self.preview2_table + pub struct Runner { + pub engine: Engine, + pub store: Store<()>, } - fn ctx(&mut self) -> &mut WasiCtx { - &mut self.preview2_ctx + impl Runner { + pub fn new() -> Runner { + let mut config = Config::default(); + config.wasm_function_references(true); + config.wasm_exceptions(true); + config.wasm_typed_continuations(true); + + let engine = Engine::new(&config).unwrap(); + + let store = Store::<()>::new(&engine, ()); + + Runner { engine, store } + } + + /// Uses this `Runner` to run the module defined in `wat`, satisfying + /// its imports using `imports`. The module must export a function + /// `entry`, taking no parameters and returning `Results`. + pub fn run_test( + mut self, + wat: &str, + imports: &[Extern], + ) -> Result { + let module = Module::new(&self.engine, wat)?; + + let instance = Instance::new(&mut self.store, &module, imports)?; + let entry = instance.get_typed_func::<(), Results>(&mut self.store, "entry")?; + + entry.call(&mut self.store, ()) + } + + /// Uses this `Runner` to run the module defined in `wat`, satisfying + /// its imports using `imports`. The module must export a function + /// `entry`, taking no parameters and returning `Results`. Execution of + /// `entry` is expected to yield a runtime `Error` with a &str payload + /// (such as an error raised with anyhow::anyhow!("Something is wrong") + pub fn run_test_expect_str_error(self, wat: &str, imports: &[Extern], error_message: &str) { + let result = self.run_test::<()>(wat, imports); + + let err = result.expect_err("Was expecting wasm execution to yield error"); + + assert_eq!(err.downcast_ref::<&'static str>(), Some(&error_message)); + } } -} -impl preview1::WasiPreview1View for WasiHostCtx { - fn adapter(&self) -> &preview1::WasiPreview1Adapter { - &self.preview1_adapter + /// Creates a simple Host function that increments an i32 + pub fn make_i32_inc_host_func(runner: &mut Runner) -> Func { + Func::new( + &mut runner.store, + FuncType::new(&runner.engine, vec![ValType::I32], vec![ValType::I32]), + |mut _caller, args: &[Val], results: &mut [Val]| { + let res = match args { + [Val::I32(i)] => i + 1, + _ => bail!("Error: Received illegal argument (should be single i32)"), + }; + results[0] = Val::I32(res); + Ok(()) + }, + ) } - fn adapter_mut(&mut self) -> &mut preview1::WasiPreview1Adapter { - &mut self.preview1_adapter + /// Creates a host function of type i32 -> i32. `export_func` must denote an + /// exported function of type i32 -> i32. The created host function + /// increments its argument by 1, passes it to the exported function, and in + /// turn increments the result before returning it as the overall result. + pub fn make_i32_inc_via_export_host_func( + runner: &mut Runner, + export_func: &'static str, + ) -> Func { + Func::new( + &mut runner.store, + FuncType::new(&runner.engine, vec![ValType::I32], vec![ValType::I32]), + |mut caller, args: &[Val], results: &mut [Val]| { + let export = caller + .get_export(export_func) + .ok_or(anyhow::anyhow!("could not get export"))?; + let func = export + .into_func() + .ok_or(anyhow::anyhow!("export is not a Func"))?; + let func_typed = func.typed::(caller.as_context())?; + let arg = args[0].unwrap_i32(); + let res = func_typed.call(caller.as_context_mut(), arg + 1)?; + results[0] = Val::I32(res + 1); + Ok(()) + }, + ) } } -fn run_wasi_test(wat: &'static str) -> Result { - // Construct the wasm engine with async support disabled. - let mut config = Config::new(); - config - .async_support(false) - .wasm_exceptions(true) - .wasm_function_references(true) - .wasm_typed_continuations(true); - let engine = Engine::new(&config)?; +mod wasi { + use anyhow::Result; + use wasmtime::*; + use wasmtime_wasi::*; + struct WasiHostCtx { + preview2_ctx: WasiCtx, + preview2_table: wasmtime::component::ResourceTable, + preview1_adapter: preview1::WasiPreview1Adapter, + } - // Add the WASI preview1 API to the linker (will be implemented in terms of - // the preview2 API) - let mut linker: Linker = Linker::new(&engine); - preview1::add_to_linker_sync(&mut linker)?; - - // Add capabilities (e.g. filesystem access) to the WASI preview2 context here. - let wasi_ctx = WasiCtxBuilder::new().inherit_stdio().build(); - - let host_ctx = WasiHostCtx { - preview2_ctx: wasi_ctx, - preview2_table: ResourceTable::new(), - preview1_adapter: preview1::WasiPreview1Adapter::new(), - }; - let mut store: Store = Store::new(&engine, host_ctx); - - // Instantiate our wasm module. - let module = Module::new(&engine, wat)?; - let func = linker - .module(&mut store, "", &module)? - .get_default(&mut store, "")? - .typed::<(), i32>(&store)?; - - // Invoke the WASI program default function. - func.call(&mut store, ()) -} + impl WasiView for WasiHostCtx { + fn table(&mut self) -> &mut wasmtime::component::ResourceTable { + &mut self.preview2_table + } -async fn run_wasi_test_async(wat: &'static str) -> Result { - // Construct the wasm engine with async support enabled. - let mut config = Config::new(); - config - .async_support(true) - .wasm_exceptions(true) - .wasm_function_references(true) - .wasm_typed_continuations(true); - let engine = Engine::new(&config)?; + fn ctx(&mut self) -> &mut WasiCtx { + &mut self.preview2_ctx + } + } - // Add the WASI preview1 API to the linker (will be implemented in terms of - // the preview2 API) - let mut linker: Linker = Linker::new(&engine); - preview1::add_to_linker_async(&mut linker)?; - - // Add capabilities (e.g. filesystem access) to the WASI preview2 context here. - let wasi_ctx = WasiCtxBuilder::new().inherit_stdio().build(); - - let host_ctx = WasiHostCtx { - preview2_ctx: wasi_ctx, - preview2_table: ResourceTable::new(), - preview1_adapter: preview1::WasiPreview1Adapter::new(), - }; - let mut store: Store = Store::new(&engine, host_ctx); - - // Instantiate our wasm module. - let module = Module::new(&engine, wat)?; - let func = linker - .module_async(&mut store, "", &module) - .await? - .get_default(&mut store, "")? - .typed::<(), i32>(&store)?; - - // Invoke the WASI program default function. - func.call_async(&mut store, ()).await -} + impl preview1::WasiPreview1View for WasiHostCtx { + fn adapter(&self) -> &preview1::WasiPreview1Adapter { + &self.preview1_adapter + } -static WRITE_SOMETHING_WAT: &'static str = &r#" + fn adapter_mut(&mut self) -> &mut preview1::WasiPreview1Adapter { + &mut self.preview1_adapter + } + } + + fn run_wasi_test(wat: &'static str) -> Result { + // Construct the wasm engine with async support disabled. + let mut config = Config::new(); + config + .async_support(false) + .wasm_exceptions(true) + .wasm_function_references(true) + .wasm_typed_continuations(true); + let engine = Engine::new(&config)?; + + // Add the WASI preview1 API to the linker (will be implemented in terms of + // the preview2 API) + let mut linker: Linker = Linker::new(&engine); + preview1::add_to_linker_sync(&mut linker)?; + + // Add capabilities (e.g. filesystem access) to the WASI preview2 context here. + let wasi_ctx = WasiCtxBuilder::new().inherit_stdio().build(); + + let host_ctx = WasiHostCtx { + preview2_ctx: wasi_ctx, + preview2_table: ResourceTable::new(), + preview1_adapter: preview1::WasiPreview1Adapter::new(), + }; + let mut store: Store = Store::new(&engine, host_ctx); + + // Instantiate our wasm module. + let module = Module::new(&engine, wat)?; + let func = linker + .module(&mut store, "", &module)? + .get_default(&mut store, "")? + .typed::<(), i32>(&store)?; + + // Invoke the WASI program default function. + func.call(&mut store, ()) + } + + async fn run_wasi_test_async(wat: &'static str) -> Result { + // Construct the wasm engine with async support enabled. + let mut config = Config::new(); + config + .async_support(true) + .wasm_exceptions(true) + .wasm_function_references(true) + .wasm_typed_continuations(true); + let engine = Engine::new(&config)?; + + // Add the WASI preview1 API to the linker (will be implemented in terms of + // the preview2 API) + let mut linker: Linker = Linker::new(&engine); + preview1::add_to_linker_async(&mut linker)?; + + // Add capabilities (e.g. filesystem access) to the WASI preview2 context here. + let wasi_ctx = WasiCtxBuilder::new().inherit_stdio().build(); + + let host_ctx = WasiHostCtx { + preview2_ctx: wasi_ctx, + preview2_table: ResourceTable::new(), + preview1_adapter: preview1::WasiPreview1Adapter::new(), + }; + let mut store: Store = Store::new(&engine, host_ctx); + + // Instantiate our wasm module. + let module = Module::new(&engine, wat)?; + let func = linker + .module_async(&mut store, "", &module) + .await? + .get_default(&mut store, "")? + .typed::<(), i32>(&store)?; + + // Invoke the WASI program default function. + func.call_async(&mut store, ()).await + } + + static WRITE_SOMETHING_WAT: &'static str = &r#" (module (type $ft (func (result i32))) (type $ct (cont $ft)) @@ -140,19 +240,19 @@ static WRITE_SOMETHING_WAT: &'static str = &r#" ) )"#; -#[test] -fn write_something_test() -> Result<()> { - assert_eq!(run_wasi_test(WRITE_SOMETHING_WAT)?, 0); - Ok(()) -} + #[test] + fn write_something_test() -> Result<()> { + assert_eq!(run_wasi_test(WRITE_SOMETHING_WAT)?, 0); + Ok(()) + } -#[tokio::test] -async fn write_something_test_async() -> Result<()> { - assert_eq!(run_wasi_test_async(WRITE_SOMETHING_WAT).await?, 0); - Ok(()) -} + #[tokio::test] + async fn write_something_test_async() -> Result<()> { + assert_eq!(run_wasi_test_async(WRITE_SOMETHING_WAT).await?, 0); + Ok(()) + } -static SCHED_YIELD_WAT: &'static str = r#" + static SCHED_YIELD_WAT: &'static str = r#" (module (type $ft (func (result i32))) (type $ct (cont $ft)) @@ -171,16 +271,17 @@ static SCHED_YIELD_WAT: &'static str = r#" ) )"#; -#[test] -fn sched_yield_test() -> Result<()> { - assert_eq!(run_wasi_test(SCHED_YIELD_WAT)?, 0); - Ok(()) -} + #[test] + fn sched_yield_test() -> Result<()> { + assert_eq!(run_wasi_test(SCHED_YIELD_WAT)?, 0); + Ok(()) + } -#[tokio::test] -async fn sched_yield_test_async() -> Result<()> { - assert_eq!(run_wasi_test_async(SCHED_YIELD_WAT).await?, 0); - Ok(()) + #[tokio::test] + async fn sched_yield_test_async() -> Result<()> { + assert_eq!(run_wasi_test_async(SCHED_YIELD_WAT).await?, 0); + Ok(()) + } } /// Test that we can handle a `suspend` from another instance. Note that this @@ -255,3 +356,291 @@ fn inter_instance_suspend() -> Result<()> { Ok(()) } + +/// Tests interaction with host functions. +mod host { + use super::test_utils::*; + use wasmtime::*; + + const RE_ENTER_ON_CONTINUATION_ERROR: &'static str = + "Re-entering wasm while already executing on a continuation stack"; + + #[test] + /// Tests calling a host function from within a wasm function running inside a continuation. + /// Call chain: + /// $entry -resume-> a -call-> host_func_a + fn call_host_from_continuation() -> Result<()> { + let wat = r#" + (module + (type $ft (func (result i32))) + (type $ct (cont $ft)) + + (import "" "" (func $host_func_a (param i32) (result i32))) + + (func $a (export "a") (result i32) + (call $host_func_a (i32.const 122)) + ) + (func $entry (export "entry") (result i32) + (resume $ct (cont.new $ct (ref.func $a))) + ) + ) + "#; + + let mut runner = Runner::new(); + + let host_func_a = make_i32_inc_host_func(&mut runner); + + let result = runner.run_test::(wat, &[host_func_a.into()]).unwrap(); + assert_eq!(result, 123); + Ok(()) + } + + #[test] + /// We re-enter wasm from a host function and execute a continuation. + /// Call chain: + /// $entry -call-> $a -call-> $host_func_a -call-> $b -resume-> $c + fn re_enter_wasm_ok1() -> Result<()> { + let wat = r#" + (module + (type $ft (func (param i32) (result i32))) + (type $ct (cont $ft)) + + (import "" "" (func $host_func_a (param i32) (result i32))) + + + (func $a (export "a") (param $x i32) (result i32) + (call $host_func_a (local.get $x)) + ) + + (func $b (export "b") (param $x i32) (result i32) + (resume $ct (local.get $x) (cont.new $ct (ref.func $c))) + ) + + (func $c (export "c") (param $x i32) (result i32) + (return (i32.add (local.get $x) (i32.const 1))) + ) + + + (func $entry (export "entry") (result i32) + (call $a (i32.const 120)) + ) + ) + "#; + + let mut runner = Runner::new(); + + let host_func_a = make_i32_inc_via_export_host_func(&mut runner, "b"); + + let result = runner.run_test::(wat, &[host_func_a.into()]).unwrap(); + assert_eq!(result, 123); + Ok(()) + } + + #[test] + /// Similar to `re_enter_wasm_ok2, but we run a continuation before the host call. + /// Call chain: + /// $entry -call-> $a -call-> $host_func_a -call-> $b -resume-> $c + fn re_enter_wasm_ok2() -> Result<()> { + let wat = r#" + (module + (type $ft (func (param i32) (result i32))) + (type $ct (cont $ft)) + + (import "" "" (func $host_func_a (param i32) (result i32))) + + + (func $a (export "a") (param $x i32) (result i32) + ;; Running continuation before calling into host is fine + (resume $ct (local.get $x) (cont.new $ct (ref.func $c))) + (drop) + + (call $host_func_a (local.get $x)) + ) + + (func $b (export "b") (param $x i32) (result i32) + (resume $ct (local.get $x) (cont.new $ct (ref.func $c))) + ) + + (func $c (export "c") (param $x i32) (result i32) + (return (i32.add (local.get $x) (i32.const 1))) + ) + + + (func $entry (export "entry") (result i32) + (call $a (i32.const 120)) + ) + ) + "#; + + let mut runner = Runner::new(); + + let host_func_a = make_i32_inc_via_export_host_func(&mut runner, "b"); + + let result = runner.run_test::(wat, &[host_func_a.into()]).unwrap(); + assert_eq!(result, 123); + Ok(()) + } + + #[cfg_attr(feature = "typed_continuations_baseline_implementation", ignore)] + #[test] + /// We re-enter wasm from a host function while we were already on a continuation stack. + /// This is currently forbidden (see wasmfx/wasmfxtime#109), but may be + /// allowed in the future. + /// Call chain: + /// $entry -resume-> $a -call-> $host_func_a -call-> $b + fn re_enter_wasm_bad() -> Result<()> { + let wat = r#" + (module + (type $ft (func (param i32) (result i32))) + (type $ct (cont $ft)) + + (import "" "" (func $host_func_a (param i32) (result i32))) + + + (func $a (export "a") (param $x i32) (result i32) + (call $host_func_a (local.get $x)) + ) + + + (func $b (export "b") (param $x i32) (result i32) + (return (i32.add (local.get $x) (i32.const 1))) + ) + + + (func $entry (export "entry") + (resume $ct (i32.const 120) (cont.new $ct (ref.func $a))) + (drop) + ) + ) + "#; + let mut runner = Runner::new(); + + let host_func_a = make_i32_inc_via_export_host_func(&mut runner, "b"); + + runner.run_test_expect_str_error( + &wat, + &[host_func_a.into()], + RE_ENTER_ON_CONTINUATION_ERROR, + ); + Ok(()) + } + + #[cfg_attr(feature = "typed_continuations_baseline_implementation", ignore)] + #[test] + /// After crossing from the host back into wasm, we suspend to a tag that is + /// handled by the surrounding function (i.e., without needing to cross the + /// host frame to reach the handler). + /// This is currently forbidden (see wasmfx/wasmfxtime#109), but could be + /// allowed in the future. + /// Call chain: + /// $entry -resume-> $a -call-> $host_func_a -call-> $b -resume-> $c + fn call_host_from_continuation_nested_suspend_ok() -> Result<()> { + let wat = r#" + (module + (type $ft (func (param i32) (result i32))) + (type $ct (cont $ft)) + (tag $t (result i32)) + + (import "" "" (func $host_func_a (param i32) (result i32))) + + + (func $a (export "a") (param $x i32) (result i32) + (call $host_func_a (local.get $x)) + ) + + + (func $b (export "b") (param $x i32) (result i32) + (block $h (result (ref $ct)) + (resume $ct (tag $t $h) (local.get $x) (cont.new $ct (ref.func $c))) + (unreachable) + ) + (drop) + ;; note that we do not run the continuation to completion + (i32.add (local.get $x) (i32.const 1)) + ) + + (func $c (export "c") (param $x i32) (result i32) + (suspend $t) + ) + + + (func $entry (export "entry") + (resume $ct (i32.const 120) (cont.new $ct (ref.func $a))) + (drop) + ) + ) + "#; + + let mut runner = Runner::new(); + + let host_func_a = make_i32_inc_via_export_host_func(&mut runner, "b"); + + runner.run_test_expect_str_error( + &wat, + &[host_func_a.into()], + RE_ENTER_ON_CONTINUATION_ERROR, + ); + Ok(()) + } + + #[cfg_attr(feature = "typed_continuations_baseline_implementation", ignore)] + #[test] + /// Similar to `call_host_from_continuation_nested_suspend_ok`. However, + /// we suspend to a tag that is only handled if we were to cross a host function + /// boundary. + /// This currently triggers the check that we must not re-enter wasm while + /// on a continuation (see wasmfx/wasmfxtime#109), but will most likely stay + /// forbidden if host calls acts as barriers for suspensions. In that case, + /// the test case will exhibit a case of suspending to an unhandled tag. + /// + /// Call chain: + /// $entry -resume-> $a -call-> $host_func_a -call-> $b -resume-> $c + fn call_host_from_continuation_nested_suspend_unhandled() -> Result<()> { + let wat = r#" + (module + (type $ft (func (param i32) (result i32))) + (type $ct (cont $ft)) + (tag $t (result i32)) + + (import "" "" (func $host_func_a (param i32) (result i32))) + + + (func $a (export "a") (param $x i32) (result i32) + (call $host_func_a (local.get $x)) + ) + + + (func $b (export "b") (param $x i32) (result i32) + (resume $ct (local.get $x) (cont.new $ct (ref.func $c))) + ) + + (func $c (export "c") (param $x i32) (result i32) + (suspend $t) + ) + + + (func $entry (export "entry") + (block $h (result (ref $ct)) + (return + (resume $ct + (tag $t $h) + (i32.const 123) + (cont.new $ct (ref.func $a)))) + ) + (drop) + ) + ) + "#; + + let mut runner = Runner::new(); + + let host_func_a = make_i32_inc_via_export_host_func(&mut runner, "b"); + + runner.run_test_expect_str_error( + &wat, + &[host_func_a.into()], + RE_ENTER_ON_CONTINUATION_ERROR, + ); + Ok(()) + } +}