Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit 3c5474d

Browse files
committed
Added DDPMWuerstchenScheduler
1 parent 20cc4ff commit 3c5474d

File tree

5 files changed

+162
-3
lines changed

5 files changed

+162
-3
lines changed

OnnxStack.StableDiffusion/Diffusers/StableCascade/StableCascadeDiffuser.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,7 @@ protected override IScheduler GetScheduler(SchedulerOptions options)
249249
return options.SchedulerType switch
250250
{
251251
SchedulerType.DDPM => new DDPMScheduler(options),
252+
SchedulerType.DDPMWuerstchen => new DDPMWuerstchenScheduler(options),
252253
_ => default
253254
};
254255
}

OnnxStack.StableDiffusion/Enums/SchedulerType.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ public enum SchedulerType
2222
[Display(Name = "KDPM2")]
2323
KDPM2 = 5,
2424

25+
[Display(Name = "DDPMWuerstchen")]
26+
DDPMWuerstchen = 6,
27+
2528
[Display(Name = "LCM")]
2629
LCM = 20,
2730

OnnxStack.StableDiffusion/Pipelines/StableCascadePipeline.cs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,15 +44,16 @@ public StableCascadePipeline(PipelineOptions pipelineOptions, TokenizerModel tok
4444
};
4545
_supportedSchedulers = new List<SchedulerType>
4646
{
47-
SchedulerType.DDPM
47+
SchedulerType.DDPM,
48+
SchedulerType.DDPMWuerstchen
4849
};
4950
_defaultSchedulerOptions = defaultSchedulerOptions ?? new SchedulerOptions
5051
{
5152
Width = 1024,
5253
Height = 1024,
5354
InferenceSteps = 20,
5455
GuidanceScale = 4f,
55-
SchedulerType = SchedulerType.DDPM
56+
SchedulerType = SchedulerType.DDPMWuerstchen
5657
};
5758
}
5859

