Skip to content

Commit 2b71b76

Browse files
committed
Parsing Input and Dense layers and their sizes from Keras h5 works
1 parent e027b13 commit 2b71b76

File tree

3 files changed

+95
-49
lines changed

3 files changed

+95
-49
lines changed

src/nf/nf_keras.f90

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ module nf_keras
66
implicit none
77

88
private
9-
public :: get_keras_h5_layers
9+
public :: get_keras_h5_layers, keras_layer
1010

1111
type :: keras_layer
1212
!! Intermediate container to convey the Keras layer information

src/nf/nf_keras_submodule.f90

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
submodule(nf_keras) nf_keras_submodule
2+
3+
use json_module, only: json_core, json_value
4+
use nf_io_hdf5, only: get_h5_attribute_string
5+
6+
implicit none
7+
8+
contains
9+
10+
module function get_keras_h5_layers(filename) result(res)
11+
character(*), intent(in) :: filename
12+
type(keras_layer), allocatable :: res(:)
13+
14+
character(:), allocatable :: model_config_string
15+
16+
type(json_core) :: json
17+
type(json_value), pointer :: &
18+
model_config_json, layers_json, layer_json, layer_config_json
19+
20+
real, allocatable :: tmp_array(:)
21+
integer :: n, num_layers, num_elements
22+
logical :: found
23+
24+
model_config_string = get_h5_attribute_string(filename, '.', 'model_config')
25+
26+
call json % parse(model_config_json, model_config_string)
27+
call json % get(model_config_json, 'config.layers', layers_json)
28+
29+
num_layers = json % count(layers_json)
30+
31+
allocate(res(num_layers))
32+
33+
! Iterate over layers
34+
layers: do n = 1, num_layers
35+
36+
! Get pointer to the layer
37+
call json % get_child(layers_json, n, layer_json)
38+
39+
! Get type of layer as a string
40+
call json % get(layer_json, 'class_name', res(n) % type)
41+
42+
! Get pointer to the layer config
43+
call json % get(layer_json, 'config', layer_config_json)
44+
45+
! Get size of layer and activation if applicable;
46+
! Instantiate neural-fortran layers at this time.
47+
if (res(n) % type == 'InputLayer') then
48+
49+
call json % get(layer_config_json, 'batch_input_shape', tmp_array)
50+
res(n) % num_elements = [tmp_array(2)]
51+
52+
else if (res(n) % type == 'Dense') then
53+
54+
call json % get(layer_config_json, 'units', num_elements, found)
55+
res(n) % num_elements = [num_elements]
56+
57+
call json % get(layer_config_json, 'activation', res(n) % activation)
58+
59+
else
60+
61+
error stop 'This layer is not supported'
62+
63+
end if
64+
65+
end do layers
66+
67+
end function get_keras_h5_layers
68+
69+
end submodule nf_keras_submodule

test/test_parse_keras_model.f90

Lines changed: 25 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -2,76 +2,53 @@ program test_parse_keras_model
22

33
use iso_fortran_env, only: stderr => error_unit
44
use nf_datasets, only: download_and_unpack, keras_model_dense_mnist_url
5-
use nf_io_hdf5, only: get_h5_attribute_string
5+
use nf_keras, only: get_keras_h5_layers, keras_layer
66
use nf, only: layer, network, dense, input
7-
use json_module
87

98
implicit none
109

1110
character(:), allocatable :: model_config_string
1211
character(*), parameter :: test_data_path = 'keras_dense_mnist.h5'
13-
type(json_core) :: json
14-
type(json_value), pointer :: &
15-
model_config, layer_list, this_layer, layer_config
16-
character(:), allocatable :: class_name, layer_type, activation
17-
real, allocatable :: tmp_array(:)
12+
13+
type(keras_layer), allocatable :: keras_layers(:)
1814

1915
type(layer), allocatable :: layers(:)
2016
type(network) :: net
2117

22-
logical :: found
23-
integer :: n, num_layers, num_elements
18+
integer :: n
2419
logical :: file_exists
2520
logical :: ok = .true.
2621

2722
inquire(file=test_data_path, exist=file_exists)
2823
if (.not. file_exists) call download_and_unpack(keras_model_dense_mnist_url)
2924

30-
model_config_string = &
31-
get_h5_attribute_string(test_data_path, '.', 'model_config')
32-
33-
call json % parse(model_config, model_config_string)
34-
call json % get(model_config, 'config.layers', layer_list)
35-
36-
num_layers = json % count(layer_list)
37-
38-
layers = [layer ::]
39-
40-
! Iterate over layers
41-
do n = 1, num_layers
25+
keras_layers = get_keras_h5_layers(test_data_path)
4226

43-
! Get pointer to the layer
44-
call json % get_child(layer_list, n, this_layer)
45-
46-
! Get type of layer as a string
47-
call json % get(this_layer, 'class_name', layer_type)
48-
49-
! Get pointer to the layer config
50-
call json % get(this_layer, 'config', layer_config)
27+
if (size(keras_layers) /= 3) then
28+
ok = .false.
29+
write(stderr, '(a)') 'Keras dense MNIST model has 3 layers.. failed'
30+
end if
5131

52-
! Get size of layer and activation if applicable;
53-
! Instantiate neural-fortran layers at this time.
54-
if (layer_type == 'InputLayer') then
55-
call json % get(layer_config, 'batch_input_shape', tmp_array, found)
56-
num_elements = tmp_array(2)
57-
layers = [layers, input(num_elements)]
58-
else if (layer_type == 'Dense') then
59-
call json % get(layer_config, 'units', num_elements, found)
60-
call json % get(layer_config, 'activation', activation, found)
61-
layers = [layers, dense(num_elements, activation)]
62-
else
63-
error stop 'This layer is not supported'
64-
end if
32+
if (keras_layers(1) % type /= 'InputLayer') then
33+
ok = .false.
34+
write(stderr, '(a)') 'Keras first layer should be InputLayer.. failed'
35+
end if
6536

66-
print *, n, layer_type, num_elements, activation
67-
end do
37+
if (.not. all(keras_layers(1) % num_elements == [784])) then
38+
ok = .false.
39+
write(stderr, '(a)') 'Keras first layer should have 784 elements.. failed'
40+
end if
6841

69-
net = network(layers)
70-
call net % print_info()
42+
if (allocated(keras_layers(1) % activation)) then
43+
ok = .false.
44+
write(stderr, '(a)') &
45+
'Keras first layer activation should not be allocated.. failed'
46+
end if
7147

72-
if (.not. num_layers == 3) then
48+
if (.not. keras_layers(2) % type == 'Dense') then
7349
ok = .false.
74-
write(stderr, '(a)') 'Keras dense MNIST model has 3 layers.. failed'
50+
write(stderr, '(a)') &
51+
'Keras second and third layers should be dense.. failed'
7552
end if
7653

7754
if (ok) then

0 commit comments

Comments
 (0)