Skip to content

Commit

Permalink
properly propagate global variables added by ray-tracing
Browse files Browse the repository at this point in the history
  • Loading branch information
Vecvec committed Aug 31, 2024
1 parent e90967d commit f62e310
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 84 deletions.
13 changes: 5 additions & 8 deletions naga/src/back/spv/block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2047,7 +2047,6 @@ impl<'w> BlockContext<'w> {
loop_context: LoopContext,
debug_info: Option<&DebugInfoInner>,
stage: Option<crate::ShaderStage>,
interface: &mut Option<super::writer::FunctionInterface>,
) -> Result<(), Error> {
let mut block = Block::new(label_id);
for (statement, span) in naga_block.span_iter() {
Expand Down Expand Up @@ -2091,7 +2090,6 @@ impl<'w> BlockContext<'w> {
loop_context,
debug_info,
stage,
interface,
)?;

block = Block::new(merge_id);
Expand Down Expand Up @@ -2137,7 +2135,6 @@ impl<'w> BlockContext<'w> {
loop_context,
debug_info,
stage,
interface,
)?;
}
if let Some(block_id) = reject_id {
Expand All @@ -2148,7 +2145,6 @@ impl<'w> BlockContext<'w> {
loop_context,
debug_info,
stage,
interface,
)?;
}

Expand Down Expand Up @@ -2230,7 +2226,6 @@ impl<'w> BlockContext<'w> {
inner_context,
debug_info,
stage,
interface,
)?;
}

Expand Down Expand Up @@ -2281,7 +2276,6 @@ impl<'w> BlockContext<'w> {
},
debug_info,
stage,
interface,
)?;

let exit = match break_if {
Expand All @@ -2304,7 +2298,6 @@ impl<'w> BlockContext<'w> {
},
debug_info,
stage,
interface,
)?;

