Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit 6c00064

Browse files
committed
OnnxInferenceParameters to handle input output disposal
1 parent 2f6cf61 commit 6c00064

File tree

11 files changed

+83
-84
lines changed

11 files changed

+83
-84
lines changed

OnnxStack.Core/Model/OnnxInferenceParameters.cs

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
using Microsoft.ML.OnnxRuntime;
2+
using System;
23
using System.Collections.Generic;
34

45
namespace OnnxStack.Core.Model
56
{
6-
public class OnnxInferenceParameters
7+
public class OnnxInferenceParameters : IDisposable
78
{
9+
private RunOptions _runOptions;
810
private OnnxValueCollection _inputs;
911
private OnnxValueCollection _outputs;
1012

@@ -13,6 +15,7 @@ public class OnnxInferenceParameters
1315
/// </summary>
1416
public OnnxInferenceParameters()
1517
{
18+
_runOptions = new RunOptions();
1619
_inputs = new OnnxValueCollection();
1720
_outputs = new OnnxValueCollection();
1821
}
@@ -50,6 +53,11 @@ public void AddOutput(OnnxNamedMetadata metaData)
5053
}
5154

5255

56+
/// <summary>
57+
/// Gets the run options.
58+
/// </summary>
59+
public RunOptions RunOptions => _runOptions;
60+
5361
/// <summary>
5462
/// Gets the input names.
5563
/// </summary>
@@ -84,5 +92,16 @@ public void AddOutput(OnnxNamedMetadata metaData)
8492
/// Gets the output name values.
8593
/// </summary>
8694
public IReadOnlyDictionary<string, OrtValue> OutputNameValues => _outputs.NameValues;
95+
96+
97+
/// <summary>
98+
/// Performs application-defined tasks associated with freeing, releasing, or resetting unmanaged resources.
99+
/// </summary>
100+
public void Dispose()
101+
{
102+
_inputs?.Dispose();
103+
_outputs?.Dispose();
104+
_runOptions?.Dispose();
105+
}
87106
}
88107
}

OnnxStack.Core/Model/OnnxValueCollection.cs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
using Microsoft.ML.OnnxRuntime;
2+
using System;
23
using System.Collections.Generic;
34

45
namespace OnnxStack.Core.Model
56
{
6-
public class OnnxValueCollection
7+
public class OnnxValueCollection : IDisposable
78
{
89
private readonly List<OnnxNamedMetadata> _metaData;
910
private readonly Dictionary<string, OrtValue> _values;
@@ -41,7 +42,6 @@ public void AddName(OnnxNamedMetadata metaData)
4142
_values.Add(metaData.Name, default);
4243
}
4344

44-
4545
/// <summary>
4646
/// Gets the names.
4747
/// </summary>
@@ -58,5 +58,15 @@ public void AddName(OnnxNamedMetadata metaData)
5858
/// Gets the name values.
5959
/// </summary>
6060
public IReadOnlyDictionary<string, OrtValue> NameValues => _values;
61+
62+
63+
/// <summary>
64+
/// Performs application-defined tasks associated with freeing, releasing, or resetting unmanaged resources.
65+
/// </summary>
66+
public void Dispose()
67+
{
68+
foreach (var ortValue in _values.Values)
69+
ortValue?.Dispose();
70+
}
6171
}
6272
}

