Skip to content

Commit 3ffaf31

Browse files
syt123450caisq
authored andcommitted
Add non-default noiseShape and seed support to dropout layer (#556)
This PR makes dropout layer support non-default `noiseShape` and `seed`, adds relative tests for these feature in `dropout` layer and backend's `dropout` op. The `noiseShape` feature in `dropout` layer would benefit further development, for example, `SpatialDropout2D` layer, feature requested in [#1453](#1453). Send this PR to ensure there won't be any broken change for further `noiseShape` related features. FEATURE
1 parent 6670116 commit 3ffaf31

File tree

3 files changed

+85
-21
lines changed

3 files changed

+85
-21
lines changed

src/backend/tfjs_backend_test.ts

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -742,6 +742,44 @@ describeMathCPUAndGPU('dropout', () => {
742742
}
743743
});
744744
}
745+
746+
it('Level=0.75, with noiseShape', () => {
747+
const x = tensor2d(range(1, 21), [10, 2]);
748+
const level = 0.75;
749+
const noiseShape = [10, 1];
750+
const y = K.dropout(x, level, noiseShape);
751+
expect(y.dtype).toEqual(x.dtype);
752+
expect(y.shape).toEqual(x.shape);
753+
const xValue = x.dataSync();
754+
const yValue = y.dataSync();
755+
let nKept = 0;
756+
for (let i = 0; i < x.shape[0]; i++) {
757+
const maskedValue = yValue[i * x.shape[1]];
758+
for (let j = 0; j < x.shape[1]; j++) {
759+
const indice = i * x.shape[1] + j;
760+
if (maskedValue !== 0) {
761+
nKept++;
762+
expect(yValue[indice]).toBeCloseTo(1 / (1 - level) * xValue[indice]);
763+
} else {
764+
expect(yValue[indice]).toEqual(0);
765+
}
766+
}
767+
}
768+
const numel = K.countParams(x);
769+
expect(nKept).toBeLessThan(numel);
770+
});
771+
772+
it('Level=0.75, with seed', () => {
773+
const x = tensor2d(range(1, 21), [10, 2]);
774+
const level = 0.75;
775+
const seed = 23;
776+
const y = K.dropout(x, level, null, seed);
777+
expect(y.dtype).toEqual(x.dtype);
778+
expect(y.shape).toEqual(x.shape);
779+
const yValuesExpected =
780+
[0, 0, 12, 16, 0, 0, 0, 0, 0, 0, 0, 48, 52, 0, 0, 0, 68, 0, 76, 0];
781+
expectTensorsClose(y, tensor2d(yValuesExpected, [10, 2]));
782+
});
745783
});
746784

