Skip to content

Commit 1ef2ec1

Browse files
committed
GraphTensorArray write, size, stack.
1 parent a8a5156 commit 1ef2ec1

File tree

3 files changed

+76
-3
lines changed

3 files changed

+76
-3
lines changed

src/TensorFlowNET.Core/Operations/BasicRNNCell.cs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,12 +66,14 @@ protected override void build(TensorShape inputs_shape)
6666
built = true;
6767
}
6868

69-
protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null, Tensor state = null)
69+
protected override Tensor[] call(Tensor inputs, Tensor training = null, Tensor state = null)
7070
{
7171
// Most basic RNN: output = new_state = act(W * input + U * state + B).
7272
var concat = array_ops.concat(new[] { inputs, state }, 1);
7373
var gate_inputs = math_ops.matmul(concat, _kernel as RefVariable);
74-
return (inputs, inputs);
74+
gate_inputs = nn_ops.bias_add(gate_inputs, _bias as RefVariable);
75+
var output = _activation(gate_inputs, null);
76+
return new[] { output, output };
7577
}
7678
}
7779
}

src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ limitations under the License.
2222

2323
namespace Tensorflow.Operations
2424
{
25-
internal class _GraphTensorArray
25+
public class _GraphTensorArray
2626
{
2727
internal TF_DataType _dtype;
2828
public TF_DataType dtype => _dtype;
@@ -174,5 +174,57 @@ public Tensor read(Tensor index, string name = null)
174174

175175
return value;
176176
}
177+
178+
public TensorArray write(Tensor index, Tensor value, string name = null)
179+
{
180+
return tf_with(ops.name_scope(name, "TensorArrayWrite", new { _handle, index, value }), delegate
181+
{
182+
value = ops.convert_to_tensor(value, preferred_dtype: _dtype, name: "value");
183+
_maybe_colocate_with(value);
184+
var flow_out = gen_data_flow_ops.tensor_array_write_v3(
185+
handle: _handle,
186+
index: index,
187+
value: value,
188+
flow_in: _flow,
189+
name: name);
190+
191+
return tensor_array_ops.build_ta_with_new_flow(this, flow_out);
192+
});
193+
}
194+
195+
private Tensor size(string name = null)
196+
{
197+
return gen_data_flow_ops.tensor_array_size_v3(_handle, _flow, name: name);
198+
}
199+
200+
public Tensor stack(string name = null)
201+
{
202+
ops.colocate_with(_handle);
203+
return tf_with(ops.name_scope(name, "TensorArrayStack", new { _handle }), delegate
204+
{
205+
return gather(math_ops.range(0, size()), name: name);
206+
});
207+
}
208+
209+
public Tensor gather(Tensor indices, string name = null)
210+
{
211+
var element_shape = new TensorShape();
212+
213+
if (_element_shape.Count > 0)
214+
element_shape = _element_shape[0];
215+
216+
var value = gen_data_flow_ops.tensor_array_gather_v3(
217+
handle: _handle,
218+
indices: indices,
219+
flow_in: _flow,
220+
dtype: _dtype,
221+
name: name,
222+
element_shape: element_shape);
223+
224+
//if (element_shape != null)
225+
//value.set_shape(-1, element_shape.dims);
226+
227+
return value;
228+
}
177229
}
178230
}

src/TensorFlowNET.Core/Operations/tensor_array_ops.cs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using System;
22
using System.Collections.Generic;
33
using System.Text;
4+
using Tensorflow.Operations;
45

56
namespace Tensorflow
67
{
@@ -29,5 +30,23 @@ public static TensorArray build_ta_with_new_flow(TensorArray old_ta, Tensor flow
2930
new_impl._element_shape = impl._element_shape;
3031
return new_ta;
3132
}
33+
34+
public static TensorArray build_ta_with_new_flow(_GraphTensorArray old_ta, Tensor flow)
35+
{
36+
var impl = old_ta;
37+
38+
var new_ta = new TensorArray(
39+
dtype: impl.dtype,
40+
handle: impl.handle,
41+
flow: flow,
42+
infer_shape: impl.infer_shape,
43+
colocate_with_first_write_call: impl.colocate_with_first_write_call);
44+
45+
var new_impl = new_ta._implementation;
46+
new_impl._dynamic_size = impl._dynamic_size;
47+
new_impl._colocate_with = impl._colocate_with;
48+
new_impl._element_shape = impl._element_shape;
49+
return new_ta;
50+
}
3251
}
3352
}

0 commit comments

Comments
 (0)