@@ -28,7 +28,7 @@ import * as array_ops_util from '../../ops/array_ops_util';
2828import * as axis_util from '../../ops/axis_util' ;
2929import { computeOutShape } from '../../ops/concat_util' ;
3030import { Conv2DInfo , Conv3DInfo } from '../../ops/conv_util' ;
31- import { Activation } from '../../ops/fused_util' ;
31+ import { Activation , FusedBatchMatMulConfig } from '../../ops/fused_util' ;
3232import * as gather_nd_util from '../../ops/gather_nd_util' ;
3333import * as reduce_util from '../../ops/reduce_util' ;
3434import * as scatter_nd_util from '../../ops/scatter_nd_util' ;
@@ -174,6 +174,11 @@ function mapActivationToShaderProgram(
174174 return unary_packed_op . RELU ;
175175 }
176176 return unary_op . RELU ;
177+ } else if ( activation === 'prelu' ) {
178+ if ( packed ) {
179+ return binaryop_packed_gpu . PRELU ;
180+ }
181+ return binaryop_gpu . PRELU ;
177182 }
178183 throw new Error ( `Activation ${
179184 activation } has not been implemented for the WebGL backend.`) ;
@@ -865,26 +870,30 @@ export class MathBackendWebGL implements KernelBackend {
865870 }
866871
867872 fusedBatchMatMul (
868- a : Tensor3D , b : Tensor3D , transposeA : boolean , transposeB : boolean ,
869- bias ?: Tensor , activation ?: Activation ) : Tensor3D {
873+ { a , b, transposeA, transposeB, bias , activation , preluActivationWeights } :
874+ FusedBatchMatMulConfig ) : Tensor3D {
870875 const outerShapeA = transposeA ? a . shape [ 2 ] : a . shape [ 1 ] ;
871876 const outerShapeB = transposeB ? b . shape [ 1 ] : b . shape [ 2 ] ;
872877 const [ batch , , ] = a . shape ;
873878
874879 const dtype = upcastType ( a . dtype , b . dtype ) ;
875880
876881 const hasBias = bias != null ;
882+ const hasPreluActivationWeights = preluActivationWeights != null ;
877883 const fusedActivation =
878884 activation ? mapActivationToShaderProgram ( activation , true ) : null ;
879885 const program = new MatMulPackedProgram (
880886 a . shape , [ batch , outerShapeA , outerShapeB ] , transposeA , transposeB ,
881- hasBias , fusedActivation ) ;
887+ hasBias , fusedActivation , hasPreluActivationWeights ) ;
882888 const output =
883889 this . makePackedTensor ( program . outputShape , dtype ) as Tensor3D ;
884890 const inputs : TensorHandle [ ] = [ a , b ] ;
885891 if ( bias ) {
886892 inputs . push ( bias ) ;
887893 }
894+ if ( preluActivationWeights ) {
895+ inputs . push ( preluActivationWeights ) ;
896+ }
888897 return this . compileAndRun < Tensor3D > ( program , inputs , output ) ;
889898 }
890899
@@ -1819,7 +1828,7 @@ export class MathBackendWebGL implements KernelBackend {
18191828
18201829 private conv2dByMatMul (
18211830 x : Tensor4D , filter : Tensor4D , convInfo : Conv2DInfo , bias ?: Tensor4D ,
1822- activation ?: Activation ) : Tensor4D {
1831+ activation ?: Activation , preluActivationWeights ?: Tensor ) : Tensor4D {
18231832 // Reshapes conv2D input to 2D tensors, uses matMul and then reshape the
18241833 // result from 2D to 4D.
18251834 const xShape = x . shape ;
@@ -1850,9 +1859,15 @@ export class MathBackendWebGL implements KernelBackend {
18501859 Tensor3D ;
18511860
18521861 return this . reshape < Rank . R4 > (
1853- this . fusedBatchMatMul (
1854- xReshaped , filterReshaped , transposeA , transposeB , bias ,
1855- activation ) ,
1862+ this . fusedBatchMatMul ( {
1863+ a : xReshaped ,
1864+ b : filterReshaped ,
1865+ transposeA,
1866+ transposeB,
1867+ bias,
1868+ activation,
1869+ preluActivationWeights
1870+ } ) ,
18561871 convInfo . outShape ) ;
18571872 }
18581873
@@ -1888,8 +1903,15 @@ export class MathBackendWebGL implements KernelBackend {
18881903 this . reshape ( filter , [ 1 , convInfo . inChannels , convInfo . outChannels ] ) as
18891904 Tensor3D ;
18901905
1891- const pointwiseConv = this . fusedBatchMatMul (
1892- xReshaped , filterReshaped , transposeA , transposeB , bias , activation ) ;
1906+ const pointwiseConv = this . fusedBatchMatMul ( {
1907+ a : xReshaped ,
1908+ b : filterReshaped ,
1909+ transposeA,
1910+ transposeB,
1911+ bias,
1912+ activation,
1913+ preluActivationWeights
1914+ } ) ;
18931915 const pointwiseConvTexData = this . texData . get ( pointwiseConv . dataId ) ;
18941916 util . assert (
18951917 pointwiseConvTexData . isPacked ,
@@ -1906,7 +1928,7 @@ export class MathBackendWebGL implements KernelBackend {
19061928
19071929 private conv2dWithIm2Row (
19081930 x : Tensor4D , filter : Tensor4D , convInfo : Conv2DInfo , bias ?: Tensor4D ,
1909- activation ?: Activation ) : Tensor4D {
1931+ activation ?: Activation , preluActivationWeights ?: Tensor ) : Tensor4D {
19101932 // Rearranges conv2d input so each block to be convolved over forms the
19111933 // column of a new matrix with shape [filterWidth * filterHeight *
19121934 // inChannels, outHeight * outWidth]. The filter is also rearranged so each
@@ -1938,42 +1960,53 @@ export class MathBackendWebGL implements KernelBackend {
19381960 ] ) as Tensor3D ;
19391961
19401962 const hasBias = bias != null ;
1963+ const hasPreluActivationWeights = preluActivationWeights != null ;
19411964 const fusedActivation =
19421965 activation ? mapActivationToShaderProgram ( activation , true ) : null ;
19431966 const matmulProgram = new MatMulPackedProgram (
19441967 im2Col . shape , [ 1 , numCols , convInfo . outChannels ] , transposeA ,
1945- transposeB , hasBias , fusedActivation ) ;
1968+ transposeB , hasBias , fusedActivation , hasPreluActivationWeights ) ;
19461969 const inputs : TensorHandle [ ] = [ im2Col , w2Row ] ;
19471970 if ( bias ) {
19481971 inputs . push ( bias ) ;
19491972 }
1973+ if ( hasPreluActivationWeights ) {
1974+ inputs . push ( preluActivationWeights ) ;
1975+ }
19501976 const product = this . compileAndRun < Tensor4D > ( matmulProgram , inputs ) ;
19511977
19521978 return product . reshape ( [ 1 , outHeight , outWidth , convInfo . outChannels ] ) ;
19531979 }
19541980
19551981 fusedConv2d (
19561982 x : Tensor4D , filter : Tensor4D , convInfo : Conv2DInfo , bias ?: Tensor4D ,
1957- activation ?: Activation ) : Tensor4D {
1983+ activation ?: Activation , preluActivationWeights ?: Tensor ) : Tensor4D {
19581984 if ( convInfo . filterHeight === 1 && convInfo . filterWidth === 1 &&
19591985 convInfo . dilationHeight === 1 && convInfo . dilationWidth === 1 &&
19601986 convInfo . strideHeight === 1 && convInfo . strideWidth === 1 &&
19611987 ( convInfo . padInfo . type === 'SAME' ||
19621988 convInfo . padInfo . type === 'VALID' ) ) {
1963- return this . conv2dByMatMul ( x , filter , convInfo , bias , activation ) ;
1989+ return this . conv2dByMatMul (
1990+ x , filter , convInfo , bias , activation , preluActivationWeights ) ;
19641991 }
19651992 if ( ENV . getBool ( 'WEBGL_CONV_IM2COL' ) && x . shape [ 0 ] === 1 ) {
1966- return this . conv2dWithIm2Row ( x , filter , convInfo , bias , activation ) ;
1993+ return this . conv2dWithIm2Row (
1994+ x , filter , convInfo , bias , activation , preluActivationWeights ) ;
19671995 }
19681996
19691997 const hasBias = bias != null ;
1998+ const hasPreluActivationWeights = preluActivationWeights != null ;
19701999 const fusedActivation =
19712000 activation ? mapActivationToShaderProgram ( activation , false ) : null ;
1972- const program = new Conv2DProgram ( convInfo , hasBias , fusedActivation ) ;
2001+ const program = new Conv2DProgram (
2002+ convInfo , hasBias , fusedActivation , hasPreluActivationWeights ) ;
19732003 const inputs : TensorHandle [ ] = [ x , filter ] ;
19742004 if ( bias ) {
19752005 inputs . push ( bias ) ;
19762006 }
2007+ if ( preluActivationWeights ) {
2008+ inputs . push ( preluActivationWeights ) ;
2009+ }
19772010 return this . compileAndRun ( program , inputs ) ;
19782011 }
19792012
0 commit comments