Skip to content

Commit df1ae35

Browse files
committed
Pool interface, Dense layer
1 parent 0fac8e9 commit df1ae35

File tree

18 files changed

+297
-17
lines changed

18 files changed

+297
-17
lines changed

src/TensorFlowNET.Core/APIs/tf.array.cs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,5 +30,13 @@ public static Tensor expand_dims(Tensor input, int axis = -1, string name = null
3030
/// <returns></returns>
3131
public static Tensor transpose(Tensor a, int[] perm = null, string name = "transpose", bool conjugate = false)
3232
=> array_ops.transpose(a, perm, name, conjugate);
33+
34+
public static Tensor squeeze(Tensor input, int[] axis = null, string name = null, int squeeze_dims = -1)
35+
=> gen_array_ops.squeeze(input, axis, name);
36+
37+
public static Tensor one_hot(Tensor indices, int depth)
38+
{
39+
throw new NotImplementedException("one_hot");
40+
}
3341
}
3442
}

src/TensorFlowNET.Core/APIs/tf.init.cs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,5 +27,13 @@ public static variable_scope variable_scope(VariableScope scope,
2727
default_name,
2828
values,
2929
auxiliary_name_scope);
30+
31+
public static IInitializer truncated_normal_initializer(float mean = 0.0f,
32+
float stddev = 1.0f,
33+
int? seed = null,
34+
TF_DataType dtype = TF_DataType.DtInvalid) => new TruncatedNormal(mean: mean,
35+
stddev: stddev,
36+
seed: seed,
37+
dtype: dtype);
3038
}
3139
}

src/TensorFlowNET.Core/APIs/tf.layers.cs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,26 @@ public static Tensor max_pooling2d(Tensor inputs,
126126

127127
return layer.apply(inputs);
128128
}
129+
130+
public static Tensor dense(Tensor inputs,
131+
int units,
132+
IActivation activation = null,
133+
bool use_bias = true,
134+
IInitializer kernel_initializer = null,
135+
IInitializer bias_initializer = null,
136+
bool trainable = true,
137+
string name = null,
138+
bool? reuse = null)
139+
{
140+
if (bias_initializer == null)
141+
bias_initializer = tf.zeros_initializer;
142+
143+
var layer = new Dense(units, activation,
144+
use_bias: use_bias,
145+
kernel_initializer: kernel_initializer);
146+
147+
return layer.apply(inputs);
148+
}
129149
}
130150
}
131151
}

src/TensorFlowNET.Core/APIs/tf.math.cs

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,31 +6,40 @@ namespace Tensorflow
66
{
77
public static partial class tf
88
{
9-
public static Tensor add(Tensor a, Tensor b) => gen_math_ops.add(a, b);
9+
public static Tensor add(Tensor a, Tensor b)
10+
=> gen_math_ops.add(a, b);
1011

11-
public static Tensor sub(Tensor a, Tensor b) => gen_math_ops.sub(a, b);
12+
public static Tensor sub(Tensor a, Tensor b)
13+
=> gen_math_ops.sub(a, b);
1214

13-
public static Tensor sqrt(Tensor a, string name = null) => gen_math_ops.sqrt(a, name);
15+
public static Tensor sqrt(Tensor a, string name = null)
16+
=> gen_math_ops.sqrt(a, name);
1417

1518
public static Tensor subtract<T>(Tensor x, T[] y, string name = null) where T : struct
1619
=> gen_math_ops.sub(x, ops.convert_to_tensor(y, dtype: x.dtype.as_base_dtype(), name: "y"), name);
1720

18-
public static Tensor multiply(Tensor x, Tensor y) => gen_math_ops.mul(x, y);
21+
public static Tensor multiply(Tensor x, Tensor y)
22+
=> gen_math_ops.mul(x, y);
1923

2024
public static Tensor divide<T>(Tensor x, T[] y, string name = null) where T : struct
2125
=> x / ops.convert_to_tensor(y, dtype: x.dtype.as_base_dtype(), name: "y");
2226

23-
public static Tensor pow<T1, T2>(T1 x, T2 y) => gen_math_ops.pow(x, y);
27+
public static Tensor pow<T1, T2>(T1 x, T2 y)
28+
=> gen_math_ops.pow(x, y);
2429

2530
/// <summary>
2631
/// Computes the sum of elements across dimensions of a tensor.
2732
/// </summary>
2833
/// <param name="input"></param>
2934
/// <param name="axis"></param>
3035
/// <returns></returns>
31-
public static Tensor reduce_sum(Tensor input, int[] axis = null) => math_ops.reduce_sum(input);
36+
public static Tensor reduce_sum(Tensor input, int[] axis = null)
37+
=> math_ops.reduce_sum(input);
3238

3339
public static Tensor cast(Tensor x, TF_DataType dtype = TF_DataType.DtInvalid, string name = null)
3440
=> math_ops.cast(x, dtype, name);
41+
42+
public static Tensor argmax(Tensor input, int axis = -1, string name = null, int? dimension = null, TF_DataType output_type = TF_DataType.TF_INT64)
43+
=> gen_math_ops.arg_max(input, axis, name: name, output_type: output_type);
3544
}
3645
}

