@@ -11,8 +11,7 @@ namespace OnnxStack.StableDiffusion.Schedulers
1111{
1212 internal class DDPMScheduler : SchedulerBase
1313 {
14- private float [ ] _betas ;
15- private List < float > _alphasCumulativeProducts ;
14+ private float [ ] _alphasCumProd ;
1615
1716 /// <summary>
1817 /// Initializes a new instance of the <see cref="DDPMScheduler"/> class.
@@ -33,43 +32,13 @@ public DDPMScheduler(SchedulerOptions options) : base(options) { }
3332 /// </summary>
3433 protected override void Initialize ( )
3534 {
36- var alphas = new List < float > ( ) ;
37- if ( Options . TrainedBetas != null )
38- {
39- _betas = Options . TrainedBetas . ToArray ( ) ;
40- }
41- else if ( Options . BetaSchedule == BetaScheduleType . Linear )
42- {
43- _betas = np . linspace ( Options . BetaStart , Options . BetaEnd , Options . TrainTimesteps ) . ToArray < float > ( ) ;
44- }
45- else if ( Options . BetaSchedule == BetaScheduleType . ScaledLinear )
46- {
47- // This schedule is very specific to the latent diffusion model.
48- _betas = np . power ( np . linspace ( MathF . Sqrt ( Options . BetaStart ) , MathF . Sqrt ( Options . BetaEnd ) , Options . TrainTimesteps ) , 2 ) . ToArray < float > ( ) ;
49- }
50- else if ( Options . BetaSchedule == BetaScheduleType . SquaredCosCapV2 )
51- {
52- // Glide cosine schedule
53- _betas = GetBetasForAlphaBar ( ) ;
54- }
55- //else if (betaSchedule == "sigmoid")
56- //{
57- // // GeoDiff sigmoid schedule
58- // var betas = np.linspace(-6, 6, numTrainTimesteps);
59- // Betas = (np.multiply(np.exp(betas), (betaEnd - betaStart)) + betaStart).ToArray<float>();
60- //}
61-
62-
63- for ( int i = 0 ; i < Options . TrainTimesteps ; i ++ )
64- {
65- alphas . Add ( 1.0f - _betas [ i ] ) ;
66- }
35+ _alphasCumProd = null ;
6736
68- _alphasCumulativeProducts = new List < float > { alphas [ 0 ] } ;
69- for ( int i = 1 ; i < Options . TrainTimesteps ; i ++ )
70- {
71- _alphasCumulativeProducts . Add ( _alphasCumulativeProducts [ i - 1 ] * alphas [ i ] ) ;
72- }
37+ var betas = GetBetaSchedule ( ) ;
38+ var alphas = betas . Select ( beta => 1.0f - beta ) ;
39+ _alphasCumProd = alphas
40+ . Select ( ( alpha , i ) => alphas . Take ( i + 1 ) . Aggregate ( ( a , b ) => a * b ) )
41+ . ToArray ( ) ;
7342
7443 SetInitNoiseSigma ( 1.0f ) ;
7544 }
@@ -82,29 +51,8 @@ protected override void Initialize()
8251 protected override int [ ] SetTimesteps ( )
8352 {
8453 // Create timesteps based on the specified strategy
85- NDArray timestepsArray = null ;
86- if ( Options . TimestepSpacing == TimestepSpacingType . Linspace )
87- {
88- timestepsArray = np . linspace ( 0 , Options . TrainTimesteps - 1 , Options . InferenceSteps ) ;
89- timestepsArray = np . around ( timestepsArray ) [ "::1" ] ;
90- }
91- else if ( Options . TimestepSpacing == TimestepSpacingType . Leading )
92- {
93- var stepRatio = Options . TrainTimesteps / Options . InferenceSteps ;
94- timestepsArray = np . arange ( 0 , ( float ) Options . InferenceSteps ) * stepRatio ;
95- timestepsArray = np . around ( timestepsArray ) [ "::1" ] ;
96- timestepsArray += Options . StepsOffset ;
97- }
98- else if ( Options . TimestepSpacing == TimestepSpacingType . Trailing )
99- {
100- var stepRatio = Options . TrainTimesteps / ( Options . InferenceSteps - 1 ) ;
101- timestepsArray = np . arange ( ( float ) Options . TrainTimesteps , 0 , - stepRatio ) [ "::-1" ] ;
102- timestepsArray = np . around ( timestepsArray ) ;
103- timestepsArray -= 1 ;
104- }
105-
106- return timestepsArray
107- . ToArray < float > ( )
54+ var timesteps = GetTimesteps ( ) ;
55+ return timesteps
10856 . Select ( x => ( int ) x )
10957 . OrderByDescending ( x => x )
11058 . ToArray ( ) ;
@@ -139,8 +87,8 @@ public override DenseTensor<float> Step(DenseTensor<float> modelOutput, int time
13987 int previousTimestep = GetPreviousTimestep ( currentTimestep ) ;
14088
14189 //# 1. compute alphas, betas
142- float alphaProdT = _alphasCumulativeProducts [ currentTimestep ] ;
143- float alphaProdTPrev = previousTimestep >= 0 ? _alphasCumulativeProducts [ previousTimestep ] : 1f ;
90+ float alphaProdT = _alphasCumProd [ currentTimestep ] ;
91+ float alphaProdTPrev = previousTimestep >= 0 ? _alphasCumProd [ previousTimestep ] : 1f ;
14492 float betaProdT = 1 - alphaProdT ;
14593 float betaProdTPrev = 1 - alphaProdTPrev ;
14694 float currentAlphaT = alphaProdT / alphaProdTPrev ;
@@ -161,27 +109,7 @@ public override DenseTensor<float> Step(DenseTensor<float> modelOutput, int time
161109
162110 //# 2. compute predicted original sample from predicted noise also called
163111 //# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
164- DenseTensor < float > predOriginalSample = null ;
165- if ( Options . PredictionType == PredictionType . Epsilon )
166- {
167- //pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
168- var sampleBeta = sample . SubtractTensors ( modelOutput . MultipleTensorByFloat ( ( float ) Math . Sqrt ( betaProdT ) ) ) ;
169- predOriginalSample = sampleBeta . DivideTensorByFloat ( ( float ) Math . Sqrt ( alphaProdT ) , sampleBeta . Dimensions ) ;
170- }
171- else if ( Options . PredictionType == PredictionType . Sample )
172- {
173- predOriginalSample = modelOutput ;
174- }
175- else if ( Options . PredictionType == PredictionType . VariablePrediction )
176- {
177- // pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
178- var alphaSqrt = ( float ) Math . Sqrt ( alphaProdT ) ;
179- var betaSqrt = ( float ) Math . Sqrt ( betaProdT ) ;
180- predOriginalSample = sample
181- . MultipleTensorByFloat ( alphaSqrt )
182- . SubtractTensors ( modelOutput . MultipleTensorByFloat ( betaSqrt ) ) ;
183- }
184-
112+ var predOriginalSample = GetPredictedSample ( modelOutput , sample , alphaProdT , betaProdT ) ;
185113
186114 //# 3. Clip or threshold "predicted x_0"
187115 if ( Options . Thresholding )
@@ -234,6 +162,31 @@ public override DenseTensor<float> Step(DenseTensor<float> modelOutput, int time
234162 }
235163
236164
165+ private DenseTensor < float > GetPredictedSample ( DenseTensor < float > modelOutput , DenseTensor < float > sample , float alphaProdT , float betaProdT )
166+ {
167+ DenseTensor < float > predOriginalSample = null ;
168+ if ( Options . PredictionType == PredictionType . Epsilon )
169+ {
170+ //pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
171+ var sampleBeta = sample . SubtractTensors ( modelOutput . MultipleTensorByFloat ( ( float ) Math . Sqrt ( betaProdT ) ) ) ;
172+ predOriginalSample = sampleBeta . DivideTensorByFloat ( ( float ) Math . Sqrt ( alphaProdT ) , sampleBeta . Dimensions ) ;
173+ }
174+ else if ( Options . PredictionType == PredictionType . Sample )
175+ {
176+ predOriginalSample = modelOutput ;
177+ }
178+ else if ( Options . PredictionType == PredictionType . VariablePrediction )
179+ {
180+ // pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
181+ var alphaSqrt = ( float ) Math . Sqrt ( alphaProdT ) ;
182+ var betaSqrt = ( float ) Math . Sqrt ( betaProdT ) ;
183+ predOriginalSample = sample
184+ . MultipleTensorByFloat ( alphaSqrt )
185+ . SubtractTensors ( modelOutput . MultipleTensorByFloat ( betaSqrt ) ) ;
186+ }
187+ return predOriginalSample ;
188+ }
189+
237190 /// <summary>
238191 /// Adds noise to the sample.
239192 /// </summary>
@@ -245,7 +198,7 @@ public override DenseTensor<float> AddNoise(DenseTensor<float> originalSamples,
245198 {
246199 // Ref: https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_ddpm.py#L456
247200 int timestep = timesteps [ 0 ] ;
248- float alphaProd = _alphasCumulativeProducts [ timestep ] ;
201+ float alphaProd = _alphasCumProd [ timestep ] ;
249202 float sqrtAlpha = ( float ) Math . Sqrt ( alphaProd ) ;
250203 float sqrtOneMinusAlpha = ( float ) Math . Sqrt ( 1.0f - alphaProd ) ;
251204
@@ -263,8 +216,8 @@ public override DenseTensor<float> AddNoise(DenseTensor<float> originalSamples,
263216 private float GetVariance ( int timestep , float predictedVariance = 0f )
264217 {
265218 int prevTimestep = GetPreviousTimestep ( timestep ) ;
266- float alphaProdT = _alphasCumulativeProducts [ timestep ] ;
267- float alphaProdTPrev = prevTimestep >= 0 ? _alphasCumulativeProducts [ prevTimestep ] : 1.0f ;
219+ float alphaProdT = _alphasCumProd [ timestep ] ;
220+ float alphaProdTPrev = prevTimestep >= 0 ? _alphasCumProd [ prevTimestep ] : 1.0f ;
268221 float currentBetaT = 1 - alphaProdT / alphaProdTPrev ;
269222
270223 // For t > 0, compute predicted variance βt
@@ -384,8 +337,7 @@ private bool IsVarianceTypeLearned()
384337
385338 protected override void Dispose ( bool disposing )
386339 {
387- _betas = null ;
388- _alphasCumulativeProducts = null ;
340+ _alphasCumProd = null ;
389341 base . Dispose ( disposing ) ;
390342 }
391343 }
0 commit comments