Skip to content

Commit

Permalink
add blas and tlas to pending writes to remove (now unneeded) `device.…
Browse files Browse the repository at this point in the history
…poll()`s
  • Loading branch information
Vecvec committed Sep 28, 2024
1 parent 7662d99 commit a89c742
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 21 deletions.
17 changes: 7 additions & 10 deletions tests/tests/ray_tracing/as_build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,6 @@ fn out_of_order_as_build(ctx: TestingContext) {
ctx.queue
.submit([encoder_blas.finish(), encoder_tlas.finish()]);

ctx.device.poll(wgt::Maintain::Wait);

drop(as_ctx);

//
Expand Down Expand Up @@ -219,18 +217,19 @@ fn out_of_order_as_build_use(ctx: TestingContext) {

encoder_tlas.build_acceleration_structures([], [&as_ctx.tlas_package]);

let mut encoder_blas = ctx
let mut encoder_blas2 = ctx
.device
.create_command_encoder(&CommandEncoderDescriptor {
label: Some("BLAS 2"),
});

encoder_blas.build_acceleration_structures([&as_ctx.blas_build_entry()], []);
encoder_blas2.build_acceleration_structures([&as_ctx.blas_build_entry()], []);

ctx.queue
.submit([encoder_blas.finish(), encoder_tlas.finish()]);

ctx.device.poll(wgt::Maintain::Wait);
ctx.queue.submit([
encoder_blas.finish(),
encoder_tlas.finish(),
encoder_blas2.finish(),
]);

//
// Create shader to use tlas with
Expand Down Expand Up @@ -283,6 +282,4 @@ fn out_of_order_as_build_use(ctx: TestingContext) {
},
None,
);

ctx.device.poll(wgt::Maintain::Wait);
}
11 changes: 10 additions & 1 deletion wgpu-core/src/command/ray_tracing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@ use crate::{
use wgt::{math::align_to, BufferUsages, Features};

use super::CommandBufferMutable;
use crate::device::queue::PendingWrites;
use hal::BufferUses;
use std::mem::ManuallyDrop;
use std::ops::DerefMut;
use std::{
cmp::max,
num::NonZeroU64,
Expand Down Expand Up @@ -181,6 +184,7 @@ impl Global {
build_command_index,
&mut buf_storage,
hub,
device.pending_writes.lock().deref_mut(),
)?;

let snatch_guard = device.snatchable_lock.read();
Expand Down Expand Up @@ -244,6 +248,7 @@ impl Global {
.get()
.map_err(|_| BuildAccelerationStructureError::InvalidTlasId)?;
cmd_buf_data.trackers.tlas_s.set_single(tlas.clone());
device.pending_writes.lock().insert_tlas(&tlas);

cmd_buf_data.tlas_actions.push(TlasAction {
tlas: tlas.clone(),
Expand Down Expand Up @@ -485,6 +490,7 @@ impl Global {
build_command_index,
&mut buf_storage,
hub,
device.pending_writes.lock().deref_mut(),
)?;

let snatch_guard = device.snatchable_lock.read();
Expand All @@ -505,8 +511,9 @@ impl Global {
.get(package.tlas_id)
.get()
.map_err(|_| BuildAccelerationStructureError::InvalidTlasId)?;

device.pending_writes.lock().insert_tlas(&tlas);
cmd_buf_data.trackers.tlas_s.set_single(tlas.clone());

tlas_lock_store.push((Some(package), tlas.clone()))
}

Expand Down Expand Up @@ -812,6 +819,7 @@ fn iter_blas<'a>(
build_command_index: NonZeroU64,
buf_storage: &mut Vec<TriangleBufferStore<'a>>,
hub: &Hub,
pending_writes: &mut ManuallyDrop<PendingWrites>,
) -> Result<(), BuildAccelerationStructureError> {
let mut temp_buffer = Vec::new();
for entry in blas_iter {
Expand All @@ -821,6 +829,7 @@ fn iter_blas<'a>(
.get()
.map_err(|_| BuildAccelerationStructureError::InvalidBlasId)?;
cmd_buf_data.trackers.blas_s.set_single(blas.clone());
pending_writes.insert_blas(&blas);

cmd_buf_data.blas_actions.push(BlasAction {
blas: blas.clone(),
Expand Down
4 changes: 2 additions & 2 deletions wgpu-core/src/device/life.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ impl ActiveSubmission {
return true;
}

if encoder.pending_buffers.contains_key(&blas.tracker_index()) {
if encoder.pending_blas_s.contains_key(&blas.tracker_index()) {
return true;
}
}
Expand All @@ -135,7 +135,7 @@ impl ActiveSubmission {
return true;
}

if encoder.pending_buffers.contains_key(&tlas.tracker_index()) {
if encoder.pending_tlas_s.contains_key(&tlas.tracker_index()) {
return true;
}
}
Expand Down
24 changes: 24 additions & 0 deletions wgpu-core/src/device/queue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,10 @@ pub(crate) struct EncoderInFlight {
pub(crate) pending_buffers: FastHashMap<TrackerIndex, Arc<Buffer>>,
/// These are the textures that have been tracked by `PendingWrites`.
pub(crate) pending_textures: FastHashMap<TrackerIndex, Arc<Texture>>,
/// These are the BLASes that have been tracked by `PendingWrites`.
pub(crate) pending_blas_s: FastHashMap<TrackerIndex, Arc<Blas>>,
/// These are the TLASes that have been tracked by `PendingWrites`.
pub(crate) pending_tlas_s: FastHashMap<TrackerIndex, Arc<Tlas>>,
}

impl EncoderInFlight {
Expand All @@ -182,6 +186,8 @@ impl EncoderInFlight {
drop(self.trackers);
drop(self.pending_buffers);
drop(self.pending_textures);
drop(self.pending_blas_s);
drop(self.pending_tlas_s);
}
self.raw
}
Expand Down Expand Up @@ -221,6 +227,8 @@ pub(crate) struct PendingWrites {
temp_resources: Vec<TempResource>,
dst_buffers: FastHashMap<TrackerIndex, Arc<Buffer>>,
dst_textures: FastHashMap<TrackerIndex, Arc<Texture>>,
dst_blas_s: FastHashMap<TrackerIndex, Arc<Blas>>,
dst_tlas_s: FastHashMap<TrackerIndex, Arc<Tlas>>,
}

impl PendingWrites {
Expand All @@ -231,6 +239,8 @@ impl PendingWrites {
temp_resources: Vec::new(),
dst_buffers: FastHashMap::default(),
dst_textures: FastHashMap::default(),
dst_blas_s: FastHashMap::default(),
dst_tlas_s: FastHashMap::default(),
}
}

Expand Down Expand Up @@ -263,6 +273,14 @@ impl PendingWrites {
self.dst_textures.contains_key(&texture.tracker_index())
}

pub fn insert_blas(&mut self, blas: &Arc<Blas>) {
self.dst_blas_s.insert(blas.tracker_index(), blas.clone());
}

pub fn insert_tlas(&mut self, tlas: &Arc<Tlas>) {
self.dst_tlas_s.insert(tlas.tracker_index(), tlas.clone());
}

pub fn consume_temp(&mut self, resource: TempResource) {
self.temp_resources.push(resource);
}
Expand All @@ -281,6 +299,8 @@ impl PendingWrites {
if self.is_recording {
let pending_buffers = mem::take(&mut self.dst_buffers);
let pending_textures = mem::take(&mut self.dst_textures);
let pending_blas_s = mem::take(&mut self.dst_blas_s);
let pending_tlas_s = mem::take(&mut self.dst_tlas_s);

let cmd_buf = unsafe { self.command_encoder.end_encoding() }
.map_err(|e| device.handle_hal_error(e))?;
Expand All @@ -296,6 +316,8 @@ impl PendingWrites {
trackers: Tracker::new(),
pending_buffers,
pending_textures,
pending_blas_s,
pending_tlas_s,
};
Ok(Some(encoder))
} else {
Expand Down Expand Up @@ -1201,6 +1223,8 @@ impl Global {
trackers: baked.trackers,
pending_buffers: FastHashMap::default(),
pending_textures: FastHashMap::default(),
pending_blas_s: FastHashMap::default(),
pending_tlas_s: FastHashMap::default(),
});
}

Expand Down
10 changes: 2 additions & 8 deletions wgpu-core/src/device/ray_tracing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -265,17 +265,11 @@ impl Global {

pub fn blas_destroy(&self, blas_id: BlasId) -> Result<(), resource::DestroyError> {
profiling::scope!("Blas::destroy");
log::info!("Blas::destroy {blas_id:?}");

let hub = &self.hub;

log::info!("Blas {:?} is destroyed", blas_id);
let blas_guard = hub.blas_s.write();
let blas = blas_guard
.get(blas_id)
.get()
.map_err(resource::DestroyError::InvalidResource)?
.clone();
drop(blas_guard);
let blas = hub.blas_s.get(blas_id).get()?;
let device = &blas.device;

#[cfg(feature = "trace")]
Expand Down

0 comments on commit a89c742

Please sign in to comment.