Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit 990ea25

Browse files
committed
SDXL Pipeline scaffold
1 parent 5d35a84 commit 990ea25

File tree

7 files changed

+296
-0
lines changed

7 files changed

+296
-0
lines changed

OnnxStack.Console/appsettings.json

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,54 @@
172172
"OnnxModelPath": "D:\\Repositories\\InstaFlow-0.9B-ONNX\\vae_decoder\\model.onnx"
173173
}
174174
]
175+
},
176+
{
177+
"Name": "DreamShaper XL",
178+
"IsEnabled": true,
179+
"PadTokenId": 49407,
180+
"BlankTokenId": 49407,
181+
"TokenizerLimit": 77,
182+
"EmbeddingsLength": 2816,
183+
"ScaleFactor": 0.13025,
184+
"PipelineType": "StableDiffusionXL",
185+
"Diffusers": [
186+
"TextToImage"
187+
],
188+
"DeviceId": 0,
189+
"InterOpNumThreads": 0,
190+
"IntraOpNumThreads": 0,
191+
"ExecutionMode": "ORT_SEQUENTIAL",
192+
"ExecutionProvider": "DirectML",
193+
"ModelConfigurations": [
194+
{
195+
"Type": "Tokenizer",
196+
"OnnxModelPath": "D:\\Repositories\\dreamshaper-xl-1-0-Olive-Onnx\\tokenizer\\model.onnx"
197+
},
198+
{
199+
"Type": "Tokenizer2",
200+
"OnnxModelPath": "D:\\Repositories\\dreamshaper-xl-1-0-Olive-Onnx\\tokenizer2\\model.onnx"
201+
},
202+
{
203+
"Type": "Unet",
204+
"OnnxModelPath": "D:\\Repositories\\dreamshaper-xl-1-0-Olive-Onnx\\unet\\model.onnx"
205+
},
206+
{
207+
"Type": "TextEncoder",
208+
"OnnxModelPath": "D:\\Repositories\\dreamshaper-xl-1-0-Olive-Onnx\\text_encoder\\model.onnx"
209+
},
210+
{
211+
"Type": "TextEncoder2",
212+
"OnnxModelPath": "D:\\Repositories\\dreamshaper-xl-1-0-Olive-Onnx\\text_encoder2\\model.onnx"
213+
},
214+
{
215+
"Type": "VaeEncoder",
216+
"OnnxModelPath": "D:\\Repositories\\dreamshaper-xl-1-0-Olive-Onnx\\vae_encoder\\model.onnx"
217+
},
218+
{
219+
"Type": "VaeDecoder",
220+
"OnnxModelPath": "D:\\Repositories\\dreamshaper-xl-1-0-Olive-Onnx\\vae_decoder\\model.onnx"
221+
}
222+
]
175223
}
176224
]
177225
}

