@@ -22,6 +22,7 @@ import './flags_webgpu';
2222import { backend_util , DataMover , DataType , ENV , KernelBackend , Rank , ShapeMap , Tensor , Tensor2D , Tensor3D , Tensor4D , util } from '@tensorflow/tfjs-core' ;
2323import * as shaderc from '@webgpu/shaderc' ;
2424
25+ import { BufferManager } from './buffer_manager' ;
2526import { ArgMinMaxProgram } from './kernels/argminmax_webgpu' ;
2627import * as binary_op from './kernels/binary_op_webgpu' ;
2728import { BinaryOpProgram } from './kernels/binary_op_webgpu' ;
@@ -50,33 +51,38 @@ export interface WebGPUMemoryInfo extends MemoryInfo {
5051 unreliable : boolean ;
5152}
5253
53- type TensorInfo = {
54+ type BufferInfo = {
5455 byteSize : number ,
55- values : Float32Array | Int32Array | Uint8Array ,
56- id : number ,
57- dtype : DataType ,
56+ usage : GPUBufferUsage ,
5857 buffer : GPUBuffer
5958} ;
6059
61- type BufferInfo = {
62- byteSize : number ,
63- buffer : GPUBuffer
60+ type TensorInfo = {
61+ values : Float32Array | Int32Array | Uint8Array ,
62+ id : number ,
63+ dtype : DataType ,
64+ bufferInfo : BufferInfo
6465} ;
6566
6667interface DataId { }
6768
69+ const DEFAULT_GPUBUFFER_USAGE = GPUBufferUsage . STORAGE |
70+ GPUBufferUsage . TRANSFER_SRC | GPUBufferUsage . TRANSFER_DST ;
71+
6872export class WebGPUBackend extends KernelBackend {
6973 device : GPUDevice ;
7074 queue : GPUQueue ;
7175 shaderc : shaderc . Shaderc ;
7276 compiler : shaderc . Compiler ;
7377 compileOpts : shaderc . CompileOptions ;
7478 commandQueue : GPUCommandEncoder [ ] ;
75- disposalQueue : BufferInfo [ ] ;
7679
80+ private commandQueueOwnedIds = new WeakSet < DataId > ( ) ;
7781 private binaryCache : { [ key : string ] : WebGPUBinary } ;
7882 private fromPixels2DContext : CanvasRenderingContext2D ;
79- private numBytesInGPU = 0 ;
83+ private bufferManager : BufferManager ;
84+ private tensorMap = new WeakMap < DataId , TensorInfo > ( ) ;
85+ private disposalQueue : BufferInfo [ ] = [ ] ;
8086
8187 private disposed = false ;
8288
@@ -86,12 +92,13 @@ export class WebGPUBackend extends KernelBackend {
8692 this . device = device ;
8793 this . queue = device . getQueue ( ) ;
8894 this . commandQueue = [ ] ;
89- this . disposalQueue = [ ] ;
9095 this . shaderc = shaderc ;
9196 this . compiler = new shaderc . Compiler ( ) ;
9297 const opts = new shaderc . CompileOptions ( ) ;
9398 opts . SetOptimizationLevel ( shaderc . optimization_level . performance ) ;
9499 this . compileOpts = opts ;
100+
101+ this . bufferManager = new BufferManager ( this . device ) ;
95102 }
96103
97104 floatPrecision ( ) : 32 {
@@ -102,43 +109,62 @@ export class WebGPUBackend extends KernelBackend {
102109 // TODO: tfjs team to implement this. Call GPUBuffer.destroy()
103110 }
104111
105- private tensorMap = new WeakMap < DataId , TensorInfo > ( ) ;
112+ flushDisposalQueue ( ) {
113+ this . disposalQueue . forEach ( d => {
114+ this . releaseBuffer ( d . buffer , d . byteSize , d . usage ) ;
115+ } ) ;
116+
117+ this . disposalQueue = [ ] ;
118+ }
106119
107120 disposeData ( dataId : DataId ) : void {
108121 if ( ! this . tensorMap . has ( dataId ) ) {
109122 throw new Error ( `Tensor ${ dataId } was not registered!` ) ;
110123 }
111124
112125 const info = this . tensorMap . get ( dataId ) ;
113- this . disposeBuffer ( info . byteSize , info . buffer ) ;
126+ if ( this . commandQueueOwnedIds . has ( dataId ) ) {
127+ this . disposalQueue . push ( info . bufferInfo ) ;
128+ } else {
129+ this . releaseBuffer (
130+ info . bufferInfo . buffer , info . bufferInfo . byteSize ,
131+ info . bufferInfo . usage ) ;
132+ }
114133
115134 this . tensorMap . delete ( dataId ) ;
116135 }
117136
118137 memory ( ) : WebGPUMemoryInfo {
119- return { numBytesInGPU : this . numBytesInGPU , unreliable : false } as
120- WebGPUMemoryInfo ;
138+ return {
139+ numBytesInGPU : this . bufferManager . numBytesUsed ,
140+ unreliable : false
141+ } as WebGPUMemoryInfo ;
142+ }
143+
144+ getBufferManager ( ) : BufferManager {
145+ return this . bufferManager ;
121146 }
122147
123- private createBuffer (
124- byteSize : number ,
125- usage : GPUBufferUsage = GPUBufferUsage . STORAGE |
126- GPUBufferUsage . TRANSFER_SRC | GPUBufferUsage . TRANSFER_DST ) {
127- this . numBytesInGPU += byteSize ;
128- return this . device . createBuffer ( { size : byteSize , usage} ) ;
148+ private acquireBuffer (
149+ byteSize : number , usage : GPUBufferUsage = DEFAULT_GPUBUFFER_USAGE ) {
150+ return this . bufferManager . acquireBuffer ( byteSize , usage ) ;
129151 }
130152
131- private disposeBuffer ( byteSize : number , buffer : GPUBuffer ) {
132- this . disposalQueue . push ( { byteSize, buffer } ) ;
133- // TODO: recycle deleted buffers
153+ private releaseBuffer (
154+ buffer : GPUBuffer , byteSize : number , usage : GPUBufferUsage ) {
155+ this . bufferManager . releaseBuffer ( buffer , byteSize , usage ) ;
134156 }
135157
136158 register ( dataId : object , shape : number [ ] , dtype : DataType ) : void {
137159 if ( ! this . tensorMap . has ( dataId ) ) {
138160 const byteSize = util . sizeFromShape ( shape ) * util . bytesPerElement ( dtype ) ;
139- const buffer = this . createBuffer ( byteSize ) ;
140- this . tensorMap . set (
141- dataId , { byteSize, values : null , id : - 1 , buffer, dtype} ) ;
161+ const buffer = this . acquireBuffer ( byteSize ) ;
162+ this . tensorMap . set ( dataId , {
163+ values : null ,
164+ id : - 1 ,
165+ dtype,
166+ bufferInfo : { byteSize, usage : DEFAULT_GPUBUFFER_USAGE , buffer}
167+ } ) ;
142168 }
143169 }
144170
@@ -149,49 +175,45 @@ export class WebGPUBackend extends KernelBackend {
149175
150176 const info = this . tensorMap . get ( dataId ) ;
151177 info . values = values ;
152- info . buffer . setSubData ( 0 , values ) ;
178+ info . bufferInfo . buffer . setSubData ( 0 , values ) ;
153179 this . tensorMap . set ( dataId , info ) ;
154180 }
155181
156182 private submitQueue ( ) {
157183 this . queue . submit ( this . commandQueue . map ( enc => enc . finish ( ) ) ) ;
158184 this . commandQueue = [ ] ;
159185
160- this . flushDisposalQueue ( ) ;
161- }
162-
163- private flushDisposalQueue ( ) {
164- this . disposalQueue . forEach ( d => {
165- d . buffer . destroy ( ) ;
166- this . numBytesInGPU -= d . byteSize ;
167- } ) ;
186+ this . commandQueueOwnedIds = new WeakSet < DataId > ( ) ;
168187
169- this . disposalQueue = [ ] ;
188+ this . flushDisposalQueue ( ) ;
170189 }
171190
172191 private async getBufferData ( info : TensorInfo ) : Promise < ArrayBuffer > {
173- const staging = this . createBuffer (
174- info . byteSize , GPUBufferUsage . TRANSFER_DST | GPUBufferUsage . MAP_READ ) ;
175- {
176- const encoder = this . device . createCommandEncoder ( { } ) ;
177- encoder . copyBufferToBuffer ( info . buffer , 0 , staging , 0 , info . byteSize ) ;
178- this . commandQueue . push ( encoder ) ;
179- this . submitQueue ( ) ;
180- }
192+ const staging = this . acquireBuffer (
193+ info . bufferInfo . byteSize ,
194+ GPUBufferUsage . TRANSFER_DST | GPUBufferUsage . MAP_READ ) ;
195+ const encoder = this . device . createCommandEncoder ( { } ) ;
196+ encoder . copyBufferToBuffer (
197+ info . bufferInfo . buffer , 0 , staging , 0 , info . bufferInfo . byteSize ) ;
198+ this . commandQueue . push ( encoder ) ;
199+ this . submitQueue ( ) ;
200+
181201 const mapped : ArrayBuffer = await staging . mapReadAsync ( ) ;
202+ const values = mapped . slice ( 0 ) ;
182203
183- return mapped . slice ( 0 ) ;
204+ staging . unmap ( ) ;
205+ this . releaseBuffer (
206+ staging , info . bufferInfo . byteSize ,
207+ GPUBufferUsage . TRANSFER_DST | GPUBufferUsage . MAP_READ ) ;
208+
209+ return values ;
184210 }
185211
186212 private convertAndCacheOnCPU ( dataId : DataId , data : backend_util . TypedArray ) :
187213 backend_util . TypedArray {
188- const texData = this . tensorMap . get ( dataId ) ;
189-
190- // TODO: implement release GPU data.
191- // TODO: add backend_webgl float32ToTypedArray to util and use that here.
192-
193- texData . values = data ;
194- return texData . values as backend_util . TypedArray ;
214+ const info = this . tensorMap . get ( dataId ) ;
215+ info . values = data ;
216+ return info . values as backend_util . TypedArray ;
195217 }
196218
197219 // TODO: Remove once this is fixed:
@@ -246,7 +268,7 @@ export class WebGPUBackend extends KernelBackend {
246268 resource : {
247269 offset : 0 ,
248270 size : tensor . size * util . bytesPerElement ( tensor . dtype ) ,
249- buffer : tensorData . buffer
271+ buffer : tensorData . bufferInfo . buffer
250272 }
251273 } ;
252274 }
@@ -337,16 +359,23 @@ export class WebGPUBackend extends KernelBackend {
337359 pass . endPass ( ) ;
338360 this . commandQueue . push ( encoder ) ;
339361
362+ inputs . forEach ( input => {
363+ this . commandQueueOwnedIds . add ( input . dataId ) ;
364+ } ) ;
365+ this . commandQueueOwnedIds . add ( output . dataId ) ;
366+
340367 if ( ENV . get ( 'WEBGPU_IMMEDIATE_EXECUTION_ENABLED' ) ) {
341368 this . submitQueue ( ) ;
342369 }
343- this . disposeBuffer ( uniformData . byteLength , uniforms . resource . buffer ) ;
370+ this . releaseBuffer (
371+ uniforms . resource . buffer , uniformData . byteLength ,
372+ GPUBufferUsage . TRANSFER_DST | GPUBufferUsage . UNIFORM ) ;
344373 return output as { } as K ;
345374 }
346375
347376 private makeUniforms ( data : Uint32Array |
348377 Int32Array ) : webgpu_program . BindingInfo {
349- const dimensionsBuffer = this . createBuffer (
378+ const dimensionsBuffer = this . acquireBuffer (
350379 data . byteLength , GPUBufferUsage . TRANSFER_DST | GPUBufferUsage . UNIFORM ) ;
351380 dimensionsBuffer . setSubData ( 0 , data ) ;
352381
@@ -611,6 +640,7 @@ export class WebGPUBackend extends KernelBackend {
611640 if ( this . disposed ) {
612641 return ;
613642 }
643+ this . bufferManager . dispose ( ) ;
614644 this . disposed = true ;
615645 }
616646}
0 commit comments