@@ -873,10 +873,12 @@ export class MathBackendWebGL implements KernelBackend {
873873
874874 const dtype = upcastType ( a . dtype , b . dtype ) ;
875875
876+ const hasBias = bias != null ;
877+ const fusedActivation =
878+ activation ? mapActivationToShaderProgram ( activation , true ) : null ;
876879 const program = new MatMulPackedProgram (
877880 a . shape , [ batch , outerShapeA , outerShapeB ] , transposeA , transposeB ,
878- ! ! bias ,
879- activation ? mapActivationToShaderProgram ( activation , true ) : null ) ;
881+ hasBias , fusedActivation ) ;
880882 const output =
881883 this . makePackedTensor ( program . outputShape , dtype ) as Tensor3D ;
882884 const inputs : TensorHandle [ ] = [ a , b ] ;
@@ -1815,15 +1817,18 @@ export class MathBackendWebGL implements KernelBackend {
18151817 return this . compileAndRun ( program , [ x ] ) as T ;
18161818 }
18171819
1818- conv2dByMatMul ( x : Tensor4D , filter : Tensor4D , convInfo : Conv2DInfo ) :
1819- Tensor4D {
1820+ private conv2dByMatMul (
1821+ x : Tensor4D , filter : Tensor4D , convInfo : Conv2DInfo , bias ?: Tensor4D ,
1822+ activation ?: Activation ) : Tensor4D {
18201823 // Reshapes conv2D input to 2D tensors, uses matMul and then reshape the
18211824 // result from 2D to 4D.
18221825 const xShape = x . shape ;
18231826 const xTexData = this . texData . get ( x . dataId ) ;
18241827 const sharedMatMulDim = convInfo . inChannels ;
18251828 const outerShapeX = xShape [ 0 ] * xShape [ 1 ] * xShape [ 2 ] ;
18261829 const outerShapeFilter = convInfo . outChannels ;
1830+ const transposeA = false ;
1831+ const transposeB = false ;
18271832
18281833 // TODO: Once reduction ops are packed, batchMatMul will always be packed
18291834 // and we can remove this condition.
@@ -1843,8 +1848,11 @@ export class MathBackendWebGL implements KernelBackend {
18431848 this . reshape (
18441849 filter , [ 1 , convInfo . inChannels , convInfo . outChannels ] ) as
18451850 Tensor3D ;
1851+
18461852 return this . reshape < Rank . R4 > (
1847- this . batchMatMul ( xReshaped , filterReshaped , false , false ) ,
1853+ this . fusedBatchMatMul (
1854+ xReshaped , filterReshaped , transposeA , transposeB , bias ,
1855+ activation ) ,
18481856 convInfo . outShape ) ;
18491857 }
18501858
@@ -1880,8 +1888,8 @@ export class MathBackendWebGL implements KernelBackend {
18801888 this . reshape ( filter , [ 1 , convInfo . inChannels , convInfo . outChannels ] ) as
18811889 Tensor3D ;
18821890
1883- const pointwiseConv =
1884- this . batchMatMul ( xReshaped , filterReshaped , false , false ) ;
1891+ const pointwiseConv = this . fusedBatchMatMul (
1892+ xReshaped , filterReshaped , transposeA , transposeB , bias , activation ) ;
18851893 const pointwiseConvTexData = this . texData . get ( pointwiseConv . dataId ) ;
18861894 util . assert (
18871895 pointwiseConvTexData . isPacked ,
@@ -1896,8 +1904,9 @@ export class MathBackendWebGL implements KernelBackend {
18961904 pointwiseConv . dtype , this ) as Tensor4D ;
18971905 }
18981906
1899- conv2dWithIm2Row ( x : Tensor4D , filter : Tensor4D , convInfo : Conv2DInfo ) :
1900- Tensor4D {
1907+ private conv2dWithIm2Row (
1908+ x : Tensor4D , filter : Tensor4D , convInfo : Conv2DInfo , bias ?: Tensor4D ,
1909+ activation ?: Activation ) : Tensor4D {
19011910 // Rearranges conv2d input so each block to be convolved over forms the
19021911 // column of a new matrix with shape [filterWidth * filterHeight *
19031912 // inChannels, outHeight * outWidth]. The filter is also rearranged so each
@@ -1915,6 +1924,8 @@ export class MathBackendWebGL implements KernelBackend {
19151924 const sharedDim = filterWidth * filterHeight * inChannels ;
19161925 const numCols = outHeight * outWidth ;
19171926 const x2ColShape = [ sharedDim , numCols ] ;
1927+ const transposeA = true ;
1928+ const transposeB = false ;
19181929
19191930 const xSqueezed = x . squeeze ( [ 0 ] ) ;
19201931 const w2Row = filter . reshape ( [ 1 , sharedDim , - 1 ] ) as Tensor3D ;
@@ -1926,14 +1937,46 @@ export class MathBackendWebGL implements KernelBackend {
19261937 1 , x2ColShape [ 0 ] , x2ColShape [ 1 ]
19271938 ] ) as Tensor3D ;
19281939
1940+ const hasBias = bias != null ;
1941+ const fusedActivation =
1942+ activation ? mapActivationToShaderProgram ( activation , true ) : null ;
19291943 const matmulProgram = new MatMulPackedProgram (
1930- im2Col . shape , [ 1 , numCols , convInfo . outChannels ] , true , false ) ;
1931- const product =
1932- this . compileAndRun < Tensor4D > ( matmulProgram , [ im2Col , w2Row ] ) ;
1944+ im2Col . shape , [ 1 , numCols , convInfo . outChannels ] , transposeA ,
1945+ transposeB , hasBias , fusedActivation ) ;
1946+ const inputs : TensorHandle [ ] = [ im2Col , w2Row ] ;
1947+ if ( bias ) {
1948+ inputs . push ( bias ) ;
1949+ }
1950+ const product = this . compileAndRun < Tensor4D > ( matmulProgram , inputs ) ;
19331951
19341952 return product . reshape ( [ 1 , outHeight , outWidth , convInfo . outChannels ] ) ;
19351953 }
19361954
1955+ fusedConv2d (
1956+ x : Tensor4D , filter : Tensor4D , convInfo : Conv2DInfo , bias ?: Tensor4D ,
1957+ activation ?: Activation ) : Tensor4D {
1958+ if ( convInfo . filterHeight === 1 && convInfo . filterWidth === 1 &&
1959+ convInfo . dilationHeight === 1 && convInfo . dilationWidth === 1 &&
1960+ convInfo . strideHeight === 1 && convInfo . strideWidth === 1 &&
1961+ ( convInfo . padInfo . type === 'SAME' ||
1962+ convInfo . padInfo . type === 'VALID' ) ) {
1963+ return this . conv2dByMatMul ( x , filter , convInfo , bias , activation ) ;
1964+ }
1965+ if ( ENV . getBool ( 'WEBGL_CONV_IM2COL' ) && x . shape [ 0 ] === 1 ) {
1966+ return this . conv2dWithIm2Row ( x , filter , convInfo , bias , activation ) ;
1967+ }
1968+
1969+ const hasBias = bias != null ;
1970+ const fusedActivation =
1971+ activation ? mapActivationToShaderProgram ( activation , false ) : null ;
1972+ const program = new Conv2DProgram ( convInfo , hasBias , fusedActivation ) ;
1973+ const inputs : TensorHandle [ ] = [ x , filter ] ;
1974+ if ( bias ) {
1975+ inputs . push ( bias ) ;
1976+ }
1977+ return this . compileAndRun ( program , inputs ) ;
1978+ }
1979+
19371980 conv2d ( x : Tensor4D , filter : Tensor4D , convInfo : Conv2DInfo ) : Tensor4D {
19381981 if ( convInfo . filterHeight === 1 && convInfo . filterWidth === 1 &&
19391982 convInfo . dilationHeight === 1 && convInfo . dilationWidth === 1 &&
0 commit comments