@@ -154,10 +154,18 @@ public override Tensor step(Func<Tensor> closure = null)
154154 var options = group . Options as Options ;
155155 var beta1 = options . beta1 . Value ;
156156 var beta2 = options . beta2 . Value ;
157+ using var beta1_scalar = beta1 . ToScalar ( ) ;
158+ using var beta2_scalar = beta2 . ToScalar ( ) ;
159+ var beta1_bar = 1 - beta1 ;
160+ var beta2_bar = 1 - beta2 ;
161+ using var beta1_bar_scalar = beta1_bar . ToScalar ( ) ;
162+ using var beta2_bar_scalar = beta2_bar . ToScalar ( ) ;
157163 var weight_decay = options . weight_decay . Value ;
164+ var need_weight_decay = weight_decay != 0 ;
165+ using var weight_decay_scalar = weight_decay . ToScalar ( ) ; // FIXME: Omit if not need_weight_decay?
158166 var amsgrad = options . amsgrad . Value ;
159167 var maximize = options . maximize . Value ;
160- var eps = options . eps . Value ;
168+ using var eps_scalar = options . eps . Value . ToScalar ( ) ;
161169 var lr = options . LearningRate . Value ;
162170
163171 foreach ( var param in group . Parameters ) {
@@ -175,25 +183,24 @@ public override Tensor step(Func<Tensor> closure = null)
175183 var bias_correction1 = 1 - Math . Pow ( beta1 , state . step ) ;
176184 var bias_correction2 = 1 - Math . Pow ( beta2 , state . step ) ;
177185
178- if ( weight_decay != 0 ) {
179- grad = grad . add ( param , alpha : weight_decay ) ;
180- }
186+ if ( need_weight_decay ) grad = grad . add ( param , alpha : weight_decay_scalar ) ;
181187
182- state . exp_avg . mul_ ( beta1 ) . add_ ( grad , alpha : 1 - beta1 ) ;
183- state . exp_avg_sq . mul_ ( beta2 ) . addcmul_ ( grad , grad . conj ( ) , value : 1 - beta2 ) ;
188+ state . exp_avg . mul_ ( beta1_scalar ) . add_ ( grad , alpha : beta1_bar_scalar ) ;
189+ state . exp_avg_sq . mul_ ( beta2_scalar ) . addcmul_ ( grad , grad . conj ( ) , value : beta2_bar_scalar ) ;
184190
185- Tensor denom = null ;
191+ Tensor denom = null ; // FIXME: Need dispose?
186192 if ( amsgrad ) {
187193 var t0 = state . max_exp_avg_sq ;
188194 state . max_exp_avg_sq = torch . maximum ( t0 , state . exp_avg_sq ) . DetachFromDisposeScope ( ) ;
189195 t0 . Dispose ( ) ;
190- denom = ( state . max_exp_avg_sq . sqrt ( ) / Math . Sqrt ( bias_correction2 ) ) . add_ ( eps ) ;
196+ denom = ( state . max_exp_avg_sq . sqrt ( ) / Math . Sqrt ( bias_correction2 ) ) . add_ ( eps_scalar ) ;
191197 } else {
192- denom = ( state . exp_avg_sq . sqrt ( ) / Math . Sqrt ( bias_correction2 ) ) . add_ ( eps ) ;
198+ denom = ( state . exp_avg_sq . sqrt ( ) / Math . Sqrt ( bias_correction2 ) ) . add_ ( eps_scalar ) ;
193199 }
194200
195201 var step_size = lr / bias_correction1 ;
196- param . addcdiv_ ( state . exp_avg , denom , value : - step_size ) ;
202+ using var negative_step_size_scalar = ( - step_size ) . ToScalar ( ) ;
203+ param . addcdiv_ ( state . exp_avg , denom , value : negative_step_size_scalar ) ;
197204 }
198205 } , closure ) ;
199206 }
0 commit comments