Skip to content

Commit 853e371

Browse files
authored
[webgpu] Add buffer manager for buffer reuse. (tensorflow#1822)
FEATURE
1 parent 1dec063 commit 853e371

File tree

3 files changed

+290
-56
lines changed

3 files changed

+290
-56
lines changed

tfjs-webgpu/src/backend_webgpu.ts

Lines changed: 85 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import './flags_webgpu';
2222
import {backend_util, DataMover, DataType, ENV, KernelBackend, Rank, ShapeMap, Tensor, Tensor2D, Tensor3D, Tensor4D, util} from '@tensorflow/tfjs-core';
2323
import * as shaderc from '@webgpu/shaderc';
2424

25+
import {BufferManager} from './buffer_manager';
2526
import {ArgMinMaxProgram} from './kernels/argminmax_webgpu';
2627
import * as binary_op from './kernels/binary_op_webgpu';
2728
import {BinaryOpProgram} from './kernels/binary_op_webgpu';
@@ -50,33 +51,38 @@ export interface WebGPUMemoryInfo extends MemoryInfo {
5051
unreliable: boolean;
5152
}
5253

53-
type TensorInfo = {
54+
type BufferInfo = {
5455
byteSize: number,
55-
values: Float32Array|Int32Array|Uint8Array,
56-
id: number,
57-
dtype: DataType,
56+
usage: GPUBufferUsage,
5857
buffer: GPUBuffer
5958
};
6059

61-
type BufferInfo = {
62-
byteSize: number,
63-
buffer: GPUBuffer
60+
type TensorInfo = {
61+
values: Float32Array|Int32Array|Uint8Array,
62+
id: number,
63+
dtype: DataType,
64+
bufferInfo: BufferInfo
6465
};
6566

6667
interface DataId {}
6768

69+
const DEFAULT_GPUBUFFER_USAGE = GPUBufferUsage.STORAGE |
70+
GPUBufferUsage.TRANSFER_SRC | GPUBufferUsage.TRANSFER_DST;
71+
6872
export class WebGPUBackend extends KernelBackend {
6973
device: GPUDevice;
7074
queue: GPUQueue;
7175
shaderc: shaderc.Shaderc;
7276
compiler: shaderc.Compiler;
7377
compileOpts: shaderc.CompileOptions;
7478
commandQueue: GPUCommandEncoder[];
75-
disposalQueue: BufferInfo[];
7679

80+
private commandQueueOwnedIds = new WeakSet<DataId>();
7781
private binaryCache: {[key: string]: WebGPUBinary};
7882
private fromPixels2DContext: CanvasRenderingContext2D;
79-
private numBytesInGPU = 0;
83+
private bufferManager: BufferManager;
84+
private tensorMap = new WeakMap<DataId, TensorInfo>();
85+
private disposalQueue: BufferInfo[] = [];
8086

8187
private disposed = false;
8288

@@ -86,12 +92,13 @@ export class WebGPUBackend extends KernelBackend {
8692
this.device = device;
8793
this.queue = device.getQueue();
8894
this.commandQueue = [];
89-
this.disposalQueue = [];
9095
this.shaderc = shaderc;
9196
this.compiler = new shaderc.Compiler();
9297
const opts = new shaderc.CompileOptions();
9398
opts.SetOptimizationLevel(shaderc.optimization_level.performance);
9499
this.compileOpts = opts;
100+
101+
this.bufferManager = new BufferManager(this.device);
95102
}
96103

97104
floatPrecision(): 32 {
@@ -102,43 +109,62 @@ export class WebGPUBackend extends KernelBackend {
102109
// TODO: tfjs team to implement this. Call GPUBuffer.destroy()
103110
}
104111

105-
private tensorMap = new WeakMap<DataId, TensorInfo>();
112+
flushDisposalQueue() {
113+
this.disposalQueue.forEach(d => {
114+
this.releaseBuffer(d.buffer, d.byteSize, d.usage);
115+
});
116+
117+
this.disposalQueue = [];
118+
}
106119

107120
disposeData(dataId: DataId): void {
108121
if (!this.tensorMap.has(dataId)) {
109122
throw new Error(`Tensor ${dataId} was not registered!`);
110123
}
111124

112125
const info = this.tensorMap.get(dataId);
113-
this.disposeBuffer(info.byteSize, info.buffer);
126+
if (this.commandQueueOwnedIds.has(dataId)) {
127+
this.disposalQueue.push(info.bufferInfo);
128+
} else {
129+
this.releaseBuffer(
130+
info.bufferInfo.buffer, info.bufferInfo.byteSize,
131+
info.bufferInfo.usage);
132+
}
114133

115134
this.tensorMap.delete(dataId);
116135
}
117136

118137
memory(): WebGPUMemoryInfo {
119-
return {numBytesInGPU: this.numBytesInGPU, unreliable: false} as
120-
WebGPUMemoryInfo;
138+
return {
139+
numBytesInGPU: this.bufferManager.numBytesUsed,
140+
unreliable: false
141+
} as WebGPUMemoryInfo;
142+
}
143+
144+
getBufferManager(): BufferManager {
145+
return this.bufferManager;
121146
}
122147

123-
private createBuffer(
124-
byteSize: number,
125-
usage: GPUBufferUsage = GPUBufferUsage.STORAGE |
126-
GPUBufferUsage.TRANSFER_SRC | GPUBufferUsage.TRANSFER_DST) {
127-
this.numBytesInGPU += byteSize;
128-
return this.device.createBuffer({size: byteSize, usage});
148+
private acquireBuffer(
149+
byteSize: number, usage: GPUBufferUsage = DEFAULT_GPUBUFFER_USAGE) {
150+
return this.bufferManager.acquireBuffer(byteSize, usage);
129151
}
130152

131-
private disposeBuffer(byteSize: number, buffer: GPUBuffer) {
132-
this.disposalQueue.push({byteSize, buffer});
133-
// TODO: recycle deleted buffers
153+
private releaseBuffer(
154+
buffer: GPUBuffer, byteSize: number, usage: GPUBufferUsage) {
155+
this.bufferManager.releaseBuffer(buffer, byteSize, usage);
134156
}
135157

136158
register(dataId: object, shape: number[], dtype: DataType): void {
137159
if (!this.tensorMap.has(dataId)) {
138160
const byteSize = util.sizeFromShape(shape) * util.bytesPerElement(dtype);
139-
const buffer = this.createBuffer(byteSize);
140-
this.tensorMap.set(
141-
dataId, {byteSize, values: null, id: -1, buffer, dtype});
161+
const buffer = this.acquireBuffer(byteSize);
162+
this.tensorMap.set(dataId, {
163+
values: null,
164+
id: -1,
165+
dtype,
166+
bufferInfo: {byteSize, usage: DEFAULT_GPUBUFFER_USAGE, buffer}
167+
});
142168
}
143169
}
144170

@@ -149,49 +175,45 @@ export class WebGPUBackend extends KernelBackend {
149175

150176
const info = this.tensorMap.get(dataId);
151177
info.values = values;
152-
info.buffer.setSubData(0, values);
178+
info.bufferInfo.buffer.setSubData(0, values);
153179
this.tensorMap.set(dataId, info);
154180
}
155181

156182
private submitQueue() {
157183
this.queue.submit(this.commandQueue.map(enc => enc.finish()));
158184
this.commandQueue = [];
159185

160-
this.flushDisposalQueue();
161-
}
162-
163-
private flushDisposalQueue() {
164-
this.disposalQueue.forEach(d => {
165-
d.buffer.destroy();
166-
this.numBytesInGPU -= d.byteSize;
167-
});
186+
this.commandQueueOwnedIds = new WeakSet<DataId>();
168187

169-
this.disposalQueue = [];
188+
this.flushDisposalQueue();
170189
}
171190

172191
private async getBufferData(info: TensorInfo): Promise<ArrayBuffer> {
173-
const staging = this.createBuffer(
174-
info.byteSize, GPUBufferUsage.TRANSFER_DST | GPUBufferUsage.MAP_READ);
175-
{
176-
const encoder = this.device.createCommandEncoder({});
177-
encoder.copyBufferToBuffer(info.buffer, 0, staging, 0, info.byteSize);
178-
this.commandQueue.push(encoder);
179-
this.submitQueue();
180-
}
192+
const staging = this.acquireBuffer(
193+
info.bufferInfo.byteSize,
194+
GPUBufferUsage.TRANSFER_DST | GPUBufferUsage.MAP_READ);
195+
const encoder = this.device.createCommandEncoder({});
196+
encoder.copyBufferToBuffer(
197+
info.bufferInfo.buffer, 0, staging, 0, info.bufferInfo.byteSize);
198+
this.commandQueue.push(encoder);
199+
this.submitQueue();
200+
181201
const mapped: ArrayBuffer = await staging.mapReadAsync();
202+
const values = mapped.slice(0);
182203

183-
return mapped.slice(0);
204+
staging.unmap();
205+
this.releaseBuffer(
206+
staging, info.bufferInfo.byteSize,
207+
GPUBufferUsage.TRANSFER_DST | GPUBufferUsage.MAP_READ);
208+
209+
return values;
184210
}
185211

186212
private convertAndCacheOnCPU(dataId: DataId, data: backend_util.TypedArray):
187213
backend_util.TypedArray {
188-
const texData = this.tensorMap.get(dataId);
189-
190-
// TODO: implement release GPU data.
191-
// TODO: add backend_webgl float32ToTypedArray to util and use that here.
192-
193-
texData.values = data;
194-
return texData.values as backend_util.TypedArray;
214+
const info = this.tensorMap.get(dataId);
215+
info.values = data;
216+
return info.values as backend_util.TypedArray;
195217
}
196218

197219
// TODO: Remove once this is fixed:
@@ -246,7 +268,7 @@ export class WebGPUBackend extends KernelBackend {
246268
resource: {
247269
offset: 0,
248270
size: tensor.size * util.bytesPerElement(tensor.dtype),
249-
buffer: tensorData.buffer
271+
buffer: tensorData.bufferInfo.buffer
250272
}
251273
};
252274
}
@@ -337,16 +359,23 @@ export class WebGPUBackend extends KernelBackend {
337359
pass.endPass();
338360
this.commandQueue.push(encoder);
339361

362+
inputs.forEach(input => {
363+
this.commandQueueOwnedIds.add(input.dataId);
364+
});
365+
this.commandQueueOwnedIds.add(output.dataId);
366+
340367
if (ENV.get('WEBGPU_IMMEDIATE_EXECUTION_ENABLED')) {
341368
this.submitQueue();
342369
}
343-
this.disposeBuffer(uniformData.byteLength, uniforms.resource.buffer);
370+
this.releaseBuffer(
371+
uniforms.resource.buffer, uniformData.byteLength,
372+
GPUBufferUsage.TRANSFER_DST | GPUBufferUsage.UNIFORM);
344373
return output as {} as K;
345374
}
346375

347376
private makeUniforms(data: Uint32Array|
348377
Int32Array): webgpu_program.BindingInfo {
349-
const dimensionsBuffer = this.createBuffer(
378+
const dimensionsBuffer = this.acquireBuffer(
350379
data.byteLength, GPUBufferUsage.TRANSFER_DST | GPUBufferUsage.UNIFORM);
351380
dimensionsBuffer.setSubData(0, data);
352381

@@ -611,6 +640,7 @@ export class WebGPUBackend extends KernelBackend {
611640
if (this.disposed) {
612641
return;
613642
}
643+
this.bufferManager.dispose();
614644
this.disposed = true;
615645
}
616646
}

tfjs-webgpu/src/backend_webgpu_test.ts

Lines changed: 84 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
*/
1717

1818
import * as tf from '@tensorflow/tfjs-core';
19-
import {WebGPUMemoryInfo} from './backend_webgpu';
19+
20+
import {WebGPUBackend, WebGPUMemoryInfo} from './backend_webgpu';
2021
import {describeWebGPU} from './test_util';
2122

2223
describeWebGPU('backend webgpu', () => {
@@ -78,6 +79,88 @@ describeWebGPU('backend webgpu', () => {
7879
tf.ENV.set('WEBGPU_IMMEDIATE_EXECUTION_ENABLED', savedFlag);
7980
});
8081

82+
it('should recycle buffers in immediate mode', () => {
83+
const savedFlag = tf.ENV.get('WEBGPU_IMMEDIATE_EXECUTION_ENABLED');
84+
tf.ENV.set('WEBGPU_IMMEDIATE_EXECUTION_ENABLED', true);
85+
const backend = tf.backend() as WebGPUBackend;
86+
const bufferManager = backend.getBufferManager();
87+
bufferManager.reset();
88+
89+
const a = tf.tensor2d([2, 4, 6, 8], [2, 2]);
90+
const b = tf.tensor2d([0.5, 0.5, 0.5, 0.5], [2, 2]);
91+
92+
const c = tf.mul(a, b);
93+
const freeBuffersAfterFirstMul = bufferManager.getNumFreeBuffers();
94+
const usedBuffersAfterFirstMul = bufferManager.getNumUsedBuffers();
95+
96+
const f = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]);
97+
tf.matMul(c, f);
98+
const freeBuffersAfterFirstMatMul = bufferManager.getNumFreeBuffers();
99+
const usedBuffersAfterFirstMatMul = bufferManager.getNumUsedBuffers();
100+
expect(freeBuffersAfterFirstMatMul - freeBuffersAfterFirstMul)
101+
.toEqual(1); // from released uniform
102+
expect(usedBuffersAfterFirstMatMul - usedBuffersAfterFirstMul).toEqual(2);
103+
104+
const a2 = tf.tensor2d([2, 4, 6, 8], [2, 2]);
105+
const b2 = tf.tensor2d([0.5, 0.5, 0.5, 0.5], [2, 2]);
106+
107+
const c2 = tf.mul(a2, b2);
108+
const freeBuffersAfterSecondMul = bufferManager.getNumFreeBuffers();
109+
const usedBuffersAfterSecondMul = bufferManager.getNumUsedBuffers();
110+
expect(freeBuffersAfterSecondMul - freeBuffersAfterFirstMatMul)
111+
.toEqual(0); // released a uniform buffer and reused a buffer
112+
expect(usedBuffersAfterSecondMul - usedBuffersAfterFirstMatMul).toEqual(3);
113+
114+
const f2 = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]);
115+
tf.matMul(c2, f2);
116+
const freeBuffersAfterSecondMatMul = bufferManager.getNumFreeBuffers();
117+
const usedBuffersAfterSecondMatMul = bufferManager.getNumUsedBuffers();
118+
expect(freeBuffersAfterSecondMatMul - freeBuffersAfterSecondMul).toEqual(0);
119+
expect(usedBuffersAfterSecondMatMul - usedBuffersAfterSecondMul).toEqual(2);
120+
tf.ENV.set('WEBGPU_IMMEDIATE_EXECUTION_ENABLED', savedFlag);
121+
});
122+
123+
it('should recycle buffers in delayed mode', () => {
124+
const savedFlag = tf.ENV.get('WEBGPU_IMMEDIATE_EXECUTION_ENABLED');
125+
tf.ENV.set('WEBGPU_IMMEDIATE_EXECUTION_ENABLED', false);
126+
const backend = tf.backend() as WebGPUBackend;
127+
const bufferManager = backend.getBufferManager();
128+
bufferManager.reset();
129+
130+
const a = tf.tensor2d([2, 4, 6, 8], [2, 2]);
131+
const b = tf.tensor2d([0.5, 0.5, 0.5, 0.5], [2, 2]);
132+
133+
const c = tf.mul(a, b);
134+
const freeBuffersAfterFirstMul = bufferManager.getNumFreeBuffers();
135+
const usedBuffersAfterFirstMul = bufferManager.getNumUsedBuffers();
136+
137+
const f = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]);
138+
tf.matMul(c, f);
139+
const freeBuffersAfterFirstMatMul = bufferManager.getNumFreeBuffers();
140+
const usedBuffersAfterFirstMatMul = bufferManager.getNumUsedBuffers();
141+
expect(freeBuffersAfterFirstMatMul - freeBuffersAfterFirstMul)
142+
.toEqual(1); // from released uniform
143+
expect(usedBuffersAfterFirstMatMul - usedBuffersAfterFirstMul).toEqual(2);
144+
145+
const a2 = tf.tensor2d([2, 4, 6, 8], [2, 2]);
146+
const b2 = tf.tensor2d([0.5, 0.5, 0.5, 0.5], [2, 2]);
147+
148+
const c2 = tf.mul(a2, b2);
149+
const freeBuffersAfterSecondMul = bufferManager.getNumFreeBuffers();
150+
const usedBuffersAfterSecondMul = bufferManager.getNumUsedBuffers();
151+
expect(freeBuffersAfterSecondMul - freeBuffersAfterFirstMatMul)
152+
.toEqual(0); // released a uniform buffer and reused a buffer
153+
expect(usedBuffersAfterSecondMul - usedBuffersAfterFirstMatMul).toEqual(3);
154+
155+
const f2 = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]);
156+
tf.matMul(c2, f2);
157+
const freeBuffersAfterSecondMatMul = bufferManager.getNumFreeBuffers();
158+
const usedBuffersAfterSecondMatMul = bufferManager.getNumUsedBuffers();
159+
expect(freeBuffersAfterSecondMatMul - freeBuffersAfterSecondMul).toEqual(0);
160+
expect(usedBuffersAfterSecondMatMul - usedBuffersAfterSecondMul).toEqual(2);
161+
tf.ENV.set('WEBGPU_IMMEDIATE_EXECUTION_ENABLED', savedFlag);
162+
});
163+
81164
it('readSync should throw if tensors are on the GPU', async () => {
82165
const a = tf.tensor2d([1, 2, 3, 4], [2, 2]);
83166
const b = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]);

0 commit comments

Comments
 (0)