Skip to content

Commit 66064dd

Browse files
KarhouTamSilv3S
authored andcommitted
[OpenReg][Feat][Docs] Enrich OpenReg device management implementation and add focused documentation (pytorch#165897)
## Summary This PR enriches OpenReg device management codes and adds focused documentation. ## Key Changes - Introduced device management documentation in `device.md`. - Updated `OpenRegFunctions.h` and `OpenRegFunctions.cpp` to use `DeviceIndex` and added error handling. - Implemented `check_device_index` function for validating device indices. - Enhanced Python bindings in `Module.cpp` for device management. - Added tests for invalid device index handling in `test_device.py`. Pull Request resolved: pytorch#165897 Approved by: https://github.com/fffrog
1 parent b30634b commit 66064dd

File tree

10 files changed

+178
-34
lines changed

10 files changed

+178
-34
lines changed

docs/source/accelerator/device.md

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
# Device Management
2+
3+
## Background
4+
5+
Device management handles basic operations like querying how many devices are available and switching between them. Accelerator backends need to wrap their device runtime's APIs and expose them to PyTorch.
6+
7+
The OpenReg implementation ([`OpenRegFunctions.h/cpp`][OpenReg Device Management]) shows how to wrap a third-party runtime. These functions are used throughout the backend - by streams, events, generators, and Python bindings.
8+
9+
## Design
10+
11+
Accelerator vendors need to implement these core functions:
12+
13+
| Function Name | Description | Application Scenarios |
14+
| ------------------------- | ---------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------- |
15+
| `device_count()` | Query the total number of available devices in the system | - Application initialization<br>- Multi-device workload distribution<br>- Validating device indices before use |
16+
| `current_device()` | Get the currently active device for the calling thread | - Debugging and logging<br>- Determining tensor placement<br>- Guard implementations |
17+
| `set_device()` | Change the active device for subsequent operations | - Switching context between devices<br>- Initializing specific device resources<br>- Multi-GPU training loops |
18+
| `exchange_device()` | Atomically swap device and return the previous device | - Implementing device guards<br>- Temporarily switching device context<br>- RAII-based device management |
19+
| `maybe_exchange_device()` | Conditionally exchange device only if the index is valid (-1 OK) | - Safe device switching with optional indices<br>- Guard implementations with nullable device values |
20+
21+
These functions are building blocks for more complex features like streams, events, and memory management. Make sure to validate inputs and handle errors properly.
22+
23+
## Implementation
24+
25+
This section shows how to implement device management using `set_device` as an example. The implementation requires:
26+
1. C++ wrappers around the device runtime
27+
2. Python bindings to expose the C++ functions
28+
3. User-friendly Python APIs
29+
30+
### C++ Side
31+
32+
Wrap the device runtime's API and add error handling. The `SetDevice` function shows this pattern:
33+
34+
```{eval-rst}
35+
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegFunctions.cpp
36+
:language: c++
37+
:start-after: LITERALINCLUDE START: OPENREG SetDevice FUNCTION
38+
:end-before: LITERALINCLUDE END: OPENREG SetDevice FUNCTION
39+
:linenos:
40+
```
41+
```{eval-rst}
42+
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegFunctions.cpp
43+
:language: c++
44+
:start-after: LITERALINCLUDE START: OPENREG set_device FUNCTION
45+
:end-before: LITERALINCLUDE END: OPENREG set_device FUNCTION
46+
:linenos:
47+
```
48+
49+
### Binding
50+
51+
Expose the C++ functions to Python using pybind11:
52+
53+
```{eval-rst}
54+
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/csrc/Module.cpp
55+
:language: c++
56+
:start-after: LITERALINCLUDE START: MODULE SET DEVICE HELPER
57+
:end-before: LITERALINCLUDE END: MODULE SET DEVICE HELPER
58+
:linenos:
59+
```
60+
```{eval-rst}
61+
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/csrc/Module.cpp
62+
:language: c++
63+
:start-after: LITERALINCLUDE START: OPENREG MODULE METHODS
64+
:end-before: LITERALINCLUDE END: OPENREG MODULE METHODS
65+
:linenos:
66+
:emphasize-lines: 5
67+
```
68+
69+
### Python Side
70+
71+
Wrap the C++ bindings with user-friendly Python functions:
72+
73+
```{eval-rst}
74+
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/openreg/__init__.py
75+
:language: python
76+
:start-after: LITERALINCLUDE START: PYTHON SET DEVICE FUNCTION
77+
:end-before: LITERALINCLUDE END: PYTHON SET DEVICE FUNCTION
78+
:linenos:
79+
```
80+
81+
Here's the complete mapping from C++ to Python:
82+
83+
| C++ Binding Function | C++ Binding API (pybind11) | Python User API | Description |
84+
| -------------------- | ---------------------------------------- | -------------------------------- | -------------------------------------------- |
85+
| `_getDeviceCount` | `torch_openreg._C._get_device_count()` | `torch.openreg.device_count()` | Returns the total number of devices |
86+
| `_getDevice` | `torch_openreg._C._get_device()` | `torch.openreg.current_device()` | Returns the current active device index |
87+
| `_setDevice` | `torch_openreg._C._set_device(idx)` | `torch.openreg.set_device(idx)` | Sets the active device |
88+
| `_exchangeDevice` | `torch_openreg._C._exchange_device(idx)` | N/A (internal use only) | Atomically swaps device and returns previous |
89+
90+
## Guard
91+
92+
Device guards provide automatic device switching with exception safety. They're similar to lock guards in C++ - they switch device on construction and restore it on destruction.
93+
94+
Implement `DeviceGuardImplInterface` to integrate with PyTorch's guard system:
95+
96+
```{eval-rst}
97+
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegGuard.h
98+
:language: c++
99+
:start-after: LITERALINCLUDE START: OPENREG DEVICE MGMT GUARD IMPL EXAMPLE
100+
:end-before: LITERALINCLUDE END: OPENREG DEVICE MGMT GUARD IMPL EXAMPLE
101+
:linenos:
102+
```
103+
104+
**What needs to be implemented:**
105+
106+
1. **exchangeDevice()**: Switch to a new device and return the old one (used by guard constructors)
107+
2. **getDevice()**: Get the current device
108+
3. **setDevice()**: Set the active device
109+
4. **Type checking**: Validate that device type matches the backend
110+
111+
This makes the guard available to PyTorch for the `PrivateUse1` device type. Users can then use standard PyTorch device guards with the custom backend.
112+
113+
[OpenReg Device Management]: https://github.com/pytorch/pytorch/blob/main/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegFunctions.cpp "OpenReg Device Management"

docs/source/accelerator/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ Next, we will delve into each chapter of this guide. Each chapter focuses on a k
4242
:glob:
4343
:maxdepth: 1
4444
45+
device
4546
hooks
4647
autoload
4748
operators

test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegException.h

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,12 @@
44

55
#include <c10/util/Exception.h>
66

7-
void orCheckFail(
8-
const char* func,
9-
const char* file,
10-
uint32_t line,
11-
const char* msg = "");
12-
13-
#define OPENREG_CHECK(EXPR, ...) \
14-
do { \
15-
const orError_t __err = EXPR; \
16-
if (__err != orSuccess) { \
17-
orCheckFail( \
18-
__func__, __FILE__, static_cast<uint32_t>(__LINE__), ##__VA_ARGS__); \
19-
} \
7+
void orCheckFail(const char* func, const char* file, uint32_t line, const char* msg = "");
8+
9+
#define OPENREG_CHECK(EXPR, ...) \
10+
do { \
11+
const orError_t __err = EXPR; \
12+
if (C10_UNLIKELY(__err != orSuccess)) { \
13+
orCheckFail(__func__, __FILE__, static_cast<uint32_t>(__LINE__), ##__VA_ARGS__); \
14+
} \
2015
} while (0)
Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#include <c10/util/Exception.h>
12
#include <include/openreg.h>
23

34
#include "OpenRegException.h"
@@ -9,56 +10,60 @@ orError_t GetDeviceCount(int* dev_count) {
910
return orGetDeviceCount(dev_count);
1011
}
1112

12-
orError_t GetDevice(c10::DeviceIndex* device) {
13+
orError_t GetDevice(DeviceIndex* device) {
1314
int tmp_device = -1;
1415
auto err = orGetDevice(&tmp_device);
15-
*device = static_cast<c10::DeviceIndex>(tmp_device);
16+
*device = static_cast<DeviceIndex>(tmp_device);
1617
return err;
1718
}
18-
19-
orError_t SetDevice(c10::DeviceIndex device) {
19+
// LITERALINCLUDE START: OPENREG SetDevice FUNCTION
20+
orError_t SetDevice(DeviceIndex device) {
2021
int cur_device = -1;
21-
orGetDevice(&cur_device);
22+
OPENREG_CHECK(orGetDevice(&cur_device));
2223
if (device == cur_device) {
2324
return orSuccess;
2425
}
2526
return orSetDevice(device);
2627
}
28+
// LITERALINCLUDE END: OPENREG SetDevice FUNCTION
2729

2830
int device_count_impl() {
2931
int count = 0;
3032
GetDeviceCount(&count);
3133
return count;
3234
}
3335

34-
OPENREG_EXPORT c10::DeviceIndex device_count() noexcept {
36+
OPENREG_EXPORT DeviceIndex device_count() noexcept {
3537
// initialize number of devices only once
3638
static int count = []() {
3739
try {
3840
auto result = device_count_impl();
3941
TORCH_CHECK(
40-
result <= std::numeric_limits<c10::DeviceIndex>::max(),
42+
result <= std::numeric_limits<DeviceIndex>::max(),
4143
"Too many devices, DeviceIndex overflowed");
4244
return result;
43-
} catch (const c10::Error& ex) {
45+
} catch (const Error& ex) {
4446
// We don't want to fail, but still log the warning
4547
// msg() returns the message without the stack trace
4648
TORCH_WARN("Device initialization: ", ex.msg());
4749
return 0;
4850
}
4951
}();
50-
return static_cast<c10::DeviceIndex>(count);
52+
return static_cast<DeviceIndex>(count);
5153
}
5254

53-
OPENREG_EXPORT c10::DeviceIndex current_device() {
54-
c10::DeviceIndex cur_device = -1;
55-
GetDevice(&cur_device);
55+
OPENREG_EXPORT DeviceIndex current_device() {
56+
DeviceIndex cur_device = -1;
57+
OPENREG_CHECK(GetDevice(&cur_device));
5658
return cur_device;
5759
}
5860

59-
OPENREG_EXPORT void set_device(c10::DeviceIndex device) {
60-
SetDevice(device);
61+
// LITERALINCLUDE START: OPENREG set_device FUNCTION
62+
OPENREG_EXPORT void set_device(DeviceIndex device) {
63+
check_device_index(device);
64+
OPENREG_CHECK(SetDevice(device));
6165
}
66+
// LITERALINCLUDE END: OPENREG set_device FUNCTION
6267

6368
OPENREG_EXPORT DeviceIndex ExchangeDevice(DeviceIndex device) {
6469
int current_device = -1;
@@ -71,4 +76,8 @@ OPENREG_EXPORT DeviceIndex ExchangeDevice(DeviceIndex device) {
7176
return current_device;
7277
}
7378

79+
OPENREG_EXPORT DeviceIndex maybe_exchange_device(DeviceIndex to_device) {
80+
check_device_index(to_device);
81+
return ExchangeDevice(to_device);
82+
}
7483
} // namespace c10::openreg

test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegFunctions.h

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,20 @@
99

1010
namespace c10::openreg {
1111

12-
OPENREG_EXPORT c10::DeviceIndex device_count() noexcept;
13-
OPENREG_EXPORT c10::DeviceIndex current_device();
14-
OPENREG_EXPORT void set_device(c10::DeviceIndex device);
12+
OPENREG_EXPORT DeviceIndex device_count() noexcept;
13+
OPENREG_EXPORT DeviceIndex current_device();
14+
OPENREG_EXPORT void set_device(DeviceIndex device);
15+
OPENREG_EXPORT DeviceIndex maybe_exchange_device(DeviceIndex to_device);
1516

1617
OPENREG_EXPORT DeviceIndex ExchangeDevice(DeviceIndex device);
1718

19+
static inline void check_device_index(int64_t device) {
20+
TORCH_CHECK(device >= 0 && device < c10::openreg::device_count(),
21+
"The device index is out of range. It must be in [0, ",
22+
static_cast<int>(c10::openreg::device_count()),
23+
"), but got ",
24+
static_cast<int>(device),
25+
".");
26+
}
27+
1828
} // namespace c10::openreg

test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegGuard.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
namespace c10::openreg {
44

5+
// LITERALINCLUDE START: OPENREG GUARD REGISTRATION
56
C10_REGISTER_GUARD_IMPL(PrivateUse1, OpenRegGuardImpl);
7+
// LITERALINCLUDE END: OPENREG GUARD REGISTRATION
68

79
} // namespace c10::openreg

test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegGuard.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
namespace c10::openreg {
1313

14+
// LITERALINCLUDE START: OPENREG DEVICE MGMT GUARD IMPL EXAMPLE
1415
struct OpenRegGuardImpl final : public c10::impl::DeviceGuardImplInterface {
1516
static constexpr DeviceType static_type = c10::DeviceType::PrivateUse1;
1617

@@ -58,6 +59,7 @@ struct OpenRegGuardImpl final : public c10::impl::DeviceGuardImplInterface {
5859

5960
set_device(d.index());
6061
}
62+
// LITERALINCLUDE END: OPENREG DEVICE MGMT GUARD IMPL EXAMPLE
6163

6264
/**
6365
* Set the current device to c10::Device, without checking for errors

test/cpp_extensions/open_registration_extension/torch_openreg/tests/test_device.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@ def test_device_context(self):
2727
self.assertEqual(torch.accelerator.current_device_index(), 1)
2828
self.assertEqual(torch.accelerator.current_device_index(), device)
2929

30+
def test_invalid_device_index(self):
31+
with self.assertRaisesRegex(RuntimeError, "The device index is out of range"):
32+
torch.accelerator.set_device_index(2)
33+
3034

3135
if __name__ == "__main__":
3236
run_tests()

test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/csrc/Module.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,18 +34,21 @@ static PyObject* _getDefaultGenerator(PyObject* self, PyObject* arg) {
3434
}
3535
// LITERALINCLUDE END: OPENREG GET DEFAULT GENERATOR
3636

37+
// LITERALINCLUDE START: MODULE SET DEVICE HELPER
38+
3739
PyObject* _setDevice(PyObject* self, PyObject* arg) {
3840
HANDLE_TH_ERRORS
3941
TORCH_CHECK(THPUtils_checkLong(arg), "invalid argument to setDevice");
40-
auto device = THPUtils_unpackLong(arg);
41-
42+
auto device = THPUtils_unpackDeviceIndex(arg);
4243
torch::utils::device_lazy_init(at::kPrivateUse1);
43-
c10::openreg::set_device(static_cast<c10::DeviceIndex>(device));
44+
c10::openreg::set_device(device);
4445

4546
Py_RETURN_NONE;
4647
END_HANDLE_TH_ERRORS
4748
}
4849

50+
// LITERALINCLUDE END: MODULE SET DEVICE HELPER
51+
4952
PyObject* _exchangeDevice(PyObject* self, PyObject* arg) {
5053
HANDLE_TH_ERRORS
5154
TORCH_CHECK(THPUtils_checkLong(arg), "invalid argument to exchangeDevice");

test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/openreg/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,13 @@ def current_device():
4141
return torch_openreg._C._get_device()
4242

4343

44+
# LITERALINCLUDE START: PYTHON SET DEVICE FUNCTION
4445
def set_device(device) -> None:
45-
return torch_openreg._C._set_device(device)
46+
if device >= 0:
47+
torch_openreg._C._set_device(device)
48+
49+
50+
# LITERALINCLUDE END: PYTHON SET DEVICE FUNCTION
4651

4752

4853
def init():

0 commit comments

Comments
 (0)