@@ -156,5 +156,37 @@ describe('convolution', () => {
156156 input1 [ 0 ] , input2 [ 0 ] , [ 2 , 2 ] , 'same' , 'NHWC' , [ 2 , 2 ] ) ;
157157 } ) ;
158158 } ) ;
159+
160+ describe ( 'AvgPool3D' , ( ) => {
161+ it ( 'should call tfc.avgPool3d' , ( ) => {
162+ spyOn ( tfc , 'avgPool3d' ) ;
163+ node . op = 'AvgPool3D' ;
164+ node . attrParams [ 'strides' ] = createNumericArrayAttr ( [ 1 , 2 , 2 , 2 , 1 ] ) ;
165+ node . attrParams [ 'pad' ] = createStrAttr ( 'same' ) ;
166+ node . attrParams [ 'kernelSize' ] =
167+ createNumericArrayAttr ( [ 1 , 2 , 2 , 2 , 1 ] ) ;
168+
169+ executeOp ( node , { input} , context ) ;
170+
171+ expect ( tfc . avgPool3d )
172+ . toHaveBeenCalledWith ( input [ 0 ] , [ 2 , 2 , 2 ] , [ 2 , 2 , 2 ] , 'same' ) ;
173+ } ) ;
174+ } ) ;
175+
176+ describe ( 'MaxPool3D' , ( ) => {
177+ it ( 'should call tfc.maxPool3d' , ( ) => {
178+ spyOn ( tfc , 'maxPool3d' ) ;
179+ node . op = 'MaxPool3D' ;
180+ node . attrParams [ 'strides' ] = createNumericArrayAttr ( [ 1 , 2 , 2 , 2 , 1 ] ) ;
181+ node . attrParams [ 'pad' ] = createStrAttr ( 'same' ) ;
182+ node . attrParams [ 'kernelSize' ] =
183+ createNumericArrayAttr ( [ 1 , 2 , 2 , 2 , 1 ] ) ;
184+
185+ executeOp ( node , { input} , context ) ;
186+
187+ expect ( tfc . maxPool3d )
188+ . toHaveBeenCalledWith ( input [ 0 ] , [ 2 , 2 , 2 ] , [ 2 , 2 , 2 ] , 'same' ) ;
189+ } ) ;
190+ } ) ;
159191 } ) ;
160192} ) ;
0 commit comments