Skip to content

Commit 3bfdedc

Browse files
committed
Change filter as IVariableV1 in conv2d.
1 parent a7f9599 commit 3bfdedc

File tree

11 files changed

+25
-18
lines changed

11 files changed

+25
-18
lines changed

src/TensorFlowNET.Core/APIs/tf.nn.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ public Tensor conv2d(Tensor input, IVariableV1 filter, int[] strides, string pad
3232
var parameters = new Conv2dParams
3333
{
3434
Input = input,
35-
Filter = filter.AsTensor(),
35+
Filter = filter,
3636
Strides = strides,
3737
Padding = padding,
3838
UseCudnnOnGpu = use_cudnn_on_gpu,
@@ -153,7 +153,7 @@ public Tensor bias_add(Tensor value, IVariableV1 bias, string data_format = null
153153
return tf_with(ops.name_scope(name, "BiasAdd", new { value, bias }), scope =>
154154
{
155155
name = scope;
156-
return gen_nn_ops.bias_add(value, bias.AsTensor(), data_format: data_format, name: name);
156+
return gen_nn_ops.bias_add(value, bias, data_format: data_format, name: name);
157157
});
158158
}
159159

src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ public Tensor[] TFE_FastPathExecute(Context ctx,
172172

173173
SafeOpHandle GetOp(Context ctx, string op_or_function_name, Status status)
174174
{
175-
if (thread_local_eager_operation_map.find(ctx, out var op))
175+
/*if (thread_local_eager_operation_map.find(ctx, out var op))
176176
c_api.TFE_OpReset(op, op_or_function_name, ctx.DeviceName, status.Handle);
177177
else
178178
{
@@ -181,7 +181,8 @@ SafeOpHandle GetOp(Context ctx, string op_or_function_name, Status status)
181181
}
182182
183183
status.Check(true);
184-
return op;
184+
return op;*/
185+
return c_api.TFE_NewOp(ctx.Handle, op_or_function_name, status.Handle);
185186
}
186187

187188
bool HasAccumulator()

src/TensorFlowNET.Core/Operations/NnOps/BasicLSTMCell.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ protected Tensors Call(Tensors inputs, Tensor state = null, bool is_training = f
8585
throw new NotImplementedException("BasicLstmCell call");
8686
}
8787
var gate_inputs = math_ops.matmul(array_ops.concat(new[] { (Tensor)inputs, h }, 1), _kernel.AsTensor());
88-
gate_inputs = nn_ops.bias_add(gate_inputs, _bias.AsTensor());
88+
gate_inputs = nn_ops.bias_add(gate_inputs, _bias);
8989

9090
// i = input_gate, j = new_input, f = forget_gate, o = output_gate
9191
var tensors = array_ops.split(value: gate_inputs, num_split: 4, axis: one);

src/TensorFlowNET.Core/Operations/NnOps/BasicRNNCell.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ protected Tensors Call(Tensors inputs, Tensor state = null, bool is_training = f
7171
// Most basic RNN: output = new_state = act(W * input + U * state + B).
7272
var concat = array_ops.concat(new Tensor[] { inputs, state }, 1);
7373
var gate_inputs = math_ops.matmul(concat, _kernel.AsTensor());
74-
gate_inputs = nn_ops.bias_add(gate_inputs, _bias.AsTensor());
74+
gate_inputs = nn_ops.bias_add(gate_inputs, _bias);
7575
var output = _activation(gate_inputs, null);
7676
return new Tensors(output, output);
7777
}

src/TensorFlowNET.Core/Operations/NnOps/Conv2dParams.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ public class Conv2dParams
4242
/// <summary>
4343
/// A 4-D tensor of shape
4444
/// </summary>
45-
public Tensor Filter { get; set; }
45+
public IVariableV1 Filter { get; set; }
4646

4747
/// <summary>
4848
/// An integer vector representing the tensor shape of `filter`

src/TensorFlowNET.Core/Operations/NnOps/ConvolutionInternal.cs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,12 +60,10 @@ public Tensor Apply(Tensors input, IVariableV1 filters)
6060
name = scope;
6161
if (num_spatial_dims == 2)
6262
{
63-
var filters_tensor = filters.AsTensor();
64-
6563
result = gen_nn_ops.conv2d(new Conv2dParams
6664
{
6765
Input = input,
68-
Filter = filters_tensor,
66+
Filter = filters,
6967
Strides = strides,
7068
Padding = padding,
7169
DataFormat = data_format,

src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ public static Tensor conv2d_backprop_input(Tensor input_sizes, Tensor filter, Te
171171
}
172172

173173
public static Tensor bias_add(Tensor value,
174-
Tensor bias,
174+
IVariableV1 bias,
175175
string data_format = null,
176176
string name = null)
177177
{

src/TensorFlowNET.Core/Operations/nn_ops.cs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,16 +46,14 @@ public static ConvolutionInternal convolution_internal(string padding,
4646
/// <param name="name"></param>
4747
/// <returns></returns>
4848
public static Tensor bias_add(Tensor value,
49-
Tensor bias,
49+
IVariableV1 bias,
5050
string data_format = null,
5151
string name = null)
5252
{
5353
return tf_with(ops.name_scope(name, "BiasAdd", new { value, bias }), scope =>
5454
{
5555
name = scope;
56-
value = ops.convert_to_tensor(value, name: "input");
57-
var bias_tensor = ops.convert_to_tensor(bias, dtype: value.dtype, name: "bias");
58-
return gen_nn_ops.bias_add(value, bias_tensor, data_format: data_format, name: name);
56+
return gen_nn_ops.bias_add(value, bias, data_format: data_format, name: name);
5957
});
6058
}
6159

src/TensorFlowNET.Keras/Layers/Convolution/Convolutional.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ protected override Tensors Call(Tensors inputs, Tensor state = null, bool traini
110110
}
111111
else
112112
{
113-
outputs = nn_ops.bias_add(outputs, bias.AsTensor(), data_format: "NHWC");
113+
outputs = nn_ops.bias_add(outputs, bias, data_format: "NHWC");
114114
}
115115
}
116116

test/TensorFlowNET.UnitTest/Basics/VariableTest.cs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,5 +121,15 @@ public void ShouldReturnNegative()
121121
Assert.IsTrue(Enumerable.SequenceEqual(new[] { 1, 2 }, neg_x.shape));
122122
Assert.IsTrue(Enumerable.SequenceEqual(new[] { -1, -2 }, neg_x.numpy().ToArray<int>()));
123123
}
124+
125+
[TestMethod]
126+
public void IdentityOriginalTensor()
127+
{
128+
var a = tf.Variable(5);
129+
var a_identity = tf.identity(a);
130+
a.assign_add(1);
131+
Assert.AreEqual(5, (int)a_identity.numpy());
132+
Assert.AreEqual(6, (int)a.numpy());
133+
}
124134
}
125135
}

0 commit comments

Comments
 (0)