Skip to content

Commit 99e707a

Browse files
authored
core: Fuse elu activations. (#1964)
PERF
1 parent 3ce439e commit 99e707a

File tree

5 files changed

+79
-1
lines changed

5 files changed

+79
-1
lines changed

tfjs-core/src/backends/cpu/backend_cpu.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ function mapActivation(
5454
return backend.linear(x);
5555
} else if (activation === 'relu') {
5656
return backend.relu(x);
57+
} else if (activation === 'elu') {
58+
return backend.elu(x);
5759
} else if (activation === 'prelu') {
5860
return backend.prelu(x, preluActivationWeights);
5961
}

tfjs-core/src/backends/webgl/backend_webgl.ts

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,11 @@ function mapActivationToShaderProgram(
176176
return unary_packed_op.RELU;
177177
}
178178
return unary_op.RELU;
179+
} else if (activation === 'elu') {
180+
if (packed) {
181+
return unary_packed_op.ELU;
182+
}
183+
return unary_op.ELU;
179184
} else if (activation === 'prelu') {
180185
if (packed) {
181186
return binaryop_packed_gpu.PRELU;
@@ -1745,6 +1750,9 @@ export class MathBackendWebGL implements KernelBackend {
17451750
}
17461751

17471752
elu<T extends Tensor>(x: T): T {
1753+
if (ENV.getBool('WEBGL_PACK_UNARY_OPERATIONS')) {
1754+
return this.packedUnaryOp(x, unary_packed_op.ELU, x.dtype) as T;
1755+
}
17481756
const program = new UnaryOpProgram(x.shape, unary_op.ELU);
17491757
return this.compileAndRun(program, [x]);
17501758
}

tfjs-core/src/backends/webgl/unaryop_packed_gpu.ts

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,17 @@ export const RELU = `
4242
return result;
4343
`;
4444

45+
export const ELU = `
46+
vec4 result;
47+
48+
result.r = (x.r >= 0.0) ? x.r : (exp(x.r) - 1.0);
49+
result.g = (x.g >= 0.0) ? x.g : (exp(x.g) - 1.0);
50+
result.b = (x.b >= 0.0) ? x.b : (exp(x.b) - 1.0);
51+
result.a = (x.a >= 0.0) ? x.a : (exp(x.a) - 1.0);
52+
53+
return result;
54+
`;
55+
4556
export class UnaryOpPackedProgram implements GPGPUProgram {
4657
variableNames = ['A'];
4758
userCode: string;

tfjs-core/src/ops/fused_test.ts

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,19 @@ describeWithFlags('fused matmul', ALL_ENVS, () => {
4343
expectArraysClose(await c.data(), [0, 8, 0, 20]);
4444
});
4545

46+
it('A x B with elu', async () => {
47+
const a = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]);
48+
const b = tf.tensor2d([0, 1, -3, 2, 2, 1], [3, 2]);
49+
const transposeA = false;
50+
const transposeB = false;
51+
52+
const c = tf.fused.matMul(
53+
{a, b, transposeA, transposeB, bias: null, activation: 'elu'});
54+
55+
expect(c.shape).toEqual([2, 2]);
56+
expectArraysClose(await c.data(), [0, 8, -0.9502, 20]);
57+
});
58+
4659
it('A x B with prelu', async () => {
4760
const a = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]);
4861
const b = tf.tensor2d([0, 1, -3, 2, 2, 1], [3, 2]);
@@ -106,6 +119,21 @@ describeWithFlags('fused matmul', ALL_ENVS, () => {
106119
expectArraysClose(await d.data(), [1, 9, 0, 21]);
107120
});
108121

122+
it('A x B with elu and broadcasted bias', async () => {
123+
const a = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]);
124+
const b = tf.tensor2d([0, 1, -3, 2, 2, 1], [3, 2]);
125+
const c = tf.tensor1d([1, 1]);
126+
const act: tf.fused.Activation = 'elu';
127+
const transposeA = false;
128+
const transposeB = false;
129+
130+
const d = tf.fused.matMul(
131+
{a, b, transposeA, transposeB, bias: c, activation: act});
132+
133+
expect(d.shape).toEqual([2, 2]);
134+
expectArraysClose(await d.data(), [1, 9, -0.8647, 21]);
135+
});
136+
109137
it('A x B with relu and broadcasted bias different rank', async () => {
110138
const a = tf.tensor3d([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], [2, 2, 3]);
111139
const b = tf.tensor3d([0, 1, -3, 2, 2, 1, 0, 1, -3, 2, 2, 1], [2, 3, 2]);
@@ -318,6 +346,35 @@ describeWithFlags('fused conv2d', ALL_ENVS, () => {
318346
expectArraysClose(await result.data(), expected);
319347
});
320348

349+
it('basic with elu', async () => {
350+
const inputDepth = 2;
351+
const inShape: [number, number, number, number] = [2, 2, 2, inputDepth];
352+
const outputDepth = 2;
353+
const fSize = 1;
354+
const pad = 0;
355+
const stride = 1;
356+
357+
const x = tf.tensor4d(
358+
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], inShape);
359+
const w =
360+
tf.tensor4d([-1, 1, -2, 0.5], [fSize, fSize, inputDepth, outputDepth]);
361+
362+
const result = tf.fused.conv2d({
363+
x,
364+
filter: w,
365+
strides: stride,
366+
pad,
367+
dataFormat: 'NHWC',
368+
dilations: [1, 1],
369+
activation: 'elu'
370+
});
371+
expect(result.shape).toEqual([2, 2, 2, 2]);
372+
const expected =
373+
[-0.99326, 2, -1, 5, -1, 8, -1, 11, -1, 14, -1, 17, -1, 20, -1, 23];
374+
375+
expectArraysClose(await result.data(), expected);
376+
});
377+
321378
it('basic with prelu', async () => {
322379
const inputDepth = 2;
323380
const inShape: [number, number, number, number] = [2, 2, 2, inputDepth];

tfjs-core/src/ops/fused_util.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
import {Tensor, Tensor3D} from '../tensor';
1919

20-
export type Activation = 'linear'|'relu'|'prelu';
20+
export type Activation = 'linear'|'relu'|'prelu'|'elu';
2121

2222
export type FusedBatchMatMulConfig = {
2323
a: Tensor3D,

0 commit comments

Comments
 (0)