Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit 603f3c1

Browse files
committed
Simplify inference input/output parameters
1 parent 6c00064 commit 603f3c1

File tree

10 files changed

+120
-108
lines changed

10 files changed

+120
-108
lines changed

OnnxStack.Core/Extensions.cs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,7 @@ public static OrtValue ToOrtValue(this DenseTensor<float> tensor, OnnxNamedMetad
253253
var dimensions = tensor.Dimensions.ToLong();
254254
return metadata.Value.ElementDataType switch
255255
{
256+
TensorElementType.Int64 => OrtValue.CreateTensorValueFromMemory(OrtMemoryInfo.DefaultInstance, tensor.Buffer.ToLong(), dimensions),
256257
TensorElementType.Float16 => OrtValue.CreateTensorValueFromMemory(OrtMemoryInfo.DefaultInstance, tensor.Buffer.ToFloat16(), dimensions),
257258
TensorElementType.BFloat16 => OrtValue.CreateTensorValueFromMemory(OrtMemoryInfo.DefaultInstance, tensor.Buffer.ToBFloat16(), dimensions),
258259
_ => OrtValue.CreateTensorValueFromMemory(OrtMemoryInfo.DefaultInstance, tensor.Buffer, dimensions)
@@ -370,5 +371,16 @@ internal static Memory<float> ToFloat(this ReadOnlySpan<BFloat16> inputMemory)
370371

371372
return floatArray.AsMemory();
372373
}
374+
375+
376+
/// <summary>
377+
/// Converts to long.
378+
/// </summary>
379+
/// <param name="inputMemory">The input memory.</param>
380+
/// <returns></returns>
381+
internal static Memory<long> ToLong(this Memory<float> inputMemory)
382+
{
383+
return Array.ConvertAll(inputMemory.ToArray(), Convert.ToInt64).AsMemory();
384+
}
373385
}
374386
}

OnnxStack.Core/Model/OnnxInferenceParameters.cs

Lines changed: 51 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,23 @@
11
using Microsoft.ML.OnnxRuntime;
2+
using Microsoft.ML.OnnxRuntime.Tensors;
23
using System;
34
using System.Collections.Generic;
45

56
namespace OnnxStack.Core.Model
67
{
78
public class OnnxInferenceParameters : IDisposable
89
{
9-
private RunOptions _runOptions;
10-
private OnnxValueCollection _inputs;
11-
private OnnxValueCollection _outputs;
10+
private readonly RunOptions _runOptions;
11+
private readonly OnnxMetadata _metadata;
12+
private readonly OnnxValueCollection _inputs;
13+
private readonly OnnxValueCollection _outputs;
1214

1315
/// <summary>
1416
/// Initializes a new instance of the <see cref="OnnxInferenceParameters"/> class.
1517
/// </summary>
16-
public OnnxInferenceParameters()
18+
public OnnxInferenceParameters(OnnxMetadata metadata)
1719
{
20+
_metadata = metadata;
1821
_runOptions = new RunOptions();
1922
_inputs = new OnnxValueCollection();
2023
_outputs = new OnnxValueCollection();
@@ -26,9 +29,20 @@ public OnnxInferenceParameters()
2629
/// </summary>
2730
/// <param name="metaData">The meta data.</param>
2831
/// <param name="value">The value.</param>
29-
public void AddInput(OnnxNamedMetadata metaData, OrtValue value)
32+
public void AddInput(OrtValue value)
3033
{
31-
_inputs.Add(metaData, value);
34+
_inputs.Add(GetNextInputMetadata(), value);
35+
}
36+
37+
38+
/// <summary>
39+
/// Adds the input tensor.
40+
/// </summary>
41+
/// <param name="value">The value.</param>
42+
public void AddInputTensor(DenseTensor<float> value)
43+
{
44+
var metaData = GetNextInputMetadata();
45+
_inputs.Add(metaData, value.ToOrtValue(metaData));
3246
}
3347

3448

@@ -37,19 +51,30 @@ public void AddInput(OnnxNamedMetadata metaData, OrtValue value)
3751
/// </summary>
3852
/// <param name="metaData">The meta data.</param>
3953
/// <param name="value">The value.</param>
40-
public void AddOutput(OnnxNamedMetadata metaData, OrtValue value)
54+
public void AddOutput(OrtValue value)
55+
{
56+
_outputs.Add(GetNextOutputMetadata(), value);
57+
}
58+
59+
60+
/// <summary>
61+
/// Adds the output buffer.
62+
/// </summary>
63+
/// <param name="bufferDimension">The buffer dimension.</param>
64+
public void AddOutputBuffer(ReadOnlySpan<int> bufferDimension)
4165
{
42-
_outputs.Add(metaData, value);
66+
var metadata = GetNextOutputMetadata();
67+
_outputs.Add(metadata, metadata.CreateOutputBuffer(bufferDimension));
4368
}
4469

4570

4671
/// <summary>
4772
/// Adds an output parameter with unknown output size.
4873
/// </summary>
4974
/// <param name="metaData">The meta data.</param>
50-
public void AddOutput(OnnxNamedMetadata metaData)
75+
public void AddOutput()
5176
{
52-
_outputs.AddName(metaData);
77+
_outputs.AddName(GetNextOutputMetadata());
5378
}
5479

5580

@@ -103,5 +128,21 @@ public void Dispose()
103128
_outputs?.Dispose();
104129
_runOptions?.Dispose();
105130
}
131+
132+
private OnnxNamedMetadata GetNextInputMetadata()
133+
{
134+
if (_inputs.Names.Count >= _metadata.Inputs.Count)
135+
throw new ArgumentOutOfRangeException($"Too Many Inputs - No Metadata found for input index {_inputs.Names.Count - 1}");
136+
137+
return _metadata.Inputs[_inputs.Names.Count];
138+
}
139+
140+
private OnnxNamedMetadata GetNextOutputMetadata()
141+
{
142+
if (_outputs.Names.Count >= _metadata.Outputs.Count)
143+
throw new ArgumentOutOfRangeException($"Too Many Outputs - No Metadata found for output index {_outputs.Names.Count}");
144+
145+
return _metadata.Outputs[_outputs.Names.Count];
146+
}
106147
}
107148
}