OnnxStack.Core/Config/OnnxModelType.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@ public enum OnnxModelType
44
{
55
Unet = 0,
66
Tokenizer = 10,
7+
Tokenizer2 = 11,
78
TextEncoder = 20,
9+
TextEncoder2 = 21,
810
VaeEncoder = 30,
911
VaeDecoder = 40,
1012
}
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
using Microsoft.Extensions.Logging;
2+
using Microsoft.ML.OnnxRuntime.Tensors;
3+
using OnnxStack.Core;
4+
using OnnxStack.Core.Config;
5+
using OnnxStack.Core.Model;
6+
using OnnxStack.Core.Services;
7+
using OnnxStack.StableDiffusion.Common;
8+
using OnnxStack.StableDiffusion.Config;
9+
using OnnxStack.StableDiffusion.Enums;
10+
using OnnxStack.StableDiffusion.Helpers;
11+
using OnnxStack.StableDiffusion.Schedulers.StableDiffusion;
12+
using System;
13+
using System.Diagnostics;
14+
using System.Linq;
15+
using System.Threading;
16+
using System.Threading.Tasks;
17+
18+
namespace OnnxStack.StableDiffusion.Diffusers.StableDiffusionXL
19+
{
20+
public abstract class StableDiffusionXLDiffuser : DiffuserBase
21+
{
22+
/// <summary>
23+
/// Initializes a new instance of the <see cref="StableDiffusionXLDiffuser"/> class.
24+
/// </summary>
25+
/// <param name="configuration">The configuration.</param>
26+
/// <param name="onnxModelService">The onnx model service.</param>
27+
public StableDiffusionXLDiffuser(IOnnxModelService onnxModelService, IPromptService promptService, ILogger<StableDiffusionXLDiffuser> logger)
28+
: base(onnxModelService, promptService, logger) { }
29+
30+
31+
/// <summary>
32+
/// Gets the type of the pipeline.
33+
/// </summary>
34+
public override DiffuserPipelineType PipelineType => DiffuserPipelineType.StableDiffusionXL;
35+
36+
37+
/// <summary>
38+
/// Runs the scheduler steps.
39+
/// </summary>
40+
/// <param name="modelOptions">The model options.</param>
41+
/// <param name="promptOptions">The prompt options.</param>
42+
/// <param name="schedulerOptions">The scheduler options.</param>
43+
/// <param name="promptEmbeddings">The prompt embeddings.</param>
44+
/// <param name="performGuidance">if set to <c>true</c> [perform guidance].</param>
45+
/// <param name="progressCallback">The progress callback.</param>
46+
/// <param name="cancellationToken">The cancellation token.</param>
47+
/// <returns></returns>
48+
protected override async Task<DenseTensor<float>> SchedulerStepAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, DenseTensor<float> promptEmbeddings, bool performGuidance, Action<int, int> progressCallback = null, CancellationToken cancellationToken = default)
49+
{
50+
// Get Scheduler
51+
using (var scheduler = GetScheduler(schedulerOptions))
52+
{
53+
// Get timesteps
54+
var timesteps = GetTimesteps(schedulerOptions, scheduler);
55+
56+
// Create latent sample
57+
var latents = await PrepareLatentsAsync(modelOptions, promptOptions, schedulerOptions, scheduler, timesteps);
58+
59+
// Get Model metadata
60+
var metadata = _onnxModelService.GetModelMetadata(modelOptions, OnnxModelType.Unet);
61+
62+
// Loop though the timesteps
63+
var step = 0;
64+
foreach (var timestep in timesteps)
65+
{
66+
step++;
67+
var stepTime = Stopwatch.GetTimestamp();
68+
cancellationToken.ThrowIfCancellationRequested();
69+
70+
// Create input tensor.
71+
var inputLatent = performGuidance ? latents.Repeat(2) : latents;
72+
var inputTensor = scheduler.ScaleInput(inputLatent, timestep);
73+
var timestepTensor = CreateTimestepTensor(timestep);
74+
75+
var outputChannels = performGuidance ? 2 : 1;
76+
var outputDimension = schedulerOptions.GetScaledDimension(outputChannels);
77+
using (var inferenceParameters = new OnnxInferenceParameters(metadata))
78+
{
79+
inferenceParameters.AddInputTensor(inputTensor);
80+
inferenceParameters.AddInputTensor(timestepTensor);
81+
inferenceParameters.AddInputTensor(promptEmbeddings);
82+
inferenceParameters.AddOutputBuffer(outputDimension);
83+
84+
var results = await _onnxModelService.RunInferenceAsync(modelOptions, OnnxModelType.Unet, inferenceParameters);
85+
using (var result = results.First())
86+
{
87+
var noisePred = result.ToDenseTensor();
88+
89+
// Perform guidance
90+
if (performGuidance)
91+
noisePred = PerformGuidance(noisePred, schedulerOptions.GuidanceScale);
92+
93+
// Scheduler Step
94+
latents = scheduler.Step(noisePred, timestep, latents).Result;
95+
}
96+
}
97+
98+
progressCallback?.Invoke(step, timesteps.Count);
99+
_logger?.LogEnd($"Step {step}/{timesteps.Count}", stepTime);
100+
}
101+
102+
// Decode Latents
103+
return await DecodeLatentsAsync(modelOptions, promptOptions, schedulerOptions, latents);
104+
}
105+
}
106+
107+
108+
/// <summary>
109+
/// Gets the scheduler.
110+
/// </summary>
111+
/// <param name="options">The options.</param>
112+
/// <param name="schedulerConfig">The scheduler configuration.</param>
113+
/// <returns></returns>
114+
protected override IScheduler GetScheduler(SchedulerOptions options)
115+
{
116+
return options.SchedulerType switch
117+
{
118+
SchedulerType.LMS => new LMSScheduler(options),
119+
SchedulerType.Euler => new EulerScheduler(options),
120+
SchedulerType.EulerAncestral => new EulerAncestralScheduler(options),
121+
SchedulerType.DDPM => new DDPMScheduler(options),
122+
SchedulerType.DDIM => new DDIMScheduler(options),
123+
SchedulerType.KDPM2 => new KDPM2Scheduler(options),
124+
_ => default
125+
};
126+
}
127+
}
128+
}
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
using Microsoft.Extensions.Logging;
2+
using Microsoft.ML.OnnxRuntime.Tensors;
3+
using OnnxStack.Core.Services;
4+
using OnnxStack.StableDiffusion.Common;
5+
using OnnxStack.StableDiffusion.Config;
6+
using OnnxStack.StableDiffusion.Enums;
7+
using System.Collections.Generic;
8+
using System.Threading.Tasks;
9+
10+
namespace OnnxStack.StableDiffusion.Diffusers.StableDiffusionXL
11+
{
12+
public sealed class TextDiffuser : StableDiffusionXLDiffuser
13+
{
14+
/// <summary>
15+
/// Initializes a new instance of the <see cref="TextDiffuser"/> class.
16+
/// </summary>
17+
/// <param name="configuration">The configuration.</param>
18+
/// <param name="onnxModelService">The onnx model service.</param>
19+
public TextDiffuser(IOnnxModelService onnxModelService, IPromptService promptService, ILogger<StableDiffusionXLDiffuser> logger)
20+
: base(onnxModelService, promptService, logger)
21+
{
22+
}
23+
24+
25+
/// <summary>
26+
/// Gets the type of the diffuser.
27+
/// </summary>
28+
public override DiffuserType DiffuserType => DiffuserType.TextToImage;
29+
30+
31+
/// <summary>
32+
/// Gets the timesteps.
33+
/// </summary>
34+
/// <param name="prompt">The prompt.</param>
35+
/// <param name="options">The options.</param>
36+
/// <param name="scheduler">The scheduler.</param>
37+
/// <returns></returns>
38+
protected override IReadOnlyList<int> GetTimesteps(SchedulerOptions options, IScheduler scheduler)
39+
{
40+
return scheduler.Timesteps;
41+
}
42+
43+
44+
/// <summary>
45+
/// Prepares the latents for inference.
46+
/// </summary>
47+
/// <param name="prompt">The prompt.</param>
48+
/// <param name="options">The options.</param>
49+
/// <param name="scheduler">The scheduler.</param>
50+
/// <returns></returns>
51+
protected override Task<DenseTensor<float>> PrepareLatentsAsync(IModelOptions model, PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList<int> timesteps)
52+
{
53+
return Task.FromResult(scheduler.CreateRandomSample(options.GetScaledDimension(), scheduler.InitNoiseSigma));
54+
}
55+
}
56+
}

