Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit 45950f1

Browse files
committed
VaeEncoder/VaeDecoder moved to OrtValue API
1 parent ff45ed3 commit 45950f1

File tree

5 files changed

+109
-42
lines changed

5 files changed

+109
-42
lines changed

OnnxStack.Core/Extensions.cs

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,5 +186,37 @@ public static T GetBufferLength<T>(this T[] array) where T : INumber<T>
186186
}
187187
return result;
188188
}
189+
190+
191+
/// <summary>
192+
/// Gets the full prod of a dimension
193+
/// </summary>
194+
/// <param name="array">The dimension array.</param>
195+
/// <returns></returns>
196+
public static T GetBufferLength<T>(this ReadOnlySpan<T> array) where T : INumber<T>
197+
{
198+
T result = T.One;
199+
foreach (T element in array)
200+
{
201+
result *= element;
202+
}
203+
return result;
204+
}
205+
206+
207+
public static long[] ToLong(this ReadOnlySpan<int> array)
208+
{
209+
return Array.ConvertAll(array.ToArray(), Convert.ToInt64);
210+
}
211+
212+
public static int[] ToInt(this long[] array)
213+
{
214+
return Array.ConvertAll(array, Convert.ToInt32);
215+
}
216+
217+
public static long[] ToLong(this int[] array)
218+
{
219+
return Array.ConvertAll(array, Convert.ToInt64);
220+
}
189221
}
190222
}

