Skip to content

Commit 3f0d806

Browse files
Update Adam.step.
* Update src/TorchSharp/Optimizers/Adam.cs. + Update Adam.step. - Declare TorchSharp.Scalar explicitly. - Add FIXME for possible unused weight_decay_scalar. - Cache weight_decay != 0. - Add FIXME for possible no denom disposing.
1 parent 645523c commit 3f0d806

File tree

1 file changed

+17
-10
lines changed

1 file changed

+17
-10
lines changed

src/TorchSharp/Optimizers/Adam.cs

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -154,10 +154,18 @@ public override Tensor step(Func<Tensor> closure = null)
154154
var options = group.Options as Options;
155155
var beta1 = options.beta1.Value;
156156
var beta2 = options.beta2.Value;
157+
using var beta1_scalar = beta1.ToScalar();
158+
using var beta2_scalar = beta2.ToScalar();
159+
var beta1_bar = 1 - beta1;
160+
var beta2_bar = 1 - beta2;
161+
using var beta1_bar_scalar = beta1_bar.ToScalar();
162+
using var beta2_bar_scalar = beta2_bar.ToScalar();
157163
var weight_decay = options.weight_decay.Value;
164+
var need_weight_decay = weight_decay != 0;
165+
using var weight_decay_scalar = weight_decay.ToScalar(); // FIXME: Omit if not need_weight_decay?
158166
var amsgrad = options.amsgrad.Value;
159167
var maximize = options.maximize.Value;
160-
var eps = options.eps.Value;
168+
using var eps_scalar = options.eps.Value.ToScalar();
161169
var lr = options.LearningRate.Value;
162170

163171
foreach (var param in group.Parameters) {
@@ -175,25 +183,24 @@ public override Tensor step(Func<Tensor> closure = null)
175183
var bias_correction1 = 1 - Math.Pow(beta1, state.step);
176184
var bias_correction2 = 1 - Math.Pow(beta2, state.step);
177185

178-
if (weight_decay != 0) {
179-
grad = grad.add(param, alpha: weight_decay);
180-
}
186+
if (need_weight_decay) grad = grad.add(param, alpha: weight_decay_scalar);
181187

182-
state.exp_avg.mul_(beta1).add_(grad, alpha: 1 - beta1);
183-
state.exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value: 1 - beta2);
188+
state.exp_avg.mul_(beta1_scalar).add_(grad, alpha: beta1_bar_scalar);
189+
state.exp_avg_sq.mul_(beta2_scalar).addcmul_(grad, grad.conj(), value: beta2_bar_scalar);
184190

185-
Tensor denom = null;
191+
Tensor denom = null; // FIXME: Need dispose?
186192
if (amsgrad) {
187193
var t0 = state.max_exp_avg_sq;
188194
state.max_exp_avg_sq = torch.maximum(t0, state.exp_avg_sq).DetachFromDisposeScope();
189195
t0.Dispose();
190-
denom = (state.max_exp_avg_sq.sqrt() / Math.Sqrt(bias_correction2)).add_(eps);
196+
denom = (state.max_exp_avg_sq.sqrt() / Math.Sqrt(bias_correction2)).add_(eps_scalar);
191197
} else {
192-
denom = (state.exp_avg_sq.sqrt() / Math.Sqrt(bias_correction2)).add_(eps);
198+
denom = (state.exp_avg_sq.sqrt() / Math.Sqrt(bias_correction2)).add_(eps_scalar);
193199
}
194200

195201
var step_size = lr / bias_correction1;
196-
param.addcdiv_(state.exp_avg, denom, value: -step_size);
202+
using var negative_step_size_scalar = (-step_size).ToScalar();
203+
param.addcdiv_(state.exp_avg, denom, value: negative_step_size_scalar);
197204
}
198205
}, closure);
199206
}

0 commit comments

Comments
 (0)