Skip to content

Commit be5ca0a

Browse files
committed
Test the flatten backward pass
1 parent 09fc9a5 commit be5ca0a

File tree

1 file changed

+18
-1
lines changed

1 file changed

+18
-1
lines changed

test/test_flatten_layer.f90

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ program test_flatten_layer
88
implicit none
99

1010
type(layer) :: test_layer, input_layer
11-
real, allocatable :: input_data(:,:,:)
11+
real, allocatable :: input_data(:,:,:), gradient(:,:,:)
1212
real, allocatable :: output(:)
1313
logical :: ok = .true.
1414

@@ -37,6 +37,8 @@ program test_flatten_layer
3737
write(stderr, '(a)') 'flatten layer has an incorrect output shape.. failed'
3838
end if
3939

40+
! Test forward pass - reshaping from 3-d to 1-d
41+
4042
select type(this_layer => input_layer % p); type is(input3d_layer)
4143
call this_layer % set(reshape(real([1, 2, 3, 4]), [1, 2, 2]))
4244
end select
@@ -49,6 +51,21 @@ program test_flatten_layer
4951
write(stderr, '(a)') 'flatten layer correctly propagates forward.. failed'
5052
end if
5153

54+
! Test backward pass - reshaping from 1-d to 3-d
55+
56+
! Calling backward() will set the values on the gradient component
57+
! input_layer is used only to determine shape
58+
call test_layer % backward(input_layer, real([1, 2, 3, 4]))
59+
60+
select type(this_layer => test_layer % p); type is(flatten_layer)
61+
gradient = this_layer % gradient
62+
end select
63+
64+
if (.not. all(gradient == reshape(real([1, 2, 3, 4]), [1, 2, 2]))) then
65+
ok = .false.
66+
write(stderr, '(a)') 'flatten layer correctly propagates backward.. failed'
67+
end if
68+
5269
if (ok) then
5370
print '(a)', 'test_flatten_layer: All tests passed.'
5471
else

0 commit comments

Comments
 (0)