Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit d67ea08

Browse files
committed
Merge branch 'master' into LCM_Inpaint
2 parents 6f0adce + bb5ed73 commit d67ea08

File tree

11 files changed

+341
-165
lines changed

11 files changed

+341
-165
lines changed

OnnxStack.Core/Extensions.cs

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
using Microsoft.ML.OnnxRuntime;
2+
using Microsoft.ML.OnnxRuntime.Tensors;
23
using OnnxStack.Core.Config;
34
using System;
45
using System.Collections.Concurrent;
56
using System.Collections.Generic;
67
using System.Linq;
8+
using System.Numerics;
79

810
namespace OnnxStack.Core
911
{
@@ -169,5 +171,60 @@ public static ConcurrentDictionary<T, U> ToConcurrentDictionary<S, T, U>(this IE
169171
{
170172
return new ConcurrentDictionary<T, U>(source.ToDictionary(keySelector, elementSelector));
171173
}
174+
175+
176+
/// <summary>
177+
/// Gets the full prod of a dimension
178+
/// </summary>
179+
/// <param name="array">The dimension array.</param>
180+
/// <returns></returns>
181+
public static T GetBufferLength<T>(this T[] array) where T : INumber<T>
182+
{
183+
T result = T.One;
184+
foreach (T element in array)
185+
{
186+
result *= element;
187+
}
188+
return result;
189+
}
190+
191+
192+
/// <summary>
193+
/// Gets the full prod of a dimension
194+
/// </summary>
195+
/// <param name="array">The dimension array.</param>
196+
/// <returns></returns>
197+
public static T GetBufferLength<T>(this ReadOnlySpan<T> array) where T : INumber<T>
198+
{
199+
T result = T.One;
200+
foreach (T element in array)
201+
{
202+
result *= element;
203+
}
204+
return result;
205+
}
206+
207+
208+
public static long[] ToLong(this ReadOnlySpan<int> array)
209+
{
210+
return Array.ConvertAll(array.ToArray(), Convert.ToInt64);
211+
}
212+
213+
public static int[] ToInt(this long[] array)
214+
{
215+
return Array.ConvertAll(array, Convert.ToInt32);
216+
}
217+
218+
public static long[] ToLong(this int[] array)
219+
{
220+
return Array.ConvertAll(array, Convert.ToInt64);
221+
}
222+
223+
224+
public static OrtValue ToOrtValue<T>(this DenseTensor<T> tensor) where T : unmanaged
225+
{
226+
return OrtValue.CreateTensorValueFromMemory(OrtMemoryInfo.DefaultInstance, tensor.Buffer, tensor.Dimensions.ToLong());
227+
}
228+
172229
}
173230
}

OnnxStack.Core/Services/IOnnxModelService.cs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,29 @@ public interface IOnnxModelService : IDisposable
7878
Task<IDisposableReadOnlyCollection<DisposableNamedOnnxValue>> RunInferenceAsync(IOnnxModel model, OnnxModelType modelType, IReadOnlyCollection<NamedOnnxValue> inputs);
7979

8080

81+
/// <summary>
82+
/// Runs the inference Use when output size is unknown
83+
/// </summary>
84+
/// <param name="model">The model.</param>
85+
/// <param name="modelType">Type of the model.</param>
86+
/// <param name="inputs">The inputs.</param>
87+
/// <param name="outputs">The outputs.</param>
88+
/// <returns></returns>
89+
IReadOnlyCollection<OrtValue> RunInference(IOnnxModel model, OnnxModelType modelType, Dictionary<string, OrtValue> inputs, IReadOnlyCollection<string> outputs);
90+
91+
92+
/// <summary>
93+
/// Runs the inference asynchronously, Use when output size is known
94+
/// Output buffer size must be known and set before inference is run
95+
/// </summary>
96+
/// <param name="model">The model.</param>
97+
/// <param name="modelType">Type of the model.</param>
98+
/// <param name="inputs">The inputs.</param>
99+
/// <param name="outputs">The outputs.</param>
100+
/// <returns></returns>
101+
Task<IReadOnlyCollection<OrtValue>> RunInferenceAsync(IOnnxModel model, OnnxModelType modelType, Dictionary<string, OrtValue> inputs, Dictionary<string, OrtValue> outputs);
102+
103+
81104
/// <summary>
82105
/// Gets the Sessions input metadata.
83106
/// </summary>
@@ -108,5 +131,6 @@ public interface IOnnxModelService : IDisposable
108131
/// <param name="modelType">Type of model.</param>
109132
/// <returns></returns>
110133
IReadOnlyList<string> GetOutputNames(IOnnxModel model, OnnxModelType modelType);
134+
111135
}
112136
}

