Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit 7d32b89

Browse files
committed
ConditionScale added, Support non normalized image input
1 parent ef6bb14 commit 7d32b89

File tree

8 files changed

+73
-88
lines changed

8 files changed

+73
-88
lines changed

OnnxStack.Core/Config/OnnxModelType.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ public enum OnnxModelType
1010
VaeEncoder = 30,
1111
VaeDecoder = 40,
1212
Control = 50,
13-
Upscaler = 1000
13+
Annotation = 100,
14+
Upscaler = 1000,
1415
}
1516
}

OnnxStack.Core/Extensions/OrtValueExtensions.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,11 @@ public static OrtValue ToOrtValue(this DenseTensor<long> tensor, OnnxNamedMetada
5959
return OrtValue.CreateTensorValueFromMemory(OrtMemoryInfo.DefaultInstance, tensor.Buffer, tensor.Dimensions.ToLong());
6060
}
6161

62+
public static OrtValue ToOrtValue(this DenseTensor<double> tensor, OnnxNamedMetadata metadata)
63+
{
64+
return OrtValue.CreateTensorValueFromMemory(OrtMemoryInfo.DefaultInstance, tensor.Buffer, tensor.Dimensions.ToLong());
65+
}
66+
6267

6368
/// <summary>
6469
/// Creates and allocates the output tensors buffer.

OnnxStack.Core/Model/OnnxInferenceParameters.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,12 @@ public void AddInputTensor(DenseTensor<float> value)
4545
_inputs.Add(metaData, value.ToOrtValue(metaData));
4646
}
4747

48+
public void AddInputTensor(DenseTensor<double> value)
49+
{
50+
var metaData = GetNextInputMetadata();
51+
_inputs.Add(metaData, value.ToOrtValue(metaData));
52+
}
53+
4854

4955
/// <summary>
5056
/// Adds the input tensor.

OnnxStack.StableDiffusion/Config/SchedulerOptions.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ public record SchedulerOptions
8383

8484
public float AestheticScore { get; set; } = 6f;
8585
public float AestheticNegativeScore { get; set; } = 2.5f;
86+
public float ConditioningScale { get; set; } = 0.7f;
8687

