Skip to content

Commit d1e1e05

Browse files
committed
inputs for rnn/while/TensorArrayReadV3 are incorrect #433
1 parent 8243807 commit d1e1e05

File tree

1 file changed

+26
-7
lines changed
  • src/TensorFlowNET.Core/Operations/NnOps

1 file changed

+26
-7
lines changed

src/TensorFlowNET.Core/Operations/NnOps/rnn.cs

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ You may obtain a copy of the License at
1414
limitations under the License.
1515
******************************************************************************/
1616

17+
using NumSharp;
1718
using System;
1819
using System.Collections.Generic;
1920
using System.Linq;
@@ -24,7 +25,7 @@ namespace Tensorflow.Operations
2425
{
2526
internal class rnn
2627
{
27-
public static (Tensor, Tensor) dynamic_rnn(RNNCell cell, Tensor inputs_tensor,
28+
public static (Tensor, Tensor) dynamic_rnn(RnnCell cell, Tensor inputs_tensor,
2829
Tensor sequence_length = null, Tensor initial_state = null,
2930
TF_DataType dtype = TF_DataType.DtInvalid,
3031
int? parallel_iterations = null, bool swap_memory = false, bool time_major = false)
@@ -79,7 +80,7 @@ public static (Tensor, Tensor) dynamic_rnn(RNNCell cell, Tensor inputs_tensor,
7980
/// <param name="sequence_length"></param>
8081
/// <param name="dtype"></param>
8182
/// <returns></returns>
82-
private static (Tensor, Tensor) _dynamic_rnn_loop(RNNCell cell, Tensor inputs, Tensor initial_state,
83+
private static (Tensor, Tensor) _dynamic_rnn_loop(RnnCell cell, Tensor inputs, Tensor initial_state,
8384
int parallel_iterations, bool swap_memory, Tensor sequence_length = null, TF_DataType dtype = TF_DataType.DtInvalid)
8485
{
8586
var state = initial_state;
@@ -170,11 +171,11 @@ private static (Tensor, Tensor) _dynamic_rnn_loop(RNNCell cell, Tensor inputs, T
170171
flat_input_i.dtype));
171172
}
172173

173-
for (int i = 0; i < input_ta.Count; i++)
174+
input_ta = zip(input_ta, flat_input).Select(x =>
174175
{
175-
var (ta, input_) = (input_ta[i], flat_input[i]);
176-
ta.unstack(input_);
177-
}
176+
var (ta, input_) = (x.Item1, x.Item2);
177+
return ta.unstack(input_);
178+
}).ToList();
178179
}
179180

180181
// Make sure that we run at least 1 step, if necessary, to ensure
@@ -192,11 +193,29 @@ private static (Tensor, Tensor) _dynamic_rnn_loop(RNNCell cell, Tensor inputs, T
192193
// Take a time step of the dynamic RNN.
193194
Func<BodyItemInRnnWhileLoop, BodyItemInRnnWhileLoop> _time_step = (item) =>
194195
{
196+
Tensor[] input_t = null;
197+
var (time1, output_ta_t, state1) = (item.time, item.output_ta_t, item.state);
195198
if (in_graph_mode)
196199
{
197-
input_ta.Select(ta => ta.read(time)).ToArray();
200+
input_t = input_ta.Select(ta => ta.read(time1)).ToArray();
201+
// Restore some shape information
202+
foreach (var (input_, shape) in zip(input_t, inputs_got_shape))
203+
input_.set_shape(shape[new Slice(1)]);
204+
}
205+
else
206+
{
207+
// input_t = tuple(ta[time.numpy()] for ta in input_ta)
198208
}
199209

210+
var input_t_t = nest.pack_sequence_as2(structure: inputs, flat_sequence: input_t);
211+
// Keras RNN cells only accept state as list, even if it's a single tensor.
212+
// var is_keras_rnn_cell = _is_keras_rnn_cell(cell);
213+
(Tensor, Tensor) a = (null, null);
214+
if (sequence_length != null)
215+
throw new NotImplementedException("sequence_length != null");
216+
else
217+
a = cell.__call__(input_t_t, state1);
218+
200219
return item;
201220
};
202221

0 commit comments

Comments
 (0)