Skip to content

Commit 8f26d56

Browse files
committed
feat: set_params()
1 parent b07fb9b commit 8f26d56

File tree

2 files changed

+14
-3
lines changed

2 files changed

+14
-3
lines changed

src/nf/nf_rnn_layer.f90

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ module nf_rnn_layer
3939
procedure :: get_num_params
4040
procedure :: get_params
4141
procedure :: init
42-
!procedure :: set_params
42+
procedure :: set_params
4343

4444
end type rnn_layer
4545

src/nf/nf_rnn_layer_submodule.f90

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,21 +85,32 @@ end function get_gradients
8585
module subroutine set_params(self, params)
8686
class(rnn_layer), intent(in out) :: self
8787
real, intent(in) :: params(:)
88+
integer :: first, last
8889

8990
! check if the number of parameters is correct
9091
if (size(params) /= self % get_num_params()) then
9192
error stop 'Error: number of parameters does not match'
9293
end if
9394

9495
! reshape the weights
96+
last = self % input_size * self % output_size
9597
self % weights = reshape( &
96-
params(:self % input_size * self % output_size), &
98+
params(:last), &
9799
[self % input_size, self % output_size] &
98100
)
99101

102+
! reshape the recurrent weights
103+
first = last + 1
104+
last = first + self % output_size * self % output_size
105+
self % recurrent = reshape( &
106+
params(first:last), &
107+
[self % output_size, self % output_size] &
108+
)
109+
100110
! reshape the biases
111+
first = last + 1
101112
self % biases = reshape( &
102-
params(self % input_size * self % output_size + 1:), &
113+
params(first:), &
103114
[self % output_size] &
104115
)
105116

0 commit comments

Comments
 (0)