Skip to content

Commit

Permalink
use snatchables for raw acceleration structures
Browse files Browse the repository at this point in the history
  • Loading branch information
Vecvec committed Oct 20, 2024
1 parent e777b76 commit 66c0cbb
Show file tree
Hide file tree
Showing 9 changed files with 235 additions and 93 deletions.
112 changes: 71 additions & 41 deletions wgpu-core/src/command/ray_tracing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -292,48 +292,53 @@ impl Global {
..BufferUses::ACCELERATION_STRUCTURE_SCRATCH,
};

let blas_descriptors = blas_storage
.iter()
.map(|storage| map_blas(storage, scratch_buffer.raw()));

let tlas_descriptors = tlas_storage.iter().map(
|UnsafeTlasStore {
tlas,
entries,
scratch_buffer_offset,
}| {
if tlas.update_mode == wgt::AccelerationStructureUpdateMode::PreferUpdate {
log::info!("only rebuild implemented")
}
hal::BuildAccelerationStructureDescriptor {
entries,
mode: hal::AccelerationStructureBuildMode::Build,
flags: tlas.flags,
source_acceleration_structure: None,
destination_acceleration_structure: tlas.raw(),
scratch_buffer: scratch_buffer.raw(),
scratch_buffer_offset: *scratch_buffer_offset,
}
},
);
let mut tlas_descriptors = Vec::new();

for UnsafeTlasStore {
tlas,
entries,
scratch_buffer_offset,
} in &tlas_storage
{
if tlas.update_mode == wgt::AccelerationStructureUpdateMode::PreferUpdate {
log::info!("only rebuild implemented")
}
tlas_descriptors.push(hal::BuildAccelerationStructureDescriptor {
entries,
mode: hal::AccelerationStructureBuildMode::Build,
flags: tlas.flags,
source_acceleration_structure: None,
destination_acceleration_structure: tlas.raw(&snatch_guard).ok_or(
BuildAccelerationStructureError::InvalidTlas(tlas.error_ident()),
)?,
scratch_buffer: scratch_buffer.raw(),
scratch_buffer_offset: *scratch_buffer_offset,
})
}

let blas_present = !blas_storage.is_empty();
let tlas_present = !tlas_storage.is_empty();

let cmd_buf_raw = cmd_buf_data.encoder.open(device)?;

let mut descriptors = Vec::new();

for storage in &blas_storage {
descriptors.push(map_blas(storage, scratch_buffer.raw(), &snatch_guard)?);
}

build_blas(
cmd_buf_raw,
blas_present,
tlas_present,
input_barriers,
&blas_descriptors.collect::<Vec<_>>(),
&descriptors,
scratch_buffer_barrier,
);

