Skip to content

Commit adef7d7

Browse files
committed
Reading coefficients from h5f model
Note a hardcoded 'simple_rnn_cell_23' that must be resolved later.
1 parent 8c11911 commit adef7d7

File tree

1 file changed

+203
-1
lines changed

1 file changed

+203
-1
lines changed

src/nf/nf_network_submodule.f90

Lines changed: 203 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,209 @@ module function network_from_layers(layers) result(res)
9696
end function network_from_layers
9797

9898

99-
module subroutine backward(self, output, loss)
99+
module function network_from_keras(filename) result(res)
100+
character(*), intent(in) :: filename
101+
type(network) :: res
102+
type(keras_layer), allocatable :: keras_layers(:)
103+
type(layer), allocatable :: layers(:)
104+
character(:), allocatable :: layer_name
105+
character(:), allocatable :: object_name
106+
integer :: n
107+
108+
keras_layers = get_keras_h5_layers(filename)
109+
110+
allocate(layers(size(keras_layers)))
111+
112+
do n = 1, size(layers)
113+
114+
select case(keras_layers(n) % class)
115+
116+
case('Conv2D')
117+
118+
if (keras_layers(n) % kernel_size(1) &
119+
/= keras_layers(n) % kernel_size(2)) &
120+
error stop 'Non-square kernel in conv2d layer not supported.'
121+
122+
layers(n) = conv2d( &
123+
keras_layers(n) % filters, &
124+
!FIXME add support for non-square kernel
125+
keras_layers(n) % kernel_size(1), &
126+
get_activation_by_name(keras_layers(n) % activation) &
127+
)
128+
129+
case('Dense')
130+
131+
layers(n) = dense( &
132+
keras_layers(n) % units(1), &
133+
get_activation_by_name(keras_layers(n) % activation) &
134+
)
135+
136+
case('Flatten')
137+
layers(n) = flatten()
138+
139+
case('InputLayer')
140+
if (size(keras_layers(n) % units) == 1) then
141+
! input1d
142+
layers(n) = input(keras_layers(n) % units(1))
143+
else
144+
! input3d
145+
layers(n) = input(keras_layers(n) % units)
146+
end if
147+
148+
case('MaxPooling2D')
149+
150+
if (keras_layers(n) % pool_size(1) &
151+
/= keras_layers(n) % pool_size(2)) &
152+
error stop 'Non-square pool in maxpool2d layer not supported.'
153+
154+
if (keras_layers(n) % strides(1) &
155+
/= keras_layers(n) % strides(2)) &
156+
error stop 'Unequal strides in maxpool2d layer are not supported.'
157+
158+
layers(n) = maxpool2d( &
159+
!FIXME add support for non-square pool and stride
160+
keras_layers(n) % pool_size(1), &
161+
keras_layers(n) % strides(1) &
162+
)
163+
164+
case('Reshape')
165+
layers(n) = reshape(keras_layers(n) % target_shape)
166+
167+
case default
168+
error stop 'This Keras layer is not supported'
169+
170+
end select
171+
172+
end do
173+
174+
res = network(layers)
175+
176+
! Loop over layers and read weights and biases from the Keras h5 file
177+
! for each; currently only dense layers are implemented.
178+
do n = 2, size(res % layers)
179+
180+
layer_name = keras_layers(n) % name
181+
182+
select type(this_layer => res % layers(n) % p)
183+
184+
type is(conv2d_layer)
185+
! Read biases from file
186+
object_name = '/model_weights/' // layer_name // '/' &
187+
// layer_name // '/bias:0'
188+
call get_hdf5_dataset(filename, object_name, this_layer % biases)
189+
190+
! Read weights from file
191+
object_name = '/model_weights/' // layer_name // '/' &
192+
// layer_name // '/kernel:0'
193+
call get_hdf5_dataset(filename, object_name, this_layer % kernel)
194+
195+
type is(dense_layer)
196+
197+
! Read biases from file
198+
object_name = '/model_weights/' // layer_name // '/' &
199+
// layer_name // '/bias:0'
200+
call get_hdf5_dataset(filename, object_name, this_layer % biases)
201+
202+
! Read weights from file
203+
object_name = '/model_weights/' // layer_name // '/' &
204+
// layer_name // '/kernel:0'
205+
call get_hdf5_dataset(filename, object_name, this_layer % weights)
206+
207+
type is(flatten_layer)
208+
! Nothing to do
209+
continue
210+
211+
type is(maxpool2d_layer)
212+
! Nothing to do
213+
continue
214+
215+
type is(reshape3d_layer)
216+
! Nothing to do
217+
continue
218+
219+
type is(rnn_layer)
220+
221+
! Read biases from file
222+
object_name = '/model_weights/' // layer_name // '/' &
223+
// layer_name // '/simple_rnn_cell_23/bias:0'
224+
call get_hdf5_dataset(filename, object_name, this_layer % biases)
225+
226+
! Read weights from file
227+
object_name = '/model_weights/' // layer_name // '/' &
228+
// layer_name // '/simple_rnn_cell_23/kernel:0'
229+
call get_hdf5_dataset(filename, object_name, this_layer % weights)
230+
231+
! Read recurrent weights from file
232+
object_name = '/model_weights/' // layer_name // '/' &
233+
// layer_name // '/simple_rnn_cell_23/recurrent_kernel:0'
234+
call get_hdf5_dataset(filename, object_name, this_layer % recurrent)
235+
236+
class default
237+
error stop 'Internal error in network_from_keras(); ' &
238+
// 'mismatch in layer types between the Keras and ' &
239+
// 'neural-fortran model layers.'
240+
241+
end select
242+
243+
end do
244+
245+
end function network_from_keras
246+
247+
248+
pure function get_activation_by_name(activation_name) result(res)
249+
! Workaround to get activation_function with some
250+
! hardcoded default parameters by its name.
251+
! Need this function since we get only activation name
252+
! from keras files.
253+
character(len=*), intent(in) :: activation_name
254+
class(activation_function), allocatable :: res
255+
256+
select case(trim(activation_name))
257+
case('elu')
258+
allocate ( res, source = elu(alpha = 0.1) )
259+
260+
case('exponential')
261+
allocate ( res, source = exponential() )
262+
263+
case('gaussian')
264+
allocate ( res, source = gaussian() )
265+
266+
case('linear')
267+
allocate ( res, source = linear() )
268+
269+
case('relu')
270+
allocate ( res, source = relu() )
271+
272+
case('leaky_relu')
273+
allocate ( res, source = leaky_relu(alpha = 0.1) )
274+
275+
case('sigmoid')
276+
allocate ( res, source = sigmoid() )
277+
278+
case('softmax')
279+
allocate ( res, source = softmax() )
280+
281+
case('softplus')
282+
allocate ( res, source = softplus() )
283+
284+
case('step')
285+
allocate ( res, source = step() )
286+
287+
case('tanh')
288+
allocate ( res, source = tanhf() )
289+
290+
case('celu')
291+
allocate ( res, source = celu() )
292+
293+
case default
294+
error stop 'activation_name must be one of: ' // &
295+
'"elu", "exponential", "gaussian", "linear", "relu", ' // &
296+
'"leaky_relu", "sigmoid", "softmax", "softplus", "step", "tanh" or "celu".'
297+
end select
298+
299+
end function get_activation_by_name
300+
301+
pure module subroutine backward(self, output, loss)
100302
class(network), intent(in out) :: self
101303
real, intent(in) :: output(:)
102304
class(loss_type), intent(in), optional :: loss

0 commit comments

Comments
 (0)