Skip to content

Commit 48a11d4

Browse files
committed
CondContext, BatchNormalization.
1 parent c26fccf commit 48a11d4

File tree

16 files changed

+276
-13
lines changed

16 files changed

+276
-13
lines changed

src/TensorFlowNET.Core/APIs/tf.nn.cs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,20 @@ public static Tensor embedding_lookup(RefVariable @params,
2626
name: name);
2727

2828
public static IActivation relu => new relu();
29+
30+
public static (Tensor, Tensor, Tensor) fused_batch_norm(Tensor x,
31+
RefVariable scale,
32+
RefVariable offset,
33+
Tensor mean = null,
34+
Tensor variance = null,
35+
float epsilon = 0.001f,
36+
string data_format = "NHWC",
37+
bool is_training = true,
38+
string name = null) => nn_impl.fused_batch_norm(x, scale, offset, mean, variance,
39+
epsilon: epsilon,
40+
data_format: data_format,
41+
is_training: is_training,
42+
name: name);
2943
}
3044
}
3145
}

src/TensorFlowNET.Core/Framework/smart_module.cs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@ namespace Tensorflow.Framework
66
{
77
public class smart_module
88
{
9-
public static object smart_cond(Tensor pred, Action true_fn = null, Action false_fn = null, string name = null)
9+
public static object smart_cond(Tensor pred,
10+
Func<(Tensor, Tensor, Tensor)> true_fn = null,
11+
Func<(Tensor, Tensor, Tensor)> false_fn = null,
12+
string name = null)
1013
{
1114
return control_flow_ops.cond(pred,
1215
true_fn: true_fn,

src/TensorFlowNET.Core/Graphs/Graph.Control.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ namespace Tensorflow
88
{
99
public partial class Graph
1010
{
11-
public Context _control_flow_context;
11+
public IControlFlowContext _control_flow_context;
1212

1313
private Queue<_ControlDependenciesController> _graph_control_dependencies_stack = new Queue<_ControlDependenciesController>();
1414
public Queue<_ControlDependenciesController> _control_dependencies_stack
@@ -72,7 +72,7 @@ public _ControlDependenciesController control_dependencies(ITensorOrOperation[]
7272
/// Returns the current control flow context.
7373
/// </summary>
7474
/// <returns>A context object.</returns>
75-
public Context _get_control_flow_context()
75+
public IControlFlowContext _get_control_flow_context()
7676
{
7777
return _control_flow_context;
7878
}
@@ -81,7 +81,7 @@ public Context _get_control_flow_context()
8181
/// Sets the current control flow context.
8282
/// </summary>
8383
/// <param name="ctx">a context object.</param>
84-
public void _set_control_flow_context(Context ctx)
84+
public void _set_control_flow_context(IControlFlowContext ctx)
8585
{
8686
_control_flow_context = ctx;
8787
}

src/TensorFlowNET.Core/Graphs/_ControlDependenciesController.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ public class _ControlDependenciesController : IPython
1515
private List<ITensorOrOperation> _seen_nodes;
1616
private Queue<_ControlDependenciesController> _old_stack;
1717
private bool _new_stack;
18-
private Context _old_control_flow_context;
18+
private IControlFlowContext _old_control_flow_context;
1919

2020
public ITensorOrOperation[] control_inputs => _control_inputs_val.ToArray();
2121

src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -142,14 +142,27 @@ private Tensor _fused_batch_norm(Tensor inputs, Tensor training)
142142
var beta = this.beta;
143143
var gamma = this.gamma;
144144

145-
Action _fused_batch_norm_training = () =>
145+
Func<(Tensor, Tensor, Tensor)> _fused_batch_norm_training = () =>
146146
{
147-
147+
return tf.nn.fused_batch_norm(
148+
inputs,
149+
gamma,
150+
beta,
151+
epsilon: epsilon,
152+
data_format: _data_format);
148153
};
149154

150-
Action _fused_batch_norm_inference = () =>
155+
Func<(Tensor, Tensor, Tensor)> _fused_batch_norm_inference = () =>
151156
{
152-
157+
return tf.nn.fused_batch_norm(
158+
inputs,
159+
gamma,
160+
beta,
161+
mean: moving_mean,
162+
variance: moving_variance,
163+
epsilon: epsilon,
164+
is_training: false,
165+
data_format: _data_format);
153166
};
154167

155168
tf_utils.smart_cond(training, _fused_batch_norm_training, _fused_batch_norm_inference);

src/TensorFlowNET.Core/Keras/Utils/tf_utils.cs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,10 @@ public static bool is_symbolic_tensor(Tensor tensor)
1818
return true;
1919
}
2020

21-
public static object smart_cond(Tensor pred, Action true_fn = null, Action false_fn = null, string name = null)
21+
public static object smart_cond(Tensor pred,
22+
Func<(Tensor, Tensor, Tensor)> true_fn = null,
23+
Func<(Tensor, Tensor, Tensor)> false_fn = null,
24+
string name = null)
2225
{
2326
return smart_module.smart_cond(pred,
2427
true_fn: true_fn,
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.Operations
6+
{
7+
/// <summary>
8+
/// The context for the conditional construct.
9+
/// </summary>
10+
public class CondContext : ControlFlowContext
11+
{
12+
private string _name;
13+
/// <summary>
14+
/// The boolean tensor for the cond predicate
15+
/// </summary>
16+
private Tensor _pred;
17+
/// <summary>
18+
/// The predicate tensor in this branch
19+
/// </summary>
20+
private Tensor _pivot;
21+
/// <summary>
22+
/// 0 or 1 representing this branch
23+
/// </summary>
24+
private int _branch;
25+
/// <summary>
26+
///
27+
/// </summary>
28+
private List<string> _values = new List<string>();
29+
private Dictionary<string, Tensor> _external_values = new Dictionary<string, Tensor>();
30+
31+
/// <summary>
32+
///
33+
/// </summary>
34+
/// <param name="pred">The `boolean` tensor for the conditional predicate.</param>
35+
/// <param name="pivot">The predicate tensor in this branch.</param>
36+
/// <param name="branch">0 or 1 representing this branch.</param>
37+
/// <param name="name">Name of the `CondContext` python object.</param>
38+
/// <param name="context_def"></param>
39+
/// <param name="import_scope"></param>
40+
public CondContext(Tensor pred,
41+
Tensor pivot,
42+
int branch,
43+
string name = "cond_text",
44+
object context_def = null,
45+
string import_scope = null)
46+
{
47+
_name = ops.get_default_graph().unique_name(name);
48+
if (context_def != null)
49+
throw new NotImplementedException("CondContext context_def is not null");
50+
else
51+
{
52+
// Initializes the default fields.
53+
base.__init__();
54+
_pred = pred;
55+
_pivot = pivot;
56+
57+
// Values considered to have been already seen in this context. pred is not
58+
// included in this context.
59+
_values.Add(pred.name);
60+
_external_values[pred.name] = pred;
61+
_values.Add(pivot.name);
62+
pivot.op._set_control_flow_context(this);
63+
}
64+
}
65+
66+
public (Tensor, Tensor, Tensor) BuildCondBranch(Func<(Tensor, Tensor, Tensor)> fn)
67+
{
68+
// Add the subgraph defined by fn() to the graph.
69+
var pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION);
70+
var original_result = fn();
71+
var post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION);
72+
73+
return original_result;
74+
}
75+
}
76+
}
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.Operations
6+
{
7+
public abstract class ControlFlowContext : IPython, IControlFlowContext
8+
{
9+
protected Stack<IControlFlowContext> _context_stack;
10+
public ControlFlowContext()
11+
{
12+
_context_stack = new Stack<IControlFlowContext>();
13+
}
14+
15+
public void __init__()
16+
{
17+
18+
}
19+
20+
public void __enter__()
21+
{
22+
}
23+
24+
public virtual void Enter()
25+
{
26+
var graph = ops.get_default_graph();
27+
_context_stack.Push(graph._get_control_flow_context());
28+
graph._set_control_flow_context(this);
29+
}
30+
31+
public void Exit()
32+
{
33+
var graph = ops.get_default_graph();
34+
var last_context = _context_stack.Pop();
35+
graph._set_control_flow_context(last_context);
36+
}
37+
38+
public void __exit__()
39+
{
40+
}
41+
42+
public void Dispose()
43+
{
44+
}
45+
}
46+
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow
6+
{
7+
public interface IControlFlowContext
8+
{
9+
}
10+
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.Operations
6+
{
7+
public class WhileContext : ControlFlowContext
8+
{
9+
}
10+
}

0 commit comments

Comments
 (0)