File tree Expand file tree Collapse file tree 2 files changed +14
-3
lines changed Expand file tree Collapse file tree 2 files changed +14
-3
lines changed Original file line number Diff line number Diff line change @@ -39,7 +39,7 @@ module nf_rnn_layer
3939 procedure :: get_num_params
4040 procedure :: get_params
4141 procedure :: init
42- ! procedure :: set_params
42+ procedure :: set_params
4343
4444 end type rnn_layer
4545
Original file line number Diff line number Diff line change @@ -85,21 +85,32 @@ end function get_gradients
8585 module subroutine set_params (self , params )
8686 class(rnn_layer), intent (in out ) :: self
8787 real , intent (in ) :: params(:)
88+ integer :: first, last
8889
8990 ! check if the number of parameters is correct
9091 if (size (params) /= self % get_num_params()) then
9192 error stop ' Error: number of parameters does not match'
9293 end if
9394
9495 ! reshape the weights
96+ last = self % input_size * self % output_size
9597 self % weights = reshape ( &
96- params(:self % input_size * self % output_size ), &
98+ params(:last ), &
9799 [self % input_size, self % output_size] &
98100 )
99101
102+ ! reshape the recurrent weights
103+ first = last + 1
104+ last = first + self % output_size * self % output_size
105+ self % recurrent = reshape ( &
106+ params(first:last), &
107+ [self % output_size, self % output_size] &
108+ )
109+
100110 ! reshape the biases
111+ first = last + 1
101112 self % biases = reshape ( &
102- params(self % input_size * self % output_size + 1 :), &
113+ params(first :), &
103114 [self % output_size] &
104115 )
105116
You can’t perform that action at this time.
0 commit comments