Skip to content

Commit ff51917

Browse files
committed
_ElementFetchMapper.build_results
1 parent ea5d35e commit ff51917

File tree

12 files changed

+136
-40
lines changed

12 files changed

+136
-40
lines changed

src/TensorFlowNET.Core/Graphs/Graph.cs

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,9 @@ private T _as_graph_element_locked<T>(T obj, bool allow_tensor = true, bool allo
7777

7878
var temp_obj = _as_graph_element(obj);
7979

80-
if(obj is Tensor && allow_tensor)
80+
if (obj is Tensor tensor && allow_tensor)
8181
{
82-
if ((obj as Tensor).Graph.Equals(this))
82+
if (tensor.Graph.Equals(this))
8383
{
8484
return obj;
8585
}
@@ -88,6 +88,17 @@ private T _as_graph_element_locked<T>(T obj, bool allow_tensor = true, bool allo
8888
throw new Exception($"Tensor {obj} is not an element of this graph.");
8989
}
9090
}
91+
else if (obj is Operation op && allow_operation)
92+
{
93+
if (op.Graph.Equals(this))
94+
{
95+
return obj;
96+
}
97+
else
98+
{
99+
throw new Exception($"Operation {obj} is not an element of this graph.");
100+
}
101+
}
91102

92103
throw new Exception($"Can not convert a {typeof(T).Name} into a {types_str}.");
93104
}

src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using System;
22
using System.Collections.Generic;
3+
using System.Linq;
34
using System.Text;
45

56
namespace Tensorflow
@@ -10,14 +11,35 @@ public static Operation group(List<Operation> inputs, string name = "")
1011
{
1112
using(var namescope = new ops.name_scope<Operation>(name, "group_deps", inputs))
1213
{
14+
name = namescope;
15+
16+
var ops_on_device = new Dictionary<string, Operation[]>();
17+
1318
// Sorts *inputs according to their devices.
19+
foreach (var inp in inputs)
20+
{
21+
ops_on_device[inp.Device] = new Operation[] { inp };
22+
}
23+
24+
// 1-level tree. The root node is the returned NoOp node.
25+
if (ops_on_device.Count == 1)
26+
{
27+
return _GroupControlDeps(ops_on_device.Keys.First(), ops_on_device.Values.First(), name);
28+
}
1429

15-
return _GroupControlDeps("", name);
30+
// 2-level tree. The root node is the returned NoOp node.
31+
// deps contains 1 NoOp node for each device.
32+
return null;
1633
}
1734
}
1835

19-
private static Operation _GroupControlDeps(string dev, string name = "")
36+
private static Operation _GroupControlDeps(string dev, Operation[] deps, string name = "")
2037
{
38+
if (string.IsNullOrEmpty(dev))
39+
{
40+
return gen_control_flow_ops.no_op(name);
41+
}
42+
2143
return null;
2244
}
2345
}
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow
6+
{
7+
public class gen_control_flow_ops
8+
{
9+
public static OpDefLibrary _op_def_lib = new OpDefLibrary();
10+
11+
public static Operation no_op(string name = "")
12+
{
13+
var _op = _op_def_lib._apply_op_helper("NoOp", name);
14+
15+
return _op;
16+
}
17+
}
18+
}

src/TensorFlowNET.Core/Sessions/BaseSession.cs

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,12 @@ public virtual NDArray run(Tensor fetches, Dictionary<Tensor, NDArray> feed_dict
4040
return _run(fetches, feed_dict);
4141
}
4242

