Skip to content

Commit 37e7ddb

Browse files
committed
Centralize use of OrtAllocator and OrtMemoryInfo
1 parent de0ec98 commit 37e7ddb

File tree

12 files changed

+121
-136
lines changed

12 files changed

+121
-136
lines changed

TensorStack.Common/Common/ModelMetadata.cs

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// Copyright (c) TensorStack. All rights reserved.
22
// Licensed under the Apache 2.0 License.
33
using Microsoft.ML.OnnxRuntime;
4+
using Microsoft.ML.OnnxRuntime.Tensors;
45
using System.Collections.Generic;
56
using System.Linq;
67

@@ -12,16 +13,24 @@ public sealed record ModelMetadata
1213
/// Initializes a new instance of the <see cref="ModelMetadata"/> class.
1314
/// </summary>
1415
/// <param name="session">The session.</param>
15-
public ModelMetadata(InferenceSession session)
16+
public ModelMetadata(InferenceSession session, OrtAllocator allocator)
1617
{
18+
Allocator = allocator;
1719
Inputs = session.InputMetadata
1820
.Select(NamedMetadata.Create)
1921
.ToList();
2022
Outputs = session.OutputMetadata
2123
.Select(NamedMetadata.Create)
2224
.ToList();
25+
26+
OutputElementType = Outputs[0].ElementType;
2327
}
2428

29+
/// <summary>
30+
/// Gets the default allocator.
31+
/// </summary>
32+
public OrtAllocator Allocator { get; }
33+
2534
/// <summary>
2635
/// Gets or sets the inputs.
2736
/// </summary>
@@ -31,5 +40,10 @@ public ModelMetadata(InferenceSession session)
3140
/// Gets or sets the outputs.
3241
/// </summary>
3342
public IReadOnlyList<NamedMetadata> Outputs { get; }
43+
44+
/// <summary>
45+
/// Gets the type of the data.
46+
/// </summary>
47+
public TensorElementType OutputElementType { get; }
3448
}
3549
}

TensorStack.Common/Common/NamedMetadata.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// Copyright (c) TensorStack. All rights reserved.
22
// Licensed under the Apache 2.0 License.
33
using Microsoft.ML.OnnxRuntime;
4+
using Microsoft.ML.OnnxRuntime.Tensors;
45
using System;
56
using System.Collections.Generic;
67

