Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit 14eae97

Browse files
committed
Session NodeMetadata access, remove fixed input strings
1 parent ed1b704 commit 14eae97

File tree

4 files changed

+136
-7
lines changed

4 files changed

+136
-7
lines changed

OnnxStack.Core/Services/IOnnxModelService.cs

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using Microsoft.ML.OnnxRuntime;
22
using OnnxStack.Core.Config;
3+
using OnnxStack.Core.Model;
34
using System;
45
using System.Collections.Generic;
56
using System.Threading.Tasks;
@@ -49,5 +50,37 @@ public interface IOnnxModelService : IDisposable
4950
/// <param name="inputs">The inputs.</param>
5051
/// <returns></returns>
5152
Task<IDisposableReadOnlyCollection<DisposableNamedOnnxValue>> RunInferenceAsync(OnnxModelType modelType, IReadOnlyCollection<NamedOnnxValue> inputs);
53+
54+
55+
/// <summary>
56+
/// Gets the Sessions input metadata.
57+
/// </summary>
58+
/// <param name="modelType">Type of model.</param>
59+
/// <returns></returns>
60+
IReadOnlyDictionary<string, NodeMetadata> GetInputMetadata(OnnxModelType modelType);
61+
62+
63+
/// <summary>
64+
/// Gets the Sessions input names.
65+
/// </summary>
66+
/// <param name="modelType">Type of model.</param>
67+
/// <returns></returns>
68+
IReadOnlyList<string> GetInputNames(OnnxModelType modelType);
69+
70+
71+
/// <summary>
72+
/// Gets the Sessions output metadata.
73+
/// </summary>
74+
/// <param name="modelType">Type of model.</param>
75+
/// <returns></returns>
76+
IReadOnlyDictionary<string, NodeMetadata> GetOutputMetadata(OnnxModelType modelType);
77+
78+
79+
/// <summary>
80+
/// Gets the Sessions output metadata names.
81+
/// </summary>
82+
/// <param name="modelType">Type of model.</param>
83+
/// <returns></returns>
84+
IReadOnlyList<string> GetOutputNames(OnnxModelType modelType);
5285
}
5386
}

OnnxStack.Core/Services/OnnxModelService.cs

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,54 @@ public async Task<IDisposableReadOnlyCollection<DisposableNamedOnnxValue>> RunIn
8181
}
8282

8383

84+
/// <summary>
85+
/// Gets the input metadata.
86+
/// </summary>
87+
/// <param name="modelType">Type of the model.</param>
88+
/// <returns></returns>
89+
/// <exception cref="System.NotImplementedException"></exception>
90+
public IReadOnlyDictionary<string, NodeMetadata> GetInputMetadata(OnnxModelType modelType)
91+
{
92+
return InputMetadataInternal(modelType);
93+
}
94+
95+
96+
/// <summary>
97+
/// Gets the input names.
98+
/// </summary>
99+
/// <param name="modelType">Type of the model.</param>
100+
/// <returns></returns>
101+
/// <exception cref="System.NotImplementedException"></exception>
102+
public IReadOnlyList<string> GetInputNames(OnnxModelType modelType)
103+
{
104+
return InputNamesInternal(modelType);
105+
}
106+
107+
108+
/// <summary>
109+
/// Gets the output metadata.
110+
/// </summary>
111+
/// <param name="modelType">Type of the model.</param>
112+
/// <returns></returns>
113+
/// <exception cref="System.NotImplementedException"></exception>
114+
public IReadOnlyDictionary<string, NodeMetadata> GetOutputMetadata(OnnxModelType modelType)
115+
{
116+
return OutputMetadataInternal(modelType);
117+
}
118+
119+
120+
/// <summary>
121+
/// Gets the output names.
122+
/// </summary>
123+
/// <param name="modelType">Type of the model.</param>
124+
/// <returns></returns>
125+
/// <exception cref="System.NotImplementedException"></exception>
126+
public IReadOnlyList<string> GetOutputNames(OnnxModelType modelType)
127+
{
128+
return OutputNamesInternal(modelType);
129+
}
130+
131+
84132
/// <summary>
85133
/// Runs inference on the specified model.
86134
/// </summary>
@@ -93,6 +141,48 @@ private IDisposableReadOnlyCollection<DisposableNamedOnnxValue> RunInternal(Onnx
93141
}
94142

95143

144+
145+
/// <summary>
146+
/// Gets the Sessions input metadata.
147+
/// </summary>
148+
/// <param name="modelType">Type of model.</param>
149+
/// <returns></returns>
150+
private IReadOnlyDictionary<string, NodeMetadata> InputMetadataInternal(OnnxModelType modelType)
151+
{
152+
return _onnxModelSet.GetSession(modelType).InputMetadata;
153+
}
154+
155+
/// <summary>
156+
/// Gets the Sessions input names.
157+
/// </summary>
158+
/// <param name="modelType">Type of model.</param>
159+
/// <returns></returns>
160+
private IReadOnlyList<string> InputNamesInternal(OnnxModelType modelType)
161+
{
162+
return _onnxModelSet.GetSession(modelType).InputNames;
163+
}
164+
165+
/// <summary>
166+
/// Gets the Sessions output metadata.
167+
/// </summary>
168+
/// <param name="modelType">Type of model.</param>
169+
/// <returns></returns>
170+
private IReadOnlyDictionary<string, NodeMetadata> OutputMetadataInternal(OnnxModelType modelType)
171+
{
172+
return _onnxModelSet.GetSession(modelType).OutputMetadata;
173+
}
174+
175+
/// <summary>
176+
/// Gets the Sessions output metadata names.
177+
/// </summary>
178+
/// <param name="modelType">Type of model.</param>
179+
/// <returns></returns>
180+
private IReadOnlyList<string> OutputNamesInternal(OnnxModelType modelType)
181+
{
182+
return _onnxModelSet.GetSession(modelType).OutputNames;
183+
}
184+
185+
96186
/// <summary>
97187
/// Performs application-defined tasks associated with freeing, releasing, or resetting unmanaged resources.
98188
/// </summary>

