Skip to content

Commit

Permalink
Support for inter-instance stack switching (#108)
Browse files Browse the repository at this point in the history
There is a currently a conceptual error in the implementation, where the
chain of active stacks (= continuations + the main stack) is stored per
`Instance`/`VMContext`.

This means that when a function `$f` calls an imported function `$g`,
where `$f` and `$g` are not part of the same instance, their stack
chains are completely separate. For example, it is not possible to
`suspend` to a tag `$t` in `$g` and handle this in a resume block in
`$f` (assuming that the underlying modules import and export the tag
`$t` so that it is shared between the two).

This PR rectifies this situation by sharing a single `StackChain` object
between all instances of the same `Store`. The `VMContext` then contains
merely a pointer to this shared chain, rather than a chain of its own.
This fully mirrors how the `VMRuntimeLimits` are already shared between
all instances of a `Store`.
  • Loading branch information
frank-emrich authored Feb 16, 2024
1 parent 4df0bc5 commit 82d8569
Show file tree
Hide file tree
Showing 9 changed files with 193 additions and 51 deletions.
21 changes: 20 additions & 1 deletion crates/continuations/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::ptr;
use std::{cell::UnsafeCell, ptr};
use wasmtime_fibre::Fiber;

/// TODO
Expand Down Expand Up @@ -132,6 +132,25 @@ impl StackChain {
pub const CONTINUATION_DISCRIMINANT: usize = STACK_CHAIN_CONTINUATION_DISCRIMINANT;
}

#[repr(transparent)]
pub struct StackChainCell(pub UnsafeCell<StackChain>);

impl StackChainCell {
pub fn absent() -> Self {
StackChainCell(UnsafeCell::new(StackChain::Absent))
}
}

// Since `StackChainCell` and `StackLimits` objects appear in the `StoreOpaque`,
// they need to be `Send` and `Sync`.
// This is safe for the same reason it is for `VMRuntimeLimits` (see comment
// there): Both types are pod-type with no destructor, and we don't access any
// of their fields from other threads.
unsafe impl Send for StackLimits {}
unsafe impl Sync for StackLimits {}
unsafe impl Send for StackChainCell {}
unsafe impl Sync for StackChainCell {}

pub struct Payloads {
/// Number of currently occupied slots.
pub length: types::payloads::Length,
Expand Down
26 changes: 24 additions & 2 deletions crates/cranelift/src/wasmfx/optimized.rs
Original file line number Diff line number Diff line change
Expand Up @@ -762,7 +762,20 @@ pub(crate) mod typed_continuation_helpers {

let offset =
i32::try_from(env.offsets.vmctx_typed_continuations_stack_chain()).unwrap();
StackChain::load(env, builder, base_addr, offset, self.pointer_type)

// The `typed_continuations_stack_chain` field of the VMContext only
// contains a pointer to the `StackChainCell` in the `Store`.
// The pointer never changes through the liftime of a `VMContext`,
// which is why this load is `readonly`.
// TODO(frank-emrich) Consider turning this pointer into a global
// variable, similar to `env.vmruntime_limits_ptr`.
let memflags = ir::MemFlags::trusted().with_readonly();
let stack_chain_ptr =
builder
.ins()
.load(self.pointer_type, memflags, base_addr, offset);

StackChain::load(env, builder, stack_chain_ptr, 0, self.pointer_type)
}

/// Stores the given stack chain saved in this `VMContext`, overwriting
Expand All @@ -777,7 +790,16 @@ pub(crate) mod typed_continuation_helpers {

let offset =
i32::try_from(env.offsets.vmctx_typed_continuations_stack_chain()).unwrap();
stack_chain.store(env, builder, base_addr, offset)

// Same situation as in `load_stack_chain` regarding pointer
// indirection and it being `readonly`.
let memflags = ir::MemFlags::trusted().with_readonly();
let stack_chain_ptr =
builder
.ins()
.load(self.pointer_type, memflags, base_addr, offset);

stack_chain.store(env, builder, stack_chain_ptr, 0)
}

/// Similar to `store_stack_chain`, but instead of storing an arbitrary
Expand Down
24 changes: 4 additions & 20 deletions crates/environ/src/vmoffsets.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,9 @@ pub struct VMOffsets<P> {
defined_func_refs: u32,
size: u32,

// The following field stores a value of type
// `wasmtime_continuations::StackLimits`.
typed_continuations_main_stack_limits: u32,
// The following field stores a value of type
// `wasmtime_continuations::StackChain`. The head of the chain is the
// The following field stores a pointer into the StoreOpauqe, to value of
// type `wasmtime_continuations::StackChain`.
// The head of the chain is the
// currently executing stack (main stack or a continuation).
typed_continuations_stack_chain: u32,
typed_continuations_payloads: u32,
Expand Down Expand Up @@ -363,7 +361,6 @@ impl<P: PtrSize> VMOffsets<P> {
calculate_sizes! {
typed_continuations_payloads: "typed continuations payloads object",
typed_continuations_stack_chain: "typed continuations stack chain",
typed_continuations_main_stack_limits: "typed continuations main stack limits",
defined_func_refs: "module functions",
defined_globals: "defined globals",
owned_memories: "owned memories",
Expand Down Expand Up @@ -416,7 +413,6 @@ impl<P: PtrSize> From<VMOffsetsFields<P>> for VMOffsets<P> {
defined_globals: 0,
defined_func_refs: 0,
size: 0,
typed_continuations_main_stack_limits: 0,
typed_continuations_stack_chain: 0,
typed_continuations_payloads: 0,
};
Expand Down Expand Up @@ -482,14 +478,8 @@ impl<P: PtrSize> From<VMOffsetsFields<P>> for VMOffsets<P> {
ret.ptr.size_of_vm_func_ref(),
),

align(std::mem::align_of::<wasmtime_continuations::StackLimits>() as u32),
size(typed_continuations_main_stack_limits)
= std::mem::size_of::<wasmtime_continuations::StackLimits>() as u32,

align(std::mem::align_of::<wasmtime_continuations::StackChain>() as u32),
size(typed_continuations_stack_chain)
= std::mem::size_of::<wasmtime_continuations::StackChain>() as u32,

= ret.ptr.size(),
align(std::mem::align_of::<wasmtime_continuations::Payloads>() as u32),
size(typed_continuations_payloads) =
std::mem::size_of::<wasmtime_continuations::Payloads>() as u32,
Expand Down Expand Up @@ -746,12 +736,6 @@ impl<P: PtrSize> VMOffsets<P> {
self.builtin_functions
}

/// TODO
#[inline]
pub fn vmctx_typed_continuations_main_stack_limits(&self) -> u32 {
self.typed_continuations_main_stack_limits
}

/// TODO
#[inline]
pub fn vmctx_typed_continuations_stack_chain(&self) -> u32 {
Expand Down
6 changes: 3 additions & 3 deletions crates/runtime/src/continuation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use std::mem;
use wasmtime_continuations::{debug_println, ENABLE_DEBUG_PRINTING};
pub use wasmtime_continuations::{
ContinuationFiber, ContinuationObject, ContinuationReference, Payloads, StackChain,
StackLimits, State,
StackChainCell, StackLimits, State,
};
use wasmtime_fibre::{Fiber, FiberStack, Suspend};

Expand Down Expand Up @@ -187,7 +187,7 @@ pub fn resume(
// SAFETY: We maintain as an invariant that the stack chain field in the
// VMContext is non-null and contains a chain of zero or more
// StackChain::Continuation values followed by StackChain::Main.
match unsafe { &*chain } {
match unsafe { (**chain).0.get_mut() } {
StackChain::Continuation(running_contobj) => {
debug_assert_eq!(contobj, *running_contobj);
debug_println!(
Expand Down Expand Up @@ -273,7 +273,7 @@ pub fn suspend(instance: &mut Instance, tag_index: u32) -> Result<(), TrapReason
// SAFETY: We maintain as an invariant that the stack chain field in the
// VMContext is non-null and contains a chain of zero or more
// StackChain::Continuation values followed by StackChain::Main.
let chain = unsafe { &*chain_ptr };
let chain = unsafe { (**chain_ptr).0.get_mut() };
let running = match chain {
StackChain::Absent => Err(TrapReason::user_without_backtrace(anyhow::anyhow!(
"Internal error: StackChain not initialised"
Expand Down
30 changes: 12 additions & 18 deletions crates/runtime/src/instance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ use std::ptr::NonNull;
use std::sync::atomic::AtomicU64;
use std::sync::Arc;
use std::{mem, ptr};
use wasmtime_continuations::StackChainCell;
use wasmtime_environ::ModuleInternedTypeIndex;
use wasmtime_environ::{
packed_option::ReservedValue, DataIndex, DefinedGlobalIndex, DefinedMemoryIndex,
Expand Down Expand Up @@ -432,6 +433,14 @@ impl Instance {
unsafe { self.vmctx_plus_offset_mut(self.offsets().vmctx_runtime_limits()) }
}

/// Return a pointer to the stack chain
#[inline]
pub fn stack_chain(&mut self) -> *mut *mut StackChainCell {
unsafe {
self.vmctx_plus_offset_mut(self.offsets().vmctx_typed_continuations_stack_chain())
}
}

/// Return a pointer to the global epoch counter used by this instance.
pub fn epoch_ptr(&mut self) -> *mut *const AtomicU64 {
unsafe { self.vmctx_plus_offset_mut(self.offsets().vmctx_epoch_ptr()) }
Expand Down Expand Up @@ -464,6 +473,7 @@ impl Instance {
if let Some(store) = store {
*self.vmctx_plus_offset_mut(self.offsets().vmctx_store()) = store;
*self.runtime_limits() = (*store).vmruntime_limits();
*self.stack_chain() = (*store).stack_chain();
*self.epoch_ptr() = (*store).epoch_ptr();
*self.externref_activations_table() = (*store).externref_activations_table().0;
} else {
Expand Down Expand Up @@ -1133,13 +1143,6 @@ impl Instance {
*self.vmctx_plus_offset_mut(offsets.vmctx_builtin_functions()) =
&VMBuiltinFunctionsArray::INIT;

let main_stack_limits_ptr =
self.vmctx_plus_offset_mut(offsets.vmctx_typed_continuations_main_stack_limits());
*main_stack_limits_ptr = wasmtime_continuations::StackLimits::default();

*self.vmctx_plus_offset_mut(offsets.vmctx_typed_continuations_stack_chain()) =
wasmtime_continuations::StackChain::MainStack(main_stack_limits_ptr);

// Initialize the Payloads object to be empty
let vmctx_payloads: *mut wasmtime_continuations::Payloads =
self.vmctx_plus_offset_mut(offsets.vmctx_typed_continuations_payloads());
Expand Down Expand Up @@ -1283,18 +1286,9 @@ impl Instance {
fault
}

#[allow(dead_code)]
pub(crate) fn typed_continuations_main_stack_limits(
&mut self,
) -> *mut wasmtime_continuations::StackLimits {
unsafe {
self.vmctx_plus_offset_mut(self.offsets().vmctx_typed_continuations_main_stack_limits())
}
}

pub(crate) fn typed_continuations_stack_chain(
&mut self,
) -> *mut wasmtime_continuations::StackChain {
) -> *mut *mut wasmtime_continuations::StackChainCell {
unsafe {
self.vmctx_plus_offset_mut(self.offsets().vmctx_typed_continuations_stack_chain())
}
Expand All @@ -1303,7 +1297,7 @@ impl Instance {
#[allow(dead_code)]
pub(crate) fn set_typed_continuations_stack_chain(
&mut self,
chain: *mut wasmtime_continuations::StackChain,
chain: *mut *mut wasmtime_continuations::StackChainCell,
) {
unsafe {
let ptr =
Expand Down
5 changes: 5 additions & 0 deletions crates/runtime/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use std::fmt;
use std::ptr::NonNull;
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
use std::sync::Arc;
use wasmtime_continuations::StackChainCell;
use wasmtime_environ::{DefinedFuncIndex, DefinedMemoryIndex, HostPtr, VMOffsets};

mod arch;
Expand Down Expand Up @@ -97,6 +98,10 @@ pub unsafe trait Store {
/// in the `VMContext`.
fn vmruntime_limits(&self) -> *mut VMRuntimeLimits;

/// Used to configure `VMContext` initialization and store the right pointer
/// in the `VMContext`.
fn stack_chain(&self) -> *mut StackChainCell;

/// Returns a pointer to the global epoch counter.
///
/// Used to configure the `VMContext` on initialization.
Expand Down
45 changes: 45 additions & 0 deletions crates/wasmtime/src/runtime/store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ use std::ptr;
use std::sync::atomic::AtomicU64;
use std::sync::Arc;
use std::task::{Context, Poll};
use wasmtime_runtime::continuation::{StackChain, StackChainCell, StackLimits};
use wasmtime_runtime::mpk::{self, ProtectionKey, ProtectionMask};
use wasmtime_runtime::{
ExportGlobal, InstanceAllocationRequest, InstanceAllocator, InstanceHandle, ModuleInfo,
Expand Down Expand Up @@ -303,6 +304,21 @@ pub struct StoreOpaque {

engine: Engine,
runtime_limits: VMRuntimeLimits,

// Stack information used by typed continuations instructions. See
// documentation on `wasmtime_continuations::StackChain` for details.
//
// Note that in terms of (interior) mutability, we generally follow the same
// pattern as the `VMRuntimeLimits` object above: In the case of
// `StackLimits`, all of its fields are `UnsafeCell`s. For the stack chain,
// we wrap the entire `StackChainObject` in an `UnsafeCell`.
//
// Finally, observe that the stack chain adds more internal self references:
// The stack chain always contains a `MainStack` element at the ends which
// has a pointer to the `main_stack_limits` field of the same `StoreOpaque`.
main_stack_limits: StackLimits,
stack_chain: StackChainCell,

instances: Vec<StoreInstance>,
#[cfg(feature = "component-model")]
num_component_instances: usize,
Expand Down Expand Up @@ -492,6 +508,8 @@ impl<T> Store<T> {
_marker: marker::PhantomPinned,
engine: engine.clone(),
runtime_limits: Default::default(),
main_stack_limits: Default::default(),
stack_chain: StackChainCell::absent(),
instances: Vec::new(),
#[cfg(feature = "component-model")]
num_component_instances: 0,
Expand Down Expand Up @@ -573,6 +591,15 @@ impl<T> Store<T> {
instance
};

unsafe {
// NOTE(frank-emrich) The setup code for `default_caller` above
// together with the comment on the `PhantomPinned` marker inside
// `Store` indicates that `inner` is supposed to be at a stable
// location at this point, without explicitly being `Pin`-ed.
let stack_chain = inner.stack_chain.0.get();
*stack_chain = StackChain::MainStack(inner.main_stack_limits());
}

Self {
inner: ManuallyDrop::new(inner),
}
Expand Down Expand Up @@ -1513,6 +1540,20 @@ impl StoreOpaque {
&self.runtime_limits as *const VMRuntimeLimits as *mut VMRuntimeLimits
}

#[inline]
pub fn main_stack_limits(&self) -> *mut StackLimits {
// NOTE(frank-emrich) This looks dogdy, but follows the same pattern as
// `vmruntime_limits()` above.
&self.main_stack_limits as *const StackLimits as *mut StackLimits
}

#[inline]
pub fn stack_chain(&self) -> *mut StackChainCell {
// NOTE(frank-emrich) This looks dogdy, but follows the same pattern as
// `vmruntime_limits()` above.
&self.stack_chain as *const StackChainCell as *mut StackChainCell
}

pub unsafe fn insert_vmexternref_without_gc(&mut self, r: VMExternRef) {
self.externref_activations_table.insert_without_gc(r);
}
Expand Down Expand Up @@ -2025,6 +2066,10 @@ unsafe impl<T> wasmtime_runtime::Store for StoreInner<T> {
<StoreOpaque>::vmruntime_limits(self)
}

fn stack_chain(&self) -> *mut StackChainCell {
<StoreOpaque>::stack_chain(self)
}

fn epoch_ptr(&self) -> *const AtomicU64 {
self.engine.epoch_counter() as *const _
}
Expand Down
14 changes: 7 additions & 7 deletions tests/all/pooling_allocator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -661,12 +661,12 @@ configured maximum of 16 bytes; breakdown of allocation requirement:
"
} else {
"\
instance allocation for this module requires 320 bytes which exceeds the \
instance allocation for this module requires 272 bytes which exceeds the \
configured maximum of 16 bytes; breakdown of allocation requirement:
* 50.00% - 160 bytes - instance state management
* 10.00% - 32 bytes - typed continuations payloads object
* 10.00% - 32 bytes - typed continuations main stack limits
* 58.82% - 160 bytes - instance state management
* 8.82% - 24 bytes - typed continuations payloads object
* 5.88% - 16 bytes - jit store state
"
};
match Module::new(&engine, "(module)") {
Expand All @@ -690,11 +690,11 @@ configured maximum of 16 bytes; breakdown of allocation requirement:
"
} else {
"\
instance allocation for this module requires 1920 bytes which exceeds the \
instance allocation for this module requires 1872 bytes which exceeds the \
configured maximum of 16 bytes; breakdown of allocation requirement:
* 8.33% - 160 bytes - instance state management
* 83.33% - 1600 bytes - defined globals
* 8.55% - 160 bytes - instance state management
* 85.47% - 1600 bytes - defined globals
"
};
match Module::new(&engine, &lots_of_globals) {
Expand Down
Loading

0 comments on commit 82d8569

Please sign in to comment.