@@ -9,6 +10,8 @@ namespace TensorStack.Common
910
public sealed record NamedMetadata(string Name, NodeMetadata Value)
1011
{
1112
public ReadOnlySpan<int> Dimensions => Value.Dimensions;
13+
public Type DataType => Value.ElementType;
14+
public TensorElementType ElementType => Value.ElementDataType;
1215

1316
/// <summary>
1417
/// Creates the specified metadata.

TensorStack.Common/ExecutionProvider.cs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,19 @@ namespace TensorStack.Common
88
public class ExecutionProvider
99
{
1010
private readonly string _name;
11+
private readonly OrtMemoryInfo _memoryInfo;
12+
1113
private readonly Func<ModelConfig, SessionOptions> _sessionOptionsFactory;
1214

13-
public ExecutionProvider(string name, Func<ModelConfig, SessionOptions> sessionOptionsFactory)
15+
public ExecutionProvider(string name, OrtMemoryInfo memoryInfo, Func<ModelConfig, SessionOptions> sessionOptionsFactory)
1416
{
1517
_name = name;
18+
_memoryInfo = memoryInfo;
1619
_sessionOptionsFactory = sessionOptionsFactory;
1720
}
1821

1922
public string Name => _name;
23+
public OrtMemoryInfo MemoryInfo => _memoryInfo;
2024

2125
public SessionOptions CreateSession(ModelConfig modelConfig)
2226
{

TensorStack.Common/Extensions/OrtExtensions.cs

Lines changed: 35 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@ public static class OrtExtensions
1818
/// </summary>
1919
/// <param name="metadata">The input metadata.</param>
2020
/// <param name="tensor">The tensor value.</param>
21-
public static OrtValue CreateTensorOrtValue<T>(this NamedMetadata metadata, TensorSpan<T> tensor) where T : unmanaged, INumber<T>
21+
public static OrtValue CreateTensorOrtValue<T>(this NamedMetadata metadata, OrtMemoryInfo memoryInfo, TensorSpan<T> tensor) where T : unmanaged, INumber<T>
2222
{
23-
return CreateOrtValue(metadata, tensor);
23+
return CreateOrtValue(metadata, tensor, memoryInfo);
2424
}
2525

2626

@@ -40,9 +40,9 @@ public static OrtValue CreateTensorOrtValue(this NamedMetadata metadata, TensorS
4040
/// </summary>
4141
/// <param name="metadata">The input metadata.</param>
4242
/// <param name="tensor">The tensor value.</param>
43-
public static OrtValue CreateTensorOrtValue(this NamedMetadata metadata, TensorSpan<bool> tensor)
43+
public static OrtValue CreateTensorOrtValue(this NamedMetadata metadata, OrtMemoryInfo memoryInfo, TensorSpan<bool> tensor)
4444
{
45-
return OrtValue.CreateTensorValueFromMemory(OrtMemoryInfo.DefaultInstance, new Memory<bool>(tensor.Span.ToArray()), tensor.Dimensions.ToLong());
45+
return OrtValue.CreateTensorValueFromMemory(memoryInfo, new Memory<bool>(tensor.Span.ToArray()), tensor.Dimensions.ToLong());
4646
}
4747

4848

@@ -51,9 +51,9 @@ public static OrtValue CreateTensorOrtValue(this NamedMetadata metadata, TensorS
5151
/// </summary>
5252
/// <param name="metadata">The input metadata.</param>
5353
/// <param name="tensor">The tensor value.</param>
54-
public static OrtValue CreateTensorOrtValue(this NamedMetadata metadata, TensorSpan<byte> tensor)
54+
public static OrtValue CreateTensorOrtValue(this NamedMetadata metadata, OrtMemoryInfo memoryInfo, TensorSpan<byte> tensor)
5555
{
56-
return OrtValue.CreateTensorValueFromMemory(OrtMemoryInfo.DefaultInstance, new Memory<byte>(tensor.Span.ToArray()), tensor.Dimensions.ToLong());
56+
return OrtValue.CreateTensorValueFromMemory(memoryInfo, new Memory<byte>(tensor.Span.ToArray()), tensor.Dimensions.ToLong());
5757
}
5858

5959

@@ -63,9 +63,9 @@ public static OrtValue CreateTensorOrtValue(this NamedMetadata metadata, TensorS
6363
/// <typeparam name="T">The type of input value</typeparam>
6464
/// <param name="metadata">The input metadata.</param>
6565
/// <param name="value">The value.</param>
66-
public static OrtValue CreateScalarOrtValue<T>(this NamedMetadata metadata, T value) where T : unmanaged, INumber<T>
66+
public static OrtValue CreateScalarOrtValue<T>(this NamedMetadata metadata, OrtMemoryInfo memoryInfo, T value) where T : unmanaged, INumber<T>
6767
{
68-
return metadata.CreateTensorOrtValue(new TensorSpan<T>([value], [1]));
68+
return metadata.CreateTensorOrtValue(memoryInfo, new TensorSpan<T>([value], [1]));
6969
}
7070

7171

@@ -74,7 +74,7 @@ public static OrtValue CreateScalarOrtValue<T>(this NamedMetadata metadata, T va
7474
/// </summary>
7575
/// <param name="metadata">The input metadata.</param>
7676
/// <param name="value">The value.</param>
77-
public static OrtValue CreateScalarOrtValue(this NamedMetadata metadata, string value)
77+
public static OrtValue CreateScalarOrtValue(this NamedMetadata metadata, OrtMemoryInfo memoryInfo, string value)
7878
{
7979
return metadata.CreateTensorOrtValue(new TensorSpan<string>([value], [1]));
8080
}
@@ -85,9 +85,9 @@ public static OrtValue CreateScalarOrtValue(this NamedMetadata metadata, string
8585
/// </summary>
8686
/// <param name="metadata">The input metadata.</param>
8787
/// <param name="value">The value.</param>
88-
public static OrtValue CreateScalarOrtValue(this NamedMetadata metadata, bool value)
88+
public static OrtValue CreateScalarOrtValue(this NamedMetadata metadata, OrtMemoryInfo memoryInfo, bool value)
8989
{
90-
return metadata.CreateTensorOrtValue(new TensorSpan<bool>([value], [1]));
90+
return metadata.CreateTensorOrtValue(memoryInfo, new TensorSpan<bool>([value], [1]));
9191
}
9292

9393

@@ -96,9 +96,9 @@ public static OrtValue CreateScalarOrtValue(this NamedMetadata metadata, bool va
9696
/// </summary>
9797
/// <param name="metadata">The input metadata.</param>
9898
/// <param name="value">The value.</param>
99-
public static OrtValue CreateScalarOrtValue(this NamedMetadata metadata, byte value)
99+
public static OrtValue CreateScalarOrtValue(this NamedMetadata metadata, OrtMemoryInfo memoryInfo, byte value)
100100
{
101-
return metadata.CreateTensorOrtValue(new TensorSpan<byte>([value], [1]));
101+
return metadata.CreateTensorOrtValue(memoryInfo, new TensorSpan<byte>([value], [1]));
102102
}
103103

104104

@@ -108,9 +108,9 @@ public static OrtValue CreateScalarOrtValue(this NamedMetadata metadata, byte va
108108
/// <param name="metadata">The metadata.</param>
109109
/// <param name="dimensions">The dimensions.</param>
110110
/// <returns></returns>
111-
public static OrtValue CreateOutputBuffer(this NamedMetadata metadata, ReadOnlySpan<int> dimensions)
111+
public static OrtValue CreateOutputBuffer(this NamedMetadata metadata, OrtAllocator allocator, ReadOnlySpan<int> dimensions)
112112
{
113-
return OrtValue.CreateAllocatedTensorValue(OrtAllocator.DefaultInstance, metadata.Value.ElementDataType, dimensions.ToLong());
113+
return OrtValue.CreateAllocatedTensorValue(allocator, metadata.Value.ElementDataType, dimensions.ToLong());
114114
}
115115

116116

@@ -229,9 +229,9 @@ private static Tensor<T> CreateTensor<T>(OrtValue ortValue, int[] dimensions) wh
229229
/// <typeparam name="T">The type of input value</typeparam>
230230
/// <param name="metadata">The input metadata.</param>
231231
/// <param name="tensor">The tensor input.</param>
232-
private static OrtValue CreateOrtValue<T>(NamedMetadata metadata, TensorSpan<T> tensor) where T : unmanaged, INumber<T>
232+
private static OrtValue CreateOrtValue<T>(NamedMetadata metadata, TensorSpan<T> tensor, OrtMemoryInfo memoryInfo) where T : unmanaged, INumber<T>
233233
{
234-
return CreateOrtValue(metadata.Value.ElementDataType, tensor);
234+
return CreateOrtValue(metadata.Value.ElementDataType, tensor, memoryInfo);
235235
}
236236

237237

@@ -242,25 +242,24 @@ private static OrtValue CreateOrtValue<T>(NamedMetadata metadata, TensorSpan<T>
242242
/// <param name="ortType">Type of the ort.</param>
243243
/// <param name="tensor">The tensor.</param>
244244
/// <returns>OrtValue.</returns>
245-
public static OrtValue CreateOrtValue<T>(OrtType ortType, TensorSpan<T> tensor) where T : unmanaged, INumber<T>
245+
private static OrtValue CreateOrtValue<T>(OrtType ortType, TensorSpan<T> tensor, OrtMemoryInfo memoryInfo) where T : unmanaged, INumber<T>
246246
{
247247
var buffer = tensor.Span;
248248
var dimensions = tensor.Dimensions.ToLong();
249-
var memoryInstance = OrtMemoryInfo.DefaultInstance;
250249
return ortType switch
251250
{
252-
OrtType.Float => OrtValue.CreateTensorValueFromMemory<float>(memoryInstance, buffer.ConvertBuffer<T, float>(), dimensions),
253-
OrtType.UInt8 => OrtValue.CreateTensorValueFromMemory<byte>(memoryInstance, buffer.ConvertBuffer<T, byte>(), dimensions),
254-
OrtType.Int8 => OrtValue.CreateTensorValueFromMemory<sbyte>(memoryInstance, buffer.ConvertBuffer<T, sbyte>(), dimensions),
255-
OrtType.UInt16 => OrtValue.CreateTensorValueFromMemory<ushort>(memoryInstance, buffer.ConvertBuffer<T, ushort>(), dimensions),
256-
OrtType.Int16 => OrtValue.CreateTensorValueFromMemory<short>(memoryInstance, buffer.ConvertBuffer<T, short>(), dimensions),
257-
OrtType.Int32 => OrtValue.CreateTensorValueFromMemory<int>(memoryInstance, buffer.ConvertBuffer<T, int>(), dimensions),
258-
OrtType.Int64 => OrtValue.CreateTensorValueFromMemory<long>(memoryInstance, buffer.ConvertBuffer<T, long>(), dimensions),
259-
OrtType.Double => OrtValue.CreateTensorValueFromMemory<double>(memoryInstance, buffer.ConvertBuffer<T, double>(), dimensions),
260-
OrtType.UInt32 => OrtValue.CreateTensorValueFromMemory<uint>(memoryInstance, buffer.ConvertBuffer<T, uint>(), dimensions),
261-
OrtType.UInt64 => OrtValue.CreateTensorValueFromMemory<ulong>(memoryInstance, buffer.ConvertBuffer<T, ulong>(), dimensions),
262-
OrtType.Float16 => OrtValue.CreateTensorValueFromMemory<Float16>(memoryInstance, buffer.ConvertBufferFloat16(), dimensions),
263-
OrtType.BFloat16 => OrtValue.CreateTensorValueFromMemory<BFloat16>(memoryInstance, buffer.ConvertBufferBFloat16(), dimensions),
251+
OrtType.Float => OrtValue.CreateTensorValueFromMemory<float>(memoryInfo, buffer.ConvertBuffer<T, float>(), dimensions),
252+
OrtType.UInt8 => OrtValue.CreateTensorValueFromMemory<byte>(memoryInfo, buffer.ConvertBuffer<T, byte>(), dimensions),
253+
OrtType.Int8 => OrtValue.CreateTensorValueFromMemory<sbyte>(memoryInfo, buffer.ConvertBuffer<T, sbyte>(), dimensions),
254+
OrtType.UInt16 => OrtValue.CreateTensorValueFromMemory<ushort>(memoryInfo, buffer.ConvertBuffer<T, ushort>(), dimensions),
255+
OrtType.Int16 => OrtValue.CreateTensorValueFromMemory<short>(memoryInfo, buffer.ConvertBuffer<T, short>(), dimensions),
256+
OrtType.Int32 => OrtValue.CreateTensorValueFromMemory<int>(memoryInfo, buffer.ConvertBuffer<T, int>(), dimensions),
257+
OrtType.Int64 => OrtValue.CreateTensorValueFromMemory<long>(memoryInfo, buffer.ConvertBuffer<T, long>(), dimensions),
258+
OrtType.Double => OrtValue.CreateTensorValueFromMemory<double>(memoryInfo, buffer.ConvertBuffer<T, double>(), dimensions),
259+
OrtType.UInt32 => OrtValue.CreateTensorValueFromMemory<uint>(memoryInfo, buffer.ConvertBuffer<T, uint>(), dimensions),
260+
OrtType.UInt64 => OrtValue.CreateTensorValueFromMemory<ulong>(memoryInfo, buffer.ConvertBuffer<T, ulong>(), dimensions),
261+
OrtType.Float16 => OrtValue.CreateTensorValueFromMemory<Float16>(memoryInfo, buffer.ConvertBufferFloat16(), dimensions),
262+
OrtType.BFloat16 => OrtValue.CreateTensorValueFromMemory<BFloat16>(memoryInfo, buffer.ConvertBufferBFloat16(), dimensions),
264263
_ => throw new NotImplementedException("Conversion is not currently implemented.")
265264
};
266265
}
@@ -271,13 +270,13 @@ public static OrtValue CreateOrtValue<T>(OrtType ortType, TensorSpan<T> tensor)
271270
/// </summary>
272271
/// <param name="original">The original.</param>
273272
/// <returns>OrtValue.</returns>
274-
public static OrtValue Clone(this OrtValue original)
273+
public static OrtValue Clone(this OrtValue original, OrtAllocator allocator)
275274
{
276275
var info = original.GetTensorTypeAndShape();
277276
return info.ElementDataType switch
278277
{
279-
OrtType.Float => original.Clone<float>(info),
280-
OrtType.Float16 => original.Clone<Float16>(info),
278+
OrtType.Float => original.Clone<float>(info, allocator),
279+
OrtType.Float16 => original.Clone<Float16>(info, allocator),
281280
_ => throw new NotSupportedException($"Unsupported element type: {info.ElementDataType}")
282281
};
283282
}
@@ -290,9 +289,9 @@ public static OrtValue Clone(this OrtValue original)
290289
/// <param name="original">The original.</param>
291290
/// <param name="info">The information.</param>
292291
/// <returns>OrtValue.</returns>
293-
public static OrtValue Clone<T>(this OrtValue original, OrtTensorTypeAndShapeInfo info) where T : unmanaged
292+
public static OrtValue Clone<T>(this OrtValue original, OrtTensorTypeAndShapeInfo info, OrtAllocator allocator) where T : unmanaged
294293
{
295-
var newValue = OrtValue.CreateAllocatedTensorValue(OrtAllocator.DefaultInstance, info.ElementDataType, info.Shape);
294+
var newValue = OrtValue.CreateAllocatedTensorValue(allocator, info.ElementDataType, info.Shape);
296295
var source = original.GetTensorDataAsSpan<T>();
297296
var destination = newValue.GetTensorMutableDataAsSpan<T>();
298297
source.CopyTo(destination);

0 commit comments

Comments
 (0)