Skip to content

Commit 6a8665f

Browse files
committed
Add Conv2DTranspose #735
1 parent cfffc68 commit 6a8665f

35 files changed

+447
-40
lines changed

src/TensorFlowNET.Core/Binding.Util.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,13 +149,18 @@ public static int len(object a)
149149
return ndArray.ndim == 0 ? 1 : ndArray.shape[0];
150150
case IEnumerable enumerable:
151151
return enumerable.OfType<object>().Count();
152+
case TensorShape arr:
153+
return arr.ndim;
152154
}
153155
throw new NotImplementedException("len() not implemented for type: " + a.GetType());
154156
}
155157

156158
public static float min(float a, float b)
157159
=> Math.Min(a, b);
158160

161+
public static int max(int a, int b)
162+
=> Math.Max(a, b);
163+
159164
public static T[] list<T>(IEnumerable<T> list)
160165
=> list.ToArray();
161166

src/TensorFlowNET.Core/Framework/smart_module.cs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ limitations under the License.
1515
******************************************************************************/
1616

1717
using System;
18+
using System.Linq;
19+
using static Tensorflow.Binding;
1820

1921
namespace Tensorflow.Framework
2022
{
@@ -52,7 +54,14 @@ public static Tensor smart_cond(bool pred,
5254
{
5355
var pred_value = tensor_util.constant_value(pred);
5456
if (pred_value is null)
55-
return pred.eval(new Session(pred.graph));
57+
{
58+
var result = range(pred.op.NumOutputs).Select(x => IntPtr.Zero).ToArray();
59+
var evaluated = c_api.TF_TryEvaluateConstant(pred.graph, pred._as_tf_output(), result, tf.Status.Handle);
60+
if (!evaluated || c_api.TF_GetCode(tf.Status.Handle) != TF_Code.TF_OK)
61+
return null;
62+
else
63+
throw new NotImplementedException("");
64+
}
5665

5766
return pred_value;
5867
}

src/TensorFlowNET.Core/Graphs/c_api.graph.cs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,5 +322,18 @@ public static extern void TF_GraphSetOutputHandleShapesAndTypes(IntPtr graph, TF
322322
[DllImport(TensorFlowLibName)]
323323

324324
public static extern void TF_UpdateEdge(IntPtr graph, TF_Output new_src, TF_Input dst, SafeStatusHandle status);
325+
326+
/// <summary>
327+
/// Attempts to evaluate `output`. This will only be possible if `output` doesn't
328+
/// depend on any graph inputs (this function is safe to call if this isn't the
329+
/// case though).
330+
/// </summary>
331+
/// <param name="graph"></param>
332+
/// <param name="output"></param>
333+
/// <param name="result"></param>
334+
/// <param name="status"></param>
335+
/// <returns></returns>
336+
[DllImport(TensorFlowLibName)]
337+
public static extern bool TF_TryEvaluateConstant(IntPtr graph, TF_Output output, IntPtr[] result, SafeStatusHandle status);
325338
}
326339
}

src/TensorFlowNET.Core/Keras/Engine/InputSpec.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,6 @@ public InputSpec(TF_DataType dtype = TF_DataType.DtInvalid,
5050
}
5151

5252
public override string ToString()
53-
=> $"min_ndim={min_ndim}, , axes={axes.Count}";
53+
=> $"ndim={ndim}, min_ndim={min_ndim}, axes={axes.Count}";
5454
}
5555
}

src/TensorFlowNET.Core/Operations/nn_impl.py.cs

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,31 @@ namespace Tensorflow
2121
{
2222
public class nn_impl
2323
{
24+
public static Tensor conv2d_transpose(Tensor value = null,
25+
IVariableV1 filter = null,
26+
Tensor output_shape = null,
27+
TensorShape strides = null,
28+
string padding = "SAME",
29+
string data_format = "NHWC",
30+
string name = null,
31+
TensorShape dilations = null)
32+
{
33+
if (dilations == null)
34+
dilations = (1, 1, 1, 1);
35+
return tf_with(ops.name_scope(name, "conv2d_transpose", new { value, filter, output_shape }), scope =>
36+
{
37+
return gen_nn_ops.conv2d_backprop_input(
38+
input_sizes: output_shape,
39+
filter: filter.AsTensor(),
40+
out_backprop: value,
41+
strides: strides,
42+
padding: padding,
43+
data_format: data_format,
44+
dilations: dilations,
45+
name: name);
46+
});
47+
}
48+
2449
/// <summary>
2550
/// Normalizes along dimension `axis` using an L2 norm.
2651
/// </summary>
@@ -83,6 +108,23 @@ public static (Tensor, Tensor) moments(Tensor x,
83108
});
84109
}
85110

