Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit 5de2ceb

Browse files
committed
LCM TextToImage implemented
1 parent 76e34b4 commit 5de2ceb

File tree

3 files changed

+194
-10
lines changed

3 files changed

+194
-10
lines changed
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
using Microsoft.ML.OnnxRuntime;
2+
using Microsoft.ML.OnnxRuntime.Tensors;
3+
using OnnxStack.Core.Config;
4+
using OnnxStack.Core.Services;
5+
using OnnxStack.StableDiffusion.Common;
6+
using OnnxStack.StableDiffusion.Config;
7+
using OnnxStack.StableDiffusion.Enums;
8+
using OnnxStack.StableDiffusion.Helpers;
9+
using OnnxStack.StableDiffusion.Schedulers;
10+
using System;
11+
using System.Collections.Generic;
12+
using System.Linq;
13+
using System.Threading;
14+
using System.Threading.Tasks;
15+
16+
namespace OnnxStack.StableDiffusion.Diffusers.LatentConsistency
17+
{
18+
public sealed class TextDiffuser : DiffuserBase
19+
{
20+
/// <summary>
21+
/// Initializes a new instance of the <see cref="TextDiffuser"/> class.
22+
/// </summary>
23+
/// <param name="configuration">The configuration.</param>
24+
/// <param name="onnxModelService">The onnx model service.</param>
25+
public TextDiffuser(IOnnxModelService onnxModelService, IPromptService promptService)
26+
: base(onnxModelService, promptService)
27+
{
28+
}
29+
30+
public override async Task<DenseTensor<float>> DiffuseAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, Action<int, int> progressCallback = null, CancellationToken cancellationToken = default)
31+
{
32+
// Create random seed if none was set
33+
schedulerOptions.Seed = schedulerOptions.Seed > 0 ? schedulerOptions.Seed : Random.Shared.Next();
34+
35+
// LCM does not support classifier-free guidance
36+
var guidance = schedulerOptions.GuidanceScale;
37+
schedulerOptions.GuidanceScale = 0f;
38+
39+
// LCM does not support negative prompting
40+
promptOptions.NegativePrompt = string.Empty;
41+
42+
// Get Scheduler
43+
using (var scheduler = GetScheduler(promptOptions, schedulerOptions))
44+
{
45+
// Process prompts
46+
var promptEmbeddings = await _promptService.CreatePromptAsync(modelOptions, promptOptions, schedulerOptions);
47+
48+
// Get timesteps
49+
var timesteps = GetTimesteps(promptOptions, schedulerOptions, scheduler);
50+
51+
// Create latent sample
52+
var latents = PrepareLatents(modelOptions, promptOptions, schedulerOptions, scheduler, timesteps);
53+
54+
// Get Guidance Scale Embedding
55+
var guidanceEmbeddings = GetGuidanceScaleEmbedding(guidance);
56+
57+
// Denoised result
58+
DenseTensor<float> denoised = null;
59+
60+
// Loop though the timesteps
61+
var step = 0;
62+
foreach (var timestep in timesteps)
63+
{
64+
step++;
65+
cancellationToken.ThrowIfCancellationRequested();
66+
67+
// Create input tensor.
68+
var inputTensor = scheduler.ScaleInput(latents, timestep);
69+
70+
// Create Input Parameters
71+
var imputMeta = _onnxModelService.GetInputMetadata(modelOptions, OnnxModelType.Unet);
72+
var inputNames = _onnxModelService.GetInputNames(modelOptions, OnnxModelType.Unet);
73+
var inputParameters = CreateInputParameters(
74+
NamedOnnxValue.CreateFromTensor(inputNames[0], inputTensor),
75+
NamedOnnxValue.CreateFromTensor(inputNames[1], new DenseTensor<long>(new long[] { timestep }, new int[] { 1 })),
76+
NamedOnnxValue.CreateFromTensor(inputNames[2], promptEmbeddings),
77+
NamedOnnxValue.CreateFromTensor(inputNames[3], guidanceEmbeddings));
78+
79+
// Run Inference
80+
using (var inferResult = await _onnxModelService.RunInferenceAsync(modelOptions, OnnxModelType.Unet, inputParameters))
81+
{
82+
var noisePred = inferResult.FirstElementAs<DenseTensor<float>>();
83+
84+
// Scheduler Step
85+
var schedulerResult = scheduler.Step(noisePred, timestep, latents);
86+
87+
latents = schedulerResult.PreviousSample;
88+
denoised = schedulerResult.ExtraSample;
89+
}
90+
91+
progressCallback?.Invoke(step, timesteps.Count);
92+
}
93+
94+
// Decode Latents
95+
return await DecodeLatents(modelOptions, promptOptions, schedulerOptions, denoised);
96+
}
97+
}
98+
99+
100+
/// <summary>
101+
/// Gets the timesteps.
102+
/// </summary>
103+
/// <param name="prompt">The prompt.</param>
104+
/// <param name="options">The options.</param>
105+
/// <param name="scheduler">The scheduler.</param>
106+
/// <returns></returns>
107+
protected override IReadOnlyList<int> GetTimesteps(PromptOptions prompt, SchedulerOptions options, IScheduler scheduler)
108+
{
109+
return scheduler.Timesteps;
110+
}
111+
112+
113+
/// <summary>
114+
/// Prepares the latents for inference.
115+
/// </summary>
116+
/// <param name="prompt">The prompt.</param>
117+
/// <param name="options">The options.</param>
118+
/// <param name="scheduler">The scheduler.</param>
119+
/// <returns></returns>
120+
protected override DenseTensor<float> PrepareLatents(IModelOptions model, PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList<int> timesteps)
121+
{
122+
return scheduler.CreateRandomSample(options.GetScaledDimension(prompt.BatchCount), scheduler.InitNoiseSigma);
123+
}
124+
125+
126+
/// <summary>
127+
/// Gets the scheduler.
128+
/// </summary>
129+
/// <param name="prompt"></param>
130+
/// <param name="options">The options.</param>
131+
/// <returns></returns>
132+
protected override IScheduler GetScheduler(PromptOptions prompt, SchedulerOptions options)
133+
{
134+
return prompt.SchedulerType switch
135+
{
136+
SchedulerType.LCM => new LCMScheduler(options),
137+
_ => default
138+
};
139+
}
140+
141+
142+
/// <summary>
143+
/// Gets the guidance scale embedding.
144+
/// </summary>
145+
/// <param name="options">The options.</param>
146+
/// <param name="embeddingDim">The embedding dim.</param>
147+
/// <returns></returns>
148+
public DenseTensor<float> GetGuidanceScaleEmbedding(float guidance, int embeddingDim = 256)
149+
{
150+
// TODO:
151+
//assert len(w.shape) == 1
152+
//w = w * 1000.0
153+
154+
//half_dim = embedding_dim // 2
155+
//emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
156+
//emb = torch.exp(torch.arange(half_dim, dtype = dtype) * -emb)
157+
//emb = w.to(dtype)[:, None] * emb[None, :]
158+
//emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim = 1)
159+
//if embedding_dim % 2 == 1: # zero pad
160+
// emb = torch.nn.functional.pad(emb, (0, 1))
161+
//assert emb.shape == (w.shape[0], embedding_dim)
162+
//return emb
163+
164+
var w = guidance - 1f;
165+
166+
var half_dim = embeddingDim / 2;
167+
168+
var log = MathF.Log(10000.0f) / (half_dim - 1);
169+
170+
var emb = Enumerable.Range(0, half_dim)
171+
.Select(x => MathF.Exp(x * -log))
172+
.ToArray();
173+
var embSin = emb.Select(MathF.Sin).ToArray();
174+
var embCos = emb.Select(MathF.Cos).ToArray();
175+
176+
DenseTensor<float> result = new DenseTensor<float>(new[] { 1, 2 * half_dim });
177+
for (int i = 0; i < half_dim; i++)
178+
{
179+
result[0, i] = embSin[i];
180+
result[0, i + half_dim] = embCos[i];
181+
}
182+
183+
return result;
184+
}
185+
}
186+
}

