Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit b15ceb2

Browse files
committed
Tidy up scheduler layout
1 parent 5de2ceb commit b15ceb2

File tree

13 files changed

+54
-68
lines changed

13 files changed

+54
-68
lines changed

OnnxStack.StableDiffusion/Diffusers/DiffuserBase.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
using OnnxStack.StableDiffusion.Enums;
88
using OnnxStack.StableDiffusion.Helpers;
99
using OnnxStack.StableDiffusion.Schedulers;
10+
using OnnxStack.StableDiffusion.Schedulers.StableDiffusion;
1011
using System;
1112
using System.Collections.Generic;
1213
using System.Linq;
@@ -106,7 +107,7 @@ public virtual async Task<DenseTensor<float>> DiffuseAsync(IModelOptions modelOp
106107
noisePred = PerformGuidance(noisePred, schedulerOptions.GuidanceScale);
107108

108109
// Scheduler Step
109-
latents = scheduler.Step(noisePred, timestep, latents).PreviousSample;
110+
latents = scheduler.Step(noisePred, timestep, latents).Result;
110111
}
111112

112113
progressCallback?.Invoke(step, timesteps.Count);

OnnxStack.StableDiffusion/Diffusers/LatentConsistency/TextDiffuser.cs

Lines changed: 10 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
using OnnxStack.StableDiffusion.Enums;
88
using OnnxStack.StableDiffusion.Helpers;
99
using OnnxStack.StableDiffusion.Schedulers;
10+
using OnnxStack.StableDiffusion.Schedulers.LatentConsistency;
1011
using System;
1112
using System.Collections.Generic;
1213
using System.Linq;
@@ -84,8 +85,8 @@ public override async Task<DenseTensor<float>> DiffuseAsync(IModelOptions modelO
8485
// Scheduler Step
8586
var schedulerResult = scheduler.Step(noisePred, timestep, latents);
8687

87-
latents = schedulerResult.PreviousSample;
88-
denoised = schedulerResult.ExtraSample;
88+
latents = schedulerResult.Result;
89+
denoised = schedulerResult.SampleData;
8990
}
9091

9192
progressCallback?.Invoke(step, timesteps.Count);
@@ -147,39 +148,20 @@ protected override IScheduler GetScheduler(PromptOptions prompt, SchedulerOption
147148
/// <returns></returns>
148149
public DenseTensor<float> GetGuidanceScaleEmbedding(float guidance, int embeddingDim = 256)
149150
{
150-
// TODO:
151-
//assert len(w.shape) == 1
152-
//w = w * 1000.0
153-
154-
//half_dim = embedding_dim // 2
155-
//emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
156-
//emb = torch.exp(torch.arange(half_dim, dtype = dtype) * -emb)
157-
//emb = w.to(dtype)[:, None] * emb[None, :]
158-
//emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim = 1)
159-
//if embedding_dim % 2 == 1: # zero pad
160-
// emb = torch.nn.functional.pad(emb, (0, 1))
161-
//assert emb.shape == (w.shape[0], embedding_dim)
162-
//return emb
163-
164-
var w = guidance - 1f;
165-
166-
var half_dim = embeddingDim / 2;
167-
168-
var log = MathF.Log(10000.0f) / (half_dim - 1);
169-
170-
var emb = Enumerable.Range(0, half_dim)
151+
var scale = guidance - 1f;
152+
var halfDim = embeddingDim / 2;
153+
float log = MathF.Log(10000.0f) / (halfDim - 1);
154+
var emb = Enumerable.Range(0, halfDim)
171155
.Select(x => MathF.Exp(x * -log))
172156
.ToArray();
173157
var embSin = emb.Select(MathF.Sin).ToArray();
174158
var embCos = emb.Select(MathF.Cos).ToArray();
175-
176-
DenseTensor<float> result = new DenseTensor<float>(new[] { 1, 2 * half_dim });
177-
for (int i = 0; i < half_dim; i++)
159+
var result = new DenseTensor<float>(new[] { 1, 2 * halfDim });
160+
for (int i = 0; i < halfDim; i++)
178161
{
179162
result[0, i] = embSin[i];
180-
result[0, i + half_dim] = embCos[i];
163+
result[0, i + halfDim] = embCos[i];
181164
}
182-
183165
return result;
184166
}
185167
}

