@@ -244,23 +244,6 @@ pure module subroutine forward_3d(self, input)
244244 end subroutine forward_3d
245245
246246
247- pure module subroutine forward_batch_3d(self, input)
248- class(network), intent (in out ) :: self
249- real , intent (in ) :: input(:,:,:,:)
250- integer :: n
251-
252- ! Set the input array into the input layer
253- select type (input_layer = > self % layers(1 ) % p); type is(input3d_layer)
254- call input_layer % set(input)
255- end select
256-
257- do n = 2 , size (self % layers)
258- call self % layers(n) % forward(self % layers(n - 1 ))
259- end do
260-
261- end subroutine forward_batch_3d
262-
263-
264247 module function output_1d (self , input ) result(res)
265248 class(network), intent (in out ) :: self
266249 real , intent (in ) :: input(:)
@@ -308,28 +291,66 @@ module function output_3d(self, input) result(res)
308291 end function output_3d
309292
310293
294+ module function output_batch_1d (self , input ) result(res)
295+ class(network), intent (in out ) :: self
296+ real , intent (in ) :: input(:,:)
297+ real , allocatable :: res(:,:)
298+ integer :: i, batch_size, num_layers, output_size
299+
300+ num_layers = size (self % layers)
301+ batch_size = size (input, dim= rank(input))
302+ output_size = product (self % layers(num_layers) % layer_shape)
303+
304+ allocate (res(output_size, batch_size))
305+
306+ batch: do concurrent(i = 1 :size (res, dim= 2 ))
307+
308+ call self % forward(input(:,i))
309+
310+ select type (output_layer = > self % layers(num_layers) % p)
311+ type is (dense_layer)
312+ res(:,i) = output_layer % output
313+ type is (flatten_layer)
314+ res(:,i) = output_layer % output
315+ class default
316+ error stop ' network % output not implemented for this output layer'
317+ end select
318+
319+ end do batch
320+
321+ end function output_batch_1d
322+
323+
311324 module function output_batch_3d (self , input ) result(res)
312325 class(network), intent (in out ) :: self
313326 real , intent (in ) :: input(:,:,:,:)
314327 real , allocatable :: res(:,:)
315- integer :: num_layers
328+ integer :: i, batch_size, num_layers, output_size
316329
317330 num_layers = size (self % layers)
331+ batch_size = size (input, dim= rank(input))
332+ output_size = product (self % layers(num_layers) % layer_shape)
318333
319- call self % forward(input)
334+ allocate (res(output_size, batch_size))
335+
336+ batch: do concurrent(i = 1 :batch_size)
337+
338+ call self % forward(input(:,:,:,i))
320339
321340 select type (output_layer = > self % layers(num_layers) % p)
322341 type is (conv2d_layer)
323342 ! FIXME flatten the result for now; find a better solution
324- res = pack (output_layer % output, .true. )
343+ res(:,i) = pack (output_layer % output, .true. )
325344 type is (dense_layer)
326- res = output_layer % output
345+ res(:,i) = output_layer % output
327346 type is (flatten_layer)
328- res = output_layer % output
347+ res(:,i) = output_layer % output
329348 class default
330349 error stop ' network % output not implemented for this output layer'
331350 end select
332351
352+ end do batch
353+
333354 end function output_batch_3d
334355
335356
0 commit comments