@@ -13,17 +13,17 @@ namespace Tensorflow.Keras.Layers
1313 /// </summary>
1414 public class Bidirectional : Wrapper
1515 {
16- BidirectionalArgs _args ;
17- RNN _forward_layer ;
18- RNN _backward_layer ;
19- RNN _layer ;
20- bool _support_masking = true ;
2116 int _num_constants = 0 ;
17+ bool _support_masking = true ;
2218 bool _return_state ;
2319 bool _stateful ;
2420 bool _return_sequences ;
25- InputSpec _input_spec ;
21+ BidirectionalArgs _args ;
2622 RNNArgs _layer_args_copy ;
23+ RNN _forward_layer ;
24+ RNN _backward_layer ;
25+ RNN _layer ;
26+ InputSpec _input_spec ;
2727 public Bidirectional ( BidirectionalArgs args ) : base ( args )
2828 {
2929 _args = args ;
@@ -66,12 +66,16 @@ public Bidirectional(BidirectionalArgs args):base(args)
6666
6767 // Recreate the forward layer from the original layer config, so that it
6868 // will not carry over any state from the layer.
69- var actualType = _layer . GetType ( ) ;
70- if ( actualType == typeof ( LSTM ) )
69+ if ( _layer is LSTM )
7170 {
7271 var arg = _layer_args_copy as LSTMArgs ;
7372 _forward_layer = new LSTM ( arg ) ;
7473 }
74+ else if ( _layer is SimpleRNN )
75+ {
76+ var arg = _layer_args_copy as SimpleRNNArgs ;
77+ _forward_layer = new SimpleRNN ( arg ) ;
78+ }
7579 // TODO(Wanglongzhi2001), add GRU if case.
7680 else
7781 {
@@ -154,12 +158,18 @@ private RNN _recreate_layer_from_config(RNN layer, bool go_backwards = false)
154158 {
155159 config . GoBackwards = ! config . GoBackwards ;
156160 }
157- var actualType = layer . GetType ( ) ;
158- if ( actualType == typeof ( LSTM ) )
161+
162+ if ( layer is LSTM )
159163 {
160164 var arg = config as LSTMArgs ;
161165 return new LSTM ( arg ) ;
162166 }
167+ else if ( layer is SimpleRNN )
168+ {
169+ var arg = config as SimpleRNNArgs ;
170+ return new SimpleRNN ( arg ) ;
171+ }
172+ // TODO(Wanglongzhi2001), add GRU if case.
163173 else
164174 {
165175 return new RNN ( cell , config ) ;
@@ -183,7 +193,6 @@ public override void build(KerasShapesWrapper input_shape)
183193 protected override Tensors Call ( Tensors inputs , Tensors state = null , bool ? training = null , IOptionalArgs ? optional_args = null )
184194 {
185195 // `Bidirectional.call` implements the same API as the wrapped `RNN`.
186-
187196 Tensors forward_inputs ;
188197 Tensors backward_inputs ;
189198 Tensors forward_state ;
0 commit comments