Skip to content

Commit d3a96fc

Browse files
committed
Get EP devices
1 parent 4bd6ea6 commit d3a96fc

File tree

9 files changed

+242
-66
lines changed

9 files changed

+242
-66
lines changed

TensorStack.Common/ExecutionProvider.cs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,5 +26,25 @@ public SessionOptions CreateSession(ModelConfig modelConfig)
2626
{
2727
return _sessionOptionsFactory(modelConfig);
2828
}
29+
30+
/// <summary>
31+
/// Gets default CPU provider.
32+
/// </summary>
33+
/// <param name="optimizationLevel">The optimization level.</param>
34+
/// <returns>ExecutionProvider.</returns>
35+
public static ExecutionProvider GetDefault(GraphOptimizationLevel optimizationLevel = GraphOptimizationLevel.ORT_DISABLE_ALL)
36+
{
37+
return new ExecutionProvider("CPU Provider", OrtMemoryInfo.DefaultInstance, configuration =>
38+
{
39+
var sessionOptions = new SessionOptions
40+
{
41+
EnableCpuMemArena = true,
42+
EnableMemoryPattern = true,
43+
GraphOptimizationLevel = optimizationLevel
44+
};
45+
sessionOptions.AppendExecutionProvider_CPU();
46+
return sessionOptions;
47+
});
48+
}
2949
}
3050
}

TensorStack.Provider.CUDA/Provider.cs

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,44 @@
11
using Microsoft.ML.OnnxRuntime;
2+
using System.Collections.Generic;
23
using TensorStack.Common;
34

