Skip to content

Commit 901d574

Browse files
committed
Expose learning_rate property in Optimizer.
1 parent 17a4fe0 commit 901d574

File tree

6 files changed

+49
-11
lines changed

6 files changed

+49
-11
lines changed

src/TensorFlowNET.Console/MemoryBasicTest.cs

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,15 +56,31 @@ public Action<int, int> VariableRead
5656
{
5757
var nd = np.zeros(1 * 256 * 256 * 3).astype(np.float32).reshape(1, 256, 256, 3);
5858
ResourceVariable variable = tf.Variable(nd);
59-
var nd2 = np.arange(1 * 256 * 256 * 3).astype(np.float32).reshape(1, 256, 256, 3);
60-
variable.assign(nd2);
6159

62-
for (int i = 0; i< 100; i++)
60+
for (int i = 0; i< 10; i++)
6361
{
6462
var v = variable.numpy();
6563
}
6664
};
6765

66+
public Action<int, int> VariableAssign
67+
=> (epoch, iterate) =>
68+
{
69+
ResourceVariable variable = tf.Variable(3112f);
70+
AssignVariable(variable);
71+
for (int i = 0; i < 100; i++)
72+
{
73+
var v = variable.numpy();
74+
if ((float)v != 1984f)
75+
throw new ValueError("");
76+
}
77+
};
78+
79+
void AssignVariable(IVariableV1 v)
80+
{
81+
using var tensor = tf.constant(1984f);
82+
v.assign(tensor);
83+
}
6884

6985
public Action<int, int> MathAdd
7086
=> (epoch, iterate) =>

src/TensorFlowNET.Console/Program.cs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,10 @@ static void BasicTest(MemoryMonitor mm)
5252
// 100K float variable.
5353
mm.Execute(10, batchSize, basic.Variable);
5454

55+
mm.Execute(10, batchSize, basic.VariableRead);
56+
57+
mm.Execute(10, batchSize, basic.VariableAssign);
58+
5559
// 1 million math.
5660
mm.Execute(10, 100 * batchSize, basic.MathAdd);
5761

src/TensorFlowNET.Core/APIs/tf.math.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,9 @@ public Tensor sinh(Tensor x, string name = null)
118118
public Tensor cos(Tensor x, string name = null)
119119
=> gen_math_ops.cos(x, name);
120120

121+
public Tensor cos(float x, string name = null)
122+
=> gen_math_ops.cos(x, name);
123+
121124
/// <summary>
122125
/// Computes hyperbolic cosine of x element-wise.
123126
/// </summary>

src/TensorFlowNET.Core/Operations/gen_math_ops.cs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -376,8 +376,18 @@ public static Tensor sinh(Tensor x, string name = null)
376376
return _op.outputs[0];
377377
}
378378

379-
public static Tensor cos(Tensor x, string name = null)
379+
public static Tensor cos<T>(T x, string name = null)
380380
{
381+
if (tf.executing_eagerly())
382+
{
383+
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
384+
"Cos", name,
385+
null,
386+
x);
387+
388+
return results[0];
389+
}
390+
381391
var _op = tf.OpDefLib._apply_op_helper("Cos", name, args: new { x });
382392

383393
return _op.outputs[0];

src/TensorFlowNET.Core/Tensors/Tensor.String.cs

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -90,15 +90,17 @@ public unsafe byte[][] StringBytes()
9090
size *= s;
9191

9292
var buffer = new byte[size][];
93-
var data_start = c_api.TF_TensorData(_handle);
94-
var string_start = data_start + (int)(size * sizeof(ulong));
93+
var src = c_api.TF_TensorData(_handle);
94+
src += (int)(size * 8);
9595
for (int i = 0; i < buffer.Length; i++)
9696
{
97-
var len = *(byte*)string_start;
98-
buffer[i] = new byte[len];
99-
string_start += 1;
100-
Marshal.Copy(string_start, buffer[i], 0, len);
101-
string_start += len;
97+
IntPtr dst = IntPtr.Zero;
98+
ulong dstLen = 0;
99+
var read = c_api.TF_StringDecode((byte*)src, bytesize, (byte**)&dst, ref dstLen, tf.Status.Handle);
100+
tf.Status.Check(true);
101+
buffer[i] = new byte[(int)dstLen];
102+
Marshal.Copy(dst, buffer[i], 0, buffer[i].Length);
103+
src += (int)read;
102104
}
103105

104106
return buffer;

src/TensorFlowNET.Keras/Optimizers/OptimizerV2.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ public class OptimizerV2 : Trackable, IOptimizer
2626
protected float _initial_decay = 0.0f;
2727
protected bool _use_locking = true;
2828

29+
public IVariableV1 lr
30+
=> _hyper_variables["learning_rate"];
31+
2932
Dictionary<string, Dictionary<string, IVariableV1>> _slots;
3033
List<string> _slot_names;
3134

0 commit comments

Comments
 (0)