@@ -96,7 +96,209 @@ module function network_from_layers(layers) result(res)
9696 end function network_from_layers
9797
9898
99- module subroutine backward (self , output , loss )
99+ module function network_from_keras (filename ) result(res)
100+ character (* ), intent (in ) :: filename
101+ type (network) :: res
102+ type (keras_layer), allocatable :: keras_layers(:)
103+ type (layer), allocatable :: layers(:)
104+ character (:), allocatable :: layer_name
105+ character (:), allocatable :: object_name
106+ integer :: n
107+
108+ keras_layers = get_keras_h5_layers(filename)
109+
110+ allocate (layers(size (keras_layers)))
111+
112+ do n = 1 , size (layers)
113+
114+ select case (keras_layers(n) % class)
115+
116+ case (' Conv2D' )
117+
118+ if (keras_layers(n) % kernel_size(1 ) &
119+ /= keras_layers(n) % kernel_size(2 )) &
120+ error stop ' Non-square kernel in conv2d layer not supported.'
121+
122+ layers(n) = conv2d( &
123+ keras_layers(n) % filters, &
124+ ! FIXME add support for non-square kernel
125+ keras_layers(n) % kernel_size(1 ), &
126+ get_activation_by_name(keras_layers(n) % activation) &
127+ )
128+
129+ case (' Dense' )
130+
131+ layers(n) = dense( &
132+ keras_layers(n) % units(1 ), &
133+ get_activation_by_name(keras_layers(n) % activation) &
134+ )
135+
136+ case (' Flatten' )
137+ layers(n) = flatten()
138+
139+ case (' InputLayer' )
140+ if (size (keras_layers(n) % units) == 1 ) then
141+ ! input1d
142+ layers(n) = input(keras_layers(n) % units(1 ))
143+ else
144+ ! input3d
145+ layers(n) = input(keras_layers(n) % units)
146+ end if
147+
148+ case (' MaxPooling2D' )
149+
150+ if (keras_layers(n) % pool_size(1 ) &
151+ /= keras_layers(n) % pool_size(2 )) &
152+ error stop ' Non-square pool in maxpool2d layer not supported.'
153+
154+ if (keras_layers(n) % strides(1 ) &
155+ /= keras_layers(n) % strides(2 )) &
156+ error stop ' Unequal strides in maxpool2d layer are not supported.'
157+
158+ layers(n) = maxpool2d( &
159+ ! FIXME add support for non-square pool and stride
160+ keras_layers(n) % pool_size(1 ), &
161+ keras_layers(n) % strides(1 ) &
162+ )
163+
164+ case (' Reshape' )
165+ layers(n) = reshape (keras_layers(n) % target_shape)
166+
167+ case default
168+ error stop ' This Keras layer is not supported'
169+
170+ end select
171+
172+ end do
173+
174+ res = network(layers)
175+
176+ ! Loop over layers and read weights and biases from the Keras h5 file
177+ ! for each; currently only dense layers are implemented.
178+ do n = 2 , size (res % layers)
179+
180+ layer_name = keras_layers(n) % name
181+
182+ select type (this_layer = > res % layers(n) % p)
183+
184+ type is (conv2d_layer)
185+ ! Read biases from file
186+ object_name = ' /model_weights/' // layer_name // ' /' &
187+ // layer_name // ' /bias:0'
188+ call get_hdf5_dataset(filename, object_name, this_layer % biases)
189+
190+ ! Read weights from file
191+ object_name = ' /model_weights/' // layer_name // ' /' &
192+ // layer_name // ' /kernel:0'
193+ call get_hdf5_dataset(filename, object_name, this_layer % kernel)
194+
195+ type is (dense_layer)
196+
197+ ! Read biases from file
198+ object_name = ' /model_weights/' // layer_name // ' /' &
199+ // layer_name // ' /bias:0'
200+ call get_hdf5_dataset(filename, object_name, this_layer % biases)
201+
202+ ! Read weights from file
203+ object_name = ' /model_weights/' // layer_name // ' /' &
204+ // layer_name // ' /kernel:0'
205+ call get_hdf5_dataset(filename, object_name, this_layer % weights)
206+
207+ type is (flatten_layer)
208+ ! Nothing to do
209+ continue
210+
211+ type is (maxpool2d_layer)
212+ ! Nothing to do
213+ continue
214+
215+ type is (reshape3d_layer)
216+ ! Nothing to do
217+ continue
218+
219+ type is (rnn_layer)
220+
221+ ! Read biases from file
222+ object_name = ' /model_weights/' // layer_name // ' /' &
223+ // layer_name // ' /simple_rnn_cell_23/bias:0'
224+ call get_hdf5_dataset(filename, object_name, this_layer % biases)
225+
226+ ! Read weights from file
227+ object_name = ' /model_weights/' // layer_name // ' /' &
228+ // layer_name // ' /simple_rnn_cell_23/kernel:0'
229+ call get_hdf5_dataset(filename, object_name, this_layer % weights)
230+
231+ ! Read recurrent weights from file
232+ object_name = ' /model_weights/' // layer_name // ' /' &
233+ // layer_name // ' /simple_rnn_cell_23/recurrent_kernel:0'
234+ call get_hdf5_dataset(filename, object_name, this_layer % recurrent)
235+
236+ class default
237+ error stop ' Internal error in network_from_keras(); ' &
238+ // ' mismatch in layer types between the Keras and ' &
239+ // ' neural-fortran model layers.'
240+
241+ end select
242+
243+ end do
244+
245+ end function network_from_keras
246+
247+
248+ pure function get_activation_by_name (activation_name ) result(res)
249+ ! Workaround to get activation_function with some
250+ ! hardcoded default parameters by its name.
251+ ! Need this function since we get only activation name
252+ ! from keras files.
253+ character (len=* ), intent (in ) :: activation_name
254+ class(activation_function), allocatable :: res
255+
256+ select case (trim (activation_name))
257+ case (' elu' )
258+ allocate ( res, source = elu(alpha = 0.1 ) )
259+
260+ case (' exponential' )
261+ allocate ( res, source = exponential() )
262+
263+ case (' gaussian' )
264+ allocate ( res, source = gaussian() )
265+
266+ case (' linear' )
267+ allocate ( res, source = linear() )
268+
269+ case (' relu' )
270+ allocate ( res, source = relu() )
271+
272+ case (' leaky_relu' )
273+ allocate ( res, source = leaky_relu(alpha = 0.1 ) )
274+
275+ case (' sigmoid' )
276+ allocate ( res, source = sigmoid() )
277+
278+ case (' softmax' )
279+ allocate ( res, source = softmax() )
280+
281+ case (' softplus' )
282+ allocate ( res, source = softplus() )
283+
284+ case (' step' )
285+ allocate ( res, source = step() )
286+
287+ case (' tanh' )
288+ allocate ( res, source = tanhf() )
289+
290+ case (' celu' )
291+ allocate ( res, source = celu() )
292+
293+ case default
294+ error stop ' activation_name must be one of: ' // &
295+ ' "elu", "exponential", "gaussian", "linear", "relu", ' // &
296+ ' "leaky_relu", "sigmoid", "softmax", "softplus", "step", "tanh" or "celu".'
297+ end select
298+
299+ end function get_activation_by_name
300+
301+ pure module subroutine backward(self, output, loss)
100302 class(network), intent (in out ) :: self
101303 real , intent (in ) :: output(:)
102304 class(loss_type), intent (in ), optional :: loss
0 commit comments