45
namespace TensorStack.Providers
56
{
67
public static class Provider
78
{
8-
public const string CPUProviderName = "CPU Provider";
9-
public const string CUDAProviderName = "CUDA Provider";
9+
private const string CUDAProviderName = "CUDAExecutionProvider";
10+
private const string CUDALibraryName = "onnxruntime_providers_cuda.dll";
11+
private static IReadOnlyList<Device> _devices;
1012

1113
/// <summary>
12-
/// Gets the CPU provider.
14+
/// Initializes the Provider with the specified environment options.
1315
/// </summary>
16+
/// <param name="environmentOptions">The environment options.</param>
17+
public static void Initialize(EnvironmentCreationOptions environmentOptions)
18+
{
19+
Devices.Initialize(environmentOptions);
20+
GetDevices();
21+
}
22+
23+
24+
/// <summary>
25+
/// Gets the DirectML devices.
26+
/// </summary>
27+
public static IReadOnlyList<Device> GetDevices()
28+
{
29+
_devices ??= Devices.GetDevices(CUDAProviderName, CUDALibraryName);
30+
return _devices;
31+
}
32+
33+
34+
/// <summary>
35+
/// Gets the CUDA provider.
36+
/// </summary>
37+
/// <param name="device">The device.</param>
1438
/// <param name="optimizationLevel">The optimization level.</param>
15-
/// <returns>ExecutionProvider.</returns>
16-
public static ExecutionProvider GetProvider(GraphOptimizationLevel optimizationLevel = GraphOptimizationLevel.ORT_DISABLE_ALL)
39+
public static ExecutionProvider GetProvider(Device device, GraphOptimizationLevel optimizationLevel = GraphOptimizationLevel.ORT_DISABLE_ALL)
1740
{
18-
return new ExecutionProvider(CPUProviderName, OrtMemoryInfo.DefaultInstance, configuration =>
19-
{
20-
var sessionOptions = new SessionOptions
21-
{
22-
EnableCpuMemArena = true,
23-
EnableMemoryPattern = true,
24-
GraphOptimizationLevel = optimizationLevel
25-
};
26-
sessionOptions.AppendExecutionProvider_CPU();
27-
return sessionOptions;
28-
});
41+
return GetProvider(device.DeviceId, optimizationLevel);
2942
}
3043

3144

TensorStack.Provider.CUDA/TensorStack.Provider.CUDA.csproj

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,12 @@
88

99
<!--Projects-->
1010
<ItemGroup Condition=" '$(Configuration)' == 'Debug'">
11-
<ProjectReference Include="..\TensorStack.Common\TensorStack.Common.csproj" />
11+
<ProjectReference Include="..\TensorStack.Provider\TensorStack.Provider.csproj" />
1212
</ItemGroup>
1313

1414
<!--Packages-->
1515
<ItemGroup Condition=" '$(Configuration)' == 'Release'">
16-
<PackageReference Include="TensorStack.Common" Version="$(Version)" />
16+
<PackageReference Include="TensorStack.Provider" Version="$(Version)" />
1717
</ItemGroup>
1818

1919
<!--Other Packages-->

TensorStack.Provider.DML/Provider.cs

Lines changed: 28 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,43 @@
11
using Microsoft.ML.OnnxRuntime;
2+
using System.Collections.Generic;
23
using TensorStack.Common;
34

45
namespace TensorStack.Providers
56
{
67
public static class Provider
78
{
8-
public const string CPUProviderName = "CPU Provider";
9-
public const string DMLProviderName = "DirectML Provider";
9+
private const string DMLProviderName = "DMLExecutionProvider";
10+
private static IReadOnlyList<Device> _devices;
1011

1112
/// <summary>
12-
/// Gets the CPU provider.
13+
/// Initializes the Provider with the specified environment options.
1314
/// </summary>
15+
/// <param name="environmentOptions">The environment options.</param>
16+
public static void Initialize(EnvironmentCreationOptions environmentOptions)
17+
{
18+
Devices.Initialize(environmentOptions);
19+
GetDevices();
20+
}
21+
22+
23+
/// <summary>
24+
/// Gets the DirectML devices.
25+
/// </summary>
26+
public static IReadOnlyList<Device> GetDevices()
27+
{
28+
_devices ??= Devices.GetDevices(DMLProviderName);
29+
return _devices;
30+
}
31+
32+
33+
/// <summary>
34+
/// Gets the DirectML provider.
35+
/// </summary>
36+
/// <param name="device">The device.</param>
1437
/// <param name="optimizationLevel">The optimization level.</param>
15-
/// <returns>ExecutionProvider.</returns>
16-
public static ExecutionProvider GetProvider(GraphOptimizationLevel optimizationLevel = GraphOptimizationLevel.ORT_DISABLE_ALL)
38+
public static ExecutionProvider GetProvider(Device device, GraphOptimizationLevel optimizationLevel = GraphOptimizationLevel.ORT_DISABLE_ALL)
1739
{
18-
return new ExecutionProvider(CPUProviderName, OrtMemoryInfo.DefaultInstance, configuration =>
19-
{
20-
var sessionOptions = new SessionOptions
21-
{
22-
EnableCpuMemArena = true,
23-
EnableMemoryPattern = true,
24-
GraphOptimizationLevel = optimizationLevel
25-
};
26-
sessionOptions.AppendExecutionProvider_CPU();
27-
return sessionOptions;
28-
});
40+
return GetProvider(device.DeviceId, optimizationLevel);
2941
}
3042

3143

TensorStack.Provider.DML/TensorStack.Provider.DML.csproj

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,12 @@
88

99
<!--Projects-->
1010
<ItemGroup Condition=" '$(Configuration)' == 'Debug'">
11-
<ProjectReference Include="..\TensorStack.Common\TensorStack.Common.csproj" />
11+
<ProjectReference Include="..\TensorStack.Provider\TensorStack.Provider.csproj" />
1212
</ItemGroup>
1313

1414
<!--Packages-->
1515
<ItemGroup Condition=" '$(Configuration)' == 'Release'">
16-
<PackageReference Include="TensorStack.Common" Version="$(Version)" />
16+
<PackageReference Include="TensorStack.Provider" Version="$(Version)" />
1717
</ItemGroup>
1818

1919
<!--Other Packages-->

TensorStack.Provider/Device.cs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
namespace TensorStack.Providers
2+
{
3+
public record Device
4+
{
5+
public int DeviceId => PerformanceIndex; // TODO:
6+
public string Name { get; init; }
7+
public DeviceType Type { get; init; }
8+
public int Memory { get; init; }
9+
public int MemoryGB => Memory / 1000;
10+
public int AdapterIndex { get; init; }
11+
public int PerformanceIndex { get; init; }
12+
13+
public int HardwareID { get; init; }
14+
public int HardwareLUID { get; init; }
15+
public int HardwareVendorId { get; init; }
16+
public string HardwareVendor { get; init; }
17+
}
18+
}

TensorStack.Provider/DeviceType.cs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
namespace TensorStack.Providers
2+
{
3+
public enum DeviceType
4+
{
5+
CPU = 0,
6+
GPU = 1,
7+
NPU = 2
8+
}
9+
}

TensorStack.Provider/Devices.cs

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
using Microsoft.ML.OnnxRuntime;
2+
using System;
3+
using System.Collections.Generic;
4+
using System.Linq;
5+
6+
namespace TensorStack.Providers
7+
{
8+
public static class Devices
9+
{
10+
private static OrtEnv _environment;
11+
private static EnvironmentCreationOptions _environmentOptions;
12+
13+
/// <summary>
14+
/// Initializes this instance.
15+
/// </summary>
16+
public static void Initialize()
17+
{
18+
Initialize(new EnvironmentCreationOptions
19+
{
20+
logId = "TensorStack",
21+
threadOptions = new OrtThreadingOptions
22+
{
23+
GlobalSpinControl = true,
24+
GlobalInterOpNumThreads = 1,
25+
GlobalIntraOpNumThreads = 1
26+
}
27+
});
28+
}
29+
30+
31+
/// <summary>
32+
/// Initializes the specified environment options.
33+
/// </summary>
34+
/// <param name="environmentOptions">The environment options.</param>
35+
public static void Initialize(EnvironmentCreationOptions environmentOptions)
36+
{
37+
_environmentOptions = environmentOptions;
38+
_environment = OrtEnv.CreateInstanceWithOptions(ref _environmentOptions);
39+
}
40+
41+
42+
/// <summary>
43+
/// Gets the devices.
44+
/// </summary>
45+
/// <param name="executionProvider">The execution provider.</param>
46+
/// <param name="libraryPath">The library path.</param>
47+
public static IReadOnlyList<Device> GetDevices(string executionProvider, string libraryPath = default)
48+
{
49+
if (_environment == null)
50+
Initialize();
51+
52+
var providers = _environment.GetAvailableProviders();
53+
if (!providers.Contains(executionProvider, StringComparer.OrdinalIgnoreCase))
54+
return [];
55+
56+
if (!string.IsNullOrEmpty(executionProvider))
57+
_environment.RegisterExecutionProviderLibrary(executionProvider, libraryPath);
58+
59+
var devices = new List<Device>();
60+
foreach (var epDevice in _environment.GetEpDevices())
61+
{
62+
if (!epDevice.EpName.Equals(executionProvider, StringComparison.OrdinalIgnoreCase))
63+
continue;
64+
65+
devices.Add(CreateDevice(epDevice));
66+
}
67+
return devices;
68+
}
69+
70+
71+
/// <summary>
72+
/// Creates the device.
73+
/// </summary>
74+
/// <param name="epDevice">The ep device.</param>
75+
/// <returns>Device.</returns>
76+
private static Device CreateDevice(OrtEpDevice epDevice)
77+
{
78+
var device = epDevice.HardwareDevice;
79+
var metadata = device.Metadata.Entries;
80+
return new Device
81+
{
82+
Type = Enum.Parse<DeviceType>(device.Type.ToString()),
83+
Name = metadata.ParseOrDefault("Description", string.Empty),
84+
Memory = metadata.ParseOrDefault("DxgiVideoMemory", 0, " MB"),
85+
AdapterIndex = metadata.ParseOrDefault("DxgiAdapterNumber", 0),
86+
PerformanceIndex = metadata.ParseOrDefault("DxgiHighPerformanceIndex", 0),
87+
HardwareLUID = metadata.ParseOrDefault("LUID", 0),
88+
HardwareID = (int)device.DeviceId,
89+
HardwareVendor = device.Vendor,
90+
HardwareVendorId = (int)device.VendorId,
91+
};
92+
}
93+
94+
95+
/// <summary>
96+
/// Parse Metadata values
97+
/// </summary>
98+
/// <typeparam name="T"></typeparam>
99+
/// <param name="metadata">The metadata.</param>
100+
/// <param name="key">The key.</param>
101+
/// <param name="defaultValue">The default value.</param>
102+
/// <param name="replace">The replace.</param>
103+
/// <returns>T.</returns>
104+
private static T ParseOrDefault<T>(this IReadOnlyDictionary<string, string> metadata, string key, T defaultValue, string replace = null)
105+
{
106+
if (!metadata.ContainsKey(key))
107+
return defaultValue;
108+
109+
var value = metadata[key].Trim();
110+
if (!string.IsNullOrEmpty(replace))
111+
value = value.Replace(replace, string.Empty);
112+
113+
if (typeof(T) == typeof(string))
114+
{
115+
return (T)(object)value;
116+
}
117+
else if (typeof(T) == typeof(int))
118+
{
119+
if (!int.TryParse(value, out var intResult))
120+
return defaultValue;
121+
122+
return (T)(object)intResult;
123+
}
124+
else if (typeof(T) == typeof(Enum))
125+
{
126+
if (!Enum.TryParse(typeof(T), value, out var enumResult))
127+
return defaultValue;
128+
129+
return (T)enumResult;
130+
}
131+
return defaultValue;
132+
}
133+
}
134+
}

TensorStack.Provider/Provider.cs

Lines changed: 0 additions & 30 deletions
This file was deleted.

0 commit comments

Comments
 (0)