Skip to content

Commit e75a111

Browse files
committed
lift_to_graph
1 parent 2093577 commit e75a111

File tree

10 files changed

+290
-49
lines changed

10 files changed

+290
-49
lines changed

src/TensorFlowNET.Core/Binding.Util.cs

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,15 @@ public static void Update<T>(this IList<T> list, T element)
4646
}
4747
}
4848

49+
public static void difference_update<T>(this IList<T> list, IList<T> list2)
50+
{
51+
foreach(var el in list2)
52+
{
53+
if (list.Contains(el))
54+
list.Remove(el);
55+
}
56+
}
57+
4958
public static void add<T>(this IList<T> list, T element)
5059
=> list.Add(element);
5160

@@ -158,6 +167,13 @@ public static IEnumerable<int> range(int start, int end)
158167
return Enumerable.Range(start, end - start);
159168
}
160169

170+
public static IEnumerable<T> reversed<T>(IList<T> values)
171+
{
172+
var len = values.Count;
173+
for (int i = len - 1; i >= 0; i--)
174+
yield return values[i];
175+
}
176+
161177
public static T New<T>() where T : ITensorFlowObject, new()
162178
{
163179
var instance = new T();
@@ -284,7 +300,7 @@ public static float time()
284300
for (int i = 0; i < len; i++)
285301
yield return (i, values[i]);
286302
}
287-
303+
288304
public static IEnumerable<(int, T)> enumerate<T>(IEnumerable<T> values, int start = 0, int step = 1)
289305
{
290306
int i = 0;

src/TensorFlowNET.Core/Functions/ConcreteFunction.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ public class ConcreteFunction : IDisposable
1414
{
1515
IntPtr _handle;
1616
FuncGraph func_graph;
17-
public Tensor[] CapturedInputs => func_graph.external_captures();
17+
public Tensor[] CapturedInputs => func_graph.external_captures;
1818

1919
public string Name
2020
{
@@ -37,7 +37,7 @@ public ConcreteFunction(string name)
3737
func_graph.as_default();
3838
}
3939

40-
public ConcreteFunction(FuncGraph graph, Dictionary<string, string> attrs)
40+
public ConcreteFunction(FuncGraph graph, Dictionary<string, string> attrs = null)
4141
{
4242
func_graph = graph;
4343

src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ public void Record(Tensors flat_outputs, Tensors inference_args)
9393
grad_ys: gradients_wrt_outputs.ToArray(),
9494
src_graph: _func_graph);
9595

96-
var captures_from_forward = backwards_graph.external_captures()
96+
var captures_from_forward = backwards_graph.external_captures
9797
.Where(x => !x.IsEagerTensor && x.graph == _func_graph)
9898
.ToArray();
9999
foreach(var capture in captures_from_forward)
@@ -105,7 +105,7 @@ public void Record(Tensors flat_outputs, Tensors inference_args)
105105
var forward_function_name = $"{_FORWARD_PREFIX}_{ops.uid()}";
106106
var backward_function_attr = new Dictionary<string, string>();
107107
backward_function_attr[FORWARD_FUNCTION_ATTRIBUTE_NAME] = forward_function_name;
108-
gradients_wrt_outputs.append(backwards_graph.internal_captures());
108+
gradients_wrt_outputs.append(backwards_graph.internal_captures);
109109
backwards_graph.Inputs = gradients_wrt_outputs;
110110
backwards_graph.Outputs = gradients_wrt_inputs;
111111

src/TensorFlowNET.Core/Graphs/FuncGraph.cs

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,20 @@ public class FuncGraph : Graph
2121
public Tensors Outputs { get; set; } = new Tensors();
2222
public Dictionary<string, string> Attrs { get; set; }
2323

24-
public Dictionary<long, (Tensor, Tensor)> _captures
24+
Dictionary<long, (Tensor, Tensor)> _captures
2525
= new Dictionary<long, (Tensor, Tensor)>();
2626

27-
public Tensor[] external_captures()
27+
public Tensor[] external_captures
2828
=> _captures.Select(x => x.Value.Item1).ToArray();
29+
public (Tensor, Tensor)[] captures
30+
=> _captures.Values.Select(x => x).ToArray();
2931

30-
public Tensor[] internal_captures()
32+
public Tensor[] internal_captures
3133
=> _captures.Select(x => x.Value.Item2).ToArray();
3234

35+
public Tensor[] captured_inputs
36+
=> external_captures;
37+
3338
/// <summary>
3439
/// Construct a new FuncGraph.
3540
/// </summary>
Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
using System.Linq;
5+
using static Tensorflow.Binding;
6+
7+
namespace Tensorflow.Graphs
8+
{
9+
public class SubGraphUtility
10+
{
11+
/// <summary>
12+
/// Copies the tensor and all its inputs recursively to the outer graph.
13+
/// </summary>
14+
/// <param name="tensors"></param>
15+
/// <param name="graph"></param>
16+
/// <param name="add_sources"></param>
17+
/// <param name="handle_captures"></param>
18+
/// <param name="base_graph"></param>
19+
/// <returns></returns>
20+
public static Dictionary<ITensorOrOperation, Operation> lift_to_graph(Tensors init_tensors,
21+
FuncGraph graph,
22+
List<Tensor> sources,
23+
bool add_sources = false,
24+
bool handle_captures = false,
25+
Graph base_graph = null,
26+
Dictionary<ITensorOrOperation, Operation> op_map = null)
27+
{
28+
base_graph = base_graph ?? init_tensors[0].graph;
29+
op_map = op_map ?? new Dictionary<ITensorOrOperation, Operation>();
30+
var visited_ops = sources.Select(x => x.op).ToList();
31+
foreach (var init_tensor in init_tensors)
32+
{
33+
var src = map_subgraph(init_tensor, sources, visited_ops, add_sources);
34+
sources.AddRange(src);
35+
}
36+
37+
var ops_to_copy = new List<Operation>();
38+
var marked_ops = new List<Operation>();
39+
var ops_to_visit = new Stack<Operation>(init_tensors.Select(x => x.op));
40+
var unvisited_ops = new List<Operation>(ops_to_visit.ToList());
41+
while (unvisited_ops.Count > 0)
42+
{
43+
while(ops_to_visit.Count > 0)
44+
{
45+
var op = ops_to_visit.Pop();
46+
if (marked_ops.Contains(op))
47+
continue;
48+
marked_ops.Add(op);
49+
ops_to_copy.append(op);
50+
foreach(var inp in op.inputs)
51+
{
52+
53+
}
54+
}
55+
// difference_update
56+
unvisited_ops.difference_update(marked_ops);
57+
if (unvisited_ops.Count > 0)
58+
ops_to_visit.Push(unvisited_ops.Last());
59+
}
60+
61+
// When lifting from one FuncGraph to another, we will need to capture the
62+
// relevant tensors as well.
63+
var inverse_captures = new Dictionary<Tensor, Tensor>();
64+
Tensor[] internal_captures = null;
65+
if (base_graph is FuncGraph base_func_graph)
66+
{
67+
var captures = base_func_graph.captures;
68+
foreach (var (external_capture, internal_capture) in captures)
69+
inverse_captures[internal_capture] = external_capture;
70+
internal_captures = base_func_graph.internal_captures;
71+
}
72+
73+
graph.as_default();
74+
var source_ops = new List<Operation>();
75+
// Add the sources in the same order as the original graph.
76+
foreach (var s in internal_captures)
77+
{
78+
if (sources.Contains(s))
79+
{
80+
sources.Remove(s);
81+
source_ops.Add(s.op);
82+
_copy_source(s: s,
83+
graph: graph,
84+
op_map: op_map,
85+
handle_captures: handle_captures,
86+
inverse_captures: inverse_captures,
87+
base_graph: base_graph);
88+
}
89+
}
90+
91+
foreach(var op in reversed(ops_to_copy))
92+
{
93+
if (source_ops.Contains(op) || op_map.ContainsKey(op))
94+
continue;
95+
_copy_non_source(op, graph, op_map, base_graph);
96+
}
97+
98+
return op_map;
99+
}
100+
101+
static void _copy_source(Tensor s,
102+
FuncGraph graph,
103+
Dictionary<ITensorOrOperation, Operation> op_map,
104+
bool handle_captures,
105+
Dictionary<Tensor, Tensor> inverse_captures,
106+
Graph base_graph)
107+
{
108+
Tensor copied_placeholder = null;
109+
if (handle_captures && inverse_captures.ContainsKey(s))
110+
copied_placeholder = graph.capture(inverse_captures[s], name: s.op.name);
111+
else
112+
throw new NotImplementedException("");
113+
op_map[s] = copied_placeholder;
114+
// Add an entry for the op of the source tensor so that if there are any nodes
115+
// depending on that op via control dependencies it can work correctly.
116+
op_map[s.op] = copied_placeholder.op;
117+
}
118+
119+
static void _copy_non_source(Operation op, FuncGraph graph, Dictionary<ITensorOrOperation, Operation> op_map, Graph base_graph)
120+
{
121+
Operation copied_op = null;
122+
var copied_inputs = new Tensors();
123+
tf_with(ops.control_dependencies(new object[] { op }), delegate
124+
{
125+
// Create a new op in the destination graph if it doesn't exist before.
126+
var attrs = new Dictionary<string, AttrValue>();
127+
foreach (var attr_def in op.node_def.Attr)
128+
attrs[attr_def.Key] = attr_def.Value;
129+
130+
copied_op = graph.create_op(op.type,
131+
copied_inputs,
132+
dtypes: op.outputs.Select(x => x.dtype).ToArray(),
133+
attrs: attrs,
134+
name: op.name);
135+
});
136+
op_map[op] = copied_op;
137+
foreach (var (i, o) in enumerate(op.outputs))
138+
op_map[o] = copied_op.outputs[i];
139+
}
140+
141+
/// <summary>
142+
/// Walk a Graph and capture the subgraph between init_tensor and sources.
143+
/// </summary>
144+
/// <param name="init_tensor"></param>
145+
/// <param name="add_sources"></param>
146+
public static List<Tensor> map_subgraph(Tensor init_tensor,
147+
List<Tensor> sources,
148+
List<Operation> visited_ops,
149+
bool add_sources)
150+
{
151+
var ops_to_visit = new Stack<Operation>();
152+
ops_to_visit.Push(init_tensor.op);
153+
var extra_sources = new List<Tensor>();
154+
while (ops_to_visit.Count > 0)
155+
{
156+
var op = ops_to_visit.Pop();
157+
if (visited_ops.Contains(op))
158+
continue;
159+
visited_ops.Add(op);
160+
bool should_raise = false;
161+
if (should_raise)
162+
throw new RuntimeError($"Unable to lift tensor {init_tensor.name}.");
163+
if(op.type == "Placeholder")
164+
{
165+
extra_sources.AddRange(op.outputs);
166+
}
167+
foreach(var inp in op.inputs)
168+
{
169+
170+
}
171+
}
172+
return extra_sources;
173+
}
174+
}
175+
}

src/TensorFlowNET.Core/Operations/gen_math_ops.cs

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -873,22 +873,6 @@ public static Tensor atan2(Tensor y, Tensor x, string name = null)
873873
return _op.output;
874874
}
875875

876-
public static Tensor mul(Tensor x, Tensor y, string name = null)
877-
{
878-
if (tf.Context.executing_eagerly())
879-
{
880-
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
881-
"Mul", name,
882-
null,
883-
x, y);
884-
return results[0];
885-
}
886-
887-
var _op = tf.OpDefLib._apply_op_helper("Mul", name, args: new { x, y });
888-
889-
return _op.output;
890-
}
891-
892876
public static Tensor mul<Tx, Ty>(Tx x, Ty y, string name = null)
893877
{
894878
if (tf.Context.executing_eagerly())

src/TensorFlowNET.Core/Operations/math_ops.cs

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,23 @@ public static Tensor abs(Tensor x, string name = null)
4444
public static Tensor add<Tx, Ty>(Tx x, Ty y, string name = null)
4545
=> gen_math_ops.add(x, y, name);
4646

47+
public static Tensor add_v2(Tensor x, Tensor y, string name = null)
48+
=> tf.Context.RunInAutoMode2(
49+
() => tf.OpDefLib._apply_op_helper("AddV2", name, new { x, y }).output,
50+
() => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
51+
"AddV2", name,
52+
null,
53+
x, y).FirstOrDefault(),
54+
(op) =>
55+
{
56+
var attrs = new object[]
57+
{
58+
"T", op.get_attr<TF_DataType>("T")
59+
};
60+
tf.Runner.RecordGradient("AddV2", op.inputs, attrs, op.outputs);
61+
},
62+
new Tensors(x, y));
63+
4764
public static Tensor add_v2<Tx, Ty>(Tx x, Ty y, string name = null)
4865
=> gen_math_ops.add_v2(x, y, name);
4966

