Skip to content

Commit 19d87cc

Browse files
Update Adamax.step.
* Update src/TorchSharp/Optimizers/Adamax.cs. + Update Adamax.step. - Declare TorchSharp.Scalar explicitly. - Add FIXME for possible unused weight_decay_scalar. - Cache weight_decay != 0. - Add FIXME for CA1806.
1 parent 3f0d806 commit 19d87cc

File tree

1 file changed

+15
-8
lines changed

1 file changed

+15
-8
lines changed

src/TorchSharp/Optimizers/Adamax.cs

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -142,8 +142,14 @@ public override Tensor step(Func<Tensor> closure = null)
142142
var options = group.Options as Options;
143143
var beta1 = options.beta1.Value;
144144
var beta2 = options.beta2.Value;
145-
var eps = options.eps.Value;
145+
using var beta1_scalar = beta1.ToScalar();
146+
using var beta2_scalar = beta2.ToScalar();
147+
var beta1_bar = 1 - beta1;
148+
using var beta1_bar_scalar = beta1_bar.ToScalar();
149+
using var eps_scalar = options.eps.Value.ToScalar();
146150
var weight_decay = options.weight_decay.Value;
151+
var need_weight_decay = weight_decay != 0;
152+
using var weight_decay_scalar = weight_decay.ToScalar(); // FIXME: Omit if not need_weight_decay?
147153
var lr = options.LearningRate.Value;
148154

149155
foreach (var param in group.Parameters) {
@@ -161,21 +167,22 @@ public override Tensor step(Func<Tensor> closure = null)
161167
var exp_avg = state.exp_avg;
162168
var exp_inf = state.exp_inf;
163169

164-
grad = (weight_decay != 0)
165-
? grad.add(param, alpha: weight_decay)
170+
grad = (need_weight_decay)
171+
? grad.add(param, alpha: weight_decay_scalar)
166172
: grad.alias();
167173

168-
exp_avg.mul_(beta1).add_(grad, alpha: 1 - beta1);
174+
exp_avg.mul_(beta1_scalar).add_(grad, alpha: beta1_bar_scalar);
169175

170176
var norm_buf = torch.cat(new Tensor[] {
171-
exp_inf.mul_(beta2).unsqueeze(0),
172-
grad.abs().add_(eps).unsqueeze_(0)
177+
exp_inf.mul_(beta2_scalar).unsqueeze(0),
178+
grad.abs().add_(eps_scalar).unsqueeze_(0)
173179
}, 0);
174180

175-
torch.amax(norm_buf, new long[] { 0 }, false, exp_inf);
181+
torch.amax(norm_buf, new long[] { 0 }, false, exp_inf); // FIXME: CA1806?
176182

177183
var clr = lr / (1 - Math.Pow(beta1, state.step));
178-
param.addcdiv_(exp_avg, exp_inf, value: -clr);
184+
using var negative_clr_scalar = (-clr).ToScalar();
185+
param.addcdiv_(exp_avg, exp_inf, value: negative_clr_scalar);
179186
}
180187
}, closure);
181188
}

0 commit comments

Comments
 (0)