@@ -8,7 +8,7 @@ program test_flatten_layer
88 implicit none
99
1010 type (layer) :: test_layer, input_layer
11- real , allocatable :: input_data(:,:,:)
11+ real , allocatable :: input_data(:,:,:), gradient(:,:,:)
1212 real , allocatable :: output(:)
1313 logical :: ok = .true.
1414
@@ -37,6 +37,8 @@ program test_flatten_layer
3737 write (stderr, ' (a)' ) ' flatten layer has an incorrect output shape.. failed'
3838 end if
3939
40+ ! Test forward pass - reshaping from 3-d to 1-d
41+
4042 select type (this_layer = > input_layer % p); type is(input3d_layer)
4143 call this_layer % set(reshape (real ([1 , 2 , 3 , 4 ]), [1 , 2 , 2 ]))
4244 end select
@@ -49,6 +51,21 @@ program test_flatten_layer
4951 write (stderr, ' (a)' ) ' flatten layer correctly propagates forward.. failed'
5052 end if
5153
54+ ! Test backward pass - reshaping from 1-d to 3-d
55+
56+ ! Calling backward() will set the values on the gradient component
57+ ! input_layer is used only to determine shape
58+ call test_layer % backward(input_layer, real ([1 , 2 , 3 , 4 ]))
59+
60+ select type (this_layer = > test_layer % p); type is(flatten_layer)
61+ gradient = this_layer % gradient
62+ end select
63+
64+ if (.not. all (gradient == reshape (real ([1 , 2 , 3 , 4 ]), [1 , 2 , 2 ]))) then
65+ ok = .false.
66+ write (stderr, ' (a)' ) ' flatten layer correctly propagates backward.. failed'
67+ end if
68+
5269 if (ok) then
5370 print ' (a)' , ' test_flatten_layer: All tests passed.'
5471 else
0 commit comments