Skip to content

Commit b406e88

Browse files
committed
Apply loss function if RNN is the output layer
1 parent 87fbbab commit b406e88

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

src/nf/nf_network_submodule.f90

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,11 @@ pure module subroutine backward(self, output, loss)
333333
self % layers(n - 1), &
334334
self % loss % derivative(output, this_layer % output) &
335335
)
336+
type is(rnn_layer)
337+
call self % layers(n) % backward( &
338+
self % layers(n - 1), &
339+
quadratic_derivative(output, this_layer % output) &
340+
)
336341
end select
337342
else
338343
! Hidden layer; take the gradient from the next layer

0 commit comments

Comments
 (0)