Skip to content

Commit cee4875

Browse files
Declare TorchSharp.Scalar explicitly.
* Update src/TorchSharp/Optimizers/RMSprop.cs. + Update RMSProp.step. - Declare TorchSharp.Scalar explicitly. - Add FIXME for possible unused momentum_scalar. - Add FIXME for possible unused weight_decay_scalar. - Cache momentum > 0. - Cache weight_decay != 0. - Add FIXME for possible no avg dispose.
1 parent 997c679 commit cee4875

File tree

1 file changed

+20
-13
lines changed

1 file changed

+20
-13
lines changed

src/TorchSharp/Optimizers/RMSprop.cs

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -152,11 +152,20 @@ public override Tensor step(Func<Tensor> closure = null)
152152
var options = group.Options as Options;
153153
var maximize = options.maximize.Value;
154154
var momentum = options.momentum.Value;
155+
var need_momentum = momentum > 0;
156+
using var momentum_scalar = momentum.ToScalar(); // FIXME: Omit if not need_momentum?
155157
var alpha = options.alpha.Value;
158+
var alpha_bar = 1 - alpha;
159+
using var alpha_scalar = alpha.ToScalar();
160+
using var alpha_bar_scalar = alpha_bar.ToScalar();
156161
var weight_decay = options.weight_decay.Value;
162+
var need_weight_decay = weight_decay != 0;
163+
using var weight_decay_scalar = weight_decay.ToScalar(); // FIXME: Omit if not need_weight_decay?
157164
var centered = options.centered.Value;
158-
var eps = options.eps.Value;
165+
using var negative_one_scalar = (-1).ToScalar();
166+
using var eps_scalar = options.eps.Value.ToScalar();
159167
var lr = options.LearningRate.Value;
168+
using var negative_lr_scalar = (-lr).ToScalar();
160169

161170
foreach (var param in group.Parameters) {
162171

@@ -170,28 +179,26 @@ public override Tensor step(Func<Tensor> closure = null)
170179

171180
state.step += 1;
172181

173-
if (weight_decay != 0) {
174-
grad = grad.add(param, alpha: weight_decay);
175-
}
182+
if (need_weight_decay) grad = grad.add(param, alpha: weight_decay_scalar);
176183

177-
state.square_avg.mul_(alpha).addcmul_(grad, grad, value: 1 - alpha);
184+
state.square_avg.mul_(alpha_scalar).addcmul_(grad, grad, value: alpha_bar_scalar);
178185

179-
Tensor avg = null;
186+
Tensor avg = null; // FIXME: Need dispose?
180187

181188
if (centered) {
182189
var grad_avg = state.grad_avg;
183-
grad_avg.mul_(alpha).add_(grad, alpha: 1 - alpha);
184-
avg = state.square_avg.addcmul(grad_avg, grad_avg, value: -1).sqrt_().add_(eps);
190+
grad_avg.mul_(alpha_scalar).add_(grad, alpha: alpha_bar_scalar);
191+
avg = state.square_avg.addcmul(grad_avg, grad_avg, value: negative_one_scalar).sqrt_().add_(eps_scalar);
185192
} else {
186-
avg = state.square_avg.sqrt().add_(eps);
193+
avg = state.square_avg.sqrt().add_(eps_scalar);
187194
}
188195

189-
if (momentum > 0) {
196+
if (need_momentum) {
190197
var buf = state.momentum_buffer;
191-
buf.mul_(momentum).addcdiv_(grad, avg);
192-
param.add_(buf, alpha: -lr);
198+
buf.mul_(momentum_scalar).addcdiv_(grad, avg);
199+
param.add_(buf, alpha: negative_lr_scalar);
193200
} else {
194-
param.addcdiv_(grad, avg, -lr);
201+
param.addcdiv_(grad, avg, negative_lr_scalar);
195202
}
196203
}
197204
}, closure);

0 commit comments

Comments
 (0)