Skip to content

Commit 00e9802

Browse files
committed
Add a few tests for a convolutional network; no training yet
1 parent 8876c99 commit 00e9802

File tree

2 files changed

+43
-1
lines changed

2 files changed

+43
-1
lines changed

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ string(REGEX REPLACE "^ | $" "" LIBS "${LIBS}")
100100

101101
# tests
102102
enable_testing()
103-
foreach(execid input1d_layer input3d_layer dense_layer conv2d_layer dense_network)
103+
foreach(execid input1d_layer input3d_layer dense_layer conv2d_layer dense_network conv2d_network)
104104
add_executable(test_${execid} test/test_${execid}.f90)
105105
target_link_libraries(test_${execid} neural ${LIBS})
106106
add_test(test_${execid} bin/test_${execid})

test/test_conv2d_network.f90

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
program test_conv2d_network
2+
3+
use iso_fortran_env, only: stderr => error_unit
4+
use nf, only: conv2d, input, network
5+
6+
implicit none
7+
8+
type(network) :: net
9+
real, allocatable :: sample_input(:,:,:), output(:,:,:)
10+
logical :: ok = .true.
11+
12+
! 3-layer convolutional network
13+
net = network([ &
14+
input([3, 32, 32]), &
15+
conv2d(window_size=3, filters=16), &
16+
conv2d(window_size=3, filters=32) &
17+
])
18+
19+
if (.not. size(net % layers) == 3) then
20+
write(stderr, '(a)') 'conv2d network should have 3 layers.. failed'
21+
ok = .false.
22+
end if
23+
24+
allocate(sample_input(3, 32, 32))
25+
sample_input = 0
26+
27+
call net % forward(sample_input)
28+
call net % layers(3) % get_output(output)
29+
30+
if (.not. all(shape(output) == [32, 28, 28])) then
31+
write(stderr, '(a)') 'conv2d network output should have correct shape.. failed'
32+
ok = .false.
33+
end if
34+
35+
if (ok) then
36+
print '(a)', 'test_dense_network: All tests passed.'
37+
else
38+
write(stderr, '(a)') 'test_dense_network: One or more tests failed.'
39+
stop 1
40+
end if
41+
42+
end program test_conv2d_network

0 commit comments

Comments
 (0)