Skip to content

Commit 6eaf191

Browse files
committed
Enable network % forward() for 3-d input data
1 parent 6bd4f94 commit 6eaf191

File tree

2 files changed

+60
-13
lines changed

2 files changed

+60
-13
lines changed

src/nf_network.f90

Lines changed: 40 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,22 @@ module nf_network
1111
public :: network
1212

1313
type :: network
14+
1415
type(layer), allocatable :: layers(:)
16+
1517
contains
18+
1619
procedure :: backward
17-
procedure :: forward
1820
procedure :: output
1921
procedure :: print_info
2022
procedure :: train
2123
procedure :: update
24+
25+
procedure, private :: forward_1d
26+
procedure, private :: forward_3d
27+
28+
generic :: forward => forward_1d, forward_3d
29+
2230
end type network
2331

2432
interface network
@@ -32,29 +40,50 @@ module function network_cons(layers) result(res)
3240
end function network_cons
3341
end interface network
3442

35-
interface
43+
interface forward
3644

37-
pure module subroutine backward(self, output)
38-
!! Apply one backward pass through the network.
45+
pure module subroutine forward_1d(self, input)
46+
!! Apply a forward pass through the network.
47+
!!
3948
!! This changes the state of layers on the network.
4049
!! Typically used only internally from the `train` method,
4150
!! but can be invoked by the user when creating custom optimizers.
51+
!!
52+
!! This specific subroutine is for 1-d input data.
4253
class(network), intent(in out) :: self
4354
!! Network instance
44-
real, intent(in) :: output(:)
45-
!! Output data
46-
end subroutine backward
55+
real, intent(in) :: input(:)
56+
!! 1-d input data
57+
end subroutine forward_1d
4758

48-
pure module subroutine forward(self, input)
59+
pure module subroutine forward_3d(self, input)
4960
!! Apply a forward pass through the network.
61+
!!
5062
!! This changes the state of layers on the network.
5163
!! Typically used only internally from the `train` method,
5264
!! but can be invoked by the user when creating custom optimizers.
65+
!!
66+
!! This specific subroutine is for 3-d input data.
5367
class(network), intent(in out) :: self
5468
!! Network instance
55-
real, intent(in) :: input(:)
56-
!! Input data
57-
end subroutine forward
69+
real, intent(in) :: input(:,:,:)
70+
!! 3-d input data
71+
end subroutine forward_3d
72+
73+
end interface forward
74+
75+
interface
76+
77+
pure module subroutine backward(self, output)
78+
!! Apply one backward pass through the network.
79+
!! This changes the state of layers on the network.
80+
!! Typically used only internally from the `train` method,
81+
!! but can be invoked by the user when creating custom optimizers.
82+
class(network), intent(in out) :: self
83+
!! Network instance
84+
real, intent(in) :: output(:)
85+
!! Output data
86+
end subroutine backward
5887

5988
module function output(self, input) result(res)
6089
!! Return the output of the network given the input array.

src/nf_network_submodule.f90

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
use nf_dense_layer, only: dense_layer
44
use nf_input1d_layer, only: input1d_layer
5+
use nf_input3d_layer, only: input3d_layer
56
use nf_layer, only: layer
67
use nf_loss, only: quadratic_derivative
78
use nf_optimizers, only: sgd
@@ -80,7 +81,7 @@ pure module subroutine backward(self, output)
8081
end subroutine backward
8182

8283

83-
pure module subroutine forward(self, input)
84+
pure module subroutine forward_1d(self, input)
8485
class(network), intent(in out) :: self
8586
real, intent(in) :: input(:)
8687
integer :: n
@@ -94,7 +95,24 @@ pure module subroutine forward(self, input)
9495
call self % layers(n) % forward(self % layers(n - 1))
9596
end do
9697

97-
end subroutine forward
98+
end subroutine forward_1d
99+
100+
101+
pure module subroutine forward_3d(self, input)
102+
class(network), intent(in out) :: self
103+
real, intent(in) :: input(:,:,:)
104+
integer :: n
105+
106+
! Set the input array into the input layer
107+
select type(input_layer => self % layers(1) % p); type is(input3d_layer)
108+
call input_layer % set(input)
109+
end select
110+
111+
do n = 2, size(self % layers)
112+
call self % layers(n) % forward(self % layers(n - 1))
113+
end do
114+
115+
end subroutine forward_3d
98116

99117

100118
module function output(self, input) result(res)

0 commit comments

Comments
 (0)