Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit d376d16

Browse files
committed
DDIM Scheduler
1 parent c93cab5 commit d376d16

File tree

3 files changed

+235
-1
lines changed

3 files changed

+235
-1
lines changed

OnnxStack.StableDiffusion/Diffusers/DiffuserBase.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,7 @@ protected static IScheduler GetScheduler(PromptOptions prompt, SchedulerOptions
281281
SchedulerType.LMS => new LMSScheduler(options),
282282
SchedulerType.EulerAncestral => new EulerAncestralScheduler(options),
283283
SchedulerType.DDPM => new DDPMScheduler(options),
284+
SchedulerType.DDIM => new DDIMScheduler(options),
284285
_ => default
285286
};
286287
}

OnnxStack.StableDiffusion/Enums/SchedulerType.cs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@ public enum SchedulerType
1111
EulerAncestral = 1,
1212

1313
[Display(Name = "DDPM")]
14-
DDPM = 3
14+
DDPM = 3,
15+
16+
[Display(Name = "DDIM")]
17+
DDIM = 4
1518
}
1619
}
Lines changed: 230 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,230 @@
1+
using Microsoft.ML.OnnxRuntime.Tensors;
2+
using OnnxStack.StableDiffusion.Config;
3+
using OnnxStack.StableDiffusion.Enums;
4+
using OnnxStack.StableDiffusion.Helpers;
5+
using System;
6+
using System.Collections.Generic;
7+
using System.Linq;
8+
9+
namespace OnnxStack.StableDiffusion.Schedulers
10+
{
11+
internal class DDIMScheduler : SchedulerBase
12+
{
13+
private float[] _alphasCumProd;
14+
private float _finalAlphaCumprod;
15+
16+
/// <summary>
17+
/// Initializes a new instance of the <see cref="DDIMScheduler"/> class.
18+
/// </summary>
19+
/// <param name="stableDiffusionOptions">The stable diffusion options.</param>
20+
public DDIMScheduler() : this(new SchedulerOptions()) { }
21+
22+
/// <summary>
23+
/// Initializes a new instance of the <see cref="DDIMScheduler"/> class.
24+
/// </summary>
25+
/// <param name="stableDiffusionOptions">The stable diffusion options.</param>
26+
/// <param name="schedulerOptions">The scheduler options.</param>
27+
public DDIMScheduler(SchedulerOptions options) : base(options) { }
28+
29+
30+
/// <summary>
31+
/// Initializes this instance.
32+
/// </summary>
33+
protected override void Initialize()
34+
{
35+
_alphasCumProd = null;
36+
37+
var betas = GetBetaSchedule();
38+
var alphas = betas.Select(beta => 1.0f - beta);
39+
_alphasCumProd = alphas
40+
.Select((alpha, i) => alphas.Take(i + 1).Aggregate((a, b) => a * b))
41+
.ToArray();
42+
43+
bool setAlphaToOne = true;
44+
_finalAlphaCumprod = setAlphaToOne
45+
? 1.0f
46+
: _alphasCumProd.First();
47+
48+
SetInitNoiseSigma(1.0f);
49+
}
50+
51+
52+
/// <summary>
53+
/// Sets the timesteps.
54+
/// </summary>
55+
/// <returns></returns>
56+
protected override int[] SetTimesteps()
57+
{
58+
// Create timesteps based on the specified strategy
59+
var timesteps = GetTimesteps();
60+
return timesteps
61+
.Select(x => (int)x)
62+
.OrderByDescending(x => x)
63+
.ToArray();
64+
}
65+
66+
67+
/// <summary>
68+
/// Scales the input.
69+
/// </summary>
70+
/// <param name="sample">The sample.</param>
71+
/// <param name="timestep">The timestep.</param>
72+
/// <returns></returns>
73+
public override DenseTensor<float> ScaleInput(DenseTensor<float> sample, int timestep)
74+
{
75+
return sample;
76+
}
77+
78+
79+
/// <summary>
80+
/// Processes a inference step for the specified model output.
81+
/// </summary>
82+
/// <param name="modelOutput">The model output.</param>
83+
/// <param name="timestep">The timestep.</param>
84+
/// <param name="sample">The sample.</param>
85+
/// <param name="order">The order.</param>
86+
/// <returns></returns>
87+
/// <exception cref="System.ArgumentException">Invalid prediction_type: {SchedulerOptions.PredictionType}</exception>
88+
/// <exception cref="System.NotImplementedException">DDIMScheduler Thresholding currently not implemented</exception>
89+
public override DenseTensor<float> Step(DenseTensor<float> modelOutput, int timestep, DenseTensor<float> sample, int order = 4)
90+
{
91+
//# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
92+
//# Ideally, read DDIM paper in-detail understanding
93+
94+
//# Notation (<variable name> -> <name in paper>
95+
//# - pred_noise_t -> e_theta(x_t, t)
96+
//# - pred_original_sample -> f_theta(x_t, t) or x_0
97+
//# - std_dev_t -> sigma_t
98+
//# - eta -> η
99+
//# - pred_sample_direction -> "direction pointing to x_t"
100+
//# - pred_prev_sample -> "x_t-1"
101+
102+
int currentTimestep = timestep;
103+
int previousTimestep = GetPreviousTimestep(currentTimestep);
104+
105+
//# 1. compute alphas, betas
106+
float alphaProdT = _alphasCumProd[currentTimestep];
107+
float alphaProdTPrev = previousTimestep >= 0 ? _alphasCumProd[previousTimestep] : _finalAlphaCumprod;
108+
float betaProdT = 1f - alphaProdT;
109+
110+
111+
//# 2. compute predicted original sample from predicted noise also called
112+
//# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
113+
DenseTensor<float> predEpsilon = null;
114+
DenseTensor<float> predOriginalSample = null;
115+
if (Options.PredictionType == PredictionType.Epsilon)
116+
{
117+
var sampleBeta = sample.SubtractTensors(modelOutput.MultipleTensorByFloat((float)Math.Sqrt(betaProdT)));
118+
predOriginalSample = sampleBeta.DivideTensorByFloat((float)Math.Sqrt(alphaProdT));
119+
predEpsilon = modelOutput;
120+
}
121+
else if (Options.PredictionType == PredictionType.Sample)
122+
{
123+
predOriginalSample = modelOutput;
124+
predEpsilon = sample.SubtractTensors(predOriginalSample
125+
.MultipleTensorByFloat((float)Math.Sqrt(alphaProdT)))
126+
.DivideTensorByFloat((float)Math.Sqrt(betaProdT));
127+
}
128+
else if (Options.PredictionType == PredictionType.VariablePrediction)
129+
{
130+
var alphaSqrt = (float)Math.Sqrt(alphaProdT);
131+
var betaSqrt = (float)Math.Sqrt(betaProdT);
132+
predOriginalSample = sample
133+
.MultipleTensorByFloat(alphaSqrt)
134+
.SubtractTensors(modelOutput.MultipleTensorByFloat(betaSqrt));
135+
predEpsilon = modelOutput
136+
.MultipleTensorByFloat(alphaSqrt)
137+
.AddTensors(sample.MultipleTensorByFloat(betaSqrt));
138+
}
139+
140+
141+
//# 3. Clip or threshold "predicted x_0"
142+
if (Options.Thresholding)
143+
{
144+
// TODO:
145+
// predOriginalSample = ThresholdSample(predOriginalSample);
146+
}
147+
else if (Options.ClipSample)
148+
{
149+
predOriginalSample = predOriginalSample.Clip(-Options.ClipSampleRange, Options.ClipSampleRange);
150+
}
151+
152+
153+
//# 4. compute variance: "sigma_t(η)" -> see formula (16)
154+
//# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
155+
var eta = 0f;
156+
var variance = GetVariance(currentTimestep, previousTimestep);
157+
var stdDevT = eta * (float)Math.Sqrt(variance);
158+
159+
var useClippedModelOutput = false;
160+
if (useClippedModelOutput)
161+
{
162+
//# the pred_epsilon is always re-derived from the clipped x_0 in Glide
163+
predEpsilon = sample
164+
.SubtractTensors(predOriginalSample.MultipleTensorByFloat((float)Math.Sqrt(alphaProdT)))
165+
.DivideTensorByFloat((float)Math.Sqrt(betaProdT));
166+
}
167+
168+
169+
//# 5. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
170+
var predSampleDirection = predEpsilon.MultipleTensorByFloat((float)Math.Sqrt(1f - alphaProdTPrev - Math.Pow(stdDevT, 2f)));
171+
172+
173+
//# 6. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
174+
var prevSample = predSampleDirection.AddTensors(predOriginalSample.MultipleTensorByFloat((float)Math.Sqrt(alphaProdTPrev)));
175+
176+
if (eta > 0)
177+
prevSample = prevSample.AddTensors(CreateRandomSample(modelOutput.Dimensions).MultipleTensorByFloat(stdDevT));
178+
179+
return prevSample;
180+
}
181+
182+
183+
/// <summary>
184+
/// Adds noise to the sample.
185+
/// </summary>
186+
/// <param name="originalSamples">The original samples.</param>
187+
/// <param name="noise">The noise.</param>
188+
/// <param name="timesteps">The timesteps.</param>
189+
/// <returns></returns>
190+
public override DenseTensor<float> AddNoise(DenseTensor<float> originalSamples, DenseTensor<float> noise, IReadOnlyList<int> timesteps)
191+
{
192+
// Ref: https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_ddpm.py#L456
193+
int timestep = timesteps[0];
194+
float alphaProd = _alphasCumProd[timestep];
195+
float sqrtAlpha = (float)Math.Sqrt(alphaProd);
196+
float sqrtOneMinusAlpha = (float)Math.Sqrt(1.0f - alphaProd);
197+
198+
return noise
199+
.MultipleTensorByFloat(sqrtOneMinusAlpha)
200+
.AddTensors(originalSamples.MultipleTensorByFloat(sqrtAlpha));
201+
}
202+
203+
204+
/// <summary>
205+
/// Gets the variance.
206+
/// </summary>
207+
/// <param name="timestep">The t.</param>
208+
/// <param name="predictedVariance">The predicted variance.</param>
209+
/// <returns></returns>
210+
private float GetVariance(int timestep, int prevTimestep)
211+
{
212+
float alphaProdT = _alphasCumProd[timestep];
213+
float alphaProdTPrev = prevTimestep >= 0
214+
? _alphasCumProd[timestep]
215+
: _finalAlphaCumprod;
216+
217+
float betaProdT = 1f - alphaProdT;
218+
float betaProdTPrev = 1f - alphaProdTPrev;
219+
float variance = (betaProdTPrev / betaProdT) * (1f - alphaProdT / alphaProdTPrev);
220+
return variance;
221+
}
222+
223+
224+
protected override void Dispose(bool disposing)
225+
{
226+
_alphasCumProd = null;
227+
base.Dispose(disposing);
228+
}
229+
}
230+
}

0 commit comments

Comments
 (0)