@@ -25,12 +25,14 @@ module mod_network
2525 procedure , public , pass(self) :: loss
2626 procedure , public , pass(self) :: output
2727 procedure , public , pass(self) :: save
28- procedure , public , pass(self) :: set_activation
28+ procedure , public , pass(self) :: set_activation_equal
29+ procedure , public , pass(self) :: set_activation_layers
2930 procedure , public , pass(self) :: sync
3031 procedure , public , pass(self) :: train_batch
3132 procedure , public , pass(self) :: train_single
3233 procedure , public , pass(self) :: update
3334
35+ generic, public :: set_activation = > set_activation_equal, set_activation_layers
3436 generic, public :: train = > train_batch, train_single
3537
3638 end type network_type
@@ -136,13 +138,18 @@ subroutine load(self, filename)
136138 ! Loads the network from file.
137139 class(network_type), intent (in out ) :: self
138140 character (len=* ), intent (in ) :: filename
139- integer (ik) :: fileunit, n, num_layers
141+ integer (ik) :: fileunit, n, num_layers, layer_idx
140142 integer (ik), allocatable :: dims(:)
143+ character (len= 100 ) :: buffer ! activation string
141144 open (newunit= fileunit, file= filename, status= ' old' , action= ' read' )
142145 read (fileunit, fmt=* ) num_layers
143146 allocate (dims(num_layers))
144147 read (fileunit, fmt=* ) dims
145148 call self % init(dims)
149+ do n = 1 , num_layers
150+ read (fileunit, fmt=* ) layer_idx, buffer
151+ call self % layers(layer_idx) % set_activation(trim (buffer))
152+ end do
146153 do n = 2 , size (self % dims)
147154 read (fileunit, fmt=* ) self % layers(n) % b
148155 end do
@@ -181,6 +188,9 @@ subroutine save(self, filename)
181188 open (newunit= fileunit, file= filename)
182189 write (fileunit, fmt=* ) size (self % dims)
183190 write (fileunit, fmt=* ) self % dims
191+ do n = 1 , size (self % dims)
192+ write (fileunit, fmt=* ) n, self % layers(n) % activation_str
193+ end do
184194 do n = 2 , size (self % dims)
185195 write (fileunit, fmt=* ) self % layers(n) % b
186196 end do
@@ -190,17 +200,23 @@ subroutine save(self, filename)
190200 close (fileunit)
191201 end subroutine save
192202
193- pure subroutine set_activation (self , activation )
203+ pure subroutine set_activation_equal (self , activation )
194204 ! A thin wrapper around layer % set_activation().
195205 ! This method can be used to set an activation function
196206 ! for all layers at once.
197207 class(network_type), intent (in out ) :: self
198208 character (len=* ), intent (in ) :: activation
199- integer :: n
200- do concurrent(n = 1 :size (self % layers))
201- call self % layers(n) % set_activation(activation)
202- end do
203- end subroutine set_activation
209+ call self % layers(:) % set_activation(activation)
210+ end subroutine set_activation_equal
211+
212+ pure subroutine set_activation_layers (self , activation )
213+ ! A thin wrapper around layer % set_activation().
214+ ! This method can be used to set different activation functions
215+ ! for each layer separately.
216+ class(network_type), intent (in out ) :: self
217+ character (len=* ), intent (in ) :: activation(size (self % layers))
218+ call self % layers(:) % set_activation(activation)
219+ end subroutine set_activation_layers
204220
205221 subroutine sync (self , image )
206222 ! Broadcasts network weights and biases from
0 commit comments