Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit ca2b6bd

Browse files
committed
KDPM2 Scheduler
1 parent 6c72048 commit ca2b6bd

File tree

5 files changed

+254
-6
lines changed

5 files changed

+254
-6
lines changed

OnnxStack.StableDiffusion/Diffusers/DiffuserBase.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,7 @@ protected static IScheduler GetScheduler(PromptOptions prompt, SchedulerOptions
283283
SchedulerType.EulerAncestral => new EulerAncestralScheduler(options),
284284
SchedulerType.DDPM => new DDPMScheduler(options),
285285
SchedulerType.DDIM => new DDIMScheduler(options),
286+
SchedulerType.KDPM2 => new KDPM2Scheduler(options),
286287
_ => default
287288
};
288289
}

OnnxStack.StableDiffusion/Enums/SchedulerType.cs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@ public enum SchedulerType
1717
DDPM = 3,
1818

1919
[Display(Name = "DDIM")]
20-
DDIM = 4
20+
DDIM = 4,
21+
22+
[Display(Name = "KDPM2")]
23+
KDPM2 = 5
2124
}
2225
}

OnnxStack.StableDiffusion/Schedulers/DDPMScheduler.cs

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,5 @@
11
using Microsoft.ML.OnnxRuntime.Tensors;
22
using NumSharp;
3-
using OnnxStack.Core;
4-
using OnnxStack.StableDiffusion.Config;
5-
using OnnxStack.StableDiffusion.Enums;
6-
using OnnxStack.StableDiffusion.Helpers;
73
using OnnxStack.StableDiffusion.Config;
84
using OnnxStack.StableDiffusion.Enums;
95
using OnnxStack.StableDiffusion.Helpers;
Lines changed: 248 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,248 @@
1+
using Microsoft.ML.OnnxRuntime.Tensors;
2+
using NumSharp;
3+
using OnnxStack.StableDiffusion.Config;
4+
using OnnxStack.StableDiffusion.Enums;
5+
using OnnxStack.StableDiffusion.Helpers;
6+
using System;
7+
using System.Collections.Generic;
8+
using System.Linq;
9+
10+
namespace OnnxStack.StableDiffusion.Schedulers
11+
{
12+
internal class KDPM2Scheduler : SchedulerBase
13+
{
14+
private int _stepIndex;
15+
private float[] _sigmas;
16+
private float[] _sigmasInterpol;
17+
private float[] _alphasCumProd;
18+
private DenseTensor<float> _sample;
19+
20+
/// <summary>
21+
/// Initializes a new instance of the <see cref="KDPM2Scheduler"/> class.
22+
/// </summary>
23+
/// <param name="stableDiffusionOptions">The stable diffusion options.</param>
24+
public KDPM2Scheduler() : this(new SchedulerOptions()) { }
25+
26+
/// <summary>
27+
/// Initializes a new instance of the <see cref="KDPM2Scheduler"/> class.
28+
/// </summary>
29+
/// <param name="stableDiffusionOptions">The stable diffusion options.</param>
30+
/// <param name="schedulerOptions">The scheduler options.</param>
31+
public KDPM2Scheduler(SchedulerOptions options) : base(options) { }
32+
33+
34+
/// <summary>
35+
/// Initializes this instance.
36+
/// </summary>
37+
protected override void Initialize()
38+
{
39+
_stepIndex = 0;
40+
_sample = null;
41+
_alphasCumProd = null;
42+
43+
var betas = GetBetaSchedule();
44+
var alphas = betas.Select(beta => 1.0f - beta);
45+
_alphasCumProd = alphas
46+
.Select((alpha, i) => alphas.Take(i + 1).Aggregate((a, b) => a * b))
47+
.ToArray();
48+
_sigmas = _alphasCumProd
49+
.Select(alpha_prod => (float)Math.Sqrt((1 - alpha_prod) / alpha_prod))
50+
.ToArray();
51+
52+
var initNoiseSigma = GetInitNoiseSigma(_sigmas);
53+
SetInitNoiseSigma(initNoiseSigma);
54+
}
55+
56+
57+
/// <summary>
58+
/// Sets the timesteps.
59+
/// </summary>
60+
/// <returns></returns>
61+
protected override int[] SetTimesteps()
62+
{
63+
// Create timesteps based on the specified strategy
64+
var sigmas = _sigmas.ToArray();
65+
var timesteps = GetTimesteps();
66+
var logSigmas = np.log(sigmas).ToArray<float>();
67+
var range = np.arange(0, (float)_sigmas.Length).ToArray<float>();
68+
sigmas = Interpolate(timesteps, range, _sigmas);
69+
70+
if (Options.UseKarrasSigmas)
71+
{
72+
sigmas = ConvertToKarras(sigmas);
73+
timesteps = SigmaToTimestep(sigmas, logSigmas);
74+
}
75+
76+
//# interpolate sigmas
77+
var sigmasInterpol = InterpolateSigmas(sigmas);
78+
79+
_sigmas = Interleave(sigmas);
80+
_sigmasInterpol = Interleave(sigmasInterpol);
81+
82+
var timestepsInterpol = SigmaToTimestep(sigmasInterpol, logSigmas);
83+
var interleavedTimesteps = timestepsInterpol
84+
.Concat(timesteps)
85+
.Select(x => (int)x)
86+
.OrderByDescending(x => x)
87+
.ToArray();
88+
return interleavedTimesteps;
89+
}
90+
91+
92+
/// <summary>
93+
/// Scales the input.
94+
/// </summary>
95+
/// <param name="sample">The sample.</param>
96+
/// <param name="timestep">The timestep.</param>
97+
/// <returns></returns>
98+
public override DenseTensor<float> ScaleInput(DenseTensor<float> sample, int timestep)
99+
{
100+
var sigma = _sample is null
101+
? _sigmas[_stepIndex]
102+
: _sigmasInterpol[_stepIndex];
103+
104+
sigma = (float)Math.Sqrt(Math.Pow(sigma, 2) + 1);
105+
return sample.DivideTensorByFloat(sigma);
106+
}
107+
108+
109+
/// <summary>
110+
/// Processes a inference step for the specified model output.
111+
/// </summary>
112+
/// <param name="modelOutput">The model output.</param>
113+
/// <param name="timestep">The timestep.</param>
114+
/// <param name="sample">The sample.</param>
115+
/// <param name="order">The order.</param>
116+
/// <returns></returns>
117+
/// <exception cref="System.ArgumentException">Invalid prediction_type: {SchedulerOptions.PredictionType}</exception>
118+
/// <exception cref="System.NotImplementedException">KDPM2Scheduler Thresholding currently not implemented</exception>
119+
public override DenseTensor<float> Step(DenseTensor<float> modelOutput, int timestep, DenseTensor<float> sample, int order = 4)
120+
{
121+
float sigma;
122+
float sigmaInterpol;
123+
float sigmaNext;
124+
bool isFirstPass = _sample is null;
125+
if (isFirstPass)
126+
{
127+
sigma = _sigmas[_stepIndex];
128+
sigmaInterpol = _sigmasInterpol[_stepIndex + 1];
129+
sigmaNext = _sigmas[_stepIndex + 1];
130+
}
131+
else
132+
{
133+
sigma = _sigmas[_stepIndex - 1];
134+
sigmaInterpol = _sigmasInterpol[_stepIndex];
135+
sigmaNext = _sigmas[_stepIndex];
136+
}
137+
138+
//# currently only gamma=0 is supported. This usually works best anyways.
139+
float gamma = 0f;
140+
float sigmaHat = sigma * (gamma + 1f);
141+
var sigmaInput = isFirstPass ? sigmaHat : sigmaInterpol;
142+
DenseTensor<float> predOriginalSample;
143+
if (Options.PredictionType == PredictionType.Epsilon)
144+
{
145+
predOriginalSample = sample.SubtractTensors(modelOutput.MultipleTensorByFloat(sigmaInput));
146+
}
147+
else if (Options.PredictionType == PredictionType.VariablePrediction)
148+
{
149+
var sigmaSqrt = (float)Math.Sqrt(sigmaInput * sigmaInput + 1f);
150+
predOriginalSample = sample.DivideTensorByFloat(sigmaSqrt)
151+
.AddTensors(modelOutput.MultipleTensorByFloat(-sigmaInput / sigmaSqrt));
152+
}
153+
else
154+
{
155+
predOriginalSample = modelOutput.ToDenseTensor();
156+
}
157+
158+
159+
float dt;
160+
DenseTensor<float> derivative;
161+
if (isFirstPass)
162+
{
163+
dt = sigmaInterpol - sigmaHat;
164+
derivative = sample
165+
.SubtractTensors(predOriginalSample)
166+
.DivideTensorByFloat(sigmaHat);
167+
_sample = sample.ToDenseTensor();
168+
}
169+
else
170+
{
171+
dt = sigmaNext - sigmaHat;
172+
derivative = sample
173+
.SubtractTensors(predOriginalSample)
174+
.DivideTensorByFloat(sigmaInterpol);
175+
sample = _sample;
176+
_sample = null;
177+
}
178+
179+
_stepIndex += 1;
180+
return sample.AddTensors(derivative.MultipleTensorByFloat(dt));
181+
}
182+
183+
184+
/// <summary>
185+
/// Adds noise to the sample.
186+
/// </summary>
187+
/// <param name="originalSamples">The original samples.</param>
188+
/// <param name="noise">The noise.</param>
189+
/// <param name="timesteps">The timesteps.</param>
190+
/// <returns></returns>
191+
public override DenseTensor<float> AddNoise(DenseTensor<float> originalSamples, DenseTensor<float> noise, IReadOnlyList<int> timesteps)
192+
{
193+
var sigma = _sigmas[_stepIndex];
194+
return noise
195+
.MultipleTensorByFloat(sigma)
196+
.AddTensors(originalSamples);
197+
}
198+
199+
200+
/// <summary>
201+
/// Interpolates the sigmas.
202+
/// </summary>
203+
/// <param name="sigmas">The sigmas.</param>
204+
/// <returns></returns>
205+
public float[] InterpolateSigmas(float[] sigmas)
206+
{
207+
var rolledLogSigmas = sigmas
208+
.Append(0f)
209+
.Select((value, index) => (float)Math.Log(sigmas[(index + sigmas.Length - 1) % sigmas.Length]))
210+
.ToArray();
211+
212+
var lerpSigmas = new float[rolledLogSigmas.Length - 1];
213+
for (int i = 0; i < rolledLogSigmas.Length - 1; i++)
214+
{
215+
lerpSigmas[i] = (float)Math.Exp(rolledLogSigmas[i] + 0.5f * (rolledLogSigmas[i + 1] - rolledLogSigmas[i]));
216+
}
217+
return lerpSigmas;
218+
}
219+
220+
221+
/// <summary>
222+
/// Interleaves the specified sigmas.
223+
/// </summary>
224+
/// <param name="sigmas">The sigmas.</param>
225+
/// <returns></returns>
226+
private float[] Interleave(float[] sigmas)
227+
{
228+
var first = sigmas.First();
229+
var last = sigmas.Last();
230+
return sigmas.Skip(1)
231+
.SelectMany(value => new[] { value, value })
232+
.Prepend(first)
233+
.Append(last)
234+
.ToArray();
235+
}
236+
237+
238+
/// <summary>
239+
/// Releases unmanaged and - optionally - managed resources.
240+
/// </summary>
241+
/// <param name="disposing"><c>true</c> to release both managed and unmanaged resources; <c>false</c> to release only unmanaged resources.</param>
242+
protected override void Dispose(bool disposing)
243+
{
244+
_alphasCumProd = null;
245+
base.Dispose(disposing);
246+
}
247+
}
248+
}

OnnxStack.StableDiffusion/Schedulers/SchedulerBase.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ protected float[] GetBetasForAlphaBar()
276276
protected float[] Interpolate(float[] timesteps, float[] range, float[] sigmas)
277277
{
278278
// Create an output array with the same shape as timesteps
279-
var result = new float[timesteps.Length + 1];
279+
var result = new float[timesteps.Length];
280280

281281
// Loop over each element of timesteps
282282
for (int i = 0; i < timesteps.Length; i++)

0 commit comments

Comments
 (0)