@@ -26,14 +26,16 @@ module mod_network
2626 procedure , public , pass(self) :: output_batch
2727 procedure , public , pass(self) :: output_single
2828 procedure , public , pass(self) :: save
29- procedure , public , pass(self) :: set_activation
29+ procedure , public , pass(self) :: set_activation_equal
30+ procedure , public , pass(self) :: set_activation_layers
3031 procedure , public , pass(self) :: sync
3132 procedure , public , pass(self) :: train_batch
3233 procedure , public , pass(self) :: train_epochs
3334 procedure , public , pass(self) :: train_single
3435 procedure , public , pass(self) :: update
3536
3637 generic, public :: output = > output_batch, output_single
38+ generic, public :: set_activation = > set_activation_equal, set_activation_layers
3739 generic, public :: train = > train_batch, train_epochs, train_single
3840
3941 end type network_type
@@ -139,13 +141,18 @@ subroutine load(self, filename)
139141 ! Loads the network from file.
140142 class(network_type), intent (in out ) :: self
141143 character (len=* ), intent (in ) :: filename
142- integer (ik) :: fileunit, n, num_layers
144+ integer (ik) :: fileunit, n, num_layers, layer_idx
143145 integer (ik), allocatable :: dims(:)
146+ character (len= 100 ) :: buffer ! activation string
144147 open (newunit= fileunit, file= filename, status= ' old' , action= ' read' )
145148 read (fileunit, fmt=* ) num_layers
146149 allocate (dims(num_layers))
147150 read (fileunit, fmt=* ) dims
148151 call self % init(dims)
152+ do n = 1 , num_layers
153+ read (fileunit, fmt=* ) layer_idx, buffer
154+ call self % layers(layer_idx) % set_activation(trim (buffer))
155+ end do
149156 do n = 2 , size (self % dims)
150157 read (fileunit, fmt=* ) self % layers(n) % b
151158 end do
@@ -198,6 +205,9 @@ subroutine save(self, filename)
198205 open (newunit= fileunit, file= filename)
199206 write (fileunit, fmt=* ) size (self % dims)
200207 write (fileunit, fmt=* ) self % dims
208+ do n = 1 , size (self % dims)
209+ write (fileunit, fmt=* ) n, self % layers(n) % activation_str
210+ end do
201211 do n = 2 , size (self % dims)
202212 write (fileunit, fmt=* ) self % layers(n) % b
203213 end do
@@ -207,17 +217,23 @@ subroutine save(self, filename)
207217 close (fileunit)
208218 end subroutine save
209219
210- pure subroutine set_activation (self , activation )
220+ pure subroutine set_activation_equal (self , activation )
211221 ! A thin wrapper around layer % set_activation().
212222 ! This method can be used to set an activation function
213223 ! for all layers at once.
214224 class(network_type), intent (in out ) :: self
215225 character (len=* ), intent (in ) :: activation
216- integer :: n
217- do concurrent(n = 1 :size (self % layers))
218- call self % layers(n) % set_activation(activation)
219- end do
220- end subroutine set_activation
226+ call self % layers(:) % set_activation(activation)
227+ end subroutine set_activation_equal
228+
229+ pure subroutine set_activation_layers (self , activation )
230+ ! A thin wrapper around layer % set_activation().
231+ ! This method can be used to set different activation functions
232+ ! for each layer separately.
233+ class(network_type), intent (in out ) :: self
234+ character (len=* ), intent (in ) :: activation(size (self % layers))
235+ call self % layers(:) % set_activation(activation)
236+ end subroutine set_activation_layers
221237
222238 subroutine sync (self , image )
223239 ! Broadcasts network weights and biases from
@@ -227,8 +243,10 @@ subroutine sync(self, image)
227243 integer (ik) :: n
228244 if (num_images() == 1 ) return
229245 layers: do n = 1 , size (self % dims)
246+ #ifdef CAF
230247 call co_broadcast(self % layers(n) % b, image)
231248 call co_broadcast(self % layers(n) % w, image)
249+ #endif
232250 end do layers
233251 end subroutine sync
234252
0 commit comments