Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit 5cb4ff3

Browse files
committed
Add some single input/output overloads for RunInference
1 parent ab48d4b commit 5cb4ff3

File tree

8 files changed

+97
-75
lines changed

8 files changed

+97
-75
lines changed

OnnxStack.Core/Services/IOnnxModelService.cs

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -61,32 +61,41 @@ public interface IOnnxModelService : IDisposable
6161

6262

6363
/// <summary>
64-
/// Runs inference on the specified model.
64+
/// Runs the inference Use when output size is unknown
6565
/// </summary>
66+
/// <param name="model">The model.</param>
6667
/// <param name="modelType">Type of the model.</param>
67-
/// <param name="inputs">The inputs.</param>
68+
/// <param name="inputName">Name of the input.</param>
69+
/// <param name="inputValue">The input value.</param>
70+
/// <param name="outputName">Name of the output.</param>
71+
/// <param name="outputValue">The output value.</param>
6872
/// <returns></returns>
69-
IDisposableReadOnlyCollection<DisposableNamedOnnxValue> RunInference(IOnnxModel model, OnnxModelType modelType, IReadOnlyCollection<NamedOnnxValue> inputs);
73+
IDisposableReadOnlyCollection<OrtValue> RunInference(IOnnxModel model, OnnxModelType modelType, string inputName, OrtValue inputValue, string outputName);
7074

7175

7276
/// <summary>
73-
/// Runs inference on the specified model.asynchronously.
77+
/// Runs the inference Use when output size is unknown
7478
/// </summary>
79+
/// <param name="model">The model.</param>
7580
/// <param name="modelType">Type of the model.</param>
7681
/// <param name="inputs">The inputs.</param>
82+
/// <param name="outputs">The outputs.</param>
7783
/// <returns></returns>
78-
Task<IDisposableReadOnlyCollection<DisposableNamedOnnxValue>> RunInferenceAsync(IOnnxModel model, OnnxModelType modelType, IReadOnlyCollection<NamedOnnxValue> inputs);
84+
IDisposableReadOnlyCollection<OrtValue> RunInference(IOnnxModel model, OnnxModelType modelType, Dictionary<string, OrtValue> inputs, IReadOnlyCollection<string> outputs);
7985

8086

8187
/// <summary>
82-
/// Runs the inference Use when output size is unknown
88+
/// Runs the inference asynchronously, Use when output size is known
89+
/// Output buffer size must be known and set before inference is run
8390
/// </summary>
8491
/// <param name="model">The model.</param>
8592
/// <param name="modelType">Type of the model.</param>
86-
/// <param name="inputs">The inputs.</param>
87-
/// <param name="outputs">The outputs.</param>
93+
/// <param name="inputName">Name of the input.</param>
94+
/// <param name="inputValue">The input value.</param>
95+
/// <param name="outputName">Name of the output.</param>
96+
/// <param name="outputValue">The output value.</param>
8897
/// <returns></returns>
89-
IReadOnlyCollection<OrtValue> RunInference(IOnnxModel model, OnnxModelType modelType, Dictionary<string, OrtValue> inputs, IReadOnlyCollection<string> outputs);
98+
Task<IReadOnlyCollection<OrtValue>> RunInferenceAsync(IOnnxModel model, OnnxModelType modelType, string inputName, OrtValue inputValue, string outputName, OrtValue outputValue);
9099

91100

92101
/// <summary>

OnnxStack.Core/Services/OnnxModelService.cs

Lines changed: 52 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
using System;
55
using System.Collections.Concurrent;
66
using System.Collections.Generic;
7-
using System.Linq;
87
using System.Threading.Tasks;
98

109
namespace OnnxStack.Core.Services
@@ -104,102 +103,119 @@ public Task<bool> IsEnabledAsync(IOnnxModel model, OnnxModelType modelType)
104103

105104

