Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit 2f6cf61

Browse files
committed
Simplify NodeMetadata access, Simplify Core API calls
1 parent 5cb4ff3 commit 2f6cf61

17 files changed

+381
-323
lines changed

OnnxStack.Core/Extensions.cs

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
using Microsoft.ML.OnnxRuntime;
22
using Microsoft.ML.OnnxRuntime.Tensors;
33
using OnnxStack.Core.Config;
4+
using OnnxStack.Core.Model;
45
using System;
56
using System.Buffers;
67
using System.Collections.Concurrent;
78
using System.Collections.Generic;
89
using System.Linq;
910
using System.Numerics;
11+
using System.Runtime.CompilerServices;
1012
using System.Runtime.InteropServices;
1113

1214
namespace OnnxStack.Core
@@ -244,12 +246,12 @@ public static long[] ToLong(this int[] array)
244246
/// Creates and OrtValue form the DenseTensor and NodeMetaData provided
245247
/// </summary>
246248
/// <param name="tensor">The tensor.</param>
247-
/// <param name="nodeMetadata">The node metadata.</param>
249+
/// <param name="metadata">The metadata.</param>
248250
/// <returns></returns>
249-
public static OrtValue ToOrtValue(this DenseTensor<float> tensor, NodeMetadata nodeMetadata)
251+
public static OrtValue ToOrtValue(this DenseTensor<float> tensor, OnnxNamedMetadata metadata)
250252
{
251253
var dimensions = tensor.Dimensions.ToLong();
252-
return nodeMetadata.ElementDataType switch
254+
return metadata.Value.ElementDataType switch
253255
{
254256
TensorElementType.Float16 => OrtValue.CreateTensorValueFromMemory(OrtMemoryInfo.DefaultInstance, tensor.Buffer.ToFloat16(), dimensions),
255257
TensorElementType.BFloat16 => OrtValue.CreateTensorValueFromMemory(OrtMemoryInfo.DefaultInstance, tensor.Buffer.ToBFloat16(), dimensions),
@@ -259,14 +261,14 @@ public static OrtValue ToOrtValue(this DenseTensor<float> tensor, NodeMetadata n
259261

260262

261263
/// <summary>
262-
/// Creates and allocates output tensors buffer.
264+
/// Creates and allocates the output tensors buffer.
263265
/// </summary>
264-
/// <param name="nodeMetadata">The node metadata.</param>
266+
/// <param name="metadata">The metadata.</param>
265267
/// <param name="dimensions">The dimensions.</param>
266268
/// <returns></returns>
267-
public static OrtValue CreateOutputBuffer(this NodeMetadata nodeMetadata, ReadOnlySpan<int> dimensions)
269+
public static OrtValue CreateOutputBuffer(this OnnxNamedMetadata metadata, ReadOnlySpan<int> dimensions)
268270
{
269-
return OrtValue.CreateAllocatedTensorValue(OrtAllocator.DefaultInstance, nodeMetadata.ElementDataType, dimensions.ToLong());
271+
return OrtValue.CreateAllocatedTensorValue(OrtAllocator.DefaultInstance, metadata.Value.ElementDataType, dimensions.ToLong());
270272
}
271273

272274

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
using Microsoft.ML.OnnxRuntime;
2+
using System.Collections.Generic;
3+
4+
namespace OnnxStack.Core.Model
5+
{
6+
public class OnnxInferenceParameters
7+
{
8+
private OnnxValueCollection _inputs;
9+
private OnnxValueCollection _outputs;
10+
11+
/// <summary>
12+
/// Initializes a new instance of the <see cref="OnnxInferenceParameters"/> class.
13+
/// </summary>
14+
public OnnxInferenceParameters()
15+
{
16+
_inputs = new OnnxValueCollection();
17+
_outputs = new OnnxValueCollection();
18+
}
19+
20+
21+
/// <summary>
22+
/// Adds an input parameter.
23+
/// </summary>
24+
/// <param name="metaData">The meta data.</param>
25+
/// <param name="value">The value.</param>
26+
public void AddInput(OnnxNamedMetadata metaData, OrtValue value)
27+
{
28+
_inputs.Add(metaData, value);
29+
}
30+
31+
32+
/// <summary>
33+
/// Adds an output parameter with known output size.
34+
/// </summary>
35+
/// <param name="metaData">The meta data.</param>
36+
/// <param name="value">The value.</param>
37+
public void AddOutput(OnnxNamedMetadata metaData, OrtValue value)
38+
{
39+
_outputs.Add(metaData, value);
40+
}
41+
42+
43+
/// <summary>
44+
/// Adds an output parameter with unknown output size.
45+
/// </summary>
46+
/// <param name="metaData">The meta data.</param>
47+
public void AddOutput(OnnxNamedMetadata metaData)
48+
{
49+
_outputs.AddName(metaData);
50+
}
51+
52+
53+
/// <summary>
54+
/// Gets the input names.
55+
/// </summary>
56+
public IReadOnlyCollection<string> InputNames => _inputs.Names;
57+
58+
59+
/// <summary>
60+
/// Gets the output names.
61+
/// </summary>
62+
public IReadOnlyCollection<string> OutputNames => _outputs.Names;
63+
64+
65+
/// <summary>
66+
/// Gets the input values.
67+
/// </summary>
68+
public IReadOnlyCollection<OrtValue> InputValues => _inputs.Values;
69+
70+
71+
/// <summary>
72+
/// Gets the output values.
73+
/// </summary>
74+
public IReadOnlyCollection<OrtValue> OutputValues => _outputs.Values;
75+
76+
77+
/// <summary>
78+
/// Gets the input name values.
79+
/// </summary>
80+
public IReadOnlyDictionary<string, OrtValue> InputNameValues => _inputs.NameValues;
81+
82+
83+
/// <summary>
84+
/// Gets the output name values.
85+
/// </summary>
86+
public IReadOnlyDictionary<string, OrtValue> OutputNameValues => _outputs.NameValues;
87+
}
88+
}
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
using System.Collections.Generic;
2+
3+
namespace OnnxStack.Core.Model
4+
{
5+
public record OnnxMetadata
6+
{
7+
/// <summary>
8+
/// Gets or sets the inputs.
9+
/// </summary>
10+
public IReadOnlyList<OnnxNamedMetadata> Inputs { get; set; }
11+
12+
/// <summary>
13+
/// Gets or sets the outputs.
14+
/// </summary>
15+
public IReadOnlyList<OnnxNamedMetadata> Outputs { get; set; }
16+
}
17+
}

OnnxStack.Core/Model/OnnxModelAdapter.cs

Lines changed: 0 additions & 13 deletions
This file was deleted.

OnnxStack.Core/Model/OnnxModelSession.cs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
using Microsoft.Extensions.Configuration;
2-
using Microsoft.ML.OnnxRuntime;
1+
using Microsoft.ML.OnnxRuntime;
32
using OnnxStack.Core.Config;
43
using System;
54
using System.IO;
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
using Microsoft.ML.OnnxRuntime;
2+
using System.Collections.Generic;
3+
4+
namespace OnnxStack.Core.Model
5+
{
6+
public record OnnxNamedMetadata(string Name, NodeMetadata Value)
7+
{
8+
internal static OnnxNamedMetadata Create(KeyValuePair<string, NodeMetadata> metadata)
9+
{
10+
return new OnnxNamedMetadata(metadata.Key, metadata.Value);
11+
}
12+
}
13+
}
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
using Microsoft.ML.OnnxRuntime;
2+
using System.Collections.Generic;
3+
4+
namespace OnnxStack.Core.Model
5+
{
6+
public class OnnxValueCollection
7+
{
8+
private readonly List<OnnxNamedMetadata> _metaData;
9+
private readonly Dictionary<string, OrtValue> _values;
10+
11+
12+
/// <summary>
13+
/// Initializes a new instance of the <see cref="OnnxValueCollection"/> class.
14+
/// </summary>
15+
public OnnxValueCollection()
16+
{
17+
_metaData = new List<OnnxNamedMetadata>();
18+
_values = new Dictionary<string, OrtValue>();
19+
}
20+
21+
22+
/// <summary>
23+
/// Adds the specified OnnxMetadata and OrtValue
24+
/// </summary>
25+
/// <param name="metaData">The meta data.</param>
26+
/// <param name="value">The value.</param>
27+
public void Add(OnnxNamedMetadata metaData, OrtValue value)
28+
{
29+
_metaData.Add(metaData);
30+
_values.Add(metaData.Name, value);
31+
}
32+
33+
34+
/// <summary>
35+
/// Adds the name only.
36+
/// </summary>
37+
/// <param name="metaData">The meta data.</param>
38+
public void AddName(OnnxNamedMetadata metaData)
39+
{
40+
_metaData.Add(metaData);
41+
_values.Add(metaData.Name, default);
42+
}
43+
44+
45+
/// <summary>
46+
/// Gets the names.
47+
/// </summary>
48+
public IReadOnlyCollection<string> Names => _values.Keys;
49+
50+
51+
/// <summary>
52+
/// Gets the values.
53+
/// </summary>
54+
public IReadOnlyCollection<OrtValue> Values => _values.Values;
55+
56+
57+
/// <summary>
58+
/// Gets the name values.
59+
/// </summary>
60+
public IReadOnlyDictionary<string, OrtValue> NameValues => _values;
61+
}
62+
}

OnnxStack.Core/Services/IOnnxModelService.cs

Lines changed: 6 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -60,19 +60,6 @@ public interface IOnnxModelService : IDisposable
6060
Task<bool> IsEnabledAsync(IOnnxModel model, OnnxModelType modelType);
6161

6262

63-
/// <summary>
64-
/// Runs the inference Use when output size is unknown
65-
/// </summary>
66-
/// <param name="model">The model.</param>
67-
/// <param name="modelType">Type of the model.</param>
68-
/// <param name="inputName">Name of the input.</param>
69-
/// <param name="inputValue">The input value.</param>
70-
/// <param name="outputName">Name of the output.</param>
71-
/// <param name="outputValue">The output value.</param>
72-
/// <returns></returns>
73-
IDisposableReadOnlyCollection<OrtValue> RunInference(IOnnxModel model, OnnxModelType modelType, string inputName, OrtValue inputValue, string outputName);
74-
75-
7663
/// <summary>
7764
/// Runs the inference Use when output size is unknown
7865
/// </summary>
@@ -81,21 +68,7 @@ public interface IOnnxModelService : IDisposable
8168
/// <param name="inputs">The inputs.</param>
8269
/// <param name="outputs">The outputs.</param>
8370
/// <returns></returns>
84-
IDisposableReadOnlyCollection<OrtValue> RunInference(IOnnxModel model, OnnxModelType modelType, Dictionary<string, OrtValue> inputs, IReadOnlyCollection<string> outputs);
85-
86-
87-
/// <summary>
88-
/// Runs the inference asynchronously, Use when output size is known
89-
/// Output buffer size must be known and set before inference is run
90-
/// </summary>
91-
/// <param name="model">The model.</param>
92-
/// <param name="modelType">Type of the model.</param>
93-
/// <param name="inputName">Name of the input.</param>
94-
/// <param name="inputValue">The input value.</param>
95-
/// <param name="outputName">Name of the output.</param>
96-
/// <param name="outputValue">The output value.</param>
97-
/// <returns></returns>
98-
Task<IReadOnlyCollection<OrtValue>> RunInferenceAsync(IOnnxModel model, OnnxModelType modelType, string inputName, OrtValue inputValue, string outputName, OrtValue outputValue);
71+
IDisposableReadOnlyCollection<OrtValue> RunInference(IOnnxModel model, OnnxModelType modelType, OnnxInferenceParameters parameters);
9972

10073

10174
/// <summary>
@@ -107,39 +80,15 @@ public interface IOnnxModelService : IDisposable
10780
/// <param name="inputs">The inputs.</param>
10881
/// <param name="outputs">The outputs.</param>
10982
/// <returns></returns>
110-
Task<IReadOnlyCollection<OrtValue>> RunInferenceAsync(IOnnxModel model, OnnxModelType modelType, Dictionary<string, OrtValue> inputs, Dictionary<string, OrtValue> outputs);
83+
Task<IReadOnlyCollection<OrtValue>> RunInferenceAsync(IOnnxModel model, OnnxModelType modelType, OnnxInferenceParameters parameters);
11184

11285

11386
/// <summary>
114-
/// Gets the Sessions input metadata.
87+
/// Gets the model metadata.
11588
/// </summary>
116-
/// <param name="modelType">Type of model.</param>
117-
/// <returns></returns>
118-
IReadOnlyDictionary<string, NodeMetadata> GetInputMetadata(IOnnxModel model, OnnxModelType modelType);
119-
120-
121-
/// <summary>
122-
/// Gets the Sessions input names.
123-
/// </summary>
124-
/// <param name="modelType">Type of model.</param>
125-
/// <returns></returns>
126-
IReadOnlyList<string> GetInputNames(IOnnxModel model, OnnxModelType modelType);
127-
128-
129-
/// <summary>
130-
/// Gets the Sessions output metadata.
131-
/// </summary>
132-
/// <param name="modelType">Type of model.</param>
133-
/// <returns></returns>
134-
IReadOnlyDictionary<string, NodeMetadata> GetOutputMetadata(IOnnxModel model, OnnxModelType modelType);
135-
136-
137-
/// <summary>
138-
/// Gets the Sessions output metadata names.
139-
/// </summary>
140-
/// <param name="modelType">Type of model.</param>
89+
/// <param name="model">The model.</param>
90+
/// <param name="modelType">Type of the model.</param>
14191
/// <returns></returns>
142-
IReadOnlyList<string> GetOutputNames(IOnnxModel model, OnnxModelType modelType);
143-
92+
OnnxMetadata GetModelMetadata(IOnnxModel model, OnnxModelType modelType);
14493
}
14594
}

0 commit comments

Comments
 (0)