@@ -26,18 +26,22 @@ module mod_network
2626
2727 procedure , public , pass(self) :: accuracy
2828 procedure , public , pass(self) :: backprop
29+ procedure , public , pass(self) :: fit_batch
2930 procedure , public , pass(self) :: fwdprop
3031 procedure , public , pass(self) :: init
3132 procedure , public , pass(self) :: load
3233 procedure , public , pass(self) :: loss
3334 procedure , public , pass(self) :: output
35+ procedure , public , pass(self) :: predict_batch
3436 procedure , public , pass(self) :: save
3537 procedure , public , pass(self) :: set_activation
3638 procedure , public , pass(self) :: sync
3739 procedure , public , pass(self) :: train_batch
3840 procedure , public , pass(self) :: train_single
3941 procedure , public , pass(self) :: update
4042
43+ generic, public :: fit = > fit_batch! , train_single
44+ generic, public :: predict = > predict_batch! , train_single
4145 generic, public :: train = > train_batch, train_single
4246
4347 end type network_type
@@ -117,6 +121,39 @@ pure subroutine backprop(self, y, dw, db)
117121
118122 end subroutine backprop
119123
124+ subroutine fit_batch (self , x , y , eta ,epochs ,batch_size )
125+ class(network_type), intent (in out ) :: self
126+ integer (ik),intent (in ),optional :: epochs,batch_size
127+ real (rk), intent (in ) :: x(:,:), y(:,:), eta
128+
129+ integer (ik):: i,n,nsamples,nbatch
130+ integer (ik):: num_epochs,num_batch_size
131+ integer (ik):: batch_start,batch_end
132+
133+ nsamples= size (y,dim= 2 )
134+
135+ num_epochs= 1
136+ if (present (epochs))num_epochs= epochs
137+
138+ num_batch_size= nsamples
139+ if (present (batch_size))num_batch_size= batch_size
140+
141+ nbatch= nsamples/ num_batch_size
142+
143+ epoch: do n= 1 ,num_epochs
144+ batch_end= 0
145+ mini_batches: do i= 1 ,nbatch
146+ batch_start= batch_end+1
147+ batch_end= batch_start+ batch_size-1
148+ if (i.eq. nbatch)batch_end= nsamples
149+
150+ call self% train(x(:,batch_start:batch_end),y(:,batch_start:batch_end),eta)
151+
152+ enddo mini_batches
153+ enddo epoch
154+
155+ end subroutine
156+
120157 pure subroutine fwdprop (self , x )
121158 ! Performs the forward propagation and stores arguments to activation
122159 ! functions and activations themselves for use in backprop.
@@ -188,6 +225,20 @@ pure function output(self, x) result(a)
188225 end associate
189226 end function output
190227
228+ pure function predict_batch (self , x ) result(a)
229+ class(network_type), intent (in ) :: self
230+ real (rk), intent (in ) :: x(:,:)
231+ real (rk), allocatable :: a(:,:)
232+
233+ integer (ik) :: i
234+
235+ allocate (a(self% dims(size (self% dims)),size (x,dim= 2 )))
236+ do i = 1 , size (x, dim= 2 )
237+ a(:,i)= self% output(x(:,i))
238+ enddo
239+
240+ end function predict_batch
241+
191242 subroutine save (self , filename )
192243 ! Saves the network to a file.
193244 class(network_type), intent (in out ) :: self
0 commit comments