Skip to content

Commit b710872

Browse files
committed
Specific procedures for batch output
1 parent 191aec9 commit b710872

File tree

2 files changed

+58
-41
lines changed

2 files changed

+58
-41
lines changed

src/nf/nf_network.f90

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,13 @@ module nf_network
2323

2424
procedure, private :: forward_1d
2525
procedure, private :: forward_3d
26-
procedure, private :: forward_batch_3d
2726
procedure, private :: output_1d
2827
procedure, private :: output_3d
28+
procedure, private :: output_batch_1d
2929
procedure, private :: output_batch_3d
3030

31-
generic :: forward => forward_1d, forward_3d, forward_batch_3d
32-
generic :: output => output_1d, output_3d, output_batch_3d
31+
generic :: forward => forward_1d, forward_3d
32+
generic :: output => output_1d, output_3d, output_batch_1d, output_batch_3d
3333

3434
end type network
3535

@@ -85,20 +85,6 @@ pure module subroutine forward_3d(self, input)
8585
!! 3-d input data
8686
end subroutine forward_3d
8787

88-
pure module subroutine forward_batch_3d(self, input)
89-
!! Apply a forward pass of a batch of data through the network.
90-
!!
91-
!! This changes the state of layers on the network.
92-
!! Typically used only internally from the `train` method,
93-
!! but can be invoked by the user when creating custom optimizers.
94-
!!
95-
!! This specific subroutine is for 3-d input data.
96-
class(network), intent(in out) :: self
97-
!! Network instance
98-
real, intent(in) :: input(:,:,:,:)
99-
!! 3-d input data; the 4th dimension is the batch
100-
end subroutine forward_batch_3d
101-
10288
end interface forward
10389

10490
interface output
@@ -123,14 +109,24 @@ module function output_3d(self, input) result(res)
123109
!! Output of the network
124110
end function output_3d
125111

112+
module function output_batch_1d(self, input) result(res)
113+
!! Return the output of the network given an input batch of 3-d data.
114+
class(network), intent(in out) :: self
115+
!! Network instance
116+
real, intent(in) :: input(:,:)
117+
!! Input data; the last dimension is the batch
118+
real, allocatable :: res(:,:)
119+
!! Output of the network; the last dimension is the batch
120+
end function output_batch_1d
121+
126122
module function output_batch_3d(self, input) result(res)
127123
!! Return the output of the network given an input batch of 3-d data.
128124
class(network), intent(in out) :: self
129125
!! Network instance
130126
real, intent(in) :: input(:,:,:,:)
131-
!! Input data; the 4th dimension is the batch
127+
!! Input data; the last dimension is the batch
132128
real, allocatable :: res(:,:)
133-
!! Output of the network; the 2nd dimension is the batch
129+
!! Output of the network; the last dimension is the batch
134130
end function output_batch_3d
135131

136132
end interface output

src/nf/nf_network_submodule.f90

Lines changed: 43 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -244,23 +244,6 @@ pure module subroutine forward_3d(self, input)
244244
end subroutine forward_3d
245245

246246

247-
pure module subroutine forward_batch_3d(self, input)
248-
class(network), intent(in out) :: self
249-
real, intent(in) :: input(:,:,:,:)
250-
integer :: n
251-
252-
! Set the input array into the input layer
253-
select type(input_layer => self % layers(1) % p); type is(input3d_layer)
254-
call input_layer % set(input)
255-
end select
256-
257-
do n = 2, size(self % layers)
258-
call self % layers(n) % forward(self % layers(n - 1))
259-
end do
260-
261-
end subroutine forward_batch_3d
262-
263-
264247
module function output_1d(self, input) result(res)
265248
class(network), intent(in out) :: self
266249
real, intent(in) :: input(:)
@@ -308,28 +291,66 @@ module function output_3d(self, input) result(res)
308291
end function output_3d
309292

310293

294+
module function output_batch_1d(self, input) result(res)
295+
class(network), intent(in out) :: self
296+
real, intent(in) :: input(:,:)
297+
real, allocatable :: res(:,:)
298+
integer :: i, batch_size, num_layers, output_size
299+
300+
num_layers = size(self % layers)
301+
batch_size = size(input, dim=rank(input))
302+
output_size = product(self % layers(num_layers) % layer_shape)
303+
304+
allocate(res(output_size, batch_size))
305+
306+
batch: do concurrent(i = 1:size(res, dim=2))
307+
308+
call self % forward(input(:,i))
309+
310+
select type(output_layer => self % layers(num_layers) % p)
311+
type is(dense_layer)
312+
res(:,i) = output_layer % output
313+
type is(flatten_layer)
314+
res(:,i) = output_layer % output
315+
class default
316+
error stop 'network % output not implemented for this output layer'
317+
end select
318+
319+
end do batch
320+
321+
end function output_batch_1d
322+
323+
311324
module function output_batch_3d(self, input) result(res)
312325
class(network), intent(in out) :: self
313326
real, intent(in) :: input(:,:,:,:)
314327
real, allocatable :: res(:,:)
315-
integer :: num_layers
328+
integer :: i, batch_size, num_layers, output_size
316329

317330
num_layers = size(self % layers)
331+
batch_size = size(input, dim=rank(input))
332+
output_size = product(self % layers(num_layers) % layer_shape)
318333

319-
call self % forward(input)
334+
allocate(res(output_size, batch_size))
335+
336+
batch: do concurrent(i = 1:batch_size)
337+
338+
call self % forward(input(:,:,:,i))
320339

321340
select type(output_layer => self % layers(num_layers) % p)
322341
type is(conv2d_layer)
323342
!FIXME flatten the result for now; find a better solution
324-
res = pack(output_layer % output, .true.)
343+
res(:,i) = pack(output_layer % output, .true.)
325344
type is(dense_layer)
326-
res = output_layer % output
345+
res(:,i) = output_layer % output
327346
type is(flatten_layer)
328-
res = output_layer % output
347+
res(:,i) = output_layer % output
329348
class default
330349
error stop 'network % output not implemented for this output layer'
331350
end select
332351

352+
end do batch
353+
333354
end function output_batch_3d
334355

335356

0 commit comments

Comments
 (0)