Skip to content

Commit 61b355e

Browse files
authored
core: Add fused depthwiseConv with bias and activation kernel. (#1977)
PERF
1 parent 9e162f0 commit 61b355e

File tree

9 files changed

+631
-46
lines changed

9 files changed

+631
-46
lines changed

tfjs-core/src/backends/backend.ts

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
*/
1717

1818
import {Conv2DInfo, Conv3DInfo} from '../ops/conv_util';
19-
import {Activation, FusedBatchMatMulConfig} from '../ops/fused_util';
19+
import {FusedBatchMatMulConfig, FusedConv2DConfig} from '../ops/fused_util';
2020
import {Backend, DataId, Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D} from '../tensor';
2121
import {BackendValues, DataType, PixelData, Rank, ShapeMap} from '../types';
2222

@@ -410,8 +410,8 @@ export class KernelBackend implements TensorStorage, Backend, BackendTimer {
410410
}
411411

412412
fusedConv2d(
413-
x: Tensor4D, filter: Tensor4D, convInfo: Conv2DInfo, bias?: Tensor4D,
414-
activation?: Activation, preluActivationWeights?: Tensor): Tensor4D {
413+
{input, filter, convInfo, bias, activation, preluActivationWeights}:
414+
FusedConv2DConfig): Tensor4D {
415415
throw new Error('Not yet implemented');
416416
}
417417

@@ -426,6 +426,12 @@ export class KernelBackend implements TensorStorage, Backend, BackendTimer {
426426
throw new Error('Not yet implemented');
427427
}
428428

429+
fusedDepthwiseConv2D(
430+
{input, filter, convInfo, bias, activation, preluActivationWeights}:
431+
FusedConv2DConfig): Tensor4D {
432+
throw new Error('Not yet implemented');
433+
}
434+
429435
depthwiseConv2D(input: Tensor4D, filter: Tensor4D, convInfo: Conv2DInfo):
430436
Tensor4D {
431437
throw new Error('Not yet implemented');

tfjs-core/src/backends/cpu/backend_cpu.ts

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ import {complex, imag, real} from '../../ops/complex_ops';
2727
import * as concat_util from '../../ops/concat_util';
2828
import {Conv2DInfo, Conv3DInfo} from '../../ops/conv_util';
2929
import * as erf_util from '../../ops/erf_util';
30-
import {Activation, FusedBatchMatMulConfig} from '../../ops/fused_util';
30+
import {Activation, FusedBatchMatMulConfig, FusedConv2DConfig} from '../../ops/fused_util';
3131
import * as gather_nd_util from '../../ops/gather_nd_util';
3232
import * as ops from '../../ops/ops';
3333
import {buffer, scalar, tensor, tensor3d, tensor4d} from '../../ops/ops';
@@ -1531,9 +1531,9 @@ export class MathBackendCPU implements KernelBackend {
15311531
}
15321532

15331533
fusedConv2d(
1534-
x: Tensor4D, filter: Tensor4D, convInfo: Conv2DInfo, bias?: Tensor4D,
1535-
activation?: Activation, preluActivationWeights?: Tensor): Tensor4D {
1536-
let result = this.conv2d(x, filter, convInfo);
1534+
{input, filter, convInfo, bias, activation, preluActivationWeights}:
1535+
FusedConv2DConfig): Tensor4D {
1536+
let result = this.conv2d(input, filter, convInfo);
15371537

15381538
if (bias) {
15391539
result = this.add(result, bias) as Tensor4D;
@@ -1973,6 +1973,22 @@ export class MathBackendCPU implements KernelBackend {
19731973
return dw.toTensor();
19741974
}
19751975

1976+
fusedDepthwiseConv2D(
1977+
{input, filter, convInfo, bias, activation, preluActivationWeights}:
1978+
FusedConv2DConfig): Tensor4D {
1979+
let result = this.depthwiseConv2D(input, filter, convInfo);
1980+
1981+
if (bias) {
1982+
result = this.add(result, bias) as Tensor4D;
1983+
}
1984+
if (activation) {
1985+
result =
1986+
mapActivation(this, result, activation, preluActivationWeights) as
1987+
Tensor4D;
1988+
}
1989+
return result;
1990+
}
1991+
19761992
depthwiseConv2D(x: Tensor4D, filter: Tensor4D, convInfo: Conv2DInfo):
19771993
Tensor4D {
19781994
this.assertNotComplex([x, filter], 'depthwiseConv2D');

tfjs-core/src/backends/webgl/backend_webgl.ts

Lines changed: 43 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ import * as axis_util from '../../ops/axis_util';
2929
import {complex, imag, real} from '../../ops/complex_ops';
3030
import {computeOutShape} from '../../ops/concat_util';
3131
import {Conv2DInfo, Conv3DInfo} from '../../ops/conv_util';
32-
import {Activation, FusedBatchMatMulConfig} from '../../ops/fused_util';
32+
import {Activation, FusedBatchMatMulConfig, FusedConv2DConfig} from '../../ops/fused_util';
3333
import * as gather_nd_util from '../../ops/gather_nd_util';
3434
import * as reduce_util from '../../ops/reduce_util';
3535
import * as scatter_nd_util from '../../ops/scatter_nd_util';
@@ -1909,7 +1909,7 @@ export class MathBackendWebGL implements KernelBackend {
19091909
}
19101910

19111911
private conv2dByMatMul(
1912-
x: Tensor4D, filter: Tensor4D, convInfo: Conv2DInfo, bias?: Tensor4D,
1912+
x: Tensor4D, filter: Tensor4D, convInfo: Conv2DInfo, bias?: Tensor,
19131913
activation?: Activation, preluActivationWeights?: Tensor): Tensor4D {
19141914
// Reshapes conv2D input to 2D tensors, uses matMul and then reshape the
19151915
// result from 2D to 4D.
@@ -2008,7 +2008,7 @@ export class MathBackendWebGL implements KernelBackend {
20082008
}
20092009

20102010
private conv2dWithIm2Row(
2011-
x: Tensor4D, filter: Tensor4D, convInfo: Conv2DInfo, bias?: Tensor4D,
2011+
x: Tensor4D, filter: Tensor4D, convInfo: Conv2DInfo, bias?: Tensor,
20122012
activation?: Activation, preluActivationWeights?: Tensor): Tensor4D {
20132013
// Rearranges conv2d input so each block to be convolved over forms the
20142014
// column of a new matrix with shape [filterWidth * filterHeight *
@@ -2067,19 +2067,19 @@ export class MathBackendWebGL implements KernelBackend {
20672067
}
20682068

20692069
fusedConv2d(
2070-
x: Tensor4D, filter: Tensor4D, convInfo: Conv2DInfo, bias?: Tensor4D,
2071-
activation?: Activation, preluActivationWeights?: Tensor): Tensor4D {
2070+
{input, filter, convInfo, bias, activation, preluActivationWeights}:
2071+
FusedConv2DConfig): Tensor4D {
20722072
if (convInfo.filterHeight === 1 && convInfo.filterWidth === 1 &&
20732073
convInfo.dilationHeight === 1 && convInfo.dilationWidth === 1 &&
20742074
convInfo.strideHeight === 1 && convInfo.strideWidth === 1 &&
20752075
(convInfo.padInfo.type === 'SAME' ||
20762076
convInfo.padInfo.type === 'VALID')) {
20772077
return this.conv2dByMatMul(
2078-
x, filter, convInfo, bias, activation, preluActivationWeights);
2078+
input, filter, convInfo, bias, activation, preluActivationWeights);
20792079
}
2080-
if (ENV.getBool('WEBGL_CONV_IM2COL') && x.shape[0] === 1) {
2080+
if (ENV.getBool('WEBGL_CONV_IM2COL') && input.shape[0] === 1) {
20812081
return this.conv2dWithIm2Row(
2082-
x, filter, convInfo, bias, activation, preluActivationWeights);
2082+
input, filter, convInfo, bias, activation, preluActivationWeights);
20832083
}
20842084

20852085
const hasBias = bias != null;
@@ -2088,7 +2088,7 @@ export class MathBackendWebGL implements KernelBackend {
20882088
activation ? mapActivationToShaderProgram(activation, false) : null;
20892089
const program = new Conv2DProgram(
20902090
convInfo, hasBias, fusedActivation, hasPreluActivationWeights);
2091-
const inputs: TensorHandle[] = [x, filter];
2091+
const inputs: TensorHandle[] = [input, filter];
20922092
if (bias) {
20932093
inputs.push(bias);
20942094
}
@@ -2124,6 +2124,40 @@ export class MathBackendWebGL implements KernelBackend {
21242124
return this.compileAndRun(program, [x, dy]);
21252125
}
21262126

2127+
fusedDepthwiseConv2D(
2128+
{input, filter, convInfo, bias, activation, preluActivationWeights}:
2129+
FusedConv2DConfig): Tensor4D {
2130+
const shouldPackDepthwiseConv = ENV.getBool('WEBGL_PACK_DEPTHWISECONV') &&
2131+
convInfo.strideWidth <= 2 &&
2132+
convInfo.outChannels / convInfo.inChannels === 1;
2133+
const fusedActivation = activation ?
2134+
mapActivationToShaderProgram(activation, shouldPackDepthwiseConv) :
2135+
null;
2136+
const inputs: Tensor[] = [input, filter];
2137+
2138+
const hasBias = bias != null;
2139+
const hasPreluActivationWeights = preluActivationWeights != null;
2140+
if (hasBias) {
2141+
inputs.push(bias);
2142+
}
2143+
if (hasPreluActivationWeights) {
2144+
inputs.push(preluActivationWeights);
2145+
}
2146+
2147+
let program: DepthwiseConv2DProgram|DepthwiseConvPacked2DProgram;
2148+
if (shouldPackDepthwiseConv) {
2149+
program = new DepthwiseConvPacked2DProgram(
2150+
convInfo, hasBias, fusedActivation, hasPreluActivationWeights);
2151+
return this.compileAndRun(
2152+
program, inputs,
2153+
this.makePackedTensor(convInfo.outShape, input.dtype));
2154+
}
2155+
2156+
program = new DepthwiseConv2DProgram(
2157+
convInfo, hasBias, fusedActivation, hasPreluActivationWeights);
2158+
return this.compileAndRun(program, inputs);
2159+
}
2160+
21272161
depthwiseConv2D(x: Tensor4D, filter: Tensor4D, convInfo: Conv2DInfo):
21282162
Tensor4D {
21292163
let program: DepthwiseConv2DProgram|DepthwiseConvPacked2DProgram;

tfjs-core/src/backends/webgl/conv_gpu_depthwise.ts

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@ export class DepthwiseConv2DProgram implements GPGPUProgram {
2323
outputShape: number[];
2424
userCode: string;
2525

26-
constructor(convInfo: Conv2DInfo) {
26+
constructor(
27+
convInfo: Conv2DInfo, addBias = false, activation: string = null,
28+
hasPreluActivation = false) {
2729
this.outputShape = convInfo.outShape;
2830

2931
const xNumRows = convInfo.inHeight;
@@ -38,7 +40,36 @@ export class DepthwiseConv2DProgram implements GPGPUProgram {
3840
const filterWidth = convInfo.filterWidth;
3941
const channelMul = convInfo.outChannels / convInfo.inChannels;
4042

43+
let activationSnippet = '', applyActivationSnippet = '';
44+
if (activation) {
45+
if (hasPreluActivation) {
46+
activationSnippet = `float activation(float a) {
47+
float b = getPreluActivationWeightsAtOutCoords();
48+
${activation}
49+
}`;
50+
} else {
51+
activationSnippet = `
52+
float activation(float x) {
53+
${activation}
54+
}
55+
`;
56+
}
57+
58+
applyActivationSnippet = `result = activation(result);`;
59+
}
60+
61+
const addBiasSnippet = addBias ? 'result += getBiasAtOutCoords();' : '';
62+
if (addBias) {
63+
this.variableNames.push('bias');
64+
}
65+
66+
if (hasPreluActivation) {
67+
this.variableNames.push('preluActivationWeights');
68+
}
69+
4170
this.userCode = `
71+
${activationSnippet}
72+
4273
const ivec2 strides = ivec2(${strideHeight}, ${strideWidth});
4374
const ivec2 pads = ivec2(${padTop}, ${padLeft});
4475
@@ -76,7 +107,11 @@ export class DepthwiseConv2DProgram implements GPGPUProgram {
76107
dotProd += xVal * wVal;
77108
}
78109
}
79-
setOutput(dotProd);
110+
111+
float result = dotProd;
112+
${addBiasSnippet}
113+
${applyActivationSnippet}
114+
setOutput(result);
80115
}
81116
`;
82117
}

tfjs-core/src/backends/webgl/conv_packed_gpu_depthwise.ts

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@ export class DepthwiseConvPacked2DProgram implements GPGPUProgram {
2626
outputShape: number[];
2727
userCode: string;
2828

29-
constructor(convInfo: Conv2DInfo) {
29+
constructor(
30+
convInfo: Conv2DInfo, addBias = false, activation: string = null,
31+
hasPreluActivation = false) {
3032
this.outputShape = convInfo.outShape;
3133

3234
const xNumRows = convInfo.inHeight;
@@ -257,11 +259,38 @@ export class DepthwiseConvPacked2DProgram implements GPGPUProgram {
257259

258260
for (let r = 0; r < filterHeight; r++) {
259261
for (let c = 0; c < filterWidth; c++) {
260-
mainLoop += `result += xR${r}C${c} * wR${r}C${c};`;
262+
mainLoop += `dotProd += xR${r}C${c} * wR${r}C${c};`;
261263
}
262264
}
263265

266+
let activationSnippet = '', applyActivationSnippet = '';
267+
if (activation) {
268+
if (hasPreluActivation) {
269+
activationSnippet = `vec4 activation(vec4 a) {
270+
vec4 b = getPreluActivationWeightsAtOutCoords();
271+
${activation}
272+
}`;
273+
} else {
274+
activationSnippet = `vec4 activation(vec4 x) {
275+
${activation}
276+
}`;
277+
}
278+
279+
applyActivationSnippet = `result = activation(result);`;
280+
}
281+
282+
const addBiasSnippet = addBias ? 'result += getBiasAtOutCoords();' : '';
283+
if (addBias) {
284+
this.variableNames.push('bias');
285+
}
286+
287+
if (hasPreluActivation) {
288+
this.variableNames.push('preluActivationWeights');
289+
}
290+
264291
this.userCode = `
292+
${activationSnippet}
293+
265294
const ivec2 strides = ivec2(${strideHeight}, ${strideWidth});
266295
const ivec2 pads = ivec2(${padTop}, ${padLeft});
267296
@@ -276,10 +305,13 @@ export class DepthwiseConvPacked2DProgram implements GPGPUProgram {
276305
int xRCorner = xRCCorner.x;
277306
int xCCorner = xRCCorner.y;
278307
279-
vec4 result = vec4(0.);
308+
vec4 dotProd = vec4(0.);
280309
281310
${mainLoop}
282311
312+
vec4 result = dotProd;
313+
${addBiasSnippet}
314+
${applyActivationSnippet}
283315
setOutput(result);
284316
}
285317
`;

tfjs-core/src/ops/conv.ts

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -194,10 +194,9 @@ function conv2d_<T extends Tensor3D|Tensor4D>(
194194
`are not yet supported in gradients. Got dilations '${dilations}'`);
195195

196196
return {
197-
x: () =>
198-
conv2dDerInput_(x4D.shape, dy, $filter, strides, pad, dataFormat),
197+
x: () => conv2dDerInput(x4D.shape, dy, $filter, strides, pad, dataFormat),
199198
$filter: () =>
200-
conv2dDerFilter_(x4D, dy, $filter.shape, strides, pad, dataFormat)
199+
conv2dDerFilter(x4D, dy, $filter.shape, strides, pad, dataFormat)
201200
};
202201
};
203202

@@ -675,7 +674,7 @@ function eitherStridesOrDilationsAreOne(
675674
return tupleValuesAreOne(strides) || tupleValuesAreOne(dilations);
676675
}
677676

678-
function depthwiseConv2dDerInput<T extends Tensor3D|Tensor4D>(
677+
function depthwiseConv2dDerInput_<T extends Tensor3D|Tensor4D>(
679678
xShape: [number, number, number, number]|[number, number, number], dy: T,
680679
filter: Tensor4D, convInfo: conv_util.Conv2DInfo): T {
681680
let dy4D = dy as Tensor4D;
@@ -693,7 +692,7 @@ function depthwiseConv2dDerInput<T extends Tensor3D|Tensor4D>(
693692
return res as T;
694693
}
695694

696-
function depthwiseConv2dDerFilter<T extends Tensor3D|Tensor4D>(
695+
function depthwiseConv2dDerFilter_<T extends Tensor3D|Tensor4D>(
697696
x: T, dy: T, filterShape: [number, number, number, number],
698697
convInfo: conv_util.Conv2DInfo): Tensor4D {
699698
let x4D = x as Tensor4D;
@@ -973,6 +972,8 @@ export const conv3d = op({conv3d_});
973972
export const conv2dDerFilter = op({conv2dDerFilter_});
974973
export const conv2dDerInput = op({conv2dDerInput_});
975974
export const depthwiseConv2d = op({depthwiseConv2d_});
975+
export const depthwiseConv2dDerInput = op({depthwiseConv2dDerInput_});
976+
export const depthwiseConv2dDerFilter = op({depthwiseConv2dDerFilter_});
976977
export const separableConv2d = op({separableConv2d_});
977978
export const conv2dTranspose = op({conv2dTranspose_});
978979
export const conv3dTranspose = op({conv3dTranspose_});

0 commit comments

Comments
 (0)