Skip to content

Commit 191aec9

Browse files
committed
Add specific network methods for batch forward and output
1 parent 75666fe commit 191aec9

File tree

2 files changed

+70
-2
lines changed

2 files changed

+70
-2
lines changed

src/nf/nf_network.f90

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,13 @@ module nf_network
2323

2424
procedure, private :: forward_1d
2525
procedure, private :: forward_3d
26+
procedure, private :: forward_batch_3d
2627
procedure, private :: output_1d
2728
procedure, private :: output_3d
29+
procedure, private :: output_batch_3d
2830

29-
generic :: forward => forward_1d, forward_3d
30-
generic :: output => output_1d, output_3d
31+
generic :: forward => forward_1d, forward_3d, forward_batch_3d
32+
generic :: output => output_1d, output_3d, output_batch_3d
3133

3234
end type network
3335

@@ -83,6 +85,20 @@ pure module subroutine forward_3d(self, input)
8385
!! 3-d input data
8486
end subroutine forward_3d
8587

88+
pure module subroutine forward_batch_3d(self, input)
89+
!! Apply a forward pass of a batch of data through the network.
90+
!!
91+
!! This changes the state of layers on the network.
92+
!! Typically used only internally from the `train` method,
93+
!! but can be invoked by the user when creating custom optimizers.
94+
!!
95+
!! This specific subroutine is for 3-d input data.
96+
class(network), intent(in out) :: self
97+
!! Network instance
98+
real, intent(in) :: input(:,:,:,:)
99+
!! 3-d input data; the 4th dimension is the batch
100+
end subroutine forward_batch_3d
101+
86102
end interface forward
87103

88104
interface output
@@ -107,6 +123,16 @@ module function output_3d(self, input) result(res)
107123
!! Output of the network
108124
end function output_3d
109125

126+
module function output_batch_3d(self, input) result(res)
127+
!! Return the output of the network given an input batch of 3-d data.
128+
class(network), intent(in out) :: self
129+
!! Network instance
130+
real, intent(in) :: input(:,:,:,:)
131+
!! Input data; the 4th dimension is the batch
132+
real, allocatable :: res(:,:)
133+
!! Output of the network; the 2nd dimension is the batch
134+
end function output_batch_3d
135+
110136
end interface output
111137

112138
interface

src/nf/nf_network_submodule.f90

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,23 @@ pure module subroutine forward_3d(self, input)
244244
end subroutine forward_3d
245245

246246

247+
pure module subroutine forward_batch_3d(self, input)
248+
class(network), intent(in out) :: self
249+
real, intent(in) :: input(:,:,:,:)
250+
integer :: n
251+
252+
! Set the input array into the input layer
253+
select type(input_layer => self % layers(1) % p); type is(input3d_layer)
254+
call input_layer % set(input)
255+
end select
256+
257+
do n = 2, size(self % layers)
258+
call self % layers(n) % forward(self % layers(n - 1))
259+
end do
260+
261+
end subroutine forward_batch_3d
262+
263+
247264
module function output_1d(self, input) result(res)
248265
class(network), intent(in out) :: self
249266
real, intent(in) :: input(:)
@@ -291,6 +308,31 @@ module function output_3d(self, input) result(res)
291308
end function output_3d
292309

293310

311+
module function output_batch_3d(self, input) result(res)
312+
class(network), intent(in out) :: self
313+
real, intent(in) :: input(:,:,:,:)
314+
real, allocatable :: res(:,:)
315+
integer :: num_layers
316+
317+
num_layers = size(self % layers)
318+
319+
call self % forward(input)
320+
321+
select type(output_layer => self % layers(num_layers) % p)
322+
type is(conv2d_layer)
323+
!FIXME flatten the result for now; find a better solution
324+
res = pack(output_layer % output, .true.)
325+
type is(dense_layer)
326+
res = output_layer % output
327+
type is(flatten_layer)
328+
res = output_layer % output
329+
class default
330+
error stop 'network % output not implemented for this output layer'
331+
end select
332+
333+
end function output_batch_3d
334+
335+
294336
module subroutine print_info(self)
295337
class(network), intent(in) :: self
296338
call self % layers % print_info()

0 commit comments

Comments
 (0)