Skip to content

Commit 8c2e26c

Browse files
syt123450caisq
authored andcommitted
Add maxPooling3d layer & averagePooling3d layer (#555)
This PR adds `averagePooling3d` layer & `maxPooling3d` layer(feature requested in [#1035](#1035)). Features in this PR: * Add `pool3d` function to compute 3D pooling * Add relative tests for `pool3d` function * Add `Pooling3DLayerArgs` interface * Add abstract `Pooling3D` class * Add `MaxPooling3D` class * Add `AveragePooling3D` class * Export `averagePooling3d` layer * Add relative tests for `averagePooling3d` layer * Export `maxPooling3d` layer * Add relative tests for `maxPooling3d` layer **Note:** This PR depends on PR [tensorflow/tfjs-core#1778](tensorflow/tfjs-core#1778), which adds `maxPool3d` op, `avgPool3d` op, and exports `Tensor5D` type. **Reference:** * [Keras AveragePooling3D layer](https://keras.io/layers/pooling/) * [Keras MaxPooling3D layer](https://keras.io/layers/pooling/) FEATURE
1 parent 3ffaf31 commit 8c2e26c

File tree

3 files changed

+619
-5
lines changed

3 files changed

+619
-5
lines changed

src/exports_layers.ts

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ import {Add, Average, Concatenate, ConcatenateLayerArgs, Dot, DotLayerArgs, Maxi
2020
import {AlphaDropout, AlphaDropoutArgs, GaussianDropout, GaussianDropoutArgs, GaussianNoise, GaussianNoiseArgs} from './layers/noise';
2121
import {BatchNormalization, BatchNormalizationLayerArgs} from './layers/normalization';
2222
import {ZeroPadding2D, ZeroPadding2DLayerArgs} from './layers/padding';
23-
import {AveragePooling1D, AveragePooling2D, GlobalAveragePooling1D, GlobalAveragePooling2D, GlobalMaxPooling1D, GlobalMaxPooling2D, GlobalPooling2DLayerArgs, MaxPooling1D, MaxPooling2D, Pooling1DLayerArgs, Pooling2DLayerArgs} from './layers/pooling';
23+
import {AveragePooling1D, AveragePooling2D, AveragePooling3D, GlobalAveragePooling1D, GlobalAveragePooling2D, GlobalMaxPooling1D, GlobalMaxPooling2D, GlobalPooling2DLayerArgs, MaxPooling1D, MaxPooling2D, MaxPooling3D, Pooling1DLayerArgs, Pooling2DLayerArgs, Pooling3DLayerArgs} from './layers/pooling';
2424
import {GRU, GRUCell, GRUCellLayerArgs, GRULayerArgs, LSTM, LSTMCell, LSTMCellLayerArgs, LSTMLayerArgs, RNN, RNNCell, RNNLayerArgs, SimpleRNN, SimpleRNNCell, SimpleRNNCellLayerArgs, SimpleRNNLayerArgs, StackedRNNCells, StackedRNNCellsArgs} from './layers/recurrent';
2525
import {Bidirectional, BidirectionalLayerArgs, TimeDistributed, WrapperLayerArgs} from './layers/wrappers';
2626

@@ -918,6 +918,38 @@ export function avgPooling2d(args: Pooling2DLayerArgs): Layer {
918918
return averagePooling2d(args);
919919
}
920920

921+
/**
922+
* Average pooling operation for 3D data.
923+
*
924+
* Input shape
925+
* - If `dataFormat === channelsLast`:
926+
* 5D tensor with shape:
927+
* `[batchSize, depths, rows, cols, channels]`
928+
* - If `dataFormat === channelsFirst`:
929+
* 4D tensor with shape:
930+
* `[batchSize, channels, depths, rows, cols]`
931+
*
932+
* Output shape
933+
* - If `dataFormat=channelsLast`:
934+
* 5D tensor with shape:
935+
* `[batchSize, pooledDepths, pooledRows, pooledCols, channels]`
936+
* - If `dataFormat=channelsFirst`:
937+
* 5D tensor with shape:
938+
* `[batchSize, channels, pooledDepths, pooledRows, pooledCols]`
939+
*/
940+
/** @doc {heading: 'Layers', subheading: 'Pooling', namespace: 'layers'} */
941+
export function averagePooling3d(args: Pooling3DLayerArgs): Layer {
942+
return new AveragePooling3D(args);
943+
}
944+
export function avgPool3d(args: Pooling3DLayerArgs): Layer {
945+
return averagePooling3d(args);
946+
}
947+
// For backwards compatibility.
948+
// See https://github.com/tensorflow/tfjs/issues/152
949+
export function avgPooling3d(args: Pooling3DLayerArgs): Layer {
950+
return averagePooling3d(args);
951+
}
952+
921953
/**
922954
* Global average pooling operation for temporal data.
923955
*
@@ -1012,6 +1044,30 @@ export function maxPooling2d(args: Pooling2DLayerArgs): Layer {
10121044
return new MaxPooling2D(args);
10131045
}
10141046

1047+
/**
1048+
* Max pooling operation for 3D data.
1049+
*
1050+
* Input shape
1051+
* - If `dataFormat === channelsLast`:
1052+
* 5D tensor with shape:
1053+
* `[batchSize, depths, rows, cols, channels]`
1054+
* - If `dataFormat === channelsFirst`:
1055+
* 5D tensor with shape:
1056+
* `[batchSize, channels, depths, rows, cols]`
1057+
*
1058+
* Output shape
1059+
* - If `dataFormat=channelsLast`:
1060+
* 5D tensor with shape:
1061+
* `[batchSize, pooledDepths, pooledRows, pooledCols, channels]`
1062+
* - If `dataFormat=channelsFirst`:
1063+
* 5D tensor with shape:
1064+
* `[batchSize, channels, pooledDepths, pooledRows, pooledCols]`
1065+
*/
1066+
/** @doc {heading: 'Layers', subheading: 'Pooling', namespace: 'layers'} */
1067+
export function maxPooling3d(args: Pooling3DLayerArgs): Layer {
1068+
return new MaxPooling3D(args);
1069+
}
1070+
10151071
// Recurrent Layers.
10161072

10171073
/**

src/layers/pooling.ts

Lines changed: 202 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
*/
1414

1515
import * as tfc from '@tensorflow/tfjs-core';
16-
import {serialization, Tensor, Tensor3D, Tensor4D, tidy} from '@tensorflow/tfjs-core';
16+
import {serialization, Tensor, Tensor3D, Tensor4D, Tensor5D, tidy} from '@tensorflow/tfjs-core';
1717

1818
import {imageDataFormat} from '../backend/common';
1919
import * as K from '../backend/tfjs_backend';
@@ -27,7 +27,7 @@ import {convOutputLength} from '../utils/conv_utils';
2727
import {assertPositiveInteger} from '../utils/generic_utils';
2828
import {getExactlyOneShape, getExactlyOneTensor} from '../utils/types_utils';
2929

30-
import {preprocessConv2DInput} from './convolutional';
30+
import {preprocessConv2DInput, preprocessConv3DInput} from './convolutional';
3131

3232
/**
3333
* 2D pooling.
@@ -82,6 +82,52 @@ export function pool2d(
8282
});
8383
}
8484

85+
/**
86+
* 3D pooling.
87+
* @param x
88+
* @param poolSize. Default to [1, 1, 1].
89+
* @param strides strides. Defaults to [1, 1, 1].
90+
* @param padding padding. Defaults to 'valid'.
91+
* @param dataFormat data format. Defaults to 'channelsLast'.
92+
* @param poolMode Mode of pooling. Defaults to 'max'.
93+
* @returns Result of the 3D pooling.
94+
*/
95+
export function pool3d(
96+
x: Tensor5D, poolSize: [number, number, number],
97+
strides?: [number, number, number], padding?: PaddingMode,
98+
dataFormat?: DataFormat, poolMode?: PoolMode): Tensor {
99+
return tidy(() => {
100+
checkDataFormat(dataFormat);
101+
checkPoolMode(poolMode);
102+
checkPaddingMode(padding);
103+
if (strides == null) {
104+
strides = [1, 1, 1];
105+
}
106+
if (padding == null) {
107+
padding = 'valid';
108+
}
109+
if (dataFormat == null) {
110+
dataFormat = imageDataFormat();
111+
}
112+
if (poolMode == null) {
113+
poolMode = 'max';
114+
}
115+
116+
// x is NDHWC after preprocessing.
117+
x = preprocessConv3DInput(x as Tensor, dataFormat) as Tensor5D;
118+
let y: Tensor;
119+
const paddingString = (padding === 'same') ? 'same' : 'valid';
120+
if (poolMode === 'max') {
121+
y = tfc.maxPool3d(x, poolSize, strides, paddingString);
122+
} else { // 'avg'
123+
y = tfc.avgPool3d(x, poolSize, strides, paddingString);
124+
}
125+
if (dataFormat === 'channelsFirst') {
126+
y = tfc.transpose(y, [0, 4, 1, 2, 3]); // NDHWC -> NCDHW.
127+
}
128+
return y;
129+
});
130+
}
85131

86132
export declare interface Pooling1DLayerArgs extends LayerArgs {
87133
/**
@@ -370,6 +416,160 @@ export class AveragePooling2D extends Pooling2D {
370416
}
371417
serialization.registerClass(AveragePooling2D);
372418

419+
export declare interface Pooling3DLayerArgs extends LayerArgs {
420+
/**
421+
* Factors by which to downscale in each dimension [depth, height, width].
422+
* Expects an integer or an array of 3 integers.
423+
*
424+
* For example, `[2, 2, 2]` will halve the input in three dimensions.
425+
* If only one integer is specified, the same window length
426+
* will be used for all dimensions.
427+
*/
428+
poolSize?: number|[number, number, number];
429+
430+
/**
431+
* The size of the stride in each dimension of the pooling window. Expects
432+
* an integer or an array of 3 integers. Integer, tuple of 3 integers, or
433+
* None.
434+
*
435+
* If `null`, defaults to `poolSize`.
436+
*/
437+
strides?: number|[number, number, number];
438+
439+
/** The padding type to use for the pooling layer. */
440+
padding?: PaddingMode;
441+
/** The data format to use for the pooling layer. */
442+
dataFormat?: DataFormat;
443+
}
444+
445+
/**
446+
* Abstract class for different pooling 3D layers.
447+
*/
448+
export abstract class Pooling3D extends Layer {
449+
protected readonly poolSize: [number, number, number];
450+
protected readonly strides: [number, number, number];
451+
protected readonly padding: PaddingMode;
452+
protected readonly dataFormat: DataFormat;
453+
454+
constructor(args: Pooling3DLayerArgs) {
455+
if (args.poolSize == null) {
456+
args.poolSize = [2, 2, 2];
457+
}
458+
super(args);
459+
this.poolSize = Array.isArray(args.poolSize) ?
460+
args.poolSize :
461+
[args.poolSize, args.poolSize, args.poolSize];
462+
if (args.strides == null) {
463+
this.strides = this.poolSize;
464+
} else if (Array.isArray(args.strides)) {
465+
if (args.strides.length !== 3) {
466+
throw new ValueError(
467+
`If the strides property of a 3D pooling layer is an Array, ` +
468+
`it is expected to have a length of 3, but received length ` +
469+
`${args.strides.length}.`);
470+
}
471+
this.strides = args.strides;
472+
} else {
473+
// `config.strides` is a number.
474+
this.strides = [args.strides, args.strides, args.strides];
475+
}
476+
assertPositiveInteger(this.poolSize, 'poolSize');
477+
assertPositiveInteger(this.strides, 'strides');
478+
this.padding = args.padding == null ? 'valid' : args.padding;
479+
this.dataFormat =
480+
args.dataFormat == null ? 'channelsLast' : args.dataFormat;
481+
checkDataFormat(this.dataFormat);
482+
checkPaddingMode(this.padding);
483+
484+
this.inputSpec = [new InputSpec({ndim: 5})];
485+
}
486+
487+
computeOutputShape(inputShape: Shape|Shape[]): Shape|Shape[] {
488+
inputShape = getExactlyOneShape(inputShape);
489+
let depths =
490+
this.dataFormat === 'channelsFirst' ? inputShape[2] : inputShape[1];
491+
let rows =
492+
this.dataFormat === 'channelsFirst' ? inputShape[3] : inputShape[2];
493+
let cols =
494+
this.dataFormat === 'channelsFirst' ? inputShape[4] : inputShape[3];
495+
depths = convOutputLength(
496+
depths, this.poolSize[0], this.padding, this.strides[0]);
497+
rows =
498+
convOutputLength(rows, this.poolSize[1], this.padding, this.strides[1]);
499+
cols =
500+
convOutputLength(cols, this.poolSize[2], this.padding, this.strides[2]);
501+
if (this.dataFormat === 'channelsFirst') {
502+
return [inputShape[0], inputShape[1], depths, rows, cols];
503+
} else {
504+
return [inputShape[0], depths, rows, cols, inputShape[4]];
505+
}
506+
}
507+
508+
protected abstract poolingFunction(
509+
inputs: Tensor, poolSize: [number, number, number],
510+
strides: [number, number, number], padding: PaddingMode,
511+
dataFormat: DataFormat): Tensor;
512+
513+
call(inputs: Tensor|Tensor[], kwargs: Kwargs): Tensor|Tensor[] {
514+
return tidy(() => {
515+
this.invokeCallHook(inputs, kwargs);
516+
return this.poolingFunction(
517+
getExactlyOneTensor(inputs), this.poolSize, this.strides,
518+
this.padding, this.dataFormat);
519+
});
520+
}
521+
522+
getConfig(): serialization.ConfigDict {
523+
const config = {
524+
poolSize: this.poolSize,
525+
padding: this.padding,
526+
strides: this.strides,
527+
dataFormat: this.dataFormat
528+
};
529+
const baseConfig = super.getConfig();
530+
Object.assign(config, baseConfig);
531+
return config;
532+
}
533+
}
534+
535+
export class MaxPooling3D extends Pooling3D {
536+
/** @nocollapse */
537+
static className = 'MaxPooling3D';
538+
constructor(args: Pooling3DLayerArgs) {
539+
super(args);
540+
}
541+
542+
protected poolingFunction(
543+
inputs: Tensor, poolSize: [number, number, number],
544+
strides: [number, number, number], padding: PaddingMode,
545+
dataFormat: DataFormat): Tensor {
546+
checkDataFormat(dataFormat);
547+
checkPaddingMode(padding);
548+
return pool3d(
549+
inputs as Tensor5D, poolSize, strides, padding, dataFormat, 'max');
550+
}
551+
}
552+
serialization.registerClass(MaxPooling3D);
553+
554+
export class AveragePooling3D extends Pooling3D {
555+
/** @nocollapse */
556+
static className = 'AveragePooling3D';
557+
constructor(args: Pooling3DLayerArgs) {
558+
super(args);
559+
}
560+
561+
protected poolingFunction(
562+
inputs: Tensor, poolSize: [number, number, number],
563+
strides: [number, number, number], padding: PaddingMode,
564+
dataFormat: DataFormat): Tensor {
565+
checkDataFormat(dataFormat);
566+
checkPaddingMode(padding);
567+
return pool3d(
568+
inputs as Tensor5D, poolSize, strides, padding, dataFormat, 'avg');
569+
}
570+
}
571+
serialization.registerClass(AveragePooling3D);
572+
373573
/**
374574
* Abstract class for different global pooling 1D layers.
375575
*/

0 commit comments

Comments
 (0)