Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit 951ef37

Browse files
committed
InstaFlow pipeline added, TextToImage
1 parent ce10893 commit 951ef37

File tree

8 files changed

+376
-2
lines changed

8 files changed

+376
-2
lines changed

OnnxStack.Console/appsettings.json

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,46 @@
132132
"OnnxModelPath": "D:\\Repositories\\photon\\vae_decoder\\model.onnx"
133133
}
134134
]
135+
},
136+
{
137+
"Name": "InstaFlow",
138+
"IsEnabled": true,
139+
"PadTokenId": 49407,
140+
"BlankTokenId": 49407,
141+
"TokenizerLimit": 77,
142+
"EmbeddingsLength": 768,
143+
"ScaleFactor": 0.18215,
144+
"PipelineType": "InstaFlow",
145+
"Diffusers": [
146+
"TextToImage"
147+
],
148+
"DeviceId": 0,
149+
"InterOpNumThreads": 0,
150+
"IntraOpNumThreads": 0,
151+
"ExecutionMode": "ORT_SEQUENTIAL",
152+
"ExecutionProvider": "DirectML",
153+
"ModelConfigurations": [
154+
{
155+
"Type": "Tokenizer",
156+
"OnnxModelPath": "D:\\Repositories\\InstaFlow-0.9B-ONNX\\tokenizer\\model.onnx"
157+
},
158+
{
159+
"Type": "Unet",
160+
"OnnxModelPath": "D:\\Repositories\\InstaFlow-0.9B-ONNX\\unet\\model.onnx"
161+
},
162+
{
163+
"Type": "TextEncoder",
164+
"OnnxModelPath": "D:\\Repositories\\InstaFlow-0.9B-ONNX\\text_encoder\\model.onnx"
165+
},
166+
{
167+
"Type": "VaeEncoder",
168+
"OnnxModelPath": "D:\\Repositories\\InstaFlow-0.9B-ONNX\\vae_encoder\\model.onnx"
169+
},
170+
{
171+
"Type": "VaeDecoder",
172+
"OnnxModelPath": "D:\\Repositories\\InstaFlow-0.9B-ONNX\\vae_decoder\\model.onnx"
173+
}
174+
]
135175
}
136176
]
137177
}
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
using Microsoft.Extensions.Logging;
2+
using Microsoft.ML.OnnxRuntime.Tensors;
3+
using OnnxStack.Core;
4+
using OnnxStack.Core.Config;
5+
using OnnxStack.Core.Model;
6+
using OnnxStack.Core.Services;
7+
using OnnxStack.StableDiffusion.Common;
8+
using OnnxStack.StableDiffusion.Config;
9+
using OnnxStack.StableDiffusion.Enums;
10+
using OnnxStack.StableDiffusion.Helpers;
11+
using OnnxStack.StableDiffusion.Schedulers.InstaFlow;
12+
using System;
13+
using System.Diagnostics;
14+
using System.Linq;
15+
using System.Threading;
16+
using System.Threading.Tasks;
17+
18+
namespace OnnxStack.StableDiffusion.Diffusers.InstaFlow
19+
{
20+
public abstract class InstaFlowDiffuser : DiffuserBase, IDiffuser
21+
{
22+
/// <summary>
23+
/// Initializes a new instance of the <see cref="InstaFlowDiffuser"/> class.
24+
/// </summary>
25+
/// <param name="configuration">The configuration.</param>
26+
/// <param name="onnxModelService">The onnx model service.</param>
27+
public InstaFlowDiffuser(IOnnxModelService onnxModelService, IPromptService promptService, ILogger<InstaFlowDiffuser> logger)
28+
: base(onnxModelService, promptService, logger) { }
29+
30+
31+
/// <summary>
32+
/// Gets the type of the pipeline.
33+
/// </summary>
34+
public override DiffuserPipelineType PipelineType => DiffuserPipelineType.InstaFlow;
35+
36+
37+
/// <summary>
38+
/// Runs the scheduler steps.
39+
/// </summary>
40+
/// <param name="modelOptions">The model options.</param>
41+
/// <param name="promptOptions">The prompt options.</param>
42+
/// <param name="schedulerOptions">The scheduler options.</param>
43+
/// <param name="promptEmbeddings">The prompt embeddings.</param>
44+
/// <param name="performGuidance">if set to <c>true</c> [perform guidance].</param>
45+
/// <param name="progressCallback">The progress callback.</param>
46+
/// <param name="cancellationToken">The cancellation token.</param>
47+
/// <returns></returns>
48+
protected override async Task<DenseTensor<float>> SchedulerStepAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, DenseTensor<float> promptEmbeddings, bool performGuidance, Action<int, int> progressCallback = null, CancellationToken cancellationToken = default)
49+
{
50+
// Get Scheduler
51+
using (var scheduler = GetScheduler(schedulerOptions))
52+
{
53+
// Get timesteps
54+
var timesteps = GetTimesteps(schedulerOptions, scheduler);
55+
56+
// Create latent sample
57+
var latents = await PrepareLatentsAsync(modelOptions, promptOptions, schedulerOptions, scheduler, timesteps);
58+
59+
// Get Model metadata
60+
var metadata = _onnxModelService.GetModelMetadata(modelOptions, OnnxModelType.Unet);
61+
62+
// Get the distilled Timestep
63+
var distilledTimestep = 1.0f / timesteps.Count;
64+
65+
// Loop though the timesteps
66+
var step = 0;
67+
foreach (var timestep in timesteps)
68+
{
69+
step++;
70+
var stepTime = Stopwatch.GetTimestamp();
71+
cancellationToken.ThrowIfCancellationRequested();
72+
73+
// Create input tensor.
74+
var inputLatent = performGuidance ? latents.Repeat(2) : latents;
75+
var inputTensor = scheduler.ScaleInput(inputLatent, timestep);
76+
var timestepTensor = CreateTimestepTensor(inputLatent, timestep);
77+
78+
var outputChannels = performGuidance ? 2 : 1;
79+
var outputDimension = schedulerOptions.GetScaledDimension(outputChannels);
80+
using (var inferenceParameters = new OnnxInferenceParameters(metadata))
81+
{
82+
inferenceParameters.AddInputTensor(inputTensor);
83+
inferenceParameters.AddInputTensor(timestepTensor);
84+
inferenceParameters.AddInputTensor(promptEmbeddings);
85+
inferenceParameters.AddOutputBuffer(outputDimension);
86+
87+
var results = await _onnxModelService.RunInferenceAsync(modelOptions, OnnxModelType.Unet, inferenceParameters);
88+
using (var result = results.First())
89+
{
90+
var noisePred = result.ToDenseTensor();
91+
92+
// Perform guidance
93+
if (performGuidance)
94+
noisePred = PerformGuidance(noisePred, schedulerOptions.GuidanceScale);
95+
96+
// Scheduler Step
97+
latents = scheduler.Step(noisePred, timestep, latents).Result;
98+
99+
latents = noisePred
100+
.MultiplyTensorByFloat(distilledTimestep)
101+
.AddTensors(latents);
102+
}
103+
}
104+
105+
progressCallback?.Invoke(step, timesteps.Count);
106+
_logger?.LogEnd($"Step {step}/{timesteps.Count}", stepTime);
107+
}
108+
109+
// Decode Latents
110+
return await DecodeLatentsAsync(modelOptions, promptOptions, schedulerOptions, latents);
111+
}
112+
}
113+
114+
115+
/// <summary>
116+
/// Creates the timestep tensor.
117+
/// </summary>
118+
/// <param name="latents">The latents.</param>
119+
/// <param name="timestep">The timestep.</param>
120+
/// <returns></returns>
121+
private DenseTensor<float> CreateTimestepTensor(DenseTensor<float> latents, int timestep)
122+
{
123+
var timestepTensor = new DenseTensor<float>(new[] { latents.Dimensions[0] });
124+
timestepTensor.Fill(timestep);
125+
return timestepTensor;
126+
}
127+
128+
129+
/// <summary>
130+
/// Gets the scheduler.
131+
/// </summary>
132+
/// <param name="options">The options.</param>
133+
/// <param name="schedulerConfig">The scheduler configuration.</param>
134+
/// <returns></returns>
135+
protected override IScheduler GetScheduler(SchedulerOptions options)
136+
{
137+
return options.SchedulerType switch
138+
{
139+
SchedulerType.InstaFlow => new InstaFlowScheduler(options),
140+
_ => default
141+
};
142+
}
143+
}
144+
}
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
using Microsoft.Extensions.Logging;
2+
using Microsoft.ML.OnnxRuntime.Tensors;
3+
using OnnxStack.Core.Services;
4+
using OnnxStack.StableDiffusion.Common;
5+
using OnnxStack.StableDiffusion.Config;
6+
using OnnxStack.StableDiffusion.Enums;
7+
using System.Collections.Generic;
8+
using System.Threading.Tasks;
9+
10+
namespace OnnxStack.StableDiffusion.Diffusers.InstaFlow
11+
{
12+
public sealed class TextDiffuser : InstaFlowDiffuser
13+
{
14+
/// <summary>
15+
/// Initializes a new instance of the <see cref="TextDiffuser"/> class.
16+
/// </summary>
17+
/// <param name="configuration">The configuration.</param>
18+
/// <param name="onnxModelService">The onnx model service.</param>
19+
public TextDiffuser(IOnnxModelService onnxModelService, IPromptService promptService, ILogger<InstaFlowDiffuser> logger)
20+
: base(onnxModelService, promptService, logger)
21+
{
22+
}
23+
24+
25+
/// <summary>
26+
/// Gets the type of the diffuser.
27+
/// </summary>
28+
public override DiffuserType DiffuserType => DiffuserType.TextToImage;
29+
30+
31+
/// <summary>
32+
/// Gets the timesteps.
33+
/// </summary>
34+
/// <param name="prompt">The prompt.</param>
35+
/// <param name="options">The options.</param>
36+
/// <param name="scheduler">The scheduler.</param>
37+
/// <returns></returns>
38+
protected override IReadOnlyList<int> GetTimesteps(SchedulerOptions options, IScheduler scheduler)
39+
{
40+
return scheduler.Timesteps;
41+
}
42+
43+
44+
/// <summary>
45+
/// Prepares the latents for inference.
46+
/// </summary>
47+
/// <param name="prompt">The prompt.</param>
48+
/// <param name="options">The options.</param>
49+
/// <param name="scheduler">The scheduler.</param>
50+
/// <returns></returns>
51+
protected override Task<DenseTensor<float>> PrepareLatentsAsync(IModelOptions model, PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList<int> timesteps)
52+
{
53+
return Task.FromResult(scheduler.CreateRandomSample(options.GetScaledDimension(), scheduler.InitNoiseSigma));
54+
}
55+
}
56+
}

