Skip to content

Commit 3f4387e

Browse files
committed
Constructor for RNN
1 parent d2aacdd commit 3f4387e

File tree

2 files changed

+47
-1
lines changed

2 files changed

+47
-1
lines changed

src/nf/nf_layer_constructors.f90

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,29 @@ pure module function reshape(output_shape) result(res)
166166
!! Resulting layer instance
167167
end function reshape
168168

169+
pure module function rnn(layer_size, activation) result(res)
170+
!! Recurrent (fully-connected) layer constructor.
171+
!!
172+
!! This layer is a building block for recurrent, fully-connected
173+
!! networks, or for an output layer of a convolutional network.
174+
!! A recurrent layer must not be the first layer in the network.
175+
!!
176+
!! Example:
177+
!!
178+
!! ```
179+
!! use nf, only :: rnn, layer, relu
180+
!! type(layer) :: rnn_layer
181+
!! rnn_layer = rnn(10)
182+
!! rnn_layer = rnn(10, activation=relu())
183+
!! ```
184+
integer, intent(in) :: layer_size
185+
!! The number of neurons in a dense layer
186+
class(activation_function), intent(in), optional :: activation
187+
!! Activation function instance (default tanh)
188+
type(layer) :: res
189+
!! Resulting layer instance
190+
end function rnn
191+
169192
end interface
170193

171194
end module nf_layer_constructors

src/nf/nf_layer_constructors_submodule.f90

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
use nf_maxpool2d_layer, only: maxpool2d_layer
1010
use nf_reshape_layer, only: reshape3d_layer
1111
use nf_rnn_layer, only: rnn_layer
12-
use nf_activation, only: activation_function, relu, sigmoid
12+
use nf_activation, only: activation_function, relu, sigmoid, tanhf
1313

1414
implicit none
1515

@@ -135,4 +135,27 @@ pure module function reshape(output_shape) result(res)
135135

136136
end function reshape
137137

138+
pure module function rnn(layer_size, activation) result(res)
139+
integer, intent(in) :: layer_size
140+
class(activation_function), intent(in), optional :: activation
141+
type(layer) :: res
142+
143+
class(activation_function), allocatable :: activation_tmp
144+
145+
res % name = 'rnn'
146+
res % layer_shape = [layer_size]
147+
148+
if (present(activation)) then
149+
allocate(activation_tmp, source=activation)
150+
else
151+
allocate(activation_tmp, source=tanhf())
152+
end if
153+
154+
res % activation = activation_tmp % get_name()
155+
156+
allocate(res % p, source=rnn_layer(layer_size, activation_tmp))
157+
158+
end function rnn
159+
160+
138161
end submodule nf_layer_constructors_submodule

0 commit comments

Comments
 (0)