OnnxStack.StableDiffusion/Enums/DiffuserPipelineType.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
public enum DiffuserPipelineType
44
{
55
StableDiffusion = 0,
6+
StableDiffusionXL = 1,
67
LatentConsistency = 10,
78
InstaFlow = 11,
89
}
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
using Microsoft.Extensions.Logging;
2+
using OnnxStack.Core;
3+
using OnnxStack.StableDiffusion.Common;
4+
using OnnxStack.StableDiffusion.Diffusers;
5+
using OnnxStack.StableDiffusion.Enums;
6+
using System.Collections.Concurrent;
7+
using System.Collections.Generic;
8+
using System.Linq;
9+
10+
namespace OnnxStack.StableDiffusion.Pipelines
11+
{
12+
public sealed class StableDiffusionXLPipeline : IPipeline
13+
{
14+
private readonly DiffuserPipelineType _pipelineType;
15+
private readonly ILogger<StableDiffusionXLPipeline> _logger;
16+
private readonly ConcurrentDictionary<DiffuserType, IDiffuser> _diffusers;
17+
18+
19+
/// <summary>
20+
/// Initializes a new instance of the <see cref="StableDiffusionXLPipeline"/> class.
21+
/// </summary>
22+
/// <param name="onnxModelService">The onnx model service.</param>
23+
/// <param name="promptService">The prompt service.</param>
24+
public StableDiffusionXLPipeline(IEnumerable<IDiffuser> diffusers, ILogger<StableDiffusionXLPipeline> logger)
25+
{
26+
_logger = logger;
27+
_pipelineType = DiffuserPipelineType.StableDiffusionXL;
28+
_diffusers = diffusers
29+
.Where(x => x.PipelineType == _pipelineType)
30+
.ToConcurrentDictionary(k => k.DiffuserType, v => v);
31+
}
32+
33+
34+
/// <summary>
35+
/// Gets the type of the pipeline.
36+
/// </summary>
37+
public DiffuserPipelineType PipelineType => _pipelineType;
38+
39+
40+
/// <summary>
41+
/// Gets the diffusers.
42+
/// </summary>
43+
public ConcurrentDictionary<DiffuserType, IDiffuser> Diffusers => _diffusers;
44+
45+
46+
/// <summary>
47+
/// Gets the diffuser.
48+
/// </summary>
49+
/// <param name="diffuserType">Type of the diffuser.</param>
50+
/// <returns></returns>
51+
public IDiffuser GetDiffuser(DiffuserType diffuserType)
52+
{
53+
_diffusers.TryGetValue(diffuserType, out var diffuser);
54+
return diffuser;
55+
}
56+
}
57+
}

