11using Microsoft . ML . OnnxRuntime . Tensors ;
2+ using OnnxStack . Core ;
23using OnnxStack . StableDiffusion . Config ;
34using OnnxStack . StableDiffusion . Enums ;
45using OnnxStack . StableDiffusion . Helpers ;
@@ -109,13 +110,23 @@ public override SchedulerStepResult Step(DenseTensor<float> modelOutput, int tim
109110 int currentTimestep = timestep ;
110111
111112 // 1. get previous step value
112- int previousTimestep = GetPreviousTimestep ( currentTimestep ) ;
113+ int prevIndex = Timesteps . IndexOf ( currentTimestep ) + 1 ;
114+ int previousTimestep = prevIndex < Timesteps . Count
115+ ? Timesteps [ prevIndex ]
116+ : currentTimestep ;
113117
114118 //# 2. compute alphas, betas
115119 float alphaProdT = _alphasCumProd [ currentTimestep ] ;
116- float alphaProdTPrev = previousTimestep >= 0 ? _alphasCumProd [ previousTimestep ] : _finalAlphaCumprod ;
120+ float alphaProdTPrev = previousTimestep >= 0
121+ ? _alphasCumProd [ previousTimestep ]
122+ : _finalAlphaCumprod ;
117123 float betaProdT = 1f - alphaProdT ;
118124 float betaProdTPrev = 1f - alphaProdTPrev ;
125+ float alphaSqrt = MathF . Sqrt ( alphaProdT ) ;
126+ float betaSqrt = MathF . Sqrt ( betaProdT ) ;
127+ float betaProdTPrevSqrt = MathF . Sqrt ( betaProdTPrev ) ;
128+ float alphaProdTPrevSqrt = MathF . Sqrt ( alphaProdTPrev ) ;
129+
119130
120131 // 3.Get scalings for boundary conditions
121132 ( float cSkip , float cOut ) = GetBoundaryConditionScalings ( currentTimestep ) ;
@@ -125,17 +136,16 @@ public override SchedulerStepResult Step(DenseTensor<float> modelOutput, int tim
125136 DenseTensor < float > predOriginalSample = null ;
126137 if ( Options . PredictionType == PredictionType . Epsilon )
127138 {
128- var sampleBeta = sample . SubtractTensors ( modelOutput . MultipleTensorByFloat ( ( float ) Math . Sqrt ( betaProdT ) ) ) ;
129- predOriginalSample = sampleBeta . DivideTensorByFloat ( ( float ) Math . Sqrt ( alphaProdT ) ) ;
139+ predOriginalSample = sample
140+ . SubtractTensors ( modelOutput . MultipleTensorByFloat ( betaSqrt ) )
141+ . DivideTensorByFloat ( alphaSqrt ) ;
130142 }
131143 else if ( Options . PredictionType == PredictionType . Sample )
132144 {
133145 predOriginalSample = modelOutput ;
134146 }
135147 else if ( Options . PredictionType == PredictionType . VariablePrediction )
136148 {
137- var alphaSqrt = ( float ) Math . Sqrt ( alphaProdT ) ;
138- var betaSqrt = ( float ) Math . Sqrt ( betaProdT ) ;
139149 predOriginalSample = sample
140150 . MultipleTensorByFloat ( alphaSqrt )
141151 . SubtractTensors ( modelOutput . MultipleTensorByFloat ( betaSqrt ) ) ;
@@ -155,8 +165,8 @@ public override SchedulerStepResult Step(DenseTensor<float> modelOutput, int tim
155165 //# 7. Sample and inject noise z ~ N(0, I) for MultiStep Inference
156166 var prevSample = Timesteps . Count > 1
157167 ? CreateRandomSample ( modelOutput . Dimensions )
158- . MultipleTensorByFloat ( MathF . Sqrt ( betaProdTPrev ) )
159- . AddTensors ( denoised . MultipleTensorByFloat ( MathF . Sqrt ( alphaProdTPrev ) ) )
168+ . MultipleTensorByFloat ( betaProdTPrevSqrt )
169+ . AddTensors ( denoised . MultipleTensorByFloat ( alphaProdTPrevSqrt ) )
160170 : denoised ;
161171
162172 return new SchedulerStepResult ( prevSample , denoised ) ;
@@ -175,8 +185,8 @@ public override DenseTensor<float> AddNoise(DenseTensor<float> originalSamples,
175185 // Ref: https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_ddpm.py#L456
176186 int timestep = timesteps [ 0 ] ;
177187 float alphaProd = _alphasCumProd [ timestep ] ;
178- float sqrtAlpha = ( float ) Math . Sqrt ( alphaProd ) ;
179- float sqrtOneMinusAlpha = ( float ) Math . Sqrt ( 1.0f - alphaProd ) ;
188+ float sqrtAlpha = MathF . Sqrt ( alphaProd ) ;
189+ float sqrtOneMinusAlpha = MathF . Sqrt ( 1.0f - alphaProd ) ;
180190
181191 return noise
182192 . MultipleTensorByFloat ( sqrtOneMinusAlpha )
0 commit comments