src/TensorFlowNET.Core/APIs/tf.nn.cs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,10 @@ public static Tensor[] fused_batch_norm(Tensor x,
4242
is_training: is_training,
4343
name: name);
4444

45-
public static Tensor max_pool() => gen_nn_ops.max_pool();
45+
public static IPoolFunction max_pool => new MaxPoolFunction();
46+
47+
public static Tensor[] top_k(Tensor input, int k = 1, bool sorted = true, string name = null)
48+
=> gen_nn_ops.top_kv2(input, k: k, sorted: sorted, name: name);
4649
}
4750
}
4851
}

src/TensorFlowNET.Core/APIs/tf.reshape.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,8 @@ public static Tensor reshape(Tensor tensor,
1010
Tensor shape,
1111
string name = null) => gen_array_ops.reshape(tensor, shape, name);
1212

13+
public static Tensor reshape(Tensor tensor,
14+
int[] shape,
15+
string name = null) => gen_array_ops.reshape(tensor, shape, name);
1316
}
1417
}

src/TensorFlowNET.Core/Keras/Engine/InputSpec.cs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,19 @@ namespace Tensorflow.Keras.Engine
1010
public class InputSpec
1111
{
1212
public int ndim;
13+
public int? min_ndim;
1314
Dictionary<int, int> axes;
1415

15-
public InputSpec(TF_DataType dtype = TF_DataType.DtInvalid,
16+
public InputSpec(TF_DataType dtype = TF_DataType.DtInvalid,
1617
int? ndim = null,
18+
int? min_ndim = null,
1719
Dictionary<int, int> axes = null)
1820
{
1921
this.ndim = ndim.Value;
2022
if (axes == null)
2123
axes = new Dictionary<int, int>();
2224
this.axes = axes;
25+
this.min_ndim = min_ndim;
2326
}
2427
}
2528
}

src/TensorFlowNET.Core/Keras/Engine/Layer.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ protected void _maybe_build(Tensor inputs)
122122

123123
protected virtual void build(TensorShape input_shape)
124124
{
125-
throw new NotImplementedException("Layer.build");
125+
built = true;
126126
}
127127

128128
protected virtual RefVariable add_weight(string name,
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
using Tensorflow.Keras.Engine;
5+
using Tensorflow.Operations.Activation;
6+
7+
namespace Tensorflow.Keras.Layers
8+
{
9+
public class Dense : Tensorflow.Layers.Layer
10+
{
11+
protected int uints;
12+
protected IActivation activation;
13+
protected bool use_bias;
14+
protected IInitializer kernel_initializer;
15+
protected IInitializer bias_initializer;
16+
17+
public Dense(int units,
18+
IActivation activation,
19+
bool use_bias = true,
20+
bool trainable = false,
21+
IInitializer kernel_initializer = null,
22+
IInitializer bias_initializer = null) : base(trainable: trainable)
23+
{
24+
this.uints = units;
25+
this.activation = activation;
26+
this.use_bias = use_bias;
27+
this.kernel_initializer = kernel_initializer;
28+
this.bias_initializer = bias_initializer;
29+
this.supports_masking = true;
30+
this.input_spec = new InputSpec(min_ndim: 2);
31+
}
32+
}
33+
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow
6+
{
7+
public interface IPoolFunction
8+
{
9+
Tensor Apply(Tensor value,
10+
int[] ksize,
11+
int[] strides,
12+
string padding,
13+
string data_format = "NHWC",
14+
string name = null);
15+
}
16+
}

0 commit comments

Comments
 (0)