111+
public static Tensor batch_normalization(Tensor x,
112+
Tensor mean,
113+
Tensor variance,
114+
Tensor offset,
115+
Tensor scale,
116+
float variance_epsilon = 0.001f,
117+
string name = null)
118+
{
119+
return tf_with(ops.name_scope(name, "batchnorm", new { x, mean, variance, scale, offset }), scope =>
120+
{
121+
var inv = math_ops.rsqrt(variance + variance_epsilon);
122+
inv *= scale;
123+
return x * math_ops.cast(inv, x.dtype) + math_ops.cast(
124+
offset == null ? (-mean * inv) : (offset - mean * inv), x.dtype);
125+
});
126+
}
127+
86128
/// <summary>
87129
/// Batch normalization.
88130
/// </summary>

src/TensorFlowNET.Core/Tensors/TensorShape.Equals.cs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@ public override bool Equals(Object obj)
1515
else if (rank != shape1.rank)
1616
return false;
1717
return Enumerable.SequenceEqual(shape1.dims, dims);
18+
case int[] shape2:
19+
if (rank != shape2.Length)
20+
return false;
21+
return Enumerable.SequenceEqual(dims, shape2);
1822
default:
1923
return false;
2024
}

src/TensorFlowNET.Keras/BackendImpl.cs

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,5 +317,30 @@ public Tensor concatenate(Tensors tensors, int axis = -1)
317317

318318
return array_ops.concat(tensors, axis);
319319
}
320+
321+
public Tensor conv2d_transpose(Tensor x,
322+
IVariableV1 kernel,
323+
Tensor output_shape,
324+
TensorShape strides = null,
325+
string padding = "valid",
326+
string data_format = null,
327+
TensorShape dilation_rate = null)
328+
{
329+
var force_transpose = false;
330+
if (data_format == "channels_first" && !dilation_rate.Equals(new[] { 1, 1 }))
331+
force_transpose = true;
332+
// x, tf_data_format = _preprocess_conv2d_input(x, data_format, force_transpose)
333+
var tf_data_format = "NHWC";
334+
padding = padding.ToUpper();
335+
strides = new TensorShape(1, strides[0], strides[1], 1);
336+
if (dilation_rate.Equals(new[] { 1, 1 }))
337+
x = nn_impl.conv2d_transpose(x, kernel, output_shape, strides,
338+
padding: padding,
339+
data_format: tf_data_format);
340+
else
341+
throw new NotImplementedException("");
342+
343+
return x;
344+
}
320345
}
321346
}

src/TensorFlowNET.Keras/Engine/Functional.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -301,9 +301,9 @@ void BuildMapHelper(Tensor tensor,
301301
nodes_in_decreasing_depth.append(node);
302302
}
303303

304-
protected override Tensors Call(Tensors inputs, Tensor state = null, bool is_training = false)
304+
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
305305
{
306-
return run_internal_graph(inputs, is_training);
306+
return run_internal_graph(inputs, training.Value);
307307
}
308308

309309
Tensors run_internal_graph(Tensors inputs, bool training = false, Tensors mask = null)

src/TensorFlowNET.Keras/Engine/Layer.Apply.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@ public partial class Layer
1010
/// </summary>
1111
/// <param name="input"></param>
1212
/// <param name="state"></param>
13-
/// <param name="is_training"></param>
13+
/// <param name="training"></param>
1414
/// <returns></returns>
15-
public Tensors Apply(Tensors inputs, Tensor state = null, bool is_training = false)
15+
public Tensors Apply(Tensors inputs, Tensor state = null, bool training = false)
1616
{
1717
callContext = callContext ?? new ThreadLocal<CallContext>()
1818
{
@@ -38,7 +38,7 @@ public Tensors Apply(Tensors inputs, Tensor state = null, bool is_training = fal
3838
if (!built)
3939
MaybeBuild(inputs);
4040

41-
outputs = Call(inputs, state: state, is_training: is_training);
41+
outputs = Call(inputs, state: state, training: training);
4242

4343
// memory leak
4444
// _set_connectivity_metadata_(inputs, outputs);

src/TensorFlowNET.Keras/Engine/Layer.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ private Tensor compute_mask(Tensor inputs, Tensor mask = null)
155155
/// <param name="state"></param>
156156
/// <param name="is_training"></param>
157157
/// <returns></returns>
158-
protected virtual Tensors Call(Tensors inputs, Tensor state = null, bool is_training = false)
158+
protected virtual Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
159159
{
160160
throw new NotImplementedException("");
161161
}

0 commit comments

Comments
 (0)