Skip to content

Commit 1ae9bbc

Browse files
committed
Allow object[] to be shape parameters.
1 parent ceccf40 commit 1ae9bbc

24 files changed

+255
-90
lines changed

src/TensorFlowNET.Core/APIs/tf.array.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,15 +78,15 @@ public Tensor check_numerics(Tensor tensor, string message, string name = null)
7878
/// <param name="axis"></param>
7979
/// <param name="name"></param>
8080
/// <returns>A `Tensor` resulting from concatenation of the input tensors.</returns>
81-
public Tensor concat(IList<Tensor> values, int axis, string name = "concat")
81+
public Tensor concat(IEnumerable<Tensor> values, int axis, string name = "concat")
8282
{
83-
if (values.Count == 1)
83+
if (values.Count() == 1)
8484
{
8585
return tf_with(ops.name_scope(name), scope =>
8686
{
8787
var tensor = ops.convert_to_tensor(axis, name: "concat_dim", dtype: dtypes.int32);
8888
Debug.Assert(tensor.TensorShape.ndim == 0);
89-
return identity(values[0], name: scope);
89+
return identity(values.First(), name: scope);
9090
});
9191
}
9292

src/TensorFlowNET.Core/APIs/tf.reshape.cs

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,18 @@ namespace Tensorflow
1919
public partial class tensorflow
2020
{
2121
public Tensor reshape(Tensor tensor,
22-
TensorShape shape,
23-
string name = null) => gen_array_ops.reshape(tensor, shape, name);
22+
TensorShape shape,
23+
string name = null)
24+
=> gen_array_ops.reshape(tensor, shape, name);
2425

2526
public Tensor reshape(Tensor tensor,
26-
Tensor[] shape,
27-
string name = null) => gen_array_ops.reshape(tensor, shape, name);
27+
Tensor shape,
28+
string name = null)
29+
=> gen_array_ops.reshape(tensor, shape, name);
2830

2931
public Tensor reshape(Tensor tensor,
30-
Tensor shape,
31-
string name = null) => gen_array_ops.reshape(tensor, shape, name);
32+
object[] shape,
33+
string name = null)
34+
=> gen_array_ops.reshape(tensor, shape, name);
3235
}
3336
}

src/TensorFlowNET.Core/APIs/tf.tile.cs

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,22 @@ You may obtain a copy of the License at
1313
See the License for the specific language governing permissions and
1414
limitations under the License.
1515
******************************************************************************/
16+
using static Tensorflow.Binding;
1617

1718
namespace Tensorflow
1819
{
1920
public partial class tensorflow
2021
{
21-
public Tensor tile<T>(Tensor input,
22-
T multiples,
23-
string name = null) => gen_array_ops.tile(input, multiples, name);
22+
public Tensor tile(Tensor input, Tensor multiples, string name = null)
23+
=> gen_array_ops.tile(input, multiples, name);
24+
25+
public Tensor tile(Tensor input, object[] multiples, string name = null)
26+
=> gen_array_ops.tile(input, multiples, name);
27+
28+
public Tensor tile(Tensor input, TensorShape multiples, string name = null)
29+
{
30+
var multiples_tensor = constant_op.constant(multiples);
31+
return gen_array_ops.tile(input, multiples_tensor, name);
32+
}
2433
}
2534
}

src/TensorFlowNET.Core/Contexts/Context.AutoMode.cs

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -28,16 +28,27 @@ namespace Tensorflow.Contexts
2828
/// </summary>
2929
public sealed partial class Context
3030
{
31-
// [DebuggerStepThrough]
32-
public T RunInAutoMode<T>(Func<T> graphAction, Func<T> eagerAction, params Tensor[] tensors)
31+
public T RunInAutoMode<T>(Func<T> graphAction, Func<T> eagerAction, params object[] args)
3332
{
34-
var shouldRunInEager = executing_eagerly()
35-
&& tensors.Count(x => x.IsEagerTensor) == tensors.Length;
36-
37-
if (shouldRunInEager)
38-
return eagerAction();
39-
else
33+
if (tf.Context.has_graph_arg(args))
34+
{
4035
return graphAction();
36+
}
37+
else
38+
{
39+
try
40+
{
41+
return eagerAction();
42+
}
43+
catch (InvalidArgumentError ex)
44+
{
45+
throw ex;
46+
}
47+
catch (Exception ex)
48+
{
49+
return graphAction();
50+
}
51+
}
4152
}
4253

4354
// [DebuggerStepThrough]
@@ -46,12 +57,7 @@ public Tensors RunInAutoMode2(Func<Tensors> graphAction,
4657
Action<Operation> recordGradient,
4758
Tensors tensors)
4859
{
49-
var shouldRunInEager = executing_eagerly()
50-
&& tensors.Count(x => x.IsEagerTensor) == tensors.Length;
51-
52-
if (shouldRunInEager)
53-
return eagerAction();
54-
else
60+
if (tf.Context.has_graph_arg(tensors))
5561
{
5662
if (executing_eagerly())
5763
{
@@ -68,6 +74,10 @@ public Tensors RunInAutoMode2(Func<Tensors> graphAction,
6874
return result;
6975
}
7076
}
77+
else
78+
{
79+
return eagerAction();
80+
}
7181
}
7282
}
7383
}

src/TensorFlowNET.Core/Contexts/Context.cs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ limitations under the License.
2020
using Tensorflow.Eager;
2121
using static Tensorflow.Binding;
2222
using Google.Protobuf;
23+
using Tensorflow.Util;
2324

