Skip to content

Commit f39acb2

Browse files
authored
[tfjs-core] Fix caching on cpu with async reads (#1974)
BUG, PERF This PR fixes a bug introduced in tensorflow/tfjs-core#1798 The values were not cached after an async read, which has performance implications. The inference of an AutoML object detection model including post-processing changed from 107ms to 80ms. The reason could be an accumulation of repeated reads done by the converter executor when we call `model.executeAsync()`.
1 parent 8049775 commit f39acb2

File tree

3 files changed

+61
-8
lines changed

3 files changed

+61
-8
lines changed

tfjs-core/src/backends/webgl/backend_webgl.ts

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -472,12 +472,12 @@ export class MathBackendWebGL implements KernelBackend {
472472
}
473473

474474
let buffer = null;
475+
let tmpDownloadTarget: TensorHandle;
476+
475477
if (dtype !== 'complex64' && ENV.get('WEBGL_BUFFER_SUPPORTED')) {
476478
// Possibly copy the texture into a buffer before inserting a fence.
477-
const tmpTarget = this.decode(dataId);
478-
479-
dataId = tmpTarget.dataId;
480-
const tmpData = this.texData.get(tmpTarget.dataId);
479+
tmpDownloadTarget = this.decode(dataId);
480+
const tmpData = this.texData.get(tmpDownloadTarget.dataId);
481481

482482
buffer = this.gpgpu.createBufferFromTexture(
483483
tmpData.texture, ...tex_util.getDenseTexShape(shape));
@@ -503,9 +503,10 @@ export class MathBackendWebGL implements KernelBackend {
503503
vals = this.getValuesFromTexture(dataId);
504504
} else {
505505
const size = util.sizeFromShape(shape);
506-
507506
vals = this.gpgpu.downloadFloat32MatrixFromBuffer(buffer, size);
508-
this.disposeData(dataId);
507+
}
508+
if (tmpDownloadTarget != null) {
509+
this.disposeData(tmpDownloadTarget.dataId);
509510
}
510511
const dTypeVals = this.convertAndCacheOnCPU(dataId, vals);
511512

@@ -690,6 +691,14 @@ export class MathBackendWebGL implements KernelBackend {
690691
return this.texData.get(dataId).texture;
691692
}
692693

694+
/**
695+
* Returns internal information for the specific data bucket. Used in unit
696+
* tests.
697+
*/
698+
getDataInfo(dataId: DataId): TextureData {
699+
return this.texData.get(dataId);
700+
}
701+
693702
private getCPUBackend(): KernelBackend|null {
694703
if (!ENV.getBool('WEBGL_CPU_FORWARD')) {
695704
return null;

tfjs-core/src/backends/webgl/backend_webgl_test.ts

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -506,6 +506,50 @@ describeWithFlags('time webgl', WEBGL_ENVS, () => {
506506
});
507507
});
508508

509+
describeWithFlags('caching on cpu', WEBGL_ENVS, () => {
510+
beforeAll(() => {
511+
tf.ENV.set('WEBGL_CPU_FORWARD', false);
512+
});
513+
514+
it('caches on cpu after async read', async () => {
515+
const backend = new MathBackendWebGL();
516+
tf.registerBackend('cache-on-cpu', () => backend);
517+
tf.setBackend('cache-on-cpu');
518+
519+
const t = tf.square(2);
520+
const info = backend.getDataInfo(t.dataId);
521+
522+
// Make sure the tensor is on the GPU.
523+
expect(info.values == null).toBe(true);
524+
525+
await t.data();
526+
527+
// Make sure the tensor is cached on CPU.
528+
expect(info.values).not.toBe(null);
529+
530+
tf.removeBackend('cache-on-cpu');
531+
});
532+
533+
it('caches on cpu after sync read', () => {
534+
const backend = new MathBackendWebGL();
535+
tf.registerBackend('cache-on-cpu', () => backend);
536+
tf.setBackend('cache-on-cpu');
537+
538+
const t = tf.square(2);
539+
const info = backend.getDataInfo(t.dataId);
540+
541+
// Make sure the tensor is on the GPU.
542+
expect(info.values == null).toBe(true);
543+
544+
t.dataSync();
545+
546+
// Make sure the tensor is cached on CPU.
547+
expect(info.values).not.toBe(null);
548+
549+
tf.removeBackend('cache-on-cpu');
550+
});
551+
});
552+
509553
describe('WebGL backend has sync init', () => {
510554
it('can do matmul without waiting for ready', async () => {
511555
tf.registerBackend('my-webgl', () => {

tfjs-core/src/ops/image_ops.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -191,8 +191,8 @@ async function nonMaxSuppressionAsync_(
191191
iouThreshold = inputs.iouThreshold;
192192
scoreThreshold = inputs.scoreThreshold;
193193

194-
const boxesVals = await $boxes.data();
195-
const scoresVals = await $scores.data();
194+
const [boxesVals, scoresVals] =
195+
await Promise.all([$boxes.data(), $scores.data()]);
196196
const res = nonMaxSuppressionImpl(
197197
boxesVals, scoresVals, maxOutputSize, iouThreshold, scoreThreshold);
198198
if ($boxes !== boxes) {

0 commit comments

Comments
 (0)