|
| 1 | +using Microsoft.ML.OnnxRuntime; |
| 2 | +using System; |
| 3 | +using System.Collections.Generic; |
| 4 | +using System.Linq; |
| 5 | + |
| 6 | +namespace TensorStack.Providers |
| 7 | +{ |
| 8 | + public static class Devices |
| 9 | + { |
| 10 | + private static OrtEnv _environment; |
| 11 | + private static EnvironmentCreationOptions _environmentOptions; |
| 12 | + |
| 13 | + /// <summary> |
| 14 | + /// Initializes this instance. |
| 15 | + /// </summary> |
| 16 | + public static void Initialize() |
| 17 | + { |
| 18 | + Initialize(new EnvironmentCreationOptions |
| 19 | + { |
| 20 | + logId = "TensorStack", |
| 21 | + threadOptions = new OrtThreadingOptions |
| 22 | + { |
| 23 | + GlobalSpinControl = true, |
| 24 | + GlobalInterOpNumThreads = 1, |
| 25 | + GlobalIntraOpNumThreads = 1 |
| 26 | + } |
| 27 | + }); |
| 28 | + } |
| 29 | + |
| 30 | + |
| 31 | + /// <summary> |
| 32 | + /// Initializes the specified environment options. |
| 33 | + /// </summary> |
| 34 | + /// <param name="environmentOptions">The environment options.</param> |
| 35 | + public static void Initialize(EnvironmentCreationOptions environmentOptions) |
| 36 | + { |
| 37 | + _environmentOptions = environmentOptions; |
| 38 | + _environment = OrtEnv.CreateInstanceWithOptions(ref _environmentOptions); |
| 39 | + } |
| 40 | + |
| 41 | + |
| 42 | + /// <summary> |
| 43 | + /// Gets the devices. |
| 44 | + /// </summary> |
| 45 | + /// <param name="executionProvider">The execution provider.</param> |
| 46 | + /// <param name="libraryPath">The library path.</param> |
| 47 | + public static IReadOnlyList<Device> GetDevices(string executionProvider, string libraryPath = default) |
| 48 | + { |
| 49 | + if (_environment == null) |
| 50 | + Initialize(); |
| 51 | + |
| 52 | + var providers = _environment.GetAvailableProviders(); |
| 53 | + if (!providers.Contains(executionProvider, StringComparer.OrdinalIgnoreCase)) |
| 54 | + return []; |
| 55 | + |
| 56 | + if (!string.IsNullOrEmpty(executionProvider)) |
| 57 | + _environment.RegisterExecutionProviderLibrary(executionProvider, libraryPath); |
| 58 | + |
| 59 | + var devices = new List<Device>(); |
| 60 | + foreach (var epDevice in _environment.GetEpDevices()) |
| 61 | + { |
| 62 | + if (!epDevice.EpName.Equals(executionProvider, StringComparison.OrdinalIgnoreCase)) |
| 63 | + continue; |
| 64 | + |
| 65 | + devices.Add(CreateDevice(epDevice)); |
| 66 | + } |
| 67 | + return devices; |
| 68 | + } |
| 69 | + |
| 70 | + |
| 71 | + /// <summary> |
| 72 | + /// Creates the device. |
| 73 | + /// </summary> |
| 74 | + /// <param name="epDevice">The ep device.</param> |
| 75 | + /// <returns>Device.</returns> |
| 76 | + private static Device CreateDevice(OrtEpDevice epDevice) |
| 77 | + { |
| 78 | + var device = epDevice.HardwareDevice; |
| 79 | + var metadata = device.Metadata.Entries; |
| 80 | + return new Device |
| 81 | + { |
| 82 | + Type = Enum.Parse<DeviceType>(device.Type.ToString()), |
| 83 | + Name = metadata.ParseOrDefault("Description", string.Empty), |
| 84 | + Memory = metadata.ParseOrDefault("DxgiVideoMemory", 0, " MB"), |
| 85 | + AdapterIndex = metadata.ParseOrDefault("DxgiAdapterNumber", 0), |
| 86 | + PerformanceIndex = metadata.ParseOrDefault("DxgiHighPerformanceIndex", 0), |
| 87 | + HardwareLUID = metadata.ParseOrDefault("LUID", 0), |
| 88 | + HardwareID = (int)device.DeviceId, |
| 89 | + HardwareVendor = device.Vendor, |
| 90 | + HardwareVendorId = (int)device.VendorId, |
| 91 | + }; |
| 92 | + } |
| 93 | + |
| 94 | + |
| 95 | + /// <summary> |
| 96 | + /// Parse Metadata values |
| 97 | + /// </summary> |
| 98 | + /// <typeparam name="T"></typeparam> |
| 99 | + /// <param name="metadata">The metadata.</param> |
| 100 | + /// <param name="key">The key.</param> |
| 101 | + /// <param name="defaultValue">The default value.</param> |
| 102 | + /// <param name="replace">The replace.</param> |
| 103 | + /// <returns>T.</returns> |
| 104 | + private static T ParseOrDefault<T>(this IReadOnlyDictionary<string, string> metadata, string key, T defaultValue, string replace = null) |
| 105 | + { |
| 106 | + if (!metadata.ContainsKey(key)) |
| 107 | + return defaultValue; |
| 108 | + |
| 109 | + var value = metadata[key].Trim(); |
| 110 | + if (!string.IsNullOrEmpty(replace)) |
| 111 | + value = value.Replace(replace, string.Empty); |
| 112 | + |
| 113 | + if (typeof(T) == typeof(string)) |
| 114 | + { |
| 115 | + return (T)(object)value; |
| 116 | + } |
| 117 | + else if (typeof(T) == typeof(int)) |
| 118 | + { |
| 119 | + if (!int.TryParse(value, out var intResult)) |
| 120 | + return defaultValue; |
| 121 | + |
| 122 | + return (T)(object)intResult; |
| 123 | + } |
| 124 | + else if (typeof(T) == typeof(Enum)) |
| 125 | + { |
| 126 | + if (!Enum.TryParse(typeof(T), value, out var enumResult)) |
| 127 | + return defaultValue; |
| 128 | + |
| 129 | + return (T)enumResult; |
| 130 | + } |
| 131 | + return defaultValue; |
| 132 | + } |
| 133 | + } |
| 134 | +} |
0 commit comments