Skip to content

Commit b536026

Browse files
committed
fix Operation.get_attr for all properties.
add Convolution class add convert_data_format add as_numpy_dtype
1 parent 958fdb8 commit b536026

File tree

24 files changed

+403
-57
lines changed

24 files changed

+403
-57
lines changed

TensorFlow.NET.sln

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Core", "src\T
1111
EndProject
1212
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Visualization", "src\TensorFlowNET.Visualization\TensorFlowNET.Visualization.csproj", "{0254BFF9-453C-4FE0-9609-3644559A79CE}"
1313
EndProject
14+
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "NumSharp.Core", "..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj", "{3EEAFB06-BEF0-4261-BAAB-630EABD25290}"
15+
EndProject
1416
Global
1517
GlobalSection(SolutionConfigurationPlatforms) = preSolution
1618
Debug|Any CPU = Debug|Any CPU
@@ -33,6 +35,10 @@ Global
3335
{0254BFF9-453C-4FE0-9609-3644559A79CE}.Debug|Any CPU.Build.0 = Debug|Any CPU
3436
{0254BFF9-453C-4FE0-9609-3644559A79CE}.Release|Any CPU.ActiveCfg = Release|Any CPU
3537
{0254BFF9-453C-4FE0-9609-3644559A79CE}.Release|Any CPU.Build.0 = Release|Any CPU
38+
{3EEAFB06-BEF0-4261-BAAB-630EABD25290}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
39+
{3EEAFB06-BEF0-4261-BAAB-630EABD25290}.Debug|Any CPU.Build.0 = Debug|Any CPU
40+
{3EEAFB06-BEF0-4261-BAAB-630EABD25290}.Release|Any CPU.ActiveCfg = Release|Any CPU
41+
{3EEAFB06-BEF0-4261-BAAB-630EABD25290}.Release|Any CPU.Build.0 = Release|Any CPU
3642
EndGlobalSection
3743
GlobalSection(SolutionProperties) = preSolution
3844
HideSolutionNode = FALSE

src/TensorFlowNET.Core/Graphs/Graph.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ public partial class Graph : IPython, IDisposable
2121
public int _version;
2222
private int _next_id_counter;
2323
private List<String> _unfetchable_ops = new List<string>();
24+
private List<Tensor> _unfeedable_tensors = new List<Tensor>();
2425

2526
public string _name_stack = "";
2627
public string _graph_key;
@@ -366,6 +367,11 @@ public object get_collection_ref(string name)
366367
return _collections[name];
367368
}
368369

370+
public void prevent_feeding(Tensor tensor)
371+
{
372+
_unfeedable_tensors.Add(tensor);
373+
}
374+
369375
public void Dispose()
370376
{
371377
c_api.TF_DeleteGraph(_handle);

src/TensorFlowNET.Core/Keras/Engine/InputSpec.cs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,16 @@ namespace Tensorflow.Keras.Engine
1010
public class InputSpec
1111
{
1212
public int ndim;
13+
Dictionary<int, int> axes;
1314

1415
public InputSpec(TF_DataType dtype = TF_DataType.DtInvalid,
15-
int? ndim = null)
16+
int? ndim = null,
17+
Dictionary<int, int> axes = null)
1618
{
1719
this.ndim = ndim.Value;
20+
if (axes == null)
21+
axes = new Dictionary<int, int>();
22+
this.axes = axes;
1823
}
1924
}
2025
}

src/TensorFlowNET.Core/Keras/Engine/Layer.cs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ protected virtual void build(TensorShape input_shape)
6666

6767
}
6868

69-
protected virtual void add_weight(string name,
69+
protected virtual RefVariable add_weight(string name,
7070
int[] shape,
7171
TF_DataType dtype = TF_DataType.DtInvalid,
7272
IInitializer initializer = null,
@@ -82,6 +82,8 @@ protected virtual void add_weight(string name,
8282
trainable: trainable.Value);
8383
backend.track_variable(variable);
8484
_trainable_weights.Add(variable);
85+
86+
return variable;
8587
}
8688
}
8789
}

src/TensorFlowNET.Core/Keras/Layers/Conv.cs

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
using System.Collections.Generic;
33
using System.Text;
44
using Tensorflow.Keras.Engine;
5+
using Tensorflow.Keras.Utils;
6+
using Tensorflow.Operations;
57
using Tensorflow.Operations.Activation;
68

79
namespace Tensorflow.Keras.Layers
@@ -19,6 +21,9 @@ public class Conv : Tensorflow.Layers.Layer
1921
protected bool use_bias;
2022
protected IInitializer kernel_initializer;
2123
protected IInitializer bias_initializer;
24+
protected RefVariable kernel;
25+
protected RefVariable bias;
26+
protected Convolution _convolution_op;
2227

