@@ -11,14 +11,22 @@ module nf_network
1111 public :: network
1212
1313 type :: network
14+
1415 type (layer), allocatable :: layers(:)
16+
1517 contains
18+
1619 procedure :: backward
17- procedure :: forward
1820 procedure :: output
1921 procedure :: print_info
2022 procedure :: train
2123 procedure :: update
24+
25+ procedure , private :: forward_1d
26+ procedure , private :: forward_3d
27+
28+ generic :: forward = > forward_1d, forward_3d
29+
2230 end type network
2331
2432 interface network
@@ -32,29 +40,50 @@ module function network_cons(layers) result(res)
3240 end function network_cons
3341 end interface network
3442
35- interface
43+ interface forward
3644
37- pure module subroutine backward(self, output)
38- ! ! Apply one backward pass through the network.
45+ pure module subroutine forward_1d(self, input)
46+ ! ! Apply a forward pass through the network.
47+ ! !
3948 ! ! This changes the state of layers on the network.
4049 ! ! Typically used only internally from the `train` method,
4150 ! ! but can be invoked by the user when creating custom optimizers.
51+ ! !
52+ ! ! This specific subroutine is for 1-d input data.
4253 class(network), intent (in out ) :: self
4354 ! ! Network instance
44- real , intent (in ) :: output (:)
45- ! ! Output data
46- end subroutine backward
55+ real , intent (in ) :: input (:)
56+ ! ! 1-d input data
57+ end subroutine forward_1d
4758
48- pure module subroutine forward (self, input)
59+ pure module subroutine forward_3d (self, input)
4960 ! ! Apply a forward pass through the network.
61+ ! !
5062 ! ! This changes the state of layers on the network.
5163 ! ! Typically used only internally from the `train` method,
5264 ! ! but can be invoked by the user when creating custom optimizers.
65+ ! !
66+ ! ! This specific subroutine is for 3-d input data.
5367 class(network), intent (in out ) :: self
5468 ! ! Network instance
55- real , intent (in ) :: input(:)
56- ! ! Input data
57- end subroutine forward
69+ real , intent (in ) :: input(:,:,:)
70+ ! ! 3-d input data
71+ end subroutine forward_3d
72+
73+ end interface forward
74+
75+ interface
76+
77+ pure module subroutine backward(self, output)
78+ ! ! Apply one backward pass through the network.
79+ ! ! This changes the state of layers on the network.
80+ ! ! Typically used only internally from the `train` method,
81+ ! ! but can be invoked by the user when creating custom optimizers.
82+ class(network), intent (in out ) :: self
83+ ! ! Network instance
84+ real , intent (in ) :: output(:)
85+ ! ! Output data
86+ end subroutine backward
5887
5988 module function output (self , input ) result(res)
6089 ! ! Return the output of the network given the input array.
0 commit comments