|
4 | 4 | using System; |
5 | 5 | using System.Collections.Concurrent; |
6 | 6 | using System.Collections.Generic; |
7 | | -using System.Linq; |
8 | 7 | using System.Threading.Tasks; |
9 | 8 |
|
10 | 9 | namespace OnnxStack.Core.Services |
@@ -104,102 +103,119 @@ public Task<bool> IsEnabledAsync(IOnnxModel model, OnnxModelType modelType) |
104 | 103 |
|
105 | 104 |
|
106 | 105 | /// <summary> |
107 | | - /// Runs inference on the specified model. |
| 106 | + /// Runs the inference (Use when output size is unknown) |
108 | 107 | /// </summary> |
| 108 | + /// <param name="model">The model.</param> |
109 | 109 | /// <param name="modelType">Type of the model.</param> |
110 | | - /// <param name="inputs">The inputs.</param> |
| 110 | + /// <param name="inputName">Name of the input.</param> |
| 111 | + /// <param name="inputValue">The input value.</param> |
| 112 | + /// <param name="outputName">Name of the output.</param> |
111 | 113 | /// <returns></returns> |
112 | | - public IDisposableReadOnlyCollection<DisposableNamedOnnxValue> RunInference(IOnnxModel model, OnnxModelType modelType, IReadOnlyCollection<NamedOnnxValue> inputs) |
| 114 | + public IDisposableReadOnlyCollection<OrtValue> RunInference(IOnnxModel model, OnnxModelType modelType, string inputName, OrtValue inputValue, string outputName) |
113 | 115 | { |
114 | | - return RunInternal(model, modelType, inputs); |
| 116 | + var inputs = new Dictionary<string, OrtValue> { { inputName, inputValue } }; |
| 117 | + var outputs = new List<string> { outputName }; |
| 118 | + return RunInference(model, modelType, inputs, outputs); |
115 | 119 | } |
116 | 120 |
|
117 | 121 |
|
118 | 122 | /// <summary> |
119 | | - /// Runs inference on the specified model asynchronously(ish). |
| 123 | + /// Runs the inference (Use when output size is unknown) |
120 | 124 | /// </summary> |
| 125 | + /// <param name="model">The model.</param> |
121 | 126 | /// <param name="modelType">Type of the model.</param> |
122 | 127 | /// <param name="inputs">The inputs.</param> |
| 128 | + /// <param name="outputs">The outputs.</param> |
123 | 129 | /// <returns></returns> |
124 | | - public async Task<IDisposableReadOnlyCollection<DisposableNamedOnnxValue>> RunInferenceAsync(IOnnxModel model, OnnxModelType modelType, IReadOnlyCollection<NamedOnnxValue> inputs) |
| 130 | + public IDisposableReadOnlyCollection<OrtValue> RunInference(IOnnxModel model, OnnxModelType modelType, Dictionary<string, OrtValue> inputs, IReadOnlyCollection<string> outputs) |
125 | 131 | { |
126 | | - return await Task.Run(() => RunInternal(model, modelType, inputs)).ConfigureAwait(false); |
| 132 | + return GetModelSet(model) |
| 133 | + .GetSession(modelType) |
| 134 | + .Run(new RunOptions(), inputs, outputs); |
127 | 135 | } |
128 | 136 |
|
129 | 137 |
|
130 | 138 | /// <summary> |
131 | | - /// Gets the input metadata. |
| 139 | + /// Runs the inference asynchronously, (Use when output size is known) |
| 140 | + /// Output buffer size must be known and set before inference is run |
132 | 141 | /// </summary> |
| 142 | + /// <param name="model">The model.</param> |
133 | 143 | /// <param name="modelType">Type of the model.</param> |
| 144 | + /// <param name="inputName">Name of the input.</param> |
| 145 | + /// <param name="inputValue">The input value.</param> |
| 146 | + /// <param name="outputName">Name of the output.</param> |
| 147 | + /// <param name="outputValue">The output value.</param> |
134 | 148 | /// <returns></returns> |
135 | | - /// <exception cref="System.NotImplementedException"></exception> |
136 | | - public IReadOnlyDictionary<string, NodeMetadata> GetInputMetadata(IOnnxModel model, OnnxModelType modelType) |
| 149 | + public Task<IReadOnlyCollection<OrtValue>> RunInferenceAsync(IOnnxModel model, OnnxModelType modelType, string inputName, OrtValue inputValue, string outputName, OrtValue outputValue) |
137 | 150 | { |
138 | | - return InputMetadataInternal(model, modelType); |
| 151 | + var inputs = new Dictionary<string, OrtValue> { { inputName, inputValue } }; |
| 152 | + var outputs = new Dictionary<string, OrtValue> { { outputName, outputValue } }; |
| 153 | + return RunInferenceAsync(model, modelType, inputs, outputs); |
139 | 154 | } |
140 | 155 |
|
141 | 156 |
|
142 | 157 | /// <summary> |
143 | | - /// Gets the input names. |
| 158 | + /// Runs the inference asynchronously, (Use when output size is known) |
| 159 | + /// Output buffer size must be known and set before inference is run |
144 | 160 | /// </summary> |
| 161 | + /// <param name="model">The model.</param> |
145 | 162 | /// <param name="modelType">Type of the model.</param> |
| 163 | + /// <param name="inputs">The inputs.</param> |
| 164 | + /// <param name="outputs">The outputs.</param> |
146 | 165 | /// <returns></returns> |
147 | | - /// <exception cref="System.NotImplementedException"></exception> |
148 | | - public IReadOnlyList<string> GetInputNames(IOnnxModel model, OnnxModelType modelType) |
| 166 | + public Task<IReadOnlyCollection<OrtValue>> RunInferenceAsync(IOnnxModel model, OnnxModelType modelType, Dictionary<string, OrtValue> inputs, Dictionary<string, OrtValue> outputs) |
149 | 167 | { |
150 | | - return InputNamesInternal(model, modelType); |
| 168 | + return GetModelSet(model) |
| 169 | + .GetSession(modelType) |
| 170 | + .RunAsync(new RunOptions(), inputs.Keys, inputs.Values, outputs.Keys, outputs.Values); |
151 | 171 | } |
152 | 172 |
|
153 | 173 |
|
154 | 174 | /// <summary> |
155 | | - /// Gets the output metadata. |
| 175 | + /// Gets the input metadata. |
156 | 176 | /// </summary> |
157 | 177 | /// <param name="modelType">Type of the model.</param> |
158 | 178 | /// <returns></returns> |
159 | 179 | /// <exception cref="System.NotImplementedException"></exception> |
160 | | - public IReadOnlyDictionary<string, NodeMetadata> GetOutputMetadata(IOnnxModel model, OnnxModelType modelType) |
| 180 | + public IReadOnlyDictionary<string, NodeMetadata> GetInputMetadata(IOnnxModel model, OnnxModelType modelType) |
161 | 181 | { |
162 | | - return OutputMetadataInternal(model, modelType); |
| 182 | + return InputMetadataInternal(model, modelType); |
163 | 183 | } |
164 | 184 |
|
165 | 185 |
|
166 | 186 | /// <summary> |
167 | | - /// Gets the output names. |
| 187 | + /// Gets the input names. |
168 | 188 | /// </summary> |
169 | 189 | /// <param name="modelType">Type of the model.</param> |
170 | 190 | /// <returns></returns> |
171 | 191 | /// <exception cref="System.NotImplementedException"></exception> |
172 | | - public IReadOnlyList<string> GetOutputNames(IOnnxModel model, OnnxModelType modelType) |
| 192 | + public IReadOnlyList<string> GetInputNames(IOnnxModel model, OnnxModelType modelType) |
173 | 193 | { |
174 | | - return OutputNamesInternal(model, modelType); |
| 194 | + return InputNamesInternal(model, modelType); |
175 | 195 | } |
176 | 196 |
|
177 | 197 |
|
178 | 198 | /// <summary> |
179 | | - /// Runs inference on the specified model. |
| 199 | + /// Gets the output metadata. |
180 | 200 | /// </summary> |
181 | 201 | /// <param name="modelType">Type of the model.</param> |
182 | | - /// <param name="inputs">The inputs.</param> |
183 | 202 | /// <returns></returns> |
184 | | - public IReadOnlyCollection<OrtValue> RunInference(IOnnxModel model, OnnxModelType modelType, Dictionary<string, OrtValue> inputs, IReadOnlyCollection<string> outputs) |
| 203 | + /// <exception cref="System.NotImplementedException"></exception> |
| 204 | + public IReadOnlyDictionary<string, NodeMetadata> GetOutputMetadata(IOnnxModel model, OnnxModelType modelType) |
185 | 205 | { |
186 | | - return GetModelSet(model) |
187 | | - .GetSession(modelType) |
188 | | - .Run(new RunOptions(), inputs, outputs); |
| 206 | + return OutputMetadataInternal(model, modelType); |
189 | 207 | } |
190 | 208 |
|
191 | 209 |
|
192 | 210 | /// <summary> |
193 | | - /// Runs inference on the specified model. |
| 211 | + /// Gets the output names. |
194 | 212 | /// </summary> |
195 | 213 | /// <param name="modelType">Type of the model.</param> |
196 | | - /// <param name="inputs">The inputs.</param> |
197 | 214 | /// <returns></returns> |
198 | | - public Task<IReadOnlyCollection<OrtValue>> RunInferenceAsync(IOnnxModel model, OnnxModelType modelType, Dictionary<string, OrtValue> inputs, Dictionary<string, OrtValue> outputs) |
| 215 | + /// <exception cref="System.NotImplementedException"></exception> |
| 216 | + public IReadOnlyList<string> GetOutputNames(IOnnxModel model, OnnxModelType modelType) |
199 | 217 | { |
200 | | - return GetModelSet(model) |
201 | | - .GetSession(modelType) |
202 | | - .RunAsync(new RunOptions(), inputs.Keys, inputs.Values, outputs.Keys, outputs.Values); |
| 218 | + return OutputNamesInternal(model, modelType); |
203 | 219 | } |
204 | 220 |
|
205 | 221 |
|
@@ -334,5 +350,7 @@ public void Dispose() |
334 | 350 | onnxModelSet?.Dispose(); |
335 | 351 | } |
336 | 352 | } |
| 353 | + |
| 354 | + |
337 | 355 | } |
338 | 356 | } |
0 commit comments