106105
/// <summary>
107-
/// Runs inference on the specified model.
106+
/// Runs the inference (Use when output size is unknown)
108107
/// </summary>
108+
/// <param name="model">The model.</param>
109109
/// <param name="modelType">Type of the model.</param>
110-
/// <param name="inputs">The inputs.</param>
110+
/// <param name="inputName">Name of the input.</param>
111+
/// <param name="inputValue">The input value.</param>
112+
/// <param name="outputName">Name of the output.</param>
111113
/// <returns></returns>
112-
public IDisposableReadOnlyCollection<DisposableNamedOnnxValue> RunInference(IOnnxModel model, OnnxModelType modelType, IReadOnlyCollection<NamedOnnxValue> inputs)
114+
public IDisposableReadOnlyCollection<OrtValue> RunInference(IOnnxModel model, OnnxModelType modelType, string inputName, OrtValue inputValue, string outputName)
113115
{
114-
return RunInternal(model, modelType, inputs);
116+
var inputs = new Dictionary<string, OrtValue> { { inputName, inputValue } };
117+
var outputs = new List<string> { outputName };
118+
return RunInference(model, modelType, inputs, outputs);
115119
}
116120

117121

118122
/// <summary>
119-
/// Runs inference on the specified model asynchronously(ish).
123+
/// Runs the inference (Use when output size is unknown)
120124
/// </summary>
125+
/// <param name="model">The model.</param>
121126
/// <param name="modelType">Type of the model.</param>
122127
/// <param name="inputs">The inputs.</param>
128+
/// <param name="outputs">The outputs.</param>
123129
/// <returns></returns>
124-
public async Task<IDisposableReadOnlyCollection<DisposableNamedOnnxValue>> RunInferenceAsync(IOnnxModel model, OnnxModelType modelType, IReadOnlyCollection<NamedOnnxValue> inputs)
130+
public IDisposableReadOnlyCollection<OrtValue> RunInference(IOnnxModel model, OnnxModelType modelType, Dictionary<string, OrtValue> inputs, IReadOnlyCollection<string> outputs)
125131
{
126-
return await Task.Run(() => RunInternal(model, modelType, inputs)).ConfigureAwait(false);
132+
return GetModelSet(model)
133+
.GetSession(modelType)
134+
.Run(new RunOptions(), inputs, outputs);
127135
}
128136

129137

130138
/// <summary>
131-
/// Gets the input metadata.
139+
/// Runs the inference asynchronously, (Use when output size is known)
140+
/// Output buffer size must be known and set before inference is run
132141
/// </summary>
142+
/// <param name="model">The model.</param>
133143
/// <param name="modelType">Type of the model.</param>
144+
/// <param name="inputName">Name of the input.</param>
145+
/// <param name="inputValue">The input value.</param>
146+
/// <param name="outputName">Name of the output.</param>
147+
/// <param name="outputValue">The output value.</param>
134148
/// <returns></returns>
135-
/// <exception cref="System.NotImplementedException"></exception>
136-
public IReadOnlyDictionary<string, NodeMetadata> GetInputMetadata(IOnnxModel model, OnnxModelType modelType)
149+
public Task<IReadOnlyCollection<OrtValue>> RunInferenceAsync(IOnnxModel model, OnnxModelType modelType, string inputName, OrtValue inputValue, string outputName, OrtValue outputValue)
137150
{
138-
return InputMetadataInternal(model, modelType);
151+
var inputs = new Dictionary<string, OrtValue> { { inputName, inputValue } };
152+
var outputs = new Dictionary<string, OrtValue> { { outputName, outputValue } };
153+
return RunInferenceAsync(model, modelType, inputs, outputs);
139154
}
140155

141156

