Skip to content

Commit 997c679

Browse files
Update RAdam.step.
* Update src/TorchSharp/Optimizers/RAdam.cs. + Update RAdam.step. - Declare TorchSharp.Scalar explicitly. - Add FIXME for possible unused weight_decay_scalar. - Cache weight_decay != 0. - Add FIXME for possible torch.Tensor.sub_ use. - Add FIXME for possible no dispose for torch.Tensor. - bias_corrected_exp_avg - t6 - adaptive_lr and its intermediates and derives - Add FIXME for possible no dispose on param.add_ if rho_t > 5.
1 parent 4cc71c6 commit 997c679

File tree

1 file changed

+20
-10
lines changed

1 file changed

+20
-10
lines changed

src/TorchSharp/Optimizers/RAdam.cs

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -141,9 +141,19 @@ public override Tensor step(Func<Tensor> closure = null)
141141
var options = group.Options as Options;
142142
var beta1 = options.beta1.Value;
143143
var beta2 = options.beta2.Value;
144-
var eps = options.eps.Value;
144+
using var beta1_scalar = beta1.ToScalar();
145+
using var beta2_scalar = beta2.ToScalar();
146+
var beta1_bar = 1 - beta1;
147+
var beta2_bar = 1 - beta2;
148+
using var beta1_bar_scalar = beta1_bar.ToScalar();
149+
using var beta2_bar_scalar = beta2_bar.ToScalar();
150+
using var eps_scalar = options.eps.Value.ToScalar();
145151
var weight_decay = options.weight_decay.Value;
152+
var need_weight_decay = weight_decay != 0;
153+
using var weight_decay_scalar = weight_decay.ToScalar(); // FIXME: Omit if not need_weight_decay?
146154
var lr = options.LearningRate.Value;
155+
using var lr_scalar = lr.ToScalar();
156+
using var negative_one_scalar = (-1.0).ToScalar(); // FIXME: Use torch.Tensor.sub_ instead?
147157

148158
foreach (var param in group.Parameters) {
149159

@@ -161,27 +171,27 @@ public override Tensor step(Func<Tensor> closure = null)
161171
var bias_correction1 = 1 - Math.Pow(beta1, state.step);
162172
var bias_correction2 = 1 - Math.Pow(beta2, state.step);
163173

164-
grad = (weight_decay != 0)
165-
? grad.add(param, alpha: weight_decay)
174+
grad = (need_weight_decay)
175+
? grad.add(param, alpha: weight_decay_scalar)
166176
: grad.alias();
167177

168-
exp_avg.mul_(beta1).add_(grad, alpha: 1 - beta1);
169-
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value: 1 - beta2);
178+
exp_avg.mul_(beta1_scalar).add_(grad, alpha: beta1_bar_scalar);
179+
exp_avg_sq.mul_(beta2_scalar).addcmul_(grad, grad, value: beta2_bar_scalar);
170180

171-
var bias_corrected_exp_avg = exp_avg / bias_correction1;
181+
var bias_corrected_exp_avg = exp_avg / bias_correction1; // FIXME: Need dispose?
172182

173183
var rho_inf = 2 / (1 - beta2) - 1;
174184
var rho_t = rho_inf - 2 * state.step * Math.Pow(beta2, state.step) / bias_correction2;
175185

176-
var t6 = bias_corrected_exp_avg * lr;
186+
var t6 = bias_corrected_exp_avg.mul(lr_scalar); // FIXME: Need dispose?
177187

178188
if (rho_t > 5) {
179189
var rect = Math.Sqrt((rho_t - 4) * (rho_t - 2) * rho_inf / ((rho_inf - 4) * (rho_inf - 2) * rho_t));
180-
var adaptive_lr = Math.Sqrt(bias_correction2) / exp_avg_sq.sqrt().add_(eps);
190+
var adaptive_lr = Math.Sqrt(bias_correction2) / exp_avg_sq.sqrt().add_(eps_scalar); // FIXME: Need dispose?
181191

182-
param.add_(t6 * lr * adaptive_lr * rect, alpha: -1.0);
192+
param.add_(t6 * lr * adaptive_lr * rect, alpha: negative_one_scalar); // FIXME: Need dispose? Use inplace ops?
183193
} else {
184-
param.add_(t6, alpha: -1.0);
194+
param.add_(t6, alpha: negative_one_scalar);
185195
}
186196
}
187197
}, closure);

0 commit comments

Comments
 (0)