Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit 92cb75c

Browse files
committed
Make ControlNet diffuser based not pipeline based
1 parent 7d32b89 commit 92cb75c

File tree

10 files changed

+173
-207
lines changed

10 files changed

+173
-207
lines changed

OnnxStack.Core/Config/OnnxModelType.cs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,7 @@ public enum OnnxModelType
99
TextEncoder2 = 21,
1010
VaeEncoder = 30,
1111
VaeDecoder = 40,
12-
Control = 50,
13-
Annotation = 100,
12+
ControlNet = 50,
1413
Upscaler = 1000,
1514
}
1615
}

OnnxStack.StableDiffusion/Diffusers/ControlNet/ControlNetDiffuser.cs

Lines changed: 0 additions & 69 deletions
This file was deleted.

OnnxStack.StableDiffusion/Diffusers/StableDiffusion/AnimateDiffuser.cs

Lines changed: 0 additions & 57 deletions
This file was deleted.

OnnxStack.StableDiffusion/Diffusers/ControlNet/TextDiffuser.cs renamed to OnnxStack.StableDiffusion/Diffusers/StableDiffusion/ControlNetDiffuser.cs

Lines changed: 82 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,32 +9,38 @@
99
using OnnxStack.StableDiffusion.Enums;
1010
using OnnxStack.StableDiffusion.Helpers;
1111
using OnnxStack.StableDiffusion.Models;
12+
using OnnxStack.StableDiffusion.Schedulers.StableDiffusion;
1213
using System;
1314
using System.Collections.Generic;
1415
using System.Diagnostics;
1516
using System.Linq;
1617
using System.Threading;
1718
using System.Threading.Tasks;
1819

