@@ -245,28 +245,24 @@ protected int GetPreviousTimestep(int timestep)
245245 /// <returns></returns>
246246 protected float [ ] GetBetasForAlphaBar ( )
247247 {
248- var betas = new float [ _options . TrainTimesteps ] ;
249-
250248 Func < float , float > alphaBarFn = null ;
251249 if ( _options . AlphaTransformType == AlphaTransformType . Cosine )
252250 {
253- alphaBarFn = t => ( float ) Math . Pow ( Math . Cos ( ( t + 0.008 ) / 1.008 * Math . PI / 2.0 ) , 2.0 ) ;
251+ alphaBarFn = t => ( float ) Math . Pow ( Math . Cos ( ( t + 0.008f ) / 1.008f * Math . PI / 2.0f ) , 2.0f ) ;
254252 }
255253 else if ( _options . AlphaTransformType == AlphaTransformType . Exponential )
256254 {
257- alphaBarFn = t => ( float ) Math . Exp ( t * - 12.0 ) ;
255+ alphaBarFn = t => ( float ) Math . Exp ( t * - 12.0f ) ;
258256 }
259257
260- for ( int i = 0 ; i < _options . TrainTimesteps ; i ++ )
261- {
262- float t1 = ( float ) i / _options . TrainTimesteps ;
263- float t2 = ( float ) ( i + 1 ) / _options . TrainTimesteps ;
264- float alphaT1 = alphaBarFn ( t1 ) ;
265- float alphaT2 = alphaBarFn ( t2 ) ;
266- float beta = Math . Min ( 1 - alphaT2 / alphaT1 , _options . MaximumBeta ) ;
267- betas [ i ] = ( float ) Math . Max ( beta , 0.0001 ) ;
268- }
269- return betas ;
258+ return Enumerable
259+ . Range ( 0 , _options . TrainTimesteps )
260+ . Select ( i =>
261+ {
262+ var t1 = ( float ) i / _options . TrainTimesteps ;
263+ var t2 = ( float ) ( i + 1 ) / _options . TrainTimesteps ;
264+ return Math . Min ( 1f - alphaBarFn ( t2 ) / alphaBarFn ( t1 ) , _options . MaximumBeta ) ;
265+ } ) . ToArray ( ) ;
270266 }
271267
272268
0 commit comments