Skip to content

Commit

Permalink
add parse_and_output function
Browse files Browse the repository at this point in the history
  • Loading branch information
elftausend committed Jan 21, 2024
1 parent 730378d commit 05ed24b
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 29 deletions.
3 changes: 1 addition & 2 deletions src/devices/wgsl/error.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use core::fmt::Display;

use naga::{WithSpan, valid::ValidationError};

use naga::{valid::ValidationError, WithSpan};

#[derive(Debug)]
pub enum TranslateError {
Expand Down
55 changes: 55 additions & 0 deletions src/devices/wgsl/glsl.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
use naga::{
back::glsl::{Options, PipelineOptions},
proc::BoundsCheckPolicies,
valid::ModuleInfo,
};

use super::error::TranslateError;

pub struct Glsl {}

impl Glsl {
pub fn from_wgsl(src: impl AsRef<str>) {}
}

pub fn write_glsl(
module: &naga::Module,
info: &ModuleInfo,
shader_stage: naga::ShaderStage,
entry_point: &str,
) -> Result<String, TranslateError> {
let mut glsl = String::new();
let options = Options::default();
let pipeline_options = PipelineOptions {
shader_stage,
entry_point: entry_point.into(),
multiview: None,
};

let mut writer = naga::back::glsl::Writer::new(
&mut glsl,
module,
info,
&options,
&pipeline_options,
BoundsCheckPolicies::default(),
)
.map_err(TranslateError::BackendGlsl)?;
writer.write().map_err(TranslateError::BackendGlsl)?;
Ok(glsl)
}

#[cfg(test)]
mod tests {
use super::write_glsl;

#[test]
fn test_wgsl_to_glsl_translation() {
let wgsl = "
@fragment
fn fs_main() -> @location(0) vec4<f32> {
return vec4<f32>(1.0, 0.0, 0.0, 1.0);
}
";
}
}
34 changes: 33 additions & 1 deletion src/devices/wgsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,37 @@ mod glsl;
#[cfg(feature = "glsl")]
pub use glsl::*;

mod error;
mod wgsl_device;
mod error;

use self::error::TranslateError;
use naga::{valid::ModuleInfo, Module, ShaderStage};

pub fn parse_and_output<O>(
src: impl AsRef<str>,
output_fn: fn(&Module, &ModuleInfo, ShaderStage, &str) -> Result<O, TranslateError>,
) -> Result<Vec<O>, TranslateError> {
let (module, info) = parse_and_validate_wgsl(src.as_ref())?;

module
.entry_points
.iter()
.map(|entry_point| output_fn(&module, &info, entry_point.stage, &entry_point.name))
.collect::<Result<Vec<_>, _>>()
}

pub fn parse_and_validate_wgsl(src: &str) -> Result<(naga::Module, ModuleInfo), TranslateError> {
let mut frontend = naga::front::wgsl::Frontend::new();

let module = frontend.parse(src).map_err(TranslateError::Frontend)?;

let mut validator = naga::valid::Validator::new(
naga::valid::ValidationFlags::all(),
naga::valid::Capabilities::all(),
);

let info = validator
.validate(&module)
.map_err(TranslateError::Validate)?;
Ok((module, info))
}
31 changes: 5 additions & 26 deletions src/devices/wgsl/spirv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,18 @@ use naga::{
valid::ModuleInfo,
};

use super::error::TranslateError;
use super::{error::TranslateError, parse_and_output};

pub struct Spirv {
words_of_entries: Vec<Vec<u32>>,
}

impl Spirv {
#[inline]
pub fn from_wgsl(src: impl AsRef<str>) -> Result<Self, TranslateError> {
let (module, info) = parse_and_validate_src(src.as_ref())?;

let words_of_entries = module
.entry_points
.iter()
.map(|entry_point| write_spirv(&module, &info, entry_point.stage, &entry_point.name))
.collect::<Result<Vec<_>, _>>()?;

Ok(Spirv { words_of_entries })
Ok(Spirv {
words_of_entries: parse_and_output(src, write_spirv)?,
})
}

#[inline]
Expand All @@ -40,22 +35,6 @@ impl Spirv {
}
}

pub fn parse_and_validate_src(src: &str) -> Result<(naga::Module, ModuleInfo), TranslateError> {
let mut frontend = naga::front::wgsl::Frontend::new();

let module = frontend.parse(src).map_err(TranslateError::Frontend)?;

let mut validator = naga::valid::Validator::new(
naga::valid::ValidationFlags::all(),
naga::valid::Capabilities::all(),
);

let info = validator
.validate(&module)
.map_err(TranslateError::Validate)?;
Ok((module, info))
}

pub fn write_spirv(
module: &naga::Module,
info: &ModuleInfo,
Expand Down

0 comments on commit 05ed24b

Please sign in to comment.