@@ -141,9 +141,19 @@ public override Tensor step(Func<Tensor> closure = null)
141141 var options = group . Options as Options ;
142142 var beta1 = options . beta1 . Value ;
143143 var beta2 = options . beta2 . Value ;
144- var eps = options . eps . Value ;
144+ using var beta1_scalar = beta1 . ToScalar ( ) ;
145+ using var beta2_scalar = beta2 . ToScalar ( ) ;
146+ var beta1_bar = 1 - beta1 ;
147+ var beta2_bar = 1 - beta2 ;
148+ using var beta1_bar_scalar = beta1_bar . ToScalar ( ) ;
149+ using var beta2_bar_scalar = beta2_bar . ToScalar ( ) ;
150+ using var eps_scalar = options . eps . Value . ToScalar ( ) ;
145151 var weight_decay = options . weight_decay . Value ;
152+ var need_weight_decay = weight_decay != 0 ;
153+ using var weight_decay_scalar = weight_decay . ToScalar ( ) ; // FIXME: Omit if not need_weight_decay?
146154 var lr = options . LearningRate . Value ;
155+ using var lr_scalar = lr . ToScalar ( ) ;
156+ using var negative_one_scalar = ( - 1.0 ) . ToScalar ( ) ; // FIXME: Use torch.Tensor.sub_ instead?
147157
148158 foreach ( var param in group . Parameters ) {
149159
@@ -161,27 +171,27 @@ public override Tensor step(Func<Tensor> closure = null)
161171 var bias_correction1 = 1 - Math . Pow ( beta1 , state . step ) ;
162172 var bias_correction2 = 1 - Math . Pow ( beta2 , state . step ) ;
163173
164- grad = ( weight_decay != 0 )
165- ? grad . add ( param , alpha : weight_decay )
174+ grad = ( need_weight_decay )
175+ ? grad . add ( param , alpha : weight_decay_scalar )
166176 : grad . alias ( ) ;
167177
168- exp_avg . mul_ ( beta1 ) . add_ ( grad , alpha : 1 - beta1 ) ;
169- exp_avg_sq . mul_ ( beta2 ) . addcmul_ ( grad , grad , value : 1 - beta2 ) ;
178+ exp_avg . mul_ ( beta1_scalar ) . add_ ( grad , alpha : beta1_bar_scalar ) ;
179+ exp_avg_sq . mul_ ( beta2_scalar ) . addcmul_ ( grad , grad , value : beta2_bar_scalar ) ;
170180
171- var bias_corrected_exp_avg = exp_avg / bias_correction1 ;
181+ var bias_corrected_exp_avg = exp_avg / bias_correction1 ; // FIXME: Need dispose?
172182
173183 var rho_inf = 2 / ( 1 - beta2 ) - 1 ;
174184 var rho_t = rho_inf - 2 * state . step * Math . Pow ( beta2 , state . step ) / bias_correction2 ;
175185
176- var t6 = bias_corrected_exp_avg * lr ;
186+ var t6 = bias_corrected_exp_avg . mul ( lr_scalar ) ; // FIXME: Need dispose?
177187
178188 if ( rho_t > 5 ) {
179189 var rect = Math . Sqrt ( ( rho_t - 4 ) * ( rho_t - 2 ) * rho_inf / ( ( rho_inf - 4 ) * ( rho_inf - 2 ) * rho_t ) ) ;
180- var adaptive_lr = Math . Sqrt ( bias_correction2 ) / exp_avg_sq . sqrt ( ) . add_ ( eps ) ;
190+ var adaptive_lr = Math . Sqrt ( bias_correction2 ) / exp_avg_sq . sqrt ( ) . add_ ( eps_scalar ) ; // FIXME: Need dispose?
181191
182- param . add_ ( t6 * lr * adaptive_lr * rect , alpha : - 1.0 ) ;
192+ param . add_ ( t6 * lr * adaptive_lr * rect , alpha : negative_one_scalar ) ; // FIXME: Need dispose? Use inplace ops?
183193 } else {
184- param . add_ ( t6 , alpha : - 1.0 ) ;
194+ param . add_ ( t6 , alpha : negative_one_scalar ) ;
185195 }
186196 }
187197 } , closure ) ;
0 commit comments