Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit 5aaa63c

Browse files
committed
Move scheduler specific code from TensorHelper to DiffuserBase
1 parent 5d2aa91 commit 5aaa63c

File tree

7 files changed

+85
-82
lines changed

7 files changed

+85
-82
lines changed

OnnxStack.Console/Examples/StableDebug.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ public async Task RunAsync()
5858
OutputHelpers.WriteConsole("Generating Image...", ConsoleColor.Green);
5959
await GenerateImage(model, promptOptions, schedulerOptions);
6060
}
61+
break;
6162
}
6263
}
6364

OnnxStack.StableDiffusion/Diffusers/DiffuserBase.cs

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,7 @@ public virtual async Task<DenseTensor<float>> DiffuseAsync(IModelOptions modelOp
101101
// Perform guidance
102102
if (schedulerOptions.GuidanceScale > 1.0f)
103103
{
104-
var (noisePredUncond, noisePredText) = noisePred.SplitTensor(schedulerOptions.GetScaledDimension());
105-
noisePred = noisePredUncond.PerformGuidance(noisePredText, schedulerOptions.GuidanceScale);
104+
noisePred = PerformGuidance(noisePred, schedulerOptions.GuidanceScale);
106105
}
107106

108107
// Scheduler Step
@@ -146,6 +145,59 @@ protected async Task<DenseTensor<float>> DecodeLatents(IModelOptions model, Sche
146145
}
147146
}
148147

148+
/// <summary>
149+
/// Performs classifier free guidance
150+
/// </summary>
151+
/// <param name="noisePredUncond">The noise pred.</param>
152+
/// <param name="noisePredText">The noise pred text.</param>
153+
/// <param name="guidanceScale">The guidance scale.</param>
154+
/// <returns></returns>
155+
protected DenseTensor<float> PerformGuidance(DenseTensor<float> noisePrediction, double guidanceScale)
156+
{
157+
// Split Prompt and Negative Prompt predictions
158+
var (noisePredCond, noisePredUncond) = SplitPredictedNoise(noisePrediction);
159+
for (int i = 0; i < noisePredUncond.Dimensions[0]; i++)
160+
{
161+
for (int j = 0; j < noisePredUncond.Dimensions[1]; j++)
162+
{
163+
for (int k = 0; k < noisePredUncond.Dimensions[2]; k++)
164+
{
165+
for (int l = 0; l < noisePredUncond.Dimensions[3]; l++)
166+
{
167+
noisePredUncond[i, j, k, l] = noisePredUncond[i, j, k, l] + (float)guidanceScale * (noisePredCond[i, j, k, l] - noisePredUncond[i, j, k, l]);
168+
}
169+
}
170+
}
171+
}
172+
return noisePredUncond;
173+
}
174+
175+
176+
/// <summary>
177+
/// Splits the predicted noise.
178+
/// </summary>
179+
/// <param name="predictedNoiseSample">The predicted noise sample.</param>
180+
/// <returns>Split Prompt and Negative Prompt predictions</returns>
181+
protected (DenseTensor<float> noisePredCond, DenseTensor<float> noisePredUncond) SplitPredictedNoise(DenseTensor<float> predictedNoiseSample)
182+
{
183+
var dimensions = predictedNoiseSample.Dimensions.ToArray();
184+
dimensions[0] /= 2;
185+
var noisePredCond = new DenseTensor<float>(dimensions);
186+
var noisePredUncond = new DenseTensor<float>(dimensions);
187+
for (int j = 0; j < 4; j++)
188+
{
189+
for (int k = 0; k < dimensions[2]; k++)
190+
{
191+
for (int l = 0; l < dimensions[3]; l++)
192+
{
193+
noisePredUncond[0, j, k, l] = predictedNoiseSample[0, j, k, l];
194+
noisePredCond[0, j, k, l] = predictedNoiseSample[0, j + 4, k, l];
195+
}
196+
}
197+
}
198+
return (noisePredCond, noisePredUncond);
199+
}
200+
149201

