Skip to content

Commit a147485

Browse files
committed
oh how i love null checks
1 parent b0692cb commit a147485

File tree

10 files changed

+108
-45
lines changed

10 files changed

+108
-45
lines changed

src/mod.c

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
#include <signal.h>
44
#include <setjmp.h>
55
#include <stdbool.h>
6+
#include <stdio.h>
7+
#include <frameobject.h> // needed to get members of PyFrameObject
68
#define GETOBJ PyObject* obj; if (!PyArg_ParseTuple(args, "O", &obj)) return NULL
7-
#define CALL_ATTR(ob, attr) PyObject_Call(PyObject_GetAttrString(ob, attr), PyTuple_New(0), NULL)
89
static jmp_buf buf;
910

1011
static PyObject* add_ref(PyObject* self, PyObject* args) {
@@ -47,8 +48,7 @@ void handler(int signum) {
4748
static PyObject* handle(PyObject* self, PyObject* args) {
4849
PyObject* func;
4950
PyObject* params = NULL;
50-
PyObject* kwargs = NULL;
51-
PyObject* faulthandler = PyImport_ImportModule("faulthandler");
51+
PyObject* kwargs = NULL;;
5252

5353
if (!PyArg_ParseTuple(
5454
args,
@@ -64,27 +64,26 @@ static PyObject* handle(PyObject* self, PyObject* args) {
6464
if (!params) params = PyTuple_New(0);
6565
if (!kwargs) kwargs = PyDict_New();
6666

67-
int val = setjmp(buf);
68-
69-
CALL_ATTR(faulthandler, "disable");
70-
// faulthandler needs to be shut off in case of a segfault or its message will still print
71-
7267
if (setjmp(buf)) {
73-
CALL_ATTR(faulthandler, "enable");
68+
PyFrameObject* frame = PyEval_GetFrame();
69+
PyCodeObject* code = frame->f_code;
70+
Py_INCREF(code);
71+
72+
// this is basically a copy of PyFrame_GetCode, which is only available on 3.9+
7473

7574
PyErr_Format(
7675
PyExc_RuntimeError,
7776
"segment violation occured during execution of %S",
78-
PyObject_GetAttrString(func, "__name__")
77+
code->co_name
7978
);
79+
80+
Py_DECREF(code);
8081
return NULL;
8182
}
8283

8384
PyObject* result = PyObject_Call(func, params, kwargs);
84-
8585
if (!result) return NULL;
8686

87-
CALL_ATTR(faulthandler, "enable");
8887
return result;
8988
}
9089

src/pointers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
TypedCPointer, VoidPointer, array, cast, to_c_ptr, to_struct_ptr, to_voidp
1515
)
1616
from .calloc import AllocatedArrayPointer, calloc
17-
from .constants import NULL, Nullable, raw_type
17+
from .constants import NULL, Nullable, handle, raw_type
1818
from .custom_binding import binding, binds
1919
from .decay import decay, decay_annotated, decay_wrapped
2020
from .exceptions import (

src/pointers/base_pointers.py

Lines changed: 16 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@
1212

1313
from typing_extensions import final
1414

15-
from _pointers import add_ref, handle, remove_ref
15+
from _pointers import add_ref, remove_ref
1616

1717
from ._utils import deref, force_set_attr, move_to_mem
18-
from .constants import NULL, Nullable
18+
from .constants import NULL, Nullable, handle
1919
from .exceptions import DereferenceError, FreedMemoryError, NullPointerError
2020

2121
__all__ = (
@@ -96,22 +96,13 @@ def ensure(self) -> int:
9696

9797
class Movable(ABC, Generic[T, A]):
9898
@abstractmethod
99-
def _move(
100-
self,
101-
data: Union[A, T],
102-
*,
103-
unsafe: bool = False,
104-
) -> None:
105-
...
106-
107-
@final
10899
def move(
109100
self,
110101
data: Union[A, T],
111102
*,
112103
unsafe: bool = False,
113104
) -> None:
114-
handle(self._move, (data,), {"unsafe": unsafe})
105+
...
115106

116107
def __ilshift__(self, data: Union[A, T]):
117108
self.move(data)
@@ -126,16 +117,12 @@ class Dereferencable(ABC, Generic[T]):
126117
"""Abstract class for an object that may be dereferenced."""
127118

128119
@abstractmethod
129-
def _dereference(self) -> T:
130-
...
131-
132-
@final
133120
def dereference(self) -> T:
134121
"""Dereference the pointer.
135122
136123
Returns:
137124
Value at the pointers address."""
138-
return handle(self._dereference)
125+
...
139126

140127
@final
141128
def __invert__(self) -> T:
@@ -204,6 +191,8 @@ def size(self) -> int:
204191
"""Size of the target value."""
205192
...
206193

194+
@handle
195+
@final
207196
def make_ct_pointer(self) -> "ctypes._PointerLike":
208197
"""Convert the address to a ctypes pointer.
209198
@@ -262,13 +251,15 @@ def type(self) -> Type[T]:
262251

263252
return self._type
264253

254+
@handle
265255
def set_attr(self, key: str, value: Any) -> None:
266256
v: Any = ~self # mypy gets angry if this isnt any
267257
if not isinstance(~self, type):
268258
v = type(v)
269259

270260
force_set_attr(v, key, value)
271261

262+
@handle
272263
def assign(
273264
self,
274265
target: Nullable[Union["BaseObjectPointer[T]", T]],
@@ -299,7 +290,7 @@ def assign(
299290
def address(self) -> Optional[int]:
300291
return self._address
301292

302-
def _dereference(self) -> T:
293+
def dereference(self) -> T:
303294
return deref(self.ensure())
304295

305296
def __irshift__(
@@ -366,7 +357,8 @@ def _make_stream_and_ptr(
366357
bytes_a = (ctypes.c_ubyte * size).from_address(address)
367358
return self.make_ct_pointer(), bytes(bytes_a)
368359

369-
def _move(
360+
@handle
361+
def move(
370362
self,
371363
data: Union["BaseCPointer[T]", T],
372364
*,
@@ -392,6 +384,7 @@ def __ixor__(self, data: Union["BaseCPointer[T]", T]):
392384
self.move(data, unsafe=True)
393385
return self
394386

387+
@handle
395388
def make_ct_pointer(self):
396389
return ctypes.cast(
397390
self.ensure(),
@@ -436,7 +429,8 @@ def assigned(self) -> bool:
436429
def assigned(self, value: bool) -> None:
437430
self._assigned = value
438431

439-
def _move(
432+
@handle
433+
def move(
440434
self,
441435
data: Union[BasePointer[T], T],
442436
unsafe: bool = False,
@@ -456,7 +450,8 @@ def _move(
456450
self.assigned = True
457451
remove_ref(data)
458452

459-
def _dereference(self) -> T:
453+
@handle
454+
def dereference(self) -> T:
460455
if self.freed:
461456
raise FreedMemoryError(
462457
"cannot dereference memory that has been freed",

src/pointers/c_pointer.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from ._utils import get_mapped, map_type
88
from .base_pointers import BaseCPointer, IterDereferencable, Typed
9+
from .constants import handle
910

1011
if TYPE_CHECKING:
1112
from .structure import Struct, StructPointer
@@ -31,11 +32,13 @@ class VoidPointer(BaseCPointer[Any]):
3132
def size(self) -> int:
3233
return self._size
3334

34-
@property
35+
@property # type: ignore
36+
@handle
3537
def _as_parameter_(self) -> ctypes.c_void_p:
3638
return ctypes.c_void_p(self.address)
3739

38-
def _dereference(self) -> Optional[int]:
40+
@handle
41+
def dereference(self) -> Optional[int]:
3942
"""Dereference the pointer."""
4043
deref = ctypes.c_void_p.from_address(self.ensure())
4144
return deref.value
@@ -61,6 +64,7 @@ def decref(self) -> bool:
6164
"""Whether the target objects reference count should be decremented when the pointer is garbage collected.""" # noqa
6265
...
6366

67+
@handle
6468
def _cleanup(self):
6569
if self.address:
6670
if (type(~self) is not str) and (self.decref):
@@ -120,18 +124,20 @@ class TypedCPointer(_TypedPointer[T]):
120124
def address(self) -> Optional[int]:
121125
return self._address
122126

123-
@property
127+
@property # type: ignore
128+
@handle
124129
def _as_parameter_(self):
125130
ctype = get_mapped(self.type)
126131
deref = ctype.from_address(self.ensure())
127132
value = deref.value # type: ignore
128133

129134
if isinstance(value, (TypedCPointer, VoidPointer)):
130-
return ctypes.pointer(value._as_parameter_)
135+
return ctypes.pointer(value._as_parameter_) # type: ignore
131136

132137
return ctypes.pointer(deref)
133138

134-
def _dereference(self) -> T:
139+
@handle
140+
def dereference(self) -> T:
135141
"""Dereference the pointer."""
136142
ctype = get_mapped(self.type)
137143

@@ -177,14 +183,16 @@ def decref(self) -> bool:
177183
def type(self) -> Type[T]: # type: ignore
178184
return self._type
179185

180-
@property
186+
@property # type: ignore
187+
@handle
181188
def _as_parameter_(self) -> "ctypes.Array[ctypes._CData]":
182189
ctype = get_mapped(self.type)
183190

184191
deref = (ctype * self._length).from_address(self.ensure())
185192
return deref
186193

187-
def _dereference(self) -> List[T]:
194+
@handle
195+
def dereference(self) -> List[T]:
188196
"""Dereference the pointer."""
189197
array = self._as_parameter_
190198
return [array[i] for i in range(self._length)] # type: ignore
@@ -200,6 +208,7 @@ def __getitem__(self, index: int) -> T:
200208
return array[index]
201209

202210

211+
@handle
203212
def cast(ptr: VoidPointer, data_type: Type[T]) -> TypedCPointer[T]:
204213
"""Cast a void pointer to a typed pointer."""
205214

@@ -237,6 +246,7 @@ def to_struct_ptr(struct: A) -> "StructPointer[A]":
237246
return StructPointer(id(struct), type(struct))
238247

239248

249+
@handle
240250
def array(*seq: T) -> CArrayPointer[T]:
241251
f_type = type(seq[0])
242252

src/pointers/calloc.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from ._cstd import c_calloc, c_free
44
from .base_pointers import BaseAllocatedPointer
5+
from .constants import handle
56
from .exceptions import AllocationError
67

78
__all__ = ("AllocatedArrayPointer", "calloc")
@@ -108,6 +109,7 @@ def __setitem__(self, index: int, value: T) -> None:
108109
chunk = self._get_chunk_at(index)
109110
chunk <<= value
110111

112+
@handle
111113
def free(self) -> None:
112114
first = self[0]
113115
first.ensure_valid()

src/pointers/constants.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,30 @@
11
import ctypes
2-
from typing import Any, NamedTuple, Type, TypeVar, Union
2+
import faulthandler
3+
from contextlib import suppress
4+
from functools import wraps
5+
from io import UnsupportedOperation
6+
from typing import Any, Callable, NamedTuple, Type, TypeVar, Union
7+
8+
from typing_extensions import ParamSpec
9+
10+
from _pointers import handle as _handle
11+
12+
from .exceptions import SegmentViolation
13+
14+
with suppress(
15+
UnsupportedOperation
16+
): # in case its running in idle or something like that
17+
faulthandler.enable()
318

419
__all__ = (
520
"NULL",
621
"Nullable",
722
"raw_type",
23+
"handle",
824
)
925

1026
T = TypeVar("T")
27+
P = ParamSpec("P")
1128

1229

1330
class NULL:
@@ -24,3 +41,28 @@ class RawType(NamedTuple):
2441
def raw_type(ct: Type["ctypes._CData"]) -> Any:
2542
"""Set a raw ctypes type for a struct."""
2643
return RawType(ct)
44+
45+
46+
def handle(func: Callable[P, T]) -> Callable[P, T]:
47+
@wraps(func)
48+
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
49+
try:
50+
faulthandler.disable()
51+
call = _handle(func, args, kwargs)
52+
53+
with suppress(UnsupportedOperation):
54+
faulthandler.enable()
55+
56+
return call
57+
except RuntimeError as e:
58+
msg = str(e)
59+
60+
if not msg.startswith("segment"):
61+
raise
62+
63+
with suppress(UnsupportedOperation):
64+
faulthandler.enable()
65+
66+
raise SegmentViolation(str(e)) from None
67+
68+
return wrapper

src/pointers/exceptions.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
"InvalidBindingParameter",
77
"NullPointerError",
88
"InvalidVersionError",
9+
"SegmentViolation",
910
)
1011

1112

@@ -49,3 +50,9 @@ class InvalidVersionError(Exception):
4950
"""Python version is not high enough."""
5051

5152
pass
53+
54+
55+
class SegmentViolation(Exception):
56+
"""SIGSEGV was sent to Python."""
57+
58+
pass

0 commit comments

Comments
 (0)