Skip to content

Commit 6898c1d

Browse files
committed
Improve testing of the maxpool2d forward pass
1 parent a50f5f8 commit 6898c1d

File tree

1 file changed

+18
-9
lines changed

1 file changed

+18
-9
lines changed

test/test_maxpool2d_layer.f90

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,11 @@ program test_maxpool2d_layer
88

99
type(layer) :: maxpool_layer, input_layer
1010
integer, parameter :: pool_size = 2, stride = 2
11-
integer, parameter :: input_shape(3) = [3, 32, 32]
12-
integer, parameter :: output_shape(3) = [3, 16, 16]
11+
integer, parameter :: channels = 3, width = 32
12+
integer, parameter :: input_shape(3) = [channels, width, width]
13+
integer, parameter :: output_shape(3) = [channels, width / 2, width / 2]
1314
real, allocatable :: sample_input(:,:,:), output(:,:,:)
15+
integer :: i, j
1416
logical :: ok = .true.
1517

1618
maxpool_layer = maxpool2d(pool_size)
@@ -43,9 +45,11 @@ program test_maxpool2d_layer
4345
write(stderr, '(a)') 'maxpool2d layer input layer shape should be correct.. failed'
4446
end if
4547

46-
allocate(sample_input(3, 32, 32))
47-
sample_input = 0
48-
sample_input(:,2,2) = 1 ! Set lower-right corner pixel of the upper-left pool
48+
allocate(sample_input(channels, width, width))
49+
50+
do concurrent(i = 1:width, j = 1:width)
51+
sample_input(:,i,j) = i * j
52+
end do
4953

5054
select type(this_layer => input_layer % p); type is(input3d_layer)
5155
call this_layer % set(sample_input)
@@ -54,10 +58,15 @@ program test_maxpool2d_layer
5458
call maxpool_layer % forward(input_layer)
5559
call maxpool_layer % get_output(output)
5660

57-
if (.not. all(output(:,1,1) == [1, 1, 1])) then
58-
ok = .false.
59-
write(stderr, '(a)') 'maxpool2d layer forward pass correctly propagates the max value.. failed'
60-
end if
61+
do j = 1, width / 2
62+
do i = 1, width / 2
63+
! Since input is i*j, maxpool2d output must be stride*i * stride*j
64+
if (.not. all(output(:,i,j) == stride**2 * i * j)) then
65+
ok = .false.
66+
write(stderr, '(a)') 'maxpool2d layer forward pass correctly propagates the max value.. failed'
67+
end if
68+
end do
69+
end do
6170

6271
if (ok) then
6372
print '(a)', 'test_maxpool2d_layer: All tests passed.'

0 commit comments

Comments
 (0)