11use crate :: config:: Config ;
2+ use anyhow:: Context ;
23use bytemuck:: Pod ;
3- use futures:: { channel :: oneshot :: Canceled , executor:: block_on} ;
4+ use futures:: executor:: block_on;
45use spirv_builder:: { ModuleResult , SpirvBuilder } ;
56use std:: {
67 borrow:: Cow ,
@@ -9,29 +10,14 @@ use std::{
910 io:: Write ,
1011 path:: PathBuf ,
1112} ;
12- use thiserror:: Error ;
13- use wgpu:: { BufferAsyncError , PipelineCompilationOptions , util:: DeviceExt } ;
14-
15- #[ derive( Error , Debug ) ]
16- pub enum ComputeError {
17- #[ error( "Failed to find a suitable GPU adapter" ) ]
18- AdapterNotFound ,
19- #[ error( "Failed to create device: {0}" ) ]
20- DeviceCreationFailed ( String ) ,
21- #[ error( "Failed to load shader: {0}" ) ]
22- ShaderLoadFailed ( String ) ,
23- #[ error( "Mapping compute output future canceled: {0}" ) ]
24- MappingCanceled ( Canceled ) ,
25- #[ error( "Mapping compute output failed: {0}" ) ]
26- MappingFailed ( BufferAsyncError ) ,
27- }
13+ use wgpu:: { PipelineCompilationOptions , util:: DeviceExt } ;
2814
2915/// Trait that creates a shader module and provides its entry point.
3016pub trait ComputeShader {
3117 fn create_module (
3218 & self ,
3319 device : & wgpu:: Device ,
34- ) -> Result < ( wgpu:: ShaderModule , Option < String > ) , ComputeError > ;
20+ ) -> anyhow :: Result < ( wgpu:: ShaderModule , Option < String > ) > ;
3521}
3622
3723/// A compute shader written in Rust compiled with spirv-builder.
@@ -49,40 +35,33 @@ impl ComputeShader for RustComputeShader {
4935 fn create_module (
5036 & self ,
5137 device : & wgpu:: Device ,
52- ) -> Result < ( wgpu:: ShaderModule , Option < String > ) , ComputeError > {
38+ ) -> anyhow :: Result < ( wgpu:: ShaderModule , Option < String > ) > {
5339 let builder = SpirvBuilder :: new ( & self . path , "spirv-unknown-vulkan1.1" )
5440 . print_metadata ( spirv_builder:: MetadataPrintout :: None )
5541 . release ( true )
5642 . multimodule ( false )
5743 . shader_panic_strategy ( spirv_builder:: ShaderPanicStrategy :: SilentExit )
5844 . preserve_bindings ( true ) ;
59- let artifact = builder
60- . build ( )
61- . map_err ( |e| ComputeError :: ShaderLoadFailed ( e. to_string ( ) ) ) ?;
45+ let artifact = builder. build ( ) . context ( "SpirvBuilder::build() failed" ) ?;
6246
6347 if artifact. entry_points . len ( ) != 1 {
64- return Err ( ComputeError :: ShaderLoadFailed ( format ! (
48+ anyhow :: bail !(
6549 "Expected exactly one entry point, found {}" ,
6650 artifact. entry_points. len( )
67- ) ) ) ;
51+ ) ;
6852 }
6953 let entry_point = artifact. entry_points . into_iter ( ) . next ( ) . unwrap ( ) ;
7054
7155 let shader_bytes = match artifact. module {
72- ModuleResult :: SingleModule ( path) => {
73- fs:: read ( & path) . map_err ( |e| ComputeError :: ShaderLoadFailed ( e. to_string ( ) ) ) ?
74- }
56+ ModuleResult :: SingleModule ( path) => fs:: read ( & path)
57+ . with_context ( || format ! ( "reading spv file '{}' failed" , path. display( ) ) ) ?,
7558 ModuleResult :: MultiModule ( _modules) => {
76- return Err ( ComputeError :: ShaderLoadFailed (
77- "Multiple modules produced" . to_string ( ) ,
78- ) ) ;
59+ anyhow:: bail!( "MultiModule modules produced" ) ;
7960 }
8061 } ;
8162
8263 if shader_bytes. len ( ) % 4 != 0 {
83- return Err ( ComputeError :: ShaderLoadFailed (
84- "SPIR-V binary length is not a multiple of 4" . to_string ( ) ,
85- ) ) ;
64+ anyhow:: bail!( "SPIR-V binary length is not a multiple of 4" ) ;
8665 }
8766 let shader_words: Vec < u32 > = bytemuck:: cast_slice ( & shader_bytes) . to_vec ( ) ;
8867 let module = device. create_shader_module ( wgpu:: ShaderModuleDescriptor {
@@ -112,9 +91,9 @@ impl ComputeShader for WgslComputeShader {
11291 fn create_module (
11392 & self ,
11493 device : & wgpu:: Device ,
115- ) -> Result < ( wgpu:: ShaderModule , Option < String > ) , ComputeError > {
94+ ) -> anyhow :: Result < ( wgpu:: ShaderModule , Option < String > ) > {
11695 let shader_source = fs:: read_to_string ( & self . path )
117- . map_err ( |e| ComputeError :: ShaderLoadFailed ( e . to_string ( ) ) ) ?;
96+ . with_context ( || format ! ( "reading wgsl source file '{}'" , & self . path . display ( ) ) ) ?;
11897 let module = device. create_shader_module ( wgpu:: ShaderModuleDescriptor {
11998 label : Some ( "Compute Shader" ) ,
12099 source : wgpu:: ShaderSource :: Wgsl ( Cow :: Owned ( shader_source) ) ,
@@ -142,7 +121,7 @@ where
142121 }
143122 }
144123
145- fn init ( ) -> Result < ( wgpu:: Device , wgpu:: Queue ) , ComputeError > {
124+ fn init ( ) -> anyhow :: Result < ( wgpu:: Device , wgpu:: Queue ) > {
146125 block_on ( async {
147126 let instance = wgpu:: Instance :: new ( wgpu:: InstanceDescriptor {
148127 #[ cfg( target_os = "linux" ) ]
@@ -160,7 +139,7 @@ where
160139 force_fallback_adapter : false ,
161140 } )
162141 . await
163- . ok_or ( ComputeError :: AdapterNotFound ) ?;
142+ . context ( "Failed to find a suitable GPU adapter" ) ?;
164143 let ( device, queue) = adapter
165144 . request_device (
166145 & wgpu:: DeviceDescriptor {
@@ -175,12 +154,12 @@ where
175154 None ,
176155 )
177156 . await
178- . map_err ( |e| ComputeError :: DeviceCreationFailed ( e . to_string ( ) ) ) ?;
157+ . context ( "Failed to create device" ) ?;
179158 Ok ( ( device, queue) )
180159 } )
181160 }
182161
183- fn run_internal < I > ( self , input : Option < I > ) -> Result < Vec < u8 > , ComputeError >
162+ fn run_internal < I > ( self , input : Option < I > ) -> anyhow :: Result < Vec < u8 > >
184163 where
185164 I : Sized + Pod ,
186165 {
@@ -278,42 +257,42 @@ where
278257 } ) ;
279258 device. poll ( wgpu:: Maintain :: Wait ) ;
280259 block_on ( receiver)
281- . map_err ( ComputeError :: MappingCanceled ) ?
282- . map_err ( ComputeError :: MappingFailed ) ?;
260+ . context ( "mapping canceled" ) ?
261+ . context ( "mapping failed" ) ?;
283262 let data = buffer_slice. get_mapped_range ( ) . to_vec ( ) ;
284263 staging_buffer. unmap ( ) ;
285264 Ok ( data)
286265 }
287266
288267 /// Runs the compute shader with no input.
289- pub fn run ( self ) -> Result < Vec < u8 > , ComputeError > {
268+ pub fn run ( self ) -> anyhow :: Result < Vec < u8 > > {
290269 self . run_internal :: < ( ) > ( None )
291270 }
292271
293272 /// Runs the compute shader with provided input.
294- pub fn run_with_input < I > ( self , input : I ) -> Result < Vec < u8 > , ComputeError >
273+ pub fn run_with_input < I > ( self , input : I ) -> anyhow :: Result < Vec < u8 > >
295274 where
296275 I : Sized + Pod ,
297276 {
298277 self . run_internal ( Some ( input) )
299278 }
300279
301280 /// Runs the compute shader with no input and writes the output to a file.
302- pub fn run_test ( self , config : & Config ) -> Result < ( ) , ComputeError > {
281+ pub fn run_test ( self , config : & Config ) -> anyhow :: Result < ( ) > {
303282 let output = self . run ( ) ?;
304- let mut f = File :: create ( & config. output_path ) . unwrap ( ) ;
305- f. write_all ( & output) . unwrap ( ) ;
283+ let mut f = File :: create ( & config. output_path ) ? ;
284+ f. write_all ( & output) ? ;
306285 Ok ( ( ) )
307286 }
308287
309288 /// Runs the compute shader with provided input and writes the output to a file.
310- pub fn run_test_with_input < I > ( self , config : & Config , input : I ) -> Result < ( ) , ComputeError >
289+ pub fn run_test_with_input < I > ( self , config : & Config , input : I ) -> anyhow :: Result < ( ) >
311290 where
312291 I : Sized + Pod ,
313292 {
314293 let output = self . run_with_input ( input) ?;
315- let mut f = File :: create ( & config. output_path ) . unwrap ( ) ;
316- f. write_all ( & output) . unwrap ( ) ;
294+ let mut f = File :: create ( & config. output_path ) ? ;
295+ f. write_all ( & output) ? ;
317296 Ok ( ( ) )
318297 }
319298}
0 commit comments