Skip to content

Commit 7dbcb6c

Browse files
committed
softmax_cross_entropy_with_logits
1 parent 7ebc2b2 commit 7dbcb6c

File tree

17 files changed

+215
-13
lines changed

17 files changed

+215
-13
lines changed

src/TensorFlowNET.Core/APIs/tf.array.cs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,11 @@ public static Tensor transpose(Tensor a, int[] perm = null, string name = "trans
3434
public static Tensor squeeze(Tensor input, int[] axis = null, string name = null, int squeeze_dims = -1)
3535
=> gen_array_ops.squeeze(input, axis, name);
3636

37-
public static Tensor one_hot(Tensor indices, int depth)
38-
{
39-
throw new NotImplementedException("one_hot");
40-
}
37+
public static Tensor one_hot(Tensor indices, int depth,
38+
Tensor on_value = null,
39+
Tensor off_value = null,
40+
TF_DataType dtype = TF_DataType.DtInvalid,
41+
int axis = -1,
42+
string name = null) => array_ops.one_hot(indices, depth, dtype: dtype, axis: axis, name: name);
4143
}
4244
}
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow
6+
{
7+
public static partial class tf
8+
{
9+
public static _ControlDependenciesController control_dependencies(Operation[] control_inputs)
10+
=> ops.control_dependencies(control_inputs);
11+
}
12+
}

src/TensorFlowNET.Core/APIs/tf.math.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ public static Tensor pow<T1, T2>(T1 x, T2 y)
3636
public static Tensor reduce_sum(Tensor input, int[] axis = null)
3737
=> math_ops.reduce_sum(input);
3838

39+
public static Tensor reduce_mean(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null)
40+
=> math_ops.reduce_mean(input_tensor, axis: axis, keepdims: keepdims, name: name);
41+
3942
public static Tensor cast(Tensor x, TF_DataType dtype = TF_DataType.DtInvalid, string name = null)
4043
=> math_ops.cast(x, dtype, name);
4144

src/TensorFlowNET.Core/APIs/tf.nn.cs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,18 @@ public static Tensor[] fused_batch_norm(Tensor x,
4646

4747
public static Tensor[] top_k(Tensor input, int k = 1, bool sorted = true, string name = null)
4848
=> gen_nn_ops.top_kv2(input, k: k, sorted: sorted, name: name);
49+
50+
public static Tensor bias_add(Tensor value, RefVariable bias, string data_format = null, string name = null)
51+
{
52+
return Python.with(ops.name_scope(name, "BiasAdd", new { value, bias }), scope =>
53+
{
54+
name = scope;
55+
return gen_nn_ops.bias_add(value, bias, data_format: data_format, name: name);
56+
});
57+
}
58+
59+
public static Tensor softmax_cross_entropy_with_logits_v2(Tensor labels, Tensor logits, int axis = -1, string name = null)
60+
=> nn_ops.softmax_cross_entropy_with_logits_v2_helper(labels, logits, axis: axis, name: name);
4961
}
5062
}
5163
}

src/TensorFlowNET.Core/Keras/Engine/Network.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using System;
22
using System.Collections.Generic;
33
using System.Text;
4+
using Tensorflow.Keras.Layers;
45

56
namespace Tensorflow.Keras.Engine
67
{

src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,10 @@
33
using System.Linq;
44
using System.Text;
55
using Tensorflow.Keras.Utils;
6-
using Tensorflow.Layers;
76

87
namespace Tensorflow.Keras.Layers
98
{
10-
public class BatchNormalization : Layer
9+
public class BatchNormalization : Tensorflow.Layers.Layer
1110
{
1211
private bool _USE_V2_BEHAVIOR = true;
1312
private float momentum;

src/TensorFlowNET.Core/Keras/Layers/Dense.cs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
using System.Text;
55
using Tensorflow.Keras.Engine;
66
using Tensorflow.Operations.Activation;
7+
using static Tensorflow.tf;
78

89
namespace Tensorflow.Keras.Layers
910
{
@@ -55,5 +56,26 @@ protected override void build(TensorShape input_shape)
5556

5657
built = true;
5758
}
59+
60+
protected override Tensor call(Tensor inputs, Tensor training = null)
61+
{
62+
Tensor outputs = null;
63+
var rank = inputs.rank;
64+
if(rank > 2)
65+
{
66+
throw new NotImplementedException("");
67+
}
68+
else
69+
{
70+
outputs = gen_math_ops.mat_mul(inputs, kernel);
71+
}
72+
73+
if (use_bias)
74+
outputs = nn.bias_add(outputs, bias);
75+
if (activation != null)
76+
return activation.Activate(outputs);
77+
78+
return outputs;
79+
}
5880
}
5981
}

src/TensorFlowNET.Core/Keras/Engine/Layer.cs renamed to src/TensorFlowNET.Core/Keras/Layers/Layer.cs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22
using System.Collections.Generic;
33
using System.Linq;
44
using System.Text;
5+
using Tensorflow.Keras.Engine;
56
using Tensorflow.Keras.Utils;
67

7-
namespace Tensorflow.Keras.Engine
8+
namespace Tensorflow.Keras.Layers
89
{
910
/// <summary>
1011
/// Base layer class.
@@ -106,7 +107,7 @@ private Tensor compute_mask(Tensor inputs, Tensor mask = null)
106107

107108
protected virtual Tensor call(Tensor inputs, Tensor training = null)
108109
{
109-
throw new NotImplementedException("Layer.call");
110+
return inputs;
110111
}
111112

112113
protected virtual string _name_scope()

src/TensorFlowNET.Core/Layers/Layer.cs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,10 @@
22
using System.Collections.Generic;
33
using System.Linq;
44
using System.Text;
5-
using Tensorflow.Keras.Engine;
65

76
namespace Tensorflow.Layers
87
{
9-
public class Layer : Keras.Engine.Layer
8+
public class Layer : Keras.Layers.Layer
109
{
1110
protected Graph _graph;
1211

src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,5 +108,23 @@ public static Tensor[] top_kv2(Tensor input, int k, bool sorted = true, string n
108108

109109
return _op.outputs;
110110
}
111+
112+
/// <summary>
113+
/// Computes softmax cross entropy cost and gradients to backpropagate.
114+
/// </summary>
115+
/// <param name="features"></param>
116+
/// <param name="labels"></param>
117+
/// <param name="name"></param>
118+
/// <returns></returns>
119+
public static (Tensor, Tensor) softmax_cross_entropy_with_logits(Tensor features, Tensor labels, string name = null)
120+
{
121+
var _op = _op_def_lib._apply_op_helper("SoftmaxCrossEntropyWithLogits", name: name, args: new
122+
{
123+
features,
124+
labels
125+
});
126+
127+
return (_op.outputs[0], _op.outputs[1]);
128+
}
111129
}
112130
}

0 commit comments

Comments
 (0)