Skip to content

Commit e20fb0a

Browse files
committed
Read Keras layer names
1 parent a526a39 commit e20fb0a

File tree

3 files changed

+19
-18
lines changed

3 files changed

+19
-18
lines changed

src/nf/nf_keras.f90

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ module nf_keras
1111
type :: keras_layer
1212
!! Intermediate container to convey the Keras layer information
1313
!! to neural-fortran layer constructors.
14-
character(:), allocatable :: type
14+
character(:), allocatable :: class
1515
character(:), allocatable :: name
1616
character(:), allocatable :: activation
1717
integer, allocatable :: num_elements(:)

src/nf/nf_keras_submodule.f90

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -39,30 +39,31 @@ module function get_keras_h5_layers(filename) result(res)
3939
call json % get_child(layers_json, n, layer_json)
4040

4141
! Get type of layer as a string
42-
call json % get(layer_json, 'class_name', res(n) % type)
42+
call json % get(layer_json, 'class_name', res(n) % class)
4343

4444
! Get pointer to the layer config
4545
call json % get(layer_json, 'config', layer_config_json)
4646

47+
! Get layer name
48+
call json % get(layer_config_json, 'name', res(n) % name)
49+
4750
! Get size of layer and activation if applicable;
4851
! Instantiate neural-fortran layers at this time.
49-
if (res(n) % type == 'InputLayer') then
50-
51-
call json % get(layer_config_json, 'batch_input_shape', tmp_array)
52-
res(n) % num_elements = [tmp_array(2)]
53-
54-
else if (res(n) % type == 'Dense') then
55-
56-
call json % get(layer_config_json, 'units', num_elements, found)
57-
res(n) % num_elements = [num_elements]
52+
select case(res(n) % class)
5853

59-
call json % get(layer_config_json, 'activation', res(n) % activation)
54+
case('InputLayer')
55+
call json % get(layer_config_json, 'batch_input_shape', tmp_array)
56+
res(n) % num_elements = [tmp_array(2)]
6057

61-
else
62-
63-
error stop 'This layer is not supported'
58+
case('Dense')
59+
call json % get(layer_config_json, 'units', num_elements, found)
60+
res(n) % num_elements = [num_elements]
61+
call json % get(layer_config_json, 'activation', res(n) % activation)
6462

65-
end if
63+
case default
64+
error stop 'This Keras layer is not supported'
65+
66+
end select
6667

6768
end do layers
6869

test/test_keras_read_model.f90

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ program test_keras_read_model
2929
write(stderr, '(a)') 'Keras dense MNIST model has 3 layers.. failed'
3030
end if
3131

32-
if (keras_layers(1) % type /= 'InputLayer') then
32+
if (keras_layers(1) % class /= 'InputLayer') then
3333
ok = .false.
3434
write(stderr, '(a)') 'Keras first layer should be InputLayer.. failed'
3535
end if
@@ -45,7 +45,7 @@ program test_keras_read_model
4545
'Keras first layer activation should not be allocated.. failed'
4646
end if
4747

48-
if (.not. keras_layers(2) % type == 'Dense') then
48+
if (.not. keras_layers(2) % class == 'Dense') then
4949
ok = .false.
5050
write(stderr, '(a)') &
5151
'Keras second and third layers should be dense.. failed'

0 commit comments

Comments
 (0)