@@ -142,8 +142,14 @@ public override Tensor step(Func<Tensor> closure = null)
142142 var options = group . Options as Options ;
143143 var beta1 = options . beta1 . Value ;
144144 var beta2 = options . beta2 . Value ;
145- var eps = options . eps . Value ;
145+ using var beta1_scalar = beta1 . ToScalar ( ) ;
146+ using var beta2_scalar = beta2 . ToScalar ( ) ;
147+ var beta1_bar = 1 - beta1 ;
148+ using var beta1_bar_scalar = beta1_bar . ToScalar ( ) ;
149+ using var eps_scalar = options . eps . Value . ToScalar ( ) ;
146150 var weight_decay = options . weight_decay . Value ;
151+ var need_weight_decay = weight_decay != 0 ;
152+ using var weight_decay_scalar = weight_decay . ToScalar ( ) ; // FIXME: Omit if not need_weight_decay?
147153 var lr = options . LearningRate . Value ;
148154
149155 foreach ( var param in group . Parameters ) {
@@ -161,21 +167,22 @@ public override Tensor step(Func<Tensor> closure = null)
161167 var exp_avg = state . exp_avg ;
162168 var exp_inf = state . exp_inf ;
163169
164- grad = ( weight_decay != 0 )
165- ? grad . add ( param , alpha : weight_decay )
170+ grad = ( need_weight_decay )
171+ ? grad . add ( param , alpha : weight_decay_scalar )
166172 : grad . alias ( ) ;
167173
168- exp_avg . mul_ ( beta1 ) . add_ ( grad , alpha : 1 - beta1 ) ;
174+ exp_avg . mul_ ( beta1_scalar ) . add_ ( grad , alpha : beta1_bar_scalar ) ;
169175
170176 var norm_buf = torch . cat ( new Tensor [ ] {
171- exp_inf . mul_ ( beta2 ) . unsqueeze ( 0 ) ,
172- grad . abs ( ) . add_ ( eps ) . unsqueeze_ ( 0 )
177+ exp_inf . mul_ ( beta2_scalar ) . unsqueeze ( 0 ) ,
178+ grad . abs ( ) . add_ ( eps_scalar ) . unsqueeze_ ( 0 )
173179 } , 0 ) ;
174180
175- torch . amax ( norm_buf , new long [ ] { 0 } , false , exp_inf ) ;
181+ torch . amax ( norm_buf , new long [ ] { 0 } , false , exp_inf ) ; // FIXME: CA1806?
176182
177183 var clr = lr / ( 1 - Math . Pow ( beta1 , state . step ) ) ;
178- param . addcdiv_ ( exp_avg , exp_inf , value : - clr ) ;
184+ using var negative_clr_scalar = ( - clr ) . ToScalar ( ) ;
185+ param . addcdiv_ ( exp_avg , exp_inf , value : negative_clr_scalar ) ;
179186 }
180187 } , closure ) ;
181188 }
0 commit comments