@@ -244,7 +244,7 @@ pure module subroutine forward_3d(self, input)
244244 end subroutine forward_3d
245245
246246
247- module function output_1d (self , input ) result(res)
247+ module function predict_1d (self , input ) result(res)
248248 class(network), intent (in out ) :: self
249249 real , intent (in ) :: input(:)
250250 real , allocatable :: res(:)
@@ -263,10 +263,10 @@ module function output_1d(self, input) result(res)
263263 error stop ' network % output not implemented for this output layer'
264264 end select
265265
266- end function output_1d
266+ end function predict_1d
267267
268268
269- module function output_3d (self , input ) result(res)
269+ module function predict_3d (self , input ) result(res)
270270 class(network), intent (in out ) :: self
271271 real , intent (in ) :: input(:,:,:)
272272 real , allocatable :: res(:)
@@ -288,10 +288,10 @@ module function output_3d(self, input) result(res)
288288 error stop ' network % output not implemented for this output layer'
289289 end select
290290
291- end function output_3d
291+ end function predict_3d
292292
293293
294- module function output_batch_1d (self , input ) result(res)
294+ module function predict_batch_1d (self , input ) result(res)
295295 class(network), intent (in out ) :: self
296296 real , intent (in ) :: input(:,:)
297297 real , allocatable :: res(:,:)
@@ -318,10 +318,10 @@ module function output_batch_1d(self, input) result(res)
318318
319319 end do batch
320320
321- end function output_batch_1d
321+ end function predict_batch_1d
322322
323323
324- module function output_batch_3d (self , input ) result(res)
324+ module function predict_batch_3d (self , input ) result(res)
325325 class(network), intent (in out ) :: self
326326 real , intent (in ) :: input(:,:,:,:)
327327 real , allocatable :: res(:,:)
@@ -335,23 +335,23 @@ module function output_batch_3d(self, input) result(res)
335335
336336 batch: do concurrent(i = 1 :batch_size)
337337
338- call self % forward(input(:,:,:,i))
338+ call self % forward(input(:,:,:,i))
339339
340- select type (output_layer = > self % layers(num_layers) % p)
341- type is (conv2d_layer)
342- ! FIXME flatten the result for now; find a better solution
343- res(:,i) = pack (output_layer % output, .true. )
344- type is (dense_layer)
345- res(:,i) = output_layer % output
346- type is (flatten_layer)
347- res(:,i) = output_layer % output
348- class default
349- error stop ' network % output not implemented for this output layer'
350- end select
340+ select type (output_layer = > self % layers(num_layers) % p)
341+ type is (conv2d_layer)
342+ ! FIXME flatten the result for now; find a better solution
343+ res(:,i) = pack (output_layer % output, .true. )
344+ type is (dense_layer)
345+ res(:,i) = output_layer % output
346+ type is (flatten_layer)
347+ res(:,i) = output_layer % output
348+ class default
349+ error stop ' network % output not implemented for this output layer'
350+ end select
351351
352352 end do batch
353353
354- end function output_batch_3d
354+ end function predict_batch_3d
355355
356356
357357 module subroutine print_info (self )
0 commit comments