OnnxStack.StableDiffusion/Schedulers/SchedulerBase.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ protected void SetInitNoiseSigma(float initNoiseSigma)
228228
/// </summary>
229229
/// <param name="timestep">The timestep.</param>
230230
/// <returns></returns>
231-
protected int GetPreviousTimestep(int timestep)
231+
protected virtual int GetPreviousTimestep(int timestep)
232232
{
233233
return timestep - _options.TrainTimesteps / _options.InferenceSteps;
234234
}
Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
using Microsoft.ML.OnnxRuntime.Tensors;
2+
using OnnxStack.Core;
3+
using OnnxStack.StableDiffusion.Config;
4+
using OnnxStack.StableDiffusion.Helpers;
5+
using System;
6+
using System.Collections.Generic;
7+
using System.Linq;
8+
9+
namespace OnnxStack.StableDiffusion.Schedulers.StableDiffusion
10+
{
11+
internal class DDPMWuerstchenScheduler : SchedulerBase
12+
{
13+
private float _s;
14+
private float _scaler;
15+
private float _initAlphaCumprod;
16+
17+
18+
/// <summary>
19+
/// Initializes a new instance of the <see cref="DDPMWuerstchenScheduler"/> class.
20+
/// </summary>
21+
/// <param name="stableDiffusionOptions">The stable diffusion options.</param>
22+
public DDPMWuerstchenScheduler() : this(new SchedulerOptions()) { }
23+
24+
/// <summary>
25+
/// Initializes a new instance of the <see cref="DDPMWuerstchenScheduler"/> class.
26+
/// </summary>
27+
/// <param name="stableDiffusionOptions">The stable diffusion options.</param>
28+
/// <param name="schedulerOptions">The scheduler options.</param>
29+
public DDPMWuerstchenScheduler(SchedulerOptions options) : base(options) { }
30+
31+
32+
/// <summary>
33+
/// Initializes this instance.
34+
/// </summary>
35+
protected override void Initialize()
36+
{
37+
_s = 0.008f;
38+
_scaler = 1.0f;
39+
_initAlphaCumprod = MathF.Pow(MathF.Cos(_s / (1f + _s) * MathF.PI * 0.5f), 2f);
40+
SetInitNoiseSigma(1.0f);
41+
}
42+
43+
44+
/// <summary>
45+
/// Sets the timesteps.
46+
/// </summary>
47+
/// <returns></returns>
48+
protected override int[] SetTimesteps()
49+
{
50+
// Create timesteps based on the specified strategy
51+
var timesteps = ArrayHelpers.Linspace(0, 1000, Options.InferenceSteps + 1);
52+
var x = timesteps
53+
.Skip(1)
54+
.Select(x => (int)x)
55+
.OrderByDescending(x => x)
56+
.ToArray();
57+
return x;
58+
}
59+
60+
61+
/// <summary>
62+
/// Scales the input.
63+
/// </summary>
64+
/// <param name="sample">The sample.</param>
65+
/// <param name="timestep">The timestep.</param>
66+
/// <returns></returns>
67+
public override DenseTensor<float> ScaleInput(DenseTensor<float> sample, int timestep)
68+
{
69+
return sample;
70+
}
71+
72+
73+
/// <summary>
74+
/// Processes a inference step for the specified model output.
75+
/// </summary>
76+
/// <param name="modelOutput">The model output.</param>
77+
/// <param name="timestep">The timestep.</param>
78+
/// <param name="sample">The sample.</param>
79+
/// <param name="order">The order.</param>
80+
/// <returns></returns>
81+
/// <exception cref="ArgumentException">Invalid prediction_type: {SchedulerOptions.PredictionType}</exception>
82+
/// <exception cref="NotImplementedException">DDPMScheduler Thresholding currently not implemented</exception>
83+
public override SchedulerStepResult Step(DenseTensor<float> modelOutput, int timestep, DenseTensor<float> sample, int order = 4)
84+
{
85+
var currentTimestep = timestep / 1000f;
86+
var previousTimestep = GetPreviousTimestep(timestep) / 1000f;
87+
88+
var alpha_cumprod = GetAlphaCumprod(currentTimestep);
89+
var alpha_cumprod_prev = GetAlphaCumprod(previousTimestep);
90+
var alpha = alpha_cumprod / alpha_cumprod_prev;
91+
92+
var predictedSample = sample
93+
.SubtractTensors(modelOutput.MultiplyTensorByFloat(1f - alpha).DivideTensorByFloat(MathF.Sqrt(1f - alpha_cumprod)))
94+
.MultiplyTensorByFloat(MathF.Sqrt(1f / alpha))
95+
.AddTensors(CreateRandomSample(modelOutput.Dimensions)
96+
.MultiplyTensorByFloat(MathF.Sqrt((1f - alpha) * (1f - alpha_cumprod_prev) / (1f - alpha_cumprod))));
97+
98+
return new SchedulerStepResult(predictedSample);
99+
}
100+
101+
102+
/// <summary>
103+
/// Adds noise to the sample.
104+
/// </summary>
105+
/// <param name="originalSamples">The original samples.</param>
106+
/// <param name="noise">The noise.</param>
107+
/// <param name="timesteps">The timesteps.</param>
108+
/// <returns></returns>
109+
public override DenseTensor<float> AddNoise(DenseTensor<float> originalSamples, DenseTensor<float> noise, IReadOnlyList<int> timesteps)
110+
{
111+
float timestep = timesteps[0] / 1000f;
112+
float alphaProd = GetAlphaCumprod(timestep);
113+
float sqrtAlpha = MathF.Sqrt(alphaProd);
114+
float sqrtOneMinusAlpha = MathF.Sqrt(1.0f - alphaProd);
115+
116+
return noise
117+
.MultiplyTensorByFloat(sqrtOneMinusAlpha)
118+
.AddTensors(originalSamples.MultiplyTensorByFloat(sqrtAlpha));
119+
}
120+
121+
122+
/// <summary>
123+
/// Gets the previous timestep.
124+
/// </summary>
125+
/// <param name="timestep">The timestep.</param>
126+
/// <returns></returns>
127+
protected override int GetPreviousTimestep(int timestep)
128+
{
129+
var index = Timesteps.IndexOf(timestep) + 1;
130+
if (index > Timesteps.Count - 1)
131+
return 0;
132+
133+
return Timesteps[index];
134+
}
135+
136+
137+
private float GetAlphaCumprod(float timestep)
138+
{
139+
if (_scaler > 1.0f)
140+
timestep = 1f - MathF.Pow(1f - timestep, _scaler);
141+
else if (_scaler < 1.0f)
142+
timestep = MathF.Pow(timestep, _scaler);
143+
144+
var alphaCumprod = MathF.Pow(MathF.Cos((timestep + _s) / (1f + _s) * MathF.PI * 0.5f), 2f) / _initAlphaCumprod;
145+
return Math.Clamp(alphaCumprod, 0.0001f, 0.9999f);
146+
}
147+
148+
149+
protected override void Dispose(bool disposing)
150+
{
151+
base.Dispose(disposing);
152+
}
153+
}
154+
}

0 commit comments

Comments
 (0)