@@ -23,11 +23,13 @@ module nf_network
2323
2424 procedure , private :: forward_1d
2525 procedure , private :: forward_3d
26+ procedure , private :: forward_batch_3d
2627 procedure , private :: output_1d
2728 procedure , private :: output_3d
29+ procedure , private :: output_batch_3d
2830
29- generic :: forward = > forward_1d, forward_3d
30- generic :: output = > output_1d, output_3d
31+ generic :: forward = > forward_1d, forward_3d, forward_batch_3d
32+ generic :: output = > output_1d, output_3d, output_batch_3d
3133
3234 end type network
3335
@@ -83,6 +85,20 @@ pure module subroutine forward_3d(self, input)
8385 ! ! 3-d input data
8486 end subroutine forward_3d
8587
88+ pure module subroutine forward_batch_3d(self, input)
89+ ! ! Apply a forward pass of a batch of data through the network.
90+ ! !
91+ ! ! This changes the state of layers on the network.
92+ ! ! Typically used only internally from the `train` method,
93+ ! ! but can be invoked by the user when creating custom optimizers.
94+ ! !
95+ ! ! This specific subroutine is for 3-d input data.
96+ class(network), intent (in out ) :: self
97+ ! ! Network instance
98+ real , intent (in ) :: input(:,:,:,:)
99+ ! ! 3-d input data; the 4th dimension is the batch
100+ end subroutine forward_batch_3d
101+
86102 end interface forward
87103
88104 interface output
@@ -107,6 +123,16 @@ module function output_3d(self, input) result(res)
107123 ! ! Output of the network
108124 end function output_3d
109125
126+ module function output_batch_3d (self , input ) result(res)
127+ ! ! Return the output of the network given an input batch of 3-d data.
128+ class(network), intent (in out ) :: self
129+ ! ! Network instance
130+ real , intent (in ) :: input(:,:,:,:)
131+ ! ! Input data; the 4th dimension is the batch
132+ real , allocatable :: res(:,:)
133+ ! ! Output of the network; the 2nd dimension is the batch
134+ end function output_batch_3d
135+
110136 end interface output
111137
112138 interface
0 commit comments