Skip to content

Commit 9665bf5

Browse files
committed
Aadded ContextOptions
1 parent 0f49cc9 commit 9665bf5

File tree

8 files changed

+109
-8
lines changed

8 files changed

+109
-8
lines changed

src/TensorFlowNET.Core/Eager/Context.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
using System.Collections.Generic;
33
using System.Text;
44

5-
namespace Tensorflow
5+
namespace Tensorflow.Eager
66
{
77
public class Context
88
{
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.Eager
6+
{
7+
public class ContextOptions : IDisposable
8+
{
9+
private IntPtr _handle;
10+
11+
public ContextOptions()
12+
{
13+
_handle = c_api.TFE_NewContextOptions();
14+
}
15+
16+
public void Dispose()
17+
{
18+
c_api.TFE_DeleteContextOptions(_handle);
19+
}
20+
21+
public static implicit operator IntPtr(ContextOptions ctx)
22+
{
23+
return ctx._handle;
24+
}
25+
}
26+
}

src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -102,16 +102,21 @@ private static void _PendingCount(List<Operation> to_ops, List<Operation> from_o
102102
/// <param name="func_graphs"></param>
103103
private static void _MarkReachedOps(List<Operation> from_ops, List<Operation> reached_ops, List<object> func_graphs)
104104
{
105-
foreach(var op in from_ops)
105+
Queue<Operation> queue = new Queue<Operation>(from_ops);
106+
while (queue.Count > 0)
106107
{
107-
reached_ops.Add(op);
108-
foreach(var output in op.outputs)
108+
var op = queue.Dequeue();
109+
110+
if (!reached_ops.Contains(op))
109111
{
110-
reached_ops.AddRange(_Consumers(output, func_graphs));
112+
reached_ops.Add(op);
113+
foreach (var output in op.outputs)
114+
{
115+
var c = _Consumers(output, func_graphs).ToList();
116+
c.ForEach(x => queue.Enqueue(x));
117+
}
111118
}
112119
}
113-
114-
reached_ops.Reverse();
115120
}
116121

117122
/// <summary>

src/TensorFlowNET.Core/Variables/RefVariable.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ public class RefVariable : VariableV1
1212
public bool _trainable;
1313
public Tensor _variable;
1414
public Tensor _snapshot;
15+
public Operation op;
1516

1617
public RefVariable(object initial_value,
1718
bool trainable = true,
@@ -91,6 +92,7 @@ private void _init_from_args(object initial_value,
9192
_snapshot = gen_array_ops.identity(_variable, name = "read");
9293
}
9394

95+
op = _initializer_op;
9496
ops.add_to_collections(collections, this);
9597
}
9698
}

tensorflowlib/README.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,8 @@ Here are some pre-built TensorFlow binaries you can use for each platform:
44
- CPU-only: https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-linux-x86_64-1.12.0.tar.gz
55
- GPU-enabled: https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-gpu-linux-x86_64-1.12.0.tar.gz
66
- Mac: https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-darwin-x86_64-1.12.0.tar.gz
7-
- Windows: https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-windows-x86_64-1.12.0.zip
7+
- Windows: https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-windows-x86_64-1.12.0.zip
8+
9+
https://www.tensorflow.org/install/source_windows
10+
pacman -S git patch unzip
11+
bazel build --config=opt //tensorflow:libtensorflow.so

test/TensorFlowNET.Examples/LinearRegression.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ public void Run()
5050
// Note, minimize() knows to modify W and b because Variable objects are trainable=True by default
5151
var optimizer = tf.train.GradientDescentOptimizer(learning_rate);
5252
optimizer.minimize(cost);
53+
54+
5355
}
5456
}
5557
}
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
using Microsoft.VisualStudio.TestTools.UnitTesting;
2+
using System;
3+
using System.Collections.Generic;
4+
using System.Text;
5+
using Tensorflow;
6+
7+
namespace TensorFlowNET.UnitTest
8+
{
9+
[TestClass]
10+
public class ConsumersTest : CApiTest
11+
{
12+
[TestMethod]
13+
public void Constant()
14+
{
15+
var X = tf.placeholder(tf.float64);
16+
var W = tf.constant(1.0D);
17+
18+
var mul = tf.multiply(X, W);
19+
EXPECT_EQ(1, X.op.OutputNumConsumers(0));
20+
EXPECT_EQ(1, W.op.OutputNumConsumers(0));
21+
}
22+
23+
[TestMethod]
24+
public void Variable()
25+
{
26+
var X = tf.placeholder(tf.float64);
27+
var W = tf.Variable(1.0D, name: "var");
28+
29+
var mul = tf.multiply(X, W);
30+
EXPECT_EQ(1, X.op.OutputNumConsumers(0));
31+
EXPECT_EQ(1, W.op.OutputNumConsumers(0));
32+
}
33+
}
34+
}
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
using Microsoft.VisualStudio.TestTools.UnitTesting;
2+
using System;
3+
using System.Collections.Generic;
4+
using System.Text;
5+
using Tensorflow;
6+
7+
namespace TensorFlowNET.UnitTest.Eager
8+
{
9+
/// <summary>
10+
/// tensorflow\c\eager\c_api_test.cc
11+
/// </summary>
12+
[TestClass]
13+
public class CApiVariableTest : CApiTest, IDisposable
14+
{
15+
Status status = new Status();
16+
17+
[TestMethod]
18+
public void Variables()
19+
{
20+
21+
}
22+
23+
public void Dispose()
24+
{
25+
26+
}
27+
}
28+
}

0 commit comments

Comments
 (0)