Skip to content

Commit cdf39c5

Browse files
committed
tf.keras.layers. #570
1 parent 4053080 commit cdf39c5

File tree

22 files changed

+415
-153
lines changed

22 files changed

+415
-153
lines changed

src/TensorFlowNET.Core/APIs/tf.layers.cs

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ public Tensor conv2d(Tensor inputs,
4040
string data_format= "channels_last",
4141
int[] dilation_rate = null,
4242
bool use_bias = true,
43-
IActivation activation = null,
43+
Activation activation = null,
4444
IInitializer kernel_initializer = null,
4545
IInitializer bias_initializer = null,
4646
bool trainable = true,
@@ -53,20 +53,23 @@ public Tensor conv2d(Tensor inputs,
5353
if (bias_initializer == null)
5454
bias_initializer = tf.zeros_initializer;
5555

56-
var layer = new Conv2D(filters,
57-
kernel_size: kernel_size,
58-
strides: strides,
59-
padding: padding,
60-
data_format: data_format,
61-
dilation_rate: dilation_rate,
62-
activation: activation,
63-
use_bias: use_bias,
64-
kernel_initializer: kernel_initializer,
65-
bias_initializer: bias_initializer,
66-
trainable: trainable,
67-
name: name);
56+
var layer = new Conv2D(new Conv2DArgs
57+
{
58+
Filters = filters,
59+
KernelSize = kernel_size,
60+
Strides = strides,
61+
Padding = padding,
62+
DataFormat = data_format,
63+
DilationRate = dilation_rate,
64+
Activation = activation,
65+
UseBias = use_bias,
66+
KernelInitializer = kernel_initializer,
67+
BiasInitializer = bias_initializer,
68+
Trainable = trainable,
69+
Name = name
70+
});
6871

69-
return layer.apply(inputs).Item1;
72+
return layer.Apply(inputs);
7073
}
7174

7275
/// <summary>
@@ -140,13 +143,16 @@ public Tensor max_pooling2d(Tensor inputs,
140143
string data_format = "channels_last",
141144
string name = null)
142145
{
143-
var layer = new MaxPooling2D(pool_size: pool_size,
144-
strides: strides,
145-
padding: padding,
146-
data_format: data_format,
147-
name: name);
146+
var layer = new MaxPooling2D(new MaxPooling2DArgs
147+
{
148+
PoolSize = pool_size,
149+
Strides = strides,
150+
Padding = padding,
151+
DataFormat = data_format,
152+
Name = name
153+
});
148154

149-
return layer.apply(inputs).Item1;
155+
return layer.Apply(inputs);
150156
}
151157

152158
/// <summary>

src/TensorFlowNET.Core/APIs/tf.nn.cs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,8 @@ public Tensor dropout(Tensor x, Tensor keep_prob = null, Tensor noise_shape = nu
6666
Tensor keep = null;
6767
if (keep_prob != null)
6868
keep = 1.0f - keep_prob;
69-
var rate_tensor = keep;
7069

71-
return nn_ops.dropout_v2(x, rate: rate_tensor, noise_shape: noise_shape, seed: seed, name: name);
70+
return nn_ops.dropout_v2(x, rate: rate.Value, noise_shape: noise_shape, seed: seed, name: name);
7271
}
7372

7473
/// <summary>

src/TensorFlowNET.Core/Framework/smart_module.cs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,14 @@ public static Tensor[] smart_cond<T>(Tensor pred,
4141
name: name);
4242
}
4343