OnnxStack.StableDiffusion/Diffusers/DiffuserBase.cs

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -217,13 +217,20 @@ protected virtual async Task<DenseTensor<float>> DecodeLatents(IModelOptions mod
217217
foreach (var image in images)
218218
{
219219
var inputNames = _onnxModelService.GetInputNames(model, OnnxModelType.VaeDecoder);
220-
var inputParameters = CreateInputParameters(NamedOnnxValue.CreateFromTensor(inputNames[0], image));
220+
var outputNames = _onnxModelService.GetOutputNames(model, OnnxModelType.VaeDecoder);
221221

222-
// Run inference.
223-
using (var inferResult = await _onnxModelService.RunInferenceAsync(model, OnnxModelType.VaeDecoder, inputParameters))
222+
var outputDim = new[] { 1, 3, options.Height, options.Width };
223+
var outputBuffer = new DenseTensor<float>(outputDim);
224+
using (var inputTensorValue = OrtValue.CreateTensorValueFromMemory(OrtMemoryInfo.DefaultInstance, image.Buffer, image.Dimensions.ToLong()))
225+
using (var outputTensorValue = OrtValue.CreateTensorValueFromMemory(OrtMemoryInfo.DefaultInstance, outputBuffer.Buffer, outputDim.ToLong()))
224226
{
225-
var resultTensor = inferResult.FirstElementAs<DenseTensor<float>>();
226-
imageTensors.Add(resultTensor.ToDenseTensor());
227+
var inputs = new Dictionary<string, OrtValue> { { inputNames[0], inputTensorValue } };
228+
var outputs = new Dictionary<string, OrtValue> { { outputNames[0], outputTensorValue } };
229+
var results = await _onnxModelService.RunInferenceAsync(model, OnnxModelType.VaeDecoder, inputs, outputs);
230+
using (var imageResult = results.First())
231+
{
232+
imageTensors.Add(images.Length == 1 ? outputBuffer : outputBuffer.ToDenseTensor());
233+
}
227234
}
228235
}
229236

OnnxStack.StableDiffusion/Diffusers/LatentConsistency/ImageDiffuser.cs

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
using System.Collections.Generic;
1313
using System.Linq;
1414
using System.Threading.Tasks;
15+
using OnnxStack.Core;
1516

1617
namespace OnnxStack.StableDiffusion.Diffusers.LatentConsistency
1718
{
@@ -55,24 +56,33 @@ protected override IReadOnlyList<int> GetTimesteps(SchedulerOptions options, ISc
5556
/// <param name="options">The options.</param>
5657
/// <param name="scheduler">The scheduler.</param>
5758
/// <returns></returns>
58-
protected override Task<DenseTensor<float>> PrepareLatents(IModelOptions model, PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList<int> timesteps)
59+
protected override async Task<DenseTensor<float>> PrepareLatents(IModelOptions model, PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList<int> timesteps)
5960
{
60-
// Image input, decode, add noise, return as latent 0
6161
var imageTensor = prompt.InputImage.ToDenseTensor(new[] { 1, 3, options.Height, options.Width });
6262
var inputNames = _onnxModelService.GetInputNames(model, OnnxModelType.VaeEncoder);
63-
var inputParameters = CreateInputParameters(NamedOnnxValue.CreateFromTensor(inputNames[0], imageTensor));
64-
using (var inferResult = _onnxModelService.RunInference(model, OnnxModelType.VaeEncoder, inputParameters))
63+
var outputNames = _onnxModelService.GetOutputNames(model, OnnxModelType.VaeEncoder);
64+
65+
//TODO: Model Config, Channels
66+
var outputDim = options.GetScaledDimension();
67+
var outputBuffer = new DenseTensor<float>(outputDim);
68+
using (var inputTensorValue = OrtValue.CreateTensorValueFromMemory(OrtMemoryInfo.DefaultInstance, imageTensor.Buffer, imageTensor.Dimensions.ToLong()))
69+
using (var outputTensorValue = OrtValue.CreateTensorValueFromMemory(OrtMemoryInfo.DefaultInstance, outputBuffer.Buffer, outputDim.ToLong()))
6570
{
66-
var sample = inferResult.FirstElementAs<DenseTensor<float>>();
67-
var scaledSample = sample
68-
.Add(scheduler.CreateRandomSample(sample.Dimensions, options.InitialNoiseLevel))
69-
.MultiplyBy(model.ScaleFactor);
71+
var inputs = new Dictionary<string, OrtValue> { { inputNames[0], inputTensorValue } };
72+
var outputs = new Dictionary<string, OrtValue> { { outputNames[0], outputTensorValue } };
73+
var results = await _onnxModelService.RunInferenceAsync(model, OnnxModelType.VaeEncoder, inputs, outputs);
74+
using (var result = results.First())
75+
{
76+
var scaledSample = outputBuffer
77+
.Add(scheduler.CreateRandomSample(outputBuffer.Dimensions, options.InitialNoiseLevel))
78+
.MultiplyBy(model.ScaleFactor);
7079

71-
var noisySample = scheduler.AddNoise(scaledSample, scheduler.CreateRandomSample(scaledSample.Dimensions), timesteps);
72-
if (prompt.BatchCount > 1)
73-
return Task.FromResult(noisySample.Repeat(prompt.BatchCount));
80+
var noisySample = scheduler.AddNoise(scaledSample, scheduler.CreateRandomSample(scaledSample.Dimensions), timesteps);
81+
if (prompt.BatchCount > 1)
82+
return noisySample.Repeat(prompt.BatchCount);
7483

75-
return Task.FromResult(noisySample);
84+
return noisySample;
85+
}
7686
}
7787
}
7888
}

OnnxStack.StableDiffusion/Diffusers/StableDiffusion/ImageDiffuser.cs

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
using System.Collections.Generic;
1313
using System.Linq;
1414
using System.Threading.Tasks;
15+
using OnnxStack.Core;
1516

1617
namespace OnnxStack.StableDiffusion.Diffusers.StableDiffusion
1718
{
@@ -57,24 +58,33 @@ protected override IReadOnlyList<int> GetTimesteps(SchedulerOptions options, ISc
5758
/// <param name="options">The options.</param>
5859
/// <param name="scheduler">The scheduler.</param>
5960
/// <returns></returns>
60-
protected override Task<DenseTensor<float>> PrepareLatents(IModelOptions model, PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList<int> timesteps)
61+
protected override async Task<DenseTensor<float>> PrepareLatents(IModelOptions model, PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList<int> timesteps)
6162
{
62-
// Image input, decode, add noise, return as latent 0
6363
var imageTensor = prompt.InputImage.ToDenseTensor(new[] { 1, 3, options.Height, options.Width });
6464
var inputNames = _onnxModelService.GetInputNames(model, OnnxModelType.VaeEncoder);
65-
var inputParameters = CreateInputParameters(NamedOnnxValue.CreateFromTensor(inputNames[0], imageTensor));
66-
using (var inferResult = _onnxModelService.RunInference(model, OnnxModelType.VaeEncoder, inputParameters))
65+
var outputNames = _onnxModelService.GetOutputNames(model, OnnxModelType.VaeEncoder);
66+
67+
//TODO: Model Config, Channels
68+
var outputDim = options.GetScaledDimension();
69+
var outputBuffer = new DenseTensor<float>(outputDim);
70+
using (var inputTensorValue = OrtValue.CreateTensorValueFromMemory(OrtMemoryInfo.DefaultInstance, imageTensor.Buffer, imageTensor.Dimensions.ToLong()))
71+
using (var outputTensorValue = OrtValue.CreateTensorValueFromMemory(OrtMemoryInfo.DefaultInstance, outputBuffer.Buffer, outputDim.ToLong()))
6772
{
68-
var sample = inferResult.FirstElementAs<DenseTensor<float>>();
69-
var scaledSample = sample
70-
.Add(scheduler.CreateRandomSample(sample.Dimensions, options.InitialNoiseLevel))
71-
.MultiplyBy(model.ScaleFactor);
73+
var inputs = new Dictionary<string, OrtValue> { { inputNames[0], inputTensorValue } };
74+
var outputs = new Dictionary<string, OrtValue> { { outputNames[0], outputTensorValue } };
75+
var results = await _onnxModelService.RunInferenceAsync(model, OnnxModelType.VaeEncoder, inputs, outputs);
76+
using (var result = results.First())
77+
{
78+
var scaledSample = outputBuffer
79+
.Add(scheduler.CreateRandomSample(outputBuffer.Dimensions, options.InitialNoiseLevel))
80+
.MultiplyBy(model.ScaleFactor);
7281

73-
var noisySample = scheduler.AddNoise(scaledSample, scheduler.CreateRandomSample(scaledSample.Dimensions), timesteps);
74-
if (prompt.BatchCount > 1)
75-
return Task.FromResult(noisySample.Repeat(prompt.BatchCount));
82+
var noisySample = scheduler.AddNoise(scaledSample, scheduler.CreateRandomSample(scaledSample.Dimensions), timesteps);
83+
if (prompt.BatchCount > 1)
84+
return noisySample.Repeat(prompt.BatchCount);
7685

77-
return Task.FromResult(noisySample);
86+
return noisySample;
87+
}
7888
}
7989
}
8090

OnnxStack.StableDiffusion/Diffusers/StableDiffusion/InpaintLegacyDiffuser.cs

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -136,24 +136,32 @@ protected override IReadOnlyList<int> GetTimesteps(SchedulerOptions options, ISc
136136
/// <param name="options">The options.</param>
137137
/// <param name="scheduler">The scheduler.</param>
138138
/// <returns></returns>
139-
protected override Task<DenseTensor<float>> PrepareLatents(IModelOptions model, PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList<int> timesteps)
139+
protected override async Task<DenseTensor<float>> PrepareLatents(IModelOptions model, PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList<int> timesteps)
140140
{
141-
// Image input, decode, add noise, return as latent 0
142-
var imageTensor = prompt.InputImage.ToDenseTensor(new[] { 1, 3, options.Width, options.Height });
141+
var imageTensor = prompt.InputImage.ToDenseTensor(new[] { 1, 3, options.Height, options.Width });
143142
var inputNames = _onnxModelService.GetInputNames(model, OnnxModelType.VaeEncoder);
144-
var inputParameters = CreateInputParameters(NamedOnnxValue.CreateFromTensor(inputNames[0], imageTensor));
145-
using (var inferResult = _onnxModelService.RunInference(model, OnnxModelType.VaeEncoder, inputParameters))
143+
var outputNames = _onnxModelService.GetOutputNames(model, OnnxModelType.VaeEncoder);
144+
145+
//TODO: Model Config, Channels
146+
var outputDim = options.GetScaledDimension();
147+
var outputBuffer = new DenseTensor<float>(outputDim);
148+
using (var inputTensorValue = OrtValue.CreateTensorValueFromMemory(OrtMemoryInfo.DefaultInstance, imageTensor.Buffer, imageTensor.Dimensions.ToLong()))
149+
using (var outputTensorValue = OrtValue.CreateTensorValueFromMemory(OrtMemoryInfo.DefaultInstance, outputBuffer.Buffer, outputDim.ToLong()))
146150
{
147-
var sample = inferResult.FirstElementAs<DenseTensor<float>>();
148-
var scaledSample = sample
149-
.Add(scheduler.CreateRandomSample(sample.Dimensions, options.InitialNoiseLevel))
150-
.MultiplyBy(model.ScaleFactor)
151-
.ToDenseTensor();
151+
var inputs = new Dictionary<string, OrtValue> { { inputNames[0], inputTensorValue } };
152+
var outputs = new Dictionary<string, OrtValue> { { outputNames[0], outputTensorValue } };
153+
var results = await _onnxModelService.RunInferenceAsync(model, OnnxModelType.VaeEncoder, inputs, outputs);
154+
using (var result = results.First())
155+
{
156+
var scaledSample = outputBuffer
157+
.Add(scheduler.CreateRandomSample(outputBuffer.Dimensions, options.InitialNoiseLevel))
158+
.MultiplyBy(model.ScaleFactor);
152159

153-
if (prompt.BatchCount > 1)
154-
return Task.FromResult(scaledSample.Repeat(prompt.BatchCount));
160+
if (prompt.BatchCount > 1)
161+
return scaledSample.Repeat(prompt.BatchCount);
155162

156-
return Task.FromResult(scaledSample);
163+
return scaledSample;
164+
}
157165
}
158166
}
159167

0 commit comments

Comments
 (0)