OnnxStack.StableDiffusion/Diffusers/StableDiffusion/InpaintDiffuser.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ public override async Task<DenseTensor<float>> DiffuseAsync(IModelOptions modelO
8787
noisePred = PerformGuidance(noisePred, schedulerOptions.GuidanceScale);
8888

8989
// Scheduler Step
90-
latents = scheduler.Step(noisePred, timestep, latents).PreviousSample;
90+
latents = scheduler.Step(noisePred, timestep, latents).Result;
9191
}
9292

9393
progress?.Invoke(++step, timesteps.Count);

OnnxStack.StableDiffusion/Diffusers/StableDiffusion/InpaintLegacyDiffuser.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ public override async Task<DenseTensor<float>> DiffuseAsync(IModelOptions modelO
8282
noisePred = PerformGuidance(noisePred, schedulerOptions.GuidanceScale);
8383

8484
// Scheduler Step
85-
var steplatents = scheduler.Step(noisePred, timestep, latents).PreviousSample;
85+
var steplatents = scheduler.Step(noisePred, timestep, latents).Result;
8686

8787
// Add noise to original latent
8888
var initLatentsProper = scheduler.AddNoise(latentsOriginal, noise, new[] { timestep });

OnnxStack.StableDiffusion/Schedulers/LCMScheduler.cs renamed to OnnxStack.StableDiffusion/Schedulers/LatentConsistency/LCMScheduler.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
using System.Collections.Generic;
77
using System.Linq;
88

9-
namespace OnnxStack.StableDiffusion.Schedulers
9+
namespace OnnxStack.StableDiffusion.Schedulers.LatentConsistency
1010
{
1111
internal class LCMScheduler : SchedulerBase
1212
{
@@ -197,11 +197,11 @@ public override DenseTensor<float> AddNoise(DenseTensor<float> originalSamples,
197197
public (float cSkip, float cOut) GetBoundaryConditionScalings(float timestep)
198198
{
199199
//self.sigma_data = 0.5 # Default: 0.5
200-
var sigmaData = 0.5f;
200+
var sigmaData = 0.1f;
201201

202-
float c = (MathF.Pow(timestep / 0.1f, 2f) + MathF.Pow(sigmaData, 2f));
202+
float c = MathF.Pow(timestep / 0.1f, 2f) + MathF.Pow(sigmaData, 2f);
203203
float cSkip = MathF.Pow(sigmaData, 2f) / c;
204-
float cOut = (timestep / 0.1f) / MathF.Pow(c, 0.5f);
204+
float cOut = timestep / 0.1f / MathF.Pow(c, 0.5f);
205205
return (cSkip, cOut);
206206
}
207207

OnnxStack.StableDiffusion/Schedulers/SchedulerBase.cs

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -443,22 +443,4 @@ protected virtual void Dispose(bool disposing)
443443

444444
#endregion
445445
}
446-
447-
448-
public class SchedulerStepResult
449-
{
450-
public SchedulerStepResult(DenseTensor<float> previousSample)
451-
{
452-
PreviousSample = previousSample;
453-
}
454-
455-
public SchedulerStepResult(DenseTensor<float> previousSample, DenseTensor<float> extraSample)
456-
{
457-
ExtraSample = extraSample;
458-
PreviousSample = previousSample;
459-
}
460-
461-
public DenseTensor<float> PreviousSample { get; set; }
462-
public DenseTensor<float> ExtraSample { get; set; }
463-
}
464446
}
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
using Microsoft.ML.OnnxRuntime.Tensors;
2+
3+
namespace OnnxStack.StableDiffusion.Schedulers
4+
{
5+
public class SchedulerStepResult
6+
{
7+
public SchedulerStepResult(DenseTensor<float> result)
8+
{
9+
Result = result;
10+
}
11+
12+
public SchedulerStepResult(DenseTensor<float> previousSample, DenseTensor<float> sampleData)
13+
{
14+
Result = previousSample;
15+
SampleData = sampleData;
16+
}
17+
18+
public DenseTensor<float> Result { get; set; }
19+
public DenseTensor<float> SampleData { get; set; }
20+
}
21+
}

OnnxStack.StableDiffusion/Schedulers/DDIMScheduler.cs renamed to OnnxStack.StableDiffusion/Schedulers/StableDiffusion/DDIMScheduler.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
using System.Collections.Generic;
77
using System.Linq;
88

9-
namespace OnnxStack.StableDiffusion.Schedulers
9+
namespace OnnxStack.StableDiffusion.Schedulers.StableDiffusion
1010
{
1111
internal class DDIMScheduler : SchedulerBase
1212
{
@@ -84,8 +84,8 @@ public override DenseTensor<float> ScaleInput(DenseTensor<float> sample, int tim
8484
/// <param name="sample">The sample.</param>
8585
/// <param name="order">The order.</param>
8686
/// <returns></returns>
87-
/// <exception cref="System.ArgumentException">Invalid prediction_type: {SchedulerOptions.PredictionType}</exception>
88-
/// <exception cref="System.NotImplementedException">DDIMScheduler Thresholding currently not implemented</exception>
87+
/// <exception cref="ArgumentException">Invalid prediction_type: {SchedulerOptions.PredictionType}</exception>
88+
/// <exception cref="NotImplementedException">DDIMScheduler Thresholding currently not implemented</exception>
8989
public override SchedulerStepResult Step(DenseTensor<float> modelOutput, int timestep, DenseTensor<float> sample, int order = 4)
9090
{
9191
//# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
@@ -216,7 +216,7 @@ private float GetVariance(int timestep, int prevTimestep)
216216

217217
float betaProdT = 1f - alphaProdT;
218218
float betaProdTPrev = 1f - alphaProdTPrev;
219-
float variance = (betaProdTPrev / betaProdT) * (1f - alphaProdT / alphaProdTPrev);
219+
float variance = betaProdTPrev / betaProdT * (1f - alphaProdT / alphaProdTPrev);
220220
return variance;
221221
}
222222

OnnxStack.StableDiffusion/Schedulers/DDPMScheduler.cs renamed to OnnxStack.StableDiffusion/Schedulers/StableDiffusion/DDPMScheduler.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
using System.Collections.Generic;
88
using System.Linq;
99

10-
namespace OnnxStack.StableDiffusion.Schedulers
10+
namespace OnnxStack.StableDiffusion.Schedulers.StableDiffusion
1111
{
1212
internal class DDPMScheduler : SchedulerBase
1313
{
@@ -79,8 +79,8 @@ public override DenseTensor<float> ScaleInput(DenseTensor<float> sample, int tim
7979
/// <param name="sample">The sample.</param>
8080
/// <param name="order">The order.</param>
8181
/// <returns></returns>
82-
/// <exception cref="System.ArgumentException">Invalid prediction_type: {SchedulerOptions.PredictionType}</exception>
83-
/// <exception cref="System.NotImplementedException">DDPMScheduler Thresholding currently not implemented</exception>
82+
/// <exception cref="ArgumentException">Invalid prediction_type: {SchedulerOptions.PredictionType}</exception>
83+
/// <exception cref="NotImplementedException">DDPMScheduler Thresholding currently not implemented</exception>
8484
public override SchedulerStepResult Step(DenseTensor<float> modelOutput, int timestep, DenseTensor<float> sample, int order = 4)
8585
{
8686
int currentTimestep = timestep;
@@ -316,7 +316,7 @@ private float CalculatePercentile(NDArray data, float percentile)
316316
var sortedIndices = np.argsort<float>(data);
317317

318318
// Calculate the index corresponding to the percentile
319-
var index = (int)Math.Ceiling((percentile / 100f) * (data.Shape[0] - 1));
319+
var index = (int)Math.Ceiling(percentile / 100f * (data.Shape[0] - 1));
320320

321321
// Retrieve the value at the calculated index
322322
var percentileValue = data[sortedIndices[index]];

OnnxStack.StableDiffusion/Schedulers/EulerAncestralScheduler.cs renamed to OnnxStack.StableDiffusion/Schedulers/StableDiffusion/EulerAncestralScheduler.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
using System.Collections.Generic;
99
using System.Linq;
1010

11-
namespace OnnxStack.StableDiffusion.Schedulers
11+
namespace OnnxStack.StableDiffusion.Schedulers.StableDiffusion
1212
{
1313
public sealed class EulerAncestralScheduler : SchedulerBase
1414
{

0 commit comments

Comments
 (0)