Skip to content

Commit 7a3fbcf

Browse files
committed
Refactor device providers
1 parent d3a96fc commit 7a3fbcf

File tree

14 files changed

+588
-203
lines changed

14 files changed

+588
-203
lines changed

TensorStack.Audio.Windows/AudioManager.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ public static class AudioManager
2525
/// <param name="ffmpegPath">The ffmpeg path.</param>
2626
/// <param name="ffprobePath">The ffprobe path.</param>
2727
/// <param name="directoryTemp">The directory temporary.</param>
28-
public static void Configure(string ffmpegPath, string ffprobePath, string directoryTemp)
28+
public static void Initialize(string ffmpegPath, string ffprobePath, string directoryTemp)
2929
{
3030
FFMpegPath = ffmpegPath;
3131
FFProbePath = ffprobePath;

TensorStack.Provider/Device.cs renamed to TensorStack.Common/Device.cs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
1-
namespace TensorStack.Providers
1+
// Copyright (c) TensorStack. All rights reserved.
2+
// Licensed under the Apache 2.0 License.
3+
namespace TensorStack.Common
24
{
35
public record Device
46
{
5-
public int DeviceId => PerformanceIndex; // TODO:
7+
public int Id { get; init; }
8+
public int DeviceId { get; init; }
69
public string Name { get; init; }
710
public DeviceType Type { get; init; }
811
public int Memory { get; init; }
9-
public int MemoryGB => Memory / 1000;
10-
public int AdapterIndex { get; init; }
11-
public int PerformanceIndex { get; init; }
12+
public int MemoryGB => Memory / 1024;
1213

1314
public int HardwareID { get; init; }
1415
public int HardwareLUID { get; init; }

TensorStack.Provider/Devices.cs renamed to TensorStack.Common/DeviceManager.cs

Lines changed: 35 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,23 @@
1-
using Microsoft.ML.OnnxRuntime;
1+
// Copyright (c) TensorStack. All rights reserved.
2+
// Licensed under the Apache 2.0 License.
3+
using Microsoft.ML.OnnxRuntime;
24
using System;
35
using System.Collections.Generic;
46
using System.Linq;
57

6-
namespace TensorStack.Providers
8+
namespace TensorStack.Common
79
{
8-
public static class Devices
10+
public static class DeviceManager
911
{
1012
private static OrtEnv _environment;
1113
private static EnvironmentCreationOptions _environmentOptions;
14+
private static IReadOnlyList<Device> _devices;
15+
private static string _deviceProvider;
1216

1317
/// <summary>
1418
/// Initializes this instance.
1519
/// </summary>
16-
public static void Initialize()
20+
public static void Initialize(string executionProvider, string libraryPath = default)
1721
{
1822
Initialize(new EnvironmentCreationOptions
1923
{
@@ -24,50 +28,52 @@ public static void Initialize()
2428
GlobalInterOpNumThreads = 1,
2529
GlobalIntraOpNumThreads = 1
2630
}
27-
});
31+
}, executionProvider, libraryPath);
2832
}
2933

3034

3135
/// <summary>
3236
/// Initializes the specified environment options.
3337
/// </summary>
3438
/// <param name="environmentOptions">The environment options.</param>
35-
public static void Initialize(EnvironmentCreationOptions environmentOptions)
39+
public static void Initialize(EnvironmentCreationOptions environmentOptions, string executionProvider, string libraryPath = default)
3640
{
41+
if (_environment is not null)
42+
throw new Exception("Environment is already initialized.");
43+
44+
_deviceProvider = executionProvider;
3745
_environmentOptions = environmentOptions;
3846
_environment = OrtEnv.CreateInstanceWithOptions(ref _environmentOptions);
39-
}
40-
41-
42-
/// <summary>
43-
/// Gets the devices.
44-
/// </summary>
45-
/// <param name="executionProvider">The execution provider.</param>
46-
/// <param name="libraryPath">The library path.</param>
47-
public static IReadOnlyList<Device> GetDevices(string executionProvider, string libraryPath = default)
48-
{
49-
if (_environment == null)
50-
Initialize();
5147

5248
var providers = _environment.GetAvailableProviders();
53-
if (!providers.Contains(executionProvider, StringComparer.OrdinalIgnoreCase))
54-
return [];
49+
if (!providers.Contains(_deviceProvider, StringComparer.OrdinalIgnoreCase))
50+
throw new Exception($"Provider {_deviceProvider} was not found in GetAvailableProviders().");
5551

56-
if (!string.IsNullOrEmpty(executionProvider))
57-
_environment.RegisterExecutionProviderLibrary(executionProvider, libraryPath);
52+
if (!string.IsNullOrEmpty(libraryPath))
53+
_environment.RegisterExecutionProviderLibrary(_deviceProvider, libraryPath);
5854

5955
var devices = new List<Device>();
6056
foreach (var epDevice in _environment.GetEpDevices())
6157
{
62-
if (!epDevice.EpName.Equals(executionProvider, StringComparison.OrdinalIgnoreCase))
63-
continue;
64-
65-
devices.Add(CreateDevice(epDevice));
58+
if (epDevice.HardwareDevice.Type == OrtHardwareDeviceType.CPU || epDevice.EpName.Equals(_deviceProvider, StringComparison.OrdinalIgnoreCase))
59+
devices.Add(CreateDevice(epDevice));
6660
}
67-
return devices;
61+
_devices = devices;
6862
}
6963

7064

65+
/// <summary>
66+
/// Gets the devices.
67+
/// </summary>
68+
public static IReadOnlyList<Device> Devices => _devices;
69+
70+
71+
/// <summary>
72+
/// The cpu provider name
73+
/// </summary>
74+
public const string CPUProviderName = "CPUExecutionProvider";
75+
76+
7177
/// <summary>
7278
/// Creates the device.
7379
/// </summary>
@@ -79,11 +85,11 @@ private static Device CreateDevice(OrtEpDevice epDevice)
7985
var metadata = device.Metadata.Entries;
8086
return new Device
8187
{
88+
Id = metadata.ParseOrDefault("DxgiAdapterNumber", 0),
89+
DeviceId = metadata.ParseOrDefault("DxgiHighPerformanceIndex", 0),
8290
Type = Enum.Parse<DeviceType>(device.Type.ToString()),
8391
Name = metadata.ParseOrDefault("Description", string.Empty),
8492
Memory = metadata.ParseOrDefault("DxgiVideoMemory", 0, " MB"),
85-
AdapterIndex = metadata.ParseOrDefault("DxgiAdapterNumber", 0),
86-
PerformanceIndex = metadata.ParseOrDefault("DxgiHighPerformanceIndex", 0),
8793
HardwareLUID = metadata.ParseOrDefault("LUID", 0),
8894
HardwareID = (int)device.DeviceId,
8995
HardwareVendor = device.Vendor,

TensorStack.Common/DeviceType.cs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
// Copyright (c) TensorStack. All rights reserved.
2+
// Licensed under the Apache 2.0 License.
3+
namespace TensorStack.Common
4+
{
5+
public enum DeviceType
6+
{
7+
CPU = 0,
8+
GPU = 1,
9+
NPU = 2
10+
}
11+
}
Lines changed: 1 addition & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
11
// Copyright (c) TensorStack. All rights reserved.
22
// Licensed under the Apache 2.0 License.
3-
using System;
43
using Microsoft.ML.OnnxRuntime;
4+
using System;
55

66
namespace TensorStack.Common
77
{
88
public class ExecutionProvider
99
{
1010
private readonly string _name;
1111
private readonly OrtMemoryInfo _memoryInfo;
12-
1312
private readonly Func<ModelConfig, SessionOptions> _sessionOptionsFactory;
1413

1514
public ExecutionProvider(string name, OrtMemoryInfo memoryInfo, Func<ModelConfig, SessionOptions> sessionOptionsFactory)
@@ -26,25 +25,5 @@ public SessionOptions CreateSession(ModelConfig modelConfig)
2625
{
2726
return _sessionOptionsFactory(modelConfig);
2827
}
29-
30-
/// <summary>
31-
/// Gets default CPU provider.
32-
/// </summary>
33-
/// <param name="optimizationLevel">The optimization level.</param>
34-
/// <returns>ExecutionProvider.</returns>
35-
public static ExecutionProvider GetDefault(GraphOptimizationLevel optimizationLevel = GraphOptimizationLevel.ORT_DISABLE_ALL)
36-
{
37-
return new ExecutionProvider("CPU Provider", OrtMemoryInfo.DefaultInstance, configuration =>
38-
{
39-
var sessionOptions = new SessionOptions
40-
{
41-
EnableCpuMemArena = true,
42-
EnableMemoryPattern = true,
43-
GraphOptimizationLevel = optimizationLevel
44-
};
45-
sessionOptions.AppendExecutionProvider_CPU();
46-
return sessionOptions;
47-
});
48-
}
4928
}
5029
}

TensorStack.Provider.CUDA/Provider.cs

Lines changed: 0 additions & 67 deletions
This file was deleted.

TensorStack.Provider.DML/Provider.cs

Lines changed: 0 additions & 66 deletions
This file was deleted.

TensorStack.Provider/DeviceType.cs

Lines changed: 0 additions & 9 deletions
This file was deleted.

0 commit comments

Comments
 (0)