44+
public static Tensor smart_cond(bool pred,
45+
Func<Tensor> true_fn = null,
46+
Func<Tensor> false_fn = null,
47+
string name = null)
48+
{
49+
return pred ? true_fn() : false_fn();
50+
}
51+
4452
public static bool? smart_constant_value(Tensor pred)
4553
{
4654
var pred_value = tensor_util.constant_value(pred);
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.Keras.ArgsDefinition
6+
{
7+
public class Conv2DArgs : ConvArgs
8+
{
9+
10+
}
11+
}
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
using static Tensorflow.Binding;
5+
6+
namespace Tensorflow.Keras.ArgsDefinition
7+
{
8+
public class ConvArgs : LayerArgs
9+
{
10+
public int Rank { get; set; } = 2;
11+
public int Filters { get; set; }
12+
public TensorShape KernelSize { get; set; } = 5;
13+
14+
/// <summary>
15+
/// specifying the stride length of the convolution.
16+
/// </summary>
17+
public TensorShape Strides { get; set; } = (1, 1);
18+
19+
public string Padding { get; set; } = "valid";
20+
public string DataFormat { get; set; }
21+
public TensorShape DilationRate { get; set; } = (1, 1);
22+
public int Groups { get; set; } = 1;
23+
public Activation Activation { get; set; }
24+
public bool UseBias { get; set; }
25+
public IInitializer KernelInitializer { get; set; } = tf.glorot_uniform_initializer;
26+
public IInitializer BiasInitializer { get; set; } = tf.zeros_initializer;
27+
public IInitializer KernelRegularizer { get; set; }
28+
public IInitializer BiasRegularizer { get; set; }
29+
public Action KernelConstraint { get; set; }
30+
public Action BiasConstraint { get; set; }
31+
}
32+
}
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.Keras.ArgsDefinition
6+
{
7+
public class DropoutArgs : LayerArgs
8+
{
9+
/// <summary>
10+
/// Float between 0 and 1. Fraction of the input units to drop.
11+
/// </summary>
12+
public float Rate { get; set; }
13+
14+
/// <summary>
15+
/// 1D integer tensor representing the shape of the
16+
/// binary dropout mask that will be multiplied with the input.
17+
/// </summary>
18+
public TensorShape NoiseShape { get; set; }
19+
20+
/// <summary>
21+
/// random seed.
22+
/// </summary>
23+
public int? Seed { get; set; }
24+
25+
public bool SupportsMasking { get; set; }
26+
}
27+
}
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.Keras.ArgsDefinition
6+
{
7+
public class FlattenArgs : LayerArgs
8+
{
9+
public string DataFormat { get; set; }
10+
}
11+
}
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
using Tensorflow.Keras.Layers;
5+
6+
namespace Tensorflow.Keras.ArgsDefinition
7+
{
8+
public class MaxPooling2DArgs : Pooling2DArgs
9+
{
10+
11+
}
12+
}
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.Keras.ArgsDefinition
6+
{
7+
public class Pooling2DArgs : LayerArgs
8+
{
9+
/// <summary>
10+
/// The pooling function to apply, e.g. `tf.nn.max_pool2d`.
11+
/// </summary>
12+
public IPoolFunction PoolFunction { get; set; }
13+
14+
/// <summary>
15+
/// specifying the size of the pooling window.
16+
/// </summary>
17+
public TensorShape PoolSize { get; set; }
18+
19+
/// <summary>
20+
/// specifying the strides of the pooling operation.
21+
/// </summary>
22+
public TensorShape Strides { get; set; }
23+
24+
/// <summary>
25+
/// The padding method, either 'valid' or 'same'.
26+
/// </summary>
27+
public string Padding { get; set; } = "valid";
28+
29+
/// <summary>
30+
/// one of `channels_last` (default) or `channels_first`.
31+
/// </summary>
32+
public string DataFormat { get; set; }
33+
}
34+
}
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
using Tensorflow.Keras.ArgsDefinition;
5+
using Tensorflow.Keras.Utils;
6+
using static Tensorflow.Binding;
7+
8+
namespace Tensorflow.Keras.Engine
9+
{
10+
public class Flatten : Layer
11+
{
12+
FlattenArgs args;
13+
InputSpec input_spec;
14+
bool _channels_first;
15+
16+
public Flatten(FlattenArgs args)
17+
: base(args)
18+
{
19+
args.DataFormat = conv_utils.normalize_data_format(args.DataFormat);
20+
input_spec = new InputSpec(min_ndim: 1);
21+
_channels_first = args.DataFormat == "channels_first";
22+
}
23+
24+
protected override Tensor call(Tensor inputs, bool is_training = false, Tensor state = null)
25+
{
26+
if (_channels_first)
27+
{
28+
throw new NotImplementedException("");
29+
}
30+
31+
if (tf.executing_eagerly())
32+
{
33+
return array_ops.reshape(inputs, new[] { inputs.shape[0], -1 });
34+
}
35+
36+
throw new NotImplementedException("");
37+
}
38+
}
39+
}

0 commit comments

Comments
 (0)