Skip to content

Commit

Permalink
Fix and harden HostWrapper by enforcing nothing can borrow it mutably
Browse files Browse the repository at this point in the history
  • Loading branch information
prokopyl committed Apr 3, 2024
1 parent 5fb548e commit 4904044
Show file tree
Hide file tree
Showing 11 changed files with 207 additions and 107 deletions.
2 changes: 1 addition & 1 deletion extensions/src/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
//! }
//!
//! impl<'a> HostShared<'a> for MyHostShared<'a> {
//! fn instantiated(&self, instance: PluginSharedHandle<'a>) {
//! fn initializing(&self, instance: PluginInitializingHandle<'a>) {
//! let _ = self.state_ext.set(instance.get_extension());
//! }
//! # fn request_restart(&self) { unimplemented!() }
Expand Down
24 changes: 15 additions & 9 deletions host/examples/cpal/src/host.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,36 +63,38 @@ impl Host for CpalHost {
/// (This is unused in this example, but this is kept here for demonstration purposes)
#[allow(dead_code)]
struct PluginCallbacks<'a> {
/// The plugin's own shared handle.
handle: PluginSharedHandle<'a>,
/// A handle to the plugin's Audio Ports extension, if it supports it.
audio_ports: Option<&'a PluginAudioPorts>,
}

/// Data, accessible by all of the plugin's threads.
/// Data, accessible by all the plugin's threads.
pub struct CpalHostShared<'a> {
/// The sender side of the channel to the main thread.
sender: Sender<MainThreadMessage>,
/// The plugin callbacks.
/// This is stored in a separate, thread-safe lock because the instantiated method might be
/// called concurrently with any other thread-safe host methods.
plugin: OnceLock<PluginCallbacks<'a>>,
/// This is stored in a separate, thread-safe lock because the initializing method might be
/// called concurrently with any other thread-safe host methods.
callbacks: OnceLock<PluginCallbacks<'a>>,
/// The plugin's shared handle.
/// This is stored in a separate, thread-safe lock because the instantiation might complete
/// concurrently with any other thread-safe host methods.
plugin: OnceLock<PluginSharedHandle<'a>>,
}

impl<'a> CpalHostShared<'a> {
/// Initializes the shared data.
fn new(sender: Sender<MainThreadMessage>) -> Self {
Self {
sender,
callbacks: OnceLock::new(),
plugin: OnceLock::new(),
}
}
}