@@ -251,6 +268,23 @@ public static Tensor equal<Tx, Ty>(Tx x, Ty y, string name = null)
251268
public static Tensor sqrt(Tensor x, string name = null)
252269
=> gen_math_ops.sqrt(x, name: name);
253270

271+
public static Tensor multiply(Tensor x, Tensor y, string name = null)
272+
=> tf.Context.RunInAutoMode2(
273+
() => tf.OpDefLib._apply_op_helper("Mul", name, new { x, y }).output,
274+
() => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
275+
"Mul", name,
276+
null,
277+
x, y).FirstOrDefault(),
278+
(op) =>
279+
{
280+
var attrs = new object[]
281+
{
282+
"T", op.get_attr<TF_DataType>("T")
283+
};
284+
tf.Runner.RecordGradient("Mul", op.inputs, attrs, op.outputs);
285+
},
286+
new Tensors(x, y));
287+
254288
public static Tensor multiply<Tx, Ty>(Tx x, Ty y, string name = null)
255289
=> gen_math_ops.mul(x, y, name: name);
256290

src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -309,25 +309,19 @@ private static string div_or_truediv<Tx, Ty>(string name, Tx x, Ty y)
309309
private static Tensor BinaryOpWrapper<Tx, Ty>(string name, Tx x, Ty y)
310310
{
311311
TF_DataType dtype = TF_DataType.DtInvalid;
312-
bool switchToGraphModeTemp = !tf.executing_eagerly();
313312

314313
if (x is Tensor tl)
315314
{
316315
dtype = tl.dtype.as_base_dtype();
317-
switchToGraphModeTemp = switchToGraphModeTemp || !tl.IsEagerTensor;
318316
}
319317

320318
if (y is Tensor tr)
321319
{
322320
dtype = tr.dtype.as_base_dtype();
323-
switchToGraphModeTemp = switchToGraphModeTemp || !tr.IsEagerTensor;
324321
}
325322

326323
return tf_with(ops.name_scope(null, name, new { x, y }), scope =>
327324
{
328-
if (switchToGraphModeTemp)
329-
tf.Context.graph_mode();
330-
331325
Tensor result;
332326
var x1 = ops.convert_to_tensor(x, dtype: dtype, name: "x");
333327
var y1 = ops.convert_to_tensor(y, dtype: dtype, name: "y");
@@ -347,7 +341,7 @@ private static Tensor BinaryOpWrapper<Tx, Ty>(string name, Tx x, Ty y)
347341
result = math_ops.truediv(x1, y1, name: scope);
348342
break;
349343
case "mul":
350-
result = gen_math_ops.mul(x1, y1, name: scope);
344+
result = math_ops.multiply(x1, y1, name: scope);
351345
break;
352346
case "sub":
353347
result = gen_math_ops.sub(x1, y1, name: scope);
@@ -359,9 +353,6 @@ private static Tensor BinaryOpWrapper<Tx, Ty>(string name, Tx x, Ty y)
359353
throw new NotImplementedException($"BinaryOpWrapper: {name} - {typeof(Tx).Name}, {typeof(Ty).Name}");
360354
}
361355

362-
if (switchToGraphModeTemp)
363-
tf.Context.restore_mode();
364-
365356
return result;
366357
});
367358
}

0 commit comments

Comments
 (0)