8788
public bool IsKarrasScheduler
8889
{

OnnxStack.StableDiffusion/Diffusers/ControlNet/ControlNetDiffuser.cs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,15 @@ public ControlNetDiffuser(IOnnxModelService onnxModelService, IPromptService pro
3535
public override DiffuserPipelineType PipelineType => DiffuserPipelineType.ControlNet;
3636

3737

38-
38+
/// <summary>
39+
/// Creates the Conditioning Scale tensor.
40+
/// </summary>
41+
/// <param name="conditioningScale">The conditioningScale.</param>
42+
/// <returns></returns>
43+
protected static DenseTensor<double> CreateConditioningScaleTensor(float conditioningScale)
44+
{
45+
return TensorHelper.CreateTensor(new double[] { conditioningScale }, new int[] { 1 });
46+
}
3947

4048

4149
/// <summary>

OnnxStack.StableDiffusion/Diffusers/ControlNet/ImageDiffuser.cs renamed to OnnxStack.StableDiffusion/Diffusers/ControlNet/TextDiffuser.cs

Lines changed: 29 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
using OnnxStack.StableDiffusion.Enums;
1010
using OnnxStack.StableDiffusion.Helpers;
1111
using OnnxStack.StableDiffusion.Models;
12-
using SixLabors.ImageSharp;
1312
using System;
1413
using System.Collections.Generic;
1514
using System.Diagnostics;
@@ -19,14 +18,14 @@
1918

2019
namespace OnnxStack.StableDiffusion.Diffusers.ControlNet
2120
{
22-
public sealed class ImageDiffuser : ControlNetDiffuser
21+
public sealed class TextDiffuser : ControlNetDiffuser
2322
{
2423
/// <summary>
25-
/// Initializes a new instance of the <see cref="ImageDiffuser"/> class.
24+
/// Initializes a new instance of the <see cref="TextDiffuser"/> class.
2625
/// </summary>
2726
/// <param name="configuration">The configuration.</param>
2827
/// <param name="onnxModelService">The onnx model service.</param>
29-
public ImageDiffuser(IOnnxModelService onnxModelService, IPromptService promptService, ILogger<ImageDiffuser> logger)
28+
public TextDiffuser(IOnnxModelService onnxModelService, IPromptService promptService, ILogger<TextDiffuser> logger)
3029
: base(onnxModelService, promptService, logger)
3130
{
3231
}
@@ -65,9 +64,9 @@ protected override async Task<DenseTensor<float>> SchedulerStepAsync(StableDiffu
6564

6665
// Get Model metadata
6766
var controlNetMetadata = _onnxModelService.GetModelMetadata(modelOptions, OnnxModelType.Control);
68-
69-
// TODO: do we need to pre-process?
70-
var controlImage = promptOptions.InputImage.ToDenseTensor(new[] { 1, 3, schedulerOptions.Height, schedulerOptions.Width });
67+
68+
// Control Image
69+
var controlImage = promptOptions.InputImage.ToDenseTensor(new[] { 1, 3, schedulerOptions.Height, schedulerOptions.Width }, false);
7170

7271
// Loop though the timesteps
7372
var step = 0;
@@ -81,6 +80,8 @@ protected override async Task<DenseTensor<float>> SchedulerStepAsync(StableDiffu
8180
var inputLatent = performGuidance ? latents.Repeat(2) : latents;
8281
var inputTensor = scheduler.ScaleInput(inputLatent, timestep);
8382
var timestepTensor = CreateTimestepTensor(timestep);
83+
var controlImageTensor = performGuidance ? controlImage.Repeat(2) : controlImage;
84+
var conditioningScale = CreateConditioningScaleTensor(schedulerOptions.ConditioningScale);
8485

8586
var outputChannels = performGuidance ? 2 : 1;
8687
var outputDimension = schedulerOptions.GetScaledDimension(outputChannels);
@@ -97,34 +98,35 @@ protected override async Task<DenseTensor<float>> SchedulerStepAsync(StableDiffu
9798
controlNetParameters.AddInputTensor(timestepTensor);
9899
controlNetParameters.AddInputTensor(promptEmbeddings.PromptEmbeds);
99100
controlNetParameters.AddInputTensor(controlImage);
101+
controlNetParameters.AddInputTensor(conditioningScale);
102+
103+
// Optimization: Pre-allocate device buffers for inputs
100104
foreach (var item in controlNetMetadata.Outputs)
101105
controlNetParameters.AddOutputBuffer();
102106

107+
// ControlNet inference
103108
var controlNetResults = _onnxModelService.RunInference(modelOptions, OnnxModelType.Control, controlNetParameters);
104-
if (controlNetResults.IsNullOrEmpty())
105-
throw new Exception("Control model produced no result");
106109

107110
// Add ControlNet outputs to Unet input
108111
foreach (var item in controlNetResults)
109-
inferenceParameters.AddInputTensor(item.ToDenseTensor());
110-
}
112+
inferenceParameters.AddInput(item);
111113

114+
// Add output buffer
115+
inferenceParameters.AddOutputBuffer(outputDimension);
112116

113-
// Add output buffer
114-
inferenceParameters.AddOutputBuffer(outputDimension);
115-
116-
// Unet
117-
var results = await _onnxModelService.RunInferenceAsync(modelOptions, OnnxModelType.Unet, inferenceParameters);
118-
using (var result = results.First())
119-
{
120-
var noisePred = result.ToDenseTensor();
117+
// Unet inference
118+
var results = await _onnxModelService.RunInferenceAsync(modelOptions, OnnxModelType.Unet, inferenceParameters);
119+
using (var result = results.First())
120+
{
121+
var noisePred = result.ToDenseTensor();
121122

122-
// Perform guidance
123-
if (performGuidance)
124-
noisePred = PerformGuidance(noisePred, schedulerOptions.GuidanceScale);
123+
// Perform guidance
124+
if (performGuidance)
125+
noisePred = PerformGuidance(noisePred, schedulerOptions.GuidanceScale);
125126

126-
// Scheduler Step
127-
latents = scheduler.Step(noisePred, timestep, latents).Result;
127+
// Scheduler Step
128+
latents = scheduler.Step(noisePred, timestep, latents).Result;
129+
}
128130
}
129131
}
130132

@@ -137,50 +139,14 @@ protected override async Task<DenseTensor<float>> SchedulerStepAsync(StableDiffu
137139
}
138140
}
139141

140-
141-
/// <summary>
142-
/// Gets the timesteps.
143-
/// </summary>
144-
/// <param name="prompt">The prompt.</param>
145-
/// <param name="options">The options.</param>
146-
/// <param name="scheduler">The scheduler.</param>
147-
/// <returns></returns>
148142
protected override IReadOnlyList<int> GetTimesteps(SchedulerOptions options, IScheduler scheduler)
149143
{
150-
var inittimestep = Math.Min((int)(options.InferenceSteps * options.Strength), options.InferenceSteps);
151-
var start = Math.Max(options.InferenceSteps - inittimestep, 0);
152-
return scheduler.Timesteps.Skip(start).ToList();
144+
return scheduler.Timesteps;
153145
}
154146

155-
156-
/// <summary>
157-
/// Prepares the latents for inference.
158-
/// </summary>
159-
/// <param name="prompt">The prompt.</param>
160-
/// <param name="options">The options.</param>
161-
/// <param name="scheduler">The scheduler.</param>
162-
/// <returns></returns>
163-
protected override async Task<DenseTensor<float>> PrepareLatentsAsync(StableDiffusionModelSet model, PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList<int> timesteps)
147+
protected override Task<DenseTensor<float>> PrepareLatentsAsync(StableDiffusionModelSet model, PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList<int> timesteps)
164148
{
165-
var imageTensor = prompt.InputImage.ToDenseTensor(new[] { 1, 3, options.Height, options.Width });
166-
167-
//TODO: Model Config, Channels
168-
var outputDimension = options.GetScaledDimension();
169-
var metadata = _onnxModelService.GetModelMetadata(model, OnnxModelType.VaeEncoder);
170-
using (var inferenceParameters = new OnnxInferenceParameters(metadata))
171-
{
172-
inferenceParameters.AddInputTensor(imageTensor);
173-
inferenceParameters.AddOutputBuffer(outputDimension);
174-
175-
var results = await _onnxModelService.RunInferenceAsync(model, OnnxModelType.VaeEncoder, inferenceParameters);
176-
using (var result = results.First())
177-
{
178-
var outputResult = result.ToDenseTensor();
179-
var scaledSample = outputResult.MultiplyBy(model.ScaleFactor);
180-
return scheduler.AddNoise(scaledSample, scheduler.CreateRandomSample(scaledSample.Dimensions), timesteps);
181-
}
182-
}
149+
return Task.FromResult(scheduler.CreateRandomSample(options.GetScaledDimension(), scheduler.InitNoiseSigma));
183150
}
184-
185151
}
186152
}

OnnxStack.StableDiffusion/Helpers/ImageHelpers.cs

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
11
using Microsoft.ML.OnnxRuntime.Tensors;
22
using OnnxStack.Core.Image;
3-
using OnnxStack.StableDiffusion.Models;
43
using SixLabors.ImageSharp;
54
using SixLabors.ImageSharp.PixelFormats;
65
using SixLabors.ImageSharp.Processing;
76
using System;
87
using System.IO;
9-
using System.Threading.Tasks;
108

119
namespace OnnxStack.StableDiffusion.Helpers
1210
{
@@ -55,18 +53,18 @@ public static Stream ToImageStream(this DenseTensor<float> imageTensor)
5553
/// <param name="imageData">The image data.</param>
5654
/// <param name="dimensions">The dimensions.</param>
5755
/// <returns></returns>
58-
public static DenseTensor<float> ToDenseTensor(this InputImage imageData, ReadOnlySpan<int> dimensions)
56+
public static DenseTensor<float> ToDenseTensor(this InputImage imageData, ReadOnlySpan<int> dimensions, bool normalize = true)
5957
{
6058
if (!string.IsNullOrEmpty(imageData.ImageBase64))
61-
return TensorFromBase64(imageData.ImageBase64, dimensions);
59+
return TensorFromBase64(imageData.ImageBase64, dimensions, normalize);
6260
if (imageData.ImageBytes != null)
63-
return TensorFromBytes(imageData.ImageBytes, dimensions);
61+
return TensorFromBytes(imageData.ImageBytes, dimensions, normalize);
6462
if (imageData.ImageStream != null)
65-
return TensorFromStream(imageData.ImageStream, dimensions);
63+
return TensorFromStream(imageData.ImageStream, dimensions, normalize);
6664
if (imageData.ImageTensor != null)
6765
return imageData.ImageTensor.ToDenseTensor(); // Note: Tensor Copy // TODO: Reshape to dimensions
6866

69-
return TensorFromImage(imageData.Image, dimensions);
67+
return TensorFromImage(imageData.Image, dimensions, normalize);
7068
}
7169

7270

@@ -147,12 +145,12 @@ public static void TensorToImageDebug2(DenseTensor<float> imageTensor, string fi
147145
/// <param name="image">The image.</param>
148146
/// <param name="dimensions">The dimensions.</param>
149147
/// <returns></returns>
150-
public static DenseTensor<float> TensorFromImage(Image<Rgba32> image, ReadOnlySpan<int> dimensions)
148+
public static DenseTensor<float> TensorFromImage(Image<Rgba32> image, ReadOnlySpan<int> dimensions, bool normalize)
151149
{
152150
using (image)
153151
{
154152
Resize(image, dimensions);
155-
return ProcessPixels(image, dimensions);
153+
return ProcessPixels(image, dimensions, normalize);
156154
}
157155
}
158156

@@ -164,12 +162,12 @@ public static DenseTensor<float> TensorFromImage(Image<Rgba32> image, ReadOnlySp
164162
/// <param name="width">The width.</param>
165163
/// <param name="height">The height.</param>
166164
/// <returns></returns>
167-
public static DenseTensor<float> TensorFromFile(string filename, ReadOnlySpan<int> dimensions)
165+
public static DenseTensor<float> TensorFromFile(string filename, ReadOnlySpan<int> dimensions, bool normalize)
168166
{
169167
using (var image = Image.Load<Rgba32>(filename))
170168
{
171169
Resize(image, dimensions);
172-
return ProcessPixels(image, dimensions);
170+
return ProcessPixels(image, dimensions, normalize);
173171
}
174172
}
175173

@@ -180,12 +178,12 @@ public static DenseTensor<float> TensorFromFile(string filename, ReadOnlySpan<in
180178
/// <param name="base64Image">The base64 image.</param>
181179
/// <param name="width">The width.</param>
182180
/// <param name="height">The height.</param>
183-
public static DenseTensor<float> TensorFromBase64(string base64Image, ReadOnlySpan<int> dimensions)
181+
public static DenseTensor<float> TensorFromBase64(string base64Image, ReadOnlySpan<int> dimensions, bool normalize)
184182
{
185183
using (var image = Image.Load<Rgba32>(Convert.FromBase64String(base64Image.Split(',')[1])))
186184
{
187185
Resize(image, dimensions);
188-
return ProcessPixels(image, dimensions);
186+
return ProcessPixels(image, dimensions, normalize);
189187
}
190188
}
191189

@@ -198,12 +196,12 @@ public static DenseTensor<float> TensorFromBase64(string base64Image, ReadOnlySp
198196
/// <param name="height">The height.</param>
199197
/// <returns>
200198
/// </returns>
201-
public static DenseTensor<float> TensorFromBytes(byte[] imageBytes, ReadOnlySpan<int> dimensions)
199+
public static DenseTensor<float> TensorFromBytes(byte[] imageBytes, ReadOnlySpan<int> dimensions, bool normalize)
202200
{
203201
using (var image = Image.Load<Rgba32>(imageBytes))
204202
{
205203
Resize(image, dimensions);
206-
return ProcessPixels(image, dimensions);
204+
return ProcessPixels(image, dimensions, normalize);
207205
}
208206
}
209207

@@ -215,12 +213,12 @@ public static DenseTensor<float> TensorFromBytes(byte[] imageBytes, ReadOnlySpan
215213
/// <param name="width">The width.</param>
216214
/// <param name="height">The height.</param>
217215
/// <returns></returns>
218-
public static DenseTensor<float> TensorFromStream(Stream imageStream, ReadOnlySpan<int> dimensions)
216+
public static DenseTensor<float> TensorFromStream(Stream imageStream, ReadOnlySpan<int> dimensions, bool normalize)
219217
{
220218
using (var image = Image.Load<Rgba32>(imageStream))
221219
{
222220
Resize(image, dimensions);
223-
return ProcessPixels(image, dimensions);
221+
return ProcessPixels(image, dimensions, normalize);
224222
}
225223
}
226224

@@ -231,7 +229,7 @@ public static DenseTensor<float> TensorFromStream(Stream imageStream, ReadOnlySp
231229
/// <param name="image">The image.</param>
232230
/// <param name="dimensions">The dimensions.</param>
233231
/// <returns></returns>
234-
private static DenseTensor<float> ProcessPixels(Image<Rgba32> image, ReadOnlySpan<int> dimensions)
232+
private static DenseTensor<float> ProcessPixels(Image<Rgba32> image, ReadOnlySpan<int> dimensions, bool normalize)
235233
{
236234
var width = dimensions[3];
237235
var height = dimensions[2];
@@ -244,11 +242,11 @@ private static DenseTensor<float> ProcessPixels(Image<Rgba32> image, ReadOnlySpa
244242
for (int y = 0; y < height; y++)
245243
{
246244
var pixelSpan = img.GetRowSpan(y);
247-
imageArray[0, 0, y, x] = (pixelSpan[x].R / 255.0f) * 2.0f - 1.0f;
248-
imageArray[0, 1, y, x] = (pixelSpan[x].G / 255.0f) * 2.0f - 1.0f;
249-
imageArray[0, 2, y, x] = (pixelSpan[x].B / 255.0f) * 2.0f - 1.0f;
245+
imageArray[0, 0, y, x] = normalize ? (pixelSpan[x].R / 255.0f) * 2.0f - 1.0f : (pixelSpan[x].R / 255.0f);
246+
imageArray[0, 1, y, x] = normalize ? (pixelSpan[x].G / 255.0f) * 2.0f - 1.0f : (pixelSpan[x].G / 255.0f);
247+
imageArray[0, 2, y, x] = normalize ? (pixelSpan[x].B / 255.0f) * 2.0f - 1.0f : (pixelSpan[x].B / 255.0f);
250248
if (channels == 4)
251-
imageArray[0, 3, y, x] = (pixelSpan[x].A / 255.0f) * 2.0f - 1.0f;
249+
imageArray[0, 3, y, x] = normalize ? (pixelSpan[x].A / 255.0f) * 2.0f - 1.0f : (pixelSpan[x].A / 255.0f);
252250
}
253251
}
254252
});

OnnxStack.StableDiffusion/Registration.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ private static void RegisterServices(this IServiceCollection serviceCollection)
8282
serviceCollection.AddSingleton<IDiffuser, StableDiffusion.Diffusers.InstaFlow.TextDiffuser>();
8383

8484
//ControlNet
85-
serviceCollection.AddSingleton<IDiffuser, StableDiffusion.Diffusers.ControlNet.ImageDiffuser>();
85+
serviceCollection.AddSingleton<IDiffuser, StableDiffusion.Diffusers.ControlNet.TextDiffuser>();
8686
}
8787

8888

0 commit comments

Comments
 (0)