@@ -14,6 +14,7 @@ You may obtain a copy of the License at
1414 limitations under the License.
1515******************************************************************************/
1616
17+ using NumSharp ;
1718using System ;
1819using System . Collections . Generic ;
1920using System . Linq ;
@@ -24,7 +25,7 @@ namespace Tensorflow.Operations
2425{
2526 internal class rnn
2627 {
27- public static ( Tensor , Tensor ) dynamic_rnn ( RNNCell cell , Tensor inputs_tensor ,
28+ public static ( Tensor , Tensor ) dynamic_rnn ( RnnCell cell , Tensor inputs_tensor ,
2829 Tensor sequence_length = null , Tensor initial_state = null ,
2930 TF_DataType dtype = TF_DataType . DtInvalid ,
3031 int ? parallel_iterations = null , bool swap_memory = false , bool time_major = false )
@@ -79,7 +80,7 @@ public static (Tensor, Tensor) dynamic_rnn(RNNCell cell, Tensor inputs_tensor,
7980 /// <param name="sequence_length"></param>
8081 /// <param name="dtype"></param>
8182 /// <returns></returns>
82- private static ( Tensor , Tensor ) _dynamic_rnn_loop ( RNNCell cell , Tensor inputs , Tensor initial_state ,
83+ private static ( Tensor , Tensor ) _dynamic_rnn_loop ( RnnCell cell , Tensor inputs , Tensor initial_state ,
8384 int parallel_iterations , bool swap_memory , Tensor sequence_length = null , TF_DataType dtype = TF_DataType . DtInvalid )
8485 {
8586 var state = initial_state ;
@@ -170,11 +171,11 @@ private static (Tensor, Tensor) _dynamic_rnn_loop(RNNCell cell, Tensor inputs, T
170171 flat_input_i . dtype ) ) ;
171172 }
172173
173- for ( int i = 0 ; i < input_ta . Count ; i ++ )
174+ input_ta = zip ( input_ta , flat_input ) . Select ( x =>
174175 {
175- var ( ta , input_ ) = ( input_ta [ i ] , flat_input [ i ] ) ;
176- ta . unstack ( input_ ) ;
177- }
176+ var ( ta , input_ ) = ( x . Item1 , x . Item2 ) ;
177+ return ta . unstack ( input_ ) ;
178+ } ) . ToList ( ) ;
178179 }
179180
180181 // Make sure that we run at least 1 step, if necessary, to ensure
@@ -192,11 +193,29 @@ private static (Tensor, Tensor) _dynamic_rnn_loop(RNNCell cell, Tensor inputs, T
192193 // Take a time step of the dynamic RNN.
193194 Func < BodyItemInRnnWhileLoop , BodyItemInRnnWhileLoop > _time_step = ( item ) =>
194195 {
196+ Tensor [ ] input_t = null ;
197+ var ( time1 , output_ta_t , state1 ) = ( item . time , item . output_ta_t , item . state ) ;
195198 if ( in_graph_mode )
196199 {
197- input_ta . Select ( ta => ta . read ( time ) ) . ToArray ( ) ;
200+ input_t = input_ta . Select ( ta => ta . read ( time1 ) ) . ToArray ( ) ;
201+ // Restore some shape information
202+ foreach ( var ( input_ , shape ) in zip ( input_t , inputs_got_shape ) )
203+ input_ . set_shape ( shape [ new Slice ( 1 ) ] ) ;
204+ }
205+ else
206+ {
207+ // input_t = tuple(ta[time.numpy()] for ta in input_ta)
198208 }
199209
210+ var input_t_t = nest . pack_sequence_as2 ( structure : inputs , flat_sequence : input_t ) ;
211+ // Keras RNN cells only accept state as list, even if it's a single tensor.
212+ // var is_keras_rnn_cell = _is_keras_rnn_cell(cell);
213+ ( Tensor , Tensor ) a = ( null , null ) ;
214+ if ( sequence_length != null )
215+ throw new NotImplementedException ( "sequence_length != null" ) ;
216+ else
217+ a = cell . __call__ ( input_t_t , state1 ) ;
218+
200219 return item ;
201220 } ;
202221
0 commit comments