Skip to content

Commit ea45129

Browse files
authored
Merge pull request #17 from ivan-pi/master
Preserve network metadata in save() and load()
2 parents 7d81f7b + 3566212 commit ea45129

File tree

3 files changed

+46
-9
lines changed

3 files changed

+46
-9
lines changed

src/lib/mod_layer.f90

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ module mod_layer
1818
real(rk), allocatable :: z(:) ! arg. to activation function
1919
procedure(activation_function), pointer, nopass :: activation => null()
2020
procedure(activation_function), pointer, nopass :: activation_prime => null()
21+
character(len=:), allocatable :: activation_str ! activation character string
2122
contains
2223
procedure, public, pass(self) :: set_activation
2324
end type layer_type
@@ -115,7 +116,7 @@ subroutine dw_co_sum(dw)
115116
end do
116117
end subroutine dw_co_sum
117118

118-
pure subroutine set_activation(self, activation)
119+
pure elemental subroutine set_activation(self, activation)
119120
! Sets the activation function. Input string must match one of
120121
! provided activation functions, otherwise it defaults to sigmoid.
121122
! If activation not present, defaults to sigmoid.
@@ -125,21 +126,27 @@ pure subroutine set_activation(self, activation)
125126
case('gaussian')
126127
self % activation => gaussian
127128
self % activation_prime => gaussian_prime
129+
self % activation_str = 'gaussian'
128130
case('relu')
129131
self % activation => relu
130132
self % activation_prime => relu_prime
133+
self % activation_str = 'relu'
131134
case('sigmoid')
132135
self % activation => sigmoid
133136
self % activation_prime => sigmoid_prime
137+
self % activation_str = 'sigmoid'
134138
case('step')
135139
self % activation => step
136140
self % activation_prime => step_prime
141+
self % activation_str = 'step'
137142
case('tanh')
138143
self % activation => tanhf
139144
self % activation_prime => tanh_prime
145+
self % activation_str = 'tanh'
140146
case default
141147
self % activation => sigmoid
142148
self % activation_prime => sigmoid_prime
149+
self % activation_str = 'sigmoid'
143150
end select
144151
end subroutine set_activation
145152

src/lib/mod_network.f90

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,14 @@ module mod_network
2525
procedure, public, pass(self) :: loss
2626
procedure, public, pass(self) :: output
2727
procedure, public, pass(self) :: save
28-
procedure, public, pass(self) :: set_activation
28+
procedure, public, pass(self) :: set_activation_equal
29+
procedure, public, pass(self) :: set_activation_layers
2930
procedure, public, pass(self) :: sync
3031
procedure, public, pass(self) :: train_batch
3132
procedure, public, pass(self) :: train_single
3233
procedure, public, pass(self) :: update
3334

35+
generic, public :: set_activation => set_activation_equal, set_activation_layers
3436
generic, public :: train => train_batch, train_single
3537

3638
end type network_type
@@ -136,13 +138,18 @@ subroutine load(self, filename)
136138
! Loads the network from file.
137139
class(network_type), intent(in out) :: self
138140
character(len=*), intent(in) :: filename
139-
integer(ik) :: fileunit, n, num_layers
141+
integer(ik) :: fileunit, n, num_layers, layer_idx
140142
integer(ik), allocatable :: dims(:)
143+
character(len=100) :: buffer ! activation string
141144
open(newunit=fileunit, file=filename, status='old', action='read')
142145
read(fileunit, fmt=*) num_layers
143146
allocate(dims(num_layers))
144147
read(fileunit, fmt=*) dims
145148
call self % init(dims)
149+
do n = 1, num_layers
150+
read(fileunit, fmt=*) layer_idx, buffer
151+
call self % layers(layer_idx) % set_activation(trim(buffer))
152+
end do
146153
do n = 2, size(self % dims)
147154
read(fileunit, fmt=*) self % layers(n) % b
148155
end do
@@ -181,6 +188,9 @@ subroutine save(self, filename)
181188
open(newunit=fileunit, file=filename)
182189
write(fileunit, fmt=*) size(self % dims)
183190
write(fileunit, fmt=*) self % dims
191+
do n = 1, size(self % dims)
192+
write(fileunit, fmt=*) n, self % layers(n) % activation_str
193+
end do
184194
do n = 2, size(self % dims)
185195
write(fileunit, fmt=*) self % layers(n) % b
186196
end do
@@ -190,17 +200,23 @@ subroutine save(self, filename)
190200
close(fileunit)
191201
end subroutine save
192202

193-
pure subroutine set_activation(self, activation)
203+
pure subroutine set_activation_equal(self, activation)
194204
! A thin wrapper around layer % set_activation().
195205
! This method can be used to set an activation function
196206
! for all layers at once.
197207
class(network_type), intent(in out) :: self
198208
character(len=*), intent(in) :: activation
199-
integer :: n
200-
do concurrent(n = 1:size(self % layers))
201-
call self % layers(n) % set_activation(activation)
202-
end do
203-
end subroutine set_activation
209+
call self % layers(:) % set_activation(activation)
210+
end subroutine set_activation_equal
211+
212+
pure subroutine set_activation_layers(self, activation)
213+
! A thin wrapper around layer % set_activation().
214+
! This method can be used to set different activation functions
215+
! for each layer separately.
216+
class(network_type), intent(in out) :: self
217+
character(len=*), intent(in) :: activation(size(self % layers))
218+
call self % layers(:) % set_activation(activation)
219+
end subroutine set_activation_layers
204220

205221
subroutine sync(self, image)
206222
! Broadcasts network weights and biases from

src/tests/test_network_save.f90

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ program test_network_save
66
print *, 'Initializing 2 networks with random weights and biases'
77
net1 = network_type([768, 30, 10])
88
net2 = network_type([768, 30, 10])
9+
910
print *, 'Save network 1 into file'
1011
call net1 % save('test_network.dat')
1112
call net2 % load('test_network.dat')
@@ -15,4 +16,17 @@ program test_network_save
1516
all(net1 % layers(n) % w == net2 % layers(n) % w),&
1617
', biases equal:', all(net1 % layers(n) % b == net2 % layers(n) % b)
1718
end do
19+
print *, ''
20+
21+
print *, 'Setting different activation functions for each layer of network 1'
22+
call net1 % set_activation([character(len=10) :: 'sigmoid', 'tanh', 'gaussian'])
23+
print *, 'Save network 1 into file'
24+
call net1 % save('test_network.dat')
25+
call net2 % load('test_network.dat')
26+
print *, 'Load network 2 from file'
27+
do n = 1, size(net1 % layers)
28+
print *, 'Layer ', n, ', activation functions equal:',&
29+
associated(net1 % layers(n) % activation, net2 % layers(n) % activation),&
30+
'(network 1: ', net1 % layers(n) % activation_str, ', network 2: ', net2 % layers(n) % activation_str,')'
31+
end do
1832
end program test_network_save

0 commit comments

Comments
 (0)