|
| 1 | +use super::backend::{self, ComputeBackend}; |
1 | 2 | use crate::config::Config; |
| 3 | +use crate::scaffold::shader::RustComputeShader; |
| 4 | +use crate::scaffold::shader::WgpuShader; |
| 5 | +use crate::scaffold::shader::WgslComputeShader; |
2 | 6 | use anyhow::Context; |
3 | 7 | use bytemuck::Pod; |
4 | 8 | use futures::executor::block_on; |
5 | | -use spirv_builder::{ModuleResult, SpirvBuilder}; |
6 | | -use std::{ |
7 | | - borrow::Cow, |
8 | | - env, |
9 | | - fs::{self, File}, |
10 | | - io::Write, |
11 | | - path::PathBuf, |
12 | | - sync::Arc, |
13 | | -}; |
| 9 | +use std::{borrow::Cow, fs::File, io::Write, sync::Arc}; |
14 | 10 | use wgpu::{PipelineCompilationOptions, util::DeviceExt}; |
15 | 11 |
|
16 | | -use super::backend::{self, ComputeBackend}; |
17 | | - |
18 | 12 | pub type BufferConfig = backend::BufferConfig; |
19 | 13 | pub type BufferUsage = backend::BufferUsage; |
20 | 14 |
|
21 | | -/// Trait for shaders that can provide SPIRV bytes. |
22 | | -pub trait SpirvShader { |
23 | | - /// Returns the SPIRV bytes and entry point name. |
24 | | - fn spirv_bytes(&self) -> anyhow::Result<(Vec<u8>, String)>; |
25 | | -} |
26 | | - |
27 | | -/// Trait for shaders that can create wgpu modules. |
28 | | -pub trait WgpuShader { |
29 | | - /// Creates a wgpu shader module. |
30 | | - fn create_module( |
31 | | - &self, |
32 | | - device: &wgpu::Device, |
33 | | - ) -> anyhow::Result<(wgpu::ShaderModule, Option<String>)>; |
34 | | -} |
35 | | - |
36 | | -/// A compute shader written in Rust compiled with spirv-builder. |
37 | | -pub struct RustComputeShader { |
38 | | - pub path: PathBuf, |
39 | | - pub target: String, |
40 | | - pub capabilities: Vec<spirv_builder::Capability>, |
41 | | -} |
42 | | - |
43 | | -impl RustComputeShader { |
44 | | - pub fn new<P: Into<PathBuf>>(path: P) -> Self { |
45 | | - Self { |
46 | | - path: path.into(), |
47 | | - target: "spirv-unknown-vulkan1.1".to_string(), |
48 | | - capabilities: Vec::new(), |
49 | | - } |
50 | | - } |
51 | | - |
52 | | - pub fn with_target<P: Into<PathBuf>>(path: P, target: impl Into<String>) -> Self { |
53 | | - Self { |
54 | | - path: path.into(), |
55 | | - target: target.into(), |
56 | | - capabilities: Vec::new(), |
57 | | - } |
58 | | - } |
59 | | - |
60 | | - pub fn with_capability(mut self, capability: spirv_builder::Capability) -> Self { |
61 | | - self.capabilities.push(capability); |
62 | | - self |
63 | | - } |
64 | | -} |
65 | | - |
66 | | -impl SpirvShader for RustComputeShader { |
67 | | - fn spirv_bytes(&self) -> anyhow::Result<(Vec<u8>, String)> { |
68 | | - let mut builder = SpirvBuilder::new(&self.path, &self.target) |
69 | | - .print_metadata(spirv_builder::MetadataPrintout::None) |
70 | | - .release(true) |
71 | | - .multimodule(false) |
72 | | - .shader_panic_strategy(spirv_builder::ShaderPanicStrategy::SilentExit) |
73 | | - .preserve_bindings(true); |
74 | | - |
75 | | - for capability in &self.capabilities { |
76 | | - builder = builder.capability(*capability); |
77 | | - } |
78 | | - |
79 | | - let artifact = builder.build().context("SpirvBuilder::build() failed")?; |
80 | | - |
81 | | - if artifact.entry_points.len() != 1 { |
82 | | - anyhow::bail!( |
83 | | - "Expected exactly one entry point, found {}", |
84 | | - artifact.entry_points.len() |
85 | | - ); |
86 | | - } |
87 | | - let entry_point = artifact.entry_points.into_iter().next().unwrap(); |
88 | | - |
89 | | - let shader_bytes = match artifact.module { |
90 | | - ModuleResult::SingleModule(path) => fs::read(&path) |
91 | | - .with_context(|| format!("reading spv file '{}' failed", path.display()))?, |
92 | | - ModuleResult::MultiModule(_modules) => { |
93 | | - anyhow::bail!("MultiModule modules produced"); |
94 | | - } |
95 | | - }; |
96 | | - |
97 | | - Ok((shader_bytes, entry_point)) |
98 | | - } |
99 | | -} |
100 | | - |
101 | | -impl WgpuShader for RustComputeShader { |
102 | | - fn create_module( |
103 | | - &self, |
104 | | - device: &wgpu::Device, |
105 | | - ) -> anyhow::Result<(wgpu::ShaderModule, Option<String>)> { |
106 | | - let (shader_bytes, entry_point) = self.spirv_bytes()?; |
107 | | - |
108 | | - if shader_bytes.len() % 4 != 0 { |
109 | | - anyhow::bail!("SPIR-V binary length is not a multiple of 4"); |
110 | | - } |
111 | | - let shader_words: Vec<u32> = bytemuck::cast_slice(&shader_bytes).to_vec(); |
112 | | - let module = device.create_shader_module(wgpu::ShaderModuleDescriptor { |
113 | | - label: Some("Compute Shader"), |
114 | | - source: wgpu::ShaderSource::SpirV(Cow::Owned(shader_words)), |
115 | | - }); |
116 | | - Ok((module, Some(entry_point))) |
117 | | - } |
118 | | -} |
119 | | - |
120 | | -/// A WGSL compute shader. |
121 | | -pub struct WgslComputeShader { |
122 | | - pub path: PathBuf, |
123 | | - pub entry_point: Option<String>, |
124 | | -} |
125 | | - |
126 | | -impl WgslComputeShader { |
127 | | - pub fn new<P: Into<PathBuf>>(path: P, entry_point: Option<String>) -> Self { |
128 | | - Self { |
129 | | - path: path.into(), |
130 | | - entry_point, |
131 | | - } |
132 | | - } |
133 | | -} |
134 | | - |
135 | | -impl WgpuShader for WgslComputeShader { |
136 | | - fn create_module( |
137 | | - &self, |
138 | | - device: &wgpu::Device, |
139 | | - ) -> anyhow::Result<(wgpu::ShaderModule, Option<String>)> { |
140 | | - let shader_source = fs::read_to_string(&self.path) |
141 | | - .with_context(|| format!("reading wgsl source file '{}'", &self.path.display()))?; |
142 | | - let module = device.create_shader_module(wgpu::ShaderModuleDescriptor { |
143 | | - label: Some("Compute Shader"), |
144 | | - source: wgpu::ShaderSource::Wgsl(Cow::Owned(shader_source)), |
145 | | - }); |
146 | | - Ok((module, self.entry_point.clone())) |
147 | | - } |
148 | | -} |
149 | | - |
150 | 15 | /// Compute test that is generic over the shader type. |
151 | 16 | pub struct WgpuComputeTest<S> { |
152 | 17 | shader: S, |
@@ -539,48 +404,6 @@ impl ComputeBackend for WgpuBackend { |
539 | 404 | } |
540 | 405 | } |
541 | 406 |
|
542 | | -/// For WGSL, the code checks for "shader.wgsl" then "compute.wgsl". |
543 | | -impl Default for WgslComputeShader { |
544 | | - fn default() -> Self { |
545 | | - let manifest_dir = env::var("CARGO_MANIFEST_DIR").expect("CARGO_MANIFEST_DIR not set"); |
546 | | - let manifest_path = PathBuf::from(manifest_dir); |
547 | | - let shader_path = manifest_path.join("shader.wgsl"); |
548 | | - let compute_path = manifest_path.join("compute.wgsl"); |
549 | | - |
550 | | - let (file, source) = if shader_path.exists() { |
551 | | - ( |
552 | | - shader_path.clone(), |
553 | | - fs::read_to_string(&shader_path).unwrap_or_default(), |
554 | | - ) |
555 | | - } else if compute_path.exists() { |
556 | | - ( |
557 | | - compute_path.clone(), |
558 | | - fs::read_to_string(&compute_path).unwrap_or_default(), |
559 | | - ) |
560 | | - } else { |
561 | | - panic!("No default WGSL shader found in manifest directory"); |
562 | | - }; |
563 | | - |
564 | | - let entry_point = if source.contains("fn main_cs(") { |
565 | | - Some("main_cs".to_string()) |
566 | | - } else if source.contains("fn main(") { |
567 | | - Some("main".to_string()) |
568 | | - } else { |
569 | | - None |
570 | | - }; |
571 | | - |
572 | | - Self::new(file, entry_point) |
573 | | - } |
574 | | -} |
575 | | - |
576 | | -/// For the SPIR-V shader, the manifest directory is used as the build path. |
577 | | -impl Default for RustComputeShader { |
578 | | - fn default() -> Self { |
579 | | - let manifest_dir = env::var("CARGO_MANIFEST_DIR").expect("CARGO_MANIFEST_DIR not set"); |
580 | | - Self::new(PathBuf::from(manifest_dir)) |
581 | | - } |
582 | | -} |
583 | | - |
584 | 407 | impl<S> WgpuComputeTestMultiBuffer<S> |
585 | 408 | where |
586 | 409 | S: WgpuShader, |
|
0 commit comments