Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit 9c6d3a0

Browse files
committed
Variable Prediction support for EulerAncestralScheduler
1 parent 2de080e commit 9c6d3a0

File tree

1 file changed

+27
-16
lines changed

1 file changed

+27
-16
lines changed

OnnxStack.StableDiffusion/Schedulers/EulerAncestralScheduler.cs

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -33,17 +33,15 @@ public EulerAncestralScheduler(SchedulerOptions schedulerOptions) : base(schedul
3333
/// </summary>
3434
protected override void Initialize()
3535
{
36+
_sigmas = null;
3637
var betas = Enumerable.Empty<float>();
3738
if (!Options.TrainedBetas.IsNullOrEmpty())
3839
{
3940
betas = Options.TrainedBetas;
4041
}
4142
else if (Options.BetaSchedule == BetaScheduleType.Linear)
4243
{
43-
var steps = Options.TrainTimesteps - 1;
44-
var delta = Options.BetaStart + (Options.BetaEnd - Options.BetaStart);
45-
betas = Enumerable.Range(0, Options.TrainTimesteps)
46-
.Select(i => delta * i / steps);
44+
betas = np.linspace(Options.BetaStart, Options.BetaEnd, Options.TrainTimesteps).ToArray<float>();
4745
}
4846
else if (Options.BetaSchedule == BetaScheduleType.ScaledLinear)
4947
{
@@ -58,16 +56,12 @@ protected override void Initialize()
5856
betas = GetBetasForAlphaBar();
5957
}
6058

59+
var alphas = betas.Select(beta => 1.0f - beta);
60+
var alphaCumProd = alphas.Select((alpha, i) => alphas.Take(i + 1).Aggregate((a, b) => a * b));
61+
_sigmas = alphaCumProd
62+
.Select(alpha_prod => (float)Math.Sqrt((1 - alpha_prod) / alpha_prod))
63+
.ToArray();
6164

62-
var alphas = betas.Select(beta => 1 - beta);
63-
var cumulativeProduct = alphas.Select((alpha, i) => alphas.Take(i + 1).Aggregate((a, b) => a * b));
64-
65-
// Create _sigmas as a list and reverse it
66-
_sigmas = cumulativeProduct
67-
.Select(alpha_prod => (float)Math.Sqrt((1 - alpha_prod) / alpha_prod))
68-
.ToArray();
69-
70-
// standard deviation of the initial noise distrubution
7165
var maxSigma = _sigmas.Max();
7266
var initNoiseSigma = Options.TimestepSpacing == TimestepSpacingType.Linspace || Options.TimestepSpacing == TimestepSpacingType.Trailing
7367
? maxSigma
@@ -115,8 +109,9 @@ protected override int[] SetTimesteps()
115109
timesteps = SigmaToTimestep(sigmas, log_sigmas);
116110
}
117111

118-
// add 0.000 to the end of the result
119-
_sigmas = sigmas.Append(0.000f).ToArray();
112+
_sigmas = sigmas
113+
.Append(0.000f)
114+
.ToArray();
120115

121116
return timesteps.Select(x => (int)x)
122117
.OrderByDescending(x => x)
@@ -160,7 +155,23 @@ public override DenseTensor<float> Step(DenseTensor<float> modelOutput, int time
160155
var sigma = _sigmas[stepIndex];
161156

162157
// 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
163-
var predOriginalSample = sample.SubtractTensors(modelOutput.MultipleTensorByFloat(sigma));
158+
DenseTensor<float> predOriginalSample = null;
159+
if (Options.PredictionType == PredictionType.Epsilon)
160+
{
161+
predOriginalSample = sample.SubtractTensors(modelOutput.MultipleTensorByFloat(sigma));
162+
}
163+
else if (Options.PredictionType == PredictionType.VariablePrediction)
164+
{
165+
var sigmaSqrt = (float)Math.Sqrt(sigma * sigma + 1);
166+
predOriginalSample = sample.DivideTensorByFloat(sigmaSqrt)
167+
.AddTensors(modelOutput.MultipleTensorByFloat(-sigma / sigmaSqrt));
168+
}
169+
else if (Options.PredictionType == PredictionType.Sample)
170+
{
171+
//prediction_type not implemented yet: sample
172+
predOriginalSample = modelOutput.ToDenseTensor();
173+
}
174+
164175

165176
var sigmaFrom = _sigmas[stepIndex];
166177
var sigmaTo = _sigmas[stepIndex + 1];

0 commit comments

Comments
 (0)