@@ -67,19 +67,19 @@ protected override int[] SetTimesteps()
6767 var timeIncrement = Options . TrainTimesteps / Options . OriginalInferenceSteps ;
6868
6969 //# LCM Training Steps Schedule
70- var lcmOriginTimesteps = Enumerable . Range ( 1 , Options . OriginalInferenceSteps )
71- . Select ( x => x * timeIncrement - 1f )
70+ var lcmOriginTimesteps = Enumerable . Range ( 0 , Options . OriginalInferenceSteps )
71+ . Select ( x => x * ( timeIncrement - 1 ) )
7272 . ToArray ( ) ;
7373
7474 var skippingStep = lcmOriginTimesteps . Length / Options . InferenceSteps ;
7575
7676 // LCM Inference Steps Schedule
77- return lcmOriginTimesteps
77+ var steps = lcmOriginTimesteps
7878 . Where ( ( t , index ) => index % skippingStep == 0 )
7979 . Take ( Options . InferenceSteps )
80- . Select ( x => ( int ) x )
8180 . OrderByDescending ( x => x )
8281 . ToArray ( ) ;
82+ return steps ;
8383 }
8484
8585
@@ -111,19 +111,20 @@ public override SchedulerStepResult Step(DenseTensor<float> modelOutput, int tim
111111
112112 // 1. get previous step value
113113 int prevIndex = Timesteps . IndexOf ( currentTimestep ) + 1 ;
114- int previousTimestep = prevIndex < Timesteps . Count
115- ? Timesteps [ prevIndex ]
114+ int previousTimestep = prevIndex < Timesteps . Count
115+ ? Timesteps [ prevIndex ]
116116 : currentTimestep ;
117117
118118 //# 2. compute alphas, betas
119119 float alphaProdT = _alphasCumProd [ currentTimestep ] ;
120- float alphaProdTPrev = previousTimestep >= 0
121- ? _alphasCumProd [ previousTimestep ]
120+ float alphaProdTPrev = previousTimestep >= 0
121+ ? _alphasCumProd [ previousTimestep ]
122122 : _finalAlphaCumprod ;
123123 float betaProdT = 1f - alphaProdT ;
124124 float betaProdTPrev = 1f - alphaProdTPrev ;
125- float alphaSqrt = MathF . Sqrt ( alphaProdT ) ;
126- float betaSqrt = MathF . Sqrt ( betaProdT ) ;
125+
126+ float alphaProdTSqrt = MathF . Sqrt ( alphaProdT ) ;
127+ float betaProdTSqrt = MathF . Sqrt ( betaProdT ) ;
127128 float betaProdTPrevSqrt = MathF . Sqrt ( betaProdTPrev ) ;
128129 float alphaProdTPrevSqrt = MathF . Sqrt ( alphaProdTPrev ) ;
129130
@@ -137,8 +138,8 @@ public override SchedulerStepResult Step(DenseTensor<float> modelOutput, int tim
137138 if ( Options . PredictionType == PredictionType . Epsilon )
138139 {
139140 predOriginalSample = sample
140- . SubtractTensors ( modelOutput . MultiplyTensorByFloat ( betaSqrt ) )
141- . DivideTensorByFloat ( alphaSqrt ) ;
141+ . SubtractTensors ( modelOutput . MultiplyTensorByFloat ( betaProdTSqrt ) )
142+ . DivideTensorByFloat ( alphaProdTSqrt ) ;
142143 }
143144 else if ( Options . PredictionType == PredictionType . Sample )
144145 {
@@ -147,8 +148,8 @@ public override SchedulerStepResult Step(DenseTensor<float> modelOutput, int tim
147148 else if ( Options . PredictionType == PredictionType . VariablePrediction )
148149 {
149150 predOriginalSample = sample
150- . MultiplyTensorByFloat ( alphaSqrt )
151- . SubtractTensors ( modelOutput . MultiplyTensorByFloat ( betaSqrt ) ) ;
151+ . MultiplyTensorByFloat ( alphaProdTSqrt )
152+ . SubtractTensors ( modelOutput . MultiplyTensorByFloat ( betaProdTSqrt ) ) ;
152153 }
153154
154155
@@ -163,13 +164,22 @@ public override SchedulerStepResult Step(DenseTensor<float> modelOutput, int tim
163164
164165
165166 //# 7. Sample and inject noise z ~ N(0, I) for MultiStep Inference
166- var prevSample = Timesteps . Count > 1
167- ? CreateRandomSample ( modelOutput . Dimensions )
167+ //# Noise is not used on the final timestep of the timestep schedule.
168+ //# This also means that noise is not used for one-step sampling.
169+ if ( Timesteps . IndexOf ( currentTimestep ) != Options . InferenceSteps - 1 )
170+ {
171+ var noise = CreateRandomSample ( modelOutput . Dimensions ) ;
172+ predOriginalSample = noise
168173 . MultiplyTensorByFloat ( betaProdTPrevSqrt )
169- . AddTensors ( denoised . MultiplyTensorByFloat ( alphaProdTPrevSqrt ) )
170- : denoised ;
174+ . AddTensors ( denoised . MultiplyTensorByFloat ( alphaProdTPrevSqrt ) ) ;
175+ }
176+ else
177+ {
178+ predOriginalSample = denoised ;
179+ }
180+
171181
172- return new SchedulerStepResult ( prevSample , denoised ) ;
182+ return new SchedulerStepResult ( predOriginalSample , denoised ) ;
173183 }
174184
175185
@@ -203,10 +213,12 @@ public override DenseTensor<float> AddNoise(DenseTensor<float> originalSamples,
203213 {
204214 //self.sigma_data = 0.5 # Default: 0.5
205215 var sigmaData = 0.5f ;
216+ var timestepScaling = 10f ;
217+ var scaledTimestep = timestepScaling * timestep ;
206218
207- float c = MathF . Pow ( timestep / 0.1f , 2f ) + MathF . Pow ( sigmaData , 2f ) ;
219+ float c = MathF . Pow ( scaledTimestep , 2f ) + MathF . Pow ( sigmaData , 2f ) ;
208220 float cSkip = MathF . Pow ( sigmaData , 2f ) / c ;
209- float cOut = timestep / 0.1f / MathF . Pow ( c , 0.5f ) ;
221+ float cOut = scaledTimestep / MathF . Pow ( c , 0.5f ) ;
210222 return ( cSkip , cOut ) ;
211223 }
212224
0 commit comments