19-
namespace OnnxStack.StableDiffusion.Diffusers.ControlNet
20+
namespace OnnxStack.StableDiffusion.Diffusers.StableDiffusion
2021
{
21-
public sealed class TextDiffuser : ControlNetDiffuser
22+
public class ControlNetDiffuser : DiffuserBase
2223
{
2324
/// <summary>
24-
/// Initializes a new instance of the <see cref="TextDiffuser"/> class.
25+
/// Initializes a new instance of the <see cref="ControlNetDiffuser"/> class.
2526
/// </summary>
2627
/// <param name="configuration">The configuration.</param>
2728
/// <param name="onnxModelService">The onnx model service.</param>
28-
public TextDiffuser(IOnnxModelService onnxModelService, IPromptService promptService, ILogger<TextDiffuser> logger)
29-
: base(onnxModelService, promptService, logger)
30-
{
31-
}
29+
public ControlNetDiffuser(IOnnxModelService onnxModelService, IPromptService promptService, ILogger<ControlNetDiffuser> logger)
30+
: base(onnxModelService, promptService, logger) { }
31+
32+
33+
/// <summary>
34+
/// Gets the type of the pipeline.
35+
/// </summary>
36+
public override DiffuserPipelineType PipelineType => DiffuserPipelineType.StableDiffusion;
3237

3338

3439
/// <summary>
3540
/// Gets the type of the diffuser.
3641
/// </summary>
37-
public override DiffuserType DiffuserType => DiffuserType.ImageToImage;
42+
public override DiffuserType DiffuserType => DiffuserType.ControlNet;
43+
3844

3945
/// <summary>
4046
/// Called on each Scheduler step.
@@ -47,7 +53,7 @@ public TextDiffuser(IOnnxModelService onnxModelService, IPromptService promptSer
4753
/// <param name="progressCallback">The progress callback.</param>
4854
/// <param name="cancellationToken">The cancellation token.</param>
4955
/// <returns></returns>
50-
/// <exception cref="System.NotImplementedException"></exception>
56+
/// <exception cref="NotImplementedException"></exception>
5157
protected override async Task<DenseTensor<float>> SchedulerStepAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default)
5258
{
5359
// Get Scheduler
@@ -63,10 +69,10 @@ protected override async Task<DenseTensor<float>> SchedulerStepAsync(StableDiffu
6369
var metadata = _onnxModelService.GetModelMetadata(modelOptions, OnnxModelType.Unet);
6470

6571
// Get Model metadata
66-
var controlNetMetadata = _onnxModelService.GetModelMetadata(modelOptions, OnnxModelType.Control);
67-
72+
var controlNetMetadata = _onnxModelService.GetModelMetadata(modelOptions, OnnxModelType.ControlNet);
73+
6874
// Control Image
69-
var controlImage = promptOptions.InputImage.ToDenseTensor(new[] { 1, 3, schedulerOptions.Height, schedulerOptions.Width }, false);
75+
var controlImage = PrepareControlImage(promptOptions, schedulerOptions);
7076

7177
// Loop though the timesteps
7278
var step = 0;
@@ -98,14 +104,15 @@ protected override async Task<DenseTensor<float>> SchedulerStepAsync(StableDiffu
98104
controlNetParameters.AddInputTensor(timestepTensor);
99105
controlNetParameters.AddInputTensor(promptEmbeddings.PromptEmbeds);
100106
controlNetParameters.AddInputTensor(controlImage);
101-
controlNetParameters.AddInputTensor(conditioningScale);
107+
if (controlNetMetadata.Inputs.Count == 5)
108+
controlNetParameters.AddInputTensor(conditioningScale);
102109

103110
// Optimization: Pre-allocate device buffers for inputs
104111
foreach (var item in controlNetMetadata.Outputs)
105112
controlNetParameters.AddOutputBuffer();
106113

107114
// ControlNet inference
108-
var controlNetResults = _onnxModelService.RunInference(modelOptions, OnnxModelType.Control, controlNetParameters);
115+
var controlNetResults = _onnxModelService.RunInference(modelOptions, OnnxModelType.ControlNet, controlNetParameters);
109116

110117
// Add ControlNet outputs to Unet input
111118
foreach (var item in controlNetResults)
@@ -139,14 +146,75 @@ protected override async Task<DenseTensor<float>> SchedulerStepAsync(StableDiffu
139146
}
140147
}
141148

149+
150+
/// <summary>
151+
/// Gets the timesteps.
152+
/// </summary>
153+
/// <param name="options">The options.</param>
154+
/// <param name="scheduler">The scheduler.</param>
155+
/// <returns></returns>
142156
protected override IReadOnlyList<int> GetTimesteps(SchedulerOptions options, IScheduler scheduler)
143157
{
144158
return scheduler.Timesteps;
145159
}
146160

161+
162+
/// <summary>
163+
/// Prepares the input latents.
164+
/// </summary>
165+
/// <param name="model">The model.</param>
166+
/// <param name="prompt">The prompt.</param>
167+
/// <param name="options">The options.</param>
168+
/// <param name="scheduler">The scheduler.</param>
169+
/// <param name="timesteps">The timesteps.</param>
170+
/// <returns></returns>
147171
protected override Task<DenseTensor<float>> PrepareLatentsAsync(StableDiffusionModelSet model, PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList<int> timesteps)
148172
{
149173
return Task.FromResult(scheduler.CreateRandomSample(options.GetScaledDimension(), scheduler.InitNoiseSigma));
150174
}
175+
176+
177+
/// <summary>
178+
/// Creates the Conditioning Scale tensor.
179+
/// </summary>
180+
/// <param name="conditioningScale">The conditioningScale.</param>
181+
/// <returns></returns>
182+
protected static DenseTensor<double> CreateConditioningScaleTensor(float conditioningScale)
183+
{
184+
return TensorHelper.CreateTensor(new double[] { conditioningScale }, new int[] { 1 });
185+
}
186+
187+
188+
/// <summary>
189+
/// Prepares the control image.
190+
/// </summary>
191+
/// <param name="promptOptions">The prompt options.</param>
192+
/// <param name="schedulerOptions">The scheduler options.</param>
193+
/// <returns></returns>
194+
protected DenseTensor<float> PrepareControlImage(PromptOptions promptOptions, SchedulerOptions schedulerOptions)
195+
{
196+
return promptOptions.InputImage.ToDenseTensor(new[] { 1, 3, schedulerOptions.Height, schedulerOptions.Width }, false);
197+
}
198+
199+
200+
/// <summary>
201+
/// Gets the scheduler.
202+
/// </summary>
203+
/// <param name="options">The options.</param>
204+
/// <param name="schedulerConfig">The scheduler configuration.</param>
205+
/// <returns></returns>
206+
protected override IScheduler GetScheduler(SchedulerOptions options)
207+
{
208+
return options.SchedulerType switch
209+
{
210+
SchedulerType.LMS => new LMSScheduler(options),
211+
SchedulerType.Euler => new EulerScheduler(options),
212+
SchedulerType.EulerAncestral => new EulerAncestralScheduler(options),
213+
SchedulerType.DDPM => new DDPMScheduler(options),
214+
SchedulerType.DDIM => new DDIMScheduler(options),
215+
SchedulerType.KDPM2 => new KDPM2Scheduler(options),
216+
_ => default
217+
};
218+
}
151219
}
152220
}
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
using Microsoft.Extensions.Logging;
2+
using Microsoft.ML.OnnxRuntime.Tensors;
3+
using OnnxStack.Core;
4+
using OnnxStack.Core.Config;
5+
using OnnxStack.Core.Model;
6+
using OnnxStack.Core.Services;
7+
using OnnxStack.StableDiffusion.Common;
8+
using OnnxStack.StableDiffusion.Config;
9+
using OnnxStack.StableDiffusion.Enums;
10+
using OnnxStack.StableDiffusion.Helpers;
11+
using SixLabors.ImageSharp;
12+
using System;
13+
using System.Collections.Generic;
14+
using System.Linq;
15+
using System.Threading.Tasks;
16+
17+
namespace OnnxStack.StableDiffusion.Diffusers.StableDiffusion
18+
{
19+
public sealed class ControlNetImageDiffuser : ControlNetDiffuser
20+
{
21+
/// <summary>
22+
/// Initializes a new instance of the <see cref="ControlNetImageDiffuser"/> class.
23+
/// </summary>
24+
/// <param name="configuration">The configuration.</param>
25+
/// <param name="onnxModelService">The onnx model service.</param>
26+
public ControlNetImageDiffuser(IOnnxModelService onnxModelService, IPromptService promptService, ILogger<ControlNetDiffuser> logger)
27+
: base(onnxModelService, promptService, logger)
28+
{
29+
}
30+
31+
32+
/// <summary>
33+
/// Gets the type of the diffuser.
34+
/// </summary>
35+
public override DiffuserType DiffuserType => DiffuserType.ControlNetImage;
36+
37+
38+
/// <summary>
39+
/// Gets the timesteps.
40+
/// </summary>
41+
/// <param name="prompt">The prompt.</param>
42+
/// <param name="options">The options.</param>
43+
/// <param name="scheduler">The scheduler.</param>
44+
/// <returns></returns>
45+
protected override IReadOnlyList<int> GetTimesteps(SchedulerOptions options, IScheduler scheduler)
46+
{
47+
var inittimestep = Math.Min((int)(options.InferenceSteps * options.Strength), options.InferenceSteps);
48+
var start = Math.Max(options.InferenceSteps - inittimestep, 0);
49+
return scheduler.Timesteps.Skip(start).ToList();
50+
}
51+
52+
53+
/// <summary>
54+
/// Prepares the latents for inference.
55+
/// </summary>
56+
/// <param name="prompt">The prompt.</param>
57+
/// <param name="options">The options.</param>
58+
/// <param name="scheduler">The scheduler.</param>
59+
/// <returns></returns>
60+
protected override async Task<DenseTensor<float>> PrepareLatentsAsync(StableDiffusionModelSet model, PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList<int> timesteps)
61+
{
62+
var imageTensor = prompt.InputImage.ToDenseTensor(new[] { 1, 3, options.Height, options.Width });
63+
64+
//TODO: Model Config, Channels
65+
var outputDimension = options.GetScaledDimension();
66+
var metadata = _onnxModelService.GetModelMetadata(model, OnnxModelType.VaeEncoder);
67+
using (var inferenceParameters = new OnnxInferenceParameters(metadata))
68+
{
69+
inferenceParameters.AddInputTensor(imageTensor);
70+
inferenceParameters.AddOutputBuffer(outputDimension);
71+
72+
var results = await _onnxModelService.RunInferenceAsync(model, OnnxModelType.VaeEncoder, inferenceParameters);
73+
using (var result = results.First())
74+
{
75+
var outputResult = result.ToDenseTensor();
76+
var scaledSample = outputResult.MultiplyBy(model.ScaleFactor);
77+
return scheduler.AddNoise(scaledSample, scheduler.CreateRandomSample(scaledSample.Dimensions), timesteps);
78+
}
79+
}
80+
}
81+
82+
}
83+
}

OnnxStack.StableDiffusion/Enums/DiffuserPipelineType.cs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ public enum DiffuserPipelineType
66
StableDiffusionXL = 1,
77
LatentConsistency = 10,
88
LatentConsistencyXL = 11,
9-
ControlNet = 20,
109
InstaFlow = 30,
1110
}
1211
}

0 commit comments

Comments
 (0)