|
42 | 42 | // before raising an error. |
43 | 43 | const constexpr size_t MAX_RECURSION_DEPTH = 250; |
44 | 44 |
|
| 45 | +// This is a private type in CPython, so we need to define it here |
| 46 | +// in order to be able to use it in our code. |
| 47 | +// We cannot put it into namespace {} because the 'PyAsyncGenASend' name would then be ambiguous. |
| 48 | +// The extern "C" is not required but here to avoid any ambiguity. |
| 49 | +extern "C" |
| 50 | +{ |
| 51 | + |
| 52 | + typedef struct PyAsyncGenASend |
| 53 | + { |
| 54 | + PyObject_HEAD PyAsyncGenObject* ags_gen; |
| 55 | + } PyAsyncGenASend; |
| 56 | + |
| 57 | +#ifndef PyAsyncGenASend_CheckExact |
| 58 | +#if PY_VERSION_HEX >= 0x03090000 |
| 59 | +// Py_IS_TYPE is only available since Python 3.9 |
| 60 | +#define PyAsyncGenASend_CheckExact(obj) (Py_IS_TYPE(obj, &_PyAsyncGenASend_Type)) |
| 61 | +#else // PY_VERSION_HEX >= 0x03090000 |
| 62 | +#define PyAsyncGenASend_CheckExact(obj) (Py_TYPE(obj) == &_PyAsyncGenASend_Type) |
| 63 | +#endif // PY_VERSION_HEX < 0x03090000 |
| 64 | +#endif // defined PyAsyncGenASend_CheckExact |
| 65 | +} |
| 66 | + |
45 | 67 | class GenInfo |
46 | 68 | { |
47 | 69 | public: |
@@ -76,13 +98,30 @@ GenInfo::create(PyObject* gen_addr) |
76 | 98 | } |
77 | 99 |
|
78 | 100 | PyGenObject gen; |
79 | | - |
80 | | - if (copy_type(gen_addr, gen) || !PyCoro_CheckExact(&gen)) { |
| 101 | + if (copy_type(gen_addr, gen)) { |
81 | 102 | recursion_depth--; |
82 | 103 | return ErrorKind::GenInfoError; |
83 | 104 | } |
84 | 105 |
|
85 | | - auto origin = gen_addr; |
| 106 | + if (PyAsyncGenASend_CheckExact(&gen)) { |
| 107 | + static_assert( |
| 108 | + sizeof(PyAsyncGenASend) <= sizeof(PyGenObject), |
| 109 | + "PyAsyncGenASend must be smaller than PyGenObject in order for copy_type to have copied enough data."); |
| 110 | + |
| 111 | + // Type-pun the PyGenObject to a PyAsyncGenASend. *gen_addr was actually never a PyGenObject to begin with, |
| 112 | + // but we do not care as the only thing we will use from it is the ags_gen field. |
| 113 | + PyAsyncGenASend* asend = reinterpret_cast<PyAsyncGenASend*>(&gen); |
| 114 | + PyAsyncGenObject* gen = asend->ags_gen; |
| 115 | + auto asend_yf = reinterpret_cast<PyObject*>(gen); |
| 116 | + auto result = GenInfo::create(asend_yf); |
| 117 | + recursion_depth--; |
| 118 | + return result; |
| 119 | + } |
| 120 | + |
| 121 | + if (!PyCoro_CheckExact(&gen) && !PyAsyncGen_CheckExact(&gen)) { |
| 122 | + recursion_depth--; |
| 123 | + return ErrorKind::GenInfoError; |
| 124 | + } |
86 | 125 |
|
87 | 126 | #if PY_VERSION_HEX >= 0x030b0000 |
88 | 127 | // The frame follows the generator object |
@@ -117,7 +156,7 @@ GenInfo::create(PyObject* gen_addr) |
117 | 156 | #endif |
118 | 157 |
|
119 | 158 | recursion_depth--; |
120 | | - return std::make_unique<GenInfo>(origin, frame, std::move(await), is_running); |
| 159 | + return std::make_unique<GenInfo>(gen_addr, frame, std::move(await), is_running); |
121 | 160 | } |
122 | 161 |
|
123 | 162 | // ---------------------------------------------------------------------------- |
|
0 commit comments