Skip to content

Commit

Permalink
allow functions to call trace rays
Browse files Browse the repository at this point in the history
  • Loading branch information
Vecvec committed Aug 30, 2024
1 parent b1bd27b commit e90967d
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 22 deletions.
2 changes: 1 addition & 1 deletion naga/src/back/spv/block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2666,7 +2666,7 @@ impl<'w> BlockContext<'w> {
self.write_ray_query_function(query, fun, &mut block);
}
Statement::RayTracing { ref fun } => {
self.write_ray_tracing_function(fun, &mut block, stage.unwrap(), interface)?;
self.write_ray_tracing_function(fun, &mut block, interface)?;
}
Statement::SubgroupBallot {
result,
Expand Down
20 changes: 7 additions & 13 deletions naga/src/back/spv/ray.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ Generating SPIR-V for ray query operations.
use super::{Block, BlockContext, Instruction, LocalType, LookupType};
use crate::arena::Handle;
use crate::proc::TypeResolution;
use crate::ShaderStage;

impl<'w> BlockContext<'w> {
pub(super) fn write_ray_query_function(
Expand Down Expand Up @@ -265,7 +264,6 @@ impl<'w> BlockContext<'w> {
&mut self,
function: &crate::RayTracingFunction,
block: &mut Block,
stage: ShaderStage,
interface: &mut Option<super::writer::FunctionInterface>,
) -> Result<(), super::Error> {
match *function {
Expand All @@ -278,7 +276,7 @@ impl<'w> BlockContext<'w> {
let acc_struct_id = self.get_handle_id(acceleration_structure);
let varying_id = self.writer.write_varying(
self.ir_module,
stage,
None,
spirv::StorageClass::RayPayloadKHR,
None,
payload_ty,
Expand Down Expand Up @@ -365,11 +363,9 @@ impl<'w> BlockContext<'w> {
block
.body
.push(Instruction::copy(payload_id, varying_id, None));
(interface
.as_mut()
.expect("can only call trace rays in ray gen entry"))
.varying_ids
.push(varying_id);
if let Some(interface) = interface.as_mut() {
interface.varying_ids.push(varying_id)
}
}
crate::RayTracingFunction::ReportIntersection {
hit_t,
Expand Down Expand Up @@ -400,11 +396,9 @@ impl<'w> BlockContext<'w> {
block
.body
.push(Instruction::store(pointer_type_id, intersection_id, None));
(interface
.as_mut()
.expect("can only call trace rays in ray gen entry"))
.varying_ids
.push(pointer_type_id);
if let Some(interface) = interface.as_mut() {
interface.varying_ids.push(pointer_type_id)
}
let result_id = self.gen_id();
let result_ty_id = self.writer.get_expression_type_id(&TypeResolution::Value(
crate::TypeInner::Scalar(crate::Scalar::BOOL),
Expand Down
18 changes: 10 additions & 8 deletions naga/src/back/spv/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ impl Writer {

let varying_id = self.write_varying(
ir_module,
iface.stage,
Some(iface.stage),
class,
name,
argument.ty,
Expand Down Expand Up @@ -412,7 +412,7 @@ impl Writer {
let binding = member.binding.as_ref().unwrap();
let varying_id = self.write_varying(
ir_module,
iface.stage,
Some(iface.stage),
class,
name,
member.ty,
Expand Down Expand Up @@ -477,7 +477,7 @@ impl Writer {
let type_id = self.get_type_id(LookupType::Handle(result.ty));
let varying_id = self.write_varying(
ir_module,
iface.stage,
Some(iface.stage),
class,
None,
result.ty,
Expand All @@ -500,7 +500,7 @@ impl Writer {
*binding == crate::Binding::BuiltIn(crate::BuiltIn::PointSize);
let varying_id = self.write_varying(
ir_module,
iface.stage,
Some(iface.stage),
class,
name,
member.ty,
Expand Down Expand Up @@ -1491,7 +1491,7 @@ impl Writer {
pub(super) fn write_varying(
&mut self,
ir_module: &crate::Module,
stage: crate::ShaderStage,
stage: Option<crate::ShaderStage>,
class: spirv::StorageClass,
debug_name: Option<&str>,
ty: Handle<crate::Type>,
Expand Down Expand Up @@ -1526,11 +1526,11 @@ impl Writer {
// VUID-StandaloneSpirv-Flat-06202
// > The Flat, NoPerspective, Sample, and Centroid decorations
// > must not be used on variables with the Input storage class in a vertex shader
(class == spirv::StorageClass::Input && stage == crate::ShaderStage::Vertex) ||
(class == spirv::StorageClass::Input && stage == Some(crate::ShaderStage::Vertex)) ||
// VUID-StandaloneSpirv-Flat-06201
// > The Flat, NoPerspective, Sample, and Centroid decorations
// > must not be used on variables with the Output storage class in a fragment shader
(class == spirv::StorageClass::Output && stage == crate::ShaderStage::Fragment);
(class == spirv::StorageClass::Output && stage == Some(crate::ShaderStage::Fragment));

if !no_decorations {
match interpolation {
Expand Down Expand Up @@ -1781,7 +1781,9 @@ impl Writer {
// > Any variable with integer or double-precision floating-
// > point type and with Input storage class in a fragment
// > shader, must be decorated Flat
if class == spirv::StorageClass::Input && stage == crate::ShaderStage::Fragment {
if class == spirv::StorageClass::Input
&& stage == Some(crate::ShaderStage::Fragment)
{
let is_flat = match ir_module.types[ty].inner {
crate::TypeInner::Scalar(scalar)
| crate::TypeInner::Vector { scalar, .. } => match scalar.kind {
Expand Down
6 changes: 6 additions & 0 deletions naga/tests/in/ray-pipeline.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,15 @@ fn ray_gen() {
}
*/
traceRay(acc_struct, RayDesc(), &colour);
trace();
return;
}

fn trace() {
var colour = vec4<f32>();
traceRay(acc_struct, RayDesc(), &colour);
}

@ray_any
fn discard_any_hit(@builtin(payload) colour: ptr<ray_tracing, vec4<f32>>, @builtin(intersection) intersection: TriRayIntersection) {
*colour = vec4<f32>();
Expand Down

0 comments on commit e90967d

Please sign in to comment.