142157
/// <summary>
143-
/// Gets the input names.
158+
/// Runs the inference asynchronously, (Use when output size is known)
159+
/// Output buffer size must be known and set before inference is run
144160
/// </summary>
161+
/// <param name="model">The model.</param>
145162
/// <param name="modelType">Type of the model.</param>
163+
/// <param name="inputs">The inputs.</param>
164+
/// <param name="outputs">The outputs.</param>
146165
/// <returns></returns>
147-
/// <exception cref="System.NotImplementedException"></exception>
148-
public IReadOnlyList<string> GetInputNames(IOnnxModel model, OnnxModelType modelType)
166+
public Task<IReadOnlyCollection<OrtValue>> RunInferenceAsync(IOnnxModel model, OnnxModelType modelType, Dictionary<string, OrtValue> inputs, Dictionary<string, OrtValue> outputs)
149167
{
150-
return InputNamesInternal(model, modelType);
168+
return GetModelSet(model)
169+
.GetSession(modelType)
170+
.RunAsync(new RunOptions(), inputs.Keys, inputs.Values, outputs.Keys, outputs.Values);
151171
}
152172

153173

154174
/// <summary>
155-
/// Gets the output metadata.
175+
/// Gets the input metadata.
156176
/// </summary>
157177
/// <param name="modelType">Type of the model.</param>
158178
/// <returns></returns>
159179
/// <exception cref="System.NotImplementedException"></exception>
160-
public IReadOnlyDictionary<string, NodeMetadata> GetOutputMetadata(IOnnxModel model, OnnxModelType modelType)
180+
public IReadOnlyDictionary<string, NodeMetadata> GetInputMetadata(IOnnxModel model, OnnxModelType modelType)
161181
{
162-
return OutputMetadataInternal(model, modelType);
182+
return InputMetadataInternal(model, modelType);
163183
}
164184

165185

166186
/// <summary>
167-
/// Gets the output names.
187+
/// Gets the input names.
168188
/// </summary>
169189
/// <param name="modelType">Type of the model.</param>
170190
/// <returns></returns>
171191
/// <exception cref="System.NotImplementedException"></exception>
172-
public IReadOnlyList<string> GetOutputNames(IOnnxModel model, OnnxModelType modelType)
192+
public IReadOnlyList<string> GetInputNames(IOnnxModel model, OnnxModelType modelType)
173193
{
174-
return OutputNamesInternal(model, modelType);
194+
return InputNamesInternal(model, modelType);
175195
}
176196

177197

178198
/// <summary>
179-
/// Runs inference on the specified model.
199+
/// Gets the output metadata.
180200
/// </summary>
181201
/// <param name="modelType">Type of the model.</param>
182-
/// <param name="inputs">The inputs.</param>
183202
/// <returns></returns>
184-
public IReadOnlyCollection<OrtValue> RunInference(IOnnxModel model, OnnxModelType modelType, Dictionary<string, OrtValue> inputs, IReadOnlyCollection<string> outputs)
203+
/// <exception cref="System.NotImplementedException"></exception>
204+
public IReadOnlyDictionary<string, NodeMetadata> GetOutputMetadata(IOnnxModel model, OnnxModelType modelType)
185205
{
186-
return GetModelSet(model)
187-
.GetSession(modelType)
188-
.Run(new RunOptions(), inputs, outputs);
206+
return OutputMetadataInternal(model, modelType);
189207
}
190208

191209

192210
/// <summary>
193-
/// Runs inference on the specified model.
211+
/// Gets the output names.
194212
/// </summary>
195213
/// <param name="modelType">Type of the model.</param>
196-
/// <param name="inputs">The inputs.</param>
197214
/// <returns></returns>
198-
public Task<IReadOnlyCollection<OrtValue>> RunInferenceAsync(IOnnxModel model, OnnxModelType modelType, Dictionary<string, OrtValue> inputs, Dictionary<string, OrtValue> outputs)
215+
/// <exception cref="System.NotImplementedException"></exception>
216+
public IReadOnlyList<string> GetOutputNames(IOnnxModel model, OnnxModelType modelType)
199217
{
200-
return GetModelSet(model)
201-
.GetSession(modelType)
202-
.RunAsync(new RunOptions(), inputs.Keys, inputs.Values, outputs.Keys, outputs.Values);
218+
return OutputNamesInternal(model, modelType);
203219
}
204220

205221

@@ -334,5 +350,7 @@ public void Dispose()
334350
onnxModelSet?.Dispose();
335351
}
336352
}
353+
354+
337355
}
338356
}

