Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit fa71a4e

Browse files
authored
Merge pull request #22 from saddam213/LCM_Inpaint
LCM Inpaint Diffuser
2 parents abd5cf7 + cb3152a commit fa71a4e

File tree

4 files changed

+261
-2
lines changed

4 files changed

+261
-2
lines changed

Assets/Templates/LCM-Dreamshaper-V7/LCM-Dreamshaper-V7-ONNX.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
"PipelineType": "LatentConsistency",
1414
"Diffusers": [
1515
"TextToImage",
16-
"ImageToImage"
16+
"ImageToImage",
17+
"ImageInpaintLegacy"
1718
],
1819
"ModelFiles": [
1920
"https://huggingface.co/TheyCallMeHex/LCM-Dreamshaper-V7-ONNX/resolve/main/tokenizer/model.onnx",
Lines changed: 256 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,256 @@
1+
using Microsoft.Extensions.Logging;
2+
using Microsoft.ML.OnnxRuntime.Tensors;
3+
using OnnxStack.Core;
4+
using OnnxStack.Core.Config;
5+
using OnnxStack.Core.Model;
6+
using OnnxStack.Core.Services;
7+
using OnnxStack.StableDiffusion.Common;
8+
using OnnxStack.StableDiffusion.Config;
9+
using OnnxStack.StableDiffusion.Enums;
10+
using OnnxStack.StableDiffusion.Helpers;
11+
using SixLabors.ImageSharp;
12+
using SixLabors.ImageSharp.Processing;
13+
using System;
14+
using System.Collections.Generic;
15+
using System.Diagnostics;
16+
using System.Linq;
17+
using System.Threading;
18+
using System.Threading.Tasks;
19+
20+
namespace OnnxStack.StableDiffusion.Diffusers.LatentConsistency
21+
{
22+
public sealed class InpaintLegacyDiffuser : LatentConsistencyDiffuser
23+
{
24+
/// <summary>
25+
/// Initializes a new instance of the <see cref="InpaintLegacyDiffuser"/> class.
26+
/// </summary>
27+
/// <param name="configuration">The configuration.</param>
28+
/// <param name="onnxModelService">The onnx model service.</param>
29+
public InpaintLegacyDiffuser(IOnnxModelService onnxModelService, IPromptService promptService, ILogger<LatentConsistencyDiffuser> logger)
30+
: base(onnxModelService, promptService, logger)
31+
{
32+
}
33+
34+
35+
/// <summary>
36+
/// Gets the type of the diffuser.
37+
/// </summary>
38+
public override DiffuserType DiffuserType => DiffuserType.ImageInpaintLegacy;
39+
40+
41+
/// <summary>
42+
/// Gets the timesteps.
43+
/// </summary>
44+
/// <param name="prompt">The prompt.</param>
45+
/// <param name="options">The options.</param>
46+
/// <param name="scheduler">The scheduler.</param>
47+
/// <returns></returns>
48+
protected override IReadOnlyList<int> GetTimesteps(SchedulerOptions options, IScheduler scheduler)
49+
{
50+
// Image2Image we narrow step the range by the Strength
51+
var inittimestep = Math.Min((int)(options.InferenceSteps * options.Strength), options.InferenceSteps);
52+
var start = Math.Max(options.InferenceSteps - inittimestep, 0);
53+
return scheduler.Timesteps.Skip(start).ToList();
54+
}
55+
56+
57+
/// <summary>
58+
/// Runs the scheduler steps.
59+
/// </summary>
60+
/// <param name="modelOptions">The model options.</param>
61+
/// <param name="promptOptions">The prompt options.</param>
62+
/// <param name="schedulerOptions">The scheduler options.</param>
63+
/// <param name="promptEmbeddings">The prompt embeddings.</param>
64+
/// <param name="performGuidance">if set to <c>true</c> [perform guidance].</param>
65+
/// <param name="progressCallback">The progress callback.</param>
66+
/// <param name="cancellationToken">The cancellation token.</param>
67+
/// <returns></returns>
68+
protected override async Task<DenseTensor<float>> SchedulerStepAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, DenseTensor<float> promptEmbeddings, bool performGuidance, Action<int, int> progressCallback = null, CancellationToken cancellationToken = default)
69+
{
70+
using (var scheduler = GetScheduler(schedulerOptions))
71+
{
72+
// Get timesteps
73+
var timesteps = GetTimesteps(schedulerOptions, scheduler);
74+
75+
// Create latent sample
76+
var latentsOriginal = await PrepareLatentsAsync(modelOptions, promptOptions, schedulerOptions, scheduler, timesteps);
77+
78+
// Create masks sample
79+
var maskImage = PrepareMask(modelOptions, promptOptions, schedulerOptions);
80+
81+
// Generate some noise
82+
var noise = scheduler.CreateRandomSample(latentsOriginal.Dimensions);
83+
84+
// Add noise to original latent
85+
var latents = scheduler.AddNoise(latentsOriginal, noise, timesteps);
86+
87+
// Get Model metadata
88+
var metadata = _onnxModelService.GetModelMetadata(modelOptions, OnnxModelType.Unet);
89+
90+
// Get Guidance Scale Embedding
91+
var guidanceEmbeddings = GetGuidanceScaleEmbedding(schedulerOptions.GuidanceScale);
92+
93+
// Denoised result
94+
DenseTensor<float> denoised = null;
95+
96+
// Loop though the timesteps
97+
var step = 0;
98+
foreach (var timestep in timesteps)
99+
{
100+
step++;
101+
var stepTime = Stopwatch.GetTimestamp();
102+
cancellationToken.ThrowIfCancellationRequested();
103+
104+
// Create input tensor.
105+
var inputTensor = scheduler.ScaleInput(latents, timestep);
106+
var timestepTensor = CreateTimestepTensor(timestep);
107+
108+
var outputChannels = 1;
109+
var outputDimension = schedulerOptions.GetScaledDimension(outputChannels);
110+
using (var inferenceParameters = new OnnxInferenceParameters(metadata))
111+
{
112+
inferenceParameters.AddInputTensor(inputTensor);
113+
inferenceParameters.AddInputTensor(timestepTensor);
114+
inferenceParameters.AddInputTensor(promptEmbeddings);
115+
inferenceParameters.AddInputTensor(guidanceEmbeddings);
116+
inferenceParameters.AddOutputBuffer(outputDimension);
117+
118+
var results = await _onnxModelService.RunInferenceAsync(modelOptions, OnnxModelType.Unet, inferenceParameters);
119+
using (var result = results.First())
120+
{
121+
var noisePred = result.ToDenseTensor();
122+
123+
// Scheduler Step
124+
var schedulerResult = scheduler.Step(noisePred, timestep, latents);
125+
126+
latents = schedulerResult.Result;
127+
denoised = schedulerResult.SampleData;
128+
129+
// Add noise to original latent
130+
if (step < timesteps.Count - 1)
131+
{
132+
var noiseTimestep = timesteps[step + 1];
133+
var initLatentsProper = scheduler.AddNoise(latentsOriginal, noise, new[] { noiseTimestep });
134+
135+
// Apply mask and combine
136+
latents = ApplyMaskedLatents(schedulerResult.Result, initLatentsProper, maskImage);
137+
}
138+
}
139+
}
140+
141+
progressCallback?.Invoke(step, timesteps.Count);
142+
_logger?.LogEnd($"Step {step}/{timesteps.Count}", stepTime);
143+
}
144+
145+
// Decode Latents
146+
return await DecodeLatentsAsync(modelOptions, promptOptions, schedulerOptions, denoised);
147+
}
148+
}
149+
150+
151+
/// <summary>
152+
/// Prepares the input latents for inference.
153+
/// </summary>
154+
/// <param name="model">The model.</param>
155+
/// <param name="prompt">The prompt.</param>
156+
/// <param name="options">The options.</param>
157+
/// <param name="scheduler">The scheduler.</param>
158+
/// <param name="timesteps">The timesteps.</param>
159+
/// <returns></returns>
160+
protected override async Task<DenseTensor<float>> PrepareLatentsAsync(IModelOptions model, PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList<int> timesteps)
161+
{
162+
// Image input, decode, add noise, return as latent 0
163+
var imageTensor = prompt.InputImage.ToDenseTensor(new[] { 1, 3, options.Height, options.Width });
164+
165+
//TODO: Model Config, Channels
166+
var outputDimensions = options.GetScaledDimension();
167+
var metadata = _onnxModelService.GetModelMetadata(model, OnnxModelType.VaeEncoder);
168+
using (var inferenceParameters = new OnnxInferenceParameters(metadata))
169+
{
170+
inferenceParameters.AddInputTensor(imageTensor);
171+
inferenceParameters.AddOutputBuffer(outputDimensions);
172+
173+
var results = await _onnxModelService.RunInferenceAsync(model, OnnxModelType.VaeEncoder, inferenceParameters);
174+
using (var result = results.First())
175+
{
176+
var outputResult = result.ToDenseTensor();
177+
var scaledSample = outputResult
178+
.Add(scheduler.CreateRandomSample(outputDimensions, options.InitialNoiseLevel))
179+
.MultiplyBy(model.ScaleFactor);
180+
181+
return scaledSample;
182+
}
183+
}
184+
}
185+
186+
187+
/// <summary>
188+
/// Prepares the mask.
189+
/// </summary>
190+
/// <param name="promptOptions">The prompt options.</param>
191+
/// <param name="schedulerOptions">The scheduler options.</param>
192+
/// <returns></returns>
193+
private DenseTensor<float> PrepareMask(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions)
194+
{
195+
using (var mask = promptOptions.InputImageMask.ToImage())
196+
{
197+
// Prepare the mask
198+
int width = schedulerOptions.GetScaledWidth();
199+
int height = schedulerOptions.GetScaledHeight();
200+
mask.Mutate(x => x.Grayscale());
201+
mask.Mutate(x => x.Resize(new Size(width, height), KnownResamplers.NearestNeighbor, true));
202+
var maskTensor = new DenseTensor<float>(new[] { 1, 4, width, height });
203+
mask.ProcessPixelRows(img =>
204+
{
205+
for (int x = 0; x < width; x++)
206+
{
207+
for (int y = 0; y < height; y++)
208+
{
209+
var pixelSpan = img.GetRowSpan(y);
210+
var value = pixelSpan[x].A / 255.0f;
211+
maskTensor[0, 0, y, x] = 1f - value;
212+
maskTensor[0, 1, y, x] = 0f; // Needed for shape only
213+
maskTensor[0, 2, y, x] = 0f; // Needed for shape only
214+
maskTensor[0, 3, y, x] = 0f; // Needed for shape only
215+
}
216+
}
217+
});
218+
219+
return maskTensor;
220+
}
221+
}
222+
223+
224+
/// <summary>
225+
/// Applies the masked latents.
226+
/// </summary>
227+
/// <param name="latents">The latents.</param>
228+
/// <param name="initLatentsProper">The initialize latents proper.</param>
229+
/// <param name="mask">The mask.</param>
230+
/// <returns></returns>
231+
private DenseTensor<float> ApplyMaskedLatents(DenseTensor<float> latents, DenseTensor<float> initLatentsProper, DenseTensor<float> mask)
232+
{
233+
var result = new DenseTensor<float>(latents.Dimensions);
234+
for (int batch = 0; batch < latents.Dimensions[0]; batch++)
235+
{
236+
for (int channel = 0; channel < latents.Dimensions[1]; channel++)
237+
{
238+
for (int height = 0; height < latents.Dimensions[2]; height++)
239+
{
240+
for (int width = 0; width < latents.Dimensions[3]; width++)
241+
{
242+
float maskValue = mask[batch, 0, height, width];
243+
float latentsValue = latents[batch, channel, height, width];
244+
float initLatentsProperValue = initLatentsProper[batch, channel, height, width];
245+
246+
//Apply the logic to compute the result based on the mask
247+
float newValue = initLatentsProperValue * maskValue + latentsValue * (1f - maskValue);
248+
result[batch, channel, height, width] = newValue;
249+
}
250+
}
251+
}
252+
}
253+
return result;
254+
}
255+
}
256+
}

OnnxStack.StableDiffusion/Registration.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ public static void AddOnnxStackStableDiffusion(this IServiceCollection serviceCo
4242
//LatentConsistency
4343
serviceCollection.AddSingleton<IDiffuser, StableDiffusion.Diffusers.LatentConsistency.TextDiffuser>();
4444
serviceCollection.AddSingleton<IDiffuser, StableDiffusion.Diffusers.LatentConsistency.ImageDiffuser>();
45+
serviceCollection.AddSingleton<IDiffuser, StableDiffusion.Diffusers.LatentConsistency.InpaintLegacyDiffuser>();
4546
}
4647

4748

OnnxStack.UI/appsettings.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,8 @@
7070
"PipelineType": "LatentConsistency",
7171
"Diffusers": [
7272
"TextToImage",
73-
"ImageToImage"
73+
"ImageToImage",
74+
"ImageInpaintLegacy"
7475
],
7576
"ModelFiles": [
7677
"https://huggingface.co/TheyCallMeHex/LCM-Dreamshaper-V7-ONNX/resolve/main/tokenizer/model.onnx",

0 commit comments

Comments
 (0)