Skip to content

Commit e98cb11

Browse files
committed
Create a network from Keras file; setting weights and biases remains TODO
1 parent e20fb0a commit e98cb11

File tree

3 files changed

+126
-6
lines changed

3 files changed

+126
-6
lines changed

src/nf/nf_network.f90

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,25 @@ module nf_network
3232
end type network
3333

3434
interface network
35-
module function network_cons(layers) result(res)
36-
!! Create a new `network` instance.
35+
36+
module function network_from_layers(layers) result(res)
37+
!! Create a new `network` instance from an array of `layer` instances.
3738
type(layer), intent(in) :: layers(:)
38-
!! Input array of layer instances;
39+
!! Input array of `layer` instances;
3940
!! the first element must be an input layer.
4041
type(network) :: res
4142
!! An instance of the `network` type
42-
end function network_cons
43+
end function network_from_layers
44+
45+
module function network_from_keras(filename) result(res)
46+
!! Create a new `network` instance
47+
!! from a Keras model saved in an h5 file.
48+
character(*), intent(in) :: filename
49+
!! Path to the Keras model h5 file
50+
type(network) :: res
51+
!! An instance of the `network` type
52+
end function network_from_keras
53+
4354
end interface network
4455

4556
interface forward

src/nf/nf_network_submodule.f90

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
use nf_flatten_layer, only: flatten_layer
55
use nf_input1d_layer, only: input1d_layer
66
use nf_input3d_layer, only: input3d_layer
7+
use nf_keras, only: get_keras_h5_layers, keras_layer
78
use nf_layer, only: layer
9+
use nf_layer_constructors, only: dense, input
810
use nf_loss, only: quadratic_derivative
911
use nf_optimizers, only: sgd
1012
use nf_parallel, only: tile_indices
@@ -13,7 +15,7 @@
1315

1416
contains
1517

16-
module function network_cons(layers) result(res)
18+
module function network_from_layers(layers) result(res)
1719
type(layer), intent(in) :: layers(:)
1820
type(network) :: res
1921
integer :: n
@@ -45,7 +47,51 @@ module function network_cons(layers) result(res)
4547
call res % layers(n) % init(res % layers(n - 1))
4648
end do
4749

48-
end function network_cons
50+
end function network_from_layers
51+
52+
53+
module function network_from_keras(filename) result(res)
54+
character(*), intent(in) :: filename
55+
type(network) :: res
56+
type(keras_layer), allocatable :: keras_layers(:)
57+
type(layer), allocatable :: layers(:)
58+
integer :: n
59+
60+
keras_layers = get_keras_h5_layers(filename)
61+
62+
allocate(layers(size(keras_layers)))
63+
64+
do n = 1, size(layers)
65+
66+
select case(keras_layers(n) % class)
67+
68+
case('InputLayer')
69+
if (size(keras_layers(n) % num_elements) == 1) then
70+
! input1d
71+
layers(n) = input(keras_layers(n) % num_elements(1))
72+
else
73+
! input3d
74+
layers(n) = input(keras_layers(n) % num_elements)
75+
end if
76+
77+
case('Dense')
78+
layers(n) = dense( &
79+
keras_layers(n) % num_elements(1), &
80+
keras_layers(n) % activation &
81+
)
82+
83+
case default
84+
error stop 'This Keras layer is not supported'
85+
86+
end select
87+
88+
end do
89+
90+
res = network(layers)
91+
92+
!TODO read weights and biases from Keras file and set here
93+
94+
end function network_from_keras
4995

5096

5197
pure module subroutine backward(self, output)
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
program test_dense_network_from_keras
2+
3+
use iso_fortran_env, only: stderr => error_unit
4+
use nf, only: network
5+
use nf_datasets, only: download_and_unpack, keras_model_dense_mnist_url
6+
7+
implicit none
8+
9+
type(network) :: net
10+
character(*), parameter :: test_data_path = 'keras_dense_mnist.h5'
11+
12+
logical :: file_exists
13+
logical :: ok = .true.
14+
15+
inquire(file=test_data_path, exist=file_exists)
16+
if (.not. file_exists) call download_and_unpack(keras_model_dense_mnist_url)
17+
18+
net = network(test_data_path)
19+
20+
if (.not. size(net % layers) == 3) then
21+
write(stderr, '(a)') 'dense network should have 3 layers.. failed'
22+
ok = .false.
23+
end if
24+
25+
if (.not. net % layers(1) % name == 'input') then
26+
write(stderr, '(a)') 'First layer should be an input layer.. failed'
27+
ok = .false.
28+
end if
29+
30+
if (.not. all(net % layers(1) % layer_shape == [784])) then
31+
write(stderr, '(a)') 'First layer should have shape [784].. failed'
32+
ok = .false.
33+
end if
34+
35+
if (.not. net % layers(2) % name == 'dense') then
36+
write(stderr, '(a)') 'Second layer should be a dense layer.. failed'
37+
ok = .false.
38+
end if
39+
40+
if (.not. all(net % layers(2) % layer_shape == [30])) then
41+
write(stderr, '(a)') 'Second layer should have shape [30].. failed'
42+
ok = .false.
43+
end if
44+
45+
if (.not. net % layers(3) % name == 'dense') then
46+
write(stderr, '(a)') 'Third layer should be a dense layer.. failed'
47+
ok = .false.
48+
end if
49+
50+
if (.not. all(net % layers(3) % layer_shape == [10])) then
51+
write(stderr, '(a)') 'Third layer should have shape [10].. failed'
52+
ok = .false.
53+
end if
54+
55+
if (ok) then
56+
print '(a)', 'test_dense_network_from_keras: All tests passed.'
57+
else
58+
write(stderr, '(a)') &
59+
'test_dense_network_from_keras: One or more tests failed.'
60+
stop 1
61+
end if
62+
63+
end program test_dense_network_from_keras

0 commit comments

Comments
 (0)