Skip to content

Commit c3ae8b1

Browse files
committed
Build RNN graph, BasicRNNCell not implemented.
1 parent 1a5a7f8 commit c3ae8b1

File tree

8 files changed

+321
-1
lines changed

8 files changed

+321
-1
lines changed

src/TensorFlowNET.Core/APIs/tf.nn.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,8 @@ public static Tensor bias_add(Tensor value, RefVariable bias, string data_format
127127
});
128128
}
129129

130+
public static rnn_cell_impl rnn_cell => new rnn_cell_impl();
131+
130132
public static Tensor softmax(Tensor logits, int axis = -1, string name = null)
131133
=> gen_nn_ops.softmax(logits, name);
132134

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow
6+
{
7+
public class BasicRNNCell
8+
{
9+
}
10+
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.Operations
6+
{
7+
public class rnn_cell_impl
8+
{
9+
public BasicRNNCell BasicRNNCell(int num_units)
10+
{
11+
throw new NotImplementedException();
12+
}
13+
}
14+
}
Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
/*****************************************************************************
2+
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
******************************************************************************/
16+
17+
using NumSharp;
18+
using System;
19+
using System.Collections.Generic;
20+
using System.Text;
21+
using Tensorflow;
22+
using TensorFlowNET.Examples.Utility;
23+
using static Tensorflow.Python;
24+
25+
namespace TensorFlowNET.Examples.ImageProcess
26+
{
27+
/// <summary>
28+
/// Convolutional Neural Network classifier for Hand Written Digits
29+
/// CNN architecture with two convolutional layers, followed by two fully-connected layers at the end.
30+
/// Use Stochastic Gradient Descent (SGD) optimizer.
31+
/// http://www.easy-tensorflow.com/tf-tutorials/convolutional-neural-nets-cnns/cnn1
32+
/// </summary>
33+
public class DigitRecognitionRNN : IExample
34+
{
35+
public bool Enabled { get; set; } = false;
36+
public bool IsImportingGraph { get; set; } = false;
37+
38+
public string Name => "MNIST RNN";
39+
40+
string logs_path = "logs";
41+
42+
// Hyper-parameters
43+
int n_neurons = 128;
44+
float learning_rate = 0.001f;
45+
int batch_size = 128;
46+
int epochs = 10;
47+
48+
int n_steps = 28;
49+
int n_inputs = 28;
50+
int n_outputs = 10;
51+
52+
Datasets<DataSetMnist> mnist;
53+
54+
Tensor x, y;
55+
Tensor loss, accuracy, cls_prediction;
56+
Operation optimizer;
57+
58+
int display_freq = 100;
59+
float accuracy_test = 0f;
60+
float loss_test = 1f;
61+
62+
NDArray x_train, y_train;
63+
NDArray x_valid, y_valid;
64+
NDArray x_test, y_test;
65+
66+
public bool Run()
67+
{
68+
PrepareData();
69+
BuildGraph();
70+
71+
with(tf.Session(), sess =>
72+
{
73+
Train(sess);
74+
Test(sess);
75+
});
76+
77+
return loss_test < 0.09 && accuracy_test > 0.95;
78+
}
79+
80+
public Graph BuildGraph()
81+
{
82+
var graph = new Graph().as_default();
83+
84+
var X = tf.placeholder(tf.float32, new[] { -1, n_steps, n_inputs });
85+
var y = tf.placeholder(tf.int32, new[] { -1 });
86+
var cell = tf.nn.rnn_cell.BasicRNNCell(num_units: n_neurons);
87+
88+
return graph;
89+
}
90+
91+
public void Train(Session sess)
92+
{
93+
// Number of training iterations in each epoch
94+
var num_tr_iter = y_train.len / batch_size;
95+
96+
var init = tf.global_variables_initializer();
97+
sess.run(init);
98+
99+
float loss_val = 100.0f;
100+
float accuracy_val = 0f;
101+
102+
foreach (var epoch in range(epochs))
103+
{
104+
print($"Training epoch: {epoch + 1}");
105+
// Randomly shuffle the training data at the beginning of each epoch
106+
(x_train, y_train) = mnist.Randomize(x_train, y_train);
107+
108+
foreach (var iteration in range(num_tr_iter))
109+
{
110+
var start = iteration * batch_size;
111+
var end = (iteration + 1) * batch_size;
112+
var (x_batch, y_batch) = mnist.GetNextBatch(x_train, y_train, start, end);
113+
114+
// Run optimization op (backprop)
115+
sess.run(optimizer, new FeedItem(x, x_batch), new FeedItem(y, y_batch));
116+
117+
if (iteration % display_freq == 0)
118+
{
119+
// Calculate and display the batch loss and accuracy
120+
var result = sess.run(new[] { loss, accuracy }, new FeedItem(x, x_batch), new FeedItem(y, y_batch));
121+
loss_val = result[0];
122+
accuracy_val = result[1];
123+
print($"iter {iteration.ToString("000")}: Loss={loss_val.ToString("0.0000")}, Training Accuracy={accuracy_val.ToString("P")}");
124+
}
125+
}
126+
127+
// Run validation after every epoch
128+
var results1 = sess.run(new[] { loss, accuracy }, new FeedItem(x, x_valid), new FeedItem(y, y_valid));
129+
loss_val = results1[0];
130+
accuracy_val = results1[1];
131+
print("---------------------------------------------------------");
132+
print($"Epoch: {epoch + 1}, validation loss: {loss_val.ToString("0.0000")}, validation accuracy: {accuracy_val.ToString("P")}");
133+
print("---------------------------------------------------------");
134+
}
135+
}
136+
137+
public void Test(Session sess)
138+
{
139+
var result = sess.run(new[] { loss, accuracy }, new FeedItem(x, x_test), new FeedItem(y, y_test));
140+
loss_test = result[0];
141+
accuracy_test = result[1];
142+
print("---------------------------------------------------------");
143+
print($"Test loss: {loss_test.ToString("0.0000")}, test accuracy: {accuracy_test.ToString("P")}");
144+
print("---------------------------------------------------------");
145+
}
146+
147+
public void PrepareData()
148+
{
149+
mnist = MNIST.read_data_sets("mnist", one_hot: true);
150+
(x_train, y_train) = (mnist.train.data, mnist.train.labels);
151+
(x_valid, y_valid) = (mnist.validation.data, mnist.validation.labels);
152+
(x_test, y_test) = (mnist.test.data, mnist.test.labels);
153+
154+
print("Size of:");
155+
print($"- Training-set:\t\t{len(mnist.train.data)}");
156+
print($"- Validation-set:\t{len(mnist.validation.data)}");
157+
}
158+
159+
public Graph ImportGraph() => throw new NotImplementedException();
160+
161+
public void Predict(Session sess) => throw new NotImplementedException();
162+
}
163+
}
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
import tensorflow as tf
2+
from tensorflow.contrib import rnn
3+
4+
#import mnist dataset
5+
from tensorflow.examples.tutorials.mnist import input_data
6+
mnist=input_data.read_data_sets("/tmp/data/",one_hot=True)
7+
8+
#define constants
9+
#unrolled through 28 time steps
10+
time_steps=28
11+
#hidden LSTM units
12+
num_units=128
13+
#rows of 28 pixels
14+
n_input=28
15+
#learning rate for adam
16+
learning_rate=0.001
17+
#mnist is meant to be classified in 10 classes(0-9).
18+
n_classes=10
19+
#size of batch
20+
batch_size=128
21+
22+
23+
#weights and biases of appropriate shape to accomplish above task
24+
out_weights=tf.Variable(tf.random_normal([num_units,n_classes]))
25+
out_bias=tf.Variable(tf.random_normal([n_classes]))
26+
27+
#defining placeholders
28+
#input image placeholder
29+
x=tf.placeholder("float",[None,time_steps,n_input])
30+
#input label placeholder
31+
y=tf.placeholder("float",[None,n_classes])
32+
33+
#processing the input tensor from [batch_size,n_steps,n_input] to "time_steps" number of [batch_size,n_input] tensors
34+
input=tf.unstack(x ,time_steps,1)
35+
36+
#defining the network
37+
lstm_layer=rnn.BasicLSTMCell(num_units,forget_bias=1)
38+
outputs,_=rnn.static_rnn(lstm_layer,input,dtype="float32")
39+
40+
#converting last output of dimension [batch_size,num_units] to [batch_size,n_classes] by out_weight multiplication
41+
prediction=tf.matmul(outputs[-1],out_weights)+out_bias
42+
43+
#loss_function
44+
loss=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=prediction,labels=y))
45+
#optimization
46+
opt=tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss)
47+
48+
#model evaluation
49+
correct_prediction=tf.equal(tf.argmax(prediction,1),tf.argmax(y,1))
50+
accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
51+
52+
#initialize variables
53+
init=tf.global_variables_initializer()
54+
with tf.Session() as sess:
55+
sess.run(init)
56+
iter=1
57+
while iter<800:
58+
batch_x,batch_y=mnist.train.next_batch(batch_size=batch_size)
59+
60+
batch_x=batch_x.reshape((batch_size,time_steps,n_input))
61+
62+
sess.run(opt, feed_dict={x: batch_x, y: batch_y})
63+
64+
if iter %10==0:
65+
acc=sess.run(accuracy,feed_dict={x:batch_x,y:batch_y})
66+
los=sess.run(loss,feed_dict={x:batch_x,y:batch_y})
67+
print("For iter ",iter)
68+
print("Accuracy ",acc)
69+
print("Loss ",los)
70+
print("__________________")
71+
72+
iter=iter+1
73+
74+
#calculating test accuracy
75+
test_data = mnist.test.images[:128].reshape((-1, time_steps, n_input))
76+
test_label = mnist.test.labels[:128]
77+
print("Testing Accuracy:", sess.run(accuracy, feed_dict={x: test_data, y: test_label}))
78+
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import tensorflow as tf
2+
3+
# hyperparameters
4+
n_neurons = 128
5+
learning_rate = 0.001
6+
batch_size = 128
7+
n_epochs = 10
8+
# parameters
9+
n_steps = 28 # 28 rows
10+
n_inputs = 28 # 28 cols
11+
n_outputs = 10 # 10 classes
12+
# build a rnn model
13+
X = tf.placeholder(tf.float32, [None, n_steps, n_inputs])
14+
y = tf.placeholder(tf.int32, [None])
15+
cell = tf.nn.rnn_cell.BasicRNNCell(num_units=n_neurons)
16+
output, state = tf.nn.dynamic_rnn(cell, X, dtype=tf.float32)
17+
logits = tf.layers.dense(state, n_outputs)
18+
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y, logits=logits)
19+
loss = tf.reduce_mean(cross_entropy)
20+
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss)
21+
prediction = tf.nn.in_top_k(logits, y, 1)
22+
accuracy = tf.reduce_mean(tf.cast(prediction, tf.float32))
23+
24+
# input data
25+
from tensorflow.examples.tutorials.mnist import input_data
26+
mnist = input_data.read_data_sets("MNIST_data/")
27+
X_test = mnist.test.images # X_test shape: [num_test, 28*28]
28+
X_test = X_test.reshape([-1, n_steps, n_inputs])
29+
y_test = mnist.test.labels
30+
31+
# initialize the variables
32+
init = tf.global_variables_initializer()
33+
# train the model
34+
with tf.Session() as sess:
35+
sess.run(init)
36+
n_batches = mnist.train.num_examples // batch_size
37+
for epoch in range(n_epochs):
38+
for batch in range(n_batches):
39+
X_train, y_train = mnist.train.next_batch(batch_size)
40+
X_train = X_train.reshape([-1, n_steps, n_inputs])
41+
sess.run(optimizer, feed_dict={X: X_train, y: y_train})
42+
loss_train, acc_train = sess.run(
43+
[loss, accuracy], feed_dict={X: X_train, y: y_train})
44+
print('Epoch: {}, Train Loss: {:.3f}, Train Acc: {:.3f}'.format(
45+
epoch + 1, loss_train, acc_train))
46+
loss_test, acc_test = sess.run(
47+
[loss, accuracy], feed_dict={X: X_test, y: y_test})
48+
print('Test Loss: {:.3f}, Test Acc: {:.3f}'.format(loss_test, acc_test))

