Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit 5d2aa91

Browse files
committed
Sigmoid BetaSchedule support, refactor SchedulerBase
1 parent 9c6d3a0 commit 5d2aa91

File tree

5 files changed

+174
-230
lines changed

5 files changed

+174
-230
lines changed

OnnxStack.StableDiffusion/Enums/BetaScheduleType.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ public enum BetaScheduleType
44
{
55
Linear = 0,
66
ScaledLinear = 1,
7-
SquaredCosCapV2 = 2
7+
SquaredCosCapV2 = 2,
8+
Sigmoid = 3
89
}
910
}

OnnxStack.StableDiffusion/Schedulers/DDPMScheduler.cs

Lines changed: 41 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,7 @@ namespace OnnxStack.StableDiffusion.Schedulers
1111
{
1212
internal class DDPMScheduler : SchedulerBase
1313
{
14-
private float[] _betas;
15-
private List<float> _alphasCumulativeProducts;
14+
private float[] _alphasCumProd;
1615

1716
/// <summary>
1817
/// Initializes a new instance of the <see cref="DDPMScheduler"/> class.
@@ -33,43 +32,13 @@ public DDPMScheduler(SchedulerOptions options) : base(options) { }
3332
/// </summary>
3433
protected override void Initialize()
3534
{
36-
var alphas = new List<float>();
37-
if (Options.TrainedBetas != null)
38-
{
39-
_betas = Options.TrainedBetas.ToArray();
40-
}
41-
else if (Options.BetaSchedule == BetaScheduleType.Linear)
42-
{
43-
_betas = np.linspace(Options.BetaStart, Options.BetaEnd, Options.TrainTimesteps).ToArray<float>();
44-
}
45-
else if (Options.BetaSchedule == BetaScheduleType.ScaledLinear)
46-
{
47-
// This schedule is very specific to the latent diffusion model.
48-
_betas = np.power(np.linspace(MathF.Sqrt(Options.BetaStart), MathF.Sqrt(Options.BetaEnd), Options.TrainTimesteps), 2).ToArray<float>();
49-
}
50-
else if (Options.BetaSchedule == BetaScheduleType.SquaredCosCapV2)
51-
{
52-
// Glide cosine schedule
53-
_betas = GetBetasForAlphaBar();
54-
}
55-
//else if (betaSchedule == "sigmoid")
56-
//{
57-
// // GeoDiff sigmoid schedule
58-
// var betas = np.linspace(-6, 6, numTrainTimesteps);
59-
// Betas = (np.multiply(np.exp(betas), (betaEnd - betaStart)) + betaStart).ToArray<float>();
60-
//}
61-
62-
63-
for (int i = 0; i < Options.TrainTimesteps; i++)
64-
{
65-
alphas.Add(1.0f - _betas[i]);
66-
}
35+
_alphasCumProd = null;
6736

68-
_alphasCumulativeProducts = new List<float> { alphas[0] };
69-
for (int i = 1; i < Options.TrainTimesteps; i++)
70-
{
71-
_alphasCumulativeProducts.Add(_alphasCumulativeProducts[i - 1] * alphas[i]);
72-
}
37+
var betas = GetBetaSchedule();
38+
var alphas = betas.Select(beta => 1.0f - beta);
39+
_alphasCumProd = alphas
40+
.Select((alpha, i) => alphas.Take(i + 1).Aggregate((a, b) => a * b))
41+
.ToArray();
7342

7443
SetInitNoiseSigma(1.0f);
7544
}
@@ -82,29 +51,8 @@ protected override void Initialize()
8251
protected override int[] SetTimesteps()
8352
{
8453
// Create timesteps based on the specified strategy
85-
NDArray timestepsArray = null;
86-
if (Options.TimestepSpacing == TimestepSpacingType.Linspace)
87-
{
88-
timestepsArray = np.linspace(0, Options.TrainTimesteps - 1, Options.InferenceSteps);
89-
timestepsArray = np.around(timestepsArray)["::1"];
90-
}
91-
else if (Options.TimestepSpacing == TimestepSpacingType.Leading)
92-
{
93-
var stepRatio = Options.TrainTimesteps / Options.InferenceSteps;
94-
timestepsArray = np.arange(0, (float)Options.InferenceSteps) * stepRatio;
95-
timestepsArray = np.around(timestepsArray)["::1"];
96-
timestepsArray += Options.StepsOffset;
97-
}
98-
else if (Options.TimestepSpacing == TimestepSpacingType.Trailing)
99-
{
100-
var stepRatio = Options.TrainTimesteps / (Options.InferenceSteps - 1);
101-
timestepsArray = np.arange((float)Options.TrainTimesteps, 0, -stepRatio)["::-1"];
102-
timestepsArray = np.around(timestepsArray);
103-
timestepsArray -= 1;
104-
}
105-
106-
return timestepsArray
107-
.ToArray<float>()
54+
var timesteps = GetTimesteps();
55+
return timesteps
10856
.Select(x => (int)x)
10957
.OrderByDescending(x => x)
11058
.ToArray();
@@ -139,8 +87,8 @@ public override DenseTensor<float> Step(DenseTensor<float> modelOutput, int time
13987
int previousTimestep = GetPreviousTimestep(currentTimestep);
14088

14189
//# 1. compute alphas, betas
142-
float alphaProdT = _alphasCumulativeProducts[currentTimestep];
143-
float alphaProdTPrev = previousTimestep >= 0 ? _alphasCumulativeProducts[previousTimestep] : 1f;
90+
float alphaProdT = _alphasCumProd[currentTimestep];
91+
float alphaProdTPrev = previousTimestep >= 0 ? _alphasCumProd[previousTimestep] : 1f;
14492
float betaProdT = 1 - alphaProdT;
14593
float betaProdTPrev = 1 - alphaProdTPrev;
14694
float currentAlphaT = alphaProdT / alphaProdTPrev;
@@ -161,27 +109,7 @@ public override DenseTensor<float> Step(DenseTensor<float> modelOutput, int time
161109

162110
//# 2. compute predicted original sample from predicted noise also called
163111
//# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
164-
DenseTensor<float> predOriginalSample = null;
165-
if (Options.PredictionType == PredictionType.Epsilon)
166-
{
167-
//pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
168-
var sampleBeta = sample.SubtractTensors(modelOutput.MultipleTensorByFloat((float)Math.Sqrt(betaProdT)));
169-
predOriginalSample = sampleBeta.DivideTensorByFloat((float)Math.Sqrt(alphaProdT), sampleBeta.Dimensions);
170-
}
171-
else if (Options.PredictionType == PredictionType.Sample)
172-
{
173-
predOriginalSample = modelOutput;
174-
}
175-
else if (Options.PredictionType == PredictionType.VariablePrediction)
176-
{
177-
// pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
178-
var alphaSqrt = (float)Math.Sqrt(alphaProdT);
179-
var betaSqrt = (float)Math.Sqrt(betaProdT);
180-
predOriginalSample = sample
181-
.MultipleTensorByFloat(alphaSqrt)
182-
.SubtractTensors(modelOutput.MultipleTensorByFloat(betaSqrt));
183-
}
184-
112+
var predOriginalSample = GetPredictedSample(modelOutput, sample, alphaProdT, betaProdT);
185113

186114
//# 3. Clip or threshold "predicted x_0"
187115
if (Options.Thresholding)
@@ -234,6 +162,31 @@ public override DenseTensor<float> Step(DenseTensor<float> modelOutput, int time
234162
}
235163

236164

165+
private DenseTensor<float> GetPredictedSample(DenseTensor<float> modelOutput, DenseTensor<float> sample, float alphaProdT, float betaProdT)
166+
{
167+
DenseTensor<float> predOriginalSample = null;
168+
if (Options.PredictionType == PredictionType.Epsilon)
169+
{
170+
//pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
171+
var sampleBeta = sample.SubtractTensors(modelOutput.MultipleTensorByFloat((float)Math.Sqrt(betaProdT)));
172+
predOriginalSample = sampleBeta.DivideTensorByFloat((float)Math.Sqrt(alphaProdT), sampleBeta.Dimensions);
173+
}
174+
else if (Options.PredictionType == PredictionType.Sample)
175+
{
176+
predOriginalSample = modelOutput;
177+
}
178+
else if (Options.PredictionType == PredictionType.VariablePrediction)
179+
{
180+
// pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
181+
var alphaSqrt = (float)Math.Sqrt(alphaProdT);
182+
var betaSqrt = (float)Math.Sqrt(betaProdT);
183+
predOriginalSample = sample
184+
.MultipleTensorByFloat(alphaSqrt)
185+
.SubtractTensors(modelOutput.MultipleTensorByFloat(betaSqrt));
186+
}
187+
return predOriginalSample;
188+
}
189+
237190
/// <summary>
238191
/// Adds noise to the sample.
239192
/// </summary>
@@ -245,7 +198,7 @@ public override DenseTensor<float> AddNoise(DenseTensor<float> originalSamples,
245198
{
246199
// Ref: https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_ddpm.py#L456
247200
int timestep = timesteps[0];
248-
float alphaProd = _alphasCumulativeProducts[timestep];
201+
float alphaProd = _alphasCumProd[timestep];
249202
float sqrtAlpha = (float)Math.Sqrt(alphaProd);
250203
float sqrtOneMinusAlpha = (float)Math.Sqrt(1.0f - alphaProd);
251204

@@ -263,8 +216,8 @@ public override DenseTensor<float> AddNoise(DenseTensor<float> originalSamples,
263216
private float GetVariance(int timestep, float predictedVariance = 0f)
264217
{
265218
int prevTimestep = GetPreviousTimestep(timestep);
266-
float alphaProdT = _alphasCumulativeProducts[timestep];
267-
float alphaProdTPrev = prevTimestep >= 0 ? _alphasCumulativeProducts[prevTimestep] : 1.0f;
219+
float alphaProdT = _alphasCumProd[timestep];
220+
float alphaProdTPrev = prevTimestep >= 0 ? _alphasCumProd[prevTimestep] : 1.0f;
268221
float currentBetaT = 1 - alphaProdT / alphaProdTPrev;
269222

270223
// For t > 0, compute predicted variance βt
@@ -384,8 +337,7 @@ private bool IsVarianceTypeLearned()
384337

385338
protected override void Dispose(bool disposing)
386339
{
387-
_betas = null;
388-
_alphasCumulativeProducts = null;
340+
_alphasCumProd = null;
389341
base.Dispose(disposing);
390342
}
391343
}

OnnxStack.StableDiffusion/Schedulers/EulerAncestralScheduler.cs

Lines changed: 7 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -34,38 +34,15 @@ public EulerAncestralScheduler(SchedulerOptions schedulerOptions) : base(schedul
3434
protected override void Initialize()
3535
{
3636
_sigmas = null;
37-
var betas = Enumerable.Empty<float>();
38-
if (!Options.TrainedBetas.IsNullOrEmpty())
39-
{
40-
betas = Options.TrainedBetas;
41-
}
42-
else if (Options.BetaSchedule == BetaScheduleType.Linear)
43-
{
44-
betas = np.linspace(Options.BetaStart, Options.BetaEnd, Options.TrainTimesteps).ToArray<float>();
45-
}
46-
else if (Options.BetaSchedule == BetaScheduleType.ScaledLinear)
47-
{
48-
var start = (float)Math.Sqrt(Options.BetaStart);
49-
var end = (float)Math.Sqrt(Options.BetaEnd);
50-
betas = np.linspace(start, end, Options.TrainTimesteps)
51-
.ToArray<float>()
52-
.Select(x => x * x);
53-
}
54-
else if (Options.BetaSchedule == BetaScheduleType.SquaredCosCapV2)
55-
{
56-
betas = GetBetasForAlphaBar();
57-
}
5837

38+
var betas = GetBetaSchedule();
5939
var alphas = betas.Select(beta => 1.0f - beta);
6040
var alphaCumProd = alphas.Select((alpha, i) => alphas.Take(i + 1).Aggregate((a, b) => a * b));
6141
_sigmas = alphaCumProd
6242
.Select(alpha_prod => (float)Math.Sqrt((1 - alpha_prod) / alpha_prod))
6343
.ToArray();
6444

65-
var maxSigma = _sigmas.Max();
66-
var initNoiseSigma = Options.TimestepSpacing == TimestepSpacingType.Linspace || Options.TimestepSpacing == TimestepSpacingType.Trailing
67-
? maxSigma
68-
: (float)Math.Sqrt(maxSigma * maxSigma + 1);
45+
var initNoiseSigma = GetInitNoiseSigma(_sigmas);
6946
SetInitNoiseSigma(initNoiseSigma);
7047
}
7148

@@ -76,29 +53,8 @@ protected override void Initialize()
7653
/// <returns></returns>
7754
protected override int[] SetTimesteps()
7855
{
79-
NDArray timestepsArray = null;
80-
if (Options.TimestepSpacing == TimestepSpacingType.Linspace)
81-
{
82-
timestepsArray = np.linspace(0, Options.TrainTimesteps - 1, Options.InferenceSteps);
83-
timestepsArray = np.around(timestepsArray)["::1"];
84-
}
85-
else if (Options.TimestepSpacing == TimestepSpacingType.Leading)
86-
{
87-
var stepRatio = Options.TrainTimesteps / Options.InferenceSteps;
88-
timestepsArray = np.arange(0, (float)Options.InferenceSteps) * stepRatio;
89-
timestepsArray = np.around(timestepsArray)["::1"];
90-
timestepsArray += Options.StepsOffset;
91-
}
92-
else if (Options.TimestepSpacing == TimestepSpacingType.Trailing)
93-
{
94-
var stepRatio = Options.TrainTimesteps / (Options.InferenceSteps - 1);
95-
timestepsArray = np.arange(Options.TrainTimesteps, 0f, -stepRatio)["::-1"];
96-
timestepsArray = np.around(timestepsArray);
97-
timestepsArray -= 1;
98-
}
99-
10056
var sigmas = _sigmas.ToArray();
101-
var timesteps = timestepsArray.ToArray<float>();
57+
var timesteps = GetTimesteps();
10258
var log_sigmas = np.log(sigmas).ToArray<float>();
10359
var range = np.arange(0, (float)_sigmas.Length).ToArray<float>();
10460
sigmas = Interpolate(timesteps, range, _sigmas);
@@ -135,9 +91,7 @@ public override DenseTensor<float> ScaleInput(DenseTensor<float> sample, int tim
13591
sigma = (float)Math.Sqrt(Math.Pow(sigma, 2) + 1);
13692

13793
// Divide sample tensor shape {2,4,(H/8),(W/8)} by sigma
138-
sample = sample.DivideTensorByFloat(sigma, sample.Dimensions);
139-
140-
return sample;
94+
return sample.DivideTensorByFloat(sigma, sample.Dimensions);
14195
}
14296

14397

@@ -155,23 +109,7 @@ public override DenseTensor<float> Step(DenseTensor<float> modelOutput, int time
155109
var sigma = _sigmas[stepIndex];
156110

157111
// 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
158-
DenseTensor<float> predOriginalSample = null;
159-
if (Options.PredictionType == PredictionType.Epsilon)
160-
{
161-
predOriginalSample = sample.SubtractTensors(modelOutput.MultipleTensorByFloat(sigma));
162-
}
163-
else if (Options.PredictionType == PredictionType.VariablePrediction)
164-
{
165-
var sigmaSqrt = (float)Math.Sqrt(sigma * sigma + 1);
166-
predOriginalSample = sample.DivideTensorByFloat(sigmaSqrt)
167-
.AddTensors(modelOutput.MultipleTensorByFloat(-sigma / sigmaSqrt));
168-
}
169-
else if (Options.PredictionType == PredictionType.Sample)
170-
{
171-
//prediction_type not implemented yet: sample
172-
predOriginalSample = modelOutput.ToDenseTensor();
173-
}
174-
112+
var predOriginalSample = GetPredictedSample(modelOutput, sample, sigma);
175113

176114
var sigmaFrom = _sigmas[stepIndex];
177115
var sigmaTo = _sigmas[stepIndex + 1];
@@ -195,6 +133,8 @@ public override DenseTensor<float> Step(DenseTensor<float> modelOutput, int time
195133
}
196134

197135

136+
137+
198138
/// <summary>
199139
/// Adds noise to the sample.
200140
/// </summary>

0 commit comments

Comments
 (0)