2328
public Conv(int rank,
2429
int filters,
@@ -53,11 +58,37 @@ protected override void build(TensorShape input_shape)
5358
int channel_axis = data_format == "channels_first" ? 1 : -1;
5459
int input_dim = input_shape.Dimensions[input_shape.NDim - 1];
5560
var kernel_shape = new int[] { kernel_size[0], kernel_size[1], input_dim, filters };
56-
add_weight(name: "kernel",
61+
kernel = add_weight(name: "kernel",
5762
shape: kernel_shape,
5863
initializer: kernel_initializer,
5964
trainable: true,
6065
dtype: _dtype);
66+
if (use_bias)
67+
bias = add_weight(name: "bias",
68+
shape: new int[] { filters },
69+
initializer: bias_initializer,
70+
trainable: true,
71+
dtype: _dtype);
72+
73+
var axes = new Dictionary<int, int>();
74+
axes.Add(-1, input_dim);
75+
input_spec = new InputSpec(ndim: rank + 2, axes: axes);
76+
77+
string op_padding;
78+
if (padding == "causal")
79+
op_padding = "valid";
80+
else
81+
op_padding = padding;
82+
83+
var df = conv_utils.convert_data_format(data_format, rank + 2);
84+
_convolution_op = nn_ops.Convolution(input_shape,
85+
kernel.shape,
86+
op_padding.ToUpper(),
87+
strides,
88+
dilation_rate,
89+
data_format: df);
90+
91+
built = true;
6192
}
6293
}
6394
}
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.Keras.Utils
6+
{
7+
public class conv_utils
8+
{
9+
public static string convert_data_format(string data_format, int ndim)
10+
{
11+
if (data_format == "channels_last")
12+
if (ndim == 3)
13+
return "NWC";
14+
else if (ndim == 4)
15+
return "NHWC";
16+
else if (ndim == 5)
17+
return "NDHWC";
18+
else
19+
throw new ValueError($"Input rank not supported: {ndim}");
20+
else if (data_format == "channels_first")
21+
if (ndim == 3)
22+
return "NCW";
23+
else if (ndim == 4)
24+
return "NCHW";
25+
else if (ndim == 5)
26+
return "NCDHW";
27+
else
28+
throw new ValueError($"Input rank not supported: {ndim}");
29+
else
30+
throw new ValueError($"Invalid data_format: {data_format}");
31+
}
32+
}
33+
}

src/TensorFlowNET.Core/Layers/Layer.cs

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ public Tensor __call__(Tensor inputs,
6868
throw new NotImplementedException("");
6969
}
7070

