Skip to content

Commit

Permalink
fix merge
Browse files Browse the repository at this point in the history
  • Loading branch information
Vecvec committed Sep 14, 2024
1 parent dbfbe56 commit fda4ff8
Show file tree
Hide file tree
Showing 12 changed files with 132 additions and 115 deletions.
4 changes: 2 additions & 2 deletions examples/src/ray_cube_compute/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -579,7 +579,7 @@ impl crate::framework::Example for Example {
timestamp_writes: None,
});
cpass.set_pipeline(&self.compute_pipeline);
cpass.set_bind_group(0, &self.compute_bind_group, &[]);
cpass.set_bind_group(0, Some(&self.compute_bind_group), &[]);
cpass.dispatch_workgroups(self.rt_target.width() / 8, self.rt_target.height() / 8, 1);
}

Expand All @@ -600,7 +600,7 @@ impl crate::framework::Example for Example {
});

rpass.set_pipeline(&self.blit_pipeline);
rpass.set_bind_group(0, &self.blit_bind_group, &[]);
rpass.set_bind_group(0, Some(&self.blit_bind_group), &[]);
rpass.draw(0..3, 0..1);
}

Expand Down
2 changes: 1 addition & 1 deletion examples/src/ray_cube_fragment/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ impl crate::framework::Example for Example {
});

rpass.set_pipeline(&self.pipeline);
rpass.set_bind_group(0, &self.bind_group, &[]);
rpass.set_bind_group(0, Some(&self.bind_group), &[]);
rpass.draw(0..3, 0..1);
}

Expand Down
2 changes: 1 addition & 1 deletion examples/src/ray_scene/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -541,7 +541,7 @@ impl crate::framework::Example for Example {
});

rpass.set_pipeline(&self.pipeline);
rpass.set_bind_group(0, &self.bind_group, &[]);
rpass.set_bind_group(0, Some(&self.bind_group), &[]);
rpass.draw(0..3, 0..1);
}

Expand Down
2 changes: 2 additions & 0 deletions wgpu-core/src/command/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,8 @@ impl CommandBufferMutable {
trackers: self.trackers,
buffer_memory_init_actions: self.buffer_memory_init_actions,
texture_memory_actions: self.texture_memory_actions,
blas_actions: self.blas_actions,
tlas_actions: self.tlas_actions,
}
}

