@@ -17,15 +17,17 @@ module nf_network
1717 contains
1818
1919 procedure :: backward
20- procedure :: output
2120 procedure :: print_info
2221 procedure :: train
2322 procedure :: update
2423
2524 procedure , private :: forward_1d
2625 procedure , private :: forward_3d
26+ procedure , private :: output_1d
27+ procedure , private :: output_3d
2728
2829 generic :: forward = > forward_1d, forward_3d
30+ generic :: output = > output_1d, output_3d
2931
3032 end type network
3133
@@ -72,6 +74,30 @@ end subroutine forward_3d
7274
7375 end interface forward
7476
77+ interface output
78+
79+ module function output_1d (self , input ) result(res)
80+ ! ! Return the output of the network given the input 1-d array.
81+ class(network), intent (in out ) :: self
82+ ! ! Network instance
83+ real , intent (in ) :: input(:)
84+ ! ! Input data
85+ real , allocatable :: res(:)
86+ ! ! Output of the network
87+ end function output_1d
88+
89+ module function output_3d (self , input ) result(res)
90+ ! ! Return the output of the network given the input 3-d array.
91+ class(network), intent (in out ) :: self
92+ ! ! Network instance
93+ real , intent (in ) :: input(:,:,:)
94+ ! ! Input data
95+ real , allocatable :: res(:)
96+ ! ! Output of the network
97+ end function output_3d
98+
99+ end interface output
100+
75101 interface
76102
77103 pure module subroutine backward(self, output)
@@ -85,16 +111,6 @@ pure module subroutine backward(self, output)
85111 ! ! Output data
86112 end subroutine backward
87113
88- module function output (self , input ) result(res)
89- ! ! Return the output of the network given the input array.
90- class(network), intent (in out ) :: self
91- ! ! Network instance
92- real , intent (in ) :: input(:)
93- ! ! Input data
94- real , allocatable :: res(:)
95- ! ! Output of the network
96- end function output
97-
98114 module subroutine print_info (self )
99115 ! ! Prints a brief summary of the network and its layers to the screen.
100116 class(network), intent (in ) :: self
0 commit comments