71-
protected virtual void add_weight(string name,
71+
protected virtual RefVariable add_weight(string name,
7272
int[] shape,
7373
TF_DataType dtype = TF_DataType.DtInvalid,
7474
IInitializer initializer = null,
@@ -93,14 +93,14 @@ protected virtual void add_weight(string name,
9393

9494
_set_scope();
9595
var reuse = built || (_reuse != null && _reuse.Value);
96-
Python.with(tf.variable_scope(_scope,
96+
return Python.with(tf.variable_scope(_scope,
9797
reuse: reuse,
9898
auxiliary_name_scope: false), scope =>
9999
{
100100
_current_scope = scope;
101-
Python.with(ops.name_scope(_name_scope()), delegate
101+
return Python.with(ops.name_scope(_name_scope()), delegate
102102
{
103-
base.add_weight(name,
103+
var variable = base.add_weight(name,
104104
shape,
105105
dtype: dtype,
106106
initializer: initializer,
@@ -113,6 +113,12 @@ protected virtual void add_weight(string name,
113113
initializer: initializer1,
114114
trainable: trainable1);
115115
});
116+
117+
if(init_graph != null)
118+
{
119+
var trainable_variables = variables.trainable_variables();
120+
}
121+
return variable;
116122
});
117123
});
118124
}
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Linq;
4+
using System.Text;
5+
6+
namespace Tensorflow.Operations
7+
{
8+
public class Convolution
9+
{
10+
public TensorShape input_shape;
11+
public TensorShape filter_shape;
12+
public string data_format;
13+
public int[] strides;
14+
public string name;
15+
public _WithSpaceToBatch conv_op;
16+
17+
public Convolution(TensorShape input_shape,
18+
TensorShape filter_shape,
19+
string padding,
20+
int[] strides,
21+
int[] dilation_rate,
22+
string name = null,
23+
string data_format = null)
24+
{
25+
var num_total_dims = filter_shape.NDim;
26+
var num_spatial_dims = num_total_dims - 2;
27+
int input_channels_dim;
28+
int[] spatial_dims;
29+
if (string.IsNullOrEmpty(data_format) || !data_format.StartsWith("NC"))
30+
{
31+
input_channels_dim = input_shape.Dimensions[num_spatial_dims + 1];
32+
spatial_dims = Enumerable.Range(1, num_spatial_dims).ToArray();
33+
}
34+
else
35+
{
36+
input_channels_dim = input_shape.Dimensions[1];
37+
spatial_dims = Enumerable.Range(2, num_spatial_dims).ToArray();
38+
}
39+
40+
this.input_shape = input_shape;
41+
this.filter_shape = filter_shape;
42+
this.data_format = data_format;
43+
this.strides = strides;
44+
this.name = name;
45+
46+
conv_op = new _WithSpaceToBatch(
47+
input_shape,
48+
dilation_rate: dilation_rate,
49+
padding: padding,
50+
build_op: _build_op,
51+
filter_shape: filter_shape,
52+
spatial_dims: spatial_dims,
53+
data_format: data_format);
54+
}
55+
56+
public _NonAtrousConvolution _build_op(int _, string padding)
57+
{
58+
return new _NonAtrousConvolution(input_shape,
59+
filter_shape: filter_shape,
60+
padding: padding,
61+
data_format: data_format,
62+
strides: strides,
63+
name: name);
64+
}
65+
}
66+
}
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Linq;
4+
using System.Text;
5+
6+
namespace Tensorflow.Operations
7+
{
8+
public class _NonAtrousConvolution
9+
{
10+
public string padding;
11+
public string name;
12+
public int[] strides;
13+
public string data_format;
14+
private Func<object, Tensor> conv_op;
15+
16+
public _NonAtrousConvolution(TensorShape input_shape,
17+
TensorShape filter_shape,
18+
string padding,
19+
string data_format,
20+
int[] strides,
21+
string name)
22+
{
23+
this.padding = padding;
24+
this.name = name;
25+
var conv_dims = input_shape.NDim - 2;
26+
if (conv_dims == 1)
27+
{
28+
throw new NotImplementedException("_NonAtrousConvolution conv_dims 1");
29+
}
30+
else if (conv_dims == 2)
31+
{
32+
var list = strides.ToList();
33+
34+
if (string.IsNullOrEmpty(data_format) || data_format == "NHWC")
35+
{
36+
data_format = "NHWC";
37+
list.Insert(0, 1);
38+
list.Add(1);
39+
}
40+
else if (data_format == "NCHW")
41+
list.InsertRange(0, new int[] { 1, 1 });
42+
else
43+
throw new ValueError("data_format must be \"NHWC\" or \"NCHW\".");
44+
45+
strides = list.ToArray();
46+
this.strides = strides;
47+
this.data_format = data_format;
48+
conv_op = gen_nn_ops.conv2d;
49+
}
50+
else if (conv_dims == 3)
51+
{
52+
throw new NotImplementedException("_NonAtrousConvolution conv_dims 3");
53+
}
54+
}
55+
}
56+
}
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Linq;
4+
using System.Text;
5+
6+
namespace Tensorflow.Operations
7+
{
8+
public class _WithSpaceToBatch
9+
{
10+
private _NonAtrousConvolution call;
11+
12+
public _WithSpaceToBatch(TensorShape input_shape,
13+
int[] dilation_rate,
14+
string padding,
15+
Func<int, string, _NonAtrousConvolution> build_op,
16+
TensorShape filter_shape = null,
17+
int[] spatial_dims = null,
18+
string data_format = null)
19+
{
20+
var dilation_rate_tensor = ops.convert_to_tensor(dilation_rate, TF_DataType.TF_INT32, name: "dilation_rate");
21+
var rate_shape = dilation_rate_tensor.getShape();
22+
var num_spatial_dims = rate_shape.Dimensions[0];
23+
int starting_spatial_dim = -1;
24+
if (!string.IsNullOrEmpty(data_format) && data_format.StartsWith("NC"))
25+
starting_spatial_dim = 2;
26+
else
27+
starting_spatial_dim = 1;
28+
29+
if (spatial_dims == null)
30+
throw new NotImplementedException("_WithSpaceToBatch spatial_dims");
31+
32+
var orig_spatial_dims = spatial_dims;
33+
spatial_dims = spatial_dims.OrderBy(x => x).ToArray();
34+
if (!Enumerable.SequenceEqual(spatial_dims, orig_spatial_dims) || spatial_dims.Any(x => x < 1))
35+
throw new ValueError("spatial_dims must be a montonically increasing sequence of positive integers");
36+
37+
int expected_input_rank = -1;
38+
if (!string.IsNullOrEmpty(data_format) && data_format.StartsWith("NC"))
39+
expected_input_rank = spatial_dims.Last();
40+
else
41+
expected_input_rank = spatial_dims.Last() + 1;
42+
43+
var const_rate = tensor_util.constant_value(dilation_rate_tensor);
44+
var rate_or_const_rate = dilation_rate;
45+
if(!(const_rate is null))
46+
{
47+
if (const_rate.Data<int>().Count(x => x == 1) == const_rate.size)
48+
{
49+
call = build_op(num_spatial_dims, padding);
50+
return;
51+
}
52+
}
53+
}
54+
}
55+
}

0 commit comments

Comments
 (0)