Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit d68b75f

Browse files
committed
Batch prompt decoding support
1 parent c20a291 commit d68b75f

File tree

8 files changed

+192
-89
lines changed

8 files changed

+192
-89
lines changed

OnnxStack.StableDiffusion/Config/PromptOptions.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ public class PromptOptions
1616
public string NegativePrompt { get; set; }
1717
public SchedulerType SchedulerType { get; set; }
1818

19+
public int BatchCount { get; set; } = 1;
20+
1921
public InputImage InputImage { get; set; }
2022

2123
public InputImage InputImageMask { get; set; }

OnnxStack.StableDiffusion/Diffusers/DiffuserBase.cs

Lines changed: 15 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -84,11 +84,12 @@ public virtual async Task<DenseTensor<float>> DiffuseAsync(IModelOptions modelOp
8484
var step = 0;
8585
foreach (var timestep in timesteps)
8686
{
87+
step++;
8788
cancellationToken.ThrowIfCancellationRequested();
8889

8990
// Create input tensor.
9091
var inputLatent = performGuidance
91-
? latents.Repeat(1)
92+
? latents.Repeat(2)
9293
: latents;
9394
var inputTensor = scheduler.ScaleInput(inputLatent, timestep);
9495

@@ -108,7 +109,7 @@ public virtual async Task<DenseTensor<float>> DiffuseAsync(IModelOptions modelOp
108109
latents = scheduler.Step(noisePred, timestep, latents);
109110
}
110111

111-
progressCallback?.Invoke(++step, timesteps.Count);
112+
progressCallback?.Invoke(step, timesteps.Count);
112113
}
113114

114115
// Decode Latents
@@ -125,8 +126,7 @@ public virtual async Task<DenseTensor<float>> DiffuseAsync(IModelOptions modelOp
125126
protected async Task<DenseTensor<float>> DecodeLatents(IModelOptions model, SchedulerOptions options, DenseTensor<float> latents)
126127
{
127128
// Scale and decode the image latents with vae.
128-
// latents = 1 / 0.18215 * latents
129-
latents = latents.MultipleTensorByFloat(1.0f / model.ScaleFactor);
129+
latents = latents.MultiplyBy(1.0f / model.ScaleFactor);
130130

131131
var inputNames = _onnxModelService.GetInputNames(model, OnnxModelType.VaeDecoder);
132132
var inputParameters = CreateInputParameters(NamedOnnxValue.CreateFromTensor(inputNames[0], latents));
@@ -145,57 +145,27 @@ protected async Task<DenseTensor<float>> DecodeLatents(IModelOptions model, Sche
145145
}
146146
}
147147

148+
148149
/// <summary>
149150
/// Performs classifier free guidance
150151
/// </summary>
151152
/// <param name="noisePredUncond">The noise pred.</param>
152153
/// <param name="noisePredText">The noise pred text.</param>
153154
/// <param name="guidanceScale">The guidance scale.</param>
154155
/// <returns></returns>
155-
protected DenseTensor<float> PerformGuidance(DenseTensor<float> noisePrediction, double guidanceScale)
156+
protected DenseTensor<float> PerformGuidance(DenseTensor<float> noisePrediction, float guidanceScale)
156157
{
157158
// Split Prompt and Negative Prompt predictions
158-
var (noisePredCond, noisePredUncond) = SplitPredictedNoise(noisePrediction);
159-
for (int i = 0; i < noisePredUncond.Dimensions[0]; i++)
160-
{
161-
for (int j = 0; j < noisePredUncond.Dimensions[1]; j++)
162-
{
163-
for (int k = 0; k < noisePredUncond.Dimensions[2]; k++)
164-
{
165-
for (int l = 0; l < noisePredUncond.Dimensions[3]; l++)
166-
{
167-
noisePredUncond[i, j, k, l] = noisePredUncond[i, j, k, l] + (float)guidanceScale * (noisePredCond[i, j, k, l] - noisePredUncond[i, j, k, l]);
168-
}
169-
}
170-
}
171-
}
172-
return noisePredUncond;
173-
}
174-
175-
176-
/// <summary>
177-
/// Splits the predicted noise.
178-
/// </summary>
179-
/// <param name="predictedNoiseSample">The predicted noise sample.</param>
180-
/// <returns>Split Prompt and Negative Prompt predictions</returns>
181-
protected (DenseTensor<float> noisePredCond, DenseTensor<float> noisePredUncond) SplitPredictedNoise(DenseTensor<float> predictedNoiseSample)
182-
{
183-
var dimensions = predictedNoiseSample.Dimensions.ToArray();
159+
var dimensions = noisePrediction.Dimensions.ToArray();
184160
dimensions[0] /= 2;
185-
var noisePredCond = new DenseTensor<float>(dimensions);
186-
var noisePredUncond = new DenseTensor<float>(dimensions);
187-
for (int j = 0; j < 4; j++)
188-
{
189-
for (int k = 0; k < dimensions[2]; k++)
190-
{
191-
for (int l = 0; l < dimensions[3]; l++)
192-
{
193-
noisePredUncond[0, j, k, l] = predictedNoiseSample[0, j, k, l];
194-
noisePredCond[0, j, k, l] = predictedNoiseSample[0, j + 4, k, l];
195-
}
196-
}
197-
}
198-
return (noisePredCond, noisePredUncond);
161+
162+
var length = (int)noisePrediction.Length / 2;
163+
var noisePredCond = new DenseTensor<float>(noisePrediction.Buffer[length..], dimensions);
164+
var noisePredUncond = new DenseTensor<float>(noisePrediction.Buffer[..length], dimensions);
165+
return noisePredUncond
166+
.Add(noisePredCond
167+
.Subtract(noisePredUncond)
168+
.MultiplyBy(guidanceScale));
199169
}
200170

201171

OnnxStack.StableDiffusion/Diffusers/ImageDiffuser.cs

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ public sealed class ImageDiffuser : DiffuserBase
2222
/// <param name="configuration">The configuration.</param>
2323
/// <param name="onnxModelService">The onnx model service.</param>
2424
public ImageDiffuser(IOnnxModelService onnxModelService, IPromptService promptService)
25-
:base(onnxModelService, promptService)
25+
: base(onnxModelService, promptService)
2626
{
2727
}
2828

@@ -59,11 +59,15 @@ protected override DenseTensor<float> PrepareLatents(IModelOptions model, Prompt
5959
using (var inferResult = _onnxModelService.RunInference(model, OnnxModelType.VaeEncoder, inputParameters))
6060
{
6161
var sample = inferResult.FirstElementAs<DenseTensor<float>>();
62-
var noisySample = sample
63-
.AddTensors(scheduler.CreateRandomSample(sample.Dimensions, options.InitialNoiseLevel))
64-
.MultipleTensorByFloat(model.ScaleFactor);
65-
var noise = scheduler.CreateRandomSample(sample.Dimensions);
66-
return scheduler.AddNoise(noisySample, noise, timesteps);
62+
var scaledSample = sample
63+
.Add(scheduler.CreateRandomSample(sample.Dimensions, options.InitialNoiseLevel))
64+
.MultiplyBy(model.ScaleFactor);
65+
66+
var noisySample = scheduler.AddNoise(scaledSample, scheduler.CreateRandomSample(scaledSample.Dimensions), timesteps);
67+
if (prompt.BatchCount > 1)
68+
return noisySample.Repeat(prompt.BatchCount);
69+
70+
return noisySample;
6771
}
6872
}
6973

OnnxStack.StableDiffusion/Diffusers/InpaintDiffuser.cs

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,10 @@ public override async Task<DenseTensor<float>> DiffuseAsync(IModelOptions modelO
5656
var latents = PrepareLatents(modelOptions, promptOptions, schedulerOptions, scheduler, timesteps);
5757

5858
// Create Image Mask
59-
var maskImage = PrepareMask(modelOptions, promptOptions, schedulerOptions)
60-
.Duplicate(schedulerOptions.GetScaledDimension(2, 1));
59+
var maskImage = PrepareMask(modelOptions, promptOptions, schedulerOptions);
6160

6261
// Create Masked Image Latents
63-
var maskedImage = PrepareImageMask(modelOptions, promptOptions, schedulerOptions)
64-
.Duplicate(schedulerOptions.GetScaledDimension(2));
62+
var maskedImage = PrepareImageMask(modelOptions, promptOptions, schedulerOptions);
6563

6664
// Loop though the timesteps
6765
var step = 0;
@@ -71,7 +69,7 @@ public override async Task<DenseTensor<float>> DiffuseAsync(IModelOptions modelO
7169

7270
// Create input tensor.
7371
var inputLatent = performGuidance
74-
? latents.Repeat(1)
72+
? latents.Repeat(2)
7573
: latents;
7674
var inputTensor = scheduler.ScaleInput(inputLatent, timestep);
7775
inputTensor = ConcatenateLatents(inputTensor, maskedImage, maskImage);
@@ -136,7 +134,14 @@ private DenseTensor<float> PrepareMask(IModelOptions modelOptions, PromptOptions
136134
}
137135
});
138136

139-
return imageTensor.MultipleTensorByFloat(modelOptions.ScaleFactor);
137+
imageTensor = imageTensor.MultiplyBy(modelOptions.ScaleFactor);
138+
if (promptOptions.BatchCount > 1)
139+
imageTensor = imageTensor.Repeat(promptOptions.BatchCount);
140+
141+
if (schedulerOptions.GuidanceScale > 1f)
142+
imageTensor = imageTensor.Repeat(2);
143+
144+
return imageTensor;
140145
}
141146
}
142147

