Skip to content

Commit b545005

Browse files
committed
Assign layer_shape for maxpool2d in layer % init()
1 parent a29bfcf commit b545005

File tree

2 files changed

+26
-3
lines changed

2 files changed

+26
-3
lines changed

src/nf_layer_submodule.f90

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
use nf_dense_layer, only: dense_layer
55
use nf_input1d_layer, only: input1d_layer
66
use nf_input3d_layer, only: input3d_layer
7+
use nf_maxpool2d_layer, only: maxpool2d_layer
78

89
implicit none
910

@@ -109,9 +110,13 @@ impure elemental module subroutine init(self, input)
109110
call this_layer % init(input % layer_shape)
110111
end select
111112

112-
! The shape of a conv2d layer is not known until we receive an input layer.
113-
select type(this_layer => self % p); type is(conv2d_layer)
114-
self % layer_shape = shape(this_layer % output)
113+
! The shape of conv2d or maxpool2d layers is not known
114+
! until we receive an input layer.
115+
select type(this_layer => self % p)
116+
type is(conv2d_layer)
117+
self % layer_shape = shape(this_layer % output)
118+
type is(maxpool2d_layer)
119+
self % layer_shape = shape(this_layer % output)
115120
end select
116121

117122
self % input_layer_shape = input % layer_shape

test/test_maxpool2d_layer.f90

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,24 @@ program test_maxpool2d_layer
2323
write(stderr, '(a)') 'maxpool2d layer should not be marked as initialized yet.. failed'
2424
end if
2525

26+
input_layer = input([3, 32, 32])
27+
call maxpool_layer % init(input_layer)
28+
29+
if (.not. maxpool_layer % initialized) then
30+
ok = .false.
31+
write(stderr, '(a)') 'maxpool2d layer should now be marked as initialized.. failed'
32+
end if
33+
34+
if (.not. all(maxpool_layer % input_layer_shape == [3, 32, 32])) then
35+
ok = .false.
36+
write(stderr, '(a)') 'maxpool2d layer input layer shape should be correct.. failed'
37+
end if
38+
39+
if (.not. all(maxpool_layer % layer_shape == [3, 16, 16])) then
40+
ok = .false.
41+
write(stderr, '(a)') 'maxpool2d layer input layer shape should be correct.. failed'
42+
end if
43+
2644
if (ok) then
2745
print '(a)', 'test_maxpool2d_layer: All tests passed.'
2846
else

0 commit comments

Comments
 (0)