Skip to content

Commit d6448a5

Browse files
author
Vandenplas, Jeremie
committed
modified as suggested by Milan
1 parent a65df50 commit d6448a5

File tree

4 files changed

+45
-53
lines changed

4 files changed

+45
-53
lines changed

src/lib/mod_network.f90

Lines changed: 40 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -26,23 +26,22 @@ module mod_network
2626

2727
procedure, public, pass(self) :: accuracy
2828
procedure, public, pass(self) :: backprop
29-
procedure, public, pass(self) :: fit_batch
3029
procedure, public, pass(self) :: fwdprop
3130
procedure, public, pass(self) :: init
3231
procedure, public, pass(self) :: load
3332
procedure, public, pass(self) :: loss
34-
procedure, public, pass(self) :: output
35-
procedure, public, pass(self) :: predict_batch
33+
procedure, public, pass(self) :: output_batch
34+
procedure, public, pass(self) :: output_single
3635
procedure, public, pass(self) :: save
3736
procedure, public, pass(self) :: set_activation
3837
procedure, public, pass(self) :: sync
3938
procedure, public, pass(self) :: train_batch
39+
procedure, public, pass(self) :: train_epochs
4040
procedure, public, pass(self) :: train_single
4141
procedure, public, pass(self) :: update
4242

43-
generic, public :: fit => fit_batch!, train_single
44-
generic, public :: predict => predict_batch!, train_single
45-
generic, public :: train => train_batch, train_single
43+
generic, public :: output => output_batch, output_single
44+
generic, public :: train => train_batch, train_epochs, train_single
4645

4746
end type network_type
4847

@@ -121,44 +120,6 @@ pure subroutine backprop(self, y, dw, db)
121120

122121
end subroutine backprop
123122

124-
subroutine fit_batch(self, x, y, eta,epochs,batch_size)
125-
!Performs the training for n epochs with mini-bachtes of size equal to batch_size
126-
class(network_type), intent(in out) :: self
127-
integer(ik),intent(in),optional::epochs,batch_size
128-
real(rk), intent(in) :: x(:,:), y(:,:), eta
129-
130-
integer(ik)::i,n,nsamples,nbatch
131-
integer(ik)::num_epochs,num_batch_size
132-
integer(ik)::batch_start,batch_end
133-
134-
real(rk)::pos
135-
136-
nsamples=size(y,dim=2)
137-
138-
num_epochs=1
139-
if(present(epochs))num_epochs=epochs
140-
141-
num_batch_size=nsamples
142-
if(present(batch_size))num_batch_size=batch_size
143-
144-
nbatch=nsamples/num_batch_size
145-
146-
epoch: do n=1,num_epochs
147-
mini_batches: do i=1,nbatch
148-
149-
!pull a random mini-batch from the dataset
150-
call random_number(pos)
151-
batch_start=int(pos*(nsamples-num_batch_size+1))
152-
if(batch_start.eq.0)batch_start=1
153-
batch_end=batch_start+batch_size-1
154-
155-
call self%train(x(:,batch_start:batch_end),y(:,batch_start:batch_end),eta)
156-
157-
enddo mini_batches
158-
enddo epoch
159-
160-
end subroutine
161-
162123
pure subroutine fwdprop(self, x)
163124
! Performs the forward propagation and stores arguments to activation
164125
! functions and activations themselves for use in backprop.
@@ -216,7 +177,7 @@ pure real(rk) function loss(self, x, y)
216177
loss = 0.5 * sum((y - self % output(x))**2) / size(x)
217178
end function loss
218179

219-
pure function output(self, x) result(a)
180+
pure function output_single(self, x) result(a)
220181
! Use forward propagation to compute the output of the network.
221182
class(network_type), intent(in) :: self
222183
real(rk), intent(in) :: x(:)
@@ -228,9 +189,9 @@ pure function output(self, x) result(a)
228189
a = self % activation(matmul(transpose(layers(n-1) % w), a) + layers(n) % b)
229190
end do
230191
end associate
231-
end function output
192+
end function output_single
232193