OnnxStack.StableDiffusion/Diffusers/DiffuserBase.cs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -222,9 +222,7 @@ protected virtual async Task<DenseTensor<float>> DecodeLatentsAsync(IModelOption
222222
using (var inputTensorValue = latents.ToOrtValue(outputTensorMetaData))
223223
using (var outputTensorValue = outputTensorMetaData.CreateOutputBuffer(outputDim))
224224
{
225-
var inputs = new Dictionary<string, OrtValue> { { inputNames[0], inputTensorValue } };
226-
var outputs = new Dictionary<string, OrtValue> { { outputNames[0], outputTensorValue } };
227-
var results = await _onnxModelService.RunInferenceAsync(model, OnnxModelType.VaeDecoder, inputs, outputs);
225+
var results = await _onnxModelService.RunInferenceAsync(model, OnnxModelType.VaeDecoder, inputNames[0], inputTensorValue, outputNames[0], outputTensorValue);
228226
using (var imageResult = results.First())
229227
{
230228
_logger?.LogEnd("Latents decoded", timestamp);

OnnxStack.StableDiffusion/Diffusers/LatentConsistency/ImageDiffuser.cs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,7 @@ protected override async Task<DenseTensor<float>> PrepareLatentsAsync(IModelOpti
6969
using (var inputTensorValue = imageTensor.ToOrtValue(outputTensorMetaData))
7070
using (var outputTensorValue = outputTensorMetaData.CreateOutputBuffer(outputDimension))
7171
{
72-
var inputs = new Dictionary<string, OrtValue> { { inputNames[0], inputTensorValue } };
73-
var outputs = new Dictionary<string, OrtValue> { { outputNames[0], outputTensorValue } };
74-
var results = await _onnxModelService.RunInferenceAsync(model, OnnxModelType.VaeEncoder, inputs, outputs);
72+
var results = await _onnxModelService.RunInferenceAsync(model, OnnxModelType.VaeEncoder, inputNames[0], inputTensorValue, outputNames[0], outputTensorValue);
7573
using (var result = results.First())
7674
{
7775
var outputResult = outputTensorValue.ToDenseTensor();

OnnxStack.StableDiffusion/Diffusers/StableDiffusion/ImageDiffuser.cs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,7 @@ protected override async Task<DenseTensor<float>> PrepareLatentsAsync(IModelOpti
7171
using (var inputTensorValue = imageTensor.ToOrtValue(outputTensorMetaData))
7272
using (var outputTensorValue = outputTensorMetaData.CreateOutputBuffer(outputDimension))
7373
{
74-
var inputs = new Dictionary<string, OrtValue> { { inputNames[0], inputTensorValue } };
75-
var outputs = new Dictionary<string, OrtValue> { { outputNames[0], outputTensorValue } };
76-
var results = await _onnxModelService.RunInferenceAsync(model, OnnxModelType.VaeEncoder, inputs, outputs);
74+
var results = await _onnxModelService.RunInferenceAsync(model, OnnxModelType.VaeEncoder, inputNames[0], inputTensorValue, outputNames[0], outputTensorValue);
7775
using (var result = results.First())
7876
{
7977
var outputResult = outputTensorValue.ToDenseTensor();

OnnxStack.StableDiffusion/Diffusers/StableDiffusion/InpaintDiffuser.cs

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
using System.Collections.Generic;
1515
using System.Diagnostics;
1616
using System.Linq;
17+
using System.Reflection;
1718
using System.Threading;
1819
using System.Threading.Tasks;
1920

@@ -66,7 +67,7 @@ protected override async Task<DenseTensor<float>> SchedulerStepAsync(IModelOptio
6667
var maskImage = PrepareMask(modelOptions, promptOptions, schedulerOptions);
6768

6869
// Create Masked Image Latents
69-
var maskedImage = PrepareImageMask(modelOptions, promptOptions, schedulerOptions);
70+
var maskedImage = await PrepareImageMask(modelOptions, promptOptions, schedulerOptions);
7071

7172
// Get Model metadata
7273
var inputNames = _onnxModelService.GetInputNames(modelOptions, OnnxModelType.Unet);
@@ -179,7 +180,7 @@ private DenseTensor<float> PrepareMask(IModelOptions modelOptions, PromptOptions
179180
/// <param name="schedulerOptions">The scheduler options.</param>
180181
/// <param name="scheduler">The scheduler.</param>
181182
/// <returns></returns>
182-
private DenseTensor<float> PrepareImageMask(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions)
183+
private async Task<DenseTensor<float>> PrepareImageMask(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions)
183184
{
184185
using (var image = promptOptions.InputImage.ToImage())
185186
using (var mask = promptOptions.InputImageMask.ToImage())
@@ -227,15 +228,24 @@ private DenseTensor<float> PrepareImageMask(IModelOptions modelOptions, PromptOp
227228

228229
// Encode the image
229230
var inputNames = _onnxModelService.GetInputNames(modelOptions, OnnxModelType.VaeEncoder);
230-
var inputParameters = CreateInputParameters(NamedOnnxValue.CreateFromTensor(inputNames[0], imageMaskedTensor));
231-
using (var inferResult = _onnxModelService.RunInference(modelOptions, OnnxModelType.VaeEncoder, inputParameters))
231+
var outputNames = _onnxModelService.GetOutputNames(modelOptions, OnnxModelType.VaeEncoder);
232+
var outputMetaData = _onnxModelService.GetOutputMetadata(modelOptions, OnnxModelType.VaeEncoder);
233+
var outputTensorMetaData = outputMetaData[outputNames[0]];
234+
235+
var outputDimension = schedulerOptions.GetScaledDimension();
236+
using (var inputTensorValue = imageTensor.ToOrtValue(outputTensorMetaData))
237+
using (var outputTensorValue = outputTensorMetaData.CreateOutputBuffer(outputDimension))
232238
{
233-
var sample = inferResult.FirstElementAs<DenseTensor<float>>();
234-
var scaledSample = sample.MultiplyBy(modelOptions.ScaleFactor);
235-
if (schedulerOptions.GuidanceScale > 1f)
236-
scaledSample = scaledSample.Repeat(2);
239+
var results = await _onnxModelService.RunInferenceAsync(modelOptions, OnnxModelType.VaeEncoder, inputNames[0], inputTensorValue, outputNames[0], outputTensorValue);
240+
using (var result = results.First())
241+
{
242+
var sample = outputTensorValue.ToDenseTensor();
243+
var scaledSample = sample.MultiplyBy(modelOptions.ScaleFactor);
244+
if (schedulerOptions.GuidanceScale > 1f)
245+
scaledSample = scaledSample.Repeat(2);
237246

238-
return scaledSample;
247+
return scaledSample;
248+
}
239249
}
240250
}
241251
}

OnnxStack.StableDiffusion/Diffusers/StableDiffusion/InpaintLegacyDiffuser.cs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -169,9 +169,7 @@ protected override async Task<DenseTensor<float>> PrepareLatentsAsync(IModelOpti
169169
using (var inputTensorValue = imageTensor.ToOrtValue(outputBufferMetaData))
170170
using (var outputTensorValue = outputBufferMetaData.CreateOutputBuffer(outputDimensions))
171171
{
172-
var inputs = new Dictionary<string, OrtValue> { { inputNames[0], inputTensorValue } };
173-
var outputs = new Dictionary<string, OrtValue> { { outputNames[0], outputTensorValue } };
174-
var results = await _onnxModelService.RunInferenceAsync(model, OnnxModelType.VaeEncoder, inputs, outputs);
172+
var results = await _onnxModelService.RunInferenceAsync(model, OnnxModelType.VaeEncoder, inputNames[0], inputTensorValue, outputNames[0], outputTensorValue);
175173
using (var result = results.First())
176174
{
177175
var outputResult = outputTensorValue.ToDenseTensor();

0 commit comments

Comments
 (0)