43-
private NDArray _run(Tensor fetches, Dictionary<Tensor, NDArray> feed_dict = null)
43+
public virtual NDArray run(Operation fetches, Dictionary<Tensor, NDArray> feed_dict = null)
44+
{
45+
return _run(fetches, feed_dict);
46+
}
47+
48+
private NDArray _run<T>(T fetches, Dictionary<Tensor, NDArray> feed_dict = null)
4449
{
4550
var feed_dict_tensor = new Dictionary<Tensor, NDArray>();
4651

@@ -53,7 +58,7 @@ private NDArray _run(Tensor fetches, Dictionary<Tensor, NDArray> feed_dict = nul
5358
}
5459

5560
// Create a fetch handler to take care of the structure of fetches.
56-
var fetch_handler = new _FetchHandler(_graph, fetches, feed_dict_tensor);
61+
var fetch_handler = new _FetchHandler<T>(_graph, fetches, feed_dict_tensor);
5762

5863
// Run request and get response.
5964
// We need to keep the returned movers alive for the following _do_run().
@@ -65,20 +70,34 @@ private NDArray _run(Tensor fetches, Dictionary<Tensor, NDArray> feed_dict = nul
6570

6671
// We only want to really perform the run if fetches or targets are provided,
6772
// or if the call is a partial run that specifies feeds.
68-
var results = _do_run(final_fetches, feed_dict_tensor);
73+
var results = _do_run(final_targets.Select(x => (Operation)(object)x).ToList(), final_fetches, feed_dict_tensor);
6974

7075
return fetch_handler.build_results(null, results);
7176
}
7277

73-
private NDArray[] _do_run(List<Tensor> fetch_list, Dictionary<Tensor, NDArray> feed_dict)
78+
/// <summary>
79+
/// Runs a step based on the given fetches and feeds.
80+
/// </summary>
81+
/// <typeparam name="T"></typeparam>
82+
/// <param name="target_list">A list of operations to be run, but not fetched.</param>
83+
/// <param name="fetch_list"></param>
84+
/// <param name="feed_dict"></param>
85+
/// <returns>
86+
/// A list of numpy ndarrays, corresponding to the elements of
87+
/// `fetch_list`. If the ith element of `fetch_list` contains the
88+
/// name of an operation, the first Tensor output of that operation
89+
/// will be returned for that element.
90+
/// </returns>
91+
private NDArray[] _do_run(List<Operation> target_list, List<Tensor> fetch_list, Dictionary<Tensor, NDArray> feed_dict)
7492
{
7593
var feeds = feed_dict.Select(x => new KeyValuePair<TF_Output, Tensor>(x.Key._as_tf_output(), new Tensor(x.Value))).ToArray();
7694
var fetches = fetch_list.Select(x => x._as_tf_output()).ToArray();
95+
var targets = target_list;
7796

78-
return _call_tf_sessionrun(feeds, fetches);
97+
return _call_tf_sessionrun(feeds, fetches, target_list);
7998
}
8099

81-
private unsafe NDArray[] _call_tf_sessionrun(KeyValuePair<TF_Output, Tensor>[] feed_dict, TF_Output[] fetch_list)
100+
private unsafe NDArray[] _call_tf_sessionrun(KeyValuePair<TF_Output, Tensor>[] feed_dict, TF_Output[] fetch_list, List<Operation> target_list)
82101
{
83102
// Ensure any changes to the graph are reflected in the runtime.
84103
_extend_graph();
@@ -95,8 +114,8 @@ private unsafe NDArray[] _call_tf_sessionrun(KeyValuePair<TF_Output, Tensor>[] f
95114
outputs: fetch_list,
96115
output_values: output_values,
97116
noutputs: fetch_list.Length,
98-
target_opers: IntPtr.Zero,
99-
ntargets: 0,
117+
target_opers: target_list.Select(f => (IntPtr)f).ToArray(),
118+
ntargets: target_list.Count,
100119
run_metadata: IntPtr.Zero,
101120
status: status);
102121

src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,26 +8,37 @@ namespace Tensorflow
88
/// <summary>
99
/// Fetch mapper for singleton tensors and ops.
1010
/// </summary>
11-
public class _ElementFetchMapper : _FetchMapper
11+
public class _ElementFetchMapper<T> : _FetchMapper<T>
1212
{
13-
private List<Object> _unique_fetches = new List<object>();
14-
private Action _contraction_fn;
13+
private List<object> _unique_fetches = new List<object>();
14+
private Func<List<object>> _contraction_fn;
1515

16-
public _ElementFetchMapper(List<Tensor> fetches, Action contraction_fn)
16+
public _ElementFetchMapper(List<T> fetches, Func<List<object>> contraction_fn)
1717
{
18-
foreach(var tensor in fetches)
18+
foreach(var fetch in fetches)
1919
{
20-
var fetch = ops.get_default_graph().as_graph_element(tensor, allow_tensor: true, allow_operation: true);
21-
_unique_fetches.Add(fetch);
20+
var g = ops.get_default_graph();
21+
var el = g.as_graph_element(fetch, allow_tensor: true, allow_operation: true);
22+
_unique_fetches.Add(el);
2223
}
24+
25+
_contraction_fn = contraction_fn;
2326
}
2427

25-
public NDArray build_results(NDArray[] values)
28+
/// <summary>
29+
/// Build results matching the original fetch shape.
30+
/// </summary>
31+
/// <param name="values"></param>
32+
/// <returns></returns>
33+
public NDArray build_results(List<object> values)
2634
{
27-
return values[0];
35+
if (values.Count == 0)
36+
return null;
37+
else
38+
return _contraction_fn(values);
2839
}
2940

30-
public List<Object> unique_fetches()
41+
public List<object> unique_fetches()
3142
{
3243
return _unique_fetches;
3344
}

src/TensorFlowNET.Core/Sessions/_FetchHandler.cs

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,21 +8,26 @@ namespace Tensorflow
88
/// <summary>
99
/// Handler for structured fetches.
1010
/// </summary>
11-
public class _FetchHandler
11+
public class _FetchHandler<T>
1212
{
13-
private _ElementFetchMapper _fetch_mapper;
13+
private _ElementFetchMapper<T> _fetch_mapper;
1414
private List<Tensor> _fetches = new List<Tensor>();
1515
private List<bool> _ops = new List<bool>();
1616
private List<Tensor> _final_fetches = new List<Tensor>();
17-
private List<object> _targets = new List<object>();
17+
private List<T> _targets = new List<T>();
1818

19-
public _FetchHandler(Graph graph, Tensor fetches, Dictionary<Tensor, NDArray> feeds = null, object feed_handles = null)
19+
public _FetchHandler(Graph graph, T fetches, Dictionary<Tensor, NDArray> feeds = null, object feed_handles = null)
2020
{
21-
_fetch_mapper = new _FetchMapper().for_fetch(fetches);
21+
_fetch_mapper = new _FetchMapper<T>().for_fetch(fetches);
2222
foreach(var fetch in _fetch_mapper.unique_fetches())
2323
{
2424
switch (fetch)
2525
{
26+
case Operation val:
27+
_assert_fetchable(graph, val);
28+
_targets.Add((T)(object)val);
29+
_ops.Add(true);
30+
break;
2631
case Tensor val:
2732
_assert_fetchable(graph, val.op);
2833
_fetches.Add(val);
@@ -35,9 +40,19 @@ public _FetchHandler(Graph graph, Tensor fetches, Dictionary<Tensor, NDArray> fe
3540
_final_fetches = _fetches;
3641
}
3742

38-
public NDArray build_results(Session session, NDArray[] results)
43+
public NDArray build_results(Session session, NDArray[] tensor_values)
3944
{
40-
return _fetch_mapper.build_results(results);
45+
var full_values = new List<object>();
46+
47+
foreach(var is_op in _ops)
48+
{
49+
if (is_op)
50+
{
51+
full_values.Add(null);
52+
}
53+
}
54+
55+
return _fetch_mapper.build_results(full_values);
4156
}
4257

4358
private void _assert_fetchable(Graph graph, Operation op)
@@ -53,7 +68,7 @@ public List<Tensor> fetches()
5368
return _final_fetches;
5469
}
5570

56-
public List<Object> targets()
71+
public List<T> targets()
5772
{
5873
return _targets;
5974
}

src/TensorFlowNET.Core/Sessions/_FetchMapper.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@
44

55
namespace Tensorflow
66
{
7-
public class _FetchMapper
7+
public class _FetchMapper<T>
88
{
9-
public _ElementFetchMapper for_fetch(Tensor fetch)
9+
public _ElementFetchMapper<T> for_fetch(T fetch)
1010
{
11-
var fetches = new List<Tensor> { fetch };
11+
var fetches = new List<T> { fetch };
1212

13-
return new _ElementFetchMapper(fetches, null);
13+
return new _ElementFetchMapper<T>(fetches, null);
1414
}
1515
}
1616
}

src/TensorFlowNET.Core/Sessions/c_api.session.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ public partial class c_api
8787
public static extern unsafe void TF_SessionRun(IntPtr session, TF_Buffer* run_options,
8888
TF_Output[] inputs, IntPtr[] input_values, int ninputs,
8989
TF_Output[] outputs, IntPtr[] output_values, int noutputs,
90-
IntPtr target_opers, int ntargets,
90+
IntPtr[] target_opers, int ntargets,
9191
IntPtr run_metadata,
9292
IntPtr status);
9393
}

src/TensorFlowNET.Core/Variables/variables.py.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ public static List<RefVariable> global_variables(string scope = "")
4242
/// <returns>An Op that run the initializers of all the specified variables.</returns>
4343
public static Operation variables_initializer(RefVariable[] var_list, string name = "init")
4444
{
45-
return control_flow_ops.group(var_list.Select(x => x.initializer).ToList());
45+
return control_flow_ops.group(var_list.Select(x => x.initializer).ToList(), name);
4646
}
4747
}
4848
}
13.1 MB
Binary file not shown.

0 commit comments

Comments
 (0)