Skip to content

Commit 6af53e6

Browse files
committed
override Adam _finish.
1 parent 93eb56e commit 6af53e6

File tree

1 file changed

+19
-0
lines changed

1 file changed

+19
-0
lines changed

src/TensorFlowNET.Core/Train/AdamOptimizer.cs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,25 @@ protected override void _create_slots(RefVariable[] var_list)
9797
}
9898
}
9999

100+
public override Operation _finish(Operation[] update_ops, string name_scope)
101+
{
102+
var operations = new List<ITensorOrOperation>();
103+
operations.AddRange(update_ops);
104+
105+
with(ops.control_dependencies(update_ops), delegate
106+
{
107+
var (beta1_power, beta2_power) = _get_beta_accumulators();
108+
ops.colocate_with(beta1_power);
109+
var update_beta1 = beta1_power.assign(beta1_power * _beta1_t, use_locking: _use_locking);
110+
var update_beta2 = beta2_power.assign(beta2_power * _beta2_t, use_locking: _use_locking);
111+
112+
operations.Add(update_beta1);
113+
operations.Add(update_beta1);
114+
});
115+
116+
return control_flow_ops.group(operations.ToArray(), name: name_scope);
117+
}
118+
100119
private (RefVariable, RefVariable) _get_beta_accumulators()
101120
{
102121
ops.init_scope();

0 commit comments

Comments
 (0)