@@ -152,11 +152,20 @@ public override Tensor step(Func<Tensor> closure = null)
152152 var options = group . Options as Options ;
153153 var maximize = options . maximize . Value ;
154154 var momentum = options . momentum . Value ;
155+ var need_momentum = momentum > 0 ;
156+ using var momentum_scalar = momentum . ToScalar ( ) ; // FIXME: Omit if not need_momentum?
155157 var alpha = options . alpha . Value ;
158+ var alpha_bar = 1 - alpha ;
159+ using var alpha_scalar = alpha . ToScalar ( ) ;
160+ using var alpha_bar_scalar = alpha_bar . ToScalar ( ) ;
156161 var weight_decay = options . weight_decay . Value ;
162+ var need_weight_decay = weight_decay != 0 ;
163+ using var weight_decay_scalar = weight_decay . ToScalar ( ) ; // FIXME: Omit if not need_weight_decay?
157164 var centered = options . centered . Value ;
158- var eps = options . eps . Value ;
165+ using var negative_one_scalar = ( - 1 ) . ToScalar ( ) ;
166+ using var eps_scalar = options . eps . Value . ToScalar ( ) ;
159167 var lr = options . LearningRate . Value ;
168+ using var negative_lr_scalar = ( - lr ) . ToScalar ( ) ;
160169
161170 foreach ( var param in group . Parameters ) {
162171
@@ -170,28 +179,26 @@ public override Tensor step(Func<Tensor> closure = null)
170179
171180 state . step += 1 ;
172181
173- if ( weight_decay != 0 ) {
174- grad = grad . add ( param , alpha : weight_decay ) ;
175- }
182+ if ( need_weight_decay ) grad = grad . add ( param , alpha : weight_decay_scalar ) ;
176183
177- state . square_avg . mul_ ( alpha ) . addcmul_ ( grad , grad , value : 1 - alpha ) ;
184+ state . square_avg . mul_ ( alpha_scalar ) . addcmul_ ( grad , grad , value : alpha_bar_scalar ) ;
178185
179- Tensor avg = null ;
186+ Tensor avg = null ; // FIXME: Need dispose?
180187
181188 if ( centered ) {
182189 var grad_avg = state . grad_avg ;
183- grad_avg . mul_ ( alpha ) . add_ ( grad , alpha : 1 - alpha ) ;
184- avg = state . square_avg . addcmul ( grad_avg , grad_avg , value : - 1 ) . sqrt_ ( ) . add_ ( eps ) ;
190+ grad_avg . mul_ ( alpha_scalar ) . add_ ( grad , alpha : alpha_bar_scalar ) ;
191+ avg = state . square_avg . addcmul ( grad_avg , grad_avg , value : negative_one_scalar ) . sqrt_ ( ) . add_ ( eps_scalar ) ;
185192 } else {
186- avg = state . square_avg . sqrt ( ) . add_ ( eps ) ;
193+ avg = state . square_avg . sqrt ( ) . add_ ( eps_scalar ) ;
187194 }
188195
189- if ( momentum > 0 ) {
196+ if ( need_momentum ) {
190197 var buf = state . momentum_buffer ;
191- buf . mul_ ( momentum ) . addcdiv_ ( grad , avg ) ;
192- param . add_ ( buf , alpha : - lr ) ;
198+ buf . mul_ ( momentum_scalar ) . addcdiv_ ( grad , avg ) ;
199+ param . add_ ( buf , alpha : negative_lr_scalar ) ;
193200 } else {
194- param . addcdiv_ ( grad , avg , - lr ) ;
201+ param . addcdiv_ ( grad , avg , negative_lr_scalar ) ;
195202 }
196203 }
197204 } , closure ) ;
0 commit comments