@@ -8,9 +8,11 @@ program test_maxpool2d_layer
88
99 type (layer) :: maxpool_layer, input_layer
1010 integer , parameter :: pool_size = 2 , stride = 2
11- integer , parameter :: input_shape(3 ) = [3 , 32 , 32 ]
12- integer , parameter :: output_shape(3 ) = [3 , 16 , 16 ]
11+ integer , parameter :: channels = 3 , width = 32
12+ integer , parameter :: input_shape(3 ) = [channels, width, width]
13+ integer , parameter :: output_shape(3 ) = [channels, width / 2 , width / 2 ]
1314 real , allocatable :: sample_input(:,:,:), output(:,:,:)
15+ integer :: i, j
1416 logical :: ok = .true.
1517
1618 maxpool_layer = maxpool2d(pool_size)
@@ -43,9 +45,11 @@ program test_maxpool2d_layer
4345 write (stderr, ' (a)' ) ' maxpool2d layer input layer shape should be correct.. failed'
4446 end if
4547
46- allocate (sample_input(3 , 32 , 32 ))
47- sample_input = 0
48- sample_input(:,2 ,2 ) = 1 ! Set lower-right corner pixel of the upper-left pool
48+ allocate (sample_input(channels, width, width))
49+
50+ do concurrent(i = 1 :width, j = 1 :width)
51+ sample_input(:,i,j) = i * j
52+ end do
4953
5054 select type (this_layer = > input_layer % p); type is(input3d_layer)
5155 call this_layer % set(sample_input)
@@ -54,10 +58,15 @@ program test_maxpool2d_layer
5458 call maxpool_layer % forward(input_layer)
5559 call maxpool_layer % get_output(output)
5660
57- if (.not. all (output(:,1 ,1 ) == [1 , 1 , 1 ])) then
58- ok = .false.
59- write (stderr, ' (a)' ) ' maxpool2d layer forward pass correctly propagates the max value.. failed'
60- end if
61+ do j = 1 , width / 2
62+ do i = 1 , width / 2
63+ ! Since input is i*j, maxpool2d output must be stride*i * stride*j
64+ if (.not. all (output(:,i,j) == stride** 2 * i * j)) then
65+ ok = .false.
66+ write (stderr, ' (a)' ) ' maxpool2d layer forward pass correctly propagates the max value.. failed'
67+ end if
68+ end do
69+ end do
6170
6271 if (ok) then
6372 print ' (a)' , ' test_maxpool2d_layer: All tests passed.'
0 commit comments