Skip to content

Commit adae3aa

Browse files
committed
AdamOptimizer, reduce_prod
1 parent 7dbcb6c commit adae3aa

File tree

15 files changed

+155
-27
lines changed

15 files changed

+155
-27
lines changed

src/TensorFlowNET.Core/Framework/common_shapes.py.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,5 +29,10 @@ public static Tensor _broadcast_shape_helper(Tensor shape_x, Tensor shape_y)
2929
{
3030
throw new NotFiniteNumberException();
3131
}
32+
33+
public static int? rank(Tensor tensor)
34+
{
35+
return tensor.rank;
36+
}
3237
}
3338
}

src/TensorFlowNET.Core/Gradients/math_grad.py.cs

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,24 @@ public static (Tensor, Tensor) _MulGrad(Operation op, Tensor grad)
5757
return (reshape1, reshape2);
5858
}
5959

60+
public static (Tensor, Tensor) _MeanGrad(Operation op, Tensor grad)
61+
{
62+
var sum_grad = _SumGrad(op, grad).Item1;
63+
var input_shape = op.inputs[0]._shape_tuple();
64+
var output_shape = op.outputs[0]._shape_tuple();
65+
66+
var input_shape_tensor = array_ops.shape(op.inputs[0]);
67+
var output_shape_tensor = array_ops.shape(op.outputs[0]);
68+
var factor = _safe_shape_div(math_ops.reduce_prod(input_shape_tensor), math_ops.reduce_prod(output_shape_tensor));
69+
70+
return (math_ops.truediv(sum_grad, math_ops.cast(factor, sum_grad.dtype)), null);
71+
}
72+
73+
private static Tensor _safe_shape_div(Tensor x, Tensor y)
74+
{
75+
return math_ops.floordiv(x, gen_math_ops.maximum(y, 1));
76+
}
77+
6078
public static (Tensor, Tensor) _SubGrad(Operation op, Tensor grad)
6179
{
6280
var x = op.inputs[0];
@@ -81,12 +99,25 @@ public static bool _ShapesFullySpecifiedAndEqual(Tensor x, Tensor y, Tensor grad
8199

82100
public static (Tensor, Tensor) _SumGrad(Operation op, Tensor grad)
83101
{
84-
if (op.inputs[0].NDims > -1)
85-
{
102+
var input_0_shape = op.inputs[0]._shape_tuple();
103+
Tensor input_shape = null;
86104

105+
if (input_0_shape != null)
106+
{
107+
var axes = tensor_util.constant_value(op.inputs[1]);
108+
if(!(axes is null))
109+
{
110+
var rank = axes.shape.Rank;
111+
grad = array_ops.reshape(grad, new int[] { 1 });
112+
if (!input_0_shape.Contains(-1))
113+
input_shape = constant_op.constant(input_0_shape);
114+
else
115+
input_shape = array_ops.shape(op.inputs[0]);
116+
return (gen_array_ops.tile(grad, input_shape), null);
117+
}
87118
}
88119

89-
var input_shape = array_ops.shape(op.inputs[0]);
120+
input_shape = array_ops.shape(op.inputs[0]);
90121
ops.colocate_with(input_shape);
91122
var output_shape_kept_dims = math_ops.reduced_shape(input_shape, op.inputs[1]);
92123
var tile_scaling = _safe_shape_div(input_shape, output_shape_kept_dims);
@@ -95,11 +126,6 @@ public static (Tensor, Tensor) _SumGrad(Operation op, Tensor grad)
95126
return (gen_array_ops.tile(grad, tile_scaling), null);
96127
}
97128

98-
public static Tensor _safe_shape_div(Tensor x, Tensor y)
99-
{
100-
return math_ops.floordiv(x, gen_math_ops.maximum(y, 1));
101-
}
102-
103129
public static (Tensor, Tensor) _RealDivGrad(Operation op, Tensor grad)
104130
{
105131
var x = op.inputs[0];

src/TensorFlowNET.Core/Keras/Layers/Dense.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ protected override Tensor call(Tensor inputs, Tensor training = null)
6363
var rank = inputs.rank;
6464
if(rank > 2)
6565
{
66-
throw new NotImplementedException("");
66+
throw new NotImplementedException("call rank > 2");
6767
}
6868
else
6969
{

src/TensorFlowNET.Core/Operations/array_ops.py.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ public static Tensor rank(Tensor input, string name = null)
8282
public static Tensor ones_like<T>(T tensor, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true)
8383
=> ones_like_impl(tensor, dtype, name, optimize);
8484

85-
public static Tensor reshape(Tensor tensor, Tensor shape, string name = null)
85+
public static Tensor reshape<T1, T2>(T1 tensor, T2 shape, string name = null)
8686
{
8787
return gen_array_ops.reshape(tensor, shape, null);
8888
}

src/TensorFlowNET.Core/Operations/gen_array_ops.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ public static (Tensor, Tensor) broadcast_gradient_args(Tensor s0, Tensor s1, str
116116
return (_op.outputs[0], _op.outputs[1]);
117117
}
118118

119-
public static Tensor reshape(Tensor tensor, Tensor shape, string name = null)
119+
public static Tensor reshape<T1, T2>(T1 tensor, T2 shape, string name = null)
120120
{
121121
var _op = _op_def_lib._apply_op_helper("Reshape", name, new { tensor, shape });
122122
return _op.outputs[0];

src/TensorFlowNET.Core/Operations/gen_math_ops.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,16 @@ public static class gen_math_ops
2020
/// <param name="keep_dims"> An optional `bool`. Defaults to `False`. If true, retain reduced dimensions with length 1.</param>
2121
/// <param name="name"> A name for the operation (optional).</param>
2222
/// <returns> A `Tensor`. Has the same type as `input`.</returns>
23-
public static Tensor mean(Tensor input, Tensor axis, bool keep_dims= false, string name = null)
23+
public static Tensor mean<T1, T2>(T1 input, T2 axis, bool keep_dims= false, string name = null)
2424
{
2525
var _op = _op_def_lib._apply_op_helper("Mean", name, args: new { input, reduction_indices = axis, keep_dims = keep_dims });
2626

2727
return _op.outputs[0];
2828
}
2929

30-
public static Tensor mean(Tensor input, int[] axis, bool keep_dims = false, string name = null)
30+
public static Tensor prod<T1, T2>(T1 input, T2 axis, bool keep_dims = false, string name = null)
3131
{
32-
var _op = _op_def_lib._apply_op_helper("Mean", name, args: new { input, reduction_indices = axis, keep_dims = keep_dims, name });
32+
var _op = _op_def_lib._apply_op_helper("Prod", name, args: new { input, reduction_indices = axis, keep_dims });
3333

3434
return _op.outputs[0];
3535
}
@@ -186,7 +186,7 @@ public static Tensor maximum<T1, T2>(T1 x, T2 y, string name = null)
186186
return _op.outputs[0];
187187
}
188188

189-
public static Tensor _max(Tensor input, int[] axis, bool keep_dims=false, string name = null)
189+
public static Tensor _max<Tx, Ty>(Tx input, Ty axis, bool keep_dims=false, string name = null)
190190
{
191191
var _op = _op_def_lib._apply_op_helper("Max", name, new { input, reduction_indices = axis, keep_dims });
192192

src/TensorFlowNET.Core/Operations/math_ops.py.cs

Lines changed: 59 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
using System;
33
using System.Collections.Generic;
44
using System.Text;
5+
using Tensorflow.Framework;
56

67
namespace Tensorflow
78
{
@@ -39,9 +40,41 @@ public static Tensor cast(Tensor x, TF_DataType dtype = TF_DataType.DtInvalid, s
3940
public static Tensor reduce_mean(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null)
4041
{
4142
var r = _ReductionDims(input_tensor, axis);
42-
var m = gen_math_ops.mean(input_tensor, (int[]) r, keepdims, name);
43-
return _may_reduce_to_scalar(keepdims,axis, m);
43+
if (axis == null)
44+
{
45+
var m = gen_math_ops.mean(input_tensor, r, keepdims, name);
46+
return _may_reduce_to_scalar(keepdims, axis, m);
47+
}
48+
else
49+
{
50+
var m = gen_math_ops.mean(input_tensor, axis, keepdims, name);
51+
return _may_reduce_to_scalar(keepdims, axis, m);
52+
}
53+
}
54+
55+
/// <summary>
56+
/// Computes the product of elements across dimensions of a tensor.
57+
/// </summary>
58+
/// <param name="input_tensor"></param>
59+
/// <param name="axis"></param>
60+
/// <param name="keepdims"></param>
61+
/// <param name="name"></param>
62+
/// <returns></returns>
63+
public static Tensor reduce_prod(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null)
64+
{
65+
var r = _ReductionDims(input_tensor, axis);
66+
if (axis == null)
67+
{
68+
var m = gen_math_ops.prod(input_tensor, r, keepdims, name);
69+
return _may_reduce_to_scalar(keepdims, axis, m);
70+
}
71+
else
72+
{
73+
var m = gen_math_ops.prod(input_tensor, axis, keepdims, name);
74+
return _may_reduce_to_scalar(keepdims, axis, m);
75+
}
4476
}
77+
4578
/// <summary>
4679
/// Returns (x - y)(x - y) element-wise.
4780
/// </summary>
@@ -134,7 +167,10 @@ public static Tensor reduce_logsumexp(Tensor input_tensor, int[] axis = null, bo
134167

135168
public static Tensor reduce_max(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null)
136169
{
137-
return _may_reduce_to_scalar(keepdims, axis, gen_math_ops._max(input_tensor, (int[])_ReductionDims(input_tensor, axis), keepdims, name));
170+
var r = _ReductionDims(input_tensor, axis);
171+
var max = (axis != null) ? gen_math_ops._max(input_tensor, axis, keepdims, name) :
172+
gen_math_ops._max(input_tensor, r, keepdims, name);
173+
return _may_reduce_to_scalar(keepdims, axis, max);
138174
}
139175

140176
/// <summary>
@@ -197,18 +233,19 @@ private static Tensor _ReductionDims(Tensor x, Tensor axis)
197233
}
198234
}
199235

200-
private static object _ReductionDims(Tensor x, int[] axis)
236+
private static Tensor _ReductionDims(Tensor x, int[] axis)
201237
{
202238
if (axis != null)
203239
{
204-
return axis;
240+
// should return axis. or check before.
241+
return null;
205242
}
206243
else
207244
{
208-
var rank = array_ops.rank(x);
245+
var rank = common_shapes.rank(x);
209246
if (rank != null)
210247
{
211-
return constant_op.constant(np.arange(rank), TF_DataType.TF_INT32);
248+
return constant_op.constant(np.arange(rank.Value), TF_DataType.TF_INT32);
212249
}
213250
return range(0, rank, 1);
214251
}
@@ -303,5 +340,20 @@ public static Tensor conj(Tensor x, string name = null)
303340
return x;
304341
});
305342
}
343+
344+
public static Tensor truediv(Tensor x, Tensor y, string name = null)
345+
=> _truediv_python3(x, y, name);
346+
347+
public static Tensor _truediv_python3(Tensor x, Tensor y, string name = null)
348+
{
349+
return with(ops.name_scope(name, "truediv", new { x, y }), scope =>
350+
{
351+
name = scope;
352+
var x_dtype = x.dtype.as_base_dtype();
353+
var y_dtype = y.dtype.as_base_dtype();
354+
355+
return gen_math_ops.real_div(x, y, name: name);
356+
});
357+
}
306358
}
307359
}

src/TensorFlowNET.Core/Tensors/Tensor.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,8 @@ public long[] shape
7474

7575
public int[] _shape_tuple()
7676
{
77-
return null;
77+
if (shape == null) return null;
78+
return shape.Select(x => (int)x).ToArray();
7879
}
7980

8081
public TensorShape getShape()

src/TensorFlowNET.Core/Tensors/tensor_util.cs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,15 @@ public static NDArray MakeNdarray(TensorProto tensor)
5151
if (tensor.TensorContent.Length > 0)
5252
return np.frombuffer(tensor.TensorContent.ToByteArray(), tensor_dtype)
5353
.reshape(shape);
54+
else if (tensor.Dtype == DataType.DtHalf || tensor.Dtype == DataType.DtBfloat16)
55+
;
56+
else if (tensor.Dtype == DataType.DtFloat)
57+
;
58+
else if (new DataType[] { DataType.DtInt32, DataType.DtUint8 }.Contains(tensor.Dtype))
59+
if (tensor.IntVal.Count == 1)
60+
return np.repeat(np.array(tensor.IntVal[0]), Convert.ToInt32(num_elements))
61+
.reshape(shape);
62+
5463
throw new NotImplementedException("MakeNdarray");
5564
}
5665

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.Train
6+
{
7+
/// <summary>
8+
/// Optimizer that implements the Adam algorithm.
9+
/// http://arxiv.org/abs/1412.6980
10+
/// </summary>
11+
public class AdamOptimizer : Optimizer
12+
{
13+
private float _beta1;
14+
private float _beta2;
15+
private float _epsilon;
16+
17+
public AdamOptimizer(float learning_rate, float beta1 = 0.9f, float beta2 = 0.999f, float epsilon = 1e-8f, bool use_locking = false, string name = "Adam")
18+
: base(learning_rate, use_locking, name)
19+
{
20+
_beta1 = beta1;
21+
_beta2 = beta2;
22+
_epsilon = epsilon;
23+
}
24+
}
25+
}

0 commit comments

Comments
 (0)