Skip to content

Commit 645523c

Browse files
Update Adagrad.step.
* Update src/TorchSharp/Optimizers/Adagrad.cs. + Update Adagrad.step. + Declare TorchSharp.Scalar explicitly. - Add FIXME for possible unused weight_decay_scalar. + Add FIXME for possible unsued initial_accumulator_value. + Cache weight_decay != 0.
1 parent 0fb172d commit 645523c

File tree

1 file changed

+10
-8
lines changed

1 file changed

+10
-8
lines changed

src/TorchSharp/Optimizers/Adagrad.cs

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -139,9 +139,12 @@ public override Tensor step(Func<Tensor> closure = null)
139139
var options = group.Options as Options;
140140
var lr_decay = options.lr_decay.Value;
141141
var weight_decay = options.weight_decay.Value;
142-
var eps = options.eps.Value;
143-
var initial_accumulator_value = options.initial_accumulator_value.Value;
142+
var need_weight_decay = weight_decay != 0;
143+
using var weight_decay_scalar = weight_decay.ToScalar(); // FIXME: Omit if not need_weight_decay?
144+
using var eps_scalar = options.eps.Value.ToScalar();
145+
var initial_accumulator_value = options.initial_accumulator_value.Value; // FIXME: Unused?
144146
var lr = options.LearningRate.Value;
147+
using var one_scalar = 1.ToScalar();
145148

146149
foreach (var param in group.Parameters) {
147150

@@ -153,20 +156,19 @@ public override Tensor step(Func<Tensor> closure = null)
153156

154157
state.step += 1;
155158

156-
if (weight_decay != 0) {
157-
grad = grad.add(param, alpha: weight_decay);
158-
}
159+
if (need_weight_decay) grad = grad.add(param, alpha: weight_decay_scalar);
159160

160161
var clr = lr / (1 + (state.step - 1) * lr_decay);
162+
using var negative_clr_scalar = (-clr).ToScalar();
161163

162164
if (grad.is_sparse)
163165
throw new NotImplementedException("Adagrad optimization over sparse parameters");
164166
if (torch.is_complex(grad))
165167
throw new NotImplementedException("Adagrad optimization over complex parameters");
166168

167-
state.sum.addcmul_(grad, grad, value: 1);
168-
var std = state.sum.sqrt().add_(eps);
169-
param.addcdiv_(grad, std, value: -clr);
169+
state.sum.addcmul_(grad, grad, value: one_scalar);
170+
var std = state.sum.sqrt().add_(eps_scalar);
171+
param.addcdiv_(grad, std, value: negative_clr_scalar);
170172
}
171173

172174

0 commit comments

Comments
 (0)