@@ -29,7 +29,7 @@ import * as axis_util from '../../ops/axis_util';
2929import { complex , imag , real } from '../../ops/complex_ops' ;
3030import { computeOutShape } from '../../ops/concat_util' ;
3131import { Conv2DInfo , Conv3DInfo } from '../../ops/conv_util' ;
32- import { Activation , FusedBatchMatMulConfig } from '../../ops/fused_util' ;
32+ import { Activation , FusedBatchMatMulConfig , FusedConv2DConfig } from '../../ops/fused_util' ;
3333import * as gather_nd_util from '../../ops/gather_nd_util' ;
3434import * as reduce_util from '../../ops/reduce_util' ;
3535import * as scatter_nd_util from '../../ops/scatter_nd_util' ;
@@ -1909,7 +1909,7 @@ export class MathBackendWebGL implements KernelBackend {
19091909 }
19101910
19111911 private conv2dByMatMul (
1912- x : Tensor4D , filter : Tensor4D , convInfo : Conv2DInfo , bias ?: Tensor4D ,
1912+ x : Tensor4D , filter : Tensor4D , convInfo : Conv2DInfo , bias ?: Tensor ,
19131913 activation ?: Activation , preluActivationWeights ?: Tensor ) : Tensor4D {
19141914 // Reshapes conv2D input to 2D tensors, uses matMul and then reshape the
19151915 // result from 2D to 4D.
@@ -2008,7 +2008,7 @@ export class MathBackendWebGL implements KernelBackend {
20082008 }
20092009
20102010 private conv2dWithIm2Row (
2011- x : Tensor4D , filter : Tensor4D , convInfo : Conv2DInfo , bias ?: Tensor4D ,
2011+ x : Tensor4D , filter : Tensor4D , convInfo : Conv2DInfo , bias ?: Tensor ,
20122012 activation ?: Activation , preluActivationWeights ?: Tensor ) : Tensor4D {
20132013 // Rearranges conv2d input so each block to be convolved over forms the
20142014 // column of a new matrix with shape [filterWidth * filterHeight *
@@ -2067,19 +2067,19 @@ export class MathBackendWebGL implements KernelBackend {
20672067 }
20682068
20692069 fusedConv2d (
2070- x : Tensor4D , filter : Tensor4D , convInfo : Conv2DInfo , bias ?: Tensor4D ,
2071- activation ?: Activation , preluActivationWeights ?: Tensor ) : Tensor4D {
2070+ { input , filter, convInfo, bias, activation , preluActivationWeights } :
2071+ FusedConv2DConfig ) : Tensor4D {
20722072 if ( convInfo . filterHeight === 1 && convInfo . filterWidth === 1 &&
20732073 convInfo . dilationHeight === 1 && convInfo . dilationWidth === 1 &&
20742074 convInfo . strideHeight === 1 && convInfo . strideWidth === 1 &&
20752075 ( convInfo . padInfo . type === 'SAME' ||
20762076 convInfo . padInfo . type === 'VALID' ) ) {
20772077 return this . conv2dByMatMul (
2078- x , filter , convInfo , bias , activation , preluActivationWeights ) ;
2078+ input , filter , convInfo , bias , activation , preluActivationWeights ) ;
20792079 }
2080- if ( ENV . getBool ( 'WEBGL_CONV_IM2COL' ) && x . shape [ 0 ] === 1 ) {
2080+ if ( ENV . getBool ( 'WEBGL_CONV_IM2COL' ) && input . shape [ 0 ] === 1 ) {
20812081 return this . conv2dWithIm2Row (
2082- x , filter , convInfo , bias , activation , preluActivationWeights ) ;
2082+ input , filter , convInfo , bias , activation , preluActivationWeights ) ;
20832083 }
20842084
20852085 const hasBias = bias != null ;
@@ -2088,7 +2088,7 @@ export class MathBackendWebGL implements KernelBackend {
20882088 activation ? mapActivationToShaderProgram ( activation , false ) : null ;
20892089 const program = new Conv2DProgram (
20902090 convInfo , hasBias , fusedActivation , hasPreluActivationWeights ) ;
2091- const inputs : TensorHandle [ ] = [ x , filter ] ;
2091+ const inputs : TensorHandle [ ] = [ input , filter ] ;
20922092 if ( bias ) {
20932093 inputs . push ( bias ) ;
20942094 }
@@ -2124,6 +2124,40 @@ export class MathBackendWebGL implements KernelBackend {
21242124 return this . compileAndRun ( program , [ x , dy ] ) ;
21252125 }
21262126
2127+ fusedDepthwiseConv2D (
2128+ { input, filter, convInfo, bias, activation, preluActivationWeights} :
2129+ FusedConv2DConfig ) : Tensor4D {
2130+ const shouldPackDepthwiseConv = ENV . getBool ( 'WEBGL_PACK_DEPTHWISECONV' ) &&
2131+ convInfo . strideWidth <= 2 &&
2132+ convInfo . outChannels / convInfo . inChannels === 1 ;
2133+ const fusedActivation = activation ?
2134+ mapActivationToShaderProgram ( activation , shouldPackDepthwiseConv ) :
2135+ null ;
2136+ const inputs : Tensor [ ] = [ input , filter ] ;
2137+
2138+ const hasBias = bias != null ;
2139+ const hasPreluActivationWeights = preluActivationWeights != null ;
2140+ if ( hasBias ) {
2141+ inputs . push ( bias ) ;
2142+ }
2143+ if ( hasPreluActivationWeights ) {
2144+ inputs . push ( preluActivationWeights ) ;
2145+ }
2146+
2147+ let program : DepthwiseConv2DProgram | DepthwiseConvPacked2DProgram ;
2148+ if ( shouldPackDepthwiseConv ) {
2149+ program = new DepthwiseConvPacked2DProgram (
2150+ convInfo , hasBias , fusedActivation , hasPreluActivationWeights ) ;
2151+ return this . compileAndRun (
2152+ program , inputs ,
2153+ this . makePackedTensor ( convInfo . outShape , input . dtype ) ) ;
2154+ }
2155+
2156+ program = new DepthwiseConv2DProgram (
2157+ convInfo , hasBias , fusedActivation , hasPreluActivationWeights ) ;
2158+ return this . compileAndRun ( program , inputs ) ;
2159+ }
2160+
21272161 depthwiseConv2D ( x : Tensor4D , filter : Tensor4D , convInfo : Conv2DInfo ) :
21282162 Tensor4D {
21292163 let program : DepthwiseConv2DProgram | DepthwiseConvPacked2DProgram ;
0 commit comments