Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit bb5ed73

Browse files
committed
Move Unet to OrtValue API
1 parent 45950f1 commit bb5ed73

File tree

8 files changed

+152
-122
lines changed

8 files changed

+152
-122
lines changed

OnnxStack.Core/Extensions.cs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using Microsoft.ML.OnnxRuntime;
2+
using Microsoft.ML.OnnxRuntime.Tensors;
23
using OnnxStack.Core.Config;
34
using System;
45
using System.Collections.Concurrent;
@@ -218,5 +219,12 @@ public static long[] ToLong(this int[] array)
218219
{
219220
return Array.ConvertAll(array, Convert.ToInt64);
220221
}
222+
223+
224+
public static OrtValue ToOrtValue<T>(this DenseTensor<T> tensor) where T : unmanaged
225+
{
226+
return OrtValue.CreateTensorValueFromMemory(OrtMemoryInfo.DefaultInstance, tensor.Buffer, tensor.Dimensions.ToLong());
227+
}
228+
221229
}
222230
}

OnnxStack.StableDiffusion/Diffusers/DiffuserBase.cs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -221,15 +221,15 @@ protected virtual async Task<DenseTensor<float>> DecodeLatents(IModelOptions mod
221221

222222
var outputDim = new[] { 1, 3, options.Height, options.Width };
223223
var outputBuffer = new DenseTensor<float>(outputDim);
224-
using (var inputTensorValue = OrtValue.CreateTensorValueFromMemory(OrtMemoryInfo.DefaultInstance, image.Buffer, image.Dimensions.ToLong()))
225-
using (var outputTensorValue = OrtValue.CreateTensorValueFromMemory(OrtMemoryInfo.DefaultInstance, outputBuffer.Buffer, outputDim.ToLong()))
224+
using (var inputTensorValue = image.ToOrtValue())
225+
using (var outputTensorValue = outputBuffer.ToOrtValue())
226226
{
227227
var inputs = new Dictionary<string, OrtValue> { { inputNames[0], inputTensorValue } };
228228
var outputs = new Dictionary<string, OrtValue> { { outputNames[0], outputTensorValue } };
229229
var results = await _onnxModelService.RunInferenceAsync(model, OnnxModelType.VaeDecoder, inputs, outputs);
230230
using (var imageResult = results.First())
231231
{
232-
imageTensors.Add(images.Length == 1 ? outputBuffer : outputBuffer.ToDenseTensor());
232+
imageTensors.Add(outputBuffer);
233233
}
234234
}
235235
}
@@ -243,19 +243,19 @@ protected virtual async Task<DenseTensor<float>> DecodeLatents(IModelOptions mod
243243

244244

245245
/// <summary>
246-
/// Creates the timestep NamedOnnxValue based on its NodeMetadata type.
246+
/// Creates the timestep OrtValue based on its NodeMetadata type.
247247
/// </summary>
248248
/// <param name="nodeMetadata">The node metadata.</param>
249249
/// <param name="timestepInputName">Name of the timestep input.</param>
250250
/// <param name="timestep">The timestep.</param>
251251
/// <returns></returns>
252-
protected static NamedOnnxValue CreateTimestepNamedOnnxValue(IReadOnlyDictionary<string, NodeMetadata> nodeMetadata, string timestepInputName, int timestep)
252+
protected static OrtValue CreateTimestepNamedOrtValue(IReadOnlyDictionary<string, NodeMetadata> nodeMetadata, string timestepInputName, int timestep)
253253
{
254254
// Some models support Long or Float, could be more but fornow just support these 2
255255
var timestepMetaData = nodeMetadata[timestepInputName];
256256
return timestepMetaData.ElementDataType == TensorElementType.Int64
257-
? NamedOnnxValue.CreateFromTensor(timestepInputName, new DenseTensor<long>(new long[] { timestep }, new int[] { 1 }))
258-
: NamedOnnxValue.CreateFromTensor(timestepInputName, new DenseTensor<float>(new float[] { timestep }, new int[] { 1 }));
257+
? OrtValue.CreateTensorValueFromMemory(new long[] { timestep }, new long[] { 1 })
258+
: OrtValue.CreateTensorValueFromMemory(new float[] { timestep }, new long[] { 1 });
259259
}
260260

261261

OnnxStack.StableDiffusion/Diffusers/LatentConsistency/ImageDiffuser.cs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,9 @@ protected override async Task<DenseTensor<float>> PrepareLatents(IModelOptions m
6363
var outputNames = _onnxModelService.GetOutputNames(model, OnnxModelType.VaeEncoder);
6464

6565
//TODO: Model Config, Channels
66-
var outputDim = options.GetScaledDimension();
67-
var outputBuffer = new DenseTensor<float>(outputDim);
68-
using (var inputTensorValue = OrtValue.CreateTensorValueFromMemory(OrtMemoryInfo.DefaultInstance, imageTensor.Buffer, imageTensor.Dimensions.ToLong()))
69-
using (var outputTensorValue = OrtValue.CreateTensorValueFromMemory(OrtMemoryInfo.DefaultInstance, outputBuffer.Buffer, outputDim.ToLong()))
66+
var outputBuffer = new DenseTensor<float>(options.GetScaledDimension());
67+
using (var inputTensorValue = imageTensor.ToOrtValue())
68+
using (var outputTensorValue = outputBuffer.ToOrtValue())
7069
{
7170
var inputs = new Dictionary<string, OrtValue> { { inputNames[0], inputTensorValue } };
7271
var outputs = new Dictionary<string, OrtValue> { { outputNames[0], outputTensorValue } };

OnnxStack.StableDiffusion/Diffusers/LatentConsistency/LatentConsistencyDiffuser.cs

Lines changed: 32 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
using System.Collections.Generic;
1414
using System.Diagnostics;
1515
using System.Linq;
16-
using System.Runtime.CompilerServices;
1716
using System.Threading;
1817
using System.Threading.Tasks;
1918

@@ -63,7 +62,7 @@ public override Task<DenseTensor<float>> DiffuseAsync(IModelOptions modelOptions
6362
/// <param name="progressCallback">The progress callback.</param>
6463
/// <param name="cancellationToken">The cancellation token.</param>
6564
/// <returns></returns>
66-
public override IAsyncEnumerable<BatchResult> DiffuseBatchAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action<int, int, int, int> progressCallback = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
65+
public override IAsyncEnumerable<BatchResult> DiffuseBatchAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action<int, int, int, int> progressCallback = null, CancellationToken cancellationToken = default)
6766
{
6867
// LCM does not support negative prompting
6968
promptOptions.NegativePrompt = string.Empty;
@@ -104,6 +103,11 @@ protected override async Task<DenseTensor<float>> SchedulerStep(IModelOptions mo
104103
// Denoised result
105104
DenseTensor<float> denoised = null;
106105

106+
// Get Model metadata
107+
var inputNames = _onnxModelService.GetInputNames(modelOptions, OnnxModelType.Unet);
108+
var outputNames = _onnxModelService.GetOutputNames(modelOptions, OnnxModelType.Unet);
109+
var inputMetaData = _onnxModelService.GetInputMetadata(modelOptions, OnnxModelType.Unet);
110+
107111
// Loop though the timesteps
108112
var step = 0;
109113
foreach (var timestep in timesteps)
@@ -115,19 +119,33 @@ protected override async Task<DenseTensor<float>> SchedulerStep(IModelOptions mo
115119
// Create input tensor.
116120
var inputTensor = scheduler.ScaleInput(latents, timestep);
117121

118-
// Create Input Parameters
119-
var inputParameters = CreateUnetInputParams(modelOptions, inputTensor, promptEmbeddings, guidanceEmbeddings, timestep);
120-
121-
// Run Inference
122-
using (var inferResult = await _onnxModelService.RunInferenceAsync(modelOptions, OnnxModelType.Unet, inputParameters))
122+
var outputBuffer = new DenseTensor<float>(schedulerOptions.GetScaledDimension());
123+
using (var outputTensorValue = outputBuffer.ToOrtValue())
124+
using (var inputTensorValue = inputTensor.ToOrtValue())
125+
using (var timestepOrtValue = CreateTimestepNamedOrtValue(inputMetaData, inputNames[1], timestep))
126+
using (var promptTensorValue = promptEmbeddings.ToOrtValue())
127+
using (var guidanceTensorValue = guidanceEmbeddings.ToOrtValue())
123128
{
124-
var noisePred = inferResult.FirstElementAs<DenseTensor<float>>();
125-
126-
// Scheduler Step
127-
var schedulerResult = scheduler.Step(noisePred, timestep, latents);
128-
129-
latents = schedulerResult.Result;
130-
denoised = schedulerResult.SampleData;
129+
var inputs = new Dictionary<string, OrtValue>
130+
{
131+
{ inputNames[0], inputTensorValue },
132+
{ inputNames[1], timestepOrtValue },
133+
{ inputNames[2], promptTensorValue },
134+
{ inputNames[3], guidanceTensorValue }
135+
};
136+
137+
var outputs = new Dictionary<string, OrtValue> { { outputNames[0], outputTensorValue } };
138+
var results = await _onnxModelService.RunInferenceAsync(modelOptions, OnnxModelType.Unet, inputs, outputs);
139+
using (var result = results.First())
140+
{
141+
var noisePred = outputBuffer;
142+
143+
// Scheduler Step
144+
var schedulerResult = scheduler.Step(noisePred, timestep, latents);
145+
146+
latents = schedulerResult.Result;
147+
denoised = schedulerResult.SampleData;
148+
}
131149
}
132150

133151
progressCallback?.Invoke(step, timesteps.Count);
@@ -140,27 +158,6 @@ protected override async Task<DenseTensor<float>> SchedulerStep(IModelOptions mo
140158
}
141159

142160

143-
/// <summary>
144-
/// Creates the Unet input parameters.
145-
/// </summary>
146-
/// <param name="model">The model.</param>
147-
/// <param name="inputTensor">The input tensor.</param>
148-
/// <param name="promptEmbeddings">The prompt embeddings.</param>
149-
/// <param name="timestep">The timestep.</param>
150-
/// <returns></returns>
151-
protected IReadOnlyList<NamedOnnxValue> CreateUnetInputParams(IModelOptions model, DenseTensor<float> inputTensor, DenseTensor<float> promptEmbeddings, DenseTensor<float> guidanceEmbeddings, int timestep)
152-
{
153-
var inputNames = _onnxModelService.GetInputNames(model, OnnxModelType.Unet);
154-
var inputMetaData = _onnxModelService.GetInputMetadata(model, OnnxModelType.Unet);
155-
var timestepNamedOnnxValue = CreateTimestepNamedOnnxValue(inputMetaData, inputNames[1], timestep);
156-
return CreateInputParameters(
157-
NamedOnnxValue.CreateFromTensor(inputNames[0], inputTensor),
158-
timestepNamedOnnxValue,
159-
NamedOnnxValue.CreateFromTensor(inputNames[2], promptEmbeddings),
160-
NamedOnnxValue.CreateFromTensor(inputNames[3], guidanceEmbeddings));
161-
}
162-
163-
164161
/// <summary>
165162
/// Gets the scheduler.
166163
/// </summary>

OnnxStack.StableDiffusion/Diffusers/StableDiffusion/ImageDiffuser.cs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,9 @@ protected override async Task<DenseTensor<float>> PrepareLatents(IModelOptions m
6565
var outputNames = _onnxModelService.GetOutputNames(model, OnnxModelType.VaeEncoder);
6666

6767
//TODO: Model Config, Channels
68-
var outputDim = options.GetScaledDimension();
69-
var outputBuffer = new DenseTensor<float>(outputDim);
70-
using (var inputTensorValue = OrtValue.CreateTensorValueFromMemory(OrtMemoryInfo.DefaultInstance, imageTensor.Buffer, imageTensor.Dimensions.ToLong()))
71-
using (var outputTensorValue = OrtValue.CreateTensorValueFromMemory(OrtMemoryInfo.DefaultInstance, outputBuffer.Buffer, outputDim.ToLong()))
68+
var outputBuffer = new DenseTensor<float>(options.GetScaledDimension());
69+
using (var inputTensorValue = imageTensor.ToOrtValue())
70+
using (var outputTensorValue = outputBuffer.ToOrtValue())
7271
{
7372
var inputs = new Dictionary<string, OrtValue> { { inputNames[0], inputTensorValue } };
7473
var outputs = new Dictionary<string, OrtValue> { { outputNames[0], outputTensorValue } };

OnnxStack.StableDiffusion/Diffusers/StableDiffusion/InpaintDiffuser.cs

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
using System;
1414
using System.Collections.Generic;
1515
using System.Diagnostics;
16+
using System.Linq;
1617
using System.Threading;
1718
using System.Threading.Tasks;
1819

@@ -67,6 +68,11 @@ protected override async Task<DenseTensor<float>> SchedulerStep(IModelOptions mo
6768
// Create Masked Image Latents
6869
var maskedImage = PrepareImageMask(modelOptions, promptOptions, schedulerOptions);
6970

71+
// Get Model metadata
72+
var inputNames = _onnxModelService.GetInputNames(modelOptions, OnnxModelType.Unet);
73+
var outputNames = _onnxModelService.GetOutputNames(modelOptions, OnnxModelType.Unet);
74+
var inputMetaData = _onnxModelService.GetInputMetadata(modelOptions, OnnxModelType.Unet);
75+
7076
// Loop though the timesteps
7177
var step = 0;
7278
foreach (var timestep in timesteps)
@@ -76,26 +82,36 @@ protected override async Task<DenseTensor<float>> SchedulerStep(IModelOptions mo
7682
cancellationToken.ThrowIfCancellationRequested();
7783

7884
// Create input tensor.
79-
var inputLatent = performGuidance
80-
? latents.Repeat(2)
81-
: latents;
85+
var inputLatent = performGuidance ? latents.Repeat(2) : latents;
8286
var inputTensor = scheduler.ScaleInput(inputLatent, timestep);
8387
inputTensor = ConcatenateLatents(inputTensor, maskedImage, maskImage);
8488

85-
// Create Input Parameters
86-
var inputParameters = CreateUnetInputParams(modelOptions, inputTensor, promptEmbeddings, timestep);
87-
88-
// Run Inference
89-
using (var inferResult = await _onnxModelService.RunInferenceAsync(modelOptions, OnnxModelType.Unet, inputParameters))
89+
var outputBuffer = new DenseTensor<float>(schedulerOptions.GetScaledDimension());
90+
using (var outputTensorValue = outputBuffer.ToOrtValue())
91+
using (var inputTensorValue = inputTensor.ToOrtValue())
92+
using (var timestepOrtValue = CreateTimestepNamedOrtValue(inputMetaData, inputNames[1], timestep))
93+
using (var promptTensorValue = promptEmbeddings.ToOrtValue())
9094
{
91-
var noisePred = inferResult.FirstElementAs<DenseTensor<float>>();
95+
var inputs = new Dictionary<string, OrtValue>
96+
{
97+
{ inputNames[0], inputTensorValue },
98+
{ inputNames[1], timestepOrtValue },
99+
{ inputNames[2], promptTensorValue }
100+
};
101+
102+
var outputs = new Dictionary<string, OrtValue> { { outputNames[0], outputTensorValue } };
103+
var results = await _onnxModelService.RunInferenceAsync(modelOptions, OnnxModelType.Unet, inputs, outputs);
104+
using (var result = results.First())
105+
{
106+
var noisePred = outputBuffer;
92107

93-
// Perform guidance
94-
if (performGuidance)
95-
noisePred = PerformGuidance(noisePred, schedulerOptions.GuidanceScale);
108+
// Perform guidance
109+
if (performGuidance)
110+
noisePred = PerformGuidance(noisePred, schedulerOptions.GuidanceScale);
96111

97-
// Scheduler Step
98-
latents = scheduler.Step(noisePred, timestep, latents).Result;
112+
// Scheduler Step
113+
latents = scheduler.Step(noisePred, timestep, latents).Result;
114+
}
99115
}
100116

101117
progressCallback?.Invoke(step, timesteps.Count);

0 commit comments

Comments
 (0)