@@ -139,9 +139,12 @@ public override Tensor step(Func<Tensor> closure = null)
139139 var options = group . Options as Options ;
140140 var lr_decay = options . lr_decay . Value ;
141141 var weight_decay = options . weight_decay . Value ;
142- var eps = options . eps . Value ;
143- var initial_accumulator_value = options . initial_accumulator_value . Value ;
142+ var need_weight_decay = weight_decay != 0 ;
143+ using var weight_decay_scalar = weight_decay . ToScalar ( ) ; // FIXME: Omit if not need_weight_decay?
144+ using var eps_scalar = options . eps . Value . ToScalar ( ) ;
145+ var initial_accumulator_value = options . initial_accumulator_value . Value ; // FIXME: Unused?
144146 var lr = options . LearningRate . Value ;
147+ using var one_scalar = 1 . ToScalar ( ) ;
145148
146149 foreach ( var param in group . Parameters ) {
147150
@@ -153,20 +156,19 @@ public override Tensor step(Func<Tensor> closure = null)
153156
154157 state . step += 1 ;
155158
156- if ( weight_decay != 0 ) {
157- grad = grad . add ( param , alpha : weight_decay ) ;
158- }
159+ if ( need_weight_decay ) grad = grad . add ( param , alpha : weight_decay_scalar ) ;
159160
160161 var clr = lr / ( 1 + ( state . step - 1 ) * lr_decay ) ;
162+ using var negative_clr_scalar = ( - clr ) . ToScalar ( ) ;
161163
162164 if ( grad . is_sparse )
163165 throw new NotImplementedException ( "Adagrad optimization over sparse parameters" ) ;
164166 if ( torch . is_complex ( grad ) )
165167 throw new NotImplementedException ( "Adagrad optimization over complex parameters" ) ;
166168
167- state . sum . addcmul_ ( grad , grad , value : 1 ) ;
168- var std = state . sum . sqrt ( ) . add_ ( eps ) ;
169- param . addcdiv_ ( grad , std , value : - clr ) ;
169+ state . sum . addcmul_ ( grad , grad , value : one_scalar ) ;
170+ var std = state . sum . sqrt ( ) . add_ ( eps_scalar ) ;
171+ param . addcdiv_ ( grad , std , value : negative_clr_scalar ) ;
170172 }
171173
172174
0 commit comments