2425
namespace Tensorflow.Contexts
2526
{
@@ -103,6 +104,29 @@ public void graph_mode(bool isFunc = false)
103104
public void eager_mode(bool isFunc = false)
104105
=> context_switches.Push(true, isFunc);
105106

107+
public bool switched_to_graph(params object[] args)
108+
{
109+
var switching_to_graph = has_graph_arg(args) && tf.Context.executing_eagerly();
110+
if (switching_to_graph)
111+
tf.Context.graph_mode(tf.Context.is_build_function());
112+
return switching_to_graph;
113+
}
114+
115+
public bool has_graph_arg(params object[] args)
116+
{
117+
var flatten_args = nest.flatten<object>(args);
118+
bool has_graph_arg = false;
119+
foreach (var el in flatten_args)
120+
{
121+
if (el is Tensor tensor && !tensor.IsEagerTensor)
122+
{
123+
has_graph_arg = true;
124+
break;
125+
}
126+
}
127+
return has_graph_arg;
128+
}
129+
106130
public void restore_mode()
107131
{
108132
context_switches.Pop();

src/TensorFlowNET.Core/Eager/EagerRunner.RecordGradient.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,9 @@ public bool RecordGradient(string op_name,
3838
}
3939
}*/
4040
}
41-
42-
tf.Logger.Debug($"RecordGradient: should_record={should_record}, op_name={op_name}");
41+
4342
if (!should_record) return should_record;
43+
tf.Logger.Debug($"RecordGradient: op_name={op_name}");
4444

4545
Tensor[] op_outputs;
4646
#pragma warning disable CS0219 // Variable is assigned but its value is never used

src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs

Lines changed: 4 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ public Tensor[] TFE_FastPathExecute(Context ctx,
5050

5151
var op_def = tf.get_default_graph().GetOpDef(opName);
5252

53-
var flattened_attrs = new List<object>(op_def.InputArg.Count);
53+
var flattened_attrs = new List<object>(op_def.Attr.Count * 2);
5454
var flattened_inputs = new List<Tensor>(op_def.InputArg.Count);
5555

5656
// Set non-inferred attrs, including setting defaults if the attr is passed in
@@ -221,23 +221,9 @@ bool AddInputToOp(object inputs,
221221
SafeTensorHandleHandle input_handle;
222222

223223
// ConvertToTensor();
224-
switch (inputs)
225-
{
226-
case EagerTensor input:
227-
input_handle = input.EagerTensorHandle;
228-
flattened_inputs.Add(input);
229-
break;
230-
case ResourceVariable variable:
231-
var var_tensor = variable.AsTensor();
232-
input_handle = var_tensor.EagerTensorHandle;
233-
flattened_inputs.Add(var_tensor);
234-
break;
235-
default:
236-
var tensor = tf.convert_to_tensor(inputs);
237-
input_handle = tensor.EagerTensorHandle;
238-
flattened_inputs.Add(tensor);
239-
break;
240-
}
224+
var tensor = tf.convert_to_tensor(inputs);
225+
input_handle = tensor.EagerTensorHandle;
226+
flattened_inputs.Add(tensor);
241227

242228
if (add_type_attr && !string.IsNullOrEmpty(input_arg.TypeAttr))
243229
{

src/TensorFlowNET.Core/Framework/tensor_shape.cs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
using System;
33
using System.Linq;
44
using System.Text;
5+
using static Tensorflow.Binding;
56

67
namespace Tensorflow.Framework
78
{
@@ -65,5 +66,17 @@ public static int dimension_value(Dimension dimension)
6566

6667
public static TensorShape as_shape(this Shape shape)
6768
=> new TensorShape(shape.Dimensions);
69+
70+
public static TensorShape most_specific_compatible_shape(this TensorShape self, TensorShape other)
71+
{
72+
var dims = range(self.rank).Select(x => -1).ToArray();
73+
foreach(var (i, (d1, d2)) in enumerate(zip(self.dims, other.dims)))
74+
{
75+
if (d1 == d2)
76+
dims[i] = d1;
77+
}
78+
79+
return new TensorShape(dims);
80+
}
6881
}
6982
}

src/TensorFlowNET.Core/Functions/ConcreteFunction.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ public Tensors FilteredCall(Tensors inputs)
134134
/// <param name="args"></param>
135135
/// <param name="captured_inputs"></param>
136136
/// <returns></returns>
137-
public Tensor[] CallFlat(Tensor[] args, Tensor[] captured_inputs)
137+
public Tensors CallFlat(Tensor[] args, Tensor[] captured_inputs)
138138
{
139139
var executing_eagerly = tf.Context.executing_eagerly();
140140
var default_graph = ops.get_default_graph();

src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,18 @@ public void Record(Tensors flat_outputs, Tensors inference_args)
9999
if (input_index >= backward_function_inputs)
100100
break;
101101
}
102+
102103
tf.Logger.Debug($"Invoke backward function: {backward.Name}");
103-
return backward.CallFlat(processed_args, remapped_captures);
104+
var gradients = backward.CallFlat(processed_args, remapped_captures);
105+
106+
foreach (var unneeded_gradient_index in unneeded_gradients)
107+
{
108+
var index = Convert.ToInt32(unneeded_gradient_index);
109+
if (gradients.Length <= index)
110+
gradients.Insert(index, null);
111+
}
112+
113+
return gradients;
104114
};
105115

106116
return (_backward_function_wrapper, recorded_outputs);

0 commit comments

Comments
 (0)