Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit 87bcdb8

Browse files
committed
LCM Scheduler prototype
1 parent 0ca1152 commit 87bcdb8

File tree

1 file changed

+216
-0
lines changed

1 file changed

+216
-0
lines changed
Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
using Microsoft.ML.OnnxRuntime.Tensors;
2+
using OnnxStack.StableDiffusion.Config;
3+
using OnnxStack.StableDiffusion.Enums;
4+
using OnnxStack.StableDiffusion.Helpers;
5+
using System;
6+
using System.Collections.Generic;
7+
using System.Linq;
8+
9+
namespace OnnxStack.StableDiffusion.Schedulers
10+
{
11+
internal class LCMScheduler : SchedulerBase
12+
{
13+
private float[] _alphasCumProd;
14+
private float _finalAlphaCumprod;
15+
private int _originalInferenceSteps;
16+
17+
/// <summary>
18+
/// Initializes a new instance of the <see cref="LCMScheduler"/> class.
19+
/// </summary>
20+
/// <param name="stableDiffusionOptions">The stable diffusion options.</param>
21+
public LCMScheduler() : this(new SchedulerOptions()) { }
22+
23+
/// <summary>
24+
/// Initializes a new instance of the <see cref="LCMScheduler"/> class.
25+
/// </summary>
26+
/// <param name="stableDiffusionOptions">The stable diffusion options.</param>
27+
/// <param name="schedulerOptions">The scheduler options.</param>
28+
public LCMScheduler(SchedulerOptions options) : base(options) { }
29+
30+
31+
/// <summary>
32+
/// Initializes this instance.
33+
/// </summary>
34+
protected override void Initialize()
35+
{
36+
_alphasCumProd = null;
37+
38+
var betas = GetBetaSchedule();
39+
var alphas = betas.Select(beta => 1.0f - beta);
40+
_alphasCumProd = alphas
41+
.Select((alpha, i) => alphas.Take(i + 1).Aggregate((a, b) => a * b))
42+
.ToArray();
43+
44+
bool setAlphaToOne = true;
45+
_finalAlphaCumprod = setAlphaToOne
46+
? 1.0f
47+
: _alphasCumProd.First();
48+
49+
//The default number of inference steps used to generate a linearly - spaced timestep schedule, from which we
50+
//will ultimately take `num_inference_steps` evenly spaced timesteps to form the final timestep schedule.
51+
_originalInferenceSteps = Options.InferenceSteps;
52+
53+
SetInitNoiseSigma(1.0f);
54+
}
55+
56+
57+
/// <summary>
58+
/// Sets the timesteps.
59+
/// </summary>
60+
/// <returns></returns>
61+
protected override int[] SetTimesteps()
62+
{
63+
// LCM Timesteps Setting
64+
// Currently, only linear spacing is supported.
65+
var timeIncrement = (float)Options.TrainTimesteps / _originalInferenceSteps;
66+
67+
//# LCM Training Steps Schedule
68+
var lcmOriginTimesteps = Enumerable.Range(1, _originalInferenceSteps)
69+
.Select(x => x * timeIncrement - 1f)
70+
.ToArray();
71+
72+
var skippingStep = (float)lcmOriginTimesteps.Length / Options.InferenceSteps;
73+
74+
// LCM Inference Steps Schedule
75+
return lcmOriginTimesteps
76+
.Where((t, index) => index % skippingStep == 0)
77+
.Take(Options.InferenceSteps)
78+
.Select(x => (int)x)
79+
.OrderByDescending(x => x)
80+
.ToArray();
81+
}
82+
83+
84+
/// <summary>
85+
/// Scales the input.
86+
/// </summary>
87+
/// <param name="sample">The sample.</param>
88+
/// <param name="timestep">The timestep.</param>
89+
/// <returns></returns>
90+
public override DenseTensor<float> ScaleInput(DenseTensor<float> sample, int timestep)
91+
{
92+
return sample;
93+
}
94+
95+
96+
/// <summary>
97+
/// Processes a inference step for the specified model output.
98+
/// </summary>
99+
/// <param name="modelOutput">The model output.</param>
100+
/// <param name="timestep">The timestep.</param>
101+
/// <param name="sample">The sample.</param>
102+
/// <param name="order">The order.</param>
103+
/// <returns></returns>
104+
public override DenseTensor<float> Step(DenseTensor<float> modelOutput, int timestep, DenseTensor<float> sample, int order = 4)
105+
{
106+
//# Latent Consistency Models paper https://arxiv.org/abs/2310.04378
107+
108+
int currentTimestep = timestep;
109+
110+
// 1. get previous step value
111+
int previousTimestep = GetPreviousTimestep(currentTimestep);
112+
113+
//# 2. compute alphas, betas
114+
float alphaProdT = _alphasCumProd[currentTimestep];
115+
float alphaProdTPrev = previousTimestep >= 0 ? _alphasCumProd[previousTimestep] : _finalAlphaCumprod;
116+
float betaProdT = 1f - alphaProdT;
117+
float betaProdTPrev = 1f - alphaProdTPrev;
118+
119+
// 3.Get scalings for boundary conditions
120+
(float cSkip, float cOut) = GetBoundaryConditionScalings(currentTimestep);
121+
122+
123+
//# 4. compute predicted original sample from predicted noise also called "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
124+
DenseTensor<float> predOriginalSample = null;
125+
if (Options.PredictionType == PredictionType.Epsilon)
126+
{
127+
var sampleBeta = sample.SubtractTensors(modelOutput.MultipleTensorByFloat((float)Math.Sqrt(betaProdT)));
128+
predOriginalSample = sampleBeta.DivideTensorByFloat((float)Math.Sqrt(alphaProdT));
129+
}
130+
else if (Options.PredictionType == PredictionType.Sample)
131+
{
132+
predOriginalSample = modelOutput;
133+
}
134+
else if (Options.PredictionType == PredictionType.VariablePrediction)
135+
{
136+
var alphaSqrt = (float)Math.Sqrt(alphaProdT);
137+
var betaSqrt = (float)Math.Sqrt(betaProdT);
138+
predOriginalSample = sample
139+
.MultipleTensorByFloat(alphaSqrt)
140+
.SubtractTensors(modelOutput.MultipleTensorByFloat(betaSqrt));
141+
}
142+
143+
144+
//# 5. Clip or threshold "predicted x_0"
145+
// TODO: OnnxStack does not yet support Threshold and Clipping
146+
147+
148+
//# 6. Denoise model output using boundary conditions
149+
var denoised = sample
150+
.MultipleTensorByFloat(cSkip)
151+
.AddTensors(predOriginalSample.MultipleTensorByFloat(cOut));
152+
153+
154+
//# 7. Sample and inject noise z ~ N(0, I) for MultiStep Inference
155+
var prevSample = Timesteps.Count > 1
156+
? CreateRandomSample(modelOutput.Dimensions)
157+
.MultipleTensorByFloat(MathF.Sqrt(betaProdTPrev))
158+
.AddTensors(denoised.MultipleTensorByFloat(MathF.Sqrt(alphaProdTPrev)))
159+
: denoised;
160+
161+
return prevSample;
162+
}
163+
164+
165+
/// <summary>
166+
/// Adds noise to the sample.
167+
/// </summary>
168+
/// <param name="originalSamples">The original samples.</param>
169+
/// <param name="noise">The noise.</param>
170+
/// <param name="timesteps">The timesteps.</param>
171+
/// <returns></returns>
172+
public override DenseTensor<float> AddNoise(DenseTensor<float> originalSamples, DenseTensor<float> noise, IReadOnlyList<int> timesteps)
173+
{
174+
// Ref: https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_ddpm.py#L456
175+
int timestep = timesteps[0];
176+
float alphaProd = _alphasCumProd[timestep];
177+
float sqrtAlpha = (float)Math.Sqrt(alphaProd);
178+
float sqrtOneMinusAlpha = (float)Math.Sqrt(1.0f - alphaProd);
179+
180+
return noise
181+
.MultipleTensorByFloat(sqrtOneMinusAlpha)
182+
.AddTensors(originalSamples.MultipleTensorByFloat(sqrtAlpha));
183+
}
184+
185+
186+
/// <summary>
187+
/// Gets the boundary condition scalings.
188+
/// </summary>
189+
/// <param name="timestep">The timestep.</param>
190+
/// <returns></returns>
191+
public (float cSkip, float cOut) GetBoundaryConditionScalings(float timestep)
192+
{
193+
//self.sigma_data = 0.5 # Default: 0.5
194+
var sigmaData = 0.5f;
195+
196+
//c_skip = self.sigma_data * *2 / ((t / 0.1) * *2 + self.sigma_data * *2)
197+
float cSkip = MathF.Pow(sigmaData, 2f) / (MathF.Pow(timestep / 0.1f, 2f) + MathF.Pow(sigmaData, 2f));
198+
199+
//c_out = (t / 0.1) / ((t / 0.1) * *2 + self.sigma_data * *2) * *0.5
200+
float cOut = (timestep / 0.1f) / (MathF.Pow(timestep / 0.1f, 2f) + MathF.Pow(sigmaData, 2f)) * 0.5f;
201+
202+
return (cSkip, cOut);
203+
}
204+
205+
206+
/// <summary>
207+
/// Releases unmanaged and - optionally - managed resources.
208+
/// </summary>
209+
/// <param name="disposing"><c>true</c> to release both managed and unmanaged resources; <c>false</c> to release only unmanaged resources.</param>
210+
protected override void Dispose(bool disposing)
211+
{
212+
_alphasCumProd = null;
213+
base.Dispose(disposing);
214+
}
215+
}
216+
}

0 commit comments

Comments
 (0)