Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit 6c72048

Browse files
committed
Euler Scheduler
1 parent d376d16 commit 6c72048

File tree

3 files changed

+183
-1
lines changed

3 files changed

+183
-1
lines changed

OnnxStack.StableDiffusion/Diffusers/DiffuserBase.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,7 @@ protected static IScheduler GetScheduler(PromptOptions prompt, SchedulerOptions
279279
return prompt.SchedulerType switch
280280
{
281281
SchedulerType.LMS => new LMSScheduler(options),
282+
SchedulerType.Euler => new EulerScheduler(options),
282283
SchedulerType.EulerAncestral => new EulerAncestralScheduler(options),
283284
SchedulerType.DDPM => new DDPMScheduler(options),
284285
SchedulerType.DDIM => new DDIMScheduler(options),

OnnxStack.StableDiffusion/Enums/SchedulerType.cs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,11 @@ public enum SchedulerType
77
[Display(Name = "LMS")]
88
LMS = 0,
99

10+
[Display(Name = "Euler")]
11+
Euler = 1,
12+
1013
[Display(Name = "Euler Ancestral")]
11-
EulerAncestral = 1,
14+
EulerAncestral = 2,
1215

1316
[Display(Name = "DDPM")]
1417
DDPM = 3,
Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
using Microsoft.ML.OnnxRuntime.Tensors;
2+
using NumSharp;
3+
using OnnxStack.Core;
4+
using OnnxStack.StableDiffusion.Config;
5+
using OnnxStack.StableDiffusion.Enums;
6+
using OnnxStack.StableDiffusion.Helpers;
7+
using System;
8+
using System.Collections.Generic;
9+
using System.Linq;
10+
11+
namespace OnnxStack.StableDiffusion.Schedulers
12+
{
13+
public sealed class EulerScheduler : SchedulerBase
14+
{
15+
private float[] _sigmas;
16+
17+
/// <summary>
18+
/// Initializes a new instance of the <see cref="EulerScheduler"/> class.
19+
/// </summary>
20+
/// <param name="stableDiffusionOptions">The stable diffusion options.</param>
21+
public EulerScheduler() : this(new SchedulerOptions()) { }
22+
23+
/// <summary>
24+
/// Initializes a new instance of the <see cref="EulerScheduler"/> class.
25+
/// </summary>
26+
/// <param name="stableDiffusionOptions">The stable diffusion options.</param>
27+
/// <param name="schedulerOptions">The scheduler options.</param>
28+
public EulerScheduler(SchedulerOptions schedulerOptions) : base(schedulerOptions) { }
29+
30+
31+
/// <summary>
32+
/// Initializes this instance.
33+
/// </summary>
34+
protected override void Initialize()
35+
{
36+
_sigmas = null;
37+
38+
var betas = GetBetaSchedule();
39+
var alphas = betas.Select(beta => 1.0f - beta);
40+
var alphaCumProd = alphas.Select((alpha, i) => alphas.Take(i + 1).Aggregate((a, b) => a * b));
41+
_sigmas = alphaCumProd
42+
.Select(alpha_prod => (float)Math.Sqrt((1 - alpha_prod) / alpha_prod))
43+
.ToArray();
44+
45+
var initNoiseSigma = GetInitNoiseSigma(_sigmas);
46+
SetInitNoiseSigma(initNoiseSigma);
47+
}
48+
49+
50+
/// <summary>
51+
/// Sets the timesteps.
52+
/// </summary>
53+
/// <returns></returns>
54+
protected override int[] SetTimesteps()
55+
{
56+
var sigmas = _sigmas.ToArray();
57+
var timesteps = GetTimesteps();
58+
var log_sigmas = np.log(sigmas).ToArray<float>();
59+
var range = np.arange(0, (float)_sigmas.Length).ToArray<float>();
60+
61+
// TODO: Implement "interpolation_type"
62+
var interpolation_type = "linear";
63+
sigmas = interpolation_type == "log_linear"
64+
? np.exp(np.linspace(np.log(sigmas.Last()), np.log(sigmas.First()), timesteps.Length + 1)).ToArray<float>()
65+
: Interpolate(timesteps, range, _sigmas);
66+
67+
if (Options.UseKarrasSigmas)
68+
{
69+
sigmas = ConvertToKarras(sigmas);
70+
timesteps = SigmaToTimestep(sigmas, log_sigmas);
71+
}
72+
73+
_sigmas = sigmas
74+
.Append(0.000f)
75+
.ToArray();
76+
77+
return timesteps.Select(x => (int)x)
78+
.OrderByDescending(x => x)
79+
.ToArray();
80+
}
81+
82+
83+
/// <summary>
84+
/// Scales the input.
85+
/// </summary>
86+
/// <param name="sample">The sample.</param>
87+
/// <param name="timestep">The timestep.</param>
88+
/// <returns></returns>
89+
public override DenseTensor<float> ScaleInput(DenseTensor<float> sample, int timestep)
90+
{
91+
// Get step index of timestep from TimeSteps
92+
int stepIndex = Timesteps.IndexOf(timestep);
93+
94+
// Get sigma at stepIndex
95+
var sigma = _sigmas[stepIndex];
96+
sigma = (float)Math.Sqrt(Math.Pow(sigma, 2) + 1);
97+
98+
// Divide sample tensor shape {2,4,(H/8),(W/8)} by sigma
99+
return sample.DivideTensorByFloat(sigma);
100+
}
101+
102+
103+
/// <summary>
104+
/// Processes a inference step for the specified model output.
105+
/// </summary>
106+
/// <param name="modelOutput">The model output.</param>
107+
/// <param name="timestep">The timestep.</param>
108+
/// <param name="sample">The sample.</param>
109+
/// <param name="order">The order.</param>
110+
/// <returns></returns>
111+
public override DenseTensor<float> Step(DenseTensor<float> modelOutput, int timestep, DenseTensor<float> sample, int order = 4)
112+
{
113+
// TODO: Implement "extended settings for scheduler types"
114+
float s_churn = 0f;
115+
float s_tmin = 0f;
116+
float s_tmax = float.PositiveInfinity;
117+
float s_noise = 1f;
118+
119+
var stepIndex = Timesteps.IndexOf(timestep);
120+
float sigma = _sigmas[stepIndex];
121+
122+
float gamma = s_tmin <= sigma && sigma <= s_tmax ? (float)Math.Min(s_churn / (_sigmas.Length - 1f), Math.Sqrt(2.0f) - 1.0f) : 0f;
123+
var noise = CreateRandomSample(modelOutput.Dimensions);
124+
var epsilon = noise.MultipleTensorByFloat(s_noise);
125+
float sigmaHat = sigma * (1.0f + gamma);
126+
127+
if (gamma > 0)
128+
sample = sample.AddTensors(epsilon.MultipleTensorByFloat((float)Math.Sqrt(Math.Pow(sigmaHat, 2f) - Math.Pow(sigma, 2f))));
129+
130+
131+
// 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
132+
var predOriginalSample = Options.PredictionType != PredictionType.Epsilon
133+
? GetPredictedSample(modelOutput, sample, sigma)
134+
: sample.SubtractTensors(modelOutput.MultipleTensorByFloat(sigmaHat));
135+
136+
137+
// 2. Convert to an ODE derivative
138+
var derivative = sample
139+
.SubtractTensors(predOriginalSample)
140+
.DivideTensorByFloat(sigmaHat);
141+
142+
var delta = _sigmas[stepIndex + 1] - sigmaHat;
143+
return sample.AddTensors(derivative.MultipleTensorByFloat(delta));
144+
}
145+
146+
147+
/// <summary>
148+
/// Adds noise to the sample.
149+
/// </summary>
150+
/// <param name="originalSamples">The original samples.</param>
151+
/// <param name="noise">The noise.</param>
152+
/// <param name="timesteps">The timesteps.</param>
153+
/// <returns></returns>
154+
public override DenseTensor<float> AddNoise(DenseTensor<float> originalSamples, DenseTensor<float> noise, IReadOnlyList<int> timesteps)
155+
{
156+
// Ref: https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py#L389
157+
var sigma = timesteps
158+
.Select(x => Timesteps.IndexOf(x))
159+
.Select(x => _sigmas[x])
160+
.Max();
161+
162+
return noise
163+
.MultipleTensorByFloat(sigma)
164+
.AddTensors(originalSamples);
165+
}
166+
167+
168+
/// <summary>
169+
/// Releases unmanaged and - optionally - managed resources.
170+
/// </summary>
171+
/// <param name="disposing"><c>true</c> to release both managed and unmanaged resources; <c>false</c> to release only unmanaged resources.</param>
172+
protected override void Dispose(bool disposing)
173+
{
174+
_sigmas = null;
175+
base.Dispose(disposing);
176+
}
177+
}
178+
}

0 commit comments

Comments
 (0)