Skip to content

Commit 8b6dc47

Browse files
committed
Merge branch 'master' into tf.keras-0.3.image-classification
2 parents 4be5fad + 89f40e5 commit 8b6dc47

25 files changed

+835
-51
lines changed

README.md

Lines changed: 20 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -56,30 +56,32 @@ PM> Install-Package SciSharp.TensorFlow.Redist-Windows-GPU
5656

5757
Import TF.NET and Keras API in your project.
5858

59-
```cs
59+
```csharp
6060
using static Tensorflow.Binding;
6161
using static Tensorflow.KerasApi;
62+
using Tensorflow;
63+
using NumSharp;
6264
```
6365

6466
Linear Regression in `Eager` mode:
6567

66-
```c#
68+
```csharp
6769
// Parameters
6870
var training_steps = 1000;
6971
var learning_rate = 0.01f;
7072
var display_step = 100;
7173

7274
// Sample data
73-
var train_X = np.array(3.3f, 4.4f, 5.5f, 6.71f, 6.93f, 4.168f, 9.779f, 6.182f, 7.59f, 2.167f,
75+
var X = np.array(3.3f, 4.4f, 5.5f, 6.71f, 6.93f, 4.168f, 9.779f, 6.182f, 7.59f, 2.167f,
7476
7.042f, 10.791f, 5.313f, 7.997f, 5.654f, 9.27f, 3.1f);
75-
var train_Y = np.array(1.7f, 2.76f, 2.09f, 3.19f, 1.694f, 1.573f, 3.366f, 2.596f, 2.53f, 1.221f,
77+
var Y = np.array(1.7f, 2.76f, 2.09f, 3.19f, 1.694f, 1.573f, 3.366f, 2.596f, 2.53f, 1.221f,
7678
2.827f, 3.465f, 1.65f, 2.904f, 2.42f, 2.94f, 1.3f);
77-
var n_samples = train_X.shape[0];
79+
var n_samples = X.shape[0];
7880

7981
// We can set a fixed init value in order to demo
8082
var W = tf.Variable(-0.06f, name: "weight");
8183
var b = tf.Variable(-0.73f, name: "bias");
82-
var optimizer = tf.optimizers.SGD(learning_rate);
84+
var optimizer = keras.optimizers.SGD(learning_rate);
8385

8486
// Run training for the given number of steps.
8587
foreach (var step in range(1, training_steps + 1))
@@ -112,46 +114,40 @@ Run this example in [Jupyter Notebook](https://github.com/SciSharp/SciSharpCube)
112114
Toy version of `ResNet` in `Keras` functional API:
113115

114116
```csharp
117+
var layers = new LayersApi();
115118
// input layer
116119
var inputs = keras.Input(shape: (32, 32, 3), name: "img");
117-
118120
// convolutional layer
119121
var x = layers.Conv2D(32, 3, activation: "relu").Apply(inputs);
120122
x = layers.Conv2D(64, 3, activation: "relu").Apply(x);
121123
var block_1_output = layers.MaxPooling2D(3).Apply(x);
122-
123124
x = layers.Conv2D(64, 3, activation: "relu", padding: "same").Apply(block_1_output);
124125
x = layers.Conv2D(64, 3, activation: "relu", padding: "same").Apply(x);
125-
var block_2_output = layers.add(x, block_1_output);
126-
126+
var block_2_output = layers.Add().Apply(new Tensors(x, block_1_output));
127127
x = layers.Conv2D(64, 3, activation: "relu", padding: "same").Apply(block_2_output);
128128
x = layers.Conv2D(64, 3, activation: "relu", padding: "same").Apply(x);
129-
var block_3_output = layers.add(x, block_2_output);
130-
129+
var block_3_output = layers.Add().Apply(new Tensors(x, block_2_output));
131130
x = layers.Conv2D(64, 3, activation: "relu").Apply(block_3_output);
132131
x = layers.GlobalAveragePooling2D().Apply(x);
133132
x = layers.Dense(256, activation: "relu").Apply(x);
134133
x = layers.Dropout(0.5f).Apply(x);
135-
136134
// output layer
137135
var outputs = layers.Dense(10).Apply(x);
138-
139136
// build keras model
140-
model = keras.Model(inputs, outputs, name: "toy_resnet");
137+
var model = keras.Model(inputs, outputs, name: "toy_resnet");
141138
model.summary();
142-
143139
// compile keras model in tensorflow static graph
144140
model.compile(optimizer: keras.optimizers.RMSprop(1e-3f),
145-
loss: keras.losses.CategoricalCrossentropy(from_logits: true),
146-
metrics: new[] { "acc" });
147-
141+
loss: keras.losses.CategoricalCrossentropy(from_logits: true),
142+
metrics: new[] { "acc" });
148143
// prepare dataset
149144
var ((x_train, y_train), (x_test, y_test)) = keras.datasets.cifar10.load_data();
150-
145+
x_train = x_train / 255.0f;
146+
y_train = np_utils.to_categorical(y_train, 10);
151147
// training
152-
model.fit(x_train[new Slice(0, 1000)], y_train[new Slice(0, 1000)],
153-
batch_size: 64,
154-
epochs: 10,
148+
model.fit(x_train[new Slice(0, 2000)], y_train[new Slice(0, 2000)],
149+
batch_size: 64,
150+
epochs: 10,
155151
validation_split: 0.2f);
156152
```
157153

@@ -260,4 +256,4 @@ WeChat Sponsor 微信打赏:
260256

261257
TensorFlow.NET is a part of [SciSharp STACK](https://scisharp.github.io/SciSharp/)
262258
<br>
263-
<a href="http://scisharpstack.org"><img src="https://github.com/SciSharp/SciSharp/blob/master/art/scisharp-stack.png" width="391" height="100" /></a>
259+
<a href="http://scisharpstack.org"><img src="https://github.com/SciSharp/SciSharp/blob/master/art/scisharp-stack.png" width="391" height="100" /></a>

src/TensorFlowNET.Core/Operations/array_ops.cs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -506,6 +506,27 @@ public static Tensor where(Tensor condition, object x = null, object y = null, s
506506
}
507507
}
508508

509+
510+
public static Tensor where_v2(Tensor condition, object x = null, object y = null, string name = null)
511+
{
512+
if (x == null && y == null)
513+
{
514+
return tf_with(ops.name_scope(name, "Where", new { condition }), scope =>
515+
{
516+
name = scope;
517+
condition = ops.convert_to_tensor(condition, preferred_dtype: dtypes.@bool, name: "condition");
518+
return gen_array_ops.where(condition: condition, name: name);
519+
});
520+
}
521+
else if (x != null && y != null)
522+
{
523+
return gen_array_ops.select_v2(condition, x, y, name);
524+
}
525+
else
526+
{
527+
throw new ValueError("x and y must both be non-None or both be None.");
528+
}
529+
}
509530
/// <summary>
510531
/// Returns the shape of a tensor.
511532
/// </summary>

src/TensorFlowNET.Core/Operations/gen_array_ops.cs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -423,6 +423,21 @@ public static Tensor select<Tx, Ty>(Tensor condition, Tx x, Ty y, string name =
423423
var _op = tf.OpDefLib._apply_op_helper("Select", name, new { condition, t = x, e = y });
424424
return _op.outputs[0];
425425
}
426+
public static Tensor select_v2<Tx, Ty>(Tensor condition, Tx x, Ty y, string name = null)
427+
{
428+
if (tf.Context.executing_eagerly())
429+
{
430+
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
431+
"SelectV2", name,
432+
null,
433+
condition, x, y);
434+
435+
return results[0];
436+
}
437+
438+
var _op = tf.OpDefLib._apply_op_helper("SelectV2", name, new { condition, t = x, e = y });
439+
return _op.outputs[0];
440+
}
426441

427442
public static Tensor scatter_nd(Tensor indices, Tensor updates, Tensor[] shape, string name = null)
428443
{

src/TensorFlowNET.Core/Operations/gen_math_ops.cs

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -714,7 +714,23 @@ public static Tensor log(Tensor x, string name = null)
714714

715715
return _op.outputs[0];
716716
}
717+
public static Tensor softplus(Tensor features, string name = null)
718+
{
719+
if (tf.Context.executing_eagerly())
720+
{
721+
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
722+
"Softplus", name,
723+
null,
724+
features);
717725

726+
return results[0];
727+
}
728+
729+
var _op = tf.OpDefLib._apply_op_helper("Softplus", name, args: new { features });
730+
731+
return _op.outputs[0];
732+
}
733+
718734
public static Tensor cast(Tensor x, TF_DataType DstT, bool Truncate = false, string name = null)
719735
=> tf.Context.RunInAutoMode(()
720736
=> tf.OpDefLib._apply_op_helper("Cast", name, args: new { x, DstT, Truncate }).output, ()
@@ -1068,6 +1084,15 @@ public static Tensor minimum<T1, T2>(T1 x, T2 y, string name = null)
10681084

10691085
public static Tensor _abs(Tensor x, string name = null)
10701086
{
1087+
if (tf.Context.executing_eagerly())
1088+
{
1089+
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
1090+
"Abs", name,
1091+
null,
1092+
x);
1093+
1094+
return results[0];
1095+
}
10711096
var _op = tf.OpDefLib._apply_op_helper("Abs", name, args: new { x });
10721097

10731098
return _op.output;
@@ -1202,6 +1227,15 @@ public static Tensor round(Tensor x, string name = "Round")
12021227
/// <returns></returns>
12031228
public static Tensor rsqrt(Tensor x, string name = null)
12041229
{
1230+
if (tf.Context.executing_eagerly())
1231+
{
1232+
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
1233+
"Rsqrt", name,
1234+
null,
1235+
x);
1236+
1237+
return results[0];
1238+
}
12051239
var _op = tf.OpDefLib._apply_op_helper("Rsqrt", name, new { x });
12061240

12071241
return _op.outputs[0];

src/TensorFlowNET.Core/Operations/nn_impl.py.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,15 +31,15 @@ public class nn_impl
3131
/// <returns></returns>
3232
public static Tensor l2_normalize(Tensor x,
3333
int axis = 0,
34-
float epsilon = 1e-12f,
34+
Tensor epsilon =null,
3535
string name = null)
3636
{
3737
return tf_with(ops.name_scope(name, "l2_normalize", new { x }), scope =>
3838
{
3939
x = ops.convert_to_tensor(x, name: "x");
4040
var sq = math_ops.square(x);
4141
var square_sum = math_ops.reduce_sum(sq, axis, keepdims: true);
42-
var x_inv_norm = math_ops.rsqrt(math_ops.maximum(square_sum, epsilon));
42+
var x_inv_norm = math_ops.rsqrt(math_ops.maximum(square_sum, epsilon == null ? tf.Variable(1e-12f) : epsilon));
4343
return math_ops.multiply(x, x_inv_norm, name: name);
4444
});
4545
}

src/TensorFlowNET.Keras/Losses/CategoricalCrossentropy.cs

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,19 @@ namespace Tensorflow.Keras.Losses
99
public class CategoricalCrossentropy : LossFunctionWrapper, ILossFunc
1010
{
1111
float label_smoothing;
12-
13-
public CategoricalCrossentropy(bool from_logits = false,
12+
public CategoricalCrossentropy(
13+
bool from_logits = false,
1414
float label_smoothing = 0,
15-
string reduction = ReductionV2.AUTO,
16-
string name = "categorical_crossentropy") :
17-
base(reduction: reduction,
18-
name: name,
19-
from_logits: from_logits)
15+
string reduction = null,
16+
string name = null) :
17+
base(reduction: reduction,
18+
name: name == null ? "categorical_crossentropy" : name,
19+
from_logits: from_logits)
2020
{
2121
this.label_smoothing = label_smoothing;
2222
}
2323

24+
2425
public override Tensor Apply(Tensor y_true, Tensor y_pred, bool from_logits = false, int axis = -1)
2526
{
2627
// Try to adjust the shape so that rank of labels = rank of logits - 1.
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
using static Tensorflow.Binding;
5+
using static Tensorflow.KerasApi;
6+
7+
namespace Tensorflow.Keras.Losses
8+
{
9+
public class CosineSimilarity : LossFunctionWrapper, ILossFunc
10+
{
11+
protected int axis=-1;
12+
public CosineSimilarity(
13+
string reduction = null,
14+
int axis=-1,
15+
string name = null) :
16+
base(reduction: reduction, name: name == null ? "cosine_similarity" : name)
17+
{
18+
this.axis = axis;
19+
}
20+
21+
public override Tensor Apply(Tensor y_true = null, Tensor y_pred =null, bool from_logits = false, int axis = -1)
22+
{
23+
Tensor y_true_normalize = nn_impl.l2_normalize(y_true, axis : this.axis);
24+
Tensor y_pred_normalize = nn_impl.l2_normalize(y_pred, axis: this.axis);
25+
return -math_ops.reduce_sum(y_true_normalize * y_pred_normalize, axis : this.axis);
26+
}
27+
}
28+
}
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
using static Tensorflow.Binding;
5+
using static Tensorflow.KerasApi;
6+
7+
namespace Tensorflow.Keras.Losses
8+
{
9+
public class Huber : LossFunctionWrapper, ILossFunc
10+
{
11+
protected Tensor delta = tf.Variable(1.0) ;
12+
public Huber (
13+
string reduction = null,
14+
Tensor delta = null,
15+
string name = null) :
16+
base(reduction: reduction, name: name == null ? "huber" : name)
17+
{
18+
this.delta = delta==null? this.delta: delta;
19+
20+
}
21+
22+
public override Tensor Apply(Tensor y_true = null, Tensor y_pred =null, bool from_logits = false, int axis = -1)
23+
{
24+
Tensor y_pred_cast = math_ops.cast(y_pred, dtype: TF_DataType.TF_FLOAT);
25+
Tensor y_true_cast = math_ops.cast(y_true, dtype: TF_DataType.TF_FLOAT);
26+
Tensor delta = math_ops.cast(this.delta, dtype: TF_DataType.TF_FLOAT);
27+
Tensor error = math_ops.subtract(y_pred_cast, y_true_cast);
28+
Tensor abs_error = math_ops.abs(error);
29+
Tensor half = ops.convert_to_tensor(0.5, dtype: abs_error.dtype);
30+
return gen_math_ops.mean(array_ops.where_v2(abs_error <= delta,
31+
half * math_ops.pow(error, 2),
32+
half * math_ops.pow(delta, 2) + delta * (abs_error - delta)),
33+
axis : -1);
34+
}
35+
}
36+
}

src/TensorFlowNET.Keras/Losses/ILossFunc.cs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
{
33
public interface ILossFunc
44
{
5-
string Reduction { get; }
6-
Tensor Call(Tensor y_true, Tensor y_pred);
5+
public string Reduction { get; }
6+
public string Name { get; }
7+
Tensor Call(Tensor y_true, Tensor y_pred, Tensor sample_weight = null);
78
}
89
}
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
using Tensorflow.Operations;
5+
using static Tensorflow.Binding;
6+
using static Tensorflow.KerasApi;
7+
8+
namespace Tensorflow.Keras.Losses
9+
{
10+
public class LogCosh : LossFunctionWrapper, ILossFunc
11+
{
12+
public LogCosh(
13+
string reduction = null,
14+
string name = null) :
15+
base(reduction: reduction, name: name == null ? "huber" : name){ }
16+
17+
public override Tensor Apply(Tensor y_true = null, Tensor y_pred =null, bool from_logits = false, int axis = -1)
18+
{
19+
Tensor y_pred_dispatch = ops.convert_to_tensor(y_pred);
20+
Tensor y_true_cast = gen_math_ops.cast(y_true, y_pred_dispatch.dtype);
21+
Tensor x = y_pred_dispatch - y_true_cast;
22+
23+
return gen_math_ops.mean(x + gen_math_ops.softplus(-2.0 * x) - math_ops.cast(math_ops.log(tf.Variable(2.0)), x.dtype),axis: -1);
24+
25+
26+
}
27+
}
28+
}

0 commit comments

Comments
 (0)