Skip to content

Commit 59ee7ef

Browse files
committed
add InputLayer
1 parent ac523ed commit 59ee7ef

File tree

8 files changed

+121
-9
lines changed

8 files changed

+121
-9
lines changed

src/TensorFlowNET.Core/APIs/keras.layers.cs

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using System;
22
using System.Collections.Generic;
3+
using System.Linq;
34
using System.Text;
45
using Tensorflow.Keras;
56
using Tensorflow.Keras.Engine;
@@ -12,10 +13,30 @@ public static partial class keras
1213
public static class layers
1314
{
1415
public static Embedding Embedding(int input_dim, int output_dim,
15-
string embeddings_initializer = "uniform",
16-
bool mask_zero = false) => new Embedding(input_dim, output_dim,
17-
embeddings_initializer,
18-
mask_zero);
16+
IInitializer embeddings_initializer = null,
17+
bool mask_zero = false) => new Embedding(input_dim, output_dim,
18+
embeddings_initializer,
19+
mask_zero);
20+
21+
public static InputLayer Input(int[] batch_shape = null,
22+
TF_DataType dtype = TF_DataType.DtInvalid,
23+
string name = null,
24+
bool sparse = false,
25+
Tensor tensor = null)
26+
{
27+
var batch_size = batch_shape[0];
28+
var shape = batch_shape.Skip(1).ToArray();
29+
30+
var input_layer = new InputLayer(
31+
input_shape: shape,
32+
batch_size: batch_size,
33+
name: name,
34+
dtype: dtype,
35+
sparse: sparse,
36+
input_tensor: tensor);
37+
38+
throw new NotImplementedException("");
39+
}
1940
}
2041
}
2142
}

src/TensorFlowNET.Core/Keras/Engine/Network.cs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,16 @@ public class Network : Layer
1010
protected bool _is_compiled;
1111
protected bool _expects_training_arg;
1212
protected bool _compute_output_and_mask_jointly;
13+
/// <summary>
14+
/// All layers in order of horizontal graph traversal.
15+
/// Entries are unique. Includes input and output layers.
16+
/// </summary>
17+
protected List<Layer> _layers;
1318

1419
public Network(string name = null)
1520
: base(name: name)
1621
{
17-
22+
_init_subclassed_network(name);
1823
}
1924

2025
protected virtual void _init_subclassed_network(string name = null)
@@ -30,6 +35,7 @@ protected virtual void _base_init(string name = null)
3035
_expects_training_arg = false;
3136
_compute_output_and_mask_jointly = false;
3237
supports_masking = false;
38+
_layers = new List<Layer>();
3339
}
3440
}
3541
}

src/TensorFlowNET.Core/Keras/Engine/Sequential.cs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,18 @@ public void add(Layer layer)
2323
{
2424
built = false;
2525
var set_inputs = false;
26+
if(_layers.Count == 0)
27+
{
28+
var (batch_shape, dtype) = (layer._batch_input_shape, layer._dtype);
29+
if(batch_shape != null)
30+
{
31+
// Instantiate an input layer.
32+
var x = keras.layers.Input(
33+
batch_shape: batch_shape,
34+
dtype: dtype,
35+
name: layer._name + "_input");
36+
}
37+
}
2638
}
2739

2840
public void __exit__()

src/TensorFlowNET.Core/Keras/Layers/Embedding.cs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@ public class Embedding : Layer
1212

1313
public Embedding(int input_dim, int output_dim,
1414
IInitializer embeddings_initializer = null,
15-
bool mask_zero = false)
15+
bool mask_zero = false,
16+
TF_DataType dtype = TF_DataType.TF_FLOAT,
17+
int[] input_shape = null) : base(dtype: dtype, input_shape: input_shape)
1618
{
1719
this.input_dim = input_dim;
1820
this.output_dim = output_dim;
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.Keras.Layers
6+
{
7+
/// <summary>
8+
/// Layer to be used as an entry point into a Network (a graph of layers).
9+
/// </summary>
10+
public class InputLayer : Layer
11+
{
12+
public bool sparse;
13+
public int? batch_size;
14+
15+
public InputLayer(int[] input_shape = null,
16+
int? batch_size = null,
17+
TF_DataType dtype = TF_DataType.DtInvalid,
18+
string name = null,
19+
bool sparse = false,
20+
Tensor input_tensor = null)
21+
{
22+
built = true;
23+
this.sparse = sparse;
24+
this.batch_size = batch_size;
25+
this.supports_masking = true;
26+
27+
if(input_tensor == null)
28+
{
29+
var batch_input_shape = new int[] { batch_size.HasValue ? batch_size.Value : -1, -1 };
30+
31+
if (sparse)
32+
{
33+
throw new NotImplementedException("InputLayer sparse is true");
34+
}
35+
else
36+
{
37+
input_tensor = backend.placeholder(
38+
shape: batch_input_shape,
39+
dtype: dtype,
40+
name: name);
41+
}
42+
}
43+
}
44+
}
45+
}

src/TensorFlowNET.Core/Keras/Layers/Layer.cs

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ public class Layer : CheckpointableBase
2121
/// </summary>
2222
protected bool built;
2323
protected bool trainable;
24-
protected TF_DataType _dtype;
24+
public TF_DataType _dtype;
2525
/// <summary>
2626
/// A stateful layer is a layer whose updates are run during inference too,
2727
/// for instance stateful RNNs.
@@ -33,12 +33,16 @@ public class Layer : CheckpointableBase
3333
protected InputSpec input_spec;
3434
protected bool supports_masking;
3535
protected List<RefVariable> _trainable_weights;
36-
protected string _name;
36+
public string _name;
3737
protected string _base_name;
3838
protected bool _compute_previous_mask;
3939
protected List<Operation> _updates;
40+
public int[] _batch_input_shape;
4041

41-
public Layer(bool trainable = true, string name = null, TF_DataType dtype = TF_DataType.DtInvalid)
42+
public Layer(bool trainable = true,
43+
string name = null,
44+
TF_DataType dtype = TF_DataType.DtInvalid,
45+
int[] input_shape = null)
4246
{
4347
this.trainable = trainable;
4448
this._dtype = dtype;
@@ -49,6 +53,12 @@ public Layer(bool trainable = true, string name = null, TF_DataType dtype = TF_D
4953
_trainable_weights = new List<RefVariable>();
5054
_compute_previous_mask = false;
5155
_updates = new List<Operation>();
56+
57+
// Manage input shape information if passed.
58+
59+
_batch_input_shape = new int[] { -1, -1 };
60+
61+
_dtype = dtype;
5262
}
5363

5464
public Tensor __call__(Tensor inputs,

src/TensorFlowNET.Core/Keras/backend.cs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,22 @@ public static void track_variable(RefVariable v)
1111

1212
}
1313

14+
public static Tensor placeholder(int[] shape = null,
15+
int ndim = -1,
16+
TF_DataType dtype = TF_DataType.DtInvalid,
17+
bool sparse = false,
18+
string name = null)
19+
{
20+
if(sparse)
21+
{
22+
throw new NotImplementedException("placeholder sparse is true");
23+
}
24+
else
25+
{
26+
return gen_array_ops.placeholder(dtype: dtype, shape: new TensorShape(shape), name: name);
27+
}
28+
}
29+
1430
public static Graph get_graph()
1531
{
1632
return ops.get_default_graph();

0 commit comments

Comments
 (0)