Skip to content

Commit ca1fa35

Browse files
committed
Add VariableTest.Add test in UnitTest
1 parent 838b12e commit ca1fa35

File tree

8 files changed

+120
-7
lines changed

8 files changed

+120
-7
lines changed

src/TensorFlowNET.Core/Graphs/Graph.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,7 @@ public Operation[] get_operations()
288288
return _nodes_by_name.Values.Select(x => x).ToArray();
289289
}
290290

291-
public object get_collection(string name)
291+
public object get_collection(string name, string scope = "")
292292
{
293293
return _collections.ContainsKey(name) ? _collections[name] : null;
294294
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow
6+
{
7+
public class control_flow_ops
8+
{
9+
public static Operation group(Operation[] inputs)
10+
{
11+
return null;
12+
}
13+
}
14+
}

src/TensorFlowNET.Core/Python.cs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow
6+
{
7+
/// <summary>
8+
/// Mapping C# functions to Python
9+
/// </summary>
10+
public class Python
11+
{
12+
protected void print(object obj)
13+
{
14+
Console.WriteLine(obj.ToString());
15+
}
16+
}
17+
}

src/TensorFlowNET.Core/Variables/RefVariable.cs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@ public class RefVariable : VariableV1
1212
public bool _trainable;
1313
public Tensor _variable;
1414
public Tensor _snapshot;
15-
public Operation op;
15+
16+
private Operation _initializer_op;
17+
public Operation initializer => _initializer_op;
18+
public Operation op => _initializer_op;
1619

1720
public RefVariable(object initial_value,
1821
bool trainable = true,
@@ -81,7 +84,7 @@ private void _init_from_args(object initial_value,
8184
// have an issue if these other variables aren't initialized first by
8285
// using their initialized_value() method.
8386

84-
var _initializer_op = gen_state_ops.assign(_variable, _initial_value, validate_shape).op;
87+
_initializer_op = gen_state_ops.assign(_variable, _initial_value, validate_shape).op;
8588

8689
if (!String.IsNullOrEmpty(caching_device))
8790
{
@@ -92,7 +95,6 @@ private void _init_from_args(object initial_value,
9295
_snapshot = gen_array_ops.identity(_variable, name = "read");
9396
}
9497

95-
op = _initializer_op;
9698
ops.add_to_collections(collections, this);
9799
}
98100
}
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow
6+
{
7+
public static partial class tf
8+
{
9+
public static Operation global_variables_initializer()
10+
{
11+
var g = variables.global_variables();
12+
return variables.variables_initializer(g as RefVariable[]);
13+
}
14+
}
15+
}

src/TensorFlowNET.Core/Variables/variables.py.cs

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using System;
22
using System.Collections.Generic;
3+
using System.Linq;
34
using System.Text;
45

56
namespace Tensorflow
@@ -14,5 +15,32 @@ public static object trainable_variables()
1415
{
1516
return ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES);
1617
}
18+
19+
/// <summary>
20+
/// Returns global variables.
21+
/// </summary>
22+
/// <param name="scope">
23+
/// (Optional.) A string. If supplied, the resulting list is filtered
24+
/// to include only items whose `name` attribute matches `scope` using
25+
/// `re.match`. Items without a `name` attribute are never returned if a
26+
/// scope is supplied. The choice of `re.match` means that a `scope` without
27+
/// special tokens filters by prefix.
28+
/// </param>
29+
/// <returns>A list of `Variable` objects.</returns>
30+
public static object global_variables(string scope = "")
31+
{
32+
return ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope);
33+
}
34+
35+
/// <summary>
36+
/// Returns an Op that initializes a list of variables.
37+
/// </summary>
38+
/// <param name="var_list">List of `Variable` objects to initialize.</param>
39+
/// <param name="name">Optional name for the returned operation.</param>
40+
/// <returns>An Op that run the initializers of all the specified variables.</returns>
41+
public static Operation variables_initializer(RefVariable[] var_list, string name = "init")
42+
{
43+
return control_flow_ops.group(var_list.Select(x => x.initializer).ToArray());
44+
}
1745
}
1846
}

src/TensorFlowNET.Core/ops.py.cs

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,23 @@ public static void add_to_collections<T>(List<string> names, T value)
2424
graph.add_to_collections(names, value);
2525
}
2626

27-
public static object get_collection(string key)
27+
/// <summary>
28+
/// Wrapper for `Graph.get_collection()` using the default graph.
29+
/// contains many standard names for collections.
30+
/// </summary>
31+
/// <param name="key">
32+
/// The key for the collection. For example, the `GraphKeys` class
33+
/// </param>
34+
/// <param name="scope"></param>
35+
/// <returns>
36+
/// The list of values in the collection with the given `name`, or
37+
/// an empty list if no value has been added to that collection. The
38+
/// list contains the values in the order under which they were
39+
/// collected.
40+
/// </returns>
41+
public static object get_collection(string key, string scope = "")
2842
{
29-
return get_default_graph().get_collection(key);
43+
return get_default_graph().get_collection(key, scope);
3044
}
3145

3246
public static Graph get_default_graph()

test/TensorFlowNET.UnitTest/VariableTest.cs

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
namespace TensorFlowNET.UnitTest
88
{
99
[TestClass]
10-
public class VariableTest
10+
public class VariableTest : Python
1111
{
1212
[TestMethod]
1313
public void StringVar()
@@ -22,5 +22,28 @@ public void ScalarVar()
2222
var x = tf.Variable(3);
2323
var y = tf.Variable(6f);
2424
}
25+
26+
/// <summary>
27+
/// https://databricks.com/tensorflow/variables
28+
/// </summary>
29+
[TestMethod]
30+
public void Add()
31+
{
32+
var x = tf.Variable(0, name: "x");
33+
34+
var model = tf.global_variables_initializer();
35+
36+
using (var session = tf.Session())
37+
{
38+
/*session.run(model);
39+
for(int i = 0; i < 5; i++)
40+
{
41+
x = x + 1;
42+
var result = session.run(x);
43+
print(result);
44+
}*/
45+
}
46+
47+
}
2548
}
2649
}

0 commit comments

Comments
 (0)