747785
describeMathCPUAndGPU('biasAdd', () => {

src/layers/core.ts

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ import {Activation as ActivationFn, getActivation, serializeActivation} from '..
1818
import * as K from '../backend/tfjs_backend';
1919
import {Constraint, ConstraintIdentifier, getConstraint, serializeConstraint} from '../constraints';
2020
import {DisposeResult, InputSpec, Layer, LayerArgs} from '../engine/topology';
21-
import {NotImplementedError, ValueError} from '../errors';
21+
import {ValueError} from '../errors';
2222
import {getInitializer, Initializer, InitializerIdentifier, serializeInitializer} from '../initializers';
2323
import {ActivationIdentifier} from '../keras_format/activation_config';
2424
import {Shape} from '../keras_format/common';
@@ -69,11 +69,6 @@ export class Dropout extends Layer {
6969
// So that the scalar doesn't get tidied up between executions.
7070
this.noiseShape = args.noiseShape;
7171
this.seed = args.seed;
72-
if (this.seed != null) {
73-
throw new NotImplementedError(
74-
'Non-default seed is not implemented in Dropout layer yet: ' +
75-
this.seed);
76-
}
7772
this.supportsMasking = true;
7873
}
7974

@@ -94,12 +89,6 @@ export class Dropout extends Layer {
9489
return tidy(() => {
9590
this.invokeCallHook(inputs, kwargs);
9691
const input = getExactlyOneTensor(inputs);
97-
if (this.noiseShape != null &&
98-
!util.arraysEqual(input.shape, this.noiseShape)) {
99-
throw new NotImplementedError(
100-
'Non-default noise shape is not implemented in Dropout ' +
101-
'layer yet: ' + JSON.stringify(this.noiseShape));
102-
}
10392
if (0 < this.rate && this.rate < 1) {
10493
const training =
10594
kwargs['training'] == null ? false : kwargs['training'];

src/layers/core_test.ts

Lines changed: 46 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,7 @@ describeMathCPUAndGPU('Dropout Layer', () => {
5252
const inputShape = [2, 3, 4];
5353
const trainingValues = [false, true];
5454
const dropoutRates = [0, 0.5];
55-
const noiseShapes = [null, inputShape];
56-
// TODO(cais): test non-default noiseShapes once they are supported.
55+
const noiseShapes = [null, inputShape, [2, 3, 1]];
5756

5857
for (const training of trainingValues) {
5958
for (const rate of dropoutRates) {
@@ -69,13 +68,37 @@ describeMathCPUAndGPU('Dropout Layer', () => {
6968
const xValue = x.dataSync();
7069
const yValue = y.dataSync();
7170
let nKept = 0;
72-
for (let i = 0; i < xValue.length; ++i) {
73-
if (yValue[i] !== 0) {
74-
nKept++;
75-
if (training) {
76-
expect(yValue[i]).toBeCloseTo(1 / (1 - rate));
77-
} else {
78-
expect(yValue[i]).toBeCloseTo(1);
71+
if (noiseShape === noiseShapes[2]) { // customized noiseShape
72+
for (let i = 0; i < x.shape[0]; ++i) {
73+
for (let j = 0; j < x.shape[1]; ++j) {
74+
const maskedValue =
75+
yValue[i * x.shape[1] * x.shape[2] + j * x.shape[2]];
76+
for (let k = 0; k < x.shape[2]; ++k) {
77+
const indice =
78+
i * x.shape[1] * x.shape[2] + j * x.shape[2] + k;
79+
if (training) {
80+
if (maskedValue === 0) {
81+
expect(yValue[indice]).toEqual(0);
82+
} else {
83+
nKept++;
84+
expect(yValue[indice]).toBeCloseTo(1 / (1 - rate));
85+
}
86+
} else {
87+
nKept++;
88+
expect(yValue[indice]).toEqual(1);
89+
}
90+
}
91+
}
92+
}
93+
} else { // default noiseShape
94+
for (let i = 0; i < xValue.length; ++i) {
95+
if (yValue[i] !== 0) {
96+
nKept++;
97+
if (training) {
98+
expect(yValue[i]).toBeCloseTo(1 / (1 - rate));
99+
} else {
100+
expect(yValue[i]).toBeCloseTo(1);
101+
}
79102
}
80103
}
81104
}
@@ -90,6 +113,20 @@ describeMathCPUAndGPU('Dropout Layer', () => {
90113
}
91114
}
92115
});
116+
117+
describe('tensor with seed get specific value', () => {
118+
const training = true;
119+
const rate = 0.5;
120+
const noiseShape = [2, 3, 4];
121+
const x = ones([2, 3, 4]);
122+
const seed = 23;
123+
const dropoutLayer = tfl.layers.dropout({rate, noiseShape, seed});
124+
const y = dropoutLayer.apply(x, {training}) as Tensor;
125+
const yValuesExpected = [
126+
0, 2, 2, 2, 0, 0, 2, 2, 0, 0, 2, 2, 2, 0, 0, 0, 2, 0, 2, 0, 2, 2, 2, 0
127+
];
128+
expectTensorsClose(y, tensor3d(yValuesExpected, [2, 3, 4]));
129+
});
93130
});
94131

95132
describeMathCPU('Dense Layer: Symbolic', () => {

0 commit comments

Comments
 (0)