@@ -129,10 +129,15 @@ public override Tensor step(Func<Tensor> closure = null)
129129
130130 var options = group . Options as Options ;
131131 var rho = options . rho . Value ;
132- var eps = options . eps . Value ;
132+ using var rho_scalar = rho . ToScalar ( ) ;
133+ using var rho_bar_scalar = ( 1 - rho ) . ToScalar ( ) ;
134+ using var eps_scalar = options . eps . Value . ToScalar ( ) ;
133135 var weight_decay = options . weight_decay . Value ;
136+ var need_weight_decay = ( weight_decay != 0 ) ;
137+ using var weight_decay_scalar = weight_decay . ToScalar ( ) ; // FIXME: Omit if not need_weight_decay?
134138 var maximize = options . maximize . Value ;
135139 var lr = options . LearningRate . Value ;
140+ using var negative_lr_scalar = ( - lr ) . ToScalar ( ) ;
136141
137142 foreach ( var param in group . Parameters ) {
138143
@@ -149,17 +154,17 @@ public override Tensor step(Func<Tensor> closure = null)
149154 var square_avg = state . square_avg ;
150155 var acc_delta = state . acc_delta ;
151156
152- grad = ( weight_decay != 0 )
153- ? grad . add ( param , alpha : weight_decay )
157+ grad = ( need_weight_decay )
158+ ? grad . add ( param , alpha : weight_decay_scalar )
154159 : grad . alias ( ) ;
155160
156- square_avg . mul_ ( rho ) . addcmul_ ( grad , grad , 1 - rho ) ;
161+ square_avg . mul_ ( rho_scalar ) . addcmul_ ( grad , grad , rho_bar_scalar ) ;
157162
158- var std = square_avg . add ( eps ) . sqrt_ ( ) ;
159- var delta = acc_delta . add ( eps ) . sqrt_ ( ) . div_ ( std ) . mul_ ( grad ) ;
163+ var std = square_avg . add ( eps_scalar ) . sqrt_ ( ) ;
164+ var delta = acc_delta . add ( eps_scalar ) . sqrt_ ( ) . div_ ( std ) . mul_ ( grad ) ;
160165
161- param . add_ ( delta , alpha : - lr ) ;
162- acc_delta . mul_ ( rho ) . addcmul_ ( delta , delta , 1 - rho ) ;
166+ param . add_ ( delta , alpha : negative_lr_scalar ) ;
167+ acc_delta . mul_ ( rho_scalar ) . addcmul_ ( delta , delta , rho_bar_scalar ) ;
163168 }
164169 } , closure ) ;
165170 }
0 commit comments