Skip to content

Commit a50f5f8

Browse files
committed
Store max location in maxpool2d layer forward pass
1 parent ff0201c commit a50f5f8

File tree

1 file changed

+14
-6
lines changed

1 file changed

+14
-6
lines changed

src/nf_maxpool2d_layer_submodule.f90

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,13 @@ pure module subroutine forward(self, input)
4343
integer :: i, j, n
4444
integer :: ii, jj
4545
integer :: iend, jend
46+
integer :: maxloc_xy(2)
4647

4748
input_width = size(input, dim=2)
4849
input_height = size(input, dim=2)
4950

5051
! Stride along the width and height of the input image
51-
do concurrent( &
52+
stride_over_input: do concurrent( &
5253
i = 1:input_width:self % stride, &
5354
j = 1:input_height:self % stride &
5455
)
@@ -60,12 +61,19 @@ pure module subroutine forward(self, input)
6061
iend = i + self % pool_size - 1
6162
jend = j + self % pool_size - 1
6263

63-
do concurrent(n = 1:self % channels)
64-
!TODO find and store maxloc
65-
self % output(n,ii,jj) = maxval(input(n,i:iend,j:jend))
66-
end do
64+
maxpool_for_each_channel: do concurrent(n = 1:self % channels)
6765

68-
end do
66+
! Get and store the location of the maximum value
67+
maxloc_xy = maxloc(input(n,i:iend,j:jend))
68+
self % maxloc_x(n,ii,jj) = maxloc_xy(1) + i - 1
69+
self % maxloc_y(n,ii,jj) = maxloc_xy(2) + j - 1
70+
71+
self % output(n,ii,jj) = &
72+
input(n,self % maxloc_x(n,ii,jj),self % maxloc_y(n,ii,jj))
73+
74+
end do maxpool_for_each_channel
75+
76+
end do stride_over_input
6977

7078
end subroutine forward
7179

0 commit comments

Comments
 (0)