Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit 6f0adce

Browse files
committed
Merge branch 'master' into LCM_Inpaint
2 parents 00e1215 + a7a555e commit 6f0adce

File tree

10 files changed

+339
-442
lines changed

10 files changed

+339
-442
lines changed
Lines changed: 265 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,265 @@
1+
using Microsoft.Extensions.Logging;
2+
using Microsoft.ML.OnnxRuntime;
3+
using Microsoft.ML.OnnxRuntime.Tensors;
4+
using OnnxStack.Core;
5+
using OnnxStack.Core.Config;
6+
using OnnxStack.Core.Services;
7+
using OnnxStack.StableDiffusion.Common;
8+
using OnnxStack.StableDiffusion.Config;
9+
using OnnxStack.StableDiffusion.Enums;
10+
using OnnxStack.StableDiffusion.Helpers;
11+
using OnnxStack.StableDiffusion.Models;
12+
using System;
13+
using System.Collections.Generic;
14+
using System.Linq;
15+
using System.Runtime.CompilerServices;
16+
using System.Threading;
17+
using System.Threading.Tasks;
18+
19+
namespace OnnxStack.StableDiffusion.Diffusers
20+
{
21+
public abstract class DiffuserBase : IDiffuser
22+
{
23+
protected readonly IPromptService _promptService;
24+
protected readonly IOnnxModelService _onnxModelService;
25+
protected readonly ILogger<DiffuserBase> _logger;
26+
27+
28+
/// <summary>
29+
/// Initializes a new instance of the <see cref="DiffuserBase"/> class.
30+
/// </summary>
31+
/// <param name="onnxModelService">The onnx model service.</param>
32+
/// <param name="promptService">The prompt service.</param>
33+
/// <param name="logger">The logger.</param>
34+
public DiffuserBase(IOnnxModelService onnxModelService, IPromptService promptService, ILogger<DiffuserBase> logger)
35+
{
36+
_logger = logger;
37+
_promptService = promptService;
38+
_onnxModelService = onnxModelService;
39+
}
40+
41+
/// <summary>
42+
/// Gets the type of the diffuser.
43+
/// </summary>
44+
public abstract DiffuserType DiffuserType { get; }
45+
46+
/// <summary>
47+
/// Gets the type of the pipeline.
48+
/// </summary>
49+
public abstract DiffuserPipelineType PipelineType { get; }
50+
51+
/// <summary>
52+
/// Gets the scheduler.
53+
/// </summary>
54+
/// <param name="options">The options.</param>
55+
/// <returns></returns>
56+
protected abstract IScheduler GetScheduler(SchedulerOptions options);
57+
58+
/// <summary>
59+
/// Gets the timesteps.
60+
/// </summary>
61+
/// <param name="options">The options.</param>
62+
/// <param name="scheduler">The scheduler.</param>
63+
/// <returns></returns>
64+
protected abstract IReadOnlyList<int> GetTimesteps(SchedulerOptions options, IScheduler scheduler);
65+
66+
67+
/// <summary>
68+
/// Prepares the input latents.
69+
/// </summary>
70+
/// <param name="model">The model.</param>
71+
/// <param name="prompt">The prompt.</param>
72+
/// <param name="options">The options.</param>
73+
/// <param name="scheduler">The scheduler.</param>
74+
/// <param name="timesteps">The timesteps.</param>
75+
/// <returns></returns>
76+
protected abstract Task<DenseTensor<float>> PrepareLatents(IModelOptions model, PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList<int> timesteps);
77+
78+
79+
/// <summary>
80+
/// Called on each Scheduler step.
81+
/// </summary>
82+
/// <param name="modelOptions">The model options.</param>
83+
/// <param name="promptOptions">The prompt options.</param>
84+
/// <param name="schedulerOptions">The scheduler options.</param>
85+
/// <param name="promptEmbeddings">The prompt embeddings.</param>
86+
/// <param name="performGuidance">if set to <c>true</c> [perform guidance].</param>
87+
/// <param name="progressCallback">The progress callback.</param>
88+
/// <param name="cancellationToken">The cancellation token.</param>
89+
/// <returns></returns>
90+
protected abstract Task<DenseTensor<float>> SchedulerStep(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, DenseTensor<float> promptEmbeddings, bool performGuidance, Action<int, int> progressCallback = null, CancellationToken cancellationToken = default);
91+
92+
93+
/// <summary>
94+
/// Rund the stable diffusion loop
95+
/// </summary>
96+
/// <param name="promptOptions">The prompt options.</param>
97+
/// <param name="schedulerOptions">The scheduler options.</param>
98+
/// <param name="progress">The progress.</param>
99+
/// <param name="cancellationToken">The cancellation token.</param>
100+
/// <returns></returns>
101+
public virtual async Task<DenseTensor<float>> DiffuseAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, Action<int, int> progressCallback = null, CancellationToken cancellationToken = default)
102+
{
103+
// Create random seed if none was set
104+
schedulerOptions.Seed = schedulerOptions.Seed > 0 ? schedulerOptions.Seed : Random.Shared.Next();
105+
106+
var diffuseTime = _logger?.LogBegin("Begin...");
107+
_logger?.Log($"Model: {modelOptions.Name}, Pipeline: {modelOptions.PipelineType}, Diffuser: {promptOptions.DiffuserType}, Scheduler: {schedulerOptions.SchedulerType}");
108+
109+
// Check guidance
110+
var performGuidance = ShouldPerformGuidance(schedulerOptions);
111+
112+
// Process prompts
113+
var promptEmbeddings = await _promptService.CreatePromptAsync(modelOptions, promptOptions, performGuidance);
114+
115+
// Run Scheduler steps
116+
var schedulerResult = await SchedulerStep(modelOptions, promptOptions, schedulerOptions, promptEmbeddings, performGuidance, progressCallback, cancellationToken);
117+
118+
_logger?.LogEnd($"End", diffuseTime);
119+
120+
return schedulerResult;
121+
}
122+
123+
124+
/// <summary>
125+
/// Runs the stable diffusion batch loop
126+
/// </summary>
127+
/// <param name="modelOptions">The model options.</param>
128+
/// <param name="promptOptions">The prompt options.</param>
129+
/// <param name="schedulerOptions">The scheduler options.</param>
130+
/// <param name="batchOptions">The batch options.</param>
131+
/// <param name="progressCallback">The progress callback.</param>
132+
/// <param name="cancellationToken">The cancellation token.</param>
133+
/// <returns></returns>
134+
/// <exception cref="System.NotImplementedException"></exception>
135+
public virtual async IAsyncEnumerable<BatchResult> DiffuseBatchAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action<int, int, int, int> progressCallback = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
136+
{
137+
// Create random seed if none was set
138+
schedulerOptions.Seed = schedulerOptions.Seed > 0 ? schedulerOptions.Seed : Random.Shared.Next();
139+
140+
var diffuseBatchTime = _logger?.LogBegin("Begin...");
141+
_logger?.Log($"Model: {modelOptions.Name}, Pipeline: {modelOptions.PipelineType}, Diffuser: {promptOptions.DiffuserType}, Scheduler: {schedulerOptions.SchedulerType}");
142+
143+
// Check guidance
144+
var performGuidance = ShouldPerformGuidance(schedulerOptions);
145+
146+
// Process prompts
147+
var promptEmbeddings = await _promptService.CreatePromptAsync(modelOptions, promptOptions, performGuidance);
148+
149+
// Generate batch options
150+
var batchSchedulerOptions = BatchGenerator.GenerateBatch(modelOptions, batchOptions, schedulerOptions);
151+
152+
var batchIndex = 1;
153+
var schedulerCallback = (int step, int steps) => progressCallback?.Invoke(batchIndex, batchSchedulerOptions.Count, step, steps);
154+
foreach (var batchSchedulerOption in batchSchedulerOptions)
155+
{
156+
yield return new BatchResult(batchSchedulerOption, await SchedulerStep(modelOptions, promptOptions, batchSchedulerOption, promptEmbeddings, performGuidance, schedulerCallback, cancellationToken));
157+
batchIndex++;
158+
}
159+
160+
_logger?.LogEnd($"End", diffuseBatchTime);
161+
}
162+
163+
164+
/// <summary>
165+
/// Chech if we should run guidance.
166+
/// </summary>
167+
/// <param name="schedulerOptions">The scheduler options.</param>
168+
/// <returns></returns>
169+
protected virtual bool ShouldPerformGuidance(SchedulerOptions schedulerOptions)
170+
{
171+
return schedulerOptions.GuidanceScale > 1f;
172+
}
173+
174+
175+
/// <summary>
176+
/// Performs classifier free guidance
177+
/// </summary>
178+
/// <param name="noisePredUncond">The noise pred.</param>
179+
/// <param name="noisePredText">The noise pred text.</param>
180+
/// <param name="guidanceScale">The guidance scale.</param>
181+
/// <returns></returns>
182+
protected virtual DenseTensor<float> PerformGuidance(DenseTensor<float> noisePrediction, float guidanceScale)
183+
{
184+
// Split Prompt and Negative Prompt predictions
185+
var dimensions = noisePrediction.Dimensions.ToArray();
186+
dimensions[0] /= 2;
187+
188+
var length = (int)noisePrediction.Length / 2;
189+
var noisePredCond = new DenseTensor<float>(noisePrediction.Buffer[length..], dimensions);
190+
var noisePredUncond = new DenseTensor<float>(noisePrediction.Buffer[..length], dimensions);
191+
return noisePredUncond
192+
.Add(noisePredCond
193+
.Subtract(noisePredUncond)
194+
.MultiplyBy(guidanceScale));
195+
}
196+
197+
198+
/// <summary>
199+
/// Decodes the latents.
200+
/// </summary>
201+
/// <param name="model">The model.</param>
202+
/// <param name="prompt">The prompt.</param>
203+
/// <param name="options">The options.</param>
204+
/// <param name="latents">The latents.</param>
205+
/// <returns></returns>
206+
protected virtual async Task<DenseTensor<float>> DecodeLatents(IModelOptions model, PromptOptions prompt, SchedulerOptions options, DenseTensor<float> latents)
207+
{
208+
var timestamp = _logger?.LogBegin("Begin...");
209+
210+
// Scale and decode the image latents with vae.
211+
latents = latents.MultiplyBy(1.0f / model.ScaleFactor);
212+
213+
var images = prompt.BatchCount > 1
214+
? latents.Split(prompt.BatchCount)
215+
: new[] { latents };
216+
var imageTensors = new List<DenseTensor<float>>();
217+
foreach (var image in images)
218+
{
219+
var inputNames = _onnxModelService.GetInputNames(model, OnnxModelType.VaeDecoder);
220+
var inputParameters = CreateInputParameters(NamedOnnxValue.CreateFromTensor(inputNames[0], image));
221+
222+
// Run inference.
223+
using (var inferResult = await _onnxModelService.RunInferenceAsync(model, OnnxModelType.VaeDecoder, inputParameters))
224+
{
225+
var resultTensor = inferResult.FirstElementAs<DenseTensor<float>>();
226+
imageTensors.Add(resultTensor.ToDenseTensor());
227+
}
228+
}
229+
230+
var result = prompt.BatchCount > 1
231+
? imageTensors.Join()
232+
: imageTensors.FirstOrDefault();
233+
_logger?.LogEnd("End", timestamp);
234+
return result;
235+
}
236+
237+
238+
/// <summary>
239+
/// Creates the timestep NamedOnnxValue based on its NodeMetadata type.
240+
/// </summary>
241+
/// <param name="nodeMetadata">The node metadata.</param>
242+
/// <param name="timestepInputName">Name of the timestep input.</param>
243+
/// <param name="timestep">The timestep.</param>
244+
/// <returns></returns>
245+
protected static NamedOnnxValue CreateTimestepNamedOnnxValue(IReadOnlyDictionary<string, NodeMetadata> nodeMetadata, string timestepInputName, int timestep)
246+
{
247+
// Some models support Long or Float, could be more but fornow just support these 2
248+
var timestepMetaData = nodeMetadata[timestepInputName];
249+
return timestepMetaData.ElementDataType == TensorElementType.Int64
250+
? NamedOnnxValue.CreateFromTensor(timestepInputName, new DenseTensor<long>(new long[] { timestep }, new int[] { 1 }))
251+
: NamedOnnxValue.CreateFromTensor(timestepInputName, new DenseTensor<float>(new float[] { timestep }, new int[] { 1 }));
252+
}
253+
254+
255+
/// <summary>
256+
/// Helper for creating the input parameters.
257+
/// </summary>
258+
/// <param name="parameters">The parameters.</param>
259+
/// <returns></returns>
260+
protected static IReadOnlyList<NamedOnnxValue> CreateInputParameters(params NamedOnnxValue[] parameters)
261+
{
262+
return parameters.ToList();
263+
}
264+
}
265+
}

