@@ -43,6 +43,19 @@ describeWithFlags('fused matmul', ALL_ENVS, () => {
4343 expectArraysClose ( await c . data ( ) , [ 0 , 8 , 0 , 20 ] ) ;
4444 } ) ;
4545
46+ it ( 'A x B with elu' , async ( ) => {
47+ const a = tf . tensor2d ( [ 1 , 2 , 3 , 4 , 5 , 6 ] , [ 2 , 3 ] ) ;
48+ const b = tf . tensor2d ( [ 0 , 1 , - 3 , 2 , 2 , 1 ] , [ 3 , 2 ] ) ;
49+ const transposeA = false ;
50+ const transposeB = false ;
51+
52+ const c = tf . fused . matMul (
53+ { a, b, transposeA, transposeB, bias : null , activation : 'elu' } ) ;
54+
55+ expect ( c . shape ) . toEqual ( [ 2 , 2 ] ) ;
56+ expectArraysClose ( await c . data ( ) , [ 0 , 8 , - 0.9502 , 20 ] ) ;
57+ } ) ;
58+
4659 it ( 'A x B with prelu' , async ( ) => {
4760 const a = tf . tensor2d ( [ 1 , 2 , 3 , 4 , 5 , 6 ] , [ 2 , 3 ] ) ;
4861 const b = tf . tensor2d ( [ 0 , 1 , - 3 , 2 , 2 , 1 ] , [ 3 , 2 ] ) ;
@@ -106,6 +119,21 @@ describeWithFlags('fused matmul', ALL_ENVS, () => {
106119 expectArraysClose ( await d . data ( ) , [ 1 , 9 , 0 , 21 ] ) ;
107120 } ) ;
108121
122+ it ( 'A x B with elu and broadcasted bias' , async ( ) => {
123+ const a = tf . tensor2d ( [ 1 , 2 , 3 , 4 , 5 , 6 ] , [ 2 , 3 ] ) ;
124+ const b = tf . tensor2d ( [ 0 , 1 , - 3 , 2 , 2 , 1 ] , [ 3 , 2 ] ) ;
125+ const c = tf . tensor1d ( [ 1 , 1 ] ) ;
126+ const act : tf . fused . Activation = 'elu' ;
127+ const transposeA = false ;
128+ const transposeB = false ;
129+
130+ const d = tf . fused . matMul (
131+ { a, b, transposeA, transposeB, bias : c , activation : act } ) ;
132+
133+ expect ( d . shape ) . toEqual ( [ 2 , 2 ] ) ;
134+ expectArraysClose ( await d . data ( ) , [ 1 , 9 , - 0.8647 , 21 ] ) ;
135+ } ) ;
136+
109137 it ( 'A x B with relu and broadcasted bias different rank' , async ( ) => {
110138 const a = tf . tensor3d ( [ 0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 ] , [ 2 , 2 , 3 ] ) ;
111139 const b = tf . tensor3d ( [ 0 , 1 , - 3 , 2 , 2 , 1 , 0 , 1 , - 3 , 2 , 2 , 1 ] , [ 2 , 3 , 2 ] ) ;
@@ -318,6 +346,35 @@ describeWithFlags('fused conv2d', ALL_ENVS, () => {
318346 expectArraysClose ( await result . data ( ) , expected ) ;
319347 } ) ;
320348
349+ it ( 'basic with elu' , async ( ) => {
350+ const inputDepth = 2 ;
351+ const inShape : [ number , number , number , number ] = [ 2 , 2 , 2 , inputDepth ] ;
352+ const outputDepth = 2 ;
353+ const fSize = 1 ;
354+ const pad = 0 ;
355+ const stride = 1 ;
356+
357+ const x = tf . tensor4d (
358+ [ 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 , 12 , 13 , 14 , 15 , 16 ] , inShape ) ;
359+ const w =
360+ tf . tensor4d ( [ - 1 , 1 , - 2 , 0.5 ] , [ fSize , fSize , inputDepth , outputDepth ] ) ;
361+
362+ const result = tf . fused . conv2d ( {
363+ x,
364+ filter : w ,
365+ strides : stride ,
366+ pad,
367+ dataFormat : 'NHWC' ,
368+ dilations : [ 1 , 1 ] ,
369+ activation : 'elu'
370+ } ) ;
371+ expect ( result . shape ) . toEqual ( [ 2 , 2 , 2 , 2 ] ) ;
372+ const expected =
373+ [ - 0.99326 , 2 , - 1 , 5 , - 1 , 8 , - 1 , 11 , - 1 , 14 , - 1 , 17 , - 1 , 20 , - 1 , 23 ] ;
374+
375+ expectArraysClose ( await result . data ( ) , expected ) ;
376+ } ) ;
377+
321378 it ( 'basic with prelu' , async ( ) => {
322379 const inputDepth = 2 ;
323380 const inShape : [ number , number , number , number ] = [ 2 , 2 , 2 , inputDepth ] ;
0 commit comments