Skip to content

Commit 68f00b2

Browse files
Update ASGD.step.
* Update src/TorchSharp/Optimizers/ASGD.cs. + Update ASGD.step. - Declare TorchSharp.Scalar explicitly. - Add FIXME for possible unsed weight_decay_scalar. - Cache weight_decay != 0.
1 parent 19d87cc commit 68f00b2

File tree

1 file changed

+11
-5
lines changed

1 file changed

+11
-5
lines changed

src/TorchSharp/Optimizers/ASGD.cs

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,8 @@ public override Tensor step(Func<Tensor> closure = null)
140140
var lambd = options.lambd.Value;
141141
var alpha = options.alpha.Value;
142142
var weight_decay = options.weight_decay.Value;
143+
var need_weight_decay = weight_decay != 0;
144+
using var weight_decay_scalar = weight_decay.ToScalar(); // FIXME: Omit if not need_weight_decay?
143145
var t0 = options.t0.Value;
144146
var lr = options.LearningRate.Value;
145147

@@ -157,15 +159,19 @@ public override Tensor step(Func<Tensor> closure = null)
157159

158160
state.step += 1;
159161

160-
grad = (weight_decay != 0)
161-
? grad.add(param, alpha: weight_decay)
162+
grad = (need_weight_decay)
163+
? grad.add(param, alpha: weight_decay_scalar)
162164
: grad.alias();
163165

164-
param.mul_(1 - lambd * state.eta);
165-
param.add_(grad, alpha: -state.eta);
166+
var lambd_eta_bar = 1 - lambd * state.eta;
167+
using var lambd_eta_bar_scalar = lambd_eta_bar.ToScalar();
168+
param.mul_(lambd_eta_bar_scalar);
169+
using var negative_eta_scalar = (-state.eta).ToScalar();
170+
param.add_(grad, alpha: negative_eta_scalar);
166171

167172
if (state.mu != 1) {
168-
state.ax.add_(param.sub(state.ax).mul(state.mu));
173+
using var mu_scalar = state.mu.ToScalar();
174+
state.ax.add_(param.sub(state.ax).mul(mu_scalar));
169175
} else {
170176
state.ax.copy_(param);
171177
}

0 commit comments

Comments
 (0)