Skip to content

Commit ac523ed

Browse files
committed
add Embedding layer
1 parent 5442727 commit ac523ed

File tree

13 files changed

+179
-16
lines changed

13 files changed

+179
-16
lines changed
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
using Tensorflow.Keras;
5+
using Tensorflow.Keras.Engine;
6+
using Tensorflow.Keras.Layers;
7+
8+
namespace Tensorflow
9+
{
10+
public static partial class keras
11+
{
12+
public static class layers
13+
{
14+
public static Embedding Embedding(int input_dim, int output_dim,
15+
string embeddings_initializer = "uniform",
16+
bool mask_zero = false) => new Embedding(input_dim, output_dim,
17+
embeddings_initializer,
18+
mask_zero);
19+
}
20+
}
21+
}

src/TensorFlowNET.Core/APIs/tf.init.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@ public static partial class tf
1010
public static IInitializer zeros_initializer => new Zeros();
1111
public static IInitializer ones_initializer => new Ones();
1212
public static IInitializer glorot_uniform_initializer => new GlorotUniform();
13-
13+
public static IInitializer uniform_initializer => new RandomUniform();
14+
1415
public static variable_scope variable_scope(string name,
1516
string default_name = null,
1617
object values = null,

src/TensorFlowNET.Core/Keras/Engine/Model.cs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,12 @@
44

55
namespace Tensorflow.Keras.Engine
66
{
7-
internal class Model : Network
7+
public class Model : Network
88
{
9+
public Model(string name = null)
10+
: base(name: name)
11+
{
12+
13+
}
914
}
1015
}

src/TensorFlowNET.Core/Keras/Engine/Network.cs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,29 @@ namespace Tensorflow.Keras.Engine
77
{
88
public class Network : Layer
99
{
10+
protected bool _is_compiled;
11+
protected bool _expects_training_arg;
12+
protected bool _compute_output_and_mask_jointly;
13+
14+
public Network(string name = null)
15+
: base(name: name)
16+
{
17+
18+
}
19+
20+
protected virtual void _init_subclassed_network(string name = null)
21+
{
22+
_base_init(name: name);
23+
}
24+
25+
protected virtual void _base_init(string name = null)
26+
{
27+
_init_set_name(name);
28+
trainable = true;
29+
_is_compiled = false;
30+
_expects_training_arg = false;
31+
_compute_output_and_mask_jointly = false;
32+
supports_masking = false;
33+
}
1034
}
1135
}
Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,38 @@
11
using System;
22
using System.Collections.Generic;
33
using System.Text;
4+
using Tensorflow.Keras.Layers;
45

56
namespace Tensorflow.Keras.Engine
67
{
7-
public class Sequential : Network, IPython
8+
public class Sequential : Model, IPython
89
{
9-
public void Dispose()
10+
public Sequential(string name = null)
11+
: base(name: name)
1012
{
11-
throw new NotImplementedException();
13+
supports_masking = true;
14+
_compute_output_and_mask_jointly = true;
1215
}
1316

1417
public void __enter__()
1518
{
16-
throw new NotImplementedException();
19+
20+
}
21+
22+
public void add(Layer layer)
23+
{
24+
built = false;
25+
var set_inputs = false;
1726
}
1827

1928
public void __exit__()
2029
{
21-
throw new NotImplementedException();
30+
31+
}
32+
33+
public void Dispose()
34+
{
35+
2236
}
2337
}
2438
}
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.Keras.Layers
6+
{
7+
public class Embedding : Layer
8+
{
9+
private int input_dim;
10+
private int output_dim;
11+
private bool mask_zero;
12+
13+
public Embedding(int input_dim, int output_dim,
14+
IInitializer embeddings_initializer = null,
15+
bool mask_zero = false)
16+
{
17+
this.input_dim = input_dim;
18+
this.output_dim = output_dim;
19+
if (embeddings_initializer == null)
20+
embeddings_initializer = tf.uniform_initializer;
21+
this.mask_zero = mask_zero;
22+
supports_masking = mask_zero;
23+
}
24+
}
25+
}

src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,6 @@ public class CondContext : ControlFlowContext
1717
private Tensor _pred;
1818
public Tensor pred => _pred;
1919

20-
/// <summary>
21-
/// The predicate tensor in this branch
22-
/// </summary>
23-
private Tensor _pivot;
24-
2520
/// <summary>
2621
/// 0 or 1 representing this branch
2722
/// </summary>

src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,11 @@ namespace Tensorflow.Operations
66
{
77
public abstract class ControlFlowContext : IPython, IControlFlowContext
88
{
9+
/// <summary>
10+
/// The predicate tensor in this branch
11+
/// </summary>
12+
protected Tensor _pivot;
13+
914
protected Stack<IControlFlowContext> _context_stack;
1015
public ControlFlowContext()
1116
{
@@ -28,6 +33,29 @@ public virtual void Enter()
2833
graph._set_control_flow_context(this);
2934
}
3035

36+
public void AddOp(Operation op)
37+
{
38+
_AddOpInternal(op);
39+
}
40+
41+
protected virtual void _AddOpInternal(Operation op)
42+
{
43+
if(op.inputs.Length == 0)
44+
{
45+
_RemoveExternalControlEdges(op);
46+
op._add_control_input(_pivot.op);
47+
}
48+
else
49+
{
50+
51+
}
52+
}
53+
54+
protected virtual void _RemoveExternalControlEdges(Operation op)
55+
{
56+
var internal_control_inputs = op.control_inputs;
57+
}
58+
3159
public void Exit()
3260
{
3361
var graph = ops.get_default_graph();

src/TensorFlowNET.Core/Operations/ControlFlows/IControlFlowContext.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,6 @@ namespace Tensorflow
66
{
77
public interface IControlFlowContext
88
{
9+
void AddOp(Operation op);
910
}
1011
}
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.Operations.Initializers
6+
{
7+
public class RandomUniform : IInitializer
8+
{
9+
private int? seed;
10+
private float minval;
11+
private float maxval;
12+
private TF_DataType dtype;
13+
14+
public RandomUniform()
15+
{
16+
17+
}
18+
19+
public Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid)
20+
{
21+
return random_ops.random_uniform(shape,
22+
minval: minval,
23+
maxval: maxval,
24+
dtype: dtype,
25+
seed: seed);
26+
}
27+
28+
public object get_config()
29+
{
30+
return new {
31+
minval,
32+
maxval,
33+
seed,
34+
dtype
35+
};
36+
}
37+
}
38+
}

0 commit comments

Comments
 (0)