Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for inter-instance stack switching #108

Merged
merged 5 commits into from
Feb 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion crates/continuations/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::ptr;
use std::{cell::UnsafeCell, ptr};
use wasmtime_fibre::Fiber;

/// TODO
Expand Down Expand Up @@ -132,6 +132,25 @@ impl StackChain {
pub const CONTINUATION_DISCRIMINANT: usize = STACK_CHAIN_CONTINUATION_DISCRIMINANT;
}

#[repr(transparent)]
pub struct StackChainCell(pub UnsafeCell<StackChain>);

impl StackChainCell {
pub fn absent() -> Self {
StackChainCell(UnsafeCell::new(StackChain::Absent))
}
}

// Since `StackChainCell` and `StackLimits` objects appear in the `StoreOpaque`,
// they need to be `Send` and `Sync`.
// This is safe for the same reason it is for `VMRuntimeLimits` (see comment
// there): Both types are pod-type with no destructor, and we don't access any
// of their fields from other threads.
unsafe impl Send for StackLimits {}
unsafe impl Sync for StackLimits {}
unsafe impl Send for StackChainCell {}
unsafe impl Sync for StackChainCell {}

pub struct Payloads {
/// Number of currently occupied slots.
pub length: types::payloads::Length,
Expand Down
26 changes: 24 additions & 2 deletions crates/cranelift/src/wasmfx/optimized.rs
Original file line number Diff line number Diff line change
Expand Up @@ -762,7 +762,20 @@ pub(crate) mod typed_continuation_helpers {

let offset =
i32::try_from(env.offsets.vmctx_typed_continuations_stack_chain()).unwrap();
StackChain::load(env, builder, base_addr, offset, self.pointer_type)

// The `typed_continuations_stack_chain` field of the VMContext only
// contains a pointer to the `StackChainCell` in the `Store`.
// The pointer never changes through the liftime of a `VMContext`,
// which is why this load is `readonly`.
// TODO(frank-emrich) Consider turning this pointer into a global
// variable, similar to `env.vmruntime_limits_ptr`.
let memflags = ir::MemFlags::trusted().with_readonly();
let stack_chain_ptr =
builder
.ins()
.load(self.pointer_type, memflags, base_addr, offset);

StackChain::load(env, builder, stack_chain_ptr, 0, self.pointer_type)
}

/// Stores the given stack chain saved in this `VMContext`, overwriting
Expand All @@ -777,7 +790,16 @@ pub(crate) mod typed_continuation_helpers {

let offset =
i32::try_from(env.offsets.vmctx_typed_continuations_stack_chain()).unwrap();
stack_chain.store(env, builder, base_addr, offset)

// Same situation as in `load_stack_chain` regarding pointer
// indirection and it being `readonly`.
let memflags = ir::MemFlags::trusted().with_readonly();
let stack_chain_ptr =
builder
.ins()
.load(self.pointer_type, memflags, base_addr, offset);

stack_chain.store(env, builder, stack_chain_ptr, 0)
}

/// Similar to `store_stack_chain`, but instead of storing an arbitrary
Expand Down
24 changes: 4 additions & 20 deletions crates/environ/src/vmoffsets.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,9 @@ pub struct VMOffsets<P> {
defined_func_refs: u32,
size: u32,

// The following field stores a value of type
// `wasmtime_continuations::StackLimits`.
typed_continuations_main_stack_limits: u32,
// The following field stores a value of type
// `wasmtime_continuations::StackChain`. The head of the chain is the
// The following field stores a pointer into the StoreOpauqe, to value of
// type `wasmtime_continuations::StackChain`.
// The head of the chain is the
// currently executing stack (main stack or a continuation).
typed_continuations_stack_chain: u32,
typed_continuations_payloads: u32,
Expand Down Expand Up @@ -363,7 +361,6 @@ impl<P: PtrSize> VMOffsets<P> {
calculate_sizes! {
typed_continuations_payloads: "typed continuations payloads object",
typed_continuations_stack_chain: "typed continuations stack chain",
typed_continuations_main_stack_limits: "typed continuations main stack limits",
defined_func_refs: "module functions",
defined_globals: "defined globals",
owned_memories: "owned memories",
Expand Down Expand Up @@ -416,7 +413,6 @@ impl<P: PtrSize> From<VMOffsetsFields<P>> for VMOffsets<P> {
defined_globals: 0,
defined_func_refs: 0,
size: 0,
typed_continuations_main_stack_limits: 0,
typed_continuations_stack_chain: 0,
typed_continuations_payloads: 0,
};
Expand Down Expand Up @@ -482,14 +478,8 @@ impl<P: PtrSize> From<VMOffsetsFields<P>> for VMOffsets<P> {
ret.ptr.size_of_vm_func_ref(),
),

align(std::mem::align_of::<wasmtime_continuations::StackLimits>() as u32),
size(typed_continuations_main_stack_limits)
= std::mem::size_of::<wasmtime_continuations::StackLimits>() as u32,

align(std::mem::align_of::<wasmtime_continuations::StackChain>() as u32),
size(typed_continuations_stack_chain)
= std::mem::size_of::<wasmtime_continuations::StackChain>() as u32,

= ret.ptr.size(),
align(std::mem::align_of::<wasmtime_continuations::Payloads>() as u32),
size(typed_continuations_payloads) =
std::mem::size_of::<wasmtime_continuations::Payloads>() as u32,
Expand Down Expand Up @@ -746,12 +736,6 @@ impl<P: PtrSize> VMOffsets<P> {
self.builtin_functions
}

/// TODO
#[inline]
pub fn vmctx_typed_continuations_main_stack_limits(&self) -> u32 {
self.typed_continuations_main_stack_limits
}

/// TODO
#[inline]
pub fn vmctx_typed_continuations_stack_chain(&self) -> u32 {
Expand Down
6 changes: 3 additions & 3 deletions crates/runtime/src/continuation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use std::mem;
use wasmtime_continuations::{debug_println, ENABLE_DEBUG_PRINTING};
pub use wasmtime_continuations::{
ContinuationFiber, ContinuationObject, ContinuationReference, Payloads, StackChain,
StackLimits, State,
StackChainCell, StackLimits, State,
};
use wasmtime_fibre::{Fiber, FiberStack, Suspend};

Expand Down Expand Up @@ -187,7 +187,7 @@ pub fn resume(
// SAFETY: We maintain as an invariant that the stack chain field in the
// VMContext is non-null and contains a chain of zero or more
// StackChain::Continuation values followed by StackChain::Main.
match unsafe { &*chain } {
match unsafe { (**chain).0.get_mut() } {
StackChain::Continuation(running_contobj) => {
debug_assert_eq!(contobj, *running_contobj);
debug_println!(
Expand Down Expand Up @@ -273,7 +273,7 @@ pub fn suspend(instance: &mut Instance, tag_index: u32) -> Result<(), TrapReason
// SAFETY: We maintain as an invariant that the stack chain field in the
// VMContext is non-null and contains a chain of zero or more
// StackChain::Continuation values followed by StackChain::Main.
let chain = unsafe { &*chain_ptr };
let chain = unsafe { (**chain_ptr).0.get_mut() };
let running = match chain {
StackChain::Absent => Err(TrapReason::user_without_backtrace(anyhow::anyhow!(
"Internal error: StackChain not initialised"
Expand Down
30 changes: 12 additions & 18 deletions crates/runtime/src/instance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ use std::ptr::NonNull;
use std::sync::atomic::AtomicU64;
use std::sync::Arc;
use std::{mem, ptr};
use wasmtime_continuations::StackChainCell;
use wasmtime_environ::ModuleInternedTypeIndex;
use wasmtime_environ::{
packed_option::ReservedValue, DataIndex, DefinedGlobalIndex, DefinedMemoryIndex,
Expand Down Expand Up @@ -432,6 +433,14 @@ impl Instance {
unsafe { self.vmctx_plus_offset_mut(self.offsets().vmctx_runtime_limits()) }
}

/// Return a pointer to the stack chain
#[inline]
pub fn stack_chain(&mut self) -> *mut *mut StackChainCell {
unsafe {
self.vmctx_plus_offset_mut(self.offsets().vmctx_typed_continuations_stack_chain())
}
}

/// Return a pointer to the global epoch counter used by this instance.
pub fn epoch_ptr(&mut self) -> *mut *const AtomicU64 {
unsafe { self.vmctx_plus_offset_mut(self.offsets().vmctx_epoch_ptr()) }
Expand Down Expand Up @@ -464,6 +473,7 @@ impl Instance {
if let Some(store) = store {
*self.vmctx_plus_offset_mut(self.offsets().vmctx_store()) = store;
*self.runtime_limits() = (*store).vmruntime_limits();
*self.stack_chain() = (*store).stack_chain();
*self.epoch_ptr() = (*store).epoch_ptr();
*self.externref_activations_table() = (*store).externref_activations_table().0;
} else {
Expand Down Expand Up @@ -1133,13 +1143,6 @@ impl Instance {
*self.vmctx_plus_offset_mut(offsets.vmctx_builtin_functions()) =
&VMBuiltinFunctionsArray::INIT;

let main_stack_limits_ptr =
self.vmctx_plus_offset_mut(offsets.vmctx_typed_continuations_main_stack_limits());
*main_stack_limits_ptr = wasmtime_continuations::StackLimits::default();

*self.vmctx_plus_offset_mut(offsets.vmctx_typed_continuations_stack_chain()) =
wasmtime_continuations::StackChain::MainStack(main_stack_limits_ptr);

// Initialize the Payloads object to be empty
let vmctx_payloads: *mut wasmtime_continuations::Payloads =
self.vmctx_plus_offset_mut(offsets.vmctx_typed_continuations_payloads());
Expand Down Expand Up @@ -1283,18 +1286,9 @@ impl Instance {
fault
}

#[allow(dead_code)]
pub(crate) fn typed_continuations_main_stack_limits(
&mut self,
) -> *mut wasmtime_continuations::StackLimits {
unsafe {
self.vmctx_plus_offset_mut(self.offsets().vmctx_typed_continuations_main_stack_limits())
}
}

pub(crate) fn typed_continuations_stack_chain(
&mut self,
) -> *mut wasmtime_continuations::StackChain {
) -> *mut *mut wasmtime_continuations::StackChainCell {
unsafe {
self.vmctx_plus_offset_mut(self.offsets().vmctx_typed_continuations_stack_chain())
}
Expand All @@ -1303,7 +1297,7 @@ impl Instance {
#[allow(dead_code)]
pub(crate) fn set_typed_continuations_stack_chain(
&mut self,
chain: *mut wasmtime_continuations::StackChain,
chain: *mut *mut wasmtime_continuations::StackChainCell,
) {
unsafe {
let ptr =
Expand Down
5 changes: 5 additions & 0 deletions crates/runtime/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use std::fmt;
use std::ptr::NonNull;
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
use std::sync::Arc;
use wasmtime_continuations::StackChainCell;
use wasmtime_environ::{DefinedFuncIndex, DefinedMemoryIndex, HostPtr, VMOffsets};

mod arch;
Expand Down Expand Up @@ -97,6 +98,10 @@ pub unsafe trait Store {
/// in the `VMContext`.
fn vmruntime_limits(&self) -> *mut VMRuntimeLimits;

/// Used to configure `VMContext` initialization and store the right pointer
/// in the `VMContext`.
fn stack_chain(&self) -> *mut StackChainCell;

/// Returns a pointer to the global epoch counter.
///
/// Used to configure the `VMContext` on initialization.
Expand Down
45 changes: 45 additions & 0 deletions crates/wasmtime/src/runtime/store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ use std::ptr;
use std::sync::atomic::AtomicU64;
use std::sync::Arc;
use std::task::{Context, Poll};
use wasmtime_runtime::continuation::{StackChain, StackChainCell, StackLimits};
use wasmtime_runtime::mpk::{self, ProtectionKey, ProtectionMask};
use wasmtime_runtime::{
ExportGlobal, InstanceAllocationRequest, InstanceAllocator, InstanceHandle, ModuleInfo,
Expand Down Expand Up @@ -303,6 +304,21 @@ pub struct StoreOpaque {

engine: Engine,
runtime_limits: VMRuntimeLimits,

// Stack information used by typed continuations instructions. See
// documentation on `wasmtime_continuations::StackChain` for details.
//
// Note that in terms of (interior) mutability, we generally follow the same
// pattern as the `VMRuntimeLimits` object above: In the case of
// `StackLimits`, all of its fields are `UnsafeCell`s. For the stack chain,
// we wrap the entire `StackChainObject` in an `UnsafeCell`.
//
// Finally, observe that the stack chain adds more internal self references:
// The stack chain always contains a `MainStack` element at the ends which
// has a pointer to the `main_stack_limits` field of the same `StoreOpaque`.
main_stack_limits: StackLimits,
stack_chain: StackChainCell,

instances: Vec<StoreInstance>,
#[cfg(feature = "component-model")]
num_component_instances: usize,
Expand Down Expand Up @@ -492,6 +508,8 @@ impl<T> Store<T> {
_marker: marker::PhantomPinned,
engine: engine.clone(),
runtime_limits: Default::default(),
main_stack_limits: Default::default(),
stack_chain: StackChainCell::absent(),
instances: Vec::new(),
#[cfg(feature = "component-model")]
num_component_instances: 0,
Expand Down Expand Up @@ -573,6 +591,15 @@ impl<T> Store<T> {
instance
};

unsafe {
// NOTE(frank-emrich) The setup code for `default_caller` above
// together with the comment on the `PhantomPinned` marker inside
// `Store` indicates that `inner` is supposed to be at a stable
// location at this point, without explicitly being `Pin`-ed.
let stack_chain = inner.stack_chain.0.get();
*stack_chain = StackChain::MainStack(inner.main_stack_limits());
}

Self {
inner: ManuallyDrop::new(inner),
}
Expand Down Expand Up @@ -1513,6 +1540,20 @@ impl StoreOpaque {
&self.runtime_limits as *const VMRuntimeLimits as *mut VMRuntimeLimits
}

#[inline]
pub fn main_stack_limits(&self) -> *mut StackLimits {
// NOTE(frank-emrich) This looks dogdy, but follows the same pattern as
// `vmruntime_limits()` above.
&self.main_stack_limits as *const StackLimits as *mut StackLimits
}

#[inline]
pub fn stack_chain(&self) -> *mut StackChainCell {
// NOTE(frank-emrich) This looks dogdy, but follows the same pattern as
// `vmruntime_limits()` above.
&self.stack_chain as *const StackChainCell as *mut StackChainCell
}

pub unsafe fn insert_vmexternref_without_gc(&mut self, r: VMExternRef) {
self.externref_activations_table.insert_without_gc(r);
}
Expand Down Expand Up @@ -2025,6 +2066,10 @@ unsafe impl<T> wasmtime_runtime::Store for StoreInner<T> {
<StoreOpaque>::vmruntime_limits(self)
}

fn stack_chain(&self) -> *mut StackChainCell {
<StoreOpaque>::stack_chain(self)
}

fn epoch_ptr(&self) -> *const AtomicU64 {
self.engine.epoch_counter() as *const _
}
Expand Down
14 changes: 7 additions & 7 deletions tests/all/pooling_allocator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -661,12 +661,12 @@ configured maximum of 16 bytes; breakdown of allocation requirement:
"
} else {
"\
instance allocation for this module requires 320 bytes which exceeds the \
instance allocation for this module requires 272 bytes which exceeds the \
configured maximum of 16 bytes; breakdown of allocation requirement:

* 50.00% - 160 bytes - instance state management
* 10.00% - 32 bytes - typed continuations payloads object
* 10.00% - 32 bytes - typed continuations main stack limits
* 58.82% - 160 bytes - instance state management
* 8.82% - 24 bytes - typed continuations payloads object
* 5.88% - 16 bytes - jit store state
"
};
match Module::new(&engine, "(module)") {
Expand All @@ -690,11 +690,11 @@ configured maximum of 16 bytes; breakdown of allocation requirement:
"
} else {
"\
instance allocation for this module requires 1920 bytes which exceeds the \
instance allocation for this module requires 1872 bytes which exceeds the \
configured maximum of 16 bytes; breakdown of allocation requirement:

* 8.33% - 160 bytes - instance state management
* 83.33% - 1600 bytes - defined globals
* 8.55% - 160 bytes - instance state management
* 85.47% - 1600 bytes - defined globals
"
};
match Module::new(&engine, &lots_of_globals) {
Expand Down
Loading