Skip to content

Commit 6e4bad4

Browse files
committed
Sequential #570
1 parent cdf39c5 commit 6e4bad4

File tree

26 files changed

+400
-81
lines changed

26 files changed

+400
-81
lines changed

src/TensorFlowNET.Core/APIs/tf.init.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ public IInitializer constant_initializer<T>(T value, TF_DataType dtype = TF_Data
2727
public IInitializer zeros_initializer => new Zeros();
2828
public IInitializer ones_initializer => new Ones();
2929
public IInitializer glorot_uniform_initializer => new GlorotUniform();
30-
public IInitializer uniform_initializer => new RandomUniform();
30+
public IInitializer random_uniform_initializer => new RandomUniform();
31+
public IInitializer orthogonal_initializer => new Orthogonal();
3132

3233
public variable_scope variable_scope(string name,
3334
string default_name = null,

src/TensorFlowNET.Core/Keras/Activations/Activations.Relu.cs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@ public partial class Activations
2020
return results[0];
2121
}
2222

23-
throw new NotImplementedException("");
23+
var _op = tf.OpDefLib._apply_op_helper("Relu", name: name, args: new { features });
24+
25+
return _op.output;
2426
};
2527
}
2628
}
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
using Tensorflow.Operations;
5+
using static Tensorflow.Binding;
6+
7+
namespace Tensorflow.Keras
8+
{
9+
public partial class Activations
10+
{
11+
public Activation Sigmoid = (features, name) =>
12+
{
13+
if (tf.executing_eagerly())
14+
{
15+
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
16+
"Sigmoid", name,
17+
null,
18+
features);
19+
20+
return results[0];
21+
}
22+
23+
throw new NotImplementedException("");
24+
};
25+
}
26+
}
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
using Tensorflow.Operations;
5+
using static Tensorflow.Binding;
6+
7+
namespace Tensorflow.Keras
8+
{
9+
public partial class Activations
10+
{
11+
public Activation Tanh = (features, name) =>
12+
{
13+
if (tf.executing_eagerly())
14+
{
15+
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
16+
"Tanh", name,
17+
null,
18+
features);
19+
20+
return results[0];
21+
}
22+
23+
throw new NotImplementedException("");
24+
};
25+
}
26+
}
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.Keras.ArgsDefinition
6+
{
7+
public class EmbeddingArgs : LayerArgs
8+
{
9+
public int InputDim { get; set; }
10+
public int OutputDim { get; set; }
11+
public bool MaskZero { get; set; }
12+
public int InputLength { get; set; } = -1;
13+
public IInitializer EmbeddingsInitializer { get; set; }
14+
}
15+
}
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.Keras.ArgsDefinition
6+
{
7+
public class LSTMArgs : RNNArgs
8+
{
9+
public int Units { get; set; }
10+
public Activation Activation { get; set; }
11+
public Activation RecurrentActivation { get; set; }
12+
public IInitializer KernelInitializer { get; set; }
13+
public IInitializer RecurrentInitializer { get; set; }
14+
public IInitializer BiasInitializer { get; set; }
15+
public bool UnitForgetBias { get; set; }
16+
public float Dropout { get; set; }
17+
public float RecurrentDropout { get; set; }
18+
public int Implementation { get; set; }
19+
public bool ReturnSequences { get; set; }
20+
public bool ReturnState { get; set; }
21+
public bool GoBackwards { get; set; }
22+
public bool Stateful { get; set; }
23+
public bool TimeMajor { get; set; }
24+
public bool Unroll { get; set; }
25+
}
26+
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.Keras.ArgsDefinition
6+
{
7+
public class LSTMCellArgs : LayerArgs
8+
{
9+
}
10+
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.Keras.ArgsDefinition
6+
{
7+
public class RNNArgs : LayerArgs
8+
{
9+
}
10+
}

src/TensorFlowNET.Core/Keras/Engine/InputSpec.cs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,17 +26,22 @@ public class InputSpec
2626
public int? ndim;
2727
public int? min_ndim;
2828
Dictionary<int, int> axes;
29+
TensorShape shape;
2930

3031
public InputSpec(TF_DataType dtype = TF_DataType.DtInvalid,
3132
int? ndim = null,
3233
int? min_ndim = null,
33-
Dictionary<int, int> axes = null)
34+
Dictionary<int, int> axes = null,
35+
TensorShape shape = null)
3436
{
3537
this.ndim = ndim;
3638
if (axes == null)
3739
axes = new Dictionary<int, int>();
3840
this.axes = axes;
3941
this.min_ndim = min_ndim;
42+
this.shape = shape;
43+
if (ndim == null && shape != null)
44+
this.ndim = shape.ndim;
4045
}
4146
}
4247
}

src/TensorFlowNET.Core/Keras/Engine/Layer.Layers.cs

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
using System.Text;
44
using Tensorflow.Keras.ArgsDefinition;
55
using Tensorflow.Keras.Layers;
6+
using Tensorflow.Operations.Activation;
67
using static Tensorflow.Binding;
78

89
namespace Tensorflow.Keras.Engine
@@ -100,5 +101,46 @@ protected Layer Flatten()
100101
_layers.Add(layer);
101102
return layer;
102103
}
104+
105+
protected Layer LSTM(int units,
106+
Activation activation = null,
107+
Activation recurrent_activation = null,
108+
bool use_bias = true,
109+
IInitializer kernel_initializer = null,
110+
IInitializer recurrent_initializer = null,
111+
IInitializer bias_initializer = null,
112+
bool unit_forget_bias = true,
113+
float dropout = 0f,
114+
float recurrent_dropout = 0f,
115+
int implementation = 2,
116+
bool return_sequences = false,
117+
bool return_state = false,
118+
bool go_backwards = false,
119+
bool stateful = false,
120+
bool time_major = false,
121+
bool unroll = false)
122+
{
123+
var layer = new LSTM(new LSTMArgs
124+
{
125+
Units = units,
126+
Activation = activation ?? tf.keras.activations.Tanh,
127+
RecurrentActivation = recurrent_activation ?? tf.keras.activations.Sigmoid,
128+
KernelInitializer = kernel_initializer ?? tf.glorot_uniform_initializer,
129+
RecurrentInitializer = recurrent_initializer ?? tf.orthogonal_initializer,
130+
BiasInitializer = bias_initializer ?? tf.zeros_initializer,
131+
Dropout = dropout,
132+
RecurrentDropout = recurrent_dropout,
133+
Implementation = implementation,
134+
ReturnSequences = return_sequences,
135+
ReturnState = return_state,
136+
GoBackwards = go_backwards,
137+
Stateful = stateful,
138+
TimeMajor = time_major,
139+
Unroll = unroll
140+
});
141+
142+
_layers.Add(layer);
143+
return layer;
144+
}
103145
}
104146
}

0 commit comments

Comments
 (0)