Skip to content

Commit a015b38

Browse files
authored
tfjs-webgpu: Add backend.time (#1817)
FEATURE
1 parent 892dd22 commit a015b38

File tree

11 files changed

+290
-103
lines changed

11 files changed

+290
-103
lines changed

tfjs-core/src/debug_mode_test.ts

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -55,18 +55,38 @@ describeWithFlags('debug on', SYNC_BACKEND_ENVS, () => {
5555
expect(a).toThrowError();
5656
});
5757

58-
it('debug mode errors when infinities in op output', () => {
58+
it('debug mode errors when infinities in op output', async () => {
5959
const a = tf.tensor1d([1, 2, 3, 4]);
6060
const b = tf.tensor1d([2, -1, 0, 3]);
61-
const c = () => a.div(b);
62-
expect(c).toThrowError();
61+
62+
spyOn(console, 'error');
63+
64+
const c = async () => {
65+
const result = a.div(b);
66+
// Must await result so we know exception would have happened by the
67+
// time we call `expect`.
68+
await result.data();
69+
};
70+
71+
await c();
72+
73+
expect(console.error).toHaveBeenCalled();
6374
});
6475

65-
it('debug mode errors when nans in op output', () => {
76+
it('debug mode errors when nans in op output', async () => {
6677
const a = tf.tensor1d([-1, 2]);
6778
const b = tf.tensor1d([0.5, 1]);
68-
const c = () => a.pow(b);
69-
expect(c).toThrowError();
79+
80+
spyOn(console, 'error');
81+
82+
const c = async () => {
83+
const result = a.pow(b);
84+
await result.data();
85+
};
86+
87+
await c();
88+
89+
expect(console.error).toHaveBeenCalled();
7090
});
7191

7292
it('debug mode errors when nans in oneHot op (tensorlike), int32', () => {

tfjs-core/src/index.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ export {RMSPropOptimizer} from './optimizers/rmsprop_optimizer';
5555
export {SGDOptimizer} from './optimizers/sgd_optimizer';
5656
export {Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D, TensorBuffer, variable, Variable} from './tensor';
5757
export {GradSaveFunc, NamedTensorMap, TensorContainer, TensorContainerArray, TensorContainerObject} from './tensor_types';
58-
export {DataType, DataTypeMap, DataValues, Rank, ShapeMap, TensorLike} from './types';
58+
export {DataType, DataTypeMap, DataValues, Rank, RecursiveArray, ShapeMap, TensorLike} from './types';
5959

6060
export * from './ops/ops';
6161
export {LSTMCellFunc} from './ops/lstm';
@@ -65,7 +65,7 @@ export * from './train';
6565
export * from './globals';
6666
export {customGrad, grad, grads, valueAndGrad, valueAndGrads, variableGrads} from './gradients';
6767

68-
export {TimingInfo} from './engine';
68+
export {TimingInfo, MemoryInfo} from './engine';
6969
export {ENV, Environment} from './environment';
7070
export {Platform} from './platforms/platform';
7171

tfjs-core/src/profiler.ts

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import {BackendTimer} from './backends/backend';
1919
import {Tensor} from './tensor';
2020
import {NamedTensorMap} from './tensor_types';
21-
import {TypedArray} from './types';
21+
import {DataType, DataTypeMap, TypedArray} from './types';
2222
import * as util from './util';
2323

2424
export class Profiler {
@@ -39,24 +39,47 @@ export class Profiler {
3939
const results: Tensor[] =
4040
Array.isArray(result) ? result : [result] as Tensor[];
4141
results.forEach(r => {
42-
const vals = r.dataSync();
43-
util.checkComputationForErrors(vals, r.dtype, name);
42+
// Dangling promise here because we don't want to propagate up
43+
// asynchronicity.
44+
r.data().then(vals => {
45+
checkComputationForErrors(vals, r.dtype, name);
4446

45-
timer.then(timing => {
46-
let extraInfo = '';
47-
if (timing.getExtraProfileInfo != null) {
48-
extraInfo = timing.getExtraProfileInfo();
49-
}
47+
timer.then(timing => {
48+
let extraInfo = '';
49+
if (timing.getExtraProfileInfo != null) {
50+
extraInfo = timing.getExtraProfileInfo();
51+
}
5052

51-
this.logger.logKernelProfile(
52-
name, r, vals, timing.kernelMs, inputs, extraInfo);
53+
this.logger.logKernelProfile(
54+
name, r, vals, timing.kernelMs, inputs, extraInfo);
55+
});
5356
});
5457
});
5558

5659
return result as T;
5760
}
5861
}
5962

63+
// Create a custom exception class so it can be stubbed in tests.
64+
export function ProfilerException(msg: string) {
65+
console.error(msg);
66+
}
67+
68+
export function checkComputationForErrors<D extends DataType>(
69+
vals: DataTypeMap[D], dtype: D, name: string): void {
70+
if (dtype !== 'float32') {
71+
// Only floating point computations will generate NaN values
72+
return;
73+
}
74+
for (let i = 0; i < vals.length; i++) {
75+
const num = vals[i] as number;
76+
if (isNaN(num) || !isFinite(num)) {
77+
// Throwing custom exception so behavior is testable.
78+
throw ProfilerException(`The result of the '${name}' is ${num}.`);
79+
}
80+
}
81+
}
82+
6083
export class Logger {
6184
logKernelProfile(
6285
name: string, result: Tensor, vals: TypedArray, timeMs: number,

tfjs-core/src/profiler_test.ts

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import {BackendTimer, BackendTimingInfo} from './backends/backend';
1919
import * as tf from './index';
2020
import {describeWithFlags, SYNC_BACKEND_ENVS} from './jasmine_util';
21-
import {Logger, Profiler} from './profiler';
21+
import {checkComputationForErrors, Logger, Profiler} from './profiler';
2222
import {Tensor} from './tensor';
2323
import {TypedArray} from './types';
2424

@@ -129,3 +129,27 @@ describeWithFlags('profiler.Profiler', SYNC_BACKEND_ENVS, () => {
129129
}, delayMs * 2);
130130
});
131131
});
132+
133+
describe('profiler.checkComputationForErrors', () => {
134+
it('Float32Array has NaN', () => {
135+
expect(
136+
() => checkComputationForErrors(
137+
new Float32Array([1, 2, 3, NaN, 4, 255]), 'float32', ''))
138+
.toThrow();
139+
});
140+
141+
it('Float32Array has Infinity', () => {
142+
expect(
143+
() => checkComputationForErrors(
144+
new Float32Array([1, 2, 3, Infinity, 4, 255]), 'float32', ''))
145+
.toThrow();
146+
});
147+
148+
it('Float32Array no NaN', () => {
149+
// Int32 and Bool NaNs should not trigger an error.
150+
expect(
151+
() => checkComputationForErrors(
152+
new Float32Array([1, 2, 3, -1, 4, 255]), 'float32', ''))
153+
.not.toThrow();
154+
});
155+
});

tfjs-core/src/util.ts

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -405,20 +405,6 @@ export function getArrayFromDType<D extends DataType>(
405405
return values as DataTypeMap[D];
406406
}
407407

408-
export function checkComputationForErrors<D extends DataType>(
409-
vals: DataTypeMap[D], dtype: D, name: string): void {
410-
if (dtype !== 'float32') {
411-
// Only floating point computations will generate NaN values
412-
return;
413-
}
414-
for (let i = 0; i < vals.length; i++) {
415-
const num = vals[i] as number;
416-
if (isNaN(num) || !isFinite(num)) {
417-
throw Error(`The result of the '${name}' is ${num}.`);
418-
}
419-
}
420-
}
421-
422408
export function checkConversionForErrors<D extends DataType>(
423409
vals: DataTypeMap[D]|number[], dtype: D): void {
424410
for (let i = 0; i < vals.length; i++) {

tfjs-core/src/util_test.ts

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -407,30 +407,6 @@ describe('util.squeezeShape', () => {
407407
});
408408
});
409409

410-
describe('util.checkComputationForErrors', () => {
411-
it('Float32Array has NaN', () => {
412-
expect(
413-
() => util.checkComputationForErrors(
414-
new Float32Array([1, 2, 3, NaN, 4, 255]), 'float32', ''))
415-
.toThrowError();
416-
});
417-
418-
it('Float32Array has Infinity', () => {
419-
expect(
420-
() => util.checkComputationForErrors(
421-
new Float32Array([1, 2, 3, Infinity, 4, 255]), 'float32', ''))
422-
.toThrowError();
423-
});
424-
425-
it('Float32Array no NaN', () => {
426-
// Int32 and Bool NaNs should not trigger an error.
427-
expect(
428-
() => util.checkComputationForErrors(
429-
new Float32Array([1, 2, 3, 4, -1, 255]), 'float32', ''))
430-
.not.toThrowError();
431-
});
432-
});
433-
434410
describe('util.checkConversionForErrors', () => {
435411
it('Float32Array has NaN', () => {
436412
expect(

tfjs-webgpu/package.json

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
},
1919
"license": "Apache-2.0",
2020
"devDependencies": {
21-
"@tensorflow/tfjs-core": "1.2.1",
21+
"@tensorflow/tfjs-core": "1.2.7",
2222
"@types/jasmine": "~2.5.53",
2323
"clang-format": "~1.2.2",
2424
"http-server": "~0.10.0",
@@ -34,6 +34,7 @@
3434
"rollup-plugin-commonjs": "9.1.3",
3535
"rollup-plugin-node-resolve": "3.3.0",
3636
"rollup-plugin-typescript2": "0.13.0",
37+
"rollup-plugin-terser": "^5.1.1",
3738
"rollup-plugin-uglify": "~3.0.0",
3839
"tslint": "~5.11.0",
3940
"tslint-no-circular-imports": "^0.5.0",
@@ -47,4 +48,4 @@
4748
"peerDependencies": {
4849
"@tensorflow/tfjs-core": "1.2.1"
4950
}
50-
}
51+
}

tfjs-webgpu/rollup.config.js

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,7 @@ function config({plugins = [], output = {}, external = []}) {
2626
return {
2727
input: 'src/index.ts',
2828
plugins: [
29-
typescript({
30-
tsconfigOverride: {compilerOptions: {module: 'ES2015'}}
31-
}),
29+
typescript({tsconfigOverride: {compilerOptions: {module: 'ES2015'}}}),
3230
node(),
3331
// Polyfill require() from dependencies.
3432
commonjs({

0 commit comments

Comments
 (0)