Skip to content

Commit e27145e

Browse files
committed
refactor TensorArray #903
1 parent 271b066 commit e27145e

File tree

9 files changed

+273
-81
lines changed

9 files changed

+273
-81
lines changed

src/TensorFlowNET.Core/APIs/tf.array.cs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ limitations under the License.
1919
using System.Diagnostics;
2020
using System.Linq;
2121
using static Tensorflow.Binding;
22+
using Tensorflow.Operations;
2223

2324
namespace Tensorflow
2425
{
@@ -309,5 +310,27 @@ public Tensor zeros_like(NDArray nd, TF_DataType dtype = TF_DataType.DtInvalid,
309310
/// <returns></returns>
310311
public Tensor stop_gradient(Tensor x, string name = null)
311312
=> gen_array_ops.stop_gradient(x, name: name);
313+
314+
public TensorArray TensorArray(TF_DataType dtype, int size = 0, bool dynamic_size = false,
315+
bool clear_after_read = true, Shape? element_shape = null, bool colocate_with_first_write_call = true,
316+
bool infer_shape = true)
317+
=> tf.executing_eagerly() ?
318+
new _EagerTensorArray(dtype, size: constant_op.constant(size), dynamic_size: dynamic_size,
319+
clear_after_read: clear_after_read, element_shape: element_shape, infer_shape: infer_shape,
320+
colocate_with_first_write_call: colocate_with_first_write_call) :
321+
new _GraphTensorArray(dtype, size: constant_op.constant(size), dynamic_size: dynamic_size,
322+
clear_after_read: clear_after_read, element_shape: element_shape, infer_shape: infer_shape,
323+
colocate_with_first_write_call: colocate_with_first_write_call);
324+
325+
public TensorArray TensorArray(TF_DataType dtype, Tensor size, bool dynamic_size = false,
326+
bool clear_after_read = true, Shape? element_shape = null, bool colocate_with_first_write_call = true,
327+
bool infer_shape = true)
328+
=> tf.executing_eagerly() ?
329+
new _EagerTensorArray(dtype, size: size, dynamic_size: dynamic_size,
330+
clear_after_read: clear_after_read, element_shape: element_shape, infer_shape: infer_shape,
331+
colocate_with_first_write_call: colocate_with_first_write_call) :
332+
new _GraphTensorArray(dtype, size: size, dynamic_size: dynamic_size,
333+
clear_after_read: clear_after_read, element_shape: element_shape, infer_shape: infer_shape,
334+
colocate_with_first_write_call: colocate_with_first_write_call);
312335
}
313336
}

