1818import { BackendTimer } from './backends/backend' ;
1919import { Tensor } from './tensor' ;
2020import { NamedTensorMap } from './tensor_types' ;
21- import { TypedArray } from './types' ;
21+ import { DataType , DataTypeMap , TypedArray } from './types' ;
2222import * as util from './util' ;
2323
2424export class Profiler {
@@ -39,24 +39,47 @@ export class Profiler {
3939 const results : Tensor [ ] =
4040 Array . isArray ( result ) ? result : [ result ] as Tensor [ ] ;
4141 results . forEach ( r => {
42- const vals = r . dataSync ( ) ;
43- util . checkComputationForErrors ( vals , r . dtype , name ) ;
42+ // Dangling promise here because we don't want to propagate up
43+ // asynchronicity.
44+ r . data ( ) . then ( vals => {
45+ checkComputationForErrors ( vals , r . dtype , name ) ;
4446
45- timer . then ( timing => {
46- let extraInfo = '' ;
47- if ( timing . getExtraProfileInfo != null ) {
48- extraInfo = timing . getExtraProfileInfo ( ) ;
49- }
47+ timer . then ( timing => {
48+ let extraInfo = '' ;
49+ if ( timing . getExtraProfileInfo != null ) {
50+ extraInfo = timing . getExtraProfileInfo ( ) ;
51+ }
5052
51- this . logger . logKernelProfile (
52- name , r , vals , timing . kernelMs , inputs , extraInfo ) ;
53+ this . logger . logKernelProfile (
54+ name , r , vals , timing . kernelMs , inputs , extraInfo ) ;
55+ } ) ;
5356 } ) ;
5457 } ) ;
5558
5659 return result as T ;
5760 }
5861}
5962
63+ // Create a custom exception class so it can be stubbed in tests.
64+ export function ProfilerException ( msg : string ) {
65+ console . error ( msg ) ;
66+ }
67+
68+ export function checkComputationForErrors < D extends DataType > (
69+ vals : DataTypeMap [ D ] , dtype : D , name : string ) : void {
70+ if ( dtype !== 'float32' ) {
71+ // Only floating point computations will generate NaN values
72+ return ;
73+ }
74+ for ( let i = 0 ; i < vals . length ; i ++ ) {
75+ const num = vals [ i ] as number ;
76+ if ( isNaN ( num ) || ! isFinite ( num ) ) {
77+ // Throwing custom exception so behavior is testable.
78+ throw ProfilerException ( `The result of the '${ name } ' is ${ num } .` ) ;
79+ }
80+ }
81+ }
82+
6083export class Logger {
6184 logKernelProfile (
6285 name : string , result : Tensor , vals : TypedArray , timeMs : number ,
0 commit comments