OnnxStack.StableDiffusion/Enums/SchedulerType.cs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,10 @@ public enum SchedulerType
2222
[Display(Name = "KDPM2")]
2323
KDPM2 = 5,
2424

25-
[Display(Name = "LCM")]
26-
LCM = 20
25+
[Display(Name = "LCM")]
26+
LCM = 20,
27+
28+
[Display(Name = "InstaFlow")]
29+
InstaFlow = 21
2730
}
2831
}

OnnxStack.StableDiffusion/Extensions.cs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,10 @@ public static SchedulerType[] GetSchedulerTypes(this DiffuserPipelineType pipeli
9999
{
100100
return pipelineType switch
101101
{
102+
DiffuserPipelineType.InstaFlow => new[]
103+
{
104+
SchedulerType.InstaFlow
105+
},
102106
DiffuserPipelineType.LatentConsistency => new[]
103107
{
104108
SchedulerType.LCM
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
using Microsoft.Extensions.Logging;
2+
using OnnxStack.Core;
3+
using OnnxStack.StableDiffusion.Common;
4+
using OnnxStack.StableDiffusion.Diffusers;
5+
using OnnxStack.StableDiffusion.Enums;
6+
using System.Collections.Concurrent;
7+
using System.Collections.Generic;
8+
using System.Linq;
9+
10+
namespace OnnxStack.StableDiffusion.Pipelines
11+
{
12+
public sealed class InstaFlowPipeline : IPipeline
13+
{
14+
private readonly DiffuserPipelineType _pipelineType;
15+
private readonly ILogger<InstaFlowPipeline> _logger;
16+
private readonly ConcurrentDictionary<DiffuserType, IDiffuser> _diffusers;
17+
18+
/// <summary>
19+
/// Initializes a new instance of the <see cref="InstaFlowPipeline"/> class.
20+
/// </summary>
21+
/// <param name="onnxModelService">The onnx model service.</param>
22+
/// <param name="promptService">The prompt service.</param>
23+
public InstaFlowPipeline(IEnumerable<IDiffuser> diffusers, ILogger<InstaFlowPipeline> logger)
24+
{
25+
_logger = logger;
26+
_pipelineType = DiffuserPipelineType.InstaFlow;
27+
_diffusers = diffusers
28+
.Where(x => x.PipelineType == _pipelineType)
29+
.ToConcurrentDictionary(k => k.DiffuserType, v => v);
30+
}
31+
32+
33+
/// <summary>
34+
/// Gets the type of the pipeline.
35+
/// </summary>
36+
public DiffuserPipelineType PipelineType => _pipelineType;
37+
38+
39+
/// <summary>
40+
/// Gets the diffusers.
41+
/// </summary>
42+
public ConcurrentDictionary<DiffuserType, IDiffuser> Diffusers => _diffusers;
43+
44+
45+
/// <summary>
46+
/// Gets the diffuser.
47+
/// </summary>
48+
/// <param name="diffuserType">Type of the diffuser.</param>
49+
/// <returns></returns>
50+
public IDiffuser GetDiffuser(DiffuserType diffuserType)
51+
{
52+
_diffusers.TryGetValue(diffuserType, out var diffuser);
53+
return diffuser;
54+
}
55+
}
56+
}

OnnxStack.StableDiffusion/Registration.cs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ public static void AddOnnxStackStableDiffusion(this IServiceCollection serviceCo
3232
//Pipelines
3333
serviceCollection.AddSingleton<IPipeline, StableDiffusionPipeline>();
3434
serviceCollection.AddSingleton<IPipeline, LatentConsistencyPipeline>();
35+
serviceCollection.AddSingleton<IPipeline, InstaFlowPipeline>();
3536

3637
//StableDiffusion
3738
serviceCollection.AddSingleton<IDiffuser, StableDiffusion.Diffusers.StableDiffusion.TextDiffuser>();
@@ -43,6 +44,9 @@ public static void AddOnnxStackStableDiffusion(this IServiceCollection serviceCo
4344
serviceCollection.AddSingleton<IDiffuser, StableDiffusion.Diffusers.LatentConsistency.TextDiffuser>();
4445
serviceCollection.AddSingleton<IDiffuser, StableDiffusion.Diffusers.LatentConsistency.ImageDiffuser>();
4546
serviceCollection.AddSingleton<IDiffuser, StableDiffusion.Diffusers.LatentConsistency.InpaintLegacyDiffuser>();
47+
48+
//InstaFlow
49+
serviceCollection.AddSingleton<IDiffuser, StableDiffusion.Diffusers.InstaFlow.TextDiffuser>();
4650
}
4751

4852

0 commit comments

Comments
 (0)