Expand Down
75 changes: 35 additions & 40 deletions wgpu-core/src/command/ray_tracing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ use crate::{

use wgt::{math::align_to, BufferAddress, BufferUsages};

use super::{BakedCommands, CommandBufferMutable, CommandEncoderError};
use super::{BakedCommands, CommandBufferMutable};
use crate::ray_tracing::BlasTriangleGeometry;
use crate::resource::{
AccelerationStructure, Buffer, Labeled, ScratchBuffer, StagingBuffer, Trackable,
AccelerationStructure, Buffer, Fallible, Labeled, ScratchBuffer, StagingBuffer, Trackable,
};
use crate::snatch::SnatchGuard;
use crate::storage::Storage;
Expand Down Expand Up @@ -57,14 +57,9 @@ impl Global {

let hub = &self.hub;

let cmd_buf = match hub
let cmd_buf = hub
.command_buffers
.get(command_encoder_id.into_command_buffer_id())
{
Ok(cmd_buf) => cmd_buf,
Err(_) => return Err(CommandEncoderError::Invalid.into()),
};
cmd_buf.check_recording()?;
.get(command_encoder_id.into_command_buffer_id());

let buffer_guard = hub.buffers.read();
let blas_guard = hub.blas_s.read();
Expand Down Expand Up @@ -181,7 +176,7 @@ impl Global {

let mut scratch_buffer_tlas_size = 0;
let mut tlas_storage = Vec::<(
&Tlas,
Arc<Tlas>,
hal::AccelerationStructureEntries<dyn hal::DynBuffer>,
u64,
)>::new();
Expand All @@ -192,14 +187,14 @@ impl Global {
)>::new();

for entry in tlas_iter {
let instance_buffer = match buffer_guard.get(entry.instance_buffer_id) {
let instance_buffer = match buffer_guard.get(entry.instance_buffer_id).get() {
Ok(buffer) => buffer,
Err(_) => {
return Err(BuildAccelerationStructureError::InvalidBufferId);
}
};
let data = cmd_buf_data.trackers.buffers.set_single(
instance_buffer,
&instance_buffer,
BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT,
);
tlas_buf_storage.push((instance_buffer.clone(), data, entry.clone()));
Expand Down Expand Up @@ -228,6 +223,7 @@ impl Global {

let tlas = tlas_guard
.get(entry.tlas_id)
.get()
.map_err(|_| BuildAccelerationStructureError::InvalidTlasId)?;
cmd_buf_data.trackers.tlas_s.set_single(tlas.clone());

Expand Down Expand Up @@ -278,7 +274,7 @@ impl Global {
let tlas_descriptors =
tlas_storage
.iter()
.map(|&(tlas, ref entries, ref scratch_buffer_offset)| {
.map(|(tlas, entries, scratch_buffer_offset)| {
if tlas.update_mode == wgt::AccelerationStructureUpdateMode::PreferUpdate {
log::info!("only rebuild implemented")
}
Expand All @@ -296,7 +292,7 @@ impl Global {
let blas_present = !blas_storage.is_empty();
let tlas_present = !tlas_storage.is_empty();

let cmd_buf_raw = cmd_buf_data.encoder.open()?;
let cmd_buf_raw = cmd_buf_data.encoder.open(device)?;

build_blas(
cmd_buf_raw,
Expand Down Expand Up @@ -338,14 +334,9 @@ impl Global {

let hub = &self.hub;

let cmd_buf = match hub
let cmd_buf = hub
.command_buffers
.get(command_encoder_id.into_command_buffer_id())
{
Ok(cmd_buf) => cmd_buf,
Err(_) => return Err(CommandEncoderError::Invalid.into()),
};
cmd_buf.check_recording()?;
.get(command_encoder_id.into_command_buffer_id());

let buffer_guard = hub.buffers.read();
let blas_guard = hub.blas_s.read();
Expand Down Expand Up @@ -490,16 +481,16 @@ impl Global {
&mut scratch_buffer_blas_size,
&mut blas_storage,
)?;
let mut tlas_lock_store =
Vec::<(&dyn hal::DynBuffer, Option<TlasPackage>, Arc<Tlas>)>::new();
let mut tlas_lock_store = Vec::<(Option<TlasPackage>, Arc<Tlas>)>::new();

for package in tlas_iter {
let tlas = tlas_guard
.get(package.tlas_id)
.get()
.map_err(|_| BuildAccelerationStructureError::InvalidTlasId)?;

cmd_buf_data.trackers.tlas_s.set_single(tlas.clone());
tlas_lock_store.push((tlas.instance_buffer.as_ref(), Some(package), tlas.clone()))
tlas_lock_store.push((Some(package), tlas.clone()))
}

let mut scratch_buffer_tlas_size = 0;
Expand All @@ -511,9 +502,8 @@ impl Global {
)>::new();
let mut instance_buffer_staging_source = Vec::<u8>::new();

for entry in &mut tlas_lock_store {
let package = entry.1.take().unwrap();
let tlas = &entry.2;
for (package, tlas) in &mut tlas_lock_store {
let package = package.take().unwrap();

let scratch_buffer_offset = scratch_buffer_tlas_size;
scratch_buffer_tlas_size += align_to(
Expand All @@ -534,6 +524,7 @@ impl Global {
}
let blas = blas_guard
.get(instance.blas_id)
.get()
.map_err(|_| BuildAccelerationStructureError::InvalidBlasIdForInstance)?
.clone();

Expand Down Expand Up @@ -574,7 +565,7 @@ impl Global {
tlas_storage.push((
tlas,
hal::AccelerationStructureEntries::Instances(hal::AccelerationStructureInstances {
buffer: Some(entry.0),
buffer: Some(tlas.instance_buffer.as_ref()),
offset: 0,
count: instance_count,
}),
Expand Down Expand Up @@ -623,7 +614,7 @@ impl Global {
let blas_present = !blas_storage.is_empty();
let tlas_present = !tlas_storage.is_empty();

let cmd_buf_raw = cmd_buf_data.encoder.open()?;
let cmd_buf_raw = cmd_buf_data.encoder.open(device)?;

build_blas(
cmd_buf_raw,
Expand Down Expand Up @@ -789,13 +780,14 @@ fn iter_blas<'a>(
blas_iter: impl Iterator<Item = BlasBuildEntry<'a>>,
cmd_buf_data: &mut CommandBufferMutable,
build_command_index: NonZeroU64,
buffer_guard: &RwLockReadGuard<Storage<Buffer>>,
blas_guard: &RwLockReadGuard<Storage<Blas>>,
buffer_guard: &RwLockReadGuard<Storage<Fallible<Buffer>>>,
blas_guard: &RwLockReadGuard<Storage<Fallible<Blas>>>,
buf_storage: &mut BufferStorage<'a>,
) -> Result<(), BuildAccelerationStructureError> {
for entry in blas_iter {
let blas = blas_guard
.get(entry.blas_id)
.get()
.map_err(|_| BuildAccelerationStructureError::InvalidBlasId)?;
cmd_buf_data.trackers.blas_s.set_single(blas.clone());

Expand Down Expand Up @@ -837,16 +829,16 @@ fn iter_blas<'a>(
blas.error_ident(),
));
}
let vertex_buffer = match buffer_guard.get(mesh.vertex_buffer) {
let vertex_buffer = match buffer_guard.get(mesh.vertex_buffer).get() {
Ok(buffer) => buffer,
Err(_) => return Err(BuildAccelerationStructureError::InvalidBufferId),
};
let vertex_pending = cmd_buf_data.trackers.buffers.set_single(
vertex_buffer,
&vertex_buffer,
BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT,
);
let index_data = if let Some(index_id) = mesh.index_buffer {
let index_buffer = match buffer_guard.get(index_id) {
let index_buffer = match buffer_guard.get(index_id).get() {
Ok(buffer) => buffer,
Err(_) => return Err(BuildAccelerationStructureError::InvalidBufferId),
};
Expand All @@ -859,15 +851,15 @@ fn iter_blas<'a>(
));
}
let data = cmd_buf_data.trackers.buffers.set_single(
index_buffer,
&index_buffer,
hal::BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT,
);
Some((index_buffer.clone(), data))
} else {
None
};
let transform_data = if let Some(transform_id) = mesh.transform_buffer {
let transform_buffer = match buffer_guard.get(transform_id) {
let transform_buffer = match buffer_guard.get(transform_id).get() {
Ok(buffer) => buffer,
Err(_) => return Err(BuildAccelerationStructureError::InvalidBufferId),
};
Expand All @@ -877,10 +869,10 @@ fn iter_blas<'a>(
));
}
let data = cmd_buf_data.trackers.buffers.set_single(
transform_buffer,
&transform_buffer,
BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT,
);
Some((transform_buffer.clone(), data))
Some((transform_buffer, data))
} else {
None
};
Expand Down Expand Up @@ -909,7 +901,7 @@ fn iter_buffers<'a, 'b>(
snatch_guard: &'a SnatchGuard,
input_barriers: &mut Vec<hal::BufferBarrier<'a, dyn hal::DynBuffer>>,
cmd_buf_data: &mut CommandBufferMutable,
buffer_guard: &RwLockReadGuard<Storage<Buffer>>,
buffer_guard: &RwLockReadGuard<Storage<Fallible<Buffer>>>,
scratch_buffer_blas_size: &mut u64,
blas_storage: &mut BlasStorage<'a>,
) -> Result<(), BuildAccelerationStructureError> {
Expand Down Expand Up @@ -946,7 +938,10 @@ fn iter_buffers<'a, 'b>(
let vertex_buffer_offset = mesh.first_vertex as u64 * mesh.vertex_stride;
cmd_buf_data.buffer_memory_init_actions.extend(
vertex_buffer.initialization_status.read().create_action(
buffer_guard.get(mesh.vertex_buffer).unwrap(),
&buffer_guard
.get(mesh.vertex_buffer)
.get()
.map_err(|_| BuildAccelerationStructureError::InvalidBufferId)?,
vertex_buffer_offset
..(vertex_buffer_offset
+ mesh.size.vertex_count as u64 * mesh.vertex_stride),
Expand Down
19 changes: 14 additions & 5 deletions wgpu-core/src/device/global.rs
Original file line number Diff line number Diff line change
Expand Up @@ -758,10 +758,11 @@ impl Global {
.get()
.map_err(binding_model::CreateBindGroupError::from)
};
let map_tlas = |id: &id::TlasId| {
let resolve_tlas = |id: &id::TlasId| {
tlas_storage
.get_owned(*id)
.map_err(|_| binding_model::CreateBindGroupError::InvalidTlasId(*id))
.get(*id)
.get()
.map_err(binding_model::CreateBindGroupError::from)
};
let resource = match e.resource {
BindingResource::Buffer(ref buffer) => {
Expand Down Expand Up @@ -795,7 +796,7 @@ impl Global {
ResolvedBindingResource::TextureViewArray(Cow::Owned(views))
}
BindingResource::AccelerationStructure(ref tlas) => {
ResolvedBindingResource::AccelerationStructure(map_tlas(tlas)?)
ResolvedBindingResource::AccelerationStructure(resolve_tlas(tlas)?)
}
};
Ok(ResolvedBindGroupEntry {
Expand All @@ -811,7 +812,15 @@ impl Global {
let tlas_guard = hub.tlas_s.read();
desc.entries
.iter()
.map(|e| resolve_entry(e, &buffer_guard, &sampler_guard, &texture_view_guard, &tlas_guard))
.map(|e| {
resolve_entry(
e,
&buffer_guard,
&sampler_guard,
&texture_view_guard,
&tlas_guard,
)
})
.collect::<Result<Vec<_>, _>>()
};
let entries = match entries {
Expand Down
Loading

0 comments on commit fda4ff8

Please sign in to comment.