150202
/// <summary>
151203
/// Determines whether the specified result image is not NSFW.
@@ -211,6 +263,7 @@ protected static DenseTensor<float> ClipImageFeatureExtractor(SchedulerOptions o
211263
}
212264

213265

266+
214267
/// <summary>
215268
/// Gets the scheduler.
216269
/// </summary>

OnnxStack.StableDiffusion/Diffusers/InpaintDiffuser.cs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,7 @@ public override async Task<DenseTensor<float>> DiffuseAsync(IModelOptions modelO
8585
// Perform guidance
8686
if (schedulerOptions.GuidanceScale > 1.0f)
8787
{
88-
var (noisePredUncond, noisePredText) = noisePred.SplitTensor(schedulerOptions.GetScaledDimension());
89-
noisePred = noisePredUncond.PerformGuidance(noisePredText, schedulerOptions.GuidanceScale);
88+
noisePred = PerformGuidance(noisePred, schedulerOptions.GuidanceScale);
9089
}
9190

9291
// Scheduler Step

OnnxStack.StableDiffusion/Diffusers/InpaintLegacyDiffuser.cs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,7 @@ public override async Task<DenseTensor<float>> DiffuseAsync(IModelOptions modelO
7878
// Perform guidance
7979
if (schedulerOptions.GuidanceScale > 1.0f)
8080
{
81-
var (noisePredUncond, noisePredText) = noisePred.SplitTensor(schedulerOptions.GetScaledDimension());
82-
noisePred = noisePredUncond.PerformGuidance(noisePredText, schedulerOptions.GuidanceScale);
81+
noisePred = PerformGuidance(noisePred, schedulerOptions.GuidanceScale);
8382
}
8483

8584
// Scheduler Step

OnnxStack.StableDiffusion/Helpers/TensorHelper.cs

Lines changed: 6 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,11 @@
44

55
namespace OnnxStack.StableDiffusion.Helpers
66
{
7+
8+
/// <summary>
9+
/// TODO: Optimization, all functions in here are tensor copy, but not all need to be
10+
/// probably some good mem/cpu gains here if a set of mutate and non-mutate functions were created
11+
/// </summary>
712
public static class TensorHelper
813
{
914
/// <summary>
@@ -104,34 +109,7 @@ public static DenseTensor<float> AddTensors(this DenseTensor<float> tensor, Dens
104109
}
105110

106111

107-
/// <summary>
108-
/// Splits the tensor.
109-
/// </summary>
110-
/// <param name="tensorToSplit">The tensor to split.</param>
111-
/// <param name="dimensions">The dimensions.</param>
112-
/// <param name="scaledHeight">Height of the scaled.</param>
113-
/// <param name="scaledWidth">Width of the scaled.</param>
114-
/// <returns></returns>
115-
public static (DenseTensor<float> noisePredUncond, DenseTensor<float> noisePredText) SplitTensor(this DenseTensor<float> tensor, ReadOnlySpan<int> dimensions)
116-
{
117-
var tensor1 = new DenseTensor<float>(dimensions);
118-
var tensor2 = new DenseTensor<float>(dimensions);
119-
for (int i = 0; i < 1; i++)
120-
{
121-
for (int j = 0; j < 4; j++)
122-
{
123-
for (int k = 0; k < dimensions[2]; k++)
124-
{
125-
for (int l = 0; l < dimensions[3]; l++)
126-
{
127-
tensor1[i, j, k, l] = tensor[i, j, k, l];
128-
tensor2[i, j, k, l] = tensor[i, j + 4, k, l];
129-
}
130-
}
131-
}
132-
}
133-
return (tensor1, tensor2);
134-
}
112+
135113

136114

137115
/// <summary>
@@ -220,31 +198,6 @@ public static DenseTensor<float> ReorderTensor(this DenseTensor<float> tensor, R
220198
}
221199

222200

223-
/// <summary>
224-
/// Performs classifier free guidance
225-
/// </summary>
226-
/// <param name="noisePred">The noise pred.</param>
227-
/// <param name="noisePredText">The noise pred text.</param>
228-
/// <param name="guidanceScale">The guidance scale.</param>
229-
/// <returns></returns>
230-
public static DenseTensor<float> PerformGuidance(this DenseTensor<float> noisePred, DenseTensor<float> noisePredText, double guidanceScale)
231-
{
232-
for (int i = 0; i < noisePred.Dimensions[0]; i++)
233-
{
234-
for (int j = 0; j < noisePred.Dimensions[1]; j++)
235-
{
236-
for (int k = 0; k < noisePred.Dimensions[2]; k++)
237-
{
238-
for (int l = 0; l < noisePred.Dimensions[3]; l++)
239-
{
240-
noisePred[i, j, k, l] = noisePred[i, j, k, l] + (float)guidanceScale * (noisePredText[i, j, k, l] - noisePred[i, j, k, l]);
241-
}
242-
}
243-
}
244-
}
245-
return noisePred;
246-
}
247-
248201

249202
/// <summary>
250203
/// Clips the specified Tensor valuse to the specified minimum/maximum.

OnnxStack.StableDiffusion/Schedulers/DDPMScheduler.cs

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
using Microsoft.ML.OnnxRuntime.Tensors;
22
using NumSharp;
3+
using OnnxStack.Core;
4+
using OnnxStack.StableDiffusion.Config;
5+
using OnnxStack.StableDiffusion.Enums;
6+
using OnnxStack.StableDiffusion.Helpers;
37
using OnnxStack.StableDiffusion.Config;
48
using OnnxStack.StableDiffusion.Enums;
59
using OnnxStack.StableDiffusion.Helpers;
@@ -89,10 +93,10 @@ public override DenseTensor<float> Step(DenseTensor<float> modelOutput, int time
8993
//# 1. compute alphas, betas
9094
float alphaProdT = _alphasCumProd[currentTimestep];
9195
float alphaProdTPrev = previousTimestep >= 0 ? _alphasCumProd[previousTimestep] : 1f;
92-
float betaProdT = 1 - alphaProdT;
93-
float betaProdTPrev = 1 - alphaProdTPrev;
96+
float betaProdT = 1f - alphaProdT;
97+
float betaProdTPrev = 1f - alphaProdTPrev;
9498
float currentAlphaT = alphaProdT / alphaProdTPrev;
95-
float currentBetaT = 1 - currentAlphaT;
99+
float currentBetaT = 1f - currentAlphaT;
96100

97101
float predictedVariance = 0;
98102

@@ -143,7 +147,8 @@ public override DenseTensor<float> Step(DenseTensor<float> modelOutput, int time
143147
var varianceNoise = CreateRandomSample(modelOutput.Dimensions);
144148
if (Options.VarianceType == VarianceType.FixedSmallLog)
145149
{
146-
variance = varianceNoise.MultipleTensorByFloat(GetVariance(currentTimestep, predictedVariance));
150+
var v = GetVariance(currentTimestep, predictedVariance);
151+
variance = varianceNoise.MultipleTensorByFloat(v);
147152
}
148153
else if (Options.VarianceType == VarianceType.LearnedRange)
149154
{

OnnxStack.StableDiffusion/Services/StableDiffusionService.cs

Lines changed: 12 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77
using SixLabors.ImageSharp;
88
using SixLabors.ImageSharp.PixelFormats;
99
using System;
10+
using System.Collections.Concurrent;
1011
using System.Collections.Generic;
12+
using System.Collections.Immutable;
1113
using System.IO;
1214
using System.Threading;
1315
using System.Threading.Tasks;
@@ -20,13 +22,9 @@ namespace OnnxStack.StableDiffusion.Services
2022
/// <seealso cref="OnnxStack.StableDiffusion.Common.IStableDiffusionService" />
2123
public sealed class StableDiffusionService : IStableDiffusionService
2224
{
23-
private readonly IDiffuser _textDiffuser;
24-
private readonly IDiffuser _imageDiffuser;
25-
private readonly IDiffuser _inpaintDiffuser;
26-
private readonly IDiffuser _inpaintLegacyDiffuser;
2725
private readonly IOnnxModelService _onnxModelService;
2826
private readonly StableDiffusionConfig _configuration;
29-
27+
private readonly IDictionary<DiffuserType, IDiffuser> _diffusers;
3028

3129
/// <summary>
3230
/// Initializes a new instance of the <see cref="StableDiffusionService"/> class.
@@ -36,10 +34,11 @@ public StableDiffusionService(StableDiffusionConfig configuration, IOnnxModelSer
3634
{
3735
_configuration = configuration;
3836
_onnxModelService = onnxModelService;
39-
_textDiffuser = new TextDiffuser(onnxModelService, promptService);
40-
_imageDiffuser = new ImageDiffuser(onnxModelService, promptService);
41-
_inpaintDiffuser = new InpaintDiffuser(onnxModelService, promptService);
42-
_inpaintLegacyDiffuser = new InpaintLegacyDiffuser(onnxModelService, promptService);
37+
_diffusers = new ConcurrentDictionary<DiffuserType, IDiffuser>();
38+
_diffusers.Add(DiffuserType.TextToImage, new TextDiffuser(onnxModelService, promptService));
39+
_diffusers.Add(DiffuserType.ImageToImage, new ImageDiffuser(onnxModelService, promptService));
40+
_diffusers.Add(DiffuserType.ImageInpaint, new InpaintDiffuser(onnxModelService, promptService));
41+
_diffusers.Add(DiffuserType.ImageInpaintLegacy, new InpaintLegacyDiffuser(onnxModelService, promptService));
4342
}
4443

4544

@@ -92,7 +91,7 @@ public bool IsModelLoaded(IModelOptions modelOptions)
9291
/// <returns>The diffusion result as <see cref="DenseTensor<float>"/></returns>
9392
public async Task<DenseTensor<float>> GenerateAsync(IModelOptions model, PromptOptions prompt, SchedulerOptions options, Action<int, int> progressCallback = null, CancellationToken cancellationToken = default)
9493
{
95-
return await RunAsync(model, prompt, options, progressCallback, cancellationToken).ConfigureAwait(false);
94+
return await DiffuseAsync(model, prompt, options, progressCallback, cancellationToken).ConfigureAwait(false);
9695
}
9796

9897

@@ -144,16 +143,10 @@ public async Task<Stream> GenerateAsStreamAsync(IModelOptions model, PromptOptio
144143
}
145144

146145

147-
private async Task<DenseTensor<float>> RunAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, Action<int, int> progress = null, CancellationToken cancellationToken = default)
146+
private async Task<DenseTensor<float>> DiffuseAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, Action<int, int> progress = null, CancellationToken cancellationToken = default)
148147
{
149-
return promptOptions.DiffuserType switch
150-
{
151-
DiffuserType.TextToImage => await _textDiffuser.DiffuseAsync(modelOptions, promptOptions, schedulerOptions, progress, cancellationToken),
152-
DiffuserType.ImageToImage => await _imageDiffuser.DiffuseAsync(modelOptions, promptOptions, schedulerOptions, progress, cancellationToken),
153-
DiffuserType.ImageInpaint => await _inpaintDiffuser.DiffuseAsync(modelOptions, promptOptions, schedulerOptions, progress, cancellationToken),
154-
DiffuserType.ImageInpaintLegacy => await _inpaintLegacyDiffuser.DiffuseAsync(modelOptions, promptOptions, schedulerOptions, progress, cancellationToken),
155-
_ => throw new NotImplementedException()
156-
};
148+
return await _diffusers[promptOptions.DiffuserType]
149+
.DiffuseAsync(modelOptions, promptOptions, schedulerOptions, progress, cancellationToken);
157150
}
158151
}
159152
}

0 commit comments

Comments
 (0)