@@ -28,8 +28,9 @@ describeWebGPU('Ops benchmarks', () => {
2828 // avoided by using fences, but we don't have a common abstraction over
2929 // WebGL and WebGPU fences at the moment.
3030 async function time (
31- trials : number , reps : number , doRep : ( r : number ) => tf . Tensor [ ] ,
32- endTrial : ( ) => Promise < void > ) {
31+ doRep : ( r : number ) => tf . Tensor [ ] | tf . Tensor ,
32+ endTrial ?: ( ) => Promise < void > , disposeAfterEachTrial = false ,
33+ trials = 50 , reps = 1 ) {
3334 const times = [ ] ;
3435
3536 let toDispose : tf . Tensor [ ] = [ ] ;
@@ -40,11 +41,19 @@ describeWebGPU('Ops benchmarks', () => {
4041 toDispose = [ ] ;
4142 } ;
4243
43- const trial = ( ) => {
44+ const trial = async ( ) => {
45+ let result ;
4446 for ( let r = 0 ; r < reps ; ++ r ) {
45- toDispose = toDispose . concat ( doRep ( r ) ) ;
47+ result = doRep ( r ) ;
48+
49+ toDispose = toDispose . concat ( Array . isArray ( result ) ? result : [ result ] ) ;
50+ }
51+
52+ if ( endTrial != null ) {
53+ await endTrial ( ) ;
54+ } else {
55+ await ( Array . isArray ( result ) ? result [ 0 ] : result ) . data ( ) ;
4656 }
47- return endTrial ( ) ;
4857 } ;
4958
5059 // Warm-up. Specifically, this pre-allocates enough memory for an entire
@@ -57,7 +66,9 @@ describeWebGPU('Ops benchmarks', () => {
5766 const start = tf . util . now ( ) ;
5867 await trial ( ) ;
5968 times . push ( tf . util . now ( ) - start ) ;
60- dispose ( ) ;
69+ if ( disposeAfterEachTrial ) {
70+ dispose ( ) ;
71+ }
6172 }
6273
6374 const mean = times . reduce ( ( a , b ) => a + b , 0 ) / trials ;
@@ -67,8 +78,7 @@ describeWebGPU('Ops benchmarks', () => {
6778 console . log ( `Min time: ${ fmt ( min ) } ms -> ${ fmt ( min / reps ) } / rep` ) ;
6879 }
6980
70- // tslint:disable-next-line:ban
71- xit ( 'argMax' , async ( ) => {
81+ it ( 'argMax' , async ( ) => {
7282 const n = 50 ;
7383 const doTest = async ( axis : number ) => {
7484 const tensors = new Array ( n ) ;
@@ -78,7 +88,6 @@ describeWebGPU('Ops benchmarks', () => {
7888 }
7989
8090 await time (
81- 5 , n ,
8291 ( r ) => {
8392 maxes [ r ] = tf . argMax ( tensors [ r ] , axis ) ;
8493 return [ ] ;
@@ -96,39 +105,24 @@ describeWebGPU('Ops benchmarks', () => {
96105 await doTest ( 2 ) ;
97106 } , 60000 ) ;
98107
99- // tslint:disable-next-line:ban
100- xit ( 'matMul' , async ( ) => {
101- let a = tf . randomNormal ( [ 500 , 500 ] ) ;
108+ it ( 'matMul' , async ( ) => {
109+ const a = tf . randomNormal ( [ 500 , 500 ] ) ;
102110 const b = tf . randomNormal ( [ 500 , 500 ] ) ;
103111
104- await time (
105- 5 , 50 ,
106- ( ) => {
107- const c = tf . matMul ( a , b ) ;
108- const toDispose = a ;
109- a = c ;
110- return [ toDispose ] ;
111- } ,
112- async ( ) => {
113- await a . data ( ) ;
114- } ) ;
115- } , 60000 ) ;
112+ await time ( ( ) => tf . matMul ( a , b ) ) ;
113+ } ) ;
114+
115+ it ( 'add' , async ( ) => {
116+ const a = tf . randomNormal ( [ 1 , 65 , 65 , 256 ] ) ;
117+ const b = tf . randomNormal ( [ 1 , 65 , 65 , 256 ] ) ;
116118
117- // tslint:disable-next-line:ban
118- xit ( 'conv2d' , async ( ) => {
119- let a = tf . randomNormal < tf . Rank . R4 > ( [ 1 , 128 , 128 , 4 ] ) ;
119+ await time ( ( ) => tf . add ( a , b ) ) ;
120+ } ) ;
121+
122+ it ( 'conv2d' , async ( ) => {
123+ const a = tf . randomNormal < tf . Rank . R4 > ( [ 1 , 128 , 128 , 4 ] ) ;
120124 const b = tf . randomNormal < tf . Rank . R4 > ( [ 25 , 25 , 4 , 4 ] ) ;
121125
122- await time (
123- 5 , 50 ,
124- ( ) => {
125- const c = tf . conv2d ( a , b , 1 , 'same' ) ;
126- const toDispose = a ;
127- a = c ;
128- return [ toDispose ] ;
129- } ,
130- async ( ) => {
131- await a . data ( ) ;
132- } ) ;
133- } , 60000 ) ;
126+ await time ( ( ) => tf . conv2d ( a , b , 1 , 'same' ) ) ;
127+ } ) ;
134128} ) ;
0 commit comments