OnnxStack.StableDiffusion/Services/PromptService.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ public async Task<int[]> DecodeTextAsync(string inputText)
7272
return Array.Empty<int>();
7373

7474
// Create input tensor.
75+
var inputNames = _onnxModelService.GetInputNames(OnnxModelType.Tokenizer);
7576
var inputTensor = new DenseTensor<string>(new string[] { inputText }, new int[] { 1 });
7677
var inputParameters = CreateInputParameters(NamedOnnxValue.CreateFromTensor("string_input", inputTensor));
7778

@@ -92,6 +93,7 @@ public async Task<int[]> DecodeTextAsync(string inputText)
9293
public async Task<float[]> EncodeTokensAsync(int[] tokenizedInput)
9394
{
9495
// Create input tensor.
96+
var inputNames = _onnxModelService.GetInputNames(OnnxModelType.TextEncoder);
9597
var inputTensor = TensorHelper.CreateTensor(tokenizedInput, new[] { 1, tokenizedInput.Length });
9698
var inputParameters = CreateInputParameters(NamedOnnxValue.CreateFromTensor("input_ids", inputTensor));
9799

OnnxStack.StableDiffusion/Services/SchedulerService.cs

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,11 @@ public async Task<DenseTensor<float>> RunAsync(PromptOptions promptOptions, Sche
7070
// Create input tensor.
7171
var inputTensor = scheduler.ScaleInput(latentSample.Duplicate(schedulerOptions.GetScaledDimension(2)), timestep);
7272

73+
var inputNames = _onnxModelService.GetInputNames(OnnxModelType.Unet);
7374
var inputParameters = CreateInputParameters(
74-
NamedOnnxValue.CreateFromTensor("encoder_hidden_states", promptEmbeddings),
75-
NamedOnnxValue.CreateFromTensor("sample", inputTensor),
76-
NamedOnnxValue.CreateFromTensor("timestep", new DenseTensor<long>(new long[] { timestep }, new int[] { 1 })));
75+
NamedOnnxValue.CreateFromTensor(inputNames[0], inputTensor),
76+
NamedOnnxValue.CreateFromTensor(inputNames[1], new DenseTensor<long>(new long[] { timestep }, new int[] { 1 })),
77+
NamedOnnxValue.CreateFromTensor(inputNames[2], promptEmbeddings));
7778

7879
// Run Inference
7980
using (var inferResult = await _onnxModelService.RunInferenceAsync(OnnxModelType.Unet, inputParameters))
@@ -128,7 +129,8 @@ private DenseTensor<float> PrepareLatents(PromptOptions prompt, SchedulerOptions
128129

129130
// Image input, decode, add noise, return as latent 0
130131
var imageTensor = prompt.InputImage.ToDenseTensor(options.Width, options.Height);
131-
var inputParameters = CreateInputParameters(NamedOnnxValue.CreateFromTensor("sample", imageTensor));
132+
var inputNames = _onnxModelService.GetInputNames(OnnxModelType.VaeEncoder);
133+
var inputParameters = CreateInputParameters(NamedOnnxValue.CreateFromTensor(inputNames[0], imageTensor));
132134
using (var inferResult = _onnxModelService.RunInference(OnnxModelType.VaeEncoder, inputParameters))
133135
{
134136
var sample = inferResult.FirstElementAs<DenseTensor<float>>();
@@ -153,7 +155,8 @@ private async Task<DenseTensor<float>> DecodeLatents(SchedulerOptions options, D
153155
// latents = 1 / 0.18215 * latents
154156
latents = latents.MultipleTensorByFloat(1.0f / _configuration.ScaleFactor);
155157

156-
var inputParameters = CreateInputParameters(NamedOnnxValue.CreateFromTensor("latent_sample", latents));
158+
var inputNames = _onnxModelService.GetInputNames(OnnxModelType.VaeDecoder);
159+
var inputParameters = CreateInputParameters(NamedOnnxValue.CreateFromTensor(inputNames[0], latents));
157160

158161
// Run inference.
159162
using (var inferResult = await _onnxModelService.RunInferenceAsync(OnnxModelType.VaeDecoder, inputParameters))
@@ -184,10 +187,11 @@ private async Task<bool> IsImageSafe(SchedulerOptions options, DenseTensor<float
184187
var inputTensor = ClipImageFeatureExtractor(options, resultImage);
185188

186189
//images input
190+
var inputNames = _onnxModelService.GetInputNames(OnnxModelType.SafetyChecker);
187191
var inputImagesTensor = inputTensor.ReorderTensor(new[] { 1, 224, 224, 3 });
188192
var inputParameters = CreateInputParameters(
189-
NamedOnnxValue.CreateFromTensor("clip_input", inputTensor),
190-
NamedOnnxValue.CreateFromTensor("images", inputImagesTensor));
193+
NamedOnnxValue.CreateFromTensor(inputNames[0], inputTensor),
194+
NamedOnnxValue.CreateFromTensor(inputNames[1], inputImagesTensor));
191195

192196
// Run session and send the input data in to get inference output.
193197
using (var inferResult = await _onnxModelService.RunInferenceAsync(OnnxModelType.SafetyChecker, inputParameters))

0 commit comments

Comments
 (0)