@@ -18,9 +18,9 @@ public static class OrtExtensions
1818 /// </summary>
1919 /// <param name="metadata">The input metadata.</param>
2020 /// <param name="tensor">The tensor value.</param>
21- public static OrtValue CreateTensorOrtValue < T > ( this NamedMetadata metadata , TensorSpan < T > tensor ) where T : unmanaged, INumber < T >
21+ public static OrtValue CreateTensorOrtValue < T > ( this NamedMetadata metadata , OrtMemoryInfo memoryInfo , TensorSpan < T > tensor ) where T : unmanaged, INumber < T >
2222 {
23- return CreateOrtValue ( metadata , tensor ) ;
23+ return CreateOrtValue ( metadata , tensor , memoryInfo ) ;
2424 }
2525
2626
@@ -40,9 +40,9 @@ public static OrtValue CreateTensorOrtValue(this NamedMetadata metadata, TensorS
4040 /// </summary>
4141 /// <param name="metadata">The input metadata.</param>
4242 /// <param name="tensor">The tensor value.</param>
43- public static OrtValue CreateTensorOrtValue ( this NamedMetadata metadata , TensorSpan < bool > tensor )
43+ public static OrtValue CreateTensorOrtValue ( this NamedMetadata metadata , OrtMemoryInfo memoryInfo , TensorSpan < bool > tensor )
4444 {
45- return OrtValue . CreateTensorValueFromMemory ( OrtMemoryInfo . DefaultInstance , new Memory < bool > ( tensor . Span . ToArray ( ) ) , tensor . Dimensions . ToLong ( ) ) ;
45+ return OrtValue . CreateTensorValueFromMemory ( memoryInfo , new Memory < bool > ( tensor . Span . ToArray ( ) ) , tensor . Dimensions . ToLong ( ) ) ;
4646 }
4747
4848
@@ -51,9 +51,9 @@ public static OrtValue CreateTensorOrtValue(this NamedMetadata metadata, TensorS
5151 /// </summary>
5252 /// <param name="metadata">The input metadata.</param>
5353 /// <param name="tensor">The tensor value.</param>
54- public static OrtValue CreateTensorOrtValue ( this NamedMetadata metadata , TensorSpan < byte > tensor )
54+ public static OrtValue CreateTensorOrtValue ( this NamedMetadata metadata , OrtMemoryInfo memoryInfo , TensorSpan < byte > tensor )
5555 {
56- return OrtValue . CreateTensorValueFromMemory ( OrtMemoryInfo . DefaultInstance , new Memory < byte > ( tensor . Span . ToArray ( ) ) , tensor . Dimensions . ToLong ( ) ) ;
56+ return OrtValue . CreateTensorValueFromMemory ( memoryInfo , new Memory < byte > ( tensor . Span . ToArray ( ) ) , tensor . Dimensions . ToLong ( ) ) ;
5757 }
5858
5959
@@ -63,9 +63,9 @@ public static OrtValue CreateTensorOrtValue(this NamedMetadata metadata, TensorS
6363 /// <typeparam name="T">The type of input value</typeparam>
6464 /// <param name="metadata">The input metadata.</param>
6565 /// <param name="value">The value.</param>
66- public static OrtValue CreateScalarOrtValue < T > ( this NamedMetadata metadata , T value ) where T : unmanaged, INumber < T >
66+ public static OrtValue CreateScalarOrtValue < T > ( this NamedMetadata metadata , OrtMemoryInfo memoryInfo , T value ) where T : unmanaged, INumber < T >
6767 {
68- return metadata . CreateTensorOrtValue ( new TensorSpan < T > ( [ value ] , [ 1 ] ) ) ;
68+ return metadata . CreateTensorOrtValue ( memoryInfo , new TensorSpan < T > ( [ value ] , [ 1 ] ) ) ;
6969 }
7070
7171
@@ -74,7 +74,7 @@ public static OrtValue CreateScalarOrtValue<T>(this NamedMetadata metadata, T va
7474 /// </summary>
7575 /// <param name="metadata">The input metadata.</param>
7676 /// <param name="value">The value.</param>
77- public static OrtValue CreateScalarOrtValue ( this NamedMetadata metadata , string value )
77+ public static OrtValue CreateScalarOrtValue ( this NamedMetadata metadata , OrtMemoryInfo memoryInfo , string value )
7878 {
7979 return metadata . CreateTensorOrtValue ( new TensorSpan < string > ( [ value ] , [ 1 ] ) ) ;
8080 }
@@ -85,9 +85,9 @@ public static OrtValue CreateScalarOrtValue(this NamedMetadata metadata, string
8585 /// </summary>
8686 /// <param name="metadata">The input metadata.</param>
8787 /// <param name="value">The value.</param>
88- public static OrtValue CreateScalarOrtValue ( this NamedMetadata metadata , bool value )
88+ public static OrtValue CreateScalarOrtValue ( this NamedMetadata metadata , OrtMemoryInfo memoryInfo , bool value )
8989 {
90- return metadata . CreateTensorOrtValue ( new TensorSpan < bool > ( [ value ] , [ 1 ] ) ) ;
90+ return metadata . CreateTensorOrtValue ( memoryInfo , new TensorSpan < bool > ( [ value ] , [ 1 ] ) ) ;
9191 }
9292
9393
@@ -96,9 +96,9 @@ public static OrtValue CreateScalarOrtValue(this NamedMetadata metadata, bool va
9696 /// </summary>
9797 /// <param name="metadata">The input metadata.</param>
9898 /// <param name="value">The value.</param>
99- public static OrtValue CreateScalarOrtValue ( this NamedMetadata metadata , byte value )
99+ public static OrtValue CreateScalarOrtValue ( this NamedMetadata metadata , OrtMemoryInfo memoryInfo , byte value )
100100 {
101- return metadata . CreateTensorOrtValue ( new TensorSpan < byte > ( [ value ] , [ 1 ] ) ) ;
101+ return metadata . CreateTensorOrtValue ( memoryInfo , new TensorSpan < byte > ( [ value ] , [ 1 ] ) ) ;
102102 }
103103
104104
@@ -108,9 +108,9 @@ public static OrtValue CreateScalarOrtValue(this NamedMetadata metadata, byte va
108108 /// <param name="metadata">The metadata.</param>
109109 /// <param name="dimensions">The dimensions.</param>
110110 /// <returns></returns>
111- public static OrtValue CreateOutputBuffer ( this NamedMetadata metadata , ReadOnlySpan < int > dimensions )
111+ public static OrtValue CreateOutputBuffer ( this NamedMetadata metadata , OrtAllocator allocator , ReadOnlySpan < int > dimensions )
112112 {
113- return OrtValue . CreateAllocatedTensorValue ( OrtAllocator . DefaultInstance , metadata . Value . ElementDataType , dimensions . ToLong ( ) ) ;
113+ return OrtValue . CreateAllocatedTensorValue ( allocator , metadata . Value . ElementDataType , dimensions . ToLong ( ) ) ;
114114 }
115115
116116
@@ -229,9 +229,9 @@ private static Tensor<T> CreateTensor<T>(OrtValue ortValue, int[] dimensions) wh
229229 /// <typeparam name="T">The type of input value</typeparam>
230230 /// <param name="metadata">The input metadata.</param>
231231 /// <param name="tensor">The tensor input.</param>
232- private static OrtValue CreateOrtValue < T > ( NamedMetadata metadata , TensorSpan < T > tensor ) where T : unmanaged, INumber < T >
232+ private static OrtValue CreateOrtValue < T > ( NamedMetadata metadata , TensorSpan < T > tensor , OrtMemoryInfo memoryInfo ) where T : unmanaged, INumber < T >
233233 {
234- return CreateOrtValue ( metadata . Value . ElementDataType , tensor ) ;
234+ return CreateOrtValue ( metadata . Value . ElementDataType , tensor , memoryInfo ) ;
235235 }
236236
237237
@@ -242,25 +242,24 @@ private static OrtValue CreateOrtValue<T>(NamedMetadata metadata, TensorSpan<T>
242242 /// <param name="ortType">Type of the ort.</param>
243243 /// <param name="tensor">The tensor.</param>
244244 /// <returns>OrtValue.</returns>
245- public static OrtValue CreateOrtValue < T > ( OrtType ortType , TensorSpan < T > tensor ) where T : unmanaged, INumber < T >
245+ private static OrtValue CreateOrtValue < T > ( OrtType ortType , TensorSpan < T > tensor , OrtMemoryInfo memoryInfo ) where T : unmanaged, INumber < T >
246246 {
247247 var buffer = tensor . Span ;
248248 var dimensions = tensor . Dimensions . ToLong ( ) ;
249- var memoryInstance = OrtMemoryInfo . DefaultInstance ;
250249 return ortType switch
251250 {
252- OrtType . Float => OrtValue . CreateTensorValueFromMemory < float > ( memoryInstance , buffer . ConvertBuffer < T , float > ( ) , dimensions ) ,
253- OrtType . UInt8 => OrtValue . CreateTensorValueFromMemory < byte > ( memoryInstance , buffer . ConvertBuffer < T , byte > ( ) , dimensions ) ,
254- OrtType . Int8 => OrtValue . CreateTensorValueFromMemory < sbyte > ( memoryInstance , buffer . ConvertBuffer < T , sbyte > ( ) , dimensions ) ,
255- OrtType . UInt16 => OrtValue . CreateTensorValueFromMemory < ushort > ( memoryInstance , buffer . ConvertBuffer < T , ushort > ( ) , dimensions ) ,
256- OrtType . Int16 => OrtValue . CreateTensorValueFromMemory < short > ( memoryInstance , buffer . ConvertBuffer < T , short > ( ) , dimensions ) ,
257- OrtType . Int32 => OrtValue . CreateTensorValueFromMemory < int > ( memoryInstance , buffer . ConvertBuffer < T , int > ( ) , dimensions ) ,
258- OrtType . Int64 => OrtValue . CreateTensorValueFromMemory < long > ( memoryInstance , buffer . ConvertBuffer < T , long > ( ) , dimensions ) ,
259- OrtType . Double => OrtValue . CreateTensorValueFromMemory < double > ( memoryInstance , buffer . ConvertBuffer < T , double > ( ) , dimensions ) ,
260- OrtType . UInt32 => OrtValue . CreateTensorValueFromMemory < uint > ( memoryInstance , buffer . ConvertBuffer < T , uint > ( ) , dimensions ) ,
261- OrtType . UInt64 => OrtValue . CreateTensorValueFromMemory < ulong > ( memoryInstance , buffer . ConvertBuffer < T , ulong > ( ) , dimensions ) ,
262- OrtType . Float16 => OrtValue . CreateTensorValueFromMemory < Float16 > ( memoryInstance , buffer . ConvertBufferFloat16 ( ) , dimensions ) ,
263- OrtType . BFloat16 => OrtValue . CreateTensorValueFromMemory < BFloat16 > ( memoryInstance , buffer . ConvertBufferBFloat16 ( ) , dimensions ) ,
251+ OrtType . Float => OrtValue . CreateTensorValueFromMemory < float > ( memoryInfo , buffer . ConvertBuffer < T , float > ( ) , dimensions ) ,
252+ OrtType . UInt8 => OrtValue . CreateTensorValueFromMemory < byte > ( memoryInfo , buffer . ConvertBuffer < T , byte > ( ) , dimensions ) ,
253+ OrtType . Int8 => OrtValue . CreateTensorValueFromMemory < sbyte > ( memoryInfo , buffer . ConvertBuffer < T , sbyte > ( ) , dimensions ) ,
254+ OrtType . UInt16 => OrtValue . CreateTensorValueFromMemory < ushort > ( memoryInfo , buffer . ConvertBuffer < T , ushort > ( ) , dimensions ) ,
255+ OrtType . Int16 => OrtValue . CreateTensorValueFromMemory < short > ( memoryInfo , buffer . ConvertBuffer < T , short > ( ) , dimensions ) ,
256+ OrtType . Int32 => OrtValue . CreateTensorValueFromMemory < int > ( memoryInfo , buffer . ConvertBuffer < T , int > ( ) , dimensions ) ,
257+ OrtType . Int64 => OrtValue . CreateTensorValueFromMemory < long > ( memoryInfo , buffer . ConvertBuffer < T , long > ( ) , dimensions ) ,
258+ OrtType . Double => OrtValue . CreateTensorValueFromMemory < double > ( memoryInfo , buffer . ConvertBuffer < T , double > ( ) , dimensions ) ,
259+ OrtType . UInt32 => OrtValue . CreateTensorValueFromMemory < uint > ( memoryInfo , buffer . ConvertBuffer < T , uint > ( ) , dimensions ) ,
260+ OrtType . UInt64 => OrtValue . CreateTensorValueFromMemory < ulong > ( memoryInfo , buffer . ConvertBuffer < T , ulong > ( ) , dimensions ) ,
261+ OrtType . Float16 => OrtValue . CreateTensorValueFromMemory < Float16 > ( memoryInfo , buffer . ConvertBufferFloat16 ( ) , dimensions ) ,
262+ OrtType . BFloat16 => OrtValue . CreateTensorValueFromMemory < BFloat16 > ( memoryInfo , buffer . ConvertBufferBFloat16 ( ) , dimensions ) ,
264263 _ => throw new NotImplementedException ( "Conversion is not currently implemented." )
265264 } ;
266265 }
@@ -271,13 +270,13 @@ public static OrtValue CreateOrtValue<T>(OrtType ortType, TensorSpan<T> tensor)
271270 /// </summary>
272271 /// <param name="original">The original.</param>
273272 /// <returns>OrtValue.</returns>
274- public static OrtValue Clone ( this OrtValue original )
273+ public static OrtValue Clone ( this OrtValue original , OrtAllocator allocator )
275274 {
276275 var info = original . GetTensorTypeAndShape ( ) ;
277276 return info . ElementDataType switch
278277 {
279- OrtType . Float => original . Clone < float > ( info ) ,
280- OrtType . Float16 => original . Clone < Float16 > ( info ) ,
278+ OrtType . Float => original . Clone < float > ( info , allocator ) ,
279+ OrtType . Float16 => original . Clone < Float16 > ( info , allocator ) ,
281280 _ => throw new NotSupportedException ( $ "Unsupported element type: { info . ElementDataType } ")
282281 } ;
283282 }
@@ -290,9 +289,9 @@ public static OrtValue Clone(this OrtValue original)
290289 /// <param name="original">The original.</param>
291290 /// <param name="info">The information.</param>
292291 /// <returns>OrtValue.</returns>
293- public static OrtValue Clone < T > ( this OrtValue original , OrtTensorTypeAndShapeInfo info ) where T : unmanaged
292+ public static OrtValue Clone < T > ( this OrtValue original , OrtTensorTypeAndShapeInfo info , OrtAllocator allocator ) where T : unmanaged
294293 {
295- var newValue = OrtValue . CreateAllocatedTensorValue ( OrtAllocator . DefaultInstance , info . ElementDataType , info . Shape ) ;
294+ var newValue = OrtValue . CreateAllocatedTensorValue ( allocator , info . ElementDataType , info . Shape ) ;
296295 var source = original . GetTensorDataAsSpan < T > ( ) ;
297296 var destination = newValue . GetTensorMutableDataAsSpan < T > ( ) ;
298297 source . CopyTo ( destination ) ;
0 commit comments