From 66c0cbb4cf2dfed689fbe8c96cbb791f2db5a639 Mon Sep 17 00:00:00 2001 From: Vecvec Date: Mon, 21 Oct 2024 07:17:16 +1300 Subject: [PATCH] use snatchables for raw acceleration structures --- wgpu-core/src/command/ray_tracing.rs | 112 +++++++++++++-------- wgpu-core/src/device/life.rs | 2 +- wgpu-core/src/device/queue.rs | 15 ++- wgpu-core/src/device/ray_tracing.rs | 41 +++----- wgpu-core/src/device/resource.rs | 10 +- wgpu-core/src/hub.rs | 4 +- wgpu-core/src/lock/rank.rs | 1 + wgpu-core/src/ray_tracing.rs | 4 +- wgpu-core/src/resource.rs | 139 ++++++++++++++++++++++++--- 9 files changed, 235 insertions(+), 93 deletions(-) diff --git a/wgpu-core/src/command/ray_tracing.rs b/wgpu-core/src/command/ray_tracing.rs index 085761df5f..e0e20ace75 100644 --- a/wgpu-core/src/command/ray_tracing.rs +++ b/wgpu-core/src/command/ray_tracing.rs @@ -292,48 +292,53 @@ impl Global { ..BufferUses::ACCELERATION_STRUCTURE_SCRATCH, }; - let blas_descriptors = blas_storage - .iter() - .map(|storage| map_blas(storage, scratch_buffer.raw())); - - let tlas_descriptors = tlas_storage.iter().map( - |UnsafeTlasStore { - tlas, - entries, - scratch_buffer_offset, - }| { - if tlas.update_mode == wgt::AccelerationStructureUpdateMode::PreferUpdate { - log::info!("only rebuild implemented") - } - hal::BuildAccelerationStructureDescriptor { - entries, - mode: hal::AccelerationStructureBuildMode::Build, - flags: tlas.flags, - source_acceleration_structure: None, - destination_acceleration_structure: tlas.raw(), - scratch_buffer: scratch_buffer.raw(), - scratch_buffer_offset: *scratch_buffer_offset, - } - }, - ); + let mut tlas_descriptors = Vec::new(); + + for UnsafeTlasStore { + tlas, + entries, + scratch_buffer_offset, + } in &tlas_storage + { + if tlas.update_mode == wgt::AccelerationStructureUpdateMode::PreferUpdate { + log::info!("only rebuild implemented") + } + tlas_descriptors.push(hal::BuildAccelerationStructureDescriptor { + entries, + mode: hal::AccelerationStructureBuildMode::Build, + flags: tlas.flags, + source_acceleration_structure: None, + destination_acceleration_structure: tlas.raw(&snatch_guard).ok_or( + BuildAccelerationStructureError::InvalidTlas(tlas.error_ident()), + )?, + scratch_buffer: scratch_buffer.raw(), + scratch_buffer_offset: *scratch_buffer_offset, + }) + } let blas_present = !blas_storage.is_empty(); let tlas_present = !tlas_storage.is_empty(); let cmd_buf_raw = cmd_buf_data.encoder.open(device)?; + let mut descriptors = Vec::new(); + + for storage in &blas_storage { + descriptors.push(map_blas(storage, scratch_buffer.raw(), &snatch_guard)?); + } + build_blas( cmd_buf_raw, blas_present, tlas_present, input_barriers, - &blas_descriptors.collect::>(), + &descriptors, scratch_buffer_barrier, ); if tlas_present { unsafe { - cmd_buf_raw.build_acceleration_structures(&tlas_descriptors.collect::>()); + cmd_buf_raw.build_acceleration_structures(&tlas_descriptors); cmd_buf_raw.place_acceleration_structure_barrier( hal::AccelerationStructureBarrier { @@ -614,10 +619,6 @@ impl Global { ..BufferUses::ACCELERATION_STRUCTURE_SCRATCH, }; - let blas_descriptors = blas_storage - .iter() - .map(|storage| map_blas(storage, scratch_buffer.raw())); - let mut tlas_descriptors = Vec::with_capacity(tlas_storage.len()); for &TlasStore { @@ -638,7 +639,13 @@ impl Global { mode: hal::AccelerationStructureBuildMode::Build, flags: tlas.flags, source_acceleration_structure: None, - destination_acceleration_structure: tlas.raw.as_ref(), + destination_acceleration_structure: tlas + .raw + .get(&snatch_guard) + .ok_or(BuildAccelerationStructureError::InvalidTlas( + tlas.error_ident(), + ))? + .as_ref(), scratch_buffer: scratch_buffer.raw(), scratch_buffer_offset: *scratch_buffer_offset, }) @@ -649,12 +656,18 @@ impl Global { let cmd_buf_raw = cmd_buf_data.encoder.open(device)?; + let mut descriptors = Vec::new(); + + for storage in &blas_storage { + descriptors.push(map_blas(storage, scratch_buffer.raw(), &snatch_guard)?); + } + build_blas( cmd_buf_raw, blas_present, tlas_present, input_barriers, - &blas_descriptors.collect::>(), + &descriptors, scratch_buffer_barrier, ); @@ -771,7 +784,10 @@ impl CommandBufferMutable { } // makes sure a tlas is built before it is used - pub(crate) fn validate_tlas_actions(&self) -> Result<(), ValidateTlasActionsError> { + pub(crate) fn validate_tlas_actions( + &self, + snatch_guard: &SnatchGuard, + ) -> Result<(), ValidateTlasActionsError> { profiling::scope!("CommandEncoder::[submission]::validate_tlas_actions"); for action in &self.tlas_actions { match &action.kind { @@ -794,8 +810,9 @@ impl CommandBufferMutable { for blas in dependencies.deref() { let blas_build_index = *blas.built_index.read(); if blas_build_index.is_none() { - return Err(ValidateTlasActionsError::UsedUnbuilt( + return Err(ValidateTlasActionsError::UsedUnbuiltBlas( action.tlas.error_ident(), + blas.error_ident(), )); } if blas_build_index.unwrap() > tlas_build_index.unwrap() { @@ -804,6 +821,9 @@ impl CommandBufferMutable { action.tlas.error_ident(), )); } + if blas.raw.get(snatch_guard).is_none() { + return Err(ValidateTlasActionsError::InvalidBlas(blas.error_ident())); + } } } } @@ -1180,10 +1200,14 @@ fn iter_buffers<'a, 'b>( fn map_blas<'a>( storage: &'a BlasStore<'_>, scratch_buffer: &'a dyn hal::DynBuffer, -) -> hal::BuildAccelerationStructureDescriptor< - 'a, - dyn hal::DynBuffer, - dyn hal::DynAccelerationStructure, + snatch_guard: &'a SnatchGuard, +) -> Result< + hal::BuildAccelerationStructureDescriptor< + 'a, + dyn hal::DynBuffer, + dyn hal::DynAccelerationStructure, + >, + BuildAccelerationStructureError, > { let BlasStore { blas, @@ -1193,15 +1217,21 @@ fn map_blas<'a>( if blas.update_mode == wgt::AccelerationStructureUpdateMode::PreferUpdate { log::info!("only rebuild implemented") } - hal::BuildAccelerationStructureDescriptor { + Ok(hal::BuildAccelerationStructureDescriptor { entries, mode: hal::AccelerationStructureBuildMode::Build, flags: blas.flags, source_acceleration_structure: None, - destination_acceleration_structure: blas.raw.as_ref(), + destination_acceleration_structure: blas + .raw + .get(snatch_guard) + .ok_or(BuildAccelerationStructureError::InvalidBlas( + blas.error_ident(), + ))? + .as_ref(), scratch_buffer, scratch_buffer_offset: *scratch_buffer_offset, - } + }) } fn build_blas<'a>( diff --git a/wgpu-core/src/device/life.rs b/wgpu-core/src/device/life.rs index 8e261fb42a..246649b589 100644 --- a/wgpu-core/src/device/life.rs +++ b/wgpu-core/src/device/life.rs @@ -283,7 +283,7 @@ impl LifetimeTracker { } /// Returns the submission index of the most recent submission that uses the - /// given blas. + /// given tlas. pub fn get_tlas_latest_submission_index(&self, tlas: &Tlas) -> Option { // We iterate in reverse order, so that we can bail out early as soon // as we find a hit. diff --git a/wgpu-core/src/device/queue.rs b/wgpu-core/src/device/queue.rs index b26707ed11..94ab451389 100644 --- a/wgpu-core/src/device/queue.rs +++ b/wgpu-core/src/device/queue.rs @@ -27,7 +27,7 @@ use crate::{ use smallvec::SmallVec; -use crate::resource::{Blas, Tlas}; +use crate::resource::{Blas, DestroyedAccelerationStructure, Tlas}; use crate::scratch::ScratchBuffer; use std::{ iter, @@ -148,8 +148,7 @@ pub enum TempResource { ScratchBuffer(ScratchBuffer), DestroyedBuffer(DestroyedBuffer), DestroyedTexture(DestroyedTexture), - Blas(Arc), - Tlas(Arc), + DestroyedAccelerationStructure(DestroyedAccelerationStructure), } /// A series of raw [`CommandBuffer`]s that have been submitted to a @@ -281,6 +280,14 @@ impl PendingWrites { self.dst_tlas_s.insert(tlas.tracker_index(), tlas.clone()); } + pub fn contains_blas(&mut self, blas: &Arc) -> bool { + self.dst_blas_s.contains_key(&blas.tracker_index()) + } + + pub fn contains_tlas(&mut self, tlas: &Arc) -> bool { + self.dst_tlas_s.contains_key(&tlas.tracker_index()) + } + pub fn consume_temp(&mut self, resource: TempResource) { self.temp_resources.push(resource); } @@ -1492,7 +1499,7 @@ fn validate_command_buffer( if let Err(e) = cmd_buf_data.validate_blas_actions() { return Err(e.into()); } - if let Err(e) = cmd_buf_data.validate_tlas_actions() { + if let Err(e) = cmd_buf_data.validate_tlas_actions(snatch_guard) { return Err(e.into()); } } diff --git a/wgpu-core/src/device/ray_tracing.rs b/wgpu-core/src/device/ray_tracing.rs index 1f0941a8d5..da68ab8fc5 100644 --- a/wgpu-core/src/device/ray_tracing.rs +++ b/wgpu-core/src/device/ray_tracing.rs @@ -3,10 +3,12 @@ use std::sync::Arc; #[cfg(feature = "trace")] use crate::device::trace; -use crate::lock::rank; +use crate::lock::{rank, Mutex}; use crate::resource::{Fallible, TrackingData}; +use crate::snatch::Snatchable; +use crate::weak_vec::WeakVec; use crate::{ - device::{queue::TempResource, Device, DeviceError}, + device::{Device, DeviceError}, global::Global, id::{self, BlasId, TlasId}, lock::RwLock, @@ -90,7 +92,7 @@ impl Device { }; Ok(Arc::new(resource::Blas { - raw: ManuallyDrop::new(raw), + raw: Snatchable::new(raw), device: self.clone(), size_info, sizes, @@ -146,7 +148,7 @@ impl Device { .map_err(DeviceError::from_hal)?; Ok(Arc::new(resource::Tlas { - raw: ManuallyDrop::new(raw), + raw: Snatchable::new(raw), device: self.clone(), size_info, flags: desc.flags, @@ -157,6 +159,7 @@ impl Device { label: desc.label.to_string(), max_instance_count: desc.max_instances, tracking_data: TrackingData::new(self.tracker_indices.tlas_s.clone()), + bind_groups: Mutex::new(rank::TLAS_BIND_GROUPS, WeakVec::new()), })) } } @@ -270,23 +273,14 @@ impl Global { let hub = &self.hub; let blas = hub.blas_s.get(blas_id).get()?; - let device = &blas.device; + let _device = &blas.device; #[cfg(feature = "trace")] - if let Some(trace) = device.trace.lock().as_mut() { + if let Some(trace) = _device.trace.lock().as_mut() { trace.add(trace::Action::FreeBlas(blas_id)); } - let temp = TempResource::Blas(blas.clone()); - { - let mut device_lock = device.lock_life(); - let last_submit_index = device_lock.get_blas_latest_submission_index(blas.as_ref()); - if let Some(last_submit_index) = last_submit_index { - device_lock.schedule_resource_destruction(temp, last_submit_index); - } - } - - Ok(()) + blas.destroy() } pub fn blas_drop(&self, blas_id: BlasId) { @@ -326,23 +320,14 @@ impl Global { .clone(); drop(tlas_guard); - let device = &mut tlas.device.clone(); + let _device = &mut tlas.device.clone(); #[cfg(feature = "trace")] - if let Some(trace) = device.trace.lock().as_mut() { + if let Some(trace) = _device.trace.lock().as_mut() { trace.add(trace::Action::FreeTlas(tlas_id)); } - let temp = TempResource::Tlas(tlas.clone()); - { - let mut device_lock = device.lock_life(); - let last_submit_index = device_lock.get_tlas_latest_submission_index(tlas.as_ref()); - if let Some(last_submit_index) = last_submit_index { - device_lock.schedule_resource_destruction(temp, last_submit_index); - } - } - - Ok(()) + tlas.destroy() } pub fn tlas_drop(&self, tlas_id: TlasId) { diff --git a/wgpu-core/src/device/resource.rs b/wgpu-core/src/device/resource.rs index 75dba162a5..e76988519b 100644 --- a/wgpu-core/src/device/resource.rs +++ b/wgpu-core/src/device/resource.rs @@ -41,7 +41,7 @@ use wgt::{ math::align_to, DeviceLostReason, TextureFormat, TextureSampleType, TextureViewDimension, }; -use crate::resource::{AccelerationStructure, Tlas}; +use crate::resource::{AccelerationStructure, DestroyedResourceError, Tlas}; use std::{ borrow::Cow, mem::{self, ManuallyDrop}, @@ -2200,6 +2200,7 @@ impl Device { binding: u32, decl: &wgt::BindGroupLayoutEntry, tlas: &'a Arc, + snatch_guard: &'a SnatchGuard<'a>, ) -> Result<&'a dyn hal::DynAccelerationStructure, binding_model::CreateBindGroupError> { use crate::binding_model::CreateBindGroupError as Error; @@ -2218,7 +2219,9 @@ impl Device { } } - Ok(tlas.raw()) + Ok(tlas + .raw(snatch_guard) + .ok_or(DestroyedResourceError(tlas.error_ident()))?) } // This function expects the provided bind group layout to be resolved @@ -2361,7 +2364,8 @@ impl Device { (res_index, num_bindings) } Br::AccelerationStructure(ref tlas) => { - let tlas = self.create_tlas_binding(&mut used, binding, decl, tlas)?; + let tlas = + self.create_tlas_binding(&mut used, binding, decl, tlas, &snatch_guard)?; let res_index = hal_tlas_s.len(); hal_tlas_s.push(tlas); (res_index, 1) diff --git a/wgpu-core/src/hub.rs b/wgpu-core/src/hub.rs index d8b8c46d84..15d3e3c06b 100644 --- a/wgpu-core/src/hub.rs +++ b/wgpu-core/src/hub.rs @@ -107,7 +107,9 @@ use crate::{ instance::Adapter, pipeline::{ComputePipeline, PipelineCache, RenderPipeline, ShaderModule}, registry::{Registry, RegistryReport}, - resource::{Blas, Buffer, Fallible, QuerySet, Sampler, StagingBuffer, Texture, TextureView, Tlas}, + resource::{ + Blas, Buffer, Fallible, QuerySet, Sampler, StagingBuffer, Texture, TextureView, Tlas, + }, }; use std::{fmt::Debug, sync::Arc}; diff --git a/wgpu-core/src/lock/rank.rs b/wgpu-core/src/lock/rank.rs index 87b977d62a..842dadf26d 100644 --- a/wgpu-core/src/lock/rank.rs +++ b/wgpu-core/src/lock/rank.rs @@ -147,6 +147,7 @@ define_lock_ranks! { rank BLAS_BUILT_INDEX "Blas::built_index" followed by { } rank TLAS_BUILT_INDEX "Tlas::built_index" followed by { } rank TLAS_DEPENDENCIES "Tlas::dependencies" followed by { } + rank TLAS_BIND_GROUPS "Tlas::bind_groups" followed by { } #[cfg(test)] rank PAWN "pawn" followed by { ROOK, BISHOP } diff --git a/wgpu-core/src/ray_tracing.rs b/wgpu-core/src/ray_tracing.rs index ab28ca6f12..11ccb714f1 100644 --- a/wgpu-core/src/ray_tracing.rs +++ b/wgpu-core/src/ray_tracing.rs @@ -168,8 +168,8 @@ pub enum ValidateTlasActionsError { #[error("Blas {0:?} is used before it is built (in Tlas {1:?})")] UsedUnbuiltBlas(ResourceErrorIdent, ResourceErrorIdent), - #[error("BlasId is invalid or destroyed (in Tlas {0:?})")] - InvalidBlasId(ResourceErrorIdent), + #[error("BlasId is destroyed (in Tlas {0:?})")] + InvalidBlas(ResourceErrorIdent), #[error("Blas {0:?} is newer than the containing Tlas {1:?}")] BlasNewerThenTlas(ResourceErrorIdent, ResourceErrorIdent), diff --git a/wgpu-core/src/resource.rs b/wgpu-core/src/resource.rs index 6ee6b3aa88..69230e4d77 100644 --- a/wgpu-core/src/resource.rs +++ b/wgpu-core/src/resource.rs @@ -1900,12 +1900,12 @@ pub type BlasDescriptor<'a> = wgt::CreateBlasDescriptor>; pub type TlasDescriptor<'a> = wgt::CreateTlasDescriptor>; pub(crate) trait AccelerationStructure: Trackable { - fn raw(&self) -> &dyn hal::DynAccelerationStructure; + fn raw<'a>(&'a self, guard: &'a SnatchGuard) -> Option<&'a dyn hal::DynAccelerationStructure>; } #[derive(Debug)] pub struct Blas { - pub(crate) raw: ManuallyDrop>, + pub(crate) raw: Snatchable>, pub(crate) device: Arc, pub(crate) size_info: hal::AccelerationStructureBuildSizes, pub(crate) sizes: wgt::BlasGeometrySizeDescriptors, @@ -1922,16 +1922,56 @@ impl Drop for Blas { fn drop(&mut self) { resource_log!("Destroy raw {}", self.error_ident()); // SAFETY: We are in the Drop impl, and we don't use self.raw anymore after this point. - let raw = unsafe { ManuallyDrop::take(&mut self.raw) }; - unsafe { - self.device.raw().destroy_acceleration_structure(raw); + if let Some(raw) = self.raw.take() { + unsafe { + self.device.raw().destroy_acceleration_structure(raw); + } } } } impl AccelerationStructure for Blas { - fn raw(&self) -> &dyn hal::DynAccelerationStructure { - self.raw.as_ref() + fn raw<'a>(&'a self, guard: &'a SnatchGuard) -> Option<&'a dyn hal::DynAccelerationStructure> { + Some(self.raw.get(guard)?.as_ref()) + } +} + +impl Blas { + pub(crate) fn destroy(self: &Arc) -> Result<(), DestroyError> { + let device = &self.device; + + let temp = { + let mut snatch_guard = device.snatchable_lock.write(); + + let raw = match self.raw.snatch(&mut snatch_guard) { + Some(raw) => raw, + None => { + return Err(DestroyError::AlreadyDestroyed); + } + }; + + drop(snatch_guard); + + queue::TempResource::DestroyedAccelerationStructure(DestroyedAccelerationStructure { + raw: ManuallyDrop::new(raw), + device: Arc::clone(&self.device), + label: self.label().to_owned(), + bind_groups: WeakVec::new(), + }) + }; + + let mut pending_writes = device.pending_writes.lock(); + if pending_writes.contains_blas(self) { + pending_writes.consume_temp(temp); + } else { + let mut life_lock = device.lock_life(); + let last_submit_index = life_lock.get_blas_latest_submission_index(self); + if let Some(last_submit_index) = last_submit_index { + life_lock.schedule_resource_destruction(temp, last_submit_index); + } + } + + Ok(()) } } @@ -1943,7 +1983,7 @@ crate::impl_trackable!(Blas); #[derive(Debug)] pub struct Tlas { - pub(crate) raw: ManuallyDrop>, + pub(crate) raw: Snatchable>, pub(crate) device: Arc, pub(crate) size_info: hal::AccelerationStructureBuildSizes, pub(crate) max_instance_count: u32, @@ -1955,23 +1995,25 @@ pub struct Tlas { /// The `label` from the descriptor used to create the resource. pub(crate) label: String, pub(crate) tracking_data: TrackingData, + pub(crate) bind_groups: Mutex>, } impl Drop for Tlas { fn drop(&mut self) { unsafe { - let structure = ManuallyDrop::take(&mut self.raw); - let buffer = ManuallyDrop::take(&mut self.instance_buffer); resource_log!("Destroy raw {}", self.error_ident()); - self.device.raw().destroy_acceleration_structure(structure); + if let Some(structure) = self.raw.take() { + self.device.raw().destroy_acceleration_structure(structure); + } + let buffer = ManuallyDrop::take(&mut self.instance_buffer); self.device.raw().destroy_buffer(buffer); } } } impl AccelerationStructure for Tlas { - fn raw(&self) -> &dyn hal::DynAccelerationStructure { - self.raw.as_ref() + fn raw<'a>(&'a self, guard: &'a SnatchGuard) -> Option<&dyn hal::DynAccelerationStructure> { + Some(self.raw.get(guard)?.as_ref()) } } @@ -1980,3 +2022,74 @@ crate::impl_labeled!(Tlas); crate::impl_parent_device!(Tlas); crate::impl_storage_item!(Tlas); crate::impl_trackable!(Tlas); + +impl Tlas { + pub(crate) fn destroy(self: &Arc) -> Result<(), DestroyError> { + let device = &self.device; + + let temp = { + let mut snatch_guard = device.snatchable_lock.write(); + + let raw = match self.raw.snatch(&mut snatch_guard) { + Some(raw) => raw, + None => { + return Err(DestroyError::AlreadyDestroyed); + } + }; + + drop(snatch_guard); + + queue::TempResource::DestroyedAccelerationStructure(DestroyedAccelerationStructure { + raw: ManuallyDrop::new(raw), + device: Arc::clone(&self.device), + label: self.label().to_owned(), + bind_groups: mem::take(&mut self.bind_groups.lock()), + }) + }; + + let mut pending_writes = device.pending_writes.lock(); + if pending_writes.contains_tlas(self) { + pending_writes.consume_temp(temp); + } else { + let mut life_lock = device.lock_life(); + let last_submit_index = life_lock.get_tlas_latest_submission_index(self); + if let Some(last_submit_index) = last_submit_index { + life_lock.schedule_resource_destruction(temp, last_submit_index); + } + } + + Ok(()) + } +} + +#[derive(Debug)] +pub struct DestroyedAccelerationStructure { + raw: ManuallyDrop>, + device: Arc, + label: String, + // only filled if the acceleration structure is a TLAS + bind_groups: WeakVec, +} + +impl DestroyedAccelerationStructure { + pub fn label(&self) -> &dyn Debug { + &self.label + } +} + +impl Drop for DestroyedAccelerationStructure { + fn drop(&mut self) { + let mut deferred = self.device.deferred_destroy.lock(); + deferred.push(DeferredDestroy::BindGroups(mem::take( + &mut self.bind_groups, + ))); + drop(deferred); + + resource_log!("Destroy raw Buffer (destroyed) {:?}", self.label()); + // SAFETY: We are in the Drop impl and we don't use self.raw anymore after this point. + let raw = unsafe { ManuallyDrop::take(&mut self.raw) }; + unsafe { + hal::DynDevice::destroy_acceleration_structure(self.device.raw(), raw); + } + } +}