Skip to content

Commit 4cc71c6

Browse files
Update NAdam.step.
* Update src/TorchSharp/Optimizers/NAdam.cs. + Update NAdam.step. - Declare TorchSharp.Scalar explicitly. - Add FIXME for possible unused weight_decay_scalar. - Cache weight_decay != 0. - Add FIXME for possible no denom disposing.
1 parent 68f00b2 commit 4cc71c6

File tree

1 file changed

+21
-8
lines changed

1 file changed

+21
-8
lines changed

src/TorchSharp/Optimizers/NAdam.cs

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -147,8 +147,16 @@ public override Tensor step(Func<Tensor> closure = null)
147147
var options = group.Options as Options;
148148
var beta1 = options.beta1.Value;
149149
var beta2 = options.beta2.Value;
150-
var eps = options.eps.Value;
150+
using var beta1_scalar = beta1.ToScalar();
151+
using var beta2_scalar = beta2.ToScalar();
152+
var beta1_bar = 1 - beta1;
153+
var beta2_bar = 1 - beta2;
154+
using var beta1_bar_scalar = beta1_bar.ToScalar();
155+
using var beta2_bar_scalar = beta2_bar.ToScalar();
156+
using var eps_scalar = options.eps.Value.ToScalar();
151157
var weight_decay = options.weight_decay.Value;
158+
var need_weight_decay = weight_decay != 0;
159+
using var weight_decay_scalar = weight_decay.ToScalar(); // FIXME: Omit if not need_weight_decay?
152160
var momentum_decay = options.momentum_decay.Value;
153161
var lr = options.LearningRate.Value;
154162

@@ -166,9 +174,10 @@ public override Tensor step(Func<Tensor> closure = null)
166174
var exp_avg_sq = state.exp_avg_sq;
167175

168176
var bias_correction2 = 1 - Math.Pow(beta2, state.step);
177+
using var bias_correction2_scalar = bias_correction2.ToScalar();
169178

170-
grad = (weight_decay != 0)
171-
? grad.add(param, alpha: weight_decay)
179+
grad = (need_weight_decay)
180+
? grad.add(param, alpha: weight_decay_scalar)
172181
: grad.alias();
173182

174183
var mu = beta1 * (1.0 - 0.5 * Math.Pow(0.96, state.step * momentum_decay));
@@ -177,13 +186,17 @@ public override Tensor step(Func<Tensor> closure = null)
177186
var mu_product = state.mu_product * mu;
178187
var mu_product_next = mu_product * mu_next;
179188

180-
exp_avg.mul_(beta1).add_(grad, alpha: 1 - beta1);
181-
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value: 1 - beta2);
189+
exp_avg.mul_(beta1_scalar).add_(grad, alpha: beta1_bar_scalar);
190+
exp_avg_sq.mul_(beta2_scalar).addcmul_(grad, grad, value: beta2_bar_scalar);
182191

183-
var denom = exp_avg_sq.div(bias_correction2).sqrt_().add_(eps);
192+
var denom = exp_avg_sq.div(bias_correction2_scalar).sqrt_().add_(eps_scalar); // FIXME: Need dispose?
184193

185-
param.addcdiv_(grad, denom, value: -lr * (1 - mu) / (1 - mu_product));
186-
param.addcdiv_(exp_avg, denom, value: -lr * mu_next / (1 - mu_product_next));
194+
var scaled_lr = lr * (1 - mu) / (1 - mu_product);
195+
using var negative_scaled_scalar = (-scaled_lr).ToScalar();
196+
param.addcdiv_(grad, denom, value: negative_scaled_scalar);
197+
var scaled_lr_next = lr * mu_next / (1 - mu_product_next);
198+
using var negative_scaled_lr_next_scalar = (-scaled_lr_next).ToScalar();
199+
param.addcdiv_(exp_avg, denom, value: negative_scaled_lr_next_scalar);
187200

188201
state.mu_product = mu_product;
189202
}

0 commit comments

Comments
 (0)