src/TensorFlowNET.Core/Operations/NnOps/rnn.cs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -294,10 +294,9 @@ private static (Tensor, Tensor) _dynamic_rnn_loop(RnnCell cell, Tensor inputs, T
294294

295295
Func<string, Shape, TF_DataType, TensorArray> _create_ta = (name, element_shape, dtype_) =>
296296
{
297-
var ta = new TensorArray(dtype: dtype_,
297+
var ta = tf.TensorArray(dtype: dtype_,
298298
size: time_steps,
299-
element_shape: element_shape,
300-
tensor_array_name: base_name + name);
299+
element_shape: element_shape);
301300
return ta;
302301
};
303302

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
/*****************************************************************************
2+
Copyright 2022 Haiping Chen. All Rights Reserved.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
******************************************************************************/
16+
17+
using System;
18+
using System.Collections.Generic;
19+
using System.Linq;
20+
using Tensorflow.Framework;
21+
using static Tensorflow.Binding;
22+
23+
namespace Tensorflow.Operations
24+
{
25+
public class _EagerTensorArray : TensorArray
26+
{
27+
TF_DataType _dtype;
28+
public override TF_DataType dtype => _dtype;
29+
30+
/// <summary>
31+
/// Used to keep track of what tensors the TensorArray should be
32+
/// colocated with. We choose to colocate the TensorArray with the
33+
/// first tensor written to it.
34+
/// </summary>
35+
bool _colocate_with_first_write_call;
36+
public override bool colocate_with_first_write_call => _colocate_with_first_write_call;
37+
38+
bool _infer_shape;
39+
public override bool infer_shape => _infer_shape;
40+
public bool _dynamic_size;
41+
public Shape _element_shape;
42+
43+
public List<Tensor> _colocate_with;
44+
45+
Tensor _handle;
46+
public override Tensor handle => _handle;
47+
Tensor _flow;
48+
public override Tensor flow => _flow;
49+
bool _clear_after_read;
50+
List<Tensor> _tensor_array;
51+
52+
public _EagerTensorArray(TF_DataType dtype, Tensor size, bool dynamic_size = false,
53+
bool clear_after_read = true, string tensor_array_name = null, Tensor handle = null, Tensor flow = null,
54+
bool infer_shape = true, Shape? element_shape = null,
55+
bool colocate_with_first_write_call = true, string name = null)
56+
{
57+
_flow = constant_op.constant(0);
58+
_infer_shape = infer_shape;
59+
_element_shape = element_shape ?? Shape.Null;
60+
_colocate_with_first_write_call = colocate_with_first_write_call;
61+
_dtype = dtype.as_base_dtype();
62+
_dynamic_size = dynamic_size;
63+
_clear_after_read = clear_after_read;
64+
_tensor_array = new List<Tensor>();
65+
}
66+
67+
public override TensorArray unstack(Tensor value, string name = null)
68+
{
69+
return tf_with(ops.name_scope(name, "TensorArrayUnstack", new { _handle, value }), delegate
70+
{
71+
var num_elements = array_ops.shape(value)[0];
72+
return scatter(indices: math_ops.range(0, num_elements), value: value, name: name);
73+
});
74+
}
75+
76+
public TensorArray scatter(Tensor indices, Tensor value, string name = null)
77+
{
78+
/*return tf_with(ops.name_scope(name, "TensorArrayScatter", new { _handle, value, indices }), delegate
79+
{
80+
value = ops.convert_to_tensor(value, preferred_dtype: _dtype, name: "value");
81+
if (_infer_shape)
82+
{
83+
var shape = new Shape(value.shape.dims.Skip(1).ToArray());
84+
_merge_element_shape(shape);
85+
}
86+
87+
_maybe_colocate_with(value);
88+
var flow_out = gen_data_flow_ops.tensor_array_scatter_v3(
89+
handle: _handle,
90+
indices: indices,
91+
value: value,
92+
flow_in: _flow,
93+
name: name);
94+
95+
var ta = new _EagerTensorArray(_dtype,
96+
infer_shape: _infer_shape,
97+
element_shape: _element_shape[0],
98+
dynamic_size: _dynamic_size,
99+
handle: _handle,
100+
flow: flow_out,
101+
colocate_with_first_write_call: _colocate_with_first_write_call);
102+
103+
104+
return ta;
105+
});*/
106+
throw new NotImplementedException("");
107+
}
108+
109+
public void _merge_element_shape(Shape shape)
110+
{
111+
_element_shape.concatenate(shape);
112+
}
113+
114+
public void _maybe_colocate_with(Tensor value)
115+
{
116+
_colocate_with.Add(value);
117+
}
118+
119+
public override Tensor read<T>(T index, string name = null)
120+
{
121+
int index_int = -1;
122+
if (index is int int_index)
123+
index_int = int_index;
124+
else if (index is Tensor tensor_index)
125+
index_int = tensor_index.numpy();
126+
else
127+
throw new ValueError("");
128+
129+
if (_clear_after_read)
130+
{
131+
_tensor_array[index_int] = null;
132+
}
133+
134+
return _tensor_array[index_int];
135+
}
136+
137+
public override TensorArray write(Tensor index, Tensor value, string name = null)
138+
{
139+
if (_infer_shape)
140+
_element_shape = _element_shape.merge_with(value.shape);
141+
_tensor_array.add(value);
142+
return this;
143+
}
144+
145+
public override TensorArray write<T>(int index, T value, string name = null)
146+
{
147+
var value_tensor = ops.convert_to_tensor(value, preferred_dtype: _dtype, name: "value");
148+
var index_tensor = ops.convert_to_tensor(index, name: "index");
149+
return write(index_tensor, value_tensor, name: name);
150+
}
151+
152+
private Tensor size(string name = null)
153+
{
154+
return gen_data_flow_ops.tensor_array_size_v3(_handle, _flow, name: name);
155+
}
156+
157+
public override Tensor stack(string name = null)
158+
{
159+
ops.colocate_with(_handle);
160+
return tf_with(ops.name_scope(name, "TensorArrayStack", new { _handle }), delegate
161+
{
162+
return gather(math_ops.range(0, size()), name: name);
163+
});
164+
}
165+
166+
public override Tensor gather(Tensor indices, string name = null)
167+
{
168+
var element_shape = Shape.Null;
169+
170+
var value = gen_data_flow_ops.tensor_array_gather_v3(
171+
handle: _handle,
172+
indices: indices,
173+
flow_in: _flow,
174+
dtype: _dtype,
175+
name: name,
176+
element_shape: element_shape);
177+
178+
//if (element_shape != null)
179+
//value.set_shape(-1, element_shape.dims);
180+
181+
return value;
182+
}
183+
}
184+
}

src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ limitations under the License.
2121

2222
namespace Tensorflow.Operations
2323
{
24-
public class _GraphTensorArray
24+
public class _GraphTensorArray : TensorArray
2525
{
2626
internal TF_DataType _dtype;
2727
public TF_DataType dtype => _dtype;
@@ -47,7 +47,7 @@ public class _GraphTensorArray
4747

4848
public _GraphTensorArray(TF_DataType dtype, Tensor size, bool? dynamic_size = null,
4949
bool? clear_after_read = null, string tensor_array_name = null, Tensor handle = null, Tensor flow = null,
50-
bool infer_shape = true, Shape element_shape = null,
50+
bool infer_shape = true, Shape? element_shape = null,
5151
bool colocate_with_first_write_call = true, string name = null)
5252
{
5353
clear_after_read = clear_after_read ?? true;
@@ -108,7 +108,7 @@ public _GraphTensorArray(TF_DataType dtype, Tensor size, bool? dynamic_size = nu
108108
});
109109
}
110110

111-
public TensorArray unstack(Tensor value, string name = null)
111+
public override TensorArray unstack(Tensor value, string name = null)
112112
{
113113
return tf_with(ops.name_scope(name, "TensorArrayUnstack", new { _handle, value }), delegate
114114
{
@@ -119,7 +119,7 @@ public TensorArray unstack(Tensor value, string name = null)
119119

120120
public TensorArray scatter(Tensor indices, Tensor value, string name = null)
121121
{
122-
return tf_with(ops.name_scope(name, "TensorArrayScatter", new { _handle, value, indices }), delegate
122+
/*return tf_with(ops.name_scope(name, "TensorArrayScatter", new { _handle, value, indices }), delegate
123123
{
124124
value = ops.convert_to_tensor(value, preferred_dtype: _dtype, name: "value");
125125
if (_infer_shape)
@@ -136,17 +136,17 @@ public TensorArray scatter(Tensor indices, Tensor value, string name = null)
136136
flow_in: _flow,
137137
name: name);
138138
139-
var ta = new TensorArray(_dtype,
139+
var ta = new _GraphTensorArray(_dtype,
140140
infer_shape: _infer_shape,
141141
element_shape: _element_shape[0],
142142
dynamic_size: _dynamic_size,
143143
handle: _handle,
144144
flow: flow_out,
145145
colocate_with_first_write_call: _colocate_with_first_write_call);
146146
147-
148147
return ta;
149-
});
148+
});*/
149+
throw new NotImplementedException("");
150150
}
151151

152152
public void _merge_element_shape(Shape shape)
@@ -159,11 +159,11 @@ public void _maybe_colocate_with(Tensor value)
159159
_colocate_with.Add(value);
160160
}
161161

162-
public Tensor read(Tensor index, string name = null)
162+
public override Tensor read<T>(T index, string name = null)
163163
{
164164
var value = gen_data_flow_ops.tensor_array_read_v3(
165165
handle: _handle,
166-
index: index,
166+
index: constant_op.constant(index),
167167
flow_in: _flow,
168168
dtype: _dtype,
169169
name: name);
@@ -174,11 +174,10 @@ public Tensor read(Tensor index, string name = null)
174174
return value;
175175
}
176176

177-
public TensorArray write(Tensor index, Tensor value, string name = null)
177+
public override TensorArray write(Tensor index, Tensor value, string name = null)
178178
{
179179
return tf_with(ops.name_scope(name, "TensorArrayWrite", new { _handle, index, value }), delegate
180180
{
181-
value = ops.convert_to_tensor(value, preferred_dtype: _dtype, name: "value");
182181
_maybe_colocate_with(value);
183182
var flow_out = gen_data_flow_ops.tensor_array_write_v3(
184183
handle: _handle,
@@ -191,12 +190,19 @@ public TensorArray write(Tensor index, Tensor value, string name = null)
191190
});
192191
}
193192

193+
public override TensorArray write<T>(int index, T value, string name = null)
194+
{
195+
var value_tensor = ops.convert_to_tensor(value, preferred_dtype: _dtype, name: "value");
196+
var index_tensor = ops.convert_to_tensor(index, name: "index");
197+
return write(index_tensor, value_tensor);
198+
}
199+
194200
private Tensor size(string name = null)
195201
{
196202
return gen_data_flow_ops.tensor_array_size_v3(_handle, _flow, name: name);
197203
}
198204

199-
public Tensor stack(string name = null)
205+
public override Tensor stack(string name = null)
200206
{
201207
ops.colocate_with(_handle);
202208
return tf_with(ops.name_scope(name, "TensorArrayStack", new { _handle }), delegate
@@ -205,7 +211,7 @@ public Tensor stack(string name = null)
205211
});
206212
}
207213

208-
public Tensor gather(Tensor indices, string name = null)
214+
public override Tensor gather(Tensor indices, string name = null)
209215
{
210216
var element_shape = Shape.Null;
211217

src/TensorFlowNET.Core/Operations/functional_ops.cs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -87,9 +87,9 @@ public static Tensor scan(
8787
// n = array_ops.shape(elems_flat[0])[0];
8888
//}
8989

90-
var elems_ta = elems_flat.Select(elem => new TensorArray(
90+
var elems_ta = elems_flat.Select(elem => tf.TensorArray(
9191
elem.dtype,
92-
size: tf.constant(n),
92+
size: n,
9393
dynamic_size: false,
9494
element_shape: elem.shape.dims.Skip(1).ToArray(),
9595
infer_shape: true)).ToList();
@@ -113,9 +113,9 @@ public static Tensor scan(
113113
i = 0;
114114
}
115115

116-
var accs_ta = a_flat.Select(init => new TensorArray(
116+
var accs_ta = a_flat.Select(init => tf.TensorArray(
117117
dtype: init.dtype,
118-
size: tf.constant(n),
118+
size: n,
119119
element_shape: infer_shape ? init.shape : null,
120120
dynamic_size: false,
121121
infer_shape: infer_shape)).ToArray();
@@ -124,7 +124,7 @@ public static Tensor scan(
124124
{
125125
for (int index = 0; index < accs_ta.Length; index++)
126126
{
127-
accs_ta[index].write(tf.constant(reverse ? n - 1 : 0), a_flat[index]);
127+
accs_ta[index].write(reverse ? n - 1 : 0, a_flat[index]);
128128
}
129129
}
130130

0 commit comments

Comments
 (0)