Skip to content

Commit 7cad8ca

Browse files
committed
Instantiating a network from h5 with correct size
1 parent 472164e commit 7cad8ca

File tree

1 file changed

+43
-10
lines changed

1 file changed

+43
-10
lines changed

test/test_parse_keras_model.f90

Lines changed: 43 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,24 @@ program test_parse_keras_model
33
use iso_fortran_env, only: stderr => error_unit
44
use nf_datasets, only: download_and_unpack, keras_model_dense_mnist_url
55
use nf_io_hdf5, only: get_h5_attribute_string
6+
use nf, only: layer, network, dense, input
67
use json_module
78

89
implicit none
910

1011
character(:), allocatable :: model_config_string
1112
character(*), parameter :: test_data_path = 'keras_dense_mnist.h5'
1213
type(json_core) :: json
13-
type(json_value), pointer :: model_config, layers, next_layer, layer
14-
character(:), allocatable :: class_name, layer_type
14+
type(json_value), pointer :: &
15+
model_config, layer_list, this_layer, layer_config
16+
character(:), allocatable :: class_name, layer_type, activation
17+
real, allocatable :: tmp_array(:)
18+
19+
type(layer), allocatable :: layers(:)
20+
type(network) :: net
21+
1522
logical :: found
16-
integer :: n, num_layers
23+
integer :: n, num_layers, num_elements
1724
logical :: file_exists
1825
logical :: ok = .true.
1926

@@ -24,18 +31,44 @@ program test_parse_keras_model
2431
get_h5_attribute_string(test_data_path, '.', 'model_config')
2532

2633
call json % parse(model_config, model_config_string)
27-
call json % get(model_config, 'config.layers', layers)
34+
call json % get(model_config, 'config.layers', layer_list)
35+
36+
num_layers = json % count(layer_list)
2837

29-
num_layers = json % count(layers)
38+
layers = [layer ::]
3039

40+
! Iterate over layers
3141
do n = 1, num_layers
32-
call json % get_child(layers, n, layer)
33-
!print *, 'Layer', n
34-
!call json % print(layer)
35-
call json % get(layer, 'class_name', layer_type)
36-
!print *, layer_type
42+
43+
! Get pointer to the layer
44+
call json % get_child(layer_list, n, this_layer)
45+
46+
! Get type of layer as a string
47+
call json % get(this_layer, 'class_name', layer_type)
48+
49+
! Get pointer to the layer config
50+
call json % get(this_layer, 'config', layer_config)
51+
52+
! Get size of layer and activation if applicable;
53+
! Instantiate neural-fortran layers at this time.
54+
if (layer_type == 'InputLayer') then
55+
call json % get(layer_config, 'batch_input_shape', tmp_array, found)
56+
num_elements = tmp_array(2)
57+
layers = [layers, input(num_elements)]
58+
else if (layer_type == 'Dense') then
59+
call json % get(layer_config, 'units', num_elements, found)
60+
call json % get(layer_config, 'activation', activation, found)
61+
layers = [layers, dense(num_elements, activation)]
62+
else
63+
error stop 'This layer is not supported'
64+
end if
65+
66+
print *, n, layer_type, num_elements, activation
3767
end do
3868

69+
net = network(layers)
70+
call net % print_info()
71+
3972
if (.not. num_layers == 3) then
4073
ok = .false.
4174
write(stderr, '(a)') 'Keras dense MNIST model has 3 layers.. failed'

0 commit comments

Comments
 (0)