OnnxStack.Core/Services/OnnxModelService.cs

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,34 @@ public IReadOnlyList<string> GetOutputNames(IOnnxModel model, OnnxModelType mode
175175
}
176176

177177

178+
/// <summary>
179+
/// Runs inference on the specified model.
180+
/// </summary>
181+
/// <param name="modelType">Type of the model.</param>
182+
/// <param name="inputs">The inputs.</param>
183+
/// <returns></returns>
184+
public IReadOnlyCollection<OrtValue> RunInference(IOnnxModel model, OnnxModelType modelType, Dictionary<string, OrtValue> inputs, IReadOnlyCollection<string> outputs)
185+
{
186+
return GetModelSet(model)
187+
.GetSession(modelType)
188+
.Run(new RunOptions(), inputs, outputs);
189+
}
190+
191+
192+
/// <summary>
193+
/// Runs inference on the specified model.
194+
/// </summary>
195+
/// <param name="modelType">Type of the model.</param>
196+
/// <param name="inputs">The inputs.</param>
197+
/// <returns></returns>
198+
public Task<IReadOnlyCollection<OrtValue>> RunInferenceAsync(IOnnxModel model, OnnxModelType modelType, Dictionary<string, OrtValue> inputs, Dictionary<string, OrtValue> outputs)
199+
{
200+
return GetModelSet(model)
201+
.GetSession(modelType)
202+
.RunAsync(new RunOptions(), inputs.Keys, inputs.Values, outputs.Keys, outputs.Values);
203+
}
204+
205+
178206
/// <summary>
179207
/// Runs inference on the specified model.
180208
/// </summary>

