Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit ef6bb14

Browse files
committed
Add Control model support, Add Control model outputs to Unet inputs
1 parent 4ea907c commit ef6bb14

File tree

4 files changed

+112
-17
lines changed

4 files changed

+112
-17
lines changed

OnnxStack.Core/Config/OnnxModelType.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ public enum OnnxModelType
99
TextEncoder2 = 21,
1010
VaeEncoder = 30,
1111
VaeDecoder = 40,
12+
Control = 50,
1213
Upscaler = 1000
1314
}
1415
}

OnnxStack.StableDiffusion/Diffusers/ControlNet/ControlNetDiffuser.cs

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,18 @@
11
using Microsoft.Extensions.Logging;
22
using Microsoft.ML.OnnxRuntime.Tensors;
3+
using OnnxStack.Core;
4+
using OnnxStack.Core.Config;
5+
using OnnxStack.Core.Model;
36
using OnnxStack.Core.Services;
47
using OnnxStack.StableDiffusion.Common;
58
using OnnxStack.StableDiffusion.Config;
69
using OnnxStack.StableDiffusion.Enums;
10+
using OnnxStack.StableDiffusion.Helpers;
711
using OnnxStack.StableDiffusion.Models;
812
using OnnxStack.StableDiffusion.Schedulers.StableDiffusion;
913
using System;
14+
using System.Diagnostics;
15+
using System.Linq;
1016
using System.Threading;
1117
using System.Threading.Tasks;
1218

@@ -29,22 +35,7 @@ public ControlNetDiffuser(IOnnxModelService onnxModelService, IPromptService pro
2935
public override DiffuserPipelineType PipelineType => DiffuserPipelineType.ControlNet;
3036

3137

32-
/// <summary>
33-
/// Called on each Scheduler step.
34-
/// </summary>
35-
/// <param name="modelOptions">The model options.</param>
36-
/// <param name="promptOptions">The prompt options.</param>
37-
/// <param name="schedulerOptions">The scheduler options.</param>
38-
/// <param name="promptEmbeddings">The prompt embeddings.</param>
39-
/// <param name="performGuidance">if set to <c>true</c> [perform guidance].</param>
40-
/// <param name="progressCallback">The progress callback.</param>
41-
/// <param name="cancellationToken">The cancellation token.</param>
42-
/// <returns></returns>
43-
/// <exception cref="System.NotImplementedException"></exception>
44-
protected override Task<DenseTensor<float>> SchedulerStepAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default)
45-
{
46-
throw new NotImplementedException();
47-
}
38+
4839

4940

5041
/// <summary>

OnnxStack.StableDiffusion/Diffusers/ControlNet/ImageDiffuser.cs

Lines changed: 103 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
using Microsoft.Extensions.Logging;
2-
using Microsoft.ML.OnnxRuntime;
32
using Microsoft.ML.OnnxRuntime.Tensors;
43
using OnnxStack.Core;
54
using OnnxStack.Core.Config;
@@ -9,10 +8,13 @@
98
using OnnxStack.StableDiffusion.Config;
109
using OnnxStack.StableDiffusion.Enums;
1110
using OnnxStack.StableDiffusion.Helpers;
11+
using OnnxStack.StableDiffusion.Models;
1212
using SixLabors.ImageSharp;
1313
using System;
1414
using System.Collections.Generic;
15+
using System.Diagnostics;
1516
using System.Linq;
17+
using System.Threading;
1618
using System.Threading.Tasks;
1719

