Skip to content

Commit bdd9bec

Browse files
committed
x_emb shape is not correct #189
1 parent 79aaae2 commit bdd9bec

File tree

10 files changed

+240
-61
lines changed

10 files changed

+240
-61
lines changed

src/TensorFlowNET.Core/APIs/tf.init.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ public static variable_scope variable_scope(string name,
2121
public static variable_scope variable_scope(VariableScope scope,
2222
string default_name = null,
2323
object values = null,
24+
bool? reuse = null,
2425
bool auxiliary_name_scope = true) => new variable_scope(scope,
2526
default_name,
2627
values,

src/TensorFlowNET.Core/Graphs/Graph.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ public partial class Graph : IPython, IDisposable
3737
/// </summary>
3838
private Dictionary<string, object> _collections = new Dictionary<string, object>();
3939

40+
public bool building_function;
41+
4042
public Graph()
4143
{
4244
_handle = c_api.TF_NewGraph();

src/TensorFlowNET.Core/Keras/Engine/InputSpec.cs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,12 @@ namespace Tensorflow.Keras.Engine
99
/// </summary>
1010
public class InputSpec
1111
{
12-
public InputSpec(TF_DataType dtype = TF_DataType.DtInvalid)
13-
{
12+
public int ndim;
1413

14+
public InputSpec(TF_DataType dtype = TF_DataType.DtInvalid,
15+
int? ndim = null)
16+
{
17+
this.ndim = ndim.Value;
1518
}
1619
}
1720
}
Lines changed: 24 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using System;
22
using System.Collections.Generic;
33
using System.Text;
4+
using Tensorflow.Keras.Utils;
45

56
namespace Tensorflow.Keras.Engine
67
{
@@ -12,77 +13,49 @@ namespace Tensorflow.Keras.Engine
1213
/// </summary>
1314
public class Layer : CheckpointableBase
1415
{
15-
protected bool trainable;
16-
protected string _name;
17-
protected TF_DataType _dtype;
18-
protected Graph _graph;
19-
protected string _base_name;
20-
protected VariableScope _scope;
21-
/// <summary>
22-
/// A stateful layer is a layer whose updates are run during inference too,
23-
/// for instance stateful RNNs.
24-
/// </summary>
25-
protected bool stateful;
2616
/// <summary>
2717
/// Indicates whether `build` needs to be called upon layer call, to create
2818
/// the layer's weights.
2919
/// </summary>
3020
protected bool built;
31-
/// <summary>
32-
/// Provides information about which inputs are compatible with the layer.
33-
/// </summary>
34-
protected InputSpec input_spec;
35-
protected bool supports_masking;
36-
37-
public Layer(bool trainable = true,
38-
string name = null,
39-
TF_DataType dtype = TF_DataType.DtInvalid)
40-
{
41-
this.trainable = trainable;
42-
this.stateful = false;
43-
this.built = false;
44-
this.supports_masking = false;
45-
_init_set_name(name);
46-
}
47-
48-
public Tensor apply(Tensor inputs)
49-
{
50-
return __call__(inputs);
51-
}
5221

5322
public Tensor __call__(Tensor inputs,
5423
VariableScope scope = null)
5524
{
56-
_set_scope(scope);
57-
_graph = ops._get_graph_from_inputs(new List<Tensor> { inputs }, graph: _graph);
58-
var scope_context_manager = tf.variable_scope(_scope);
25+
var input_list = new Tensor[] { inputs };
26+
27+
// We will attempt to build a TF graph if & only if all inputs are symbolic.
28+
// This is always the case in graph mode. It can also be the case in eager
29+
// mode when all inputs can be traced back to `keras.Input()` (when building
30+
// models using the functional API).
31+
bool build_graph = tf_utils.are_all_symbolic_tensors(input_list);
32+
33+
// Handle Keras mask propagation from previous layer to current layer.
34+
Python.with(new ops.name_scope(_name_scope()), delegate
35+
{
36+
if (!built)
37+
{
38+
_maybe_build(inputs);
39+
}
40+
});
5941

6042
throw new NotImplementedException("");
6143
}
6244

63-
private void _init_set_name(string name)
45+
protected virtual string _name_scope()
6446
{
65-
if (string.IsNullOrEmpty(name))
66-
(_name, _base_name) = _make_unique_name();
47+
return null;
6748
}
6849

69-
private (string, string) _make_unique_name()
50+
protected void _maybe_build(Tensor inputs)
7051
{
71-
string base_name = "conv2d";
72-
string name = base_layer_utils.unique_layer_name(base_name);
73-
return (name, base_name);
52+
var input_list = new Tensor[] { inputs };
53+
build(inputs.getShape());
7454
}
7555

76-
private void _set_scope(VariableScope scope = null)
56+
protected virtual void build(TensorShape input_shape)
7757
{
78-
if (_scope == null)
79-
{
80-
Python.with(tf.variable_scope(scope, default_name: _base_name), captured_scope =>
81-
{
82-
_scope = captured_scope;
83-
});
84-
}
85-
58+
8659
}
8760
}
8861
}

src/TensorFlowNET.Core/Keras/Layers/Conv.cs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
namespace Tensorflow.Keras.Layers
88
{
9-
public class Conv : Layer
9+
public class Conv : Tensorflow.Layers.Layer
1010
{
1111
protected int rank;
1212
protected int filters;
@@ -45,6 +45,15 @@ public Conv(int rank,
4545
this.use_bias = use_bias;
4646
this.kernel_initializer = kernel_initializer;
4747
this.bias_initializer = bias_initializer;
48+
input_spec = new InputSpec(ndim: rank + 2);
49+
}
50+
51+
protected override void build(TensorShape input_shape)
52+
{
53+
int channel_axis = data_format == "channels_first" ? 1 : -1;
54+
int input_dim = input_shape.Dimensions[input_shape.NDim - 1];
55+
var kernel_shape = new int[] { kernel_size[0], kernel_size[1], input_dim, filters };
56+
add_weight();
4857
}
4958
}
5059
}
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Linq;
4+
using System.Text;
5+
6+
namespace Tensorflow.Keras.Utils
7+
{
8+
public class tf_utils
9+
{
10+
public static bool are_all_symbolic_tensors(Tensor[] tensors)
11+
{
12+
return tensors.Select(x => is_symbolic_tensor(x)).Count() == tensors.Length;
13+
}
14+
15+
public static bool is_symbolic_tensor(Tensor tensor)
16+
{
17+
return true;
18+
}
19+
}
20+
}
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
using Tensorflow.Keras.Engine;
5+
6+
namespace Tensorflow.Layers
7+
{
8+
public class Layer : Keras.Engine.Layer
9+
{
10+
protected bool trainable;
11+
protected string _name;
12+
protected TF_DataType _dtype;
13+
protected Graph _graph;
14+
protected string _base_name;
15+
protected VariableScope _scope;
16+
protected VariableScope _current_scope;
17+
/// <summary>
18+
/// A stateful layer is a layer whose updates are run during inference too,
19+
/// for instance stateful RNNs.
20+
/// </summary>
21+
protected bool stateful;
22+
/// <summary>
23+
/// Provides information about which inputs are compatible with the layer.
24+
/// </summary>
25+
protected InputSpec input_spec;
26+
protected bool supports_masking;
27+
protected bool? _reuse;
28+
29+
public Layer(bool trainable = true,
30+
string name = null,
31+
TF_DataType dtype = TF_DataType.DtInvalid,
32+
bool? _reuse = null)
33+
{
34+
this.trainable = trainable;
35+
this.stateful = false;
36+
this._reuse = _reuse;
37+
this.built = false;
38+
this.supports_masking = false;
39+
_init_set_name(name);
40+
}
41+
42+
public Tensor apply(Tensor inputs)
43+
{
44+
return __call__(inputs);
45+
}
46+
47+
public Tensor __call__(Tensor inputs,
48+
VariableScope scope = null)
49+
{
50+
_set_scope(scope);
51+
_graph = ops._get_graph_from_inputs(new List<Tensor> { inputs }, graph: _graph);
52+
53+
variable_scope scope_context_manager = null;
54+
if (built)
55+
{
56+
57+
}
58+
else
59+
{
60+
scope_context_manager = tf.variable_scope(_scope,
61+
auxiliary_name_scope: false);
62+
}
63+
64+
Python.with(scope_context_manager, scope2 => _current_scope = scope2);
65+
// Actually call layer
66+
var outputs = base.__call__(inputs);
67+
68+
throw new NotImplementedException("");
69+
}
70+
71+
protected virtual void add_weight()
72+
{
73+
var default_graph = ops.get_default_graph();
74+
Graph init_graph = null;
75+
RefVariable[] existing_variables = null;
76+
77+
if (default_graph.building_function)
78+
{
79+
throw new NotImplementedException("add_weight");
80+
}
81+
else
82+
{
83+
init_graph = default_graph;
84+
existing_variables = variables.global_variables().ToArray();
85+
}
86+
87+
var dtype = TF_DataType.TF_FLOAT;
88+
_set_scope();
89+
var reuse = built || (_reuse != null && _reuse.Value);
90+
Python.with(tf.variable_scope(_scope,
91+
reuse: reuse,
92+
auxiliary_name_scope: false), scope =>
93+
{
94+
_current_scope = scope;
95+
Python.with(new ops.name_scope(_name_scope()), delegate
96+
{
97+
98+
99+
});
100+
});
101+
}
102+
103+
private void _init_set_name(string name)
104+
{
105+
if (string.IsNullOrEmpty(name))
106+
(_name, _base_name) = _make_unique_name();
107+
}
108+
109+
private (string, string) _make_unique_name()
110+
{
111+
string base_name = "conv2d";
112+
string name = base_layer_utils.unique_layer_name(base_name);
113+
return (name, base_name);
114+
}
115+
116+
protected override string _name_scope()
117+
{
118+
return _current_scope.original_name_scope;
119+
}
120+
121+
private void _set_scope(VariableScope scope = null)
122+
{
123+
if (_scope == null)
124+
{
125+
Python.with(tf.variable_scope(scope, default_name: _base_name), captured_scope =>
126+
{
127+
_scope = captured_scope;
128+
});
129+
}
130+
}
131+
}
132+
}

src/TensorFlowNET.Core/Operations/embedding_ops.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ public static Tensor _embedding_lookup_and_transform(RefVariable @params,
2828
if(np == 1)
2929
{
3030
var gather = array_ops.gather(@params, ids, name: name);
31-
var result = _clip(@params, ids, max_norm);
31+
var result = _clip(gather, ids, max_norm);
3232

3333
return array_ops.identity(result);
3434
}
@@ -37,7 +37,7 @@ public static Tensor _embedding_lookup_and_transform(RefVariable @params,
3737
});
3838
}
3939

40-
public static Tensor _clip(RefVariable @params, Tensor ids, string max_norm = null)
40+
public static Tensor _clip(Tensor @params, Tensor ids, string max_norm = null)
4141
{
4242
if (max_norm == null)
4343
return @params;

src/TensorFlowNET.Core/Variables/_VariableScopeStore.cs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,5 +22,13 @@ public void open_variable_scope(string scope_name)
2222
else
2323
variable_scopes_count[scope_name] = 1;
2424
}
25+
26+
public int variable_scope_count(string scope_name)
27+
{
28+
if (variable_scopes_count.ContainsKey(scope_name))
29+
return variable_scopes_count[scope_name];
30+
else
31+
return 0;
32+
}
2533
}
2634
}

0 commit comments

Comments
 (0)