1313 */
1414
1515import * as tfc from '@tensorflow/tfjs-core' ;
16- import { serialization , Tensor , Tensor3D , Tensor4D , tidy } from '@tensorflow/tfjs-core' ;
16+ import { serialization , Tensor , Tensor3D , Tensor4D , Tensor5D , tidy } from '@tensorflow/tfjs-core' ;
1717
1818import { imageDataFormat } from '../backend/common' ;
1919import * as K from '../backend/tfjs_backend' ;
@@ -27,7 +27,7 @@ import {convOutputLength} from '../utils/conv_utils';
2727import { assertPositiveInteger } from '../utils/generic_utils' ;
2828import { getExactlyOneShape , getExactlyOneTensor } from '../utils/types_utils' ;
2929
30- import { preprocessConv2DInput } from './convolutional' ;
30+ import { preprocessConv2DInput , preprocessConv3DInput } from './convolutional' ;
3131
3232/**
3333 * 2D pooling.
@@ -82,6 +82,52 @@ export function pool2d(
8282 } ) ;
8383}
8484
85+ /**
86+ * 3D pooling.
87+ * @param x
88+ * @param poolSize. Default to [1, 1, 1].
89+ * @param strides strides. Defaults to [1, 1, 1].
90+ * @param padding padding. Defaults to 'valid'.
91+ * @param dataFormat data format. Defaults to 'channelsLast'.
92+ * @param poolMode Mode of pooling. Defaults to 'max'.
93+ * @returns Result of the 3D pooling.
94+ */
95+ export function pool3d (
96+ x : Tensor5D , poolSize : [ number , number , number ] ,
97+ strides ?: [ number , number , number ] , padding ?: PaddingMode ,
98+ dataFormat ?: DataFormat , poolMode ?: PoolMode ) : Tensor {
99+ return tidy ( ( ) => {
100+ checkDataFormat ( dataFormat ) ;
101+ checkPoolMode ( poolMode ) ;
102+ checkPaddingMode ( padding ) ;
103+ if ( strides == null ) {
104+ strides = [ 1 , 1 , 1 ] ;
105+ }
106+ if ( padding == null ) {
107+ padding = 'valid' ;
108+ }
109+ if ( dataFormat == null ) {
110+ dataFormat = imageDataFormat ( ) ;
111+ }
112+ if ( poolMode == null ) {
113+ poolMode = 'max' ;
114+ }
115+
116+ // x is NDHWC after preprocessing.
117+ x = preprocessConv3DInput ( x as Tensor , dataFormat ) as Tensor5D ;
118+ let y : Tensor ;
119+ const paddingString = ( padding === 'same' ) ? 'same' : 'valid' ;
120+ if ( poolMode === 'max' ) {
121+ y = tfc . maxPool3d ( x , poolSize , strides , paddingString ) ;
122+ } else { // 'avg'
123+ y = tfc . avgPool3d ( x , poolSize , strides , paddingString ) ;
124+ }
125+ if ( dataFormat === 'channelsFirst' ) {
126+ y = tfc . transpose ( y , [ 0 , 4 , 1 , 2 , 3 ] ) ; // NDHWC -> NCDHW.
127+ }
128+ return y ;
129+ } ) ;
130+ }
85131
86132export declare interface Pooling1DLayerArgs extends LayerArgs {
87133 /**
@@ -370,6 +416,160 @@ export class AveragePooling2D extends Pooling2D {
370416}
371417serialization . registerClass ( AveragePooling2D ) ;
372418
419+ export declare interface Pooling3DLayerArgs extends LayerArgs {
420+ /**
421+ * Factors by which to downscale in each dimension [depth, height, width].
422+ * Expects an integer or an array of 3 integers.
423+ *
424+ * For example, `[2, 2, 2]` will halve the input in three dimensions.
425+ * If only one integer is specified, the same window length
426+ * will be used for all dimensions.
427+ */
428+ poolSize ?: number | [ number , number , number ] ;
429+
430+ /**
431+ * The size of the stride in each dimension of the pooling window. Expects
432+ * an integer or an array of 3 integers. Integer, tuple of 3 integers, or
433+ * None.
434+ *
435+ * If `null`, defaults to `poolSize`.
436+ */
437+ strides ?: number | [ number , number , number ] ;
438+
439+ /** The padding type to use for the pooling layer. */
440+ padding ?: PaddingMode ;
441+ /** The data format to use for the pooling layer. */
442+ dataFormat ?: DataFormat ;
443+ }
444+
445+ /**
446+ * Abstract class for different pooling 3D layers.
447+ */
448+ export abstract class Pooling3D extends Layer {
449+ protected readonly poolSize : [ number , number , number ] ;
450+ protected readonly strides : [ number , number , number ] ;
451+ protected readonly padding : PaddingMode ;
452+ protected readonly dataFormat : DataFormat ;
453+
454+ constructor ( args : Pooling3DLayerArgs ) {
455+ if ( args . poolSize == null ) {
456+ args . poolSize = [ 2 , 2 , 2 ] ;
457+ }
458+ super ( args ) ;
459+ this . poolSize = Array . isArray ( args . poolSize ) ?
460+ args . poolSize :
461+ [ args . poolSize , args . poolSize , args . poolSize ] ;
462+ if ( args . strides == null ) {
463+ this . strides = this . poolSize ;
464+ } else if ( Array . isArray ( args . strides ) ) {
465+ if ( args . strides . length !== 3 ) {
466+ throw new ValueError (
467+ `If the strides property of a 3D pooling layer is an Array, ` +
468+ `it is expected to have a length of 3, but received length ` +
469+ `${ args . strides . length } .` ) ;
470+ }
471+ this . strides = args . strides ;
472+ } else {
473+ // `config.strides` is a number.
474+ this . strides = [ args . strides , args . strides , args . strides ] ;
475+ }
476+ assertPositiveInteger ( this . poolSize , 'poolSize' ) ;
477+ assertPositiveInteger ( this . strides , 'strides' ) ;
478+ this . padding = args . padding == null ? 'valid' : args . padding ;
479+ this . dataFormat =
480+ args . dataFormat == null ? 'channelsLast' : args . dataFormat ;
481+ checkDataFormat ( this . dataFormat ) ;
482+ checkPaddingMode ( this . padding ) ;
483+
484+ this . inputSpec = [ new InputSpec ( { ndim : 5 } ) ] ;
485+ }
486+
487+ computeOutputShape ( inputShape : Shape | Shape [ ] ) : Shape | Shape [ ] {
488+ inputShape = getExactlyOneShape ( inputShape ) ;
489+ let depths =
490+ this . dataFormat === 'channelsFirst' ? inputShape [ 2 ] : inputShape [ 1 ] ;
491+ let rows =
492+ this . dataFormat === 'channelsFirst' ? inputShape [ 3 ] : inputShape [ 2 ] ;
493+ let cols =
494+ this . dataFormat === 'channelsFirst' ? inputShape [ 4 ] : inputShape [ 3 ] ;
495+ depths = convOutputLength (
496+ depths , this . poolSize [ 0 ] , this . padding , this . strides [ 0 ] ) ;
497+ rows =
498+ convOutputLength ( rows , this . poolSize [ 1 ] , this . padding , this . strides [ 1 ] ) ;
499+ cols =
500+ convOutputLength ( cols , this . poolSize [ 2 ] , this . padding , this . strides [ 2 ] ) ;
501+ if ( this . dataFormat === 'channelsFirst' ) {
502+ return [ inputShape [ 0 ] , inputShape [ 1 ] , depths , rows , cols ] ;
503+ } else {
504+ return [ inputShape [ 0 ] , depths , rows , cols , inputShape [ 4 ] ] ;
505+ }
506+ }
507+
508+ protected abstract poolingFunction (
509+ inputs : Tensor , poolSize : [ number , number , number ] ,
510+ strides : [ number , number , number ] , padding : PaddingMode ,
511+ dataFormat : DataFormat ) : Tensor ;
512+
513+ call ( inputs : Tensor | Tensor [ ] , kwargs : Kwargs ) : Tensor | Tensor [ ] {
514+ return tidy ( ( ) => {
515+ this . invokeCallHook ( inputs , kwargs ) ;
516+ return this . poolingFunction (
517+ getExactlyOneTensor ( inputs ) , this . poolSize , this . strides ,
518+ this . padding , this . dataFormat ) ;
519+ } ) ;
520+ }
521+
522+ getConfig ( ) : serialization . ConfigDict {
523+ const config = {
524+ poolSize : this . poolSize ,
525+ padding : this . padding ,
526+ strides : this . strides ,
527+ dataFormat : this . dataFormat
528+ } ;
529+ const baseConfig = super . getConfig ( ) ;
530+ Object . assign ( config , baseConfig ) ;
531+ return config ;
532+ }
533+ }
534+
535+ export class MaxPooling3D extends Pooling3D {
536+ /** @nocollapse */
537+ static className = 'MaxPooling3D' ;
538+ constructor ( args : Pooling3DLayerArgs ) {
539+ super ( args ) ;
540+ }
541+
542+ protected poolingFunction (
543+ inputs : Tensor , poolSize : [ number , number , number ] ,
544+ strides : [ number , number , number ] , padding : PaddingMode ,
545+ dataFormat : DataFormat ) : Tensor {
546+ checkDataFormat ( dataFormat ) ;
547+ checkPaddingMode ( padding ) ;
548+ return pool3d (
549+ inputs as Tensor5D , poolSize , strides , padding , dataFormat , 'max' ) ;
550+ }
551+ }
552+ serialization . registerClass ( MaxPooling3D ) ;
553+
554+ export class AveragePooling3D extends Pooling3D {
555+ /** @nocollapse */
556+ static className = 'AveragePooling3D' ;
557+ constructor ( args : Pooling3DLayerArgs ) {
558+ super ( args ) ;
559+ }
560+
561+ protected poolingFunction (
562+ inputs : Tensor , poolSize : [ number , number , number ] ,
563+ strides : [ number , number , number ] , padding : PaddingMode ,
564+ dataFormat : DataFormat ) : Tensor {
565+ checkDataFormat ( dataFormat ) ;
566+ checkPaddingMode ( padding ) ;
567+ return pool3d (
568+ inputs as Tensor5D , poolSize , strides , padding , dataFormat , 'avg' ) ;
569+ }
570+ }
571+ serialization . registerClass ( AveragePooling3D ) ;
572+
373573/**
374574 * Abstract class for different global pooling 1D layers.
375575 */
0 commit comments