Skip to content

Commit bc4fbf1

Browse files
committed
MNIST CNN not finished yet.
1 parent b430d6f commit bc4fbf1

File tree

9 files changed

+267
-35
lines changed

9 files changed

+267
-35
lines changed

src/TensorFlowNET.Core/APIs/tf.nn.cs

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,24 @@ public static partial class tf
2727
{
2828
public static class nn
2929
{
30+
public static Tensor conv2d(Tensor input, RefVariable filter, int[] strides, string padding, bool use_cudnn_on_gpu = true,
31+
string data_format= "NHWC", int[] dilations= null, string name = null)
32+
{
33+
if (dilations == null)
34+
dilations = new[] { 1, 1, 1, 1 };
35+
36+
return gen_nn_ops.conv2d(new Conv2dParams
37+
{
38+
Input = input,
39+
Filter = filter,
40+
Strides = strides,
41+
UseCudnnOnGpu = use_cudnn_on_gpu,
42+
DataFormat = data_format,
43+
Dilations = dilations,
44+
Name = name
45+
});
46+
}
47+
3048
/// <summary>
3149
/// Computes dropout.
3250
/// </summary>
@@ -90,7 +108,10 @@ public static Tensor[] fused_batch_norm(Tensor x,
90108
is_training: is_training,
91109
name: name);
92110

93-
public static IPoolFunction max_pool => new MaxPoolFunction();
111+
public static IPoolFunction max_pool_fn => new MaxPoolFunction();
112+
113+
public static Tensor max_pool(Tensor value, int[] ksize, int[] strides, string padding, string data_format = "NHWC", string name = null)
114+
=> nn_ops.max_pool(value, ksize, strides, padding, data_format: data_format, name: name);
94115

95116
public static Tensor[] top_k(Tensor input, int k = 1, bool sorted = true, string name = null)
96117
=> gen_nn_ops.top_kv2(input, k: k, sorted: sorted, name: name);

src/TensorFlowNET.Core/Keras/Layers/MaxPooling2D.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ public MaxPooling2D(
1212
int[] strides,
1313
string padding = "valid",
1414
string data_format = null,
15-
string name = null) : base(nn.max_pool, pool_size,
15+
string name = null) : base(nn.max_pool_fn, pool_size,
1616
strides,
1717
padding: padding,
1818
data_format: data_format,

src/TensorFlowNET.Core/Operations/nn_ops.cs

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,38 @@ public static Tensor log_softmax(Tensor logits, int axis = -1, string name = nul
118118
return _softmax(logits, gen_nn_ops.log_softmax, axis, name);
119119
}
120120

121+
/// <summary>
122+
/// Performs the max pooling on the input.
123+
/// </summary>
124+
/// <param name="value">A 4-D `Tensor` of the format specified by `data_format`.</param>
125+
/// <param name="ksize">
126+
/// A list or tuple of 4 ints. The size of the window for each dimension
127+
/// of the input tensor.
128+
/// </param>
129+
/// <param name="strides">
130+
/// A list or tuple of 4 ints. The stride of the sliding window for
131+
/// each dimension of the input tensor.
132+
/// </param>
133+
/// <param name="padding">A string, either `'VALID'` or `'SAME'`. The padding algorithm.</param>
134+
/// <param name="data_format">A string. 'NHWC', 'NCHW' and 'NCHW_VECT_C' are supported.</param>
135+
/// <param name="name">Optional name for the operation.</param>
136+
/// <returns></returns>
137+
public static Tensor max_pool(Tensor value, int[] ksize, int[] strides, string padding, string data_format = "NHWC", string name = null)
138+
{
139+
return with(ops.name_scope(name, "MaxPool", value), scope =>
140+
{
141+
name = scope;
142+
value = ops.convert_to_tensor(value, name: "input");
143+
return gen_nn_ops.max_pool(
144+
value,
145+
ksize: ksize,
146+
strides: strides,
147+
padding: padding,
148+
data_format: data_format,
149+
name: name);
150+
});
151+
}
152+
121153
public static Tensor _softmax(Tensor logits, Func<Tensor, string, Tensor> compute_op, int dim = -1, string name = null)
122154
{
123155
logits = ops.convert_to_tensor(logits);

src/TensorFlowNET.Core/Tensors/TensorShape.cs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,16 @@ public TensorShape(params int[] dims) : base(dims)
2424

2525
}
2626

27+
public TensorShape this[Slice slice]
28+
{
29+
get
30+
{
31+
return new TensorShape(Dimensions.Skip(slice.Start.Value)
32+
.Take(slice.Length.Value)
33+
.ToArray());
34+
}
35+
}
36+
2737
/// <summary>
2838
/// Returns True iff `self` is fully defined in every dimension.
2939
/// </summary>
@@ -38,6 +48,9 @@ public bool is_compatible_with(TensorShape shape2)
3848
throw new NotImplementedException("TensorShape is_compatible_with");
3949
}
4050

51+
public static implicit operator TensorShape(int[] dims) => new TensorShape(dims);
4152
public static implicit operator TensorShape((int, int) dims) => new TensorShape(dims.Item1, dims.Item2);
53+
public static implicit operator TensorShape((int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3);
54+
public static implicit operator TensorShape((int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4);
4255
}
4356
}

test/TensorFlowNET.Examples/IExample.cs

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,20 @@
1-
using System;
1+
/*****************************************************************************
2+
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
******************************************************************************/
16+
17+
using System;
218
using System.Collections.Generic;
319
using System.Text;
420
using Tensorflow;

test/TensorFlowNET.Examples/ImageProcess/DigitRecognitionCNN.cs

Lines changed: 131 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ You may obtain a copy of the License at
1414
limitations under the License.
1515
******************************************************************************/
1616

17+
using NumSharp;
1718
using System;
1819
using System.Collections.Generic;
1920
using System.Text;
@@ -65,7 +66,7 @@ public class DigitRecognitionCNN : IExample
6566

6667

6768
Tensor x, y;
68-
Tensor loss, accuracy;
69+
Tensor loss, accuracy, cls_prediction;
6970
Operation optimizer;
7071

7172
int display_freq = 100;
@@ -90,47 +91,148 @@ public Graph BuildGraph()
9091
{
9192
var graph = new Graph().as_default();
9293

93-
// Placeholders for inputs (x) and outputs(y)
94-
x = tf.placeholder(tf.float32, shape: (-1, img_size_flat), name: "X");
95-
y = tf.placeholder(tf.float32, shape: (-1, n_classes), name: "Y");
94+
with(tf.name_scope("Input"), delegate
95+
{
96+
// Placeholders for inputs (x) and outputs(y)
97+
x = tf.placeholder(tf.float32, shape: (-1, img_h, img_w, n_channels), name: "X");
98+
y = tf.placeholder(tf.float32, shape: (-1, n_classes), name: "Y");
99+
});
96100

97-
// Create a fully-connected layer with h1 nodes as hidden layer
98-
var fc1 = fc_layer(x, h1, "FC1", use_relu: true);
99-
// Create a fully-connected layer with n_classes nodes as output layer
101+
var conv1 = conv_layer(x, filter_size1, num_filters1, stride1, name: "conv1");
102+
var pool1 = max_pool(conv1, ksize: 2, stride: 2, name: "pool1");
103+
var conv2 = conv_layer(pool1, filter_size2, num_filters2, stride2, name: "conv2");
104+
var pool2 = max_pool(conv2, ksize: 2, stride: 2, name: "pool2");
105+
var layer_flat = flatten_layer(pool2);
106+
var fc1 = fc_layer(layer_flat, h1, "FC1", use_relu: true);
100107
var output_logits = fc_layer(fc1, n_classes, "OUT", use_relu: false);
101-
// Define the loss function, optimizer, and accuracy
102-
var logits = tf.nn.softmax_cross_entropy_with_logits(labels: y, logits: output_logits);
103-
loss = tf.reduce_mean(logits, name: "loss");
104-
optimizer = tf.train.AdamOptimizer(learning_rate: learning_rate, name: "Adam-op").minimize(loss);
105-
var correct_prediction = tf.equal(tf.argmax(output_logits, 1), tf.argmax(y, 1), name: "correct_pred");
106-
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name: "accuracy");
107108

108-
// Network predictions
109-
var cls_prediction = tf.argmax(output_logits, axis: 1, name: "predictions");
109+
with(tf.variable_scope("Train"), delegate
110+
{
111+
with(tf.variable_scope("Loss"), delegate
112+
{
113+
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels: y, logits: output_logits), name: "loss");
114+
});
115+
116+
with(tf.variable_scope("Optimizer"), delegate
117+
{
118+
optimizer = tf.train.AdamOptimizer(learning_rate: learning_rate, name: "Adam-op").minimize(loss);
119+
});
120+
121+
with(tf.variable_scope("Accuracy"), delegate
122+
{
123+
var correct_prediction = tf.equal(tf.argmax(output_logits, 1), tf.argmax(y, 1), name: "correct_pred");
124+
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name: "accuracy");
125+
});
126+
127+
with(tf.variable_scope("Prediction"), delegate
128+
{
129+
cls_prediction = tf.argmax(output_logits, axis: 1, name: "predictions");
130+
});
131+
});
110132