OnnxStack.Core/Services/OnnxModelService.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ private IDisposableReadOnlyCollection<OrtValue> RunInferenceInternal(IOnnxModel
156156
{
157157
return GetModelSet(model)
158158
.GetSession(modelType)
159-
.Run(new RunOptions(), parameters.InputNameValues, parameters.OutputNames);
159+
.Run(parameters.RunOptions, parameters.InputNameValues, parameters.OutputNames);
160160
}
161161

162162

@@ -171,7 +171,7 @@ private Task<IReadOnlyCollection<OrtValue>> RunInferenceInternalAsync(IOnnxModel
171171
{
172172
return GetModelSet(model)
173173
.GetSession(modelType)
174-
.RunAsync(new RunOptions(), parameters.InputNames, parameters.InputValues, parameters.OutputNames, parameters.OutputValues);
174+
.RunAsync(parameters.RunOptions, parameters.InputNames, parameters.InputValues, parameters.OutputNames, parameters.OutputValues);
175175
}
176176

177177

OnnxStack.StableDiffusion/Diffusers/DiffuserBase.cs

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -219,12 +219,10 @@ protected virtual async Task<DenseTensor<float>> DecodeLatentsAsync(IModelOption
219219
var outputMetadata = metadata.Outputs[0];
220220

221221
var outputDim = new[] { 1, 3, options.Height, options.Width };
222-
using (var inputTensorValue = latents.ToOrtValue(outputMetadata))
223-
using (var outputTensorValue = outputMetadata.CreateOutputBuffer(outputDim))
222+
using (var inferenceParameters = new OnnxInferenceParameters())
224223
{
225-
var inferenceParameters = new OnnxInferenceParameters();
226-
inferenceParameters.AddInput(inputMetadata, inputTensorValue);
227-
inferenceParameters.AddOutput(outputMetadata, outputTensorValue);
224+
inferenceParameters.AddInput(inputMetadata, latents.ToOrtValue(outputMetadata));
225+
inferenceParameters.AddOutput(outputMetadata, outputMetadata.CreateOutputBuffer(outputDim));
228226

229227
var results = await _onnxModelService.RunInferenceAsync(model, OnnxModelType.VaeDecoder, inferenceParameters);
230228
using (var imageResult = results.First())

OnnxStack.StableDiffusion/Diffusers/LatentConsistency/ImageDiffuser.cs

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -67,17 +67,15 @@ protected override async Task<DenseTensor<float>> PrepareLatentsAsync(IModelOpti
6767

6868
//TODO: Model Config, Channels
6969
var outputDimension = options.GetScaledDimension();
70-
using (var inputTensorValue = imageTensor.ToOrtValue(outputMetadata))
71-
using (var outputTensorValue = outputMetadata.CreateOutputBuffer(outputDimension))
70+
using (var inferenceParameters = new OnnxInferenceParameters())
7271
{
73-
var inferenceParameters = new OnnxInferenceParameters();
74-
inferenceParameters.AddInput(inputMetadata, inputTensorValue);
75-
inferenceParameters.AddOutput(outputMetadata, outputTensorValue);
72+
inferenceParameters.AddInput(inputMetadata, imageTensor.ToOrtValue(outputMetadata));
73+
inferenceParameters.AddOutput(outputMetadata, outputMetadata.CreateOutputBuffer(outputDimension));
7674

7775
var results = await _onnxModelService.RunInferenceAsync(model, OnnxModelType.VaeEncoder, inferenceParameters);
7876
using (var result = results.First())
7977
{
80-
var outputResult = outputTensorValue.ToDenseTensor();
78+
var outputResult = result.ToDenseTensor();
8179
var scaledSample = outputResult
8280
.Add(scheduler.CreateRandomSample(outputDimension, options.InitialNoiseLevel))
8381
.MultiplyBy(model.ScaleFactor);

OnnxStack.StableDiffusion/Diffusers/LatentConsistency/LatentConsistencyDiffuser.cs

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -125,23 +125,18 @@ protected override async Task<DenseTensor<float>> SchedulerStepAsync(IModelOptio
125125

126126
var outputChannels = 1;
127127
var outputDimension = schedulerOptions.GetScaledDimension(outputChannels);
128-
using (var outputTensorValue = outputMetadata.CreateOutputBuffer(outputDimension))
129-
using (var inputTensorValue = inputTensor.ToOrtValue(outputMetadata))
130-
using (var promptTensorValue = promptEmbeddings.ToOrtValue(outputMetadata))
131-
using (var guidanceTensorValue = guidanceEmbeddings.ToOrtValue(outputMetadata))
132-
using (var timestepTensorValue = CreateTimestepNamedOrtValue(timestepMetadata, timestep))
128+
using (var inferenceParameters = new OnnxInferenceParameters())
133129
{
134-
var inferenceParameters = new OnnxInferenceParameters();
135-
inferenceParameters.AddInput(inputMetadata, inputTensorValue);
136-
inferenceParameters.AddInput(timestepMetadata, timestepTensorValue);
137-
inferenceParameters.AddInput(promptMetadata, promptTensorValue);
138-
inferenceParameters.AddInput(guidanceMetadata, guidanceTensorValue);
139-
inferenceParameters.AddOutput(outputMetadata, outputTensorValue);
130+
inferenceParameters.AddInput(inputMetadata, inputTensor.ToOrtValue(outputMetadata));
131+
inferenceParameters.AddInput(timestepMetadata, CreateTimestepNamedOrtValue(timestepMetadata, timestep));
132+
inferenceParameters.AddInput(promptMetadata, promptEmbeddings.ToOrtValue(outputMetadata));
133+
inferenceParameters.AddInput(guidanceMetadata, guidanceEmbeddings.ToOrtValue(outputMetadata));
134+
inferenceParameters.AddOutput(outputMetadata, outputMetadata.CreateOutputBuffer(outputDimension));
140135

141136
var results = await _onnxModelService.RunInferenceAsync(modelOptions, OnnxModelType.Unet, inferenceParameters);
142137
using (var result = results.First())
143138
{
144-
var noisePred = outputTensorValue.ToDenseTensor();
139+
var noisePred = result.ToDenseTensor();
145140

146141
// Scheduler Step
147142
var schedulerResult = scheduler.Step(noisePred, timestep, latents);

OnnxStack.StableDiffusion/Diffusers/StableDiffusion/ImageDiffuser.cs

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -69,17 +69,15 @@ protected override async Task<DenseTensor<float>> PrepareLatentsAsync(IModelOpti
6969

7070
//TODO: Model Config, Channels
7171
var outputDimension = options.GetScaledDimension();
72-
using (var inputTensorValue = imageTensor.ToOrtValue(outputMetadata))
73-
using (var outputTensorValue = outputMetadata.CreateOutputBuffer(outputDimension))
72+
using (var inferenceParameters = new OnnxInferenceParameters())
7473
{
75-
var inferenceParameters = new OnnxInferenceParameters();
76-
inferenceParameters.AddInput(inputMetadata, inputTensorValue);
77-
inferenceParameters.AddOutput(outputMetadata, outputTensorValue);
74+
inferenceParameters.AddInput(inputMetadata, imageTensor.ToOrtValue(outputMetadata));
75+
inferenceParameters.AddOutput(outputMetadata, outputMetadata.CreateOutputBuffer(outputDimension));
7876

7977
var results = await _onnxModelService.RunInferenceAsync(model, OnnxModelType.VaeEncoder, inferenceParameters);
8078
using (var result = results.First())
8179
{
82-
var outputResult = outputTensorValue.ToDenseTensor();
80+
var outputResult = result.ToDenseTensor();
8381
var scaledSample = outputResult
8482
.Add(scheduler.CreateRandomSample(outputDimension, options.InitialNoiseLevel))
8583
.MultiplyBy(model.ScaleFactor);

OnnxStack.StableDiffusion/Diffusers/StableDiffusion/InpaintDiffuser.cs

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -92,21 +92,17 @@ protected override async Task<DenseTensor<float>> SchedulerStepAsync(IModelOptio
9292

9393
var outputChannels = performGuidance ? 2 : 1;
9494
var outputDimension = schedulerOptions.GetScaledDimension(outputChannels);
95-
using (var outputTensorValue = outputMetadata.CreateOutputBuffer(outputDimension))
96-
using (var inputTensorValue = inputTensor.ToOrtValue(outputMetadata))
97-
using (var promptTensorValue = promptEmbeddings.ToOrtValue(outputMetadata))
98-
using (var timestepTensorValue = CreateTimestepNamedOrtValue(timestepMetadata, timestep))
95+
using (var inferenceParameters = new OnnxInferenceParameters())
9996
{
100-
var inferenceParameters = new OnnxInferenceParameters();
101-
inferenceParameters.AddInput(inputMetadata, inputTensorValue);
102-
inferenceParameters.AddInput(timestepMetadata, timestepTensorValue);
103-
inferenceParameters.AddInput(promptMetadata, promptTensorValue);
104-
inferenceParameters.AddOutput(outputMetadata, outputTensorValue);
97+
inferenceParameters.AddInput(inputMetadata, inputTensor.ToOrtValue(outputMetadata));
98+
inferenceParameters.AddInput(timestepMetadata, CreateTimestepNamedOrtValue(timestepMetadata, timestep));
99+
inferenceParameters.AddInput(promptMetadata, promptEmbeddings.ToOrtValue(outputMetadata));
100+
inferenceParameters.AddOutput(outputMetadata, outputMetadata.CreateOutputBuffer(outputDimension));
105101

106102
var results = await _onnxModelService.RunInferenceAsync(modelOptions, OnnxModelType.Unet, inferenceParameters);
107103
using (var result = results.First())
108104
{
109-
var noisePred = outputTensorValue.ToDenseTensor();
105+
var noisePred = result.ToDenseTensor();
110106

111107
// Perform guidance
112108
if (performGuidance)
@@ -230,17 +226,15 @@ private async Task<DenseTensor<float>> PrepareImageMask(IModelOptions modelOptio
230226
var outputMetadata = metadata.Outputs[0];
231227

232228
var outputDimension = schedulerOptions.GetScaledDimension();
233-
using (var inputTensorValue = imageTensor.ToOrtValue(outputMetadata))
234-
using (var outputTensorValue = outputMetadata.CreateOutputBuffer(outputDimension))
229+
using (var inferenceParameters = new OnnxInferenceParameters())
235230
{
236-
var inferenceParameters = new OnnxInferenceParameters();
237-
inferenceParameters.AddInput(inputMetadata, inputTensorValue);
238-
inferenceParameters.AddOutput(outputMetadata, outputTensorValue);
231+
inferenceParameters.AddInput(inputMetadata, imageTensor.ToOrtValue(outputMetadata));
232+
inferenceParameters.AddOutput(outputMetadata, outputMetadata.CreateOutputBuffer(outputDimension));
239233

240234
var results = await _onnxModelService.RunInferenceAsync(modelOptions, OnnxModelType.VaeEncoder, inferenceParameters);
241235
using (var result = results.First())
242236
{
243-
var sample = outputTensorValue.ToDenseTensor();
237+
var sample = result.ToDenseTensor();
244238
var scaledSample = sample.MultiplyBy(modelOptions.ScaleFactor);
245239
if (schedulerOptions.GuidanceScale > 1f)
246240
scaledSample = scaledSample.Repeat(2);

OnnxStack.StableDiffusion/Diffusers/StableDiffusion/InpaintLegacyDiffuser.cs

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -90,21 +90,17 @@ protected override async Task<DenseTensor<float>> SchedulerStepAsync(IModelOptio
9090

9191
var outputChannels = performGuidance ? 2 : 1;
9292
var outputDimension = schedulerOptions.GetScaledDimension(outputChannels);
93-
using (var outputTensorValue = outputMetadata.CreateOutputBuffer(outputDimension))
94-
using (var inputTensorValue = inputTensor.ToOrtValue(outputMetadata))
95-
using (var promptTensorValue = promptEmbeddings.ToOrtValue(outputMetadata))
96-
using (var timestepTensorValue = CreateTimestepNamedOrtValue(timestepMetadata, timestep))
93+
using (var inferenceParameters = new OnnxInferenceParameters())
9794
{
98-
var inferenceParameters = new OnnxInferenceParameters();
99-
inferenceParameters.AddInput(inputMetadata, inputTensorValue);
100-
inferenceParameters.AddInput(timestepMetadata, timestepTensorValue);
101-
inferenceParameters.AddInput(promptMetadata, promptTensorValue);
102-
inferenceParameters.AddOutput(outputMetadata, outputTensorValue);
95+
inferenceParameters.AddInput(inputMetadata, inputTensor.ToOrtValue(outputMetadata));
96+
inferenceParameters.AddInput(timestepMetadata, CreateTimestepNamedOrtValue(timestepMetadata, timestep));
97+
inferenceParameters.AddInput(promptMetadata, promptEmbeddings.ToOrtValue(outputMetadata));
98+
inferenceParameters.AddOutput(outputMetadata, outputMetadata.CreateOutputBuffer(outputDimension));
10399

104100
var results = await _onnxModelService.RunInferenceAsync(modelOptions, OnnxModelType.Unet, inferenceParameters);
105101
using (var result = results.First())
106102
{
107-
var noisePred = outputTensorValue.ToDenseTensor();
103+
var noisePred = result.ToDenseTensor();
108104

109105
// Perform guidance
110106
if (performGuidance)
@@ -163,17 +159,15 @@ protected override async Task<DenseTensor<float>> PrepareLatentsAsync(IModelOpti
163159

164160
//TODO: Model Config, Channels
165161
var outputDimensions = options.GetScaledDimension();
166-
using (var inputTensorValue = imageTensor.ToOrtValue(outputMetadata))
167-
using (var outputTensorValue = outputMetadata.CreateOutputBuffer(outputDimensions))
162+
using (var inferenceParameters = new OnnxInferenceParameters())
168163
{
169-
var inferenceParameters = new OnnxInferenceParameters();
170-
inferenceParameters.AddInput(inputMetadata, inputTensorValue);
171-
inferenceParameters.AddOutput(outputMetadata, outputTensorValue);
164+
inferenceParameters.AddInput(inputMetadata, imageTensor.ToOrtValue(outputMetadata));
165+
inferenceParameters.AddOutput(outputMetadata, outputMetadata.CreateOutputBuffer(outputDimensions));
172166

173167
var results = await _onnxModelService.RunInferenceAsync(model, OnnxModelType.VaeEncoder, inferenceParameters);
174168
using (var result = results.First())
175169
{
176-
var outputResult = outputTensorValue.ToDenseTensor();
170+
var outputResult = result.ToDenseTensor();
177171
var scaledSample = outputResult
178172
.Add(scheduler.CreateRandomSample(outputDimensions, options.InitialNoiseLevel))
179173
.MultiplyBy(model.ScaleFactor);

OnnxStack.StableDiffusion/Diffusers/StableDiffusion/StableDiffusionDiffuser.cs

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -80,16 +80,12 @@ protected override async Task<DenseTensor<float>> SchedulerStepAsync(IModelOptio
8080

8181
var outputChannels = performGuidance ? 2 : 1;
8282
var outputDimension = schedulerOptions.GetScaledDimension(outputChannels);
83-
using (var outputTensorValue = outputMetadata.CreateOutputBuffer(outputDimension))
84-
using (var inputTensorValue = inputTensor.ToOrtValue(outputMetadata))
85-
using (var promptTensorValue = promptEmbeddings.ToOrtValue(outputMetadata))
86-
using (var timestepTensorValue = CreateTimestepNamedOrtValue(timestepMetadata, timestep))
83+
using (var inferenceParameters = new OnnxInferenceParameters())
8784
{
88-
var inferenceParameters = new OnnxInferenceParameters();
89-
inferenceParameters.AddInput(inputMetadata, inputTensorValue);
90-
inferenceParameters.AddInput(timestepMetadata, timestepTensorValue);
91-
inferenceParameters.AddInput(promptMetadata, promptTensorValue);
92-
inferenceParameters.AddOutput(outputMetadata, outputTensorValue);
85+
inferenceParameters.AddInput(inputMetadata, inputTensor.ToOrtValue(outputMetadata));
86+
inferenceParameters.AddInput(timestepMetadata, CreateTimestepNamedOrtValue(timestepMetadata, timestep));
87+
inferenceParameters.AddInput(promptMetadata, promptEmbeddings.ToOrtValue(outputMetadata));
88+
inferenceParameters.AddOutput(outputMetadata, outputMetadata.CreateOutputBuffer(outputDimension));
9389

9490
var results = await _onnxModelService.RunInferenceAsync(modelOptions, OnnxModelType.Unet, inferenceParameters);
9591
using (var result = results.First())

0 commit comments

Comments
 (0)