233-
pure function predict_batch(self, x) result(a)
194+
pure function output_batch(self, x) result(a)
234195
class(network_type), intent(in) :: self
235196
real(rk), intent(in) :: x(:,:)
236197
real(rk), allocatable :: a(:,:)
@@ -242,7 +203,7 @@ pure function predict_batch(self, x) result(a)
242203
a(:,i)=self%output(x(:,i))
243204
enddo
244205

245-
end function predict_batch
206+
end function output_batch
246207

247208
subroutine save(self, filename)
248209
! Saves the network to a file.
@@ -342,6 +303,37 @@ subroutine train_batch(self, x, y, eta)
342303

343304
end subroutine train_batch
344305

306+
subroutine train_epochs(self, x, y, eta,num_epochs,num_batch_size)
307+
!Performs the training for nun_epochs epochs with mini-bachtes of size equal to num_batch_size
308+
class(network_type), intent(in out) :: self
309+
integer(ik),intent(in)::num_epochs,num_batch_size
310+
real(rk), intent(in) :: x(:,:), y(:,:), eta
311+
312+
integer(ik)::i,n,nsamples,nbatch
313+
integer(ik)::batch_start,batch_end
314+
315+
real(rk)::pos
316+
317+
nsamples=size(y,dim=2)
318+
319+
nbatch=nsamples/num_batch_size
320+
321+
epoch: do n=1,num_epochs
322+
mini_batches: do i=1,nbatch
323+
324+
!pull a random mini-batch from the dataset
325+
call random_number(pos)
326+
batch_start=int(pos*(nsamples-num_batch_size+1))
327+
if(batch_start.eq.0)batch_start=1
328+
batch_end=batch_start+num_batch_size-1
329+
330+
call self%train(x(:,batch_start:batch_end),y(:,batch_start:batch_end),eta)
331+
332+
enddo mini_batches
333+
enddo epoch
334+
335+
end subroutine train_epochs
336+
345337
pure subroutine train_single(self, x, y, eta)
346338
! Trains a network using a single set of input data x and output data y,
347339
! and learning rate eta.

src/tests/example_mnist.f90

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ program example_mnist
3030
net % accuracy(te_images, label_digits(te_labels)) * 100, ' %'
3131
end if
3232

33-
call net%fit(tr_images,label_digits(tr_labels),eta=3._rk,epochs=num_epochs,batch_size=batch_size)
33+
call net%train(tr_images,label_digits(tr_labels),3._rk,num_epochs,batch_size)
3434

3535
if (this_image() == 1) then
3636
write(*, '(a,f5.2,a)') 'Epochs done, Accuracy: ',&

src/tests/example_montesinos_multi.f90

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ program example_montesinos_multi
2121
num_epochs=50
2222

2323
!training
24-
call net%fit(x_tr,y_tr,3._rk,epochs=num_epochs,batch_size=batch_size)
24+
call net%train(x_tr,y_tr,3._rk,num_epochs,batch_size)
2525

2626
call net%sync(1)
2727

@@ -30,7 +30,7 @@ program example_montesinos_multi
3030
call readfile('../data/montesinos_multi/x_ts.dat',nx1_ts,nx2_ts,x_ts)
3131

3232
if(this_image().eq.1)then
33-
write(*,*)'Correlation(s): ',corr_array(net%predict(x_ts),y_ts)
33+
write(*,*)'Correlation(s): ',corr_array(net%output(x_ts),y_ts)
3434
endif
3535

3636
contains

src/tests/example_montesinos_uni.f90

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ program example_montesinos_uni
2222
num_epochs=20
2323

2424
!training
25-
call net%fit(x_tr,y_tr,3._rk,epochs=num_epochs,batch_size=batch_size)
25+
call net%train(x_tr,y_tr,3._rk,num_epochs,batch_size)
2626

2727
call net%sync(1)
2828

@@ -31,7 +31,7 @@ program example_montesinos_uni
3131
call readfile('../data/montesinos_uni/x_ts.dat',nx1_ts,nx2_ts,x_ts)
3232

3333
if(this_image().eq.1)then
34-
write(*,*)'Correlation(s): ',corr_array(net%predict(x_ts),y_ts)
34+
write(*,*)'Correlation(s): ',corr_array(net%output(x_ts),y_ts)
3535
endif
3636

3737
contains

0 commit comments

Comments
 (0)