Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit 97263ed

Browse files
committed
Try Fix LCM 1 and 2 step inference
1 parent 7ac3b93 commit 97263ed

File tree

2 files changed

+34
-21
lines changed

2 files changed

+34
-21
lines changed

OnnxStack.StableDiffusion/Schedulers/LatentConsistency/LCMScheduler.cs

Lines changed: 33 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -67,19 +67,19 @@ protected override int[] SetTimesteps()
6767
var timeIncrement = Options.TrainTimesteps / Options.OriginalInferenceSteps;
6868

6969
//# LCM Training Steps Schedule
70-
var lcmOriginTimesteps = Enumerable.Range(1, Options.OriginalInferenceSteps)
71-
.Select(x => x * timeIncrement - 1f)
70+
var lcmOriginTimesteps = Enumerable.Range(0, Options.OriginalInferenceSteps)
71+
.Select(x => x * (timeIncrement - 1))
7272
.ToArray();
7373

7474
var skippingStep = lcmOriginTimesteps.Length / Options.InferenceSteps;
7575

7676
// LCM Inference Steps Schedule
77-
return lcmOriginTimesteps
77+
var steps = lcmOriginTimesteps
7878
.Where((t, index) => index % skippingStep == 0)
7979
.Take(Options.InferenceSteps)
80-
.Select(x => (int)x)
8180
.OrderByDescending(x => x)
8281
.ToArray();
82+
return steps;
8383
}
8484

8585

@@ -111,19 +111,20 @@ public override SchedulerStepResult Step(DenseTensor<float> modelOutput, int tim
111111

112112
// 1. get previous step value
113113
int prevIndex = Timesteps.IndexOf(currentTimestep) + 1;
114-
int previousTimestep = prevIndex < Timesteps.Count
115-
? Timesteps[prevIndex]
114+
int previousTimestep = prevIndex < Timesteps.Count
115+
? Timesteps[prevIndex]
116116
: currentTimestep;
117117

118118
//# 2. compute alphas, betas
119119
float alphaProdT = _alphasCumProd[currentTimestep];
120-
float alphaProdTPrev = previousTimestep >= 0
121-
? _alphasCumProd[previousTimestep]
120+
float alphaProdTPrev = previousTimestep >= 0
121+
? _alphasCumProd[previousTimestep]
122122
: _finalAlphaCumprod;
123123
float betaProdT = 1f - alphaProdT;
124124
float betaProdTPrev = 1f - alphaProdTPrev;
125-
float alphaSqrt = MathF.Sqrt(alphaProdT);
126-
float betaSqrt = MathF.Sqrt(betaProdT);
125+
126+
float alphaProdTSqrt = MathF.Sqrt(alphaProdT);
127+
float betaProdTSqrt = MathF.Sqrt(betaProdT);
127128
float betaProdTPrevSqrt = MathF.Sqrt(betaProdTPrev);
128129
float alphaProdTPrevSqrt = MathF.Sqrt(alphaProdTPrev);
129130

@@ -137,8 +138,8 @@ public override SchedulerStepResult Step(DenseTensor<float> modelOutput, int tim
137138
if (Options.PredictionType == PredictionType.Epsilon)
138139
{
139140
predOriginalSample = sample
140-
.SubtractTensors(modelOutput.MultiplyTensorByFloat(betaSqrt))
141-
.DivideTensorByFloat(alphaSqrt);
141+
.SubtractTensors(modelOutput.MultiplyTensorByFloat(betaProdTSqrt))
142+
.DivideTensorByFloat(alphaProdTSqrt);
142143
}
143144
else if (Options.PredictionType == PredictionType.Sample)
144145
{
@@ -147,8 +148,8 @@ public override SchedulerStepResult Step(DenseTensor<float> modelOutput, int tim
147148
else if (Options.PredictionType == PredictionType.VariablePrediction)
148149
{
149150
predOriginalSample = sample
150-
.MultiplyTensorByFloat(alphaSqrt)
151-
.SubtractTensors(modelOutput.MultiplyTensorByFloat(betaSqrt));
151+
.MultiplyTensorByFloat(alphaProdTSqrt)
152+
.SubtractTensors(modelOutput.MultiplyTensorByFloat(betaProdTSqrt));
152153
}
153154

154155

@@ -163,13 +164,22 @@ public override SchedulerStepResult Step(DenseTensor<float> modelOutput, int tim
163164

164165

165166
//# 7. Sample and inject noise z ~ N(0, I) for MultiStep Inference
166-
var prevSample = Timesteps.Count > 1
167-
? CreateRandomSample(modelOutput.Dimensions)
167+
//# Noise is not used on the final timestep of the timestep schedule.
168+
//# This also means that noise is not used for one-step sampling.
169+
if (Timesteps.IndexOf(currentTimestep) != Options.InferenceSteps - 1)
170+
{
171+
var noise = CreateRandomSample(modelOutput.Dimensions);
172+
predOriginalSample = noise
168173
.MultiplyTensorByFloat(betaProdTPrevSqrt)
169-
.AddTensors(denoised.MultiplyTensorByFloat(alphaProdTPrevSqrt))
170-
: denoised;
174+
.AddTensors(denoised.MultiplyTensorByFloat(alphaProdTPrevSqrt));
175+
}
176+
else
177+
{
178+
predOriginalSample = denoised;
179+
}
180+
171181

172-
return new SchedulerStepResult(prevSample, denoised);
182+
return new SchedulerStepResult(predOriginalSample, denoised);
173183
}
174184

175185

@@ -203,10 +213,12 @@ public override DenseTensor<float> AddNoise(DenseTensor<float> originalSamples,
203213
{
204214
//self.sigma_data = 0.5 # Default: 0.5
205215
var sigmaData = 0.5f;
216+
var timestepScaling = 10f;
217+
var scaledTimestep = timestepScaling * timestep;
206218

207-
float c = MathF.Pow(timestep / 0.1f, 2f) + MathF.Pow(sigmaData, 2f);
219+
float c = MathF.Pow(scaledTimestep, 2f) + MathF.Pow(sigmaData, 2f);
208220
float cSkip = MathF.Pow(sigmaData, 2f) / c;
209-
float cOut = timestep / 0.1f / MathF.Pow(c, 0.5f);
221+
float cOut = scaledTimestep / MathF.Pow(c, 0.5f);
210222
return (cSkip, cOut);
211223
}
212224

OnnxStack.UI/UserControls/SchedulerControl.xaml.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ private void OnModelChanged(ModelOptionsModel model)
122122
{
123123
SchedulerOptions.OriginalInferenceSteps = 50;
124124
SchedulerOptions.InferenceSteps = 6;
125+
SchedulerOptions.GuidanceScale = 1f;
125126
}
126127

127128

0 commit comments

Comments
 (0)