@@ -147,8 +147,16 @@ public override Tensor step(Func<Tensor> closure = null)
147147 var options = group . Options as Options ;
148148 var beta1 = options . beta1 . Value ;
149149 var beta2 = options . beta2 . Value ;
150- var eps = options . eps . Value ;
150+ using var beta1_scalar = beta1 . ToScalar ( ) ;
151+ using var beta2_scalar = beta2 . ToScalar ( ) ;
152+ var beta1_bar = 1 - beta1 ;
153+ var beta2_bar = 1 - beta2 ;
154+ using var beta1_bar_scalar = beta1_bar . ToScalar ( ) ;
155+ using var beta2_bar_scalar = beta2_bar . ToScalar ( ) ;
156+ using var eps_scalar = options . eps . Value . ToScalar ( ) ;
151157 var weight_decay = options . weight_decay . Value ;
158+ var need_weight_decay = weight_decay != 0 ;
159+ using var weight_decay_scalar = weight_decay . ToScalar ( ) ; // FIXME: Omit if not need_weight_decay?
152160 var momentum_decay = options . momentum_decay . Value ;
153161 var lr = options . LearningRate . Value ;
154162
@@ -166,9 +174,10 @@ public override Tensor step(Func<Tensor> closure = null)
166174 var exp_avg_sq = state . exp_avg_sq ;
167175
168176 var bias_correction2 = 1 - Math . Pow ( beta2 , state . step ) ;
177+ using var bias_correction2_scalar = bias_correction2 . ToScalar ( ) ;
169178
170- grad = ( weight_decay != 0 )
171- ? grad . add ( param , alpha : weight_decay )
179+ grad = ( need_weight_decay )
180+ ? grad . add ( param , alpha : weight_decay_scalar )
172181 : grad . alias ( ) ;
173182
174183 var mu = beta1 * ( 1.0 - 0.5 * Math . Pow ( 0.96 , state . step * momentum_decay ) ) ;
@@ -177,13 +186,17 @@ public override Tensor step(Func<Tensor> closure = null)
177186 var mu_product = state . mu_product * mu ;
178187 var mu_product_next = mu_product * mu_next ;
179188
180- exp_avg . mul_ ( beta1 ) . add_ ( grad , alpha : 1 - beta1 ) ;
181- exp_avg_sq . mul_ ( beta2 ) . addcmul_ ( grad , grad , value : 1 - beta2 ) ;
189+ exp_avg . mul_ ( beta1_scalar ) . add_ ( grad , alpha : beta1_bar_scalar ) ;
190+ exp_avg_sq . mul_ ( beta2_scalar ) . addcmul_ ( grad , grad , value : beta2_bar_scalar ) ;
182191
183- var denom = exp_avg_sq . div ( bias_correction2 ) . sqrt_ ( ) . add_ ( eps ) ;
192+ var denom = exp_avg_sq . div ( bias_correction2_scalar ) . sqrt_ ( ) . add_ ( eps_scalar ) ; // FIXME: Need dispose?
184193
185- param . addcdiv_ ( grad , denom , value : - lr * ( 1 - mu ) / ( 1 - mu_product ) ) ;
186- param . addcdiv_ ( exp_avg , denom , value : - lr * mu_next / ( 1 - mu_product_next ) ) ;
194+ var scaled_lr = lr * ( 1 - mu ) / ( 1 - mu_product ) ;
195+ using var negative_scaled_scalar = ( - scaled_lr ) . ToScalar ( ) ;
196+ param . addcdiv_ ( grad , denom , value : negative_scaled_scalar ) ;
197+ var scaled_lr_next = lr * mu_next / ( 1 - mu_product_next ) ;
198+ using var negative_scaled_lr_next_scalar = ( - scaled_lr_next ) . ToScalar ( ) ;
199+ param . addcdiv_ ( exp_avg , denom , value : negative_scaled_lr_next_scalar ) ;
187200
188201 state . mu_product = mu_product ;
189202 }
0 commit comments