Skip to content

Commit 28923ef

Browse files
authored
tfjs-webgpu: Improved timing in benchmarks. (#1823)
FEATURE
1 parent f3c503e commit 28923ef

File tree

2 files changed

+34
-39
lines changed

2 files changed

+34
-39
lines changed

tfjs-webgpu/karma.conf.js

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ module.exports = function(config) {
4242
'src/setup_test.ts', // Setup the environment for the tests.
4343
{pattern: 'src/**/*.ts'}, // Import all tests.
4444
],
45+
exclude: ['src/benchmark_ops_test.ts'],
4546
preprocessors: {'**/*.ts': ['karma-typescript']},
4647
karmaTypescriptConfig,
4748
reporters: ['progress', 'karma-typescript'],

tfjs-webgpu/src/benchmark_ops_test.ts

Lines changed: 33 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,9 @@ describeWebGPU('Ops benchmarks', () => {
2828
// avoided by using fences, but we don't have a common abstraction over
2929
// WebGL and WebGPU fences at the moment.
3030
async function time(
31-
trials: number, reps: number, doRep: (r: number) => tf.Tensor[],
32-
endTrial: () => Promise<void>) {
31+
doRep: (r: number) => tf.Tensor[] | tf.Tensor,
32+
endTrial?: () => Promise<void>, disposeAfterEachTrial = false,
33+
trials = 50, reps = 1) {
3334
const times = [];
3435

3536
let toDispose: tf.Tensor[] = [];
@@ -40,11 +41,19 @@ describeWebGPU('Ops benchmarks', () => {
4041
toDispose = [];
4142
};
4243

43-
const trial = () => {
44+
const trial = async () => {
45+
let result;
4446
for (let r = 0; r < reps; ++r) {
45-
toDispose = toDispose.concat(doRep(r));
47+
result = doRep(r);
48+
49+
toDispose = toDispose.concat(Array.isArray(result) ? result : [result]);
50+
}
51+
52+
if (endTrial != null) {
53+
await endTrial();
54+
} else {
55+
await (Array.isArray(result) ? result[0] : result).data();
4656
}
47-
return endTrial();
4857
};
4958

5059
// Warm-up. Specifically, this pre-allocates enough memory for an entire
@@ -57,7 +66,9 @@ describeWebGPU('Ops benchmarks', () => {
5766
const start = tf.util.now();
5867
await trial();
5968
times.push(tf.util.now() - start);
60-
dispose();
69+
if (disposeAfterEachTrial) {
70+
dispose();
71+
}
6172
}
6273

6374
const mean = times.reduce((a, b) => a + b, 0) / trials;
@@ -67,8 +78,7 @@ describeWebGPU('Ops benchmarks', () => {
6778
console.log(`Min time: ${fmt(min)} ms -> ${fmt(min / reps)} / rep`);
6879
}
6980

70-
// tslint:disable-next-line:ban
71-
xit('argMax', async () => {
81+
it('argMax', async () => {
7282
const n = 50;
7383
const doTest = async (axis: number) => {
7484
const tensors = new Array(n);
@@ -78,7 +88,6 @@ describeWebGPU('Ops benchmarks', () => {
7888
}
7989

8090
await time(
81-
5, n,
8291
(r) => {
8392
maxes[r] = tf.argMax(tensors[r], axis);
8493
return [];
@@ -96,39 +105,24 @@ describeWebGPU('Ops benchmarks', () => {
96105
await doTest(2);
97106
}, 60000);
98107

99-
// tslint:disable-next-line:ban
100-
xit('matMul', async () => {
101-
let a = tf.randomNormal([500, 500]);
108+
it('matMul', async () => {
109+
const a = tf.randomNormal([500, 500]);
102110
const b = tf.randomNormal([500, 500]);
103111

104-
await time(
105-
5, 50,
106-
() => {
107-
const c = tf.matMul(a, b);
108-
const toDispose = a;
109-
a = c;
110-
return [toDispose];
111-
},
112-
async () => {
113-
await a.data();
114-
});
115-
}, 60000);
112+
await time(() => tf.matMul(a, b));
113+
});
114+
115+
it('add', async () => {
116+
const a = tf.randomNormal([1, 65, 65, 256]);
117+
const b = tf.randomNormal([1, 65, 65, 256]);
116118

117-
// tslint:disable-next-line:ban
118-
xit('conv2d', async () => {
119-
let a = tf.randomNormal<tf.Rank.R4>([1, 128, 128, 4]);
119+
await time(() => tf.add(a, b));
120+
});
121+
122+
it('conv2d', async () => {
123+
const a = tf.randomNormal<tf.Rank.R4>([1, 128, 128, 4]);
120124
const b = tf.randomNormal<tf.Rank.R4>([25, 25, 4, 4]);
121125

122-
await time(
123-
5, 50,
124-
() => {
125-
const c = tf.conv2d(a, b, 1, 'same');
126-
const toDispose = a;
127-
a = c;
128-
return [toDispose];
129-
},
130-
async () => {
131-
await a.data();
132-
});
133-
}, 60000);
126+
await time(() => tf.conv2d(a, b, 1, 'same'));
127+
});
134128
});

0 commit comments

Comments
 (0)