Skip to content

Commit 0fb172d

Browse files
Update Adadelta.step.
* Update src/TorchSharp/Optimizers/Adadelta.cs. + Update Adadelta.step. - Declare TorchSharp.Scalar explicitly. - Add FIXME for possible unused weight_decay_scalar. - Cache weight_decay != 0 explicitly.
1 parent 06b9e45 commit 0fb172d

File tree

1 file changed

+13
-8
lines changed

1 file changed

+13
-8
lines changed

src/TorchSharp/Optimizers/Adadelta.cs

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -129,10 +129,15 @@ public override Tensor step(Func<Tensor> closure = null)
129129

130130
var options = group.Options as Options;
131131
var rho = options.rho.Value;
132-
var eps = options.eps.Value;
132+
using var rho_scalar = rho.ToScalar();
133+
using var rho_bar_scalar = (1 - rho).ToScalar();
134+
using var eps_scalar = options.eps.Value.ToScalar();
133135
var weight_decay = options.weight_decay.Value;
136+
var need_weight_decay = (weight_decay != 0);
137+
using var weight_decay_scalar = weight_decay.ToScalar(); // FIXME: Omit if not need_weight_decay?
134138
var maximize = options.maximize.Value;
135139
var lr = options.LearningRate.Value;
140+
using var negative_lr_scalar = (-lr).ToScalar();
136141

137142
foreach (var param in group.Parameters) {
138143

@@ -149,17 +154,17 @@ public override Tensor step(Func<Tensor> closure = null)
149154
var square_avg = state.square_avg;
150155
var acc_delta = state.acc_delta;
151156

152-
grad = (weight_decay != 0)
153-
? grad.add(param, alpha: weight_decay)
157+
grad = (need_weight_decay)
158+
? grad.add(param, alpha: weight_decay_scalar)
154159
: grad.alias();
155160

156-
square_avg.mul_(rho).addcmul_(grad, grad, 1 - rho);
161+
square_avg.mul_(rho_scalar).addcmul_(grad, grad, rho_bar_scalar);
157162

158-
var std = square_avg.add(eps).sqrt_();
159-
var delta = acc_delta.add(eps).sqrt_().div_(std).mul_(grad);
163+
var std = square_avg.add(eps_scalar).sqrt_();
164+
var delta = acc_delta.add(eps_scalar).sqrt_().div_(std).mul_(grad);
160165

161-
param.add_(delta, alpha: -lr);
162-
acc_delta.mul_(rho).addcmul_(delta, delta, 1 - rho);
166+
param.add_(delta, alpha: negative_lr_scalar);
167+
acc_delta.mul_(rho_scalar).addcmul_(delta, delta, rho_bar_scalar);
163168
}
164169
}, closure);
165170
}

0 commit comments

Comments
 (0)