test/TensorFlowNET.UnitTest/control_flow_ops_test/CondTestCases.cs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test
1212
[TestClass]
1313
public class CondTestCases : PythonTest
1414
{
15+
[Ignore("need tesnroflow expose AddControlInput API")]
1516
[TestMethod]
1617
public void testCondTrue_ConstOnly()
1718
{
@@ -31,6 +32,7 @@ public void testCondTrue_ConstOnly()
3132
});
3233
}
3334

35+
[Ignore("need tesnroflow expose AddControlInput API")]
3436
[TestMethod]
3537
public void testCondFalse_ConstOnly()
3638
{
@@ -50,6 +52,7 @@ public void testCondFalse_ConstOnly()
5052
});
5153
}
5254

55+
[Ignore("need tesnroflow expose AddControlInput API")]
5356
[TestMethod]
5457
public void testCondTrue()
5558
{
@@ -66,6 +69,7 @@ public void testCondTrue()
6669
assertEquals(result, 34);
6770
}
6871

72+
[Ignore("need tesnroflow expose AddControlInput API")]
6973
[TestMethod]
7074
public void testCondFalse()
7175
{

test/TensorFlowNET.UnitTest/ops_test/CreateOpFromTfOperationTest.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,11 +65,12 @@ public void TestUniqueName()
6565
});
6666
}
6767

68+
[Ignore("need tesnroflow expose UpdateEdge API")]
6869
[TestMethod]
6970
public void TestCond()
7071
{
7172
var graph = tf.Graph().as_default();
72-
with<Graph>(graph, g =>
73+
with(graph, g =>
7374
{
7475
var x = constant_op.constant(10);
7576

0 commit comments

Comments
 (0)