impl<'a> HostShared<'a> for CpalHostShared<'a> {
fn instantiated(&self, instance: PluginSharedHandle<'a>) {
let _ = self.plugin.set(PluginCallbacks {
handle: instance,
fn initializing(&self, instance: PluginInitializingHandle<'a>) {
let _ = self.callbacks.set(PluginCallbacks {
audio_ports: instance.get_extension(),
});
}
Expand Down Expand Up @@ -150,6 +152,10 @@ impl<'a> HostMainThread<'a> for CpalHostMainThread<'a> {
.map(|gui| Gui::new(gui, &mut instance));

self.timer_support = instance.shared().get_extension();
self._shared
.plugin
.set(instance.shared())
.expect("This is the only method that should set the instance handles.");
self.plugin = Some(instance);
}
}
Expand Down
4 changes: 2 additions & 2 deletions host/src/extensions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@
//! }
//!
//! impl<'a> HostShared<'a> for MyHostShared<'a> {
//! // Once the plugin is fully instantiated, we can query its extensions
//! fn instantiated(&self, instance: PluginSharedHandle<'a>) {
//! // We can query the plugin's extensions as soon as the plugin starts initializing
//! fn initializing(&self, instance: PluginInitializingHandle<'a>) {
//! let _ = self.latency_extension.set(instance.get_extension());
//! }
//!
Expand Down
117 changes: 69 additions & 48 deletions host/src/extensions/wrapper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use std::cell::UnsafeCell;
use std::panic::AssertUnwindSafe;
use std::pin::Pin;
use std::ptr::NonNull;
use std::sync::Arc;

mod panic {
#[cfg(not(test))]
Expand All @@ -25,9 +26,11 @@ mod panic {

pub(crate) mod descriptor;

// Safety note: once this type is constructed, a pointer to it will be given to the plugin instance,
// which means we can never
pub struct HostWrapper<H: Host> {
audio_processor: Option<UnsafeCell<<H as Host>::AudioProcessor<'static>>>,
main_thread: Option<UnsafeCell<<H as Host>::MainThread<'static>>>,
audio_processor: UnsafeCell<Option<<H as Host>::AudioProcessor<'static>>>,
main_thread: UnsafeCell<Option<<H as Host>::MainThread<'static>>>,
shared: Pin<Box<<H as Host>::Shared<'static>>>,
}

Expand Down Expand Up @@ -71,7 +74,9 @@ impl<H: Host> HostWrapper<H> {
/// aliased, as per usual safety rules.
#[inline]
pub unsafe fn main_thread(&self) -> NonNull<<H as Host>::MainThread<'_>> {
NonNull::new_unchecked(self.main_thread.as_ref().unwrap_unchecked().get()).cast()
let ptr: NonNull<_> = (*self.main_thread.get()).as_ref().unwrap_unchecked().into();

ptr.cast()
}

/// Returns a raw, non-null pointer to the host's [`AudioProcessor`](Host::AudioProcessor)
Expand All @@ -86,10 +91,12 @@ impl<H: Host> HostWrapper<H> {
pub unsafe fn audio_processor(
&self,
) -> Result<NonNull<<H as Host>::AudioProcessor<'_>>, HostError> {
match &self.audio_processor {
None => Err(HostError::DeactivatedPlugin),
Some(ap) => Ok(NonNull::new_unchecked(ap.get()).cast()),
}
let ptr: NonNull<_> = (*self.audio_processor.get())
.as_ref()
.ok_or(HostError::DeactivatedPlugin)?
.into();

Ok(ptr.cast())
}

/// Returns a shared reference to the host's [`Shared`](Host::Shared) struct.
Expand All @@ -99,49 +106,52 @@ impl<H: Host> HostWrapper<H> {
unsafe { shrink_shared_ref::<H>(&self.shared) }
}

pub(crate) fn new<FS, FH>(shared: FS, main_thread: FH) -> Pin<Box<Self>>
pub(crate) fn new<FS, FH>(shared: FS, main_thread: FH) -> Pin<Arc<Self>>
where
FS: for<'s> FnOnce(&'s ()) -> <H as Host>::Shared<'s>,
FH: for<'s> FnOnce(&'s <H as Host>::Shared<'s>) -> <H as Host>::MainThread<'s>,
{
let mut wrapper = Box::pin(Self {
audio_processor: None,
main_thread: None,
// We use Arc only because Box<T> implies Unique<T>, which is not the case since the plugin
// will effectively hold a shared pointer to this.
let mut wrapper = Arc::new(Self {
audio_processor: UnsafeCell::new(None),
main_thread: UnsafeCell::new(None),
shared: Box::pin(shared(&())),
});

// Safety: we never move out of pinned_wrapper, we only update main_thread.
let pinned_wrapper = unsafe { Pin::get_unchecked_mut(wrapper.as_mut()) };
// PANIC: we have the only Arc copy of this wrapper data.
let wrapper_mut = Arc::get_mut(&mut wrapper).unwrap();

// SAFETY: This type guarantees main thread data cannot outlive shared
pinned_wrapper.main_thread = Some(UnsafeCell::new(main_thread(unsafe {
extend_shared_ref(&pinned_wrapper.shared)
})));
*wrapper_mut.main_thread.get_mut() = Some(main_thread(unsafe {
extend_shared_ref(&wrapper_mut.shared)
}));

wrapper
// SAFETY: wrapper is the only reference to the data, we can guarantee it will remain pinned
// until drop happens.
unsafe { Pin::new_unchecked(wrapper) }
}

/// # Safety
/// This must only be called on the main thread. User must ensure the provided instance pointer
/// is valid.
pub(crate) unsafe fn instantiated(self: Pin<&mut Self>, instance: NonNull<clap_plugin>) {
// SAFETY: we only update the fields, we don't move them
let pinned_self = unsafe { Pin::get_unchecked_mut(self) };

pub(crate) unsafe fn instantiated(&self, instance: NonNull<clap_plugin>) {
// SAFETY: At this point there is no way main_thread could not have been set.
unsafe { pinned_self.main_thread.as_mut().unwrap_unchecked() }
.get_mut()
self.main_thread()
.as_mut()
.instantiated(PluginMainThreadHandle::new(instance));

pinned_self
.shared
.instantiated(PluginSharedHandle::new(instance));
self.shared
.initializing(PluginInitializingHandle::new(instance));
}

/// # Safety
/// The user must ensure this is only called on the main thread, and not concurrently
/// to any other main-thread OR audio-thread method.
#[inline]
pub(crate) fn setup_audio_processor<FA>(
self: Pin<&mut Self>,
audio_processor: FA,
pub(crate) unsafe fn setup_audio_processor<FA>(
&self,
audio_processor_builder: FA,
instance: NonNull<clap_plugin>,
) -> Result<(), HostError>
where
Expand All @@ -151,48 +161,59 @@ impl<H: Host> HostWrapper<H> {
&mut <H as Host>::MainThread<'a>,
) -> <H as Host>::AudioProcessor<'a>,
{
// SAFETY: we only update the fields, we don't move the struct
let pinned_self = unsafe { Pin::get_unchecked_mut(self) };
// SAFETY: the user enforces this is called non-concurrently to any other audio-thread method.
let audio_processor = unsafe { &mut *self.audio_processor.get() };

match &mut pinned_self.audio_processor {
match audio_processor {
Some(_) => Err(HostError::AlreadyActivatedPlugin),
None => {
pinned_self.audio_processor = Some(UnsafeCell::new(audio_processor(
*audio_processor = Some(audio_processor_builder(
PluginAudioProcessorHandle::new(instance),
// SAFETY: Shared lives at least as long as the audio processor does.
unsafe { extend_shared_ref(&pinned_self.shared) },
// SAFETY: At this point there is no way main_thread could not have been set.
unsafe { pinned_self.main_thread.as_mut().unwrap_unchecked() }.get_mut(),
)));
unsafe { extend_shared_ref(&self.shared) },
// SAFETY: The user enforces that this is only called on the main thread, and
// non-concurrently to any other main-thread method.
unsafe { self.main_thread().cast().as_mut() },
));
Ok(())
}
}
}

/// # Safety
/// The user must ensure this is only called on the main thread, and not concurrently
/// to any other main-thread OR audio-thread method.
#[inline]
pub(crate) fn deactivate<T>(
self: Pin<&mut Self>,
pub(crate) unsafe fn teardown_audio_processor<T>(
&self,
drop: impl for<'s> FnOnce(
<H as Host>::AudioProcessor<'s>,
&mut <H as Host>::MainThread<'s>,
) -> T,
) -> Result<T, HostError> {
// SAFETY: we only update the fields, we don't move the struct
let pinned_self = unsafe { Pin::get_unchecked_mut(self) };
// SAFETY: The user enforces that this is called and non-concurrently to any other audio-thread method.
let audio_processor = unsafe { &mut *self.audio_processor.get() };

match pinned_self.audio_processor.take() {
match audio_processor.take() {
None => Err(HostError::DeactivatedPlugin),
Some(cell) => Ok(drop(
cell.into_inner(),
// SAFETY: At this point there is no way main_thread could not have been set.
unsafe { pinned_self.main_thread.as_mut().unwrap_unchecked() }.get_mut(),
Some(audio_processor) => Ok(drop(
audio_processor,
// SAFETY: The user enforces that this is only called on the main thread, and
// non-concurrently to any other main-thread method.
unsafe { self.main_thread().cast().as_mut() },
)),
}
}

/// # Safety
/// the user must ensure this is not called concurrently
/// to [`Self::setup_audio_processor`] or [`Self::teardown_audio_processor`]
#[inline]
pub(crate) fn is_active(&self) -> bool {
self.audio_processor.is_some()
pub(crate) unsafe fn is_active(&self) -> bool {
// SAFETY: The user enforces this isn't called to any audio processor method that would
// get a mutable reference to this
// TODO: make this actually the case
unsafe { (*self.audio_processor.get()).is_some() }
}

fn handle_panic<T, F, Pa>(handler: F, param: Pa) -> Result<T, HostWrapperError>
Expand Down
16 changes: 2 additions & 14 deletions host/src/factory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,26 +148,14 @@ impl<'a> PluginFactory<'a> {
plugin_id: &CStr,
host: *const clap_host,
) -> Result<NonNull<clap_plugin>, HostError> {
let plugin = NonNull::new((*self.inner)
NonNull::new((*self.inner)
.create_plugin
.ok_or(HostError::NullFactoryCreatePluginFunction)?(
self.inner,
host,
plugin_id.as_ptr(),
) as *mut clap_plugin)
.ok_or(HostError::PluginNotFound)?;

if let Some(init) = plugin.as_ref().init {
if !init(plugin.as_ptr()) {
if let Some(destroy) = plugin.as_ref().destroy {
destroy(plugin.as_ptr());
}

return Err(HostError::InstantiationFailed);
}
}

Ok(plugin)
.ok_or(HostError::PluginNotFound)
}
}

Expand Down
8 changes: 4 additions & 4 deletions host/src/host.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@
//!
//! impl<'a> HostShared<'a> for MyHostShared<'a> {
//! // Once the plugin is fully instantiated, we can query its extensions
//! fn instantiated(&self, instance: PluginSharedHandle<'a>) {
//! fn initializing(&self, instance: PluginInitializingHandle<'a>) {
//! let _ = self.latency_extension.set(instance.get_extension());
//! }
//!
Expand Down Expand Up @@ -198,7 +198,7 @@ pub use error::HostError;
pub use extensions::HostExtensions;
pub use info::HostInfo;

use crate::plugin::{PluginMainThreadHandle, PluginSharedHandle};
use crate::plugin::{PluginInitializingHandle, PluginMainThreadHandle};

/// Host data and callbacks that are tied to `[main-thread]` operations.
///
Expand Down Expand Up @@ -249,7 +249,7 @@ pub trait HostShared<'a>: Send + Sync {
/// plugin instance's lifetime.
#[inline]
#[allow(unused)]
fn instantiated(&self, instance: PluginSharedHandle<'a>) {}
fn initializing(&self, instance: PluginInitializingHandle<'a>) {}

/// Called by the plugin when it requests to be deactivated and then restarted by the host.
///
Expand Down Expand Up @@ -285,7 +285,7 @@ pub trait Host: 'static {
/// See the [`HostAudioProcessor`] docs and the [module docs](self) for more information.
type AudioProcessor<'a>: HostAudioProcessor<'a> + 'a;

/// Declares all of the extensions supported by this host.
/// Declares all the extensions supported by this host.
///
/// Extension declaration is done using the [`HostExtensions::register`] method.
///
Expand Down
5 changes: 4 additions & 1 deletion host/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,10 @@ pub mod prelude {
HostShared,
},
plugin::PluginInstance,
plugin::{PluginAudioProcessorHandle, PluginMainThreadHandle, PluginSharedHandle},
plugin::{
PluginAudioProcessorHandle, PluginInitializingHandle, PluginMainThreadHandle,
PluginSharedHandle,
},
process::{
audio_buffers::{
AudioPortBuffer, AudioPortBufferType, AudioPorts, InputAudioBuffers, InputChannel,
Expand Down
4 changes: 2 additions & 2 deletions host/src/plugin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ impl<H: Host> PluginInstance<H> {
drop_with: D,
) -> T
where
D: for<'s> FnOnce(<H as Host>::AudioProcessor<'s>, &mut <H as Host>::MainThread<'s>) -> T,
D: for<'s> FnOnce(<H as Host>::AudioProcessor<'_>, &mut <H as Host>::MainThread<'s>) -> T,
{
if !Arc::ptr_eq(self.inner.get(), &processor.inner) {
panic!("Given plugin audio processor does not match the instance being deactivated")
Expand All @@ -98,7 +98,7 @@ impl<H: Host> PluginInstance<H> {

pub fn try_deactivate_with<T, D>(&mut self, drop_with: D) -> Result<T, HostError>
where
D: for<'s> FnOnce(<H as Host>::AudioProcessor<'s>, &mut <H as Host>::MainThread<'s>) -> T,
D: for<'s> FnOnce(<H as Host>::AudioProcessor<'_>, &mut <H as Host>::MainThread<'s>) -> T,
{
self.inner.use_mut(|inner| {
let wrapper = Arc::get_mut(inner).ok_or(HostError::StillActivatedPlugin)?;
Expand Down
Loading

0 comments on commit 4904044

Please sign in to comment.