111133
return graph;
112134
}
113135

114-
private Tensor fc_layer(Tensor x, int num_units, string name, bool use_relu = true)
136+
/// <summary>
137+
/// Create a 2D convolution layer
138+
/// </summary>
139+
/// <param name="x">input from previous layer</param>
140+
/// <param name="filter_size">size of each filter</param>
141+
/// <param name="num_filters">number of filters(or output feature maps)</param>
142+
/// <param name="stride">filter stride</param>
143+
/// <param name="name">layer name</param>
144+
/// <returns>The output array</returns>
145+
private Tensor conv_layer(Tensor x, int filter_size, int num_filters, int stride, string name)
146+
{
147+
return with(tf.variable_scope(name), delegate {
148+
149+
var num_in_channel = x.shape[x.NDims - 1];
150+
var shape = new[] { filter_size, filter_size, num_in_channel, num_filters };
151+
var W = weight_variable("W", shape);
152+
// var tf.summary.histogram("weight", W);
153+
var b = bias_variable("b", new[] { num_filters });
154+
// tf.summary.histogram("bias", b);
155+
var layer = tf.nn.conv2d(x, W,
156+
strides: new[] { 1, stride, stride, 1 },
157+
padding: "SAME");
158+
layer += b;
159+
return tf.nn.relu(layer);
160+
});
161+
162+
}
163+
164+
/// <summary>
165+
/// Create a max pooling layer
166+
/// </summary>
167+
/// <param name="x">input to max-pooling layer</param>
168+
/// <param name="ksize">size of the max-pooling filter</param>
169+
/// <param name="stride">stride of the max-pooling filter</param>
170+
/// <param name="name">layer name</param>
171+
/// <returns>The output array</returns>
172+
private Tensor max_pool(Tensor x, int ksize, int stride, string name)
115173
{
116-
var in_dim = x.shape[1];
174+
return tf.nn.max_pool(x,
175+
ksize: new[] { 1, ksize, ksize, 1 },
176+
strides: new[] { 1, stride, stride, 1 },
177+
padding: "SAME",
178+
name: name);
179+
}
117180

