Skip to content

Commit 7ebc2b2

Browse files
committed
Dese layer
1 parent df1ae35 commit 7ebc2b2

File tree

3 files changed

+31
-5
lines changed

3 files changed

+31
-5
lines changed

src/TensorFlowNET.Core/Keras/Engine/InputSpec.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ namespace Tensorflow.Keras.Engine
99
/// </summary>
1010
public class InputSpec
1111
{
12-
public int ndim;
12+
public int? ndim;
1313
public int? min_ndim;
1414
Dictionary<int, int> axes;
1515

@@ -18,7 +18,7 @@ public InputSpec(TF_DataType dtype = TF_DataType.DtInvalid,
1818
int? min_ndim = null,
1919
Dictionary<int, int> axes = null)
2020
{
21-
this.ndim = ndim.Value;
21+
this.ndim = ndim;
2222
if (axes == null)
2323
axes = new Dictionary<int, int>();
2424
this.axes = axes;
Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using System;
22
using System.Collections.Generic;
3+
using System.Linq;
34
using System.Text;
45
using Tensorflow.Keras.Engine;
56
using Tensorflow.Operations.Activation;
@@ -8,11 +9,13 @@ namespace Tensorflow.Keras.Layers
89
{
910
public class Dense : Tensorflow.Layers.Layer
1011
{
11-
protected int uints;
12+
protected int units;
1213
protected IActivation activation;
1314
protected bool use_bias;
1415
protected IInitializer kernel_initializer;
1516
protected IInitializer bias_initializer;
17+
protected RefVariable kernel;
18+
protected RefVariable bias;
1619

1720
public Dense(int units,
1821
IActivation activation,
@@ -21,13 +24,36 @@ public Dense(int units,
2124
IInitializer kernel_initializer = null,
2225
IInitializer bias_initializer = null) : base(trainable: trainable)
2326
{
24-
this.uints = units;
27+
this.units = units;
2528
this.activation = activation;
2629
this.use_bias = use_bias;
2730
this.kernel_initializer = kernel_initializer;
2831
this.bias_initializer = bias_initializer;
2932
this.supports_masking = true;
3033
this.input_spec = new InputSpec(min_ndim: 2);
3134
}
35+
36+
protected override void build(TensorShape input_shape)
37+
{
38+
var last_dim = input_shape.Dimensions.Last();
39+
var axes = new Dictionary<int, int>();
40+
axes[-1] = last_dim;
41+
input_spec = new InputSpec(min_ndim: 2, axes: axes);
42+
kernel = add_weight(
43+
"kernel",
44+
shape: new int[] { last_dim, units },
45+
initializer: kernel_initializer,
46+
dtype: _dtype,
47+
trainable: true);
48+
if (use_bias)
49+
bias = add_weight(
50+
"bias",
51+
shape: new int[] { units },
52+
initializer: bias_initializer,
53+
dtype: _dtype,
54+
trainable: true);
55+
56+
built = true;
57+
}
3258
}
3359
}

src/TensorFlowNET.Core/Operations/Initializers/TruncatedNormal.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ public TruncatedNormal(float mean = 0.0f,
2424

2525
public Tensor call(TensorShape shape, TF_DataType dtype)
2626
{
27-
throw new NotImplementedException("");
27+
return random_ops.truncated_normal(shape, mean, stddev, dtype : dtype, seed: seed);
2828
}
2929

3030
public object get_config()

0 commit comments

Comments
 (0)