1- using Microsoft . ML . OnnxRuntime ;
1+ // Copyright (c) TensorStack. All rights reserved.
2+ // Licensed under the Apache 2.0 License.
3+ using Microsoft . ML . OnnxRuntime ;
24using System ;
35using System . Collections . Generic ;
46using System . Linq ;
57
6- namespace TensorStack . Providers
8+ namespace TensorStack . Common
79{
8- public static class Devices
10+ public static class DeviceManager
911 {
1012 private static OrtEnv _environment ;
1113 private static EnvironmentCreationOptions _environmentOptions ;
14+ private static IReadOnlyList < Device > _devices ;
15+ private static string _deviceProvider ;
1216
1317 /// <summary>
1418 /// Initializes this instance.
1519 /// </summary>
16- public static void Initialize ( )
20+ public static void Initialize ( string executionProvider , string libraryPath = default )
1721 {
1822 Initialize ( new EnvironmentCreationOptions
1923 {
@@ -24,50 +28,52 @@ public static void Initialize()
2428 GlobalInterOpNumThreads = 1 ,
2529 GlobalIntraOpNumThreads = 1
2630 }
27- } ) ;
31+ } , executionProvider , libraryPath ) ;
2832 }
2933
3034
3135 /// <summary>
3236 /// Initializes the specified environment options.
3337 /// </summary>
3438 /// <param name="environmentOptions">The environment options.</param>
35- public static void Initialize ( EnvironmentCreationOptions environmentOptions )
39+ public static void Initialize ( EnvironmentCreationOptions environmentOptions , string executionProvider , string libraryPath = default )
3640 {
41+ if ( _environment is not null )
42+ throw new Exception ( "Environment is already initialized." ) ;
43+
44+ _deviceProvider = executionProvider ;
3745 _environmentOptions = environmentOptions ;
3846 _environment = OrtEnv . CreateInstanceWithOptions ( ref _environmentOptions ) ;
39- }
40-
41-
42- /// <summary>
43- /// Gets the devices.
44- /// </summary>
45- /// <param name="executionProvider">The execution provider.</param>
46- /// <param name="libraryPath">The library path.</param>
47- public static IReadOnlyList < Device > GetDevices ( string executionProvider , string libraryPath = default )
48- {
49- if ( _environment == null )
50- Initialize ( ) ;
5147
5248 var providers = _environment . GetAvailableProviders ( ) ;
53- if ( ! providers . Contains ( executionProvider , StringComparer . OrdinalIgnoreCase ) )
54- return [ ] ;
49+ if ( ! providers . Contains ( _deviceProvider , StringComparer . OrdinalIgnoreCase ) )
50+ throw new Exception ( $ "Provider { _deviceProvider } was not found in GetAvailableProviders()." ) ;
5551
56- if ( ! string . IsNullOrEmpty ( executionProvider ) )
57- _environment . RegisterExecutionProviderLibrary ( executionProvider , libraryPath ) ;
52+ if ( ! string . IsNullOrEmpty ( libraryPath ) )
53+ _environment . RegisterExecutionProviderLibrary ( _deviceProvider , libraryPath ) ;
5854
5955 var devices = new List < Device > ( ) ;
6056 foreach ( var epDevice in _environment . GetEpDevices ( ) )
6157 {
62- if ( ! epDevice . EpName . Equals ( executionProvider , StringComparison . OrdinalIgnoreCase ) )
63- continue ;
64-
65- devices . Add ( CreateDevice ( epDevice ) ) ;
58+ if ( epDevice . HardwareDevice . Type == OrtHardwareDeviceType . CPU || epDevice . EpName . Equals ( _deviceProvider , StringComparison . OrdinalIgnoreCase ) )
59+ devices . Add ( CreateDevice ( epDevice ) ) ;
6660 }
67- return devices ;
61+ _devices = devices ;
6862 }
6963
7064
65+ /// <summary>
66+ /// Gets the devices.
67+ /// </summary>
68+ public static IReadOnlyList < Device > Devices => _devices ;
69+
70+
71+ /// <summary>
72+ /// The cpu provider name
73+ /// </summary>
74+ public const string CPUProviderName = "CPUExecutionProvider" ;
75+
76+
7177 /// <summary>
7278 /// Creates the device.
7379 /// </summary>
@@ -79,11 +85,11 @@ private static Device CreateDevice(OrtEpDevice epDevice)
7985 var metadata = device . Metadata . Entries ;
8086 return new Device
8187 {
88+ Id = metadata . ParseOrDefault ( "DxgiAdapterNumber" , 0 ) ,
89+ DeviceId = metadata . ParseOrDefault ( "DxgiHighPerformanceIndex" , 0 ) ,
8290 Type = Enum . Parse < DeviceType > ( device . Type . ToString ( ) ) ,
8391 Name = metadata . ParseOrDefault ( "Description" , string . Empty ) ,
8492 Memory = metadata . ParseOrDefault ( "DxgiVideoMemory" , 0 , " MB" ) ,
85- AdapterIndex = metadata . ParseOrDefault ( "DxgiAdapterNumber" , 0 ) ,
86- PerformanceIndex = metadata . ParseOrDefault ( "DxgiHighPerformanceIndex" , 0 ) ,
8793 HardwareLUID = metadata . ParseOrDefault ( "LUID" , 0 ) ,
8894 HardwareID = ( int ) device . DeviceId ,
8995 HardwareVendor = device . Vendor ,
0 commit comments