OnnxStack.StableDiffusion/Diffusers/DiffuserBase.cs

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -214,15 +214,12 @@ protected virtual async Task<DenseTensor<float>> DecodeLatentsAsync(IModelOption
214214
// Scale and decode the image latents with vae.
215215
latents = latents.MultiplyBy(1.0f / model.ScaleFactor);
216216

217-
var metadata = _onnxModelService.GetModelMetadata(model, OnnxModelType.VaeDecoder);
218-
var inputMetadata = metadata.Inputs[0];
219-
var outputMetadata = metadata.Outputs[0];
220-
221217
var outputDim = new[] { 1, 3, options.Height, options.Width };
222-
using (var inferenceParameters = new OnnxInferenceParameters())
218+
var metadata = _onnxModelService.GetModelMetadata(model, OnnxModelType.VaeDecoder);
219+
using (var inferenceParameters = new OnnxInferenceParameters(metadata))
223220
{
224-
inferenceParameters.AddInput(inputMetadata, latents.ToOrtValue(outputMetadata));
225-
inferenceParameters.AddOutput(outputMetadata, outputMetadata.CreateOutputBuffer(outputDim));
221+
inferenceParameters.AddInputTensor(latents);
222+
inferenceParameters.AddOutputBuffer(outputDim);
226223

227224
var results = await _onnxModelService.RunInferenceAsync(model, OnnxModelType.VaeDecoder, inferenceParameters);
228225
using (var imageResult = results.First())
@@ -241,16 +238,9 @@ protected virtual async Task<DenseTensor<float>> DecodeLatentsAsync(IModelOption
241238
/// <param name="timestepInputName">Name of the timestep input.</param>
242239
/// <param name="timestep">The timestep.</param>
243240
/// <returns></returns>
244-
protected static OrtValue CreateTimestepNamedOrtValue(OnnxNamedMetadata timestepMetaData, int timestep)
241+
protected static DenseTensor<float> CreateTimestepTensor(int timestep)
245242
{
246-
var dimension = new long[] { 1 };
247-
return timestepMetaData.Value.ElementDataType switch
248-
{
249-
TensorElementType.Int64 => OrtValue.CreateTensorValueFromMemory(new long[] { timestep }, dimension),
250-
TensorElementType.Float16 => OrtValue.CreateTensorValueFromMemory(new Float16[] { (Float16)timestep }, dimension),
251-
TensorElementType.BFloat16 => OrtValue.CreateTensorValueFromMemory(new BFloat16[] { (BFloat16)timestep }, dimension),
252-
_ => OrtValue.CreateTensorValueFromMemory(new float[] { timestep }, dimension) // TODO: Deafult to Float32 for now
253-
};
243+
return TensorHelper.CreateTensor(new float[] { timestep }, new int[] { 1 });
254244
}
255245

256246

OnnxStack.StableDiffusion/Diffusers/LatentConsistency/ImageDiffuser.cs

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -61,16 +61,13 @@ protected override async Task<DenseTensor<float>> PrepareLatentsAsync(IModelOpti
6161
{
6262
var imageTensor = prompt.InputImage.ToDenseTensor(new[] { 1, 3, options.Height, options.Width });
6363

64-
var metadata = _onnxModelService.GetModelMetadata(model, OnnxModelType.VaeEncoder);
65-
var inputMetadata = metadata.Inputs[0];
66-
var outputMetadata = metadata.Outputs[0];
67-
6864
//TODO: Model Config, Channels
6965
var outputDimension = options.GetScaledDimension();
70-
using (var inferenceParameters = new OnnxInferenceParameters())
66+
var metadata = _onnxModelService.GetModelMetadata(model, OnnxModelType.VaeEncoder);
67+
using (var inferenceParameters = new OnnxInferenceParameters(metadata))
7168
{
72-
inferenceParameters.AddInput(inputMetadata, imageTensor.ToOrtValue(outputMetadata));
73-
inferenceParameters.AddOutput(outputMetadata, outputMetadata.CreateOutputBuffer(outputDimension));
69+
inferenceParameters.AddInputTensor(imageTensor);
70+
inferenceParameters.AddOutputBuffer(outputDimension);
7471

7572
var results = await _onnxModelService.RunInferenceAsync(model, OnnxModelType.VaeEncoder, inferenceParameters);
7673
using (var result = results.First())

OnnxStack.StableDiffusion/Diffusers/LatentConsistency/LatentConsistencyDiffuser.cs

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -106,12 +106,7 @@ protected override async Task<DenseTensor<float>> SchedulerStepAsync(IModelOptio
106106

107107
// Get Model metadata
108108
var metadata = _onnxModelService.GetModelMetadata(modelOptions, OnnxModelType.Unet);
109-
var outputMetadata = metadata.Outputs[0];
110-
var inputMetadata = metadata.Inputs[0];
111-
var timestepMetadata = metadata.Inputs[1];
112-
var promptMetadata = metadata.Inputs[2];
113-
var guidanceMetadata = metadata.Inputs[3];
114-
109+
115110
// Loop though the timesteps
116111
var step = 0;
117112
foreach (var timestep in timesteps)
@@ -122,17 +117,18 @@ protected override async Task<DenseTensor<float>> SchedulerStepAsync(IModelOptio
122117

123118
// Create input tensor.
124119
var inputTensor = scheduler.ScaleInput(latents, timestep);
120+
var timestepTensor = CreateTimestepTensor(timestep);
125121

126122
var outputChannels = 1;
127123
var outputDimension = schedulerOptions.GetScaledDimension(outputChannels);
128-
using (var inferenceParameters = new OnnxInferenceParameters())
124+
using (var inferenceParameters = new OnnxInferenceParameters(metadata))
129125
{
130-
inferenceParameters.AddInput(inputMetadata, inputTensor.ToOrtValue(outputMetadata));
131-
inferenceParameters.AddInput(timestepMetadata, CreateTimestepNamedOrtValue(timestepMetadata, timestep));
132-
inferenceParameters.AddInput(promptMetadata, promptEmbeddings.ToOrtValue(outputMetadata));
133-
inferenceParameters.AddInput(guidanceMetadata, guidanceEmbeddings.ToOrtValue(outputMetadata));
134-
inferenceParameters.AddOutput(outputMetadata, outputMetadata.CreateOutputBuffer(outputDimension));
135-
126+
inferenceParameters.AddInputTensor(inputTensor);
127+
inferenceParameters.AddInputTensor(timestepTensor);
128+
inferenceParameters.AddInputTensor(promptEmbeddings);
129+
inferenceParameters.AddInputTensor(guidanceEmbeddings);
130+
inferenceParameters.AddOutputBuffer(outputDimension);
131+
136132
var results = await _onnxModelService.RunInferenceAsync(modelOptions, OnnxModelType.Unet, inferenceParameters);
137133
using (var result = results.First())
138134
{

OnnxStack.StableDiffusion/Diffusers/StableDiffusion/ImageDiffuser.cs

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -63,16 +63,13 @@ protected override async Task<DenseTensor<float>> PrepareLatentsAsync(IModelOpti
6363
{
6464
var imageTensor = prompt.InputImage.ToDenseTensor(new[] { 1, 3, options.Height, options.Width });
6565

66-
var metadata = _onnxModelService.GetModelMetadata(model, OnnxModelType.VaeEncoder);
67-
var inputMetadata = metadata.Inputs[0];
68-
var outputMetadata = metadata.Outputs[0];
69-
7066
//TODO: Model Config, Channels
7167
var outputDimension = options.GetScaledDimension();
72-
using (var inferenceParameters = new OnnxInferenceParameters())
68+
var metadata = _onnxModelService.GetModelMetadata(model, OnnxModelType.VaeEncoder);
69+
using (var inferenceParameters = new OnnxInferenceParameters(metadata))
7370
{
74-
inferenceParameters.AddInput(inputMetadata, imageTensor.ToOrtValue(outputMetadata));
75-
inferenceParameters.AddOutput(outputMetadata, outputMetadata.CreateOutputBuffer(outputDimension));
71+
inferenceParameters.AddInputTensor(imageTensor);
72+
inferenceParameters.AddOutputBuffer(outputDimension);
7673

7774
var results = await _onnxModelService.RunInferenceAsync(model, OnnxModelType.VaeEncoder, inferenceParameters);
7875
using (var result = results.First())

OnnxStack.StableDiffusion/Diffusers/StableDiffusion/InpaintDiffuser.cs

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -72,10 +72,6 @@ protected override async Task<DenseTensor<float>> SchedulerStepAsync(IModelOptio
7272

7373
// Get Model metadata
7474
var metadata = _onnxModelService.GetModelMetadata(modelOptions, OnnxModelType.Unet);
75-
var outputMetadata = metadata.Outputs[0];
76-
var inputMetadata = metadata.Inputs[0];
77-
var timestepMetadata = metadata.Inputs[1];
78-
var promptMetadata = metadata.Inputs[2];
7975

8076
// Loop though the timesteps
8177
var step = 0;
@@ -89,15 +85,16 @@ protected override async Task<DenseTensor<float>> SchedulerStepAsync(IModelOptio
8985
var inputLatent = performGuidance ? latents.Repeat(2) : latents;
9086
var inputTensor = scheduler.ScaleInput(inputLatent, timestep);
9187
inputTensor = ConcatenateLatents(inputTensor, maskedImage, maskImage);
88+
var timestepTensor = CreateTimestepTensor(timestep);
9289

9390
var outputChannels = performGuidance ? 2 : 1;
9491
var outputDimension = schedulerOptions.GetScaledDimension(outputChannels);
95-
using (var inferenceParameters = new OnnxInferenceParameters())
92+
using (var inferenceParameters = new OnnxInferenceParameters(metadata))
9693
{
97-
inferenceParameters.AddInput(inputMetadata, inputTensor.ToOrtValue(outputMetadata));
98-
inferenceParameters.AddInput(timestepMetadata, CreateTimestepNamedOrtValue(timestepMetadata, timestep));
99-
inferenceParameters.AddInput(promptMetadata, promptEmbeddings.ToOrtValue(outputMetadata));
100-
inferenceParameters.AddOutput(outputMetadata, outputMetadata.CreateOutputBuffer(outputDimension));
94+
inferenceParameters.AddInputTensor(inputTensor);
95+
inferenceParameters.AddInputTensor(timestepTensor);
96+
inferenceParameters.AddInputTensor(promptEmbeddings);
97+
inferenceParameters.AddOutputBuffer(outputDimension);
10198

10299
var results = await _onnxModelService.RunInferenceAsync(modelOptions, OnnxModelType.Unet, inferenceParameters);
103100
using (var result = results.First())
@@ -221,15 +218,12 @@ private async Task<DenseTensor<float>> PrepareImageMask(IModelOptions modelOptio
221218
});
222219

223220
// Encode the image
224-
var metadata = _onnxModelService.GetModelMetadata(modelOptions, OnnxModelType.VaeEncoder);
225-
var inputMetadata = metadata.Inputs[0];
226-
var outputMetadata = metadata.Outputs[0];
227-
228221
var outputDimension = schedulerOptions.GetScaledDimension();
229-
using (var inferenceParameters = new OnnxInferenceParameters())
222+
var metadata = _onnxModelService.GetModelMetadata(modelOptions, OnnxModelType.VaeEncoder);
223+
using (var inferenceParameters = new OnnxInferenceParameters(metadata))
230224
{
231-
inferenceParameters.AddInput(inputMetadata, imageTensor.ToOrtValue(outputMetadata));
232-
inferenceParameters.AddOutput(outputMetadata, outputMetadata.CreateOutputBuffer(outputDimension));
225+
inferenceParameters.AddInputTensor(imageTensor);
226+
inferenceParameters.AddOutputBuffer(outputDimension);
233227

234228
var results = await _onnxModelService.RunInferenceAsync(modelOptions, OnnxModelType.VaeEncoder, inferenceParameters);
235229
using (var result = results.First())

OnnxStack.StableDiffusion/Diffusers/StableDiffusion/InpaintLegacyDiffuser.cs

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,6 @@ protected override async Task<DenseTensor<float>> SchedulerStepAsync(IModelOptio
7171

7272
// Get Model metadata
7373
var metadata = _onnxModelService.GetModelMetadata(modelOptions, OnnxModelType.Unet);
74-
var outputMetadata = metadata.Outputs[0];
75-
var inputMetadata = metadata.Inputs[0];
76-
var timestepMetadata = metadata.Inputs[1];
77-
var promptMetadata = metadata.Inputs[2];
7874

7975
// Loop though the timesteps
8076
var step = 0;
@@ -87,15 +83,16 @@ protected override async Task<DenseTensor<float>> SchedulerStepAsync(IModelOptio
8783
// Create input tensor.
8884
var inputLatent = performGuidance ? latents.Repeat(2) : latents;
8985
var inputTensor = scheduler.ScaleInput(inputLatent, timestep);
86+
var timestepTensor = CreateTimestepTensor(timestep);
9087

9188
var outputChannels = performGuidance ? 2 : 1;
9289
var outputDimension = schedulerOptions.GetScaledDimension(outputChannels);
93-
using (var inferenceParameters = new OnnxInferenceParameters())
90+
using (var inferenceParameters = new OnnxInferenceParameters(metadata))
9491
{
95-
inferenceParameters.AddInput(inputMetadata, inputTensor.ToOrtValue(outputMetadata));
96-
inferenceParameters.AddInput(timestepMetadata, CreateTimestepNamedOrtValue(timestepMetadata, timestep));
97-
inferenceParameters.AddInput(promptMetadata, promptEmbeddings.ToOrtValue(outputMetadata));
98-
inferenceParameters.AddOutput(outputMetadata, outputMetadata.CreateOutputBuffer(outputDimension));
92+
inferenceParameters.AddInputTensor(inputTensor);
93+
inferenceParameters.AddInputTensor(timestepTensor);
94+
inferenceParameters.AddInputTensor(promptEmbeddings);
95+
inferenceParameters.AddOutputBuffer(outputDimension);
9996

10097
var results = await _onnxModelService.RunInferenceAsync(modelOptions, OnnxModelType.Unet, inferenceParameters);
10198
using (var result = results.First())
@@ -153,16 +150,13 @@ protected override async Task<DenseTensor<float>> PrepareLatentsAsync(IModelOpti
153150
{
154151
var imageTensor = prompt.InputImage.ToDenseTensor(new[] { 1, 3, options.Height, options.Width });
155152

156-
var metadata = _onnxModelService.GetModelMetadata(model, OnnxModelType.VaeEncoder);
157-
var inputMetadata = metadata.Inputs[0];
158-
var outputMetadata = metadata.Outputs[0];
159-
160153
//TODO: Model Config, Channels
161154
var outputDimensions = options.GetScaledDimension();
162-
using (var inferenceParameters = new OnnxInferenceParameters())
155+
var metadata = _onnxModelService.GetModelMetadata(model, OnnxModelType.VaeEncoder);
156+
using (var inferenceParameters = new OnnxInferenceParameters(metadata))
163157
{
164-
inferenceParameters.AddInput(inputMetadata, imageTensor.ToOrtValue(outputMetadata));
165-
inferenceParameters.AddOutput(outputMetadata, outputMetadata.CreateOutputBuffer(outputDimensions));
158+
inferenceParameters.AddInputTensor(imageTensor);
159+
inferenceParameters.AddOutputBuffer(outputDimensions);
166160

167161
var results = await _onnxModelService.RunInferenceAsync(model, OnnxModelType.VaeEncoder, inferenceParameters);
168162
using (var result = results.First())

0 commit comments

Comments
 (0)