@@ -140,6 +140,8 @@ public override Tensor step(Func<Tensor> closure = null)
140140 var lambd = options . lambd . Value ;
141141 var alpha = options . alpha . Value ;
142142 var weight_decay = options . weight_decay . Value ;
143+ var need_weight_decay = weight_decay != 0 ;
144+ using var weight_decay_scalar = weight_decay . ToScalar ( ) ; // FIXME: Omit if not need_weight_decay?
143145 var t0 = options . t0 . Value ;
144146 var lr = options . LearningRate . Value ;
145147
@@ -157,15 +159,19 @@ public override Tensor step(Func<Tensor> closure = null)
157159
158160 state . step += 1 ;
159161
160- grad = ( weight_decay != 0 )
161- ? grad . add ( param , alpha : weight_decay )
162+ grad = ( need_weight_decay )
163+ ? grad . add ( param , alpha : weight_decay_scalar )
162164 : grad . alias ( ) ;
163165
164- param . mul_ ( 1 - lambd * state . eta ) ;
165- param . add_ ( grad , alpha : - state . eta ) ;
166+ var lambd_eta_bar = 1 - lambd * state . eta ;
167+ using var lambd_eta_bar_scalar = lambd_eta_bar . ToScalar ( ) ;
168+ param . mul_ ( lambd_eta_bar_scalar ) ;
169+ using var negative_eta_scalar = ( - state . eta ) . ToScalar ( ) ;
170+ param . add_ ( grad , alpha : negative_eta_scalar ) ;
166171
167172 if ( state . mu != 1 ) {
168- state . ax . add_ ( param . sub ( state . ax ) . mul ( state . mu ) ) ;
173+ using var mu_scalar = state . mu . ToScalar ( ) ;
174+ state . ax . add_ ( param . sub ( state . ax ) . mul ( mu_scalar ) ) ;
169175 } else {
170176 state . ax . copy_ ( param ) ;
171177 }
0 commit comments