@@ -33,17 +33,15 @@ public EulerAncestralScheduler(SchedulerOptions schedulerOptions) : base(schedul
3333 /// </summary>
3434 protected override void Initialize ( )
3535 {
36+ _sigmas = null ;
3637 var betas = Enumerable . Empty < float > ( ) ;
3738 if ( ! Options . TrainedBetas . IsNullOrEmpty ( ) )
3839 {
3940 betas = Options . TrainedBetas ;
4041 }
4142 else if ( Options . BetaSchedule == BetaScheduleType . Linear )
4243 {
43- var steps = Options . TrainTimesteps - 1 ;
44- var delta = Options . BetaStart + ( Options . BetaEnd - Options . BetaStart ) ;
45- betas = Enumerable . Range ( 0 , Options . TrainTimesteps )
46- . Select ( i => delta * i / steps ) ;
44+ betas = np . linspace ( Options . BetaStart , Options . BetaEnd , Options . TrainTimesteps ) . ToArray < float > ( ) ;
4745 }
4846 else if ( Options . BetaSchedule == BetaScheduleType . ScaledLinear )
4947 {
@@ -58,16 +56,12 @@ protected override void Initialize()
5856 betas = GetBetasForAlphaBar ( ) ;
5957 }
6058
59+ var alphas = betas . Select ( beta => 1.0f - beta ) ;
60+ var alphaCumProd = alphas . Select ( ( alpha , i ) => alphas . Take ( i + 1 ) . Aggregate ( ( a , b ) => a * b ) ) ;
61+ _sigmas = alphaCumProd
62+ . Select ( alpha_prod => ( float ) Math . Sqrt ( ( 1 - alpha_prod ) / alpha_prod ) )
63+ . ToArray ( ) ;
6164
62- var alphas = betas . Select ( beta => 1 - beta ) ;
63- var cumulativeProduct = alphas . Select ( ( alpha , i ) => alphas . Take ( i + 1 ) . Aggregate ( ( a , b ) => a * b ) ) ;
64-
65- // Create _sigmas as a list and reverse it
66- _sigmas = cumulativeProduct
67- . Select ( alpha_prod => ( float ) Math . Sqrt ( ( 1 - alpha_prod ) / alpha_prod ) )
68- . ToArray ( ) ;
69-
70- // standard deviation of the initial noise distrubution
7165 var maxSigma = _sigmas . Max ( ) ;
7266 var initNoiseSigma = Options . TimestepSpacing == TimestepSpacingType . Linspace || Options . TimestepSpacing == TimestepSpacingType . Trailing
7367 ? maxSigma
@@ -115,8 +109,9 @@ protected override int[] SetTimesteps()
115109 timesteps = SigmaToTimestep ( sigmas , log_sigmas ) ;
116110 }
117111
118- // add 0.000 to the end of the result
119- _sigmas = sigmas . Append ( 0.000f ) . ToArray ( ) ;
112+ _sigmas = sigmas
113+ . Append ( 0.000f )
114+ . ToArray ( ) ;
120115
121116 return timesteps . Select ( x => ( int ) x )
122117 . OrderByDescending ( x => x )
@@ -160,7 +155,23 @@ public override DenseTensor<float> Step(DenseTensor<float> modelOutput, int time
160155 var sigma = _sigmas [ stepIndex ] ;
161156
162157 // 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
163- var predOriginalSample = sample . SubtractTensors ( modelOutput . MultipleTensorByFloat ( sigma ) ) ;
158+ DenseTensor < float > predOriginalSample = null ;
159+ if ( Options . PredictionType == PredictionType . Epsilon )
160+ {
161+ predOriginalSample = sample . SubtractTensors ( modelOutput . MultipleTensorByFloat ( sigma ) ) ;
162+ }
163+ else if ( Options . PredictionType == PredictionType . VariablePrediction )
164+ {
165+ var sigmaSqrt = ( float ) Math . Sqrt ( sigma * sigma + 1 ) ;
166+ predOriginalSample = sample . DivideTensorByFloat ( sigmaSqrt )
167+ . AddTensors ( modelOutput . MultipleTensorByFloat ( - sigma / sigmaSqrt ) ) ;
168+ }
169+ else if ( Options . PredictionType == PredictionType . Sample )
170+ {
171+ //prediction_type not implemented yet: sample
172+ predOriginalSample = modelOutput . ToDenseTensor ( ) ;
173+ }
174+
164175
165176 var sigmaFrom = _sigmas [ stepIndex ] ;
166177 var sigmaTo = _sigmas [ stepIndex + 1 ] ;
0 commit comments