OnnxStack.StableDiffusion/Diffusers/LatentConsistency/ImageDiffuser.cs

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,13 @@
55
using OnnxStack.Core.Services;
66
using OnnxStack.StableDiffusion.Common;
77
using OnnxStack.StableDiffusion.Config;
8-
using OnnxStack.StableDiffusion.Diffusers.StableDiffusion;
98
using OnnxStack.StableDiffusion.Enums;
109
using OnnxStack.StableDiffusion.Helpers;
1110
using SixLabors.ImageSharp;
1211
using System;
1312
using System.Collections.Generic;
1413
using System.Linq;
14+
using System.Threading.Tasks;
1515

1616
namespace OnnxStack.StableDiffusion.Diffusers.LatentConsistency
1717
{
@@ -23,9 +23,7 @@ public sealed class ImageDiffuser : LatentConsistencyDiffuser
2323
/// <param name="configuration">The configuration.</param>
2424
/// <param name="onnxModelService">The onnx model service.</param>
2525
public ImageDiffuser(IOnnxModelService onnxModelService, IPromptService promptService, ILogger<LatentConsistencyDiffuser> logger)
26-
: base(onnxModelService, promptService, logger)
27-
{
28-
}
26+
: base(onnxModelService, promptService, logger) { }
2927

3028

3129
/// <summary>
@@ -41,7 +39,7 @@ public ImageDiffuser(IOnnxModelService onnxModelService, IPromptService promptSe
4139
/// <param name="options">The options.</param>
4240
/// <param name="scheduler">The scheduler.</param>
4341
/// <returns></returns>
44-
protected override IReadOnlyList<int> GetTimesteps(PromptOptions prompt, SchedulerOptions options, IScheduler scheduler)
42+
protected override IReadOnlyList<int> GetTimesteps(SchedulerOptions options, IScheduler scheduler)
4543
{
4644
// Image2Image we narrow step the range by the Strength
4745
var inittimestep = Math.Min((int)(options.InferenceSteps * options.Strength), options.InferenceSteps);
@@ -57,7 +55,7 @@ protected override IReadOnlyList<int> GetTimesteps(PromptOptions prompt, Schedul
5755
/// <param name="options">The options.</param>
5856
/// <param name="scheduler">The scheduler.</param>
5957
/// <returns></returns>
60-
protected override DenseTensor<float> PrepareLatents(IModelOptions model, PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList<int> timesteps)
58+
protected override Task<DenseTensor<float>> PrepareLatents(IModelOptions model, PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList<int> timesteps)
6159
{
6260
// Image input, decode, add noise, return as latent 0
6361
var imageTensor = prompt.InputImage.ToDenseTensor(new[] { 1, 3, options.Height, options.Width });
@@ -72,9 +70,9 @@ protected override DenseTensor<float> PrepareLatents(IModelOptions model, Prompt
7270

7371
var noisySample = scheduler.AddNoise(scaledSample, scheduler.CreateRandomSample(scaledSample.Dimensions), timesteps);
7472
if (prompt.BatchCount > 1)
75-
return noisySample.Repeat(prompt.BatchCount);
73+
return Task.FromResult(noisySample.Repeat(prompt.BatchCount));
7674

77-
return noisySample;
75+
return Task.FromResult(noisySample);
7876
}
7977
}
8078
}

0 commit comments

Comments
 (0)