@@ -200,7 +205,13 @@ private DenseTensor<float> PrepareImageMask(IModelOptions modelOptions, PromptOp
200205
using (var inferResult = _onnxModelService.RunInference(modelOptions, OnnxModelType.VaeEncoder, inputParameters))
201206
{
202207
var sample = inferResult.FirstElementAs<DenseTensor<float>>();
203-
var scaledSample = sample.MultipleTensorByFloat(modelOptions.ScaleFactor);
208+
var scaledSample = sample.MultiplyBy(modelOptions.ScaleFactor);
209+
if (promptOptions.BatchCount > 1)
210+
scaledSample = scaledSample.Repeat(promptOptions.BatchCount);
211+
212+
if (schedulerOptions.GuidanceScale > 1f)
213+
scaledSample = scaledSample.Repeat(2);
214+
204215
return scaledSample;
205216
}
206217
}
@@ -230,7 +241,7 @@ protected override IReadOnlyList<int> GetTimesteps(PromptOptions prompt, Schedul
230241
/// <returns></returns>
231242
protected override DenseTensor<float> PrepareLatents(IModelOptions model, PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList<int> timesteps)
232243
{
233-
return scheduler.CreateRandomSample(options.GetScaledDimension(), scheduler.InitNoiseSigma);
244+
return scheduler.CreateRandomSample(options.GetScaledDimension(prompt.BatchCount), scheduler.InitNoiseSigma);
234245
}
235246

236247

OnnxStack.StableDiffusion/Diffusers/InpaintLegacyDiffuser.cs

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ public override async Task<DenseTensor<float>> DiffuseAsync(IModelOptions modelO
5656

5757
// Add noise to original latent
5858
var latents = scheduler.AddNoise(latentsOriginal, noise, timesteps);
59-
59+
6060
// Loop though the timesteps
6161
var step = 0;
6262
foreach (var timestep in timesteps)
@@ -65,7 +65,7 @@ public override async Task<DenseTensor<float>> DiffuseAsync(IModelOptions modelO
6565

6666
// Create input tensor.
6767
var inputLatent = performGuidance
68-
? latents.Repeat(1)
68+
? latents.Repeat(2)
6969
: latents;
7070
var inputTensor = scheduler.ScaleInput(inputLatent, timestep);
7171

@@ -131,10 +131,15 @@ protected override DenseTensor<float> PrepareLatents(IModelOptions model, Prompt
131131
using (var inferResult = _onnxModelService.RunInference(model, OnnxModelType.VaeEncoder, inputParameters))
132132
{
133133
var sample = inferResult.FirstElementAs<DenseTensor<float>>();
134-
var noisySample = sample
135-
.AddTensors(scheduler.CreateRandomSample(sample.Dimensions, options.InitialNoiseLevel))
136-
.MultipleTensorByFloat(model.ScaleFactor);
137-
return noisySample;
134+
var scaledSample = sample
135+
.Add(scheduler.CreateRandomSample(sample.Dimensions, options.InitialNoiseLevel))
136+
.MultiplyBy(model.ScaleFactor)
137+
.ToDenseTensor();
138+
139+
if (prompt.BatchCount > 1)
140+
return scaledSample.Repeat(prompt.BatchCount);
141+
142+
return scaledSample;
138143
}
139144
}
140145

@@ -170,6 +175,10 @@ private DenseTensor<float> PrepareMask(IModelOptions modelOptions, PromptOptions
170175
}
171176
}
172177
});
178+
179+
if (promptOptions.BatchCount > 1)
180+
return maskTensor.Repeat(promptOptions.BatchCount);
181+
173182
return maskTensor;
174183
}
175184
}

OnnxStack.StableDiffusion/Diffusers/TextDiffuser.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ protected override IReadOnlyList<int> GetTimesteps(PromptOptions prompt, Schedul
4242
/// <returns></returns>
4343
protected override DenseTensor<float> PrepareLatents(IModelOptions model, PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList<int> timesteps)
4444
{
45-
return scheduler.CreateRandomSample(options.GetScaledDimension(), scheduler.InitNoiseSigma);
45+
return scheduler.CreateRandomSample(options.GetScaledDimension(prompt.BatchCount), scheduler.InitNoiseSigma);
4646
}
4747
}
4848
}

0 commit comments

Comments
 (0)