@@ -43,12 +43,13 @@ pure module subroutine forward(self, input)
4343 integer :: i, j, n
4444 integer :: ii, jj
4545 integer :: iend, jend
46+ integer :: maxloc_xy(2 )
4647
4748 input_width = size (input, dim= 2 )
4849 input_height = size (input, dim= 2 )
4950
5051 ! Stride along the width and height of the input image
51- do concurrent( &
52+ stride_over_input: do concurrent( &
5253 i = 1 :input_width:self % stride, &
5354 j = 1 :input_height:self % stride &
5455 )
@@ -60,12 +61,19 @@ pure module subroutine forward(self, input)
6061 iend = i + self % pool_size - 1
6162 jend = j + self % pool_size - 1
6263
63- do concurrent(n = 1 :self % channels)
64- ! TODO find and store maxloc
65- self % output(n,ii,jj) = maxval (input(n,i:iend,j:jend))
66- end do
64+ maxpool_for_each_channel: do concurrent(n = 1 :self % channels)
6765
68- end do
66+ ! Get and store the location of the maximum value
67+ maxloc_xy = maxloc (input(n,i:iend,j:jend))
68+ self % maxloc_x(n,ii,jj) = maxloc_xy(1 ) + i - 1
69+ self % maxloc_y(n,ii,jj) = maxloc_xy(2 ) + j - 1
70+
71+ self % output(n,ii,jj) = &
72+ input(n,self % maxloc_x(n,ii,jj),self % maxloc_y(n,ii,jj))
73+
74+ end do maxpool_for_each_channel
75+
76+ end do stride_over_input
6977
7078 end subroutine forward
7179
0 commit comments