OnnxStack.StableDiffusion/Pipelines/LatentConsistency.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using OnnxStack.Core.Services;
22
using OnnxStack.StableDiffusion.Common;
33
using OnnxStack.StableDiffusion.Diffusers;
4+
using OnnxStack.StableDiffusion.Diffusers.LatentConsistency;
45
using OnnxStack.StableDiffusion.Enums;
56
using System.Collections.Concurrent;
67
using System.Collections.Generic;
@@ -16,7 +17,7 @@ public LatentConsistencyPipeline(IOnnxModelService onnxModelService, IPromptServ
1617
{
1718
var diffusers = new Dictionary<DiffuserType, IDiffuser>
1819
{
19-
//TODO: TextToImage and ImageToImage is supported with LCM
20+
{ DiffuserType.TextToImage, new TextDiffuser(onnxModelService, promptService) }
2021
};
2122
_pipelineType = DiffuserPipelineType.LatentConsistency;
2223
_diffusers = new ConcurrentDictionary<DiffuserType, IDiffuser>(diffusers);

OnnxStack.StableDiffusion/Schedulers/LCMScheduler.cs

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ protected override void Initialize()
5454

5555
//The default number of inference steps used to generate a linearly - spaced timestep schedule, from which we
5656
//will ultimately take `num_inference_steps` evenly spaced timesteps to form the final timestep schedule.
57-
_originalInferenceSteps = Options.InferenceSteps;
57+
_originalInferenceSteps = 30;
5858

5959
SetInitNoiseSigma(1.0f);
6060
}
@@ -68,14 +68,14 @@ protected override int[] SetTimesteps()
6868
{
6969
// LCM Timesteps Setting
7070
// Currently, only linear spacing is supported.
71-
var timeIncrement = (float)Options.TrainTimesteps / _originalInferenceSteps;
71+
var timeIncrement = Options.TrainTimesteps / _originalInferenceSteps;
7272

7373
//# LCM Training Steps Schedule
7474
var lcmOriginTimesteps = Enumerable.Range(1, _originalInferenceSteps)
7575
.Select(x => x * timeIncrement - 1f)
7676
.ToArray();
7777

78-
var skippingStep = (float)lcmOriginTimesteps.Length / Options.InferenceSteps;
78+
var skippingStep = lcmOriginTimesteps.Length / Options.InferenceSteps;
7979

8080
// LCM Inference Steps Schedule
8181
return lcmOriginTimesteps
@@ -199,12 +199,9 @@ public override DenseTensor<float> AddNoise(DenseTensor<float> originalSamples,
199199
//self.sigma_data = 0.5 # Default: 0.5
200200
var sigmaData = 0.5f;
201201

202-
//c_skip = self.sigma_data * *2 / ((t / 0.1) * *2 + self.sigma_data * *2)
203-
float cSkip = MathF.Pow(sigmaData, 2f) / (MathF.Pow(timestep / 0.1f, 2f) + MathF.Pow(sigmaData, 2f));
204-
205-
//c_out = (t / 0.1) / ((t / 0.1) * *2 + self.sigma_data * *2) * *0.5
206-
float cOut = (timestep / 0.1f) / (MathF.Pow(timestep / 0.1f, 2f) + MathF.Pow(sigmaData, 2f)) * 0.5f;
207-
202+
float c = (MathF.Pow(timestep / 0.1f, 2f) + MathF.Pow(sigmaData, 2f));
203+
float cSkip = MathF.Pow(sigmaData, 2f) / c;
204+
float cOut = (timestep / 0.1f) / MathF.Pow(c, 0.5f);
208205
return (cSkip, cOut);
209206
}
210207

0 commit comments

Comments
 (0)