1820
namespace OnnxStack.StableDiffusion.Diffusers.ControlNet
@@ -35,6 +37,106 @@ public ImageDiffuser(IOnnxModelService onnxModelService, IPromptService promptSe
3537
/// </summary>
3638
public override DiffuserType DiffuserType => DiffuserType.ImageToImage;
3739

40+
/// <summary>
41+
/// Called on each Scheduler step.
42+
/// </summary>
43+
/// <param name="modelOptions">The model options.</param>
44+
/// <param name="promptOptions">The prompt options.</param>
45+
/// <param name="schedulerOptions">The scheduler options.</param>
46+
/// <param name="promptEmbeddings">The prompt embeddings.</param>
47+
/// <param name="performGuidance">if set to <c>true</c> [perform guidance].</param>
48+
/// <param name="progressCallback">The progress callback.</param>
49+
/// <param name="cancellationToken">The cancellation token.</param>
50+
/// <returns></returns>
51+
/// <exception cref="System.NotImplementedException"></exception>
52+
protected override async Task<DenseTensor<float>> SchedulerStepAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default)
53+
{
54+
// Get Scheduler
55+
using (var scheduler = GetScheduler(schedulerOptions))
56+
{
57+
// Get timesteps
58+
var timesteps = GetTimesteps(schedulerOptions, scheduler);
59+
60+
// Create latent sample
61+
var latents = await PrepareLatentsAsync(modelOptions, promptOptions, schedulerOptions, scheduler, timesteps);
62+
63+
// Get Model metadata
64+
var metadata = _onnxModelService.GetModelMetadata(modelOptions, OnnxModelType.Unet);
65+
66+
// Get Model metadata
67+
var controlNetMetadata = _onnxModelService.GetModelMetadata(modelOptions, OnnxModelType.Control);
68+
69+
// TODO: do we need to pre-process?
70+
var controlImage = promptOptions.InputImage.ToDenseTensor(new[] { 1, 3, schedulerOptions.Height, schedulerOptions.Width });
71+
72+
// Loop though the timesteps
73+
var step = 0;
74+
foreach (var timestep in timesteps)
75+
{
76+
step++;
77+
var stepTime = Stopwatch.GetTimestamp();
78+
cancellationToken.ThrowIfCancellationRequested();
79+
80+
// Create input tensor.
81+
var inputLatent = performGuidance ? latents.Repeat(2) : latents;
82+
var inputTensor = scheduler.ScaleInput(inputLatent, timestep);
83+
var timestepTensor = CreateTimestepTensor(timestep);
84+
85+
var outputChannels = performGuidance ? 2 : 1;
86+
var outputDimension = schedulerOptions.GetScaledDimension(outputChannels);
87+
using (var inferenceParameters = new OnnxInferenceParameters(metadata))
88+
{
89+
inferenceParameters.AddInputTensor(inputTensor);
90+
inferenceParameters.AddInputTensor(timestepTensor);
91+
inferenceParameters.AddInputTensor(promptEmbeddings.PromptEmbeds);
92+
93+
// ControlNet
94+
using (var controlNetParameters = new OnnxInferenceParameters(controlNetMetadata))
95+
{
96+
controlNetParameters.AddInputTensor(inputTensor);
97+
controlNetParameters.AddInputTensor(timestepTensor);
98+
controlNetParameters.AddInputTensor(promptEmbeddings.PromptEmbeds);
99+
controlNetParameters.AddInputTensor(controlImage);
100+
foreach (var item in controlNetMetadata.Outputs)
101+
controlNetParameters.AddOutputBuffer();
102+
103+
var controlNetResults = _onnxModelService.RunInference(modelOptions, OnnxModelType.Control, controlNetParameters);
104+
if (controlNetResults.IsNullOrEmpty())
105+
throw new Exception("Control model produced no result");
106+
107+
// Add ControlNet outputs to Unet input
108+
foreach (var item in controlNetResults)
109+
inferenceParameters.AddInputTensor(item.ToDenseTensor());
110+
}
111+
112+
113+
// Add output buffer
114+
inferenceParameters.AddOutputBuffer(outputDimension);
115+
116+
// Unet
117+
var results = await _onnxModelService.RunInferenceAsync(modelOptions, OnnxModelType.Unet, inferenceParameters);
118+
using (var result = results.First())
119+
{
120+
var noisePred = result.ToDenseTensor();
121+
122+
// Perform guidance
123+
if (performGuidance)
124+
noisePred = PerformGuidance(noisePred, schedulerOptions.GuidanceScale);
125+
126+
// Scheduler Step
127+
latents = scheduler.Step(noisePred, timestep, latents).Result;
128+
}
129+
}
130+
131+
ReportProgress(progressCallback, step, timesteps.Count, latents);
132+
_logger?.LogEnd($"Step {step}/{timesteps.Count}", stepTime);
133+
}
134+
135+
// Decode Latents
136+
return await DecodeLatentsAsync(modelOptions, promptOptions, schedulerOptions, latents);
137+
}
138+
}
139+
38140

39141
/// <summary>
40142
/// Gets the timesteps.

OnnxStack.StableDiffusion/Extensions.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ public static SchedulerType[] GetSchedulerTypes(this DiffuserPipelineType pipeli
9999
{
100100
switch (pipelineType)
101101
{
102+
case DiffuserPipelineType.ControlNet:
102103
case DiffuserPipelineType.StableDiffusion:
103104
case DiffuserPipelineType.StableDiffusionXL:
104105
return new[]

0 commit comments

Comments
 (0)