Skip to content

Commit 051855e

Browse files
committed
fix return value issue for weight_variable and bias_variable.
1 parent e5c0665 commit 051855e

File tree

2 files changed

+10
-9
lines changed

2 files changed

+10
-9
lines changed

src/TensorFlowNET.Core/APIs/tf.nn.cs

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,19 +30,21 @@ public static class nn
3030
public static Tensor conv2d(Tensor input, RefVariable filter, int[] strides, string padding, bool use_cudnn_on_gpu = true,
3131
string data_format= "NHWC", int[] dilations= null, string name = null)
3232
{
33-
if (dilations == null)
34-
dilations = new[] { 1, 1, 1, 1 };
35-
36-
return gen_nn_ops.conv2d(new Conv2dParams
33+
var parameters = new Conv2dParams
3734
{
3835
Input = input,
3936
Filter = filter,
4037
Strides = strides,
38+
Padding = padding,
4139
UseCudnnOnGpu = use_cudnn_on_gpu,
4240
DataFormat = data_format,
43-
Dilations = dilations,
4441
Name = name
45-
});
42+
};
43+
44+
if (dilations != null)
45+
parameters.Dilations = dilations;
46+
47+
return gen_nn_ops.conv2d(parameters);
4648
}
4749

4850
/// <summary>

test/TensorFlowNET.Examples/ImageProcess/DigitRecognitionCNN.cs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,6 @@ private Tensor conv_layer(Tensor x, int filter_size, int num_filters, int stride
158158
layer += b;
159159
return tf.nn.relu(layer);
160160
});
161-
162161
}
163162

164163
/// <summary>
@@ -195,7 +194,7 @@ private Tensor flatten_layer(Tensor layer)
195194
});
196195
}
197196

198-
private Tensor weight_variable(string name, int[] shape)
197+
private RefVariable weight_variable(string name, int[] shape)
199198
{
200199
var initer = tf.truncated_normal_initializer(stddev: 0.01f);
201200
return tf.get_variable(name,
@@ -210,7 +209,7 @@ private Tensor weight_variable(string name, int[] shape)
210209
/// <param name="name"></param>
211210
/// <param name="shape"></param>
212211
/// <returns></returns>
213-
private Tensor bias_variable(string name, int[] shape)
212+
private RefVariable bias_variable(string name, int[] shape)
214213
{
215214
var initial = tf.constant(0f, shape: shape, dtype: tf.float32);
216215
return tf.get_variable(name,

0 commit comments

Comments
 (0)