181+
/// <summary>
182+
/// Flattens the output of the convolutional layer to be fed into fully-connected layer
183+
/// </summary>
184+
/// <param name="layer">input array</param>
185+
/// <returns>flattened array</returns>
186+
private Tensor flatten_layer(Tensor layer)
187+
{
188+
return with(tf.variable_scope("Flatten_layer"), delegate
189+
{
190+
var layer_shape = layer.TensorShape;
191+
var num_features = layer_shape[new Slice(1, 4)].Size;
192+
var layer_flat = tf.reshape(layer, new[] { -1, num_features });
193+
194+
return layer_flat;
195+
});
196+
}
197+
198+
private Tensor weight_variable(string name, int[] shape)
199+
{
118200
var initer = tf.truncated_normal_initializer(stddev: 0.01f);
119-
var W = tf.get_variable("W_" + name,
120-
dtype: tf.float32,
121-
shape: (in_dim, num_units),
122-
initializer: initer);
201+
return tf.get_variable(name,
202+
dtype: tf.float32,
203+
shape: shape,
204+
initializer: initer);
205+
}
123206

124-
var initial = tf.constant(0f, num_units);
125-
var b = tf.get_variable("b_" + name,
126-
dtype: tf.float32,
127-
initializer: initial);
207+
/// <summary>
208+
/// Create a bias variable with appropriate initialization
209+
/// </summary>
210+
/// <param name="name"></param>
211+
/// <param name="shape"></param>
212+
/// <returns></returns>
213+
private Tensor bias_variable(string name, int[] shape)
214+
{
215+
var initial = tf.constant(0f, shape: shape, dtype: tf.float32);
216+
return tf.get_variable(name,
217+
dtype: tf.float32,
218+
initializer: initial);
219+
}
220+
221+
private Tensor fc_layer(Tensor x, int num_units, string name, bool use_relu = true)
222+
{
223+
return with(tf.variable_scope(name), delegate
224+
{
225+
var in_dim = x.shape[1];
128226

129-
var layer = tf.matmul(x, W) + b;
130-
if (use_relu)
131-
layer = tf.nn.relu(layer);
227+
var W = weight_variable("W_" + name, shape: new[] { in_dim, num_units });
228+
var b = bias_variable("b_" + name, new[] { num_units });
132229

133-
return layer;
230+
var layer = tf.matmul(x, W) + b;
231+
if (use_relu)
232+
layer = tf.nn.relu(layer);
233+
234+
return layer;
235+
});
134236
}
135237

136238
public Graph ImportGraph() => throw new NotImplementedException();

test/TensorFlowNET.Examples/Utility/Compress.cs

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,20 @@
1-
using ICSharpCode.SharpZipLib.Core;
1+
/*****************************************************************************
2+
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
******************************************************************************/
16+
17+
using ICSharpCode.SharpZipLib.Core;
218
using ICSharpCode.SharpZipLib.GZip;
319
using ICSharpCode.SharpZipLib.Tar;
420
using System;

test/TensorFlowNET.Examples/Utility/DataSetMnist.cs

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,20 @@
1-
using NumSharp;
1+
/*****************************************************************************
2+
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
******************************************************************************/
16+
17+
using NumSharp;
218
using System;
319
using System.Collections.Generic;
420
using System.Text;

test/TensorFlowNET.Examples/Utility/Web.cs

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,20 @@
1-
using System;
1+
/*****************************************************************************
2+
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
******************************************************************************/
16+
17+
using System;
218
using System.Collections.Generic;
319
using System.IO;
420
using System.Linq;

0 commit comments

Comments
 (0)