Skip to content

Commit aecaa78

Browse files
committed
build_results can't handle when is_op is false condition.
1 parent d94e545 commit aecaa78

File tree

10 files changed

+93
-27
lines changed

10 files changed

+93
-27
lines changed

src/TensorFlowNET.Core/Sessions/BaseSession.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ private NDArray _run<T>(T fetches, Dictionary<Tensor, NDArray> feed_dict = null)
7272
// or if the call is a partial run that specifies feeds.
7373
var results = _do_run(final_targets.Select(x => (Operation)(object)x).ToList(), final_fetches, feed_dict_tensor);
7474

75-
return fetch_handler.build_results(null, results);
75+
return fetch_handler.build_results(this, results);
7676
}
7777

7878
/// <summary>

src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@ namespace Tensorflow
1111
public class _ElementFetchMapper<T> : _FetchMapper<T>
1212
{
1313
private List<object> _unique_fetches = new List<object>();
14-
private Func<List<object>, NDArray> _contraction_fn;
14+
private Func<List<object>, object> _contraction_fn;
1515

16-
public _ElementFetchMapper(List<T> fetches, Func<List<object>, NDArray> contraction_fn)
16+
public _ElementFetchMapper(List<T> fetches, Func<List<object>, object> contraction_fn)
1717
{
1818
foreach(var fetch in fetches)
1919
{
@@ -32,10 +32,22 @@ public _ElementFetchMapper(List<T> fetches, Func<List<object>, NDArray> contract
3232
/// <returns></returns>
3333
public NDArray build_results(List<object> values)
3434
{
35-
if (values.Count == 0)
36-
return null;
37-
else
38-
return _contraction_fn(values);
35+
NDArray result = null;
36+
37+
if (values.Count > 0)
38+
{
39+
var ret = _contraction_fn(values);
40+
switch (ret)
41+
{
42+
case NDArray value:
43+
result = value;
44+
break;
45+
default:
46+
break;
47+
}
48+
}
49+
50+
return result;
3951
}
4052

4153
public List<object> unique_fetches()

src/TensorFlowNET.Core/Sessions/_FetchHandler.cs

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ public class _FetchHandler<T>
1616
private List<Tensor> _final_fetches = new List<Tensor>();
1717
private List<T> _targets = new List<T>();
1818

19-
public _FetchHandler(Graph graph, T fetches, Dictionary<Tensor, NDArray> feeds = null, object feed_handles = null)
19+
public _FetchHandler(Graph graph, T fetches, Dictionary<Tensor, NDArray> feeds = null, Action feed_handles = null)
2020
{
2121
_fetch_mapper = new _FetchMapper<T>().for_fetch(fetches);
2222
foreach(var fetch in _fetch_mapper.unique_fetches())
@@ -40,18 +40,32 @@ public _FetchHandler(Graph graph, T fetches, Dictionary<Tensor, NDArray> feeds =
4040
_final_fetches = _fetches;
4141
}
4242

43-
public NDArray build_results(Session session, NDArray[] tensor_values)
43+
public NDArray build_results(BaseSession session, NDArray[] tensor_values)
4444
{
4545
var full_values = new List<object>();
46+
if (_final_fetches.Count != tensor_values.Length)
47+
throw new InvalidOperationException("_final_fetches mismatch tensor_values");
4648

49+
int i = 0;
50+
int j = 0;
4751
foreach(var is_op in _ops)
4852
{
4953
if (is_op)
5054
{
5155
full_values.Add(null);
5256
}
57+
else
58+
{
59+
var value = tensor_values[j];
60+
j += 1;
61+
full_values.Add(value);
62+
}
63+
i += 1;
5364
}
5465

66+
if (j != tensor_values.Length)
67+
throw new InvalidOperationException("j mismatch tensor_values");
68+
5569
return _fetch_mapper.build_results(full_values);
5670
}
5771

src/TensorFlowNET.Core/Sessions/_FetchMapper.cs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@ public _ElementFetchMapper<T> for_fetch(T fetch)
1010
{
1111
var fetches = new List<T> { fetch };
1212

13-
return new _ElementFetchMapper<T>(fetches, null);
13+
return new _ElementFetchMapper<T>(fetches, (List<object> fetched_vals) =>
14+
{
15+
return fetched_vals[0];
16+
});
1417
}
1518
}
1619
}

src/TensorFlowNET.Core/Tensors/Tensor.Implicit.cs

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,5 @@ public static implicit operator Tensor(IntPtr handle)
2525
{
2626
return new Tensor(handle);
2727
}
28-
29-
public static implicit operator Tensor(RefVariable var)
30-
{
31-
return var._initial_value;
32-
}
3328
}
3429
}
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow
6+
{
7+
public partial class RefVariable
8+
{
9+
public static implicit operator _VariableScopeStore(RefVariable variable)
10+
{
11+
return null;
12+
}
13+
14+
public static implicit operator RefVariable(_VariableScopeStore store)
15+
{
16+
return null;
17+
}
18+
19+
public static implicit operator Tensor(RefVariable var)
20+
{
21+
return var._AsTensor();
22+
}
23+
}
24+
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow
6+
{
7+
public partial class RefVariable
8+
{
9+
public static Tensor operator +(RefVariable t1, int t2)
10+
{
11+
var tensor1 = t1._AsTensor();
12+
var tensor2 = ops.convert_to_tensor(t2, tensor1.dtype, "y");
13+
return gen_math_ops.add(tensor1, tensor2);
14+
}
15+
}
16+
}

src/TensorFlowNET.Core/Variables/RefVariable.cs

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
namespace Tensorflow
66
{
7-
public class RefVariable : VariableV1
7+
public partial class RefVariable : VariableV1
88
{
99
public bool _in_graph_mode = true;
1010
public Tensor _initial_value;
@@ -106,14 +106,9 @@ public Tensor _ref()
106106
return _variable;
107107
}
108108

109-
public static implicit operator _VariableScopeStore(RefVariable variable)
109+
public Tensor _AsTensor()
110110
{
111-
return null;
112-
}
113-
114-
public static implicit operator RefVariable(_VariableScopeStore store)
115-
{
116-
return null;
111+
return _snapshot;
117112
}
118113
}
119114
}

src/TensorFlowNET.Core/ops.py.cs

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,15 +59,22 @@ public static Graph _get_graph_from_inputs(List<Tensor> op_input_list, Graph gra
5959
return get_default_graph();
6060
}
6161

62-
public static Tensor convert_to_tensor(object value, string name = "")
62+
/// <summary>
63+
/// Converts the given `value` to a `Tensor`.
64+
/// </summary>
65+
/// <param name="value"></param>
66+
/// <param name="dtype"></param>
67+
/// <param name="name"></param>
68+
/// <returns></returns>
69+
public static Tensor convert_to_tensor(object value, TF_DataType dtype = TF_DataType.DtInvalid, string name = "")
6370
{
6471
switch (value)
6572
{
6673
case Tensor val:
6774
return val;
6875
default:
6976
var nd = tensor_util.convert_to_numpy_ndarray(value);
70-
return tf.constant(nd, name);
77+
return constant_op.Constant(nd, name);
7178
}
7279
}
7380

test/TensorFlowNET.UnitTest/VariableTest.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@ public void Add()
3838
session.run(model);
3939
for(int i = 0; i < 5; i++)
4040
{
41-
// x = x + 1;
42-
var result = session.run(x);
41+
var x1 = x + 1;
42+
var result = session.run(x1);
4343
print(result);
4444
}
4545
}

0 commit comments

Comments
 (0)