block = Block::new(merge_id);
Expand Down Expand Up @@ -2415,6 +2408,10 @@ impl<'w> BlockContext<'w> {
ref arguments,
result,
} => {
self.ray_tracing_global_vars.append(
&mut self.writer.lookup_ray_global_variables[&local_function].to_vec(),
);

let id = self.gen_id();
self.temp_list.clear();
for &argument in arguments {
Expand Down Expand Up @@ -2666,7 +2663,7 @@ impl<'w> BlockContext<'w> {
self.write_ray_query_function(query, fun, &mut block);
}
Statement::RayTracing { ref fun } => {
self.write_ray_tracing_function(fun, &mut block, interface)?;
self.write_ray_tracing_function(fun, &mut block)?;
}
Statement::SubgroupBallot {
result,
Expand Down
3 changes: 3 additions & 0 deletions naga/src/back/spv/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,8 @@ struct BlockContext<'w> {

/// Tracks the constness of `Expression`s residing in `self.ir_function.expressions`
expression_constness: ExpressionConstnessTracker,

ray_tracing_global_vars: &'w mut Vec<Word>,
}

impl BlockContext<'_> {
Expand Down Expand Up @@ -661,6 +663,7 @@ pub struct Writer {
lookup_type: crate::FastHashMap<LookupType, Word>,
lookup_function: crate::FastHashMap<Handle<crate::Function>, Word>,
lookup_function_type: crate::FastHashMap<LookupFunctionType, Word>,
lookup_ray_global_variables: crate::FastHashMap<Handle<crate::Function>, Box<[Word]>>,
/// Indexed by const-expression handle indexes
constant_ids: HandleVec<crate::Expression, Word>,
cached_constants: crate::FastHashMap<CachedConstant, Word>,
Expand Down
9 changes: 2 additions & 7 deletions naga/src/back/spv/ray.rs
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,6 @@ impl<'w> BlockContext<'w> {
&mut self,
function: &crate::RayTracingFunction,
block: &mut Block,
interface: &mut Option<super::writer::FunctionInterface>,
) -> Result<(), super::Error> {
match *function {
crate::RayTracingFunction::TraceRay {
Expand Down Expand Up @@ -363,9 +362,7 @@ impl<'w> BlockContext<'w> {
block
.body
.push(Instruction::copy(payload_id, varying_id, None));
if let Some(interface) = interface.as_mut() {
interface.varying_ids.push(varying_id)
}
self.ray_tracing_global_vars.push(varying_id);
}
crate::RayTracingFunction::ReportIntersection {
hit_t,
Expand Down Expand Up @@ -396,9 +393,7 @@ impl<'w> BlockContext<'w> {
block
.body
.push(Instruction::store(pointer_type_id, intersection_id, None));
if let Some(interface) = interface.as_mut() {
interface.varying_ids.push(pointer_type_id)
}
self.ray_tracing_global_vars.push(pointer_type_id);
let result_id = self.gen_id();
let result_ty_id = self.writer.get_expression_type_id(&TypeResolution::Value(
crate::TypeInner::Scalar(crate::Scalar::BOOL),
Expand Down
16 changes: 11 additions & 5 deletions naga/src/back/spv/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ impl Writer {
lookup_type: crate::FastHashMap::default(),
lookup_function: crate::FastHashMap::default(),
lookup_function_type: crate::FastHashMap::default(),
lookup_ray_global_variables: crate::FastHashMap::default(),
constant_ids: HandleVec::new(),
cached_constants: crate::FastHashMap::default(),
global_variables: HandleVec::new(),
Expand Down Expand Up @@ -125,6 +126,7 @@ impl Writer {
lookup_type: take(&mut self.lookup_type).recycle(),
lookup_function: take(&mut self.lookup_function).recycle(),
lookup_function_type: take(&mut self.lookup_function_type).recycle(),
lookup_ray_global_variables: take(&mut self.lookup_ray_global_variables).recycle(),
constant_ids: take(&mut self.constant_ids).recycle(),
cached_constants: take(&mut self.cached_constants).recycle(),
global_variables: take(&mut self.global_variables).recycle(),
Expand Down Expand Up @@ -333,7 +335,7 @@ impl Writer {
mut interface: Option<FunctionInterface>,
debug_info: &Option<DebugInfoInner>,
stage: Option<crate::ShaderStage>,
) -> Result<Word, Error> {
) -> Result<(Word, Box<[Word]>), Error> {
let mut function = Function::default();

let prelude_id = self.id_gen.next();
Expand Down Expand Up @@ -627,6 +629,8 @@ impl Writer {
self.global_variables[handle] = gv;
}

let mut ray_global_vars = Vec::new();

// Create a `BlockContext` for generating SPIR-V for the function's
// body.
let mut context = BlockContext {
Expand All @@ -643,6 +647,7 @@ impl Writer {
expression_constness: super::ExpressionConstnessTracker::from_arena(
&ir_function.expressions,
),
ray_tracing_global_vars: &mut ray_global_vars,
};

// fill up the pre-emitted and const expressions
Expand Down Expand Up @@ -737,7 +742,6 @@ impl Writer {
LoopContext::default(),
debug_info.as_ref(),
stage,
&mut interface,
)?;

// Consume the `BlockContext`, ending its borrows and letting the
Expand All @@ -751,7 +755,7 @@ impl Writer {
function.to_words(&mut self.logical_layout.function_definitions);
Instruction::function_end().to_words(&mut self.logical_layout.function_definitions);

Ok(function_id)
Ok((function_id, Box::from(ray_global_vars)))
}

fn write_execution_mode(
Expand All @@ -774,7 +778,7 @@ impl Writer {
debug_info: &Option<DebugInfoInner>,
) -> Result<Instruction, Error> {
let mut interface_ids = Vec::new();
let function_id = self.write_function(
let (function_id, ray_global_vars) = self.write_function(
&entry_point.function,
info,
ir_module,
Expand All @@ -785,6 +789,7 @@ impl Writer {
debug_info,
Some(entry_point.stage),
)?;
interface_ids.append(&mut ray_global_vars.to_vec());

let exec_model = match entry_point.stage {
crate::ShaderStage::Vertex => spirv::ExecutionModel::Vertex,
Expand Down Expand Up @@ -2198,7 +2203,8 @@ impl Writer {
}
let id =
self.write_function(ir_function, info, ir_module, None, &debug_info_inner, None)?;
self.lookup_function.insert(handle, id);
self.lookup_function.insert(handle, id.0);
self.lookup_ray_global_variables.insert(handle, id.1);
}

// write all or one entry points
Expand Down
Loading

0 comments on commit f62e310

Please sign in to comment.