OnnxStack.StableDiffusion/Diffusers/DiffuserBase.cs

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -217,13 +217,20 @@ protected virtual async Task<DenseTensor<float>> DecodeLatents(IModelOptions mod
217217
foreach (var image in images)
218218
{
219219
var inputNames = _onnxModelService.GetInputNames(model, OnnxModelType.VaeDecoder);
220-
var inputParameters = CreateInputParameters(NamedOnnxValue.CreateFromTensor(inputNames[0], image));
220+
var outputNames = _onnxModelService.GetOutputNames(model, OnnxModelType.VaeDecoder);
221221

222-
// Run inference.
223-
using (var inferResult = await _onnxModelService.RunInferenceAsync(model, OnnxModelType.VaeDecoder, inputParameters))
222+
var outputDim = new[] { 1, 3, options.Height, options.Width };
223+
var outputBuffer = new DenseTensor<float>(outputDim);
224+
using (var inputTensorValue = image.ToOrtValue())
225+
using (var outputTensorValue = outputBuffer.ToOrtValue())
224226
{
225-
var resultTensor = inferResult.FirstElementAs<DenseTensor<float>>();
226-
imageTensors.Add(resultTensor.ToDenseTensor());
227+
var inputs = new Dictionary<string, OrtValue> { { inputNames[0], inputTensorValue } };
228+
var outputs = new Dictionary<string, OrtValue> { { outputNames[0], outputTensorValue } };
229+
var results = await _onnxModelService.RunInferenceAsync(model, OnnxModelType.VaeDecoder, inputs, outputs);
230+
using (var imageResult = results.First())
231+
{
232+
imageTensors.Add(outputBuffer);
233+
}
227234
}
228235
}
229236

@@ -236,19 +243,19 @@ protected virtual async Task<DenseTensor<float>> DecodeLatents(IModelOptions mod
236243

237244

238245
/// <summary>
239-
/// Creates the timestep NamedOnnxValue based on its NodeMetadata type.
246+
/// Creates the timestep OrtValue based on its NodeMetadata type.
240247
/// </summary>
241248
/// <param name="nodeMetadata">The node metadata.</param>
242249
/// <param name="timestepInputName">Name of the timestep input.</param>
243250
/// <param name="timestep">The timestep.</param>
244251
/// <returns></returns>
245-
protected static NamedOnnxValue CreateTimestepNamedOnnxValue(IReadOnlyDictionary<string, NodeMetadata> nodeMetadata, string timestepInputName, int timestep)
252+
protected static OrtValue CreateTimestepNamedOrtValue(IReadOnlyDictionary<string, NodeMetadata> nodeMetadata, string timestepInputName, int timestep)
246253
{
247254
// Some models support Long or Float, could be more but fornow just support these 2
248255
var timestepMetaData = nodeMetadata[timestepInputName];
249256
return timestepMetaData.ElementDataType == TensorElementType.Int64
250-
? NamedOnnxValue.CreateFromTensor(timestepInputName, new DenseTensor<long>(new long[] { timestep }, new int[] { 1 }))
251-
: NamedOnnxValue.CreateFromTensor(timestepInputName, new DenseTensor<float>(new float[] { timestep }, new int[] { 1 }));
257+
? OrtValue.CreateTensorValueFromMemory(new long[] { timestep }, new long[] { 1 })
258+
: OrtValue.CreateTensorValueFromMemory(new float[] { timestep }, new long[] { 1 });
252259
}
253260

254261

OnnxStack.StableDiffusion/Diffusers/LatentConsistency/ImageDiffuser.cs

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
using System.Collections.Generic;
1313
using System.Linq;
1414
using System.Threading.Tasks;
15+
using OnnxStack.Core;
1516

1617
namespace OnnxStack.StableDiffusion.Diffusers.LatentConsistency
1718
{
@@ -55,24 +56,32 @@ protected override IReadOnlyList<int> GetTimesteps(SchedulerOptions options, ISc
5556
/// <param name="options">The options.</param>
5657
/// <param name="scheduler">The scheduler.</param>
5758
/// <returns></returns>
58-
protected override Task<DenseTensor<float>> PrepareLatents(IModelOptions model, PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList<int> timesteps)
59+
protected override async Task<DenseTensor<float>> PrepareLatents(IModelOptions model, PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList<int> timesteps)
5960
{
60-
// Image input, decode, add noise, return as latent 0
6161
var imageTensor = prompt.InputImage.ToDenseTensor(new[] { 1, 3, options.Height, options.Width });
6262
var inputNames = _onnxModelService.GetInputNames(model, OnnxModelType.VaeEncoder);
63-
var inputParameters = CreateInputParameters(NamedOnnxValue.CreateFromTensor(inputNames[0], imageTensor));
64-
using (var inferResult = _onnxModelService.RunInference(model, OnnxModelType.VaeEncoder, inputParameters))
63+
var outputNames = _onnxModelService.GetOutputNames(model, OnnxModelType.VaeEncoder);
64+
65+
//TODO: Model Config, Channels
66+
var outputBuffer = new DenseTensor<float>(options.GetScaledDimension());
67+
using (var inputTensorValue = imageTensor.ToOrtValue())
68+
using (var outputTensorValue = outputBuffer.ToOrtValue())
6569
{
66-
var sample = inferResult.FirstElementAs<DenseTensor<float>>();
67-
var scaledSample = sample
68-
.Add(scheduler.CreateRandomSample(sample.Dimensions, options.InitialNoiseLevel))
69-
.MultiplyBy(model.ScaleFactor);
70+
var inputs = new Dictionary<string, OrtValue> { { inputNames[0], inputTensorValue } };
71+
var outputs = new Dictionary<string, OrtValue> { { outputNames[0], outputTensorValue } };
72+
var results = await _onnxModelService.RunInferenceAsync(model, OnnxModelType.VaeEncoder, inputs, outputs);
73+
using (var result = results.First())
74+
{
75+
var scaledSample = outputBuffer
76+
.Add(scheduler.CreateRandomSample(outputBuffer.Dimensions, options.InitialNoiseLevel))
77+
.MultiplyBy(model.ScaleFactor);
7078

71-
var noisySample = scheduler.AddNoise(scaledSample, scheduler.CreateRandomSample(scaledSample.Dimensions), timesteps);
72-
if (prompt.BatchCount > 1)
73-
return Task.FromResult(noisySample.Repeat(prompt.BatchCount));
79+
var noisySample = scheduler.AddNoise(scaledSample, scheduler.CreateRandomSample(scaledSample.Dimensions), timesteps);
80+
if (prompt.BatchCount > 1)
81+
return noisySample.Repeat(prompt.BatchCount);
7482

75-
return Task.FromResult(noisySample);
83+
return noisySample;
84+
}
7685
}
7786
}
7887
}

OnnxStack.StableDiffusion/Diffusers/LatentConsistency/LatentConsistencyDiffuser.cs

Lines changed: 32 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
using System.Collections.Generic;
1414
using System.Diagnostics;
1515
using System.Linq;
16-
using System.Runtime.CompilerServices;
1716
using System.Threading;
1817
using System.Threading.Tasks;
1918

@@ -63,7 +62,7 @@ public override Task<DenseTensor<float>> DiffuseAsync(IModelOptions modelOptions
6362
/// <param name="progressCallback">The progress callback.</param>
6463
/// <param name="cancellationToken">The cancellation token.</param>
6564
/// <returns></returns>
66-
public override IAsyncEnumerable<BatchResult> DiffuseBatchAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action<int, int, int, int> progressCallback = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
65+
public override IAsyncEnumerable<BatchResult> DiffuseBatchAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action<int, int, int, int> progressCallback = null, CancellationToken cancellationToken = default)
6766
{
6867
// LCM does not support negative prompting
6968
promptOptions.NegativePrompt = string.Empty;
@@ -104,6 +103,11 @@ protected override async Task<DenseTensor<float>> SchedulerStep(IModelOptions mo
104103
// Denoised result
105104
DenseTensor<float> denoised = null;
106105

106+
// Get Model metadata
107+
var inputNames = _onnxModelService.GetInputNames(modelOptions, OnnxModelType.Unet);
108+
var outputNames = _onnxModelService.GetOutputNames(modelOptions, OnnxModelType.Unet);
109+
var inputMetaData = _onnxModelService.GetInputMetadata(modelOptions, OnnxModelType.Unet);
110+
107111
// Loop though the timesteps
108112
var step = 0;
109113
foreach (var timestep in timesteps)
@@ -115,19 +119,33 @@ protected override async Task<DenseTensor<float>> SchedulerStep(IModelOptions mo
115119
// Create input tensor.
116120
var inputTensor = scheduler.ScaleInput(latents, timestep);
117121

118-
// Create Input Parameters
119-
var inputParameters = CreateUnetInputParams(modelOptions, inputTensor, promptEmbeddings, guidanceEmbeddings, timestep);
120-
121-
// Run Inference
122-
using (var inferResult = await _onnxModelService.RunInferenceAsync(modelOptions, OnnxModelType.Unet, inputParameters))
122+
var outputBuffer = new DenseTensor<float>(schedulerOptions.GetScaledDimension());
123+
using (var outputTensorValue = outputBuffer.ToOrtValue())
124+
using (var inputTensorValue = inputTensor.ToOrtValue())
125+
using (var timestepOrtValue = CreateTimestepNamedOrtValue(inputMetaData, inputNames[1], timestep))
126+
using (var promptTensorValue = promptEmbeddings.ToOrtValue())
127+
using (var guidanceTensorValue = guidanceEmbeddings.ToOrtValue())
123128
{
124-
var noisePred = inferResult.FirstElementAs<DenseTensor<float>>();
125-
126-
// Scheduler Step
127-
var schedulerResult = scheduler.Step(noisePred, timestep, latents);
128-
129-
latents = schedulerResult.Result;
130-
denoised = schedulerResult.SampleData;
129+
var inputs = new Dictionary<string, OrtValue>
130+
{
131+
{ inputNames[0], inputTensorValue },
132+
{ inputNames[1], timestepOrtValue },
133+
{ inputNames[2], promptTensorValue },
134+
{ inputNames[3], guidanceTensorValue }
135+
};
136+
137+
var outputs = new Dictionary<string, OrtValue> { { outputNames[0], outputTensorValue } };
138+
var results = await _onnxModelService.RunInferenceAsync(modelOptions, OnnxModelType.Unet, inputs, outputs);
139+
using (var result = results.First())
140+
{
141+
var noisePred = outputBuffer;
142+
143+
// Scheduler Step
144+
var schedulerResult = scheduler.Step(noisePred, timestep, latents);
145+
146+
latents = schedulerResult.Result;
147+
denoised = schedulerResult.SampleData;
148+
}
131149
}
132150

133151
progressCallback?.Invoke(step, timesteps.Count);
@@ -140,27 +158,6 @@ protected override async Task<DenseTensor<float>> SchedulerStep(IModelOptions mo
140158
}
141159

142160

143-
/// <summary>
144-
/// Creates the Unet input parameters.
145-
/// </summary>
146-
/// <param name="model">The model.</param>
147-
/// <param name="inputTensor">The input tensor.</param>
148-
/// <param name="promptEmbeddings">The prompt embeddings.</param>
149-
/// <param name="timestep">The timestep.</param>
150-
/// <returns></returns>
151-
protected IReadOnlyList<NamedOnnxValue> CreateUnetInputParams(IModelOptions model, DenseTensor<float> inputTensor, DenseTensor<float> promptEmbeddings, DenseTensor<float> guidanceEmbeddings, int timestep)
152-
{
153-
var inputNames = _onnxModelService.GetInputNames(model, OnnxModelType.Unet);
154-
var inputMetaData = _onnxModelService.GetInputMetadata(model, OnnxModelType.Unet);
155-
var timestepNamedOnnxValue = CreateTimestepNamedOnnxValue(inputMetaData, inputNames[1], timestep);
156-
return CreateInputParameters(
157-
NamedOnnxValue.CreateFromTensor(inputNames[0], inputTensor),
158-
timestepNamedOnnxValue,
159-
NamedOnnxValue.CreateFromTensor(inputNames[2], promptEmbeddings),
160-
NamedOnnxValue.CreateFromTensor(inputNames[3], guidanceEmbeddings));
161-
}
162-
163-
164161
/// <summary>
165162
/// Gets the scheduler.
166163
/// </summary>

0 commit comments

Comments
 (0)