Skip to content

Commit ab1faa6

Browse files
committed
network % output is now generic for input ranks 1 and 3
1 parent 40a6f33 commit ab1faa6

File tree

2 files changed

+55
-15
lines changed

2 files changed

+55
-15
lines changed

src/nf/nf_network.f90

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,17 @@ module nf_network
1717
contains
1818

1919
procedure :: backward
20-
procedure :: output
2120
procedure :: print_info
2221
procedure :: train
2322
procedure :: update
2423

2524
procedure, private :: forward_1d
2625
procedure, private :: forward_3d
26+
procedure, private :: output_1d
27+
procedure, private :: output_3d
2728

2829
generic :: forward => forward_1d, forward_3d
30+
generic :: output => output_1d, output_3d
2931

3032
end type network
3133

@@ -72,6 +74,30 @@ end subroutine forward_3d
7274

7375
end interface forward
7476

77+
interface output
78+
79+
module function output_1d(self, input) result(res)
80+
!! Return the output of the network given the input 1-d array.
81+
class(network), intent(in out) :: self
82+
!! Network instance
83+
real, intent(in) :: input(:)
84+
!! Input data
85+
real, allocatable :: res(:)
86+
!! Output of the network
87+
end function output_1d
88+
89+
module function output_3d(self, input) result(res)
90+
!! Return the output of the network given the input 3-d array.
91+
class(network), intent(in out) :: self
92+
!! Network instance
93+
real, intent(in) :: input(:,:,:)
94+
!! Input data
95+
real, allocatable :: res(:)
96+
!! Output of the network
97+
end function output_3d
98+
99+
end interface output
100+
75101
interface
76102

77103
pure module subroutine backward(self, output)
@@ -85,16 +111,6 @@ pure module subroutine backward(self, output)
85111
!! Output data
86112
end subroutine backward
87113

88-
module function output(self, input) result(res)
89-
!! Return the output of the network given the input array.
90-
class(network), intent(in out) :: self
91-
!! Network instance
92-
real, intent(in) :: input(:)
93-
!! Input data
94-
real, allocatable :: res(:)
95-
!! Output of the network
96-
end function output
97-
98114
module subroutine print_info(self)
99115
!! Prints a brief summary of the network and its layers to the screen.
100116
class(network), intent(in) :: self

src/nf/nf_network_submodule.f90

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
submodule(nf_network) nf_network_submodule
22

33
use nf_dense_layer, only: dense_layer
4+
use nf_flatten_layer, only: flatten_layer
45
use nf_input1d_layer, only: input1d_layer
56
use nf_input3d_layer, only: input3d_layer
67
use nf_layer, only: layer
@@ -114,7 +115,7 @@ pure module subroutine forward_3d(self, input)
114115
end subroutine forward_3d
115116

116117

117-
module function output(self, input) result(res)
118+
module function output_1d(self, input) result(res)
118119
class(network), intent(in out) :: self
119120
real, intent(in) :: input(:)
120121
real, allocatable :: res(:)
@@ -124,11 +125,34 @@ module function output(self, input) result(res)
124125

125126
call self % forward(input)
126127

127-
select type(output_layer => self % layers(num_layers) % p); type is(dense_layer)
128-
res = output_layer % output
128+
select type(output_layer => self % layers(num_layers) % p)
129+
type is(dense_layer)
130+
res = output_layer % output
131+
type is(flatten_layer)
132+
res = output_layer % output
129133
end select
130134

131-
end function output
135+
end function output_1d
136+
137+
138+
module function output_3d(self, input) result(res)
139+
class(network), intent(in out) :: self
140+
real, intent(in) :: input(:,:,:)
141+
real, allocatable :: res(:)
142+
integer :: num_layers
143+
144+
num_layers = size(self % layers)
145+
146+
call self % forward(input)
147+
148+
select type(output_layer => self % layers(num_layers) % p)
149+
type is(dense_layer)
150+
res = output_layer % output
151+
type is(flatten_layer)
152+
res = output_layer % output
153+
end select
154+
155+
end function output_3d
132156

133157

134158
module subroutine print_info(self)

0 commit comments

Comments
 (0)