Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit 29e2973

Browse files
committed
Support the new Onnx OrtValue API
1 parent a7a555e commit 29e2973

File tree

3 files changed

+69
-0
lines changed

3 files changed

+69
-0
lines changed

OnnxStack.Core/Extensions.cs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
using System.Collections.Concurrent;
55
using System.Collections.Generic;
66
using System.Linq;
7+
using System.Numerics;
78

89
namespace OnnxStack.Core
910
{
@@ -169,5 +170,21 @@ public static ConcurrentDictionary<T, U> ToConcurrentDictionary<S, T, U>(this IE
169170
{
170171
return new ConcurrentDictionary<T, U>(source.ToDictionary(keySelector, elementSelector));
171172
}
173+
174+
175+
/// <summary>
176+
/// Gets the full prod of a dimension
177+
/// </summary>
178+
/// <param name="array">The dimension array.</param>
179+
/// <returns></returns>
180+
public static T GetBufferLength<T>(this T[] array) where T : INumber<T>
181+
{
182+
T result = T.One;
183+
foreach (T element in array)
184+
{
185+
result *= element;
186+
}
187+
return result;
188+
}
172189
}
173190
}

OnnxStack.Core/Services/IOnnxModelService.cs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,29 @@ public interface IOnnxModelService : IDisposable
7878
Task<IDisposableReadOnlyCollection<DisposableNamedOnnxValue>> RunInferenceAsync(IOnnxModel model, OnnxModelType modelType, IReadOnlyCollection<NamedOnnxValue> inputs);
7979

8080

81+
/// <summary>
82+
/// Runs the inference Use when output size is unknown
83+
/// </summary>
84+
/// <param name="model">The model.</param>
85+
/// <param name="modelType">Type of the model.</param>
86+
/// <param name="inputs">The inputs.</param>
87+
/// <param name="outputs">The outputs.</param>
88+
/// <returns></returns>
89+
IReadOnlyCollection<OrtValue> RunInference(IOnnxModel model, OnnxModelType modelType, Dictionary<string, OrtValue> inputs, IReadOnlyCollection<string> outputs);
90+
91+
92+
/// <summary>
93+
/// Runs the inference asynchronously, Use when output size is known
94+
/// Output buffer size must be known and set before inference is run
95+
/// </summary>
96+
/// <param name="model">The model.</param>
97+
/// <param name="modelType">Type of the model.</param>
98+
/// <param name="inputs">The inputs.</param>
99+
/// <param name="outputs">The outputs.</param>
100+
/// <returns></returns>
101+
Task<IReadOnlyCollection<OrtValue>> RunInferenceAsync(IOnnxModel model, OnnxModelType modelType, Dictionary<string, OrtValue> inputs, Dictionary<string, OrtValue> outputs);
102+
103+
81104
/// <summary>
82105
/// Gets the Sessions input metadata.
83106
/// </summary>
@@ -108,5 +131,6 @@ public interface IOnnxModelService : IDisposable
108131
/// <param name="modelType">Type of model.</param>
109132
/// <returns></returns>
110133
IReadOnlyList<string> GetOutputNames(IOnnxModel model, OnnxModelType modelType);
134+
111135
}
112136
}

OnnxStack.Core/Services/OnnxModelService.cs

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,34 @@ public IReadOnlyList<string> GetOutputNames(IOnnxModel model, OnnxModelType mode
175175
}
176176

177177

178+
/// <summary>
179+
/// Runs inference on the specified model.
180+
/// </summary>
181+
/// <param name="modelType">Type of the model.</param>
182+
/// <param name="inputs">The inputs.</param>
183+
/// <returns></returns>
184+
public IReadOnlyCollection<OrtValue> RunInference(IOnnxModel model, OnnxModelType modelType, Dictionary<string, OrtValue> inputs, IReadOnlyCollection<string> outputs)
185+
{
186+
return GetModelSet(model)
187+
.GetSession(modelType)
188+
.Run(new RunOptions(), inputs, outputs);
189+
}
190+
191+
192+
/// <summary>
193+
/// Runs inference on the specified model.
194+
/// </summary>
195+
/// <param name="modelType">Type of the model.</param>
196+
/// <param name="inputs">The inputs.</param>
197+
/// <returns></returns>
198+
public Task<IReadOnlyCollection<OrtValue>> RunInferenceAsync(IOnnxModel model, OnnxModelType modelType, Dictionary<string, OrtValue> inputs, Dictionary<string, OrtValue> outputs)
199+
{
200+
return GetModelSet(model)
201+
.GetSession(modelType)
202+
.RunAsync(new RunOptions(), inputs.Keys, inputs.Values, outputs.Keys, outputs.Values);
203+
}
204+
205+
178206
/// <summary>
179207
/// Runs inference on the specified model.
180208
/// </summary>

0 commit comments

Comments
 (0)