Skip to content

Commit b048b62

Browse files
committed
fix Embedding layer.
1 parent 67a70bf commit b048b62

File tree

3 files changed

+6
-29
lines changed

3 files changed

+6
-29
lines changed

src/TensorFlowNET.Keras/Engine/Model.Predict.cs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@ public Tensors predict(Tensor x,
6060
// callbacks.on_predict_batch_end(end_step, {'outputs': batch_outputs})
6161
}
6262
GC.Collect();
63-
GC.WaitForPendingFinalizers();
6463
}
6564
// callbacks.on_predict_end()
6665
return outputs;

src/TensorFlowNET.Keras/Layers/Core/Embedding.cs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,9 @@ public Embedding(EmbeddingArgs args)
3838
: base(new LayerArgs // copy args
3939
{
4040
DType = args.DType,
41-
Name = args.Name
41+
Name = args.Name,
42+
InputShape = args.InputShape,
43+
BatchSize = args.BatchSize
4244
})
4345
{
4446
this.args = args;

test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs

Lines changed: 3 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -82,37 +82,13 @@ public void TensorFlowOpLayer()
8282
/// https://www.tensorflow.org/api_docs/python/tf/keras/layers/Embedding
8383
/// </summary>
8484
[TestMethod]
85-
public void Embedding_Simple()
86-
{
87-
var emb = keras.layers.Embedding(256, 12, input_length: 4);
88-
var input_array = np.arange(12).reshape((3, 4)).astype(np.float32);
89-
var output = emb.Apply(input_array);
90-
Assert.AreEqual((3, 4, 12), output.shape);
91-
}
92-
93-
/// <summary>
94-
/// https://www.tensorflow.org/api_docs/python/tf/keras/layers/Embedding
95-
/// </summary>
96-
[TestMethod]
97-
[Ignore]
9885
public void Embedding()
9986
{
10087
var model = keras.Sequential();
101-
var layer = keras.layers.Embedding(7, 2, input_length: 4);
88+
var layer = keras.layers.Embedding(1000, 64, input_length: 10);
10289
model.add(layer);
103-
// the model will take as input an integer matrix of size (batch,
104-
// input_length).
105-
// the largest integer (i.e. word index) in the input should be no larger
106-
// than 999 (vocabulary size).
107-
// now model.output_shape == (None, 10, 64), where None is the batch
108-
// dimension.
109-
var input_array = np.array(new int[,]
110-
{
111-
{ 1, 2, 3, 4 },
112-
{ 2, 3, 4, 5 },
113-
{ 3, 4, 5, 6 }
114-
});
115-
// model.compile("rmsprop", "mse");
90+
var input_array = np.random.randint(1000, size: (32, 10));
91+
model.compile("rmsprop", "mse", new[] { "accuracy" });
11692
var output_array = model.predict(input_array);
11793
Assert.AreEqual((32, 10, 64), output_array.shape);
11894
}

0 commit comments

Comments
 (0)