Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit c20a291

Browse files
committed
Support disabled classifier free guidance
1 parent 5aaa63c commit c20a291

File tree

6 files changed

+120
-86
lines changed

6 files changed

+120
-86
lines changed

OnnxStack.StableDiffusion/Common/IPromptService.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
using Microsoft.ML.OnnxRuntime.Tensors;
2+
using OnnxStack.StableDiffusion.Config;
23
using System.Threading.Tasks;
34

45
namespace OnnxStack.StableDiffusion.Common
56
{
67
public interface IPromptService
78
{
8-
Task<DenseTensor<float>> CreatePromptAsync(IModelOptions model, string prompt, string negativePrompt);
9+
Task<DenseTensor<float>> CreatePromptAsync(IModelOptions model, PromptOptions promptOptions, SchedulerOptions schedulerOptions);
910
Task<int[]> DecodeTextAsync(IModelOptions model, string inputText);
1011
Task<float[]> EncodeTokensAsync(IModelOptions model, int[] tokenizedInput);
1112
}

OnnxStack.StableDiffusion/Diffusers/DiffuserBase.cs

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,10 @@ public virtual async Task<DenseTensor<float>> DiffuseAsync(IModelOptions modelOp
6969
using (var scheduler = GetScheduler(promptOptions, schedulerOptions))
7070
{
7171
// Process prompts
72-
var promptEmbeddings = await _promptService.CreatePromptAsync(modelOptions, promptOptions.Prompt, promptOptions.NegativePrompt);
72+
var promptEmbeddings = await _promptService.CreatePromptAsync(modelOptions, promptOptions, schedulerOptions);
73+
74+
// Should we perform classifier free guidance
75+
var performGuidance = schedulerOptions.GuidanceScale > 1.0f;
7376

7477
// Get timesteps
7578
var timesteps = GetTimesteps(promptOptions, schedulerOptions, scheduler);
@@ -84,25 +87,22 @@ public virtual async Task<DenseTensor<float>> DiffuseAsync(IModelOptions modelOp
8487
cancellationToken.ThrowIfCancellationRequested();
8588

8689
// Create input tensor.
87-
var inputTensor = scheduler.ScaleInput(latents.Duplicate(schedulerOptions.GetScaledDimension(2)), timestep);
90+
var inputLatent = performGuidance
91+
? latents.Repeat(1)
92+
: latents;
93+
var inputTensor = scheduler.ScaleInput(inputLatent, timestep);
8894

8995
// Create Input Parameters
90-
var inputNames = _onnxModelService.GetInputNames(modelOptions, OnnxModelType.Unet);
91-
var inputParameters = CreateInputParameters(
92-
NamedOnnxValue.CreateFromTensor(inputNames[0], inputTensor),
93-
NamedOnnxValue.CreateFromTensor(inputNames[1], new DenseTensor<long>(new long[] { timestep }, new int[] { 1 })),
94-
NamedOnnxValue.CreateFromTensor(inputNames[2], promptEmbeddings));
96+
var inputParameters = CreateUnetInputParams(modelOptions, inputTensor, promptEmbeddings, timestep);
9597

9698
// Run Inference
9799
using (var inferResult = await _onnxModelService.RunInferenceAsync(modelOptions, OnnxModelType.Unet, inputParameters))
98100
{
99101
var noisePred = inferResult.FirstElementAs<DenseTensor<float>>();
100102

101103
// Perform guidance
102-
if (schedulerOptions.GuidanceScale > 1.0f)
103-
{
104+
if (performGuidance)
104105
noisePred = PerformGuidance(noisePred, schedulerOptions.GuidanceScale);
105-
}
106106

107107
// Scheduler Step
108108
latents = scheduler.Step(noisePred, timestep, latents);
@@ -199,6 +199,24 @@ protected DenseTensor<float> PerformGuidance(DenseTensor<float> noisePrediction,
199199
}
200200

201201

202+
/// <summary>
203+
/// Creates the Unet input parameters.
204+
/// </summary>
205+
/// <param name="model">The model.</param>
206+
/// <param name="inputTensor">The input tensor.</param>
207+
/// <param name="promptEmbeddings">The prompt embeddings.</param>
208+
/// <param name="timestep">The timestep.</param>
209+
/// <returns></returns>
210+
protected virtual IReadOnlyList<NamedOnnxValue> CreateUnetInputParams(IModelOptions model, DenseTensor<float> inputTensor, DenseTensor<float> promptEmbeddings, int timestep)
211+
{
212+
var inputNames = _onnxModelService.GetInputNames(model, OnnxModelType.Unet);
213+
return CreateInputParameters(
214+
NamedOnnxValue.CreateFromTensor(inputNames[0], inputTensor),
215+
NamedOnnxValue.CreateFromTensor(inputNames[1], new DenseTensor<long>(new long[] { timestep }, new int[] { 1 })),
216+
NamedOnnxValue.CreateFromTensor(inputNames[2], promptEmbeddings));
217+
}
218+
219+
202220
/// <summary>
203221
/// Determines whether the specified result image is not NSFW.
204222
/// </summary>
@@ -286,9 +304,9 @@ protected static IScheduler GetScheduler(PromptOptions prompt, SchedulerOptions
286304
/// </summary>
287305
/// <param name="parameters">The parameters.</param>
288306
/// <returns></returns>
289-
protected static IReadOnlyCollection<NamedOnnxValue> CreateInputParameters(params NamedOnnxValue[] parameters)
307+
protected static IReadOnlyList<NamedOnnxValue> CreateInputParameters(params NamedOnnxValue[] parameters)
290308
{
291-
return parameters.ToList().AsReadOnly();
309+
return parameters.ToList();
292310
}
293311
}
294312
}

OnnxStack.StableDiffusion/Diffusers/InpaintDiffuser.cs

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,10 @@ public override async Task<DenseTensor<float>> DiffuseAsync(IModelOptions modelO
4444
using (var scheduler = GetScheduler(promptOptions, schedulerOptions))
4545
{
4646
// Process prompts
47-
var promptEmbeddings = await _promptService.CreatePromptAsync(modelOptions, promptOptions.Prompt, promptOptions.NegativePrompt);
47+
var promptEmbeddings = await _promptService.CreatePromptAsync(modelOptions, promptOptions, schedulerOptions);
48+
49+
// Should we perform classifier free guidance
50+
var performGuidance = schedulerOptions.GuidanceScale > 1.0f;
4851

4952
// Get timesteps
5053
var timesteps = GetTimesteps(promptOptions, schedulerOptions, scheduler);
@@ -67,26 +70,23 @@ public override async Task<DenseTensor<float>> DiffuseAsync(IModelOptions modelO
6770
cancellationToken.ThrowIfCancellationRequested();
6871

6972
// Create input tensor.
70-
var inputTensor = scheduler.ScaleInput(latents.Duplicate(schedulerOptions.GetScaledDimension(2)), timestep);
73+
var inputLatent = performGuidance
74+
? latents.Repeat(1)
75+
: latents;
76+
var inputTensor = scheduler.ScaleInput(inputLatent, timestep);
7177
inputTensor = ConcatenateLatents(inputTensor, maskedImage, maskImage);
7278

7379
// Create Input Parameters
74-
var inputNames = _onnxModelService.GetInputNames(modelOptions, OnnxModelType.Unet);
75-
var inputParameters = CreateInputParameters(
76-
NamedOnnxValue.CreateFromTensor(inputNames[0], inputTensor),
77-
NamedOnnxValue.CreateFromTensor(inputNames[1], new DenseTensor<long>(new long[] { timestep }, new int[] { 1 })),
78-
NamedOnnxValue.CreateFromTensor(inputNames[2], promptEmbeddings));
80+
var inputParameters = CreateUnetInputParams(modelOptions, inputTensor, promptEmbeddings, timestep);
7981

8082
// Run Inference
8183
using (var inferResult = await _onnxModelService.RunInferenceAsync(modelOptions, OnnxModelType.Unet, inputParameters))
8284
{
8385
var noisePred = inferResult.FirstElementAs<DenseTensor<float>>();
8486

8587
// Perform guidance
86-
if (schedulerOptions.GuidanceScale > 1.0f)
87-
{
88+
if (performGuidance)
8889
noisePred = PerformGuidance(noisePred, schedulerOptions.GuidanceScale);
89-
}
9090

9191
// Scheduler Step
9292
latents = scheduler.Step(noisePred, timestep, latents);

OnnxStack.StableDiffusion/Diffusers/InpaintLegacyDiffuser.cs

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,10 @@ public override async Task<DenseTensor<float>> DiffuseAsync(IModelOptions modelO
3737
using (var scheduler = GetScheduler(promptOptions, schedulerOptions))
3838
{
3939
// Process prompts
40-
var promptEmbeddings = await _promptService.CreatePromptAsync(modelOptions, promptOptions.Prompt, promptOptions.NegativePrompt);
40+
var promptEmbeddings = await _promptService.CreatePromptAsync(modelOptions, promptOptions, schedulerOptions);
41+
42+
// Should we perform classifier free guidance
43+
var performGuidance = schedulerOptions.GuidanceScale > 1.0f;
4144

4245
// Get timesteps
4346
var timesteps = GetTimesteps(promptOptions, schedulerOptions, scheduler);
@@ -53,33 +56,30 @@ public override async Task<DenseTensor<float>> DiffuseAsync(IModelOptions modelO
5356

5457
// Add noise to original latent
5558
var latents = scheduler.AddNoise(latentsOriginal, noise, timesteps);
56-
59+
5760
// Loop though the timesteps
5861
var step = 0;
5962
foreach (var timestep in timesteps)
6063
{
6164
cancellationToken.ThrowIfCancellationRequested();
6265

6366
// Create input tensor.
64-
var inputTensor = scheduler.ScaleInput(latents.Duplicate(schedulerOptions.GetScaledDimension(2)), timestep);
67+
var inputLatent = performGuidance
68+
? latents.Repeat(1)
69+
: latents;
70+
var inputTensor = scheduler.ScaleInput(inputLatent, timestep);
6571

6672
// Create Input Parameters
67-
var inputNames = _onnxModelService.GetInputNames(modelOptions, OnnxModelType.Unet);
68-
var inputParameters = CreateInputParameters(
69-
NamedOnnxValue.CreateFromTensor(inputNames[0], inputTensor),
70-
NamedOnnxValue.CreateFromTensor(inputNames[1], new DenseTensor<long>(new long[] { timestep }, new int[] { 1 })),
71-
NamedOnnxValue.CreateFromTensor(inputNames[2], promptEmbeddings));
73+
var inputParameters = CreateUnetInputParams(modelOptions, inputTensor, promptEmbeddings, timestep);
7274

7375
// Run Inference
7476
using (var inferResult = await _onnxModelService.RunInferenceAsync(modelOptions, OnnxModelType.Unet, inputParameters))
7577
{
7678
var noisePred = inferResult.FirstElementAs<DenseTensor<float>>();
7779

7880
// Perform guidance
79-
if (schedulerOptions.GuidanceScale > 1.0f)
80-
{
81+
if (performGuidance)
8182
noisePred = PerformGuidance(noisePred, schedulerOptions.GuidanceScale);
82-
}
8383

8484
// Scheduler Step
8585
var steplatents = scheduler.Step(noisePred, timestep, latents);

OnnxStack.StableDiffusion/Helpers/TensorHelper.cs

Lines changed: 47 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -109,9 +109,6 @@ public static DenseTensor<float> AddTensors(this DenseTensor<float> tensor, Dens
109109
}
110110

111111

112-
113-
114-
115112
/// <summary>
116113
/// Sums the tensors.
117114
/// </summary>
@@ -242,20 +239,9 @@ public static DenseTensor<float> Abs(this DenseTensor<float> tensor)
242239
public static DenseTensor<float> Multiply(this DenseTensor<float> tensor1, DenseTensor<float> tensor2)
243240
{
244241
var result = new DenseTensor<float>(tensor1.Dimensions);
245-
for (int batch = 0; batch < tensor1.Dimensions[0]; batch++)
242+
for (int i = 0; i < tensor1.Length; i++)
246243
{
247-
for (int channel = 0; channel < tensor1.Dimensions[1]; channel++)
248-
{
249-
for (int height = 0; height < tensor1.Dimensions[2]; height++)
250-
{
251-
for (int width = 0; width < tensor1.Dimensions[3]; width++)
252-
{
253-
var value1 = tensor1[batch, channel, height, width];
254-
var value2 = tensor2[batch, channel, height, width];
255-
result[batch, channel, height, width] = value1 * value2;
256-
}
257-
}
258-
}
244+
result.SetValue(i, tensor1.GetValue(i) * tensor2.GetValue(i));
259245
}
260246
return result;
261247
}
@@ -270,25 +256,57 @@ public static DenseTensor<float> Multiply(this DenseTensor<float> tensor1, Dense
270256
public static DenseTensor<float> Divide(this DenseTensor<float> tensor1, DenseTensor<float> tensor2)
271257
{
272258
var result = new DenseTensor<float>(tensor1.Dimensions);
273-
for (int batch = 0; batch < tensor1.Dimensions[0]; batch++)
259+
for (int i = 0; i < tensor1.Length; i++)
274260
{
275-
for (int channel = 0; channel < tensor1.Dimensions[1]; channel++)
276-
{
277-
for (int height = 0; height < tensor1.Dimensions[2]; height++)
278-
{
279-
for (int width = 0; width < tensor1.Dimensions[3]; width++)
280-
{
281-
var value1 = tensor1[batch, channel, height, width];
282-
var value2 = tensor2[batch, channel, height, width];
283-
result[batch, channel, height, width] = value1 / value2;
284-
}
285-
}
286-
}
261+
result.SetValue(i, tensor1.GetValue(i) / tensor2.GetValue(i));
287262
}
288263
return result;
289264
}
290265

291266

267+
/// <summary>
268+
/// Concatenates the specified tensors along the 0 axis.
269+
/// </summary>
270+
/// <param name="tensor1">The tensor1.</param>
271+
/// <param name="tensor2">The tensor2.</param>
272+
/// <param name="axis">The axis.</param>
273+
/// <returns></returns>
274+
/// <exception cref="System.NotImplementedException">Only axis 0 is supported</exception>
275+
public static DenseTensor<float> Concatenate(this DenseTensor<float> tensor1, DenseTensor<float> tensor2, int axis = 0)
276+
{
277+
if (axis != 0)
278+
throw new NotImplementedException("Only axis 0 is supported");
279+
280+
var dimensions = tensor1.Dimensions.ToArray();
281+
dimensions[0] += tensor2.Dimensions[0];
282+
return CreateTensor(tensor1.Concat(tensor2).ToArray(), dimensions);
283+
}
284+
285+
286+
/// <summary>
287+
/// Repeats the specified Tensor along the 0 axis.
288+
/// </summary>
289+
/// <param name="tensor1">The tensor1.</param>
290+
/// <param name="count">The count.</param>
291+
/// <param name="axis">The axis.</param>
292+
/// <returns></returns>
293+
/// <exception cref="System.NotImplementedException">Only axis 0 is supported</exception>
294+
public static DenseTensor<float> Repeat(this DenseTensor<float> tensor1, int count, int axis = 0)
295+
{
296+
if (axis != 0)
297+
throw new NotImplementedException("Only axis 0 is supported");
298+
299+
var data = tensor1.ToArray();
300+
var dimensions = tensor1.Dimensions.ToArray();
301+
for (int i = 0; i < count; i++)
302+
{
303+
dimensions[0] += tensor1.Dimensions[0];
304+
data = data.Concat(tensor1).ToArray();
305+
}
306+
return CreateTensor(data, dimensions);
307+
}
308+
309+
292310
/// <summary>
293311
/// Generate a random Tensor from a normal distribution with mean 0 and variance 1
294312
/// </summary>

OnnxStack.StableDiffusion/Services/PromptService.cs

Lines changed: 20 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
1-
using Microsoft.ML.OnnxRuntime.Tensors;
2-
using Microsoft.ML.OnnxRuntime;
3-
using OnnxStack.Core.Config;
1+
using Microsoft.ML.OnnxRuntime;
2+
using Microsoft.ML.OnnxRuntime.Tensors;
43
using OnnxStack.Core;
4+
using OnnxStack.Core.Config;
5+
using OnnxStack.Core.Services;
6+
using OnnxStack.StableDiffusion.Common;
7+
using OnnxStack.StableDiffusion.Config;
58
using OnnxStack.StableDiffusion.Helpers;
69
using System;
710
using System.Collections.Generic;
11+
using System.Collections.Immutable;
812
using System.Linq;
9-
using System.Text;
1013
using System.Threading.Tasks;
11-
using OnnxStack.Core.Services;
12-
using OnnxStack.StableDiffusion.Common;
13-
using System.Collections.Immutable;
1414

1515
namespace OnnxStack.StableDiffusion.Services
1616
{
@@ -35,28 +35,23 @@ public PromptService(IOnnxModelService onnxModelService)
3535
/// <param name="prompt">The prompt.</param>
3636
/// <param name="negativePrompt">The negative prompt.</param>
3737
/// <returns>Tensor containing all text embeds generated from the prompt and negative prompt</returns>
38-
public async Task<DenseTensor<float>> CreatePromptAsync(IModelOptions model, string prompt, string negativePrompt)
38+
public async Task<DenseTensor<float>> CreatePromptAsync(IModelOptions model, PromptOptions promptOptions, SchedulerOptions schedulerOptions)
3939
{
4040
// Tokenize Prompt and NegativePrompt
41-
var promptTokens = await DecodeTextAsync(model, prompt);
42-
var negativePromptTokens = await DecodeTextAsync(model, negativePrompt);
41+
var promptTokens = await DecodeTextAsync(model, promptOptions.Prompt);
42+
var negativePromptTokens = await DecodeTextAsync(model, promptOptions.NegativePrompt);
4343
var maxPromptTokenCount = Math.Max(promptTokens.Length, negativePromptTokens.Length);
4444

45-
Console.WriteLine($"Prompt - Length: {prompt.Length}, Tokens: {promptTokens.Length}");
46-
Console.WriteLine($"N-Prompt - Length: {negativePrompt?.Length}, Tokens: {negativePromptTokens.Length}");
47-
4845
// Generate embeds for tokens
4946
var promptEmbeddings = await GenerateEmbedsAsync(model, promptTokens, maxPromptTokenCount);
5047
var negativePromptEmbeddings = await GenerateEmbedsAsync(model, negativePromptTokens, maxPromptTokenCount);
5148

52-
// Calculate embeddings
53-
var textEmbeddings = new DenseTensor<float>(new[] { 2, promptEmbeddings.Count / model.EmbeddingsLength, model.EmbeddingsLength });
54-
for (var i = 0; i < promptEmbeddings.Count; i++)
55-
{
56-
textEmbeddings[0, i / model.EmbeddingsLength, i % model.EmbeddingsLength] = negativePromptEmbeddings[i];
57-
textEmbeddings[1, i / model.EmbeddingsLength, i % model.EmbeddingsLength] = promptEmbeddings[i];
58-
}
59-
return textEmbeddings;
49+
// If we are doing guided diffusion, concatenate the negative prompt embeddings
50+
// If not we ingore the negative prompt embeddings
51+
if (schedulerOptions.GuidanceScale > 1)
52+
return negativePromptEmbeddings.Concatenate(promptEmbeddings);
53+
54+
return promptEmbeddings;
6055
}
6156

6257

@@ -111,7 +106,7 @@ public async Task<float[]> EncodeTokensAsync(IModelOptions model, int[] tokenize
111106
/// <param name="inputTokens">The input tokens.</param>
112107
/// <param name="minimumLength">The minimum length.</param>
113108
/// <returns></returns>
114-
private async Task<List<float>> GenerateEmbedsAsync(IModelOptions model, int[] inputTokens, int minimumLength)
109+
private async Task<DenseTensor<float>> GenerateEmbedsAsync(IModelOptions model, int[] inputTokens, int minimumLength)
115110
{
116111
// If less than minimumLength pad with blank tokens
117112
if (inputTokens.Length < minimumLength)
@@ -124,7 +119,9 @@ private async Task<List<float>> GenerateEmbedsAsync(IModelOptions model, int[] i
124119
var tokens = PadWithBlankTokens(tokenBatch, model.TokenizerLimit, model.BlankTokenValueArray);
125120
embeddings.AddRange(await EncodeTokensAsync(model, tokens.ToArray()));
126121
}
127-
return embeddings;
122+
123+
var dim = new[] { 1, embeddings.Count / model.EmbeddingsLength, model.EmbeddingsLength };
124+
return TensorHelper.CreateTensor(embeddings.ToArray(), dim);
128125
}
129126

130127

0 commit comments

Comments
 (0)