@@ -291,6 +291,69 @@ module function output_3d(self, input) result(res)
291291 end function output_3d
292292
293293
294+ module function output_batch_1d (self , input ) result(res)
295+ class(network), intent (in out ) :: self
296+ real , intent (in ) :: input(:,:)
297+ real , allocatable :: res(:,:)
298+ integer :: i, batch_size, num_layers, output_size
299+
300+ num_layers = size (self % layers)
301+ batch_size = size (input, dim= rank(input))
302+ output_size = product (self % layers(num_layers) % layer_shape)
303+
304+ allocate (res(output_size, batch_size))
305+
306+ batch: do concurrent(i = 1 :size (res, dim= 2 ))
307+
308+ call self % forward(input(:,i))
309+
310+ select type (output_layer = > self % layers(num_layers) % p)
311+ type is (dense_layer)
312+ res(:,i) = output_layer % output
313+ type is (flatten_layer)
314+ res(:,i) = output_layer % output
315+ class default
316+ error stop ' network % output not implemented for this output layer'
317+ end select
318+
319+ end do batch
320+
321+ end function output_batch_1d
322+
323+
324+ module function output_batch_3d (self , input ) result(res)
325+ class(network), intent (in out ) :: self
326+ real , intent (in ) :: input(:,:,:,:)
327+ real , allocatable :: res(:,:)
328+ integer :: i, batch_size, num_layers, output_size
329+
330+ num_layers = size (self % layers)
331+ batch_size = size (input, dim= rank(input))
332+ output_size = product (self % layers(num_layers) % layer_shape)
333+
334+ allocate (res(output_size, batch_size))
335+
336+ batch: do concurrent(i = 1 :batch_size)
337+
338+ call self % forward(input(:,:,:,i))
339+
340+ select type (output_layer = > self % layers(num_layers) % p)
341+ type is (conv2d_layer)
342+ ! FIXME flatten the result for now; find a better solution
343+ res(:,i) = pack (output_layer % output, .true. )
344+ type is (dense_layer)
345+ res(:,i) = output_layer % output
346+ type is (flatten_layer)
347+ res(:,i) = output_layer % output
348+ class default
349+ error stop ' network % output not implemented for this output layer'
350+ end select
351+
352+ end do batch
353+
354+ end function output_batch_3d
355+
356+
294357 module subroutine print_info (self )
295358 class(network), intent (in ) :: self
296359 call self % layers % print_info()
0 commit comments