if tlas_present {
unsafe {
cmd_buf_raw.build_acceleration_structures(&tlas_descriptors.collect::<Vec<_>>());
cmd_buf_raw.build_acceleration_structures(&tlas_descriptors);

cmd_buf_raw.place_acceleration_structure_barrier(
hal::AccelerationStructureBarrier {
Expand Down Expand Up @@ -614,10 +619,6 @@ impl Global {
..BufferUses::ACCELERATION_STRUCTURE_SCRATCH,
};

let blas_descriptors = blas_storage
.iter()
.map(|storage| map_blas(storage, scratch_buffer.raw()));

let mut tlas_descriptors = Vec::with_capacity(tlas_storage.len());

for &TlasStore {
Expand All @@ -638,7 +639,13 @@ impl Global {
mode: hal::AccelerationStructureBuildMode::Build,
flags: tlas.flags,
source_acceleration_structure: None,
destination_acceleration_structure: tlas.raw.as_ref(),
destination_acceleration_structure: tlas
.raw
.get(&snatch_guard)
.ok_or(BuildAccelerationStructureError::InvalidTlas(
tlas.error_ident(),
))?
.as_ref(),
scratch_buffer: scratch_buffer.raw(),
scratch_buffer_offset: *scratch_buffer_offset,
})
Expand All @@ -649,12 +656,18 @@ impl Global {

let cmd_buf_raw = cmd_buf_data.encoder.open(device)?;

let mut descriptors = Vec::new();

for storage in &blas_storage {
descriptors.push(map_blas(storage, scratch_buffer.raw(), &snatch_guard)?);
}

build_blas(
cmd_buf_raw,
blas_present,
tlas_present,
input_barriers,
&blas_descriptors.collect::<Vec<_>>(),
&descriptors,
scratch_buffer_barrier,
);

Expand Down Expand Up @@ -771,7 +784,10 @@ impl CommandBufferMutable {
}

// makes sure a tlas is built before it is used
pub(crate) fn validate_tlas_actions(&self) -> Result<(), ValidateTlasActionsError> {
pub(crate) fn validate_tlas_actions(
&self,
snatch_guard: &SnatchGuard,
) -> Result<(), ValidateTlasActionsError> {
profiling::scope!("CommandEncoder::[submission]::validate_tlas_actions");
for action in &self.tlas_actions {
match &action.kind {
Expand All @@ -794,8 +810,9 @@ impl CommandBufferMutable {
for blas in dependencies.deref() {
let blas_build_index = *blas.built_index.read();
if blas_build_index.is_none() {
return Err(ValidateTlasActionsError::UsedUnbuilt(
return Err(ValidateTlasActionsError::UsedUnbuiltBlas(
action.tlas.error_ident(),
blas.error_ident(),
));
}
if blas_build_index.unwrap() > tlas_build_index.unwrap() {
Expand All @@ -804,6 +821,9 @@ impl CommandBufferMutable {
action.tlas.error_ident(),
));
}
if blas.raw.get(snatch_guard).is_none() {
return Err(ValidateTlasActionsError::InvalidBlas(blas.error_ident()));
}
}
}
}
Expand Down Expand Up @@ -1180,10 +1200,14 @@ fn iter_buffers<'a, 'b>(
fn map_blas<'a>(
storage: &'a BlasStore<'_>,
scratch_buffer: &'a dyn hal::DynBuffer,
) -> hal::BuildAccelerationStructureDescriptor<
'a,
dyn hal::DynBuffer,
dyn hal::DynAccelerationStructure,
snatch_guard: &'a SnatchGuard,
) -> Result<
hal::BuildAccelerationStructureDescriptor<
'a,
dyn hal::DynBuffer,
dyn hal::DynAccelerationStructure,
>,
BuildAccelerationStructureError,
> {
let BlasStore {
blas,
Expand All @@ -1193,15 +1217,21 @@ fn map_blas<'a>(
if blas.update_mode == wgt::AccelerationStructureUpdateMode::PreferUpdate {
log::info!("only rebuild implemented")
}
hal::BuildAccelerationStructureDescriptor {
Ok(hal::BuildAccelerationStructureDescriptor {
entries,
mode: hal::AccelerationStructureBuildMode::Build,
flags: blas.flags,
source_acceleration_structure: None,
destination_acceleration_structure: blas.raw.as_ref(),
destination_acceleration_structure: blas
.raw
.get(snatch_guard)
.ok_or(BuildAccelerationStructureError::InvalidBlas(
blas.error_ident(),
))?
.as_ref(),
scratch_buffer,
scratch_buffer_offset: *scratch_buffer_offset,
}
})
}

fn build_blas<'a>(
Expand Down
2 changes: 1 addition & 1 deletion wgpu-core/src/device/life.rs
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ impl LifetimeTracker {
}

/// Returns the submission index of the most recent submission that uses the
/// given blas.
/// given tlas.
pub fn get_tlas_latest_submission_index(&self, tlas: &Tlas) -> Option<SubmissionIndex> {
// We iterate in reverse order, so that we can bail out early as soon
// as we find a hit.
Expand Down
15 changes: 11 additions & 4 deletions wgpu-core/src/device/queue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use crate::{

use smallvec::SmallVec;

use crate::resource::{Blas, Tlas};
use crate::resource::{Blas, DestroyedAccelerationStructure, Tlas};
use crate::scratch::ScratchBuffer;
use std::{
iter,
Expand Down Expand Up @@ -148,8 +148,7 @@ pub enum TempResource {
ScratchBuffer(ScratchBuffer),
DestroyedBuffer(DestroyedBuffer),
DestroyedTexture(DestroyedTexture),
Blas(Arc<Blas>),
Tlas(Arc<Tlas>),
DestroyedAccelerationStructure(DestroyedAccelerationStructure),
}

/// A series of raw [`CommandBuffer`]s that have been submitted to a
Expand Down Expand Up @@ -281,6 +280,14 @@ impl PendingWrites {
self.dst_tlas_s.insert(tlas.tracker_index(), tlas.clone());
}

pub fn contains_blas(&mut self, blas: &Arc<Blas>) -> bool {
self.dst_blas_s.contains_key(&blas.tracker_index())
}

pub fn contains_tlas(&mut self, tlas: &Arc<Tlas>) -> bool {
self.dst_tlas_s.contains_key(&tlas.tracker_index())
}

pub fn consume_temp(&mut self, resource: TempResource) {
self.temp_resources.push(resource);
}
Expand Down Expand Up @@ -1492,7 +1499,7 @@ fn validate_command_buffer(
if let Err(e) = cmd_buf_data.validate_blas_actions() {
return Err(e.into());
}
if let Err(e) = cmd_buf_data.validate_tlas_actions() {
if let Err(e) = cmd_buf_data.validate_tlas_actions(snatch_guard) {
return Err(e.into());
}
}
Expand Down
41 changes: 13 additions & 28 deletions wgpu-core/src/device/ray_tracing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@ use std::sync::Arc;

#[cfg(feature = "trace")]
use crate::device::trace;
use crate::lock::rank;
use crate::lock::{rank, Mutex};
use crate::resource::{Fallible, TrackingData};
use crate::snatch::Snatchable;
use crate::weak_vec::WeakVec;
use crate::{
device::{queue::TempResource, Device, DeviceError},
device::{Device, DeviceError},
global::Global,
id::{self, BlasId, TlasId},
lock::RwLock,
Expand Down Expand Up @@ -90,7 +92,7 @@ impl Device {
};

Ok(Arc::new(resource::Blas {
raw: ManuallyDrop::new(raw),
raw: Snatchable::new(raw),
device: self.clone(),
size_info,
sizes,
Expand Down Expand Up @@ -146,7 +148,7 @@ impl Device {
.map_err(DeviceError::from_hal)?;

Ok(Arc::new(resource::Tlas {
raw: ManuallyDrop::new(raw),
raw: Snatchable::new(raw),
device: self.clone(),
size_info,
flags: desc.flags,
Expand All @@ -157,6 +159,7 @@ impl Device {
label: desc.label.to_string(),
max_instance_count: desc.max_instances,
tracking_data: TrackingData::new(self.tracker_indices.tlas_s.clone()),
bind_groups: Mutex::new(rank::TLAS_BIND_GROUPS, WeakVec::new()),
}))
}
}
Expand Down Expand Up @@ -270,23 +273,14 @@ impl Global {
let hub = &self.hub;

let blas = hub.blas_s.get(blas_id).get()?;
let device = &blas.device;
let _device = &blas.device;

#[cfg(feature = "trace")]
if let Some(trace) = device.trace.lock().as_mut() {
if let Some(trace) = _device.trace.lock().as_mut() {
trace.add(trace::Action::FreeBlas(blas_id));
}

let temp = TempResource::Blas(blas.clone());
{
let mut device_lock = device.lock_life();
let last_submit_index = device_lock.get_blas_latest_submission_index(blas.as_ref());
if let Some(last_submit_index) = last_submit_index {
device_lock.schedule_resource_destruction(temp, last_submit_index);
}
}

Ok(())
blas.destroy()
}

pub fn blas_drop(&self, blas_id: BlasId) {
Expand Down Expand Up @@ -326,23 +320,14 @@ impl Global {
.clone();
drop(tlas_guard);

let device = &mut tlas.device.clone();
let _device = &mut tlas.device.clone();

#[cfg(feature = "trace")]
if let Some(trace) = device.trace.lock().as_mut() {
if let Some(trace) = _device.trace.lock().as_mut() {
trace.add(trace::Action::FreeTlas(tlas_id));
}

let temp = TempResource::Tlas(tlas.clone());
{
let mut device_lock = device.lock_life();
let last_submit_index = device_lock.get_tlas_latest_submission_index(tlas.as_ref());
if let Some(last_submit_index) = last_submit_index {
device_lock.schedule_resource_destruction(temp, last_submit_index);
}
}

Ok(())
tlas.destroy()
}

pub fn tlas_drop(&self, tlas_id: TlasId) {
Expand Down
10 changes: 7 additions & 3 deletions wgpu-core/src/device/resource.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ use wgt::{
math::align_to, DeviceLostReason, TextureFormat, TextureSampleType, TextureViewDimension,
};

use crate::resource::{AccelerationStructure, Tlas};
use crate::resource::{AccelerationStructure, DestroyedResourceError, Tlas};
use std::{
borrow::Cow,
mem::{self, ManuallyDrop},
Expand Down Expand Up @@ -2200,6 +2200,7 @@ impl Device {
binding: u32,
decl: &wgt::BindGroupLayoutEntry,
tlas: &'a Arc<Tlas>,
snatch_guard: &'a SnatchGuard<'a>,
) -> Result<&'a dyn hal::DynAccelerationStructure, binding_model::CreateBindGroupError> {
use crate::binding_model::CreateBindGroupError as Error;

Expand All @@ -2218,7 +2219,9 @@ impl Device {
}
}

Ok(tlas.raw())
Ok(tlas
.raw(snatch_guard)
.ok_or(DestroyedResourceError(tlas.error_ident()))?)
}

// This function expects the provided bind group layout to be resolved
Expand Down Expand Up @@ -2361,7 +2364,8 @@ impl Device {
(res_index, num_bindings)
}
Br::AccelerationStructure(ref tlas) => {
let tlas = self.create_tlas_binding(&mut used, binding, decl, tlas)?;
let tlas =
self.create_tlas_binding(&mut used, binding, decl, tlas, &snatch_guard)?;
let res_index = hal_tlas_s.len();
hal_tlas_s.push(tlas);
(res_index, 1)
Expand Down
4 changes: 3 additions & 1 deletion wgpu-core/src/hub.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,9 @@ use crate::{
instance::Adapter,
pipeline::{ComputePipeline, PipelineCache, RenderPipeline, ShaderModule},
registry::{Registry, RegistryReport},
resource::{Blas, Buffer, Fallible, QuerySet, Sampler, StagingBuffer, Texture, TextureView, Tlas},
resource::{
Blas, Buffer, Fallible, QuerySet, Sampler, StagingBuffer, Texture, TextureView, Tlas,
},
};
use std::{fmt::Debug, sync::Arc};

Expand Down
Loading

0 comments on commit 66c0cbb

Please sign in to comment.