@@ -243,30 +243,28 @@ impl<A: hal::Api> PendingWrites<A> {
243243 }
244244}
245245
246- impl < A : HalApi > super :: Device < A > {
247- fn prepare_staging_buffer (
248- & mut self ,
249- size : wgt:: BufferAddress ,
250- ) -> Result < ( StagingBuffer < A > , * mut u8 ) , DeviceError > {
251- profiling:: scope!( "prepare_staging_buffer" ) ;
252- let stage_desc = hal:: BufferDescriptor {
253- label : Some ( "(wgpu internal) Staging" ) ,
254- size,
255- usage : hal:: BufferUses :: MAP_WRITE | hal:: BufferUses :: COPY_SRC ,
256- memory_flags : hal:: MemoryFlags :: TRANSIENT ,
257- } ;
258-
259- let buffer = unsafe { self . raw . create_buffer ( & stage_desc) ? } ;
260- let mapping = unsafe { self . raw . map_buffer ( & buffer, 0 ..size) } ?;
261-
262- let staging_buffer = StagingBuffer {
263- raw : buffer,
264- size,
265- is_coherent : mapping. is_coherent ,
266- } ;
267-
268- Ok ( ( staging_buffer, mapping. ptr . as_ptr ( ) ) )
269- }
246+ fn prepare_staging_buffer < A : HalApi > (
247+ device : & mut A :: Device ,
248+ size : wgt:: BufferAddress ,
249+ ) -> Result < ( StagingBuffer < A > , * mut u8 ) , DeviceError > {
250+ profiling:: scope!( "prepare_staging_buffer" ) ;
251+ let stage_desc = hal:: BufferDescriptor {
252+ label : Some ( "(wgpu internal) Staging" ) ,
253+ size,
254+ usage : hal:: BufferUses :: MAP_WRITE | hal:: BufferUses :: COPY_SRC ,
255+ memory_flags : hal:: MemoryFlags :: TRANSIENT ,
256+ } ;
257+
258+ let buffer = unsafe { device. create_buffer ( & stage_desc) ? } ;
259+ let mapping = unsafe { device. map_buffer ( & buffer, 0 ..size) } ?;
260+
261+ let staging_buffer = StagingBuffer {
262+ raw : buffer,
263+ size,
264+ is_coherent : mapping. is_coherent ,
265+ } ;
266+
267+ Ok ( ( staging_buffer, mapping. ptr . as_ptr ( ) ) )
270268}
271269
272270impl < A : hal:: Api > StagingBuffer < A > {
@@ -350,21 +348,31 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
350348 return Ok ( ( ) ) ;
351349 }
352350
353- let ( staging_buffer, staging_buffer_ptr) = device. prepare_staging_buffer ( data_size) ?;
351+ // Platform validation requires that the staging buffer always be
352+ // freed, even if an error occurs. All paths from here must call
353+ // `device.pending_writes.consume`.
354+ let ( staging_buffer, staging_buffer_ptr) =
355+ prepare_staging_buffer ( & mut device. raw , data_size) ?;
354356
355- unsafe {
357+ if let Err ( flush_error ) = unsafe {
356358 profiling:: scope!( "copy" ) ;
357359 ptr:: copy_nonoverlapping ( data. as_ptr ( ) , staging_buffer_ptr, data. len ( ) ) ;
358- staging_buffer. flush ( & device. raw ) ?;
359- } ;
360+ staging_buffer. flush ( & device. raw )
361+ } {
362+ device. pending_writes . consume ( staging_buffer) ;
363+ return Err ( flush_error. into ( ) ) ;
364+ }
360365
361- self . queue_write_staging_buffer_impl (
366+ let result = self . queue_write_staging_buffer_impl (
362367 device,
363368 device_token,
364- staging_buffer,
369+ & staging_buffer,
365370 buffer_id,
366371 buffer_offset,
367- )
372+ ) ;
373+
374+ device. pending_writes . consume ( staging_buffer) ;
375+ result
368376 }
369377
370378 pub fn queue_create_staging_buffer < A : HalApi > (
@@ -382,7 +390,7 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
382390 . map_err ( |_| DeviceError :: Invalid ) ?;
383391
384392 let ( staging_buffer, staging_buffer_ptr) =
385- device . prepare_staging_buffer ( buffer_size. get ( ) ) ?;
393+ prepare_staging_buffer ( & mut device . raw , buffer_size. get ( ) ) ?;
386394
387395 let fid = hub. staging_buffers . prepare ( id_in) ;
388396 let id = fid. assign ( staging_buffer, device_token) ;
@@ -413,15 +421,25 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
413421 . 0
414422 . ok_or ( TransferError :: InvalidBuffer ( buffer_id) ) ?;
415423
416- unsafe { staging_buffer. flush ( & device. raw ) ? } ;
424+ // At this point, we have taken ownership of the staging_buffer from the
425+ // user. Platform validation requires that the staging buffer always
426+ // be freed, even if an error occurs. All paths from here must call
427+ // `device.pending_writes.consume`.
428+ if let Err ( flush_error) = unsafe { staging_buffer. flush ( & device. raw ) } {
429+ device. pending_writes . consume ( staging_buffer) ;
430+ return Err ( flush_error. into ( ) ) ;
431+ }
417432
418- self . queue_write_staging_buffer_impl (
433+ let result = self . queue_write_staging_buffer_impl (
419434 device,
420435 device_token,
421- staging_buffer,
436+ & staging_buffer,
422437 buffer_id,
423438 buffer_offset,
424- )
439+ ) ;
440+
441+ device. pending_writes . consume ( staging_buffer) ;
442+ result
425443 }
426444
427445 pub fn queue_validate_write_buffer < A : HalApi > (
@@ -481,7 +499,7 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
481499 & self ,
482500 device : & mut super :: Device < A > ,
483501 device_token : & mut Token < super :: Device < A > > ,
484- staging_buffer : StagingBuffer < A > ,
502+ staging_buffer : & StagingBuffer < A > ,
485503 buffer_id : id:: BufferId ,
486504 buffer_offset : u64 ,
487505 ) -> Result < ( ) , QueueWriteError > {
@@ -520,7 +538,6 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
520538 encoder. copy_buffer_to_buffer ( & staging_buffer. raw , dst_raw, region. into_iter ( ) ) ;
521539 }
522540
523- device. pending_writes . consume ( staging_buffer) ;
524541 device. pending_writes . dst_buffers . insert ( buffer_id) ;
525542
526543 // Ensure the overwritten bytes are marked as initialized so they don't need to be nulled prior to mapping or binding.
@@ -613,7 +630,6 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
613630 let block_rows_in_copy =
614631 ( size. depth_or_array_layers - 1 ) * block_rows_per_image + height_blocks;
615632 let stage_size = stage_bytes_per_row as u64 * block_rows_in_copy as u64 ;
616- let ( staging_buffer, staging_buffer_ptr) = device. prepare_staging_buffer ( stage_size) ?;
617633
618634 let dst = texture_guard. get_mut ( destination. texture ) . unwrap ( ) ;
619635 if !dst. desc . usage . contains ( wgt:: TextureUsages :: COPY_DST ) {
@@ -676,12 +692,23 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
676692 validate_texture_copy_range ( destination, & dst. desc , CopySide :: Destination , size) ?;
677693 dst. life_guard . use_at ( device. active_submission_index + 1 ) ;
678694
695+ let dst_raw = dst
696+ . inner
697+ . as_raw ( )
698+ . ok_or ( TransferError :: InvalidTexture ( destination. texture ) ) ?;
699+
679700 let bytes_per_row = if let Some ( bytes_per_row) = data_layout. bytes_per_row {
680701 bytes_per_row. get ( )
681702 } else {
682703 width_blocks * format_desc. block_size as u32
683704 } ;
684705
706+ // Platform validation requires that the staging buffer always be
707+ // freed, even if an error occurs. All paths from here must call
708+ // `device.pending_writes.consume`.
709+ let ( staging_buffer, staging_buffer_ptr) =
710+ prepare_staging_buffer ( & mut device. raw , stage_size) ?;
711+
685712 if stage_bytes_per_row == bytes_per_row {
686713 profiling:: scope!( "copy aligned" ) ;
687714 // Fast path if the data is already being aligned optimally.
@@ -715,7 +742,10 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
715742 }
716743 }
717744
718- unsafe { staging_buffer. flush ( & device. raw ) } ?;
745+ if let Err ( e) = unsafe { staging_buffer. flush ( & device. raw ) } {
746+ device. pending_writes . consume ( staging_buffer) ;
747+ return Err ( e. into ( ) ) ;
748+ }
719749
720750 let regions = ( 0 ..array_layer_count) . map ( |rel_array_layer| {
721751 let mut texture_base = dst_base. clone ( ) ;
@@ -737,11 +767,6 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
737767 usage : hal:: BufferUses :: MAP_WRITE ..hal:: BufferUses :: COPY_SRC ,
738768 } ;
739769
740- let dst_raw = dst
741- . inner
742- . as_raw ( )
743- . ok_or ( TransferError :: InvalidTexture ( destination. texture ) ) ?;
744-
745770 unsafe {
746771 encoder
747772 . transition_textures ( transition. map ( |pending| pending. into_hal ( dst) ) . into_iter ( ) ) ;
0 commit comments