OnnxStack.StableDiffusion/Registration.cs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ public static void AddOnnxStackStableDiffusion(this IServiceCollection serviceCo
3131

3232
//Pipelines
3333
serviceCollection.AddSingleton<IPipeline, StableDiffusionPipeline>();
34+
serviceCollection.AddSingleton<IPipeline, StableDiffusionXLPipeline>();
3435
serviceCollection.AddSingleton<IPipeline, LatentConsistencyPipeline>();
3536
serviceCollection.AddSingleton<IPipeline, InstaFlowPipeline>();
3637

@@ -40,6 +41,9 @@ public static void AddOnnxStackStableDiffusion(this IServiceCollection serviceCo
4041
serviceCollection.AddSingleton<IDiffuser, StableDiffusion.Diffusers.StableDiffusion.InpaintDiffuser>();
4142
serviceCollection.AddSingleton<IDiffuser, StableDiffusion.Diffusers.StableDiffusion.InpaintLegacyDiffuser>();
4243

44+
//StableDiffusionXL
45+
serviceCollection.AddSingleton<IDiffuser, StableDiffusion.Diffusers.StableDiffusionXL.TextDiffuser>();
46+
4347
//LatentConsistency
4448
serviceCollection.AddSingleton<IDiffuser, StableDiffusion.Diffusers.LatentConsistency.TextDiffuser>();
4549
serviceCollection.AddSingleton<IDiffuser, StableDiffusion.Diffusers.LatentConsistency.ImageDiffuser>();

0 commit comments

Comments
 (0)