Skip to content

Commit 1192ef0

Browse files
committed
First test for maxpool2d % forward(); does not pass yet
1 parent 33b9c5d commit 1192ef0

File tree

1 file changed

+21
-3
lines changed

1 file changed

+21
-3
lines changed

test/test_maxpool2d_layer.f90

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ program test_maxpool2d_layer
88

99
type(layer) :: maxpool_layer, input_layer
1010
integer, parameter :: pool_size = 2, stride = 2
11+
integer, parameter :: input_shape(3) = [3, 32, 32]
12+
integer, parameter :: output_shape(3) = [3, 16, 16]
1113
real, allocatable :: sample_input(:,:,:), output(:,:,:)
1214
logical :: ok = .true.
1315

@@ -23,24 +25,40 @@ program test_maxpool2d_layer
2325
write(stderr, '(a)') 'maxpool2d layer should not be marked as initialized yet.. failed'
2426
end if
2527

26-
input_layer = input([3, 32, 32])
28+
input_layer = input(input_shape)
2729
call maxpool_layer % init(input_layer)
2830

2931
if (.not. maxpool_layer % initialized) then
3032
ok = .false.
3133
write(stderr, '(a)') 'maxpool2d layer should now be marked as initialized.. failed'
3234
end if
3335

34-
if (.not. all(maxpool_layer % input_layer_shape == [3, 32, 32])) then
36+
if (.not. all(maxpool_layer % input_layer_shape == input_shape)) then
3537
ok = .false.
3638
write(stderr, '(a)') 'maxpool2d layer input layer shape should be correct.. failed'
3739
end if
3840

39-
if (.not. all(maxpool_layer % layer_shape == [3, 16, 16])) then
41+
if (.not. all(maxpool_layer % layer_shape == output_shape)) then
4042
ok = .false.
4143
write(stderr, '(a)') 'maxpool2d layer input layer shape should be correct.. failed'
4244
end if
4345

46+
allocate(sample_input(3, 32, 32))
47+
sample_input = 0
48+
sample_input(:,2,2) = 1 ! Set lower-right corner pixel of the upper-left pool
49+
50+
select type(this_layer => input_layer % p); type is(input3d_layer)
51+
call this_layer % set(sample_input)
52+
end select
53+
54+
call maxpool_layer % forward(input_layer)
55+
call maxpool_layer % get_output(output)
56+
57+
if (.not. all(output(:,1,1) == [1, 1, 1])) then
58+
ok = .false.
59+
write(stderr, '(a)') 'maxpool2d layer forward pass correctly propagates the max value.. failed'
60+
end if
61+
4462
if (ok) then
4563
print '(a)', 'test_maxpool2d_layer: All tests passed.'
4664
else

0 commit comments

Comments
 (0)