Skip to content

Commit d3949e3

Browse files
committed
implement #20
1 parent 86916a4 commit d3949e3

File tree

6 files changed

+331
-39
lines changed

6 files changed

+331
-39
lines changed

src/pointers/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from .bindings import *
22
from .c_pointer import (
3-
StructPointer, TypedCPointer, VoidPointer, cast, to_c_ptr, to_struct_ptr
3+
StructPointer, TypedCPointer, VoidPointer, array, cast, to_c_ptr,
4+
to_struct_ptr
45
)
56
from .calloc import CallocPointer, calloc
67
from .custom_binding import binding, binds

src/pointers/_cstd.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,6 @@ class lconv(ctypes.Structure):
288288
# int printf(const char* format, ...)
289289
dll.printf.restype = ctypes.c_int
290290
# int sprintf(char* str, const char* format, ...)
291-
dll.sprintf.argtypes = (ctypes.c_char_p, ctypes.c_char_p)
292291
dll.sprintf.restype = ctypes.c_int
293292
# int vfprintf(FILE* stream, const char* format, va_list arg)
294293
# int vprintf(const char* format, va_list arg)
@@ -548,6 +547,45 @@ class lconv(ctypes.Structure):
548547
# time_t time(time_t* timer)
549548
dll.time.argtypes = (ctypes.POINTER(ctypes.c_int),)
550549
dll.time.restype = ctypes.c_int
550+
# void (*signal(int sig, void (*func)(int)))(int)
551+
dll.signal.argtypes = (ctypes.c_int, ctypes.CFUNCTYPE(None, ctypes.c_int))
552+
dll.signal.restype = None
553+
# void qsort(
554+
# void *base,
555+
# size_t nitems,
556+
# size_t size,
557+
# int (*compar)(const void *, const void*)
558+
# )
559+
dll.qsort.argtypes = (
560+
ctypes.c_void_p,
561+
ctypes.c_size_t,
562+
ctypes.c_size_t,
563+
ctypes.CFUNCTYPE(
564+
ctypes.c_int,
565+
ctypes.c_void_p,
566+
ctypes.c_void_p,
567+
),
568+
)
569+
dll.qsort.restype = None
570+
# void *bsearch(
571+
# const void *key,
572+
# const void *base,
573+
# size_t nitems,
574+
# size_t size,
575+
# int (*compar)(const void *, const void *)
576+
# )
577+
dll.bsearch.argtypes = (
578+
ctypes.c_void_p,
579+
ctypes.c_void_p,
580+
ctypes.c_size_t,
581+
ctypes.c_size_t,
582+
ctypes.CFUNCTYPE(
583+
ctypes.c_int,
584+
ctypes.c_void_p,
585+
ctypes.c_void_p,
586+
),
587+
)
588+
dll.bsearch.restype = ctypes.c_void_p
551589

552590
c_malloc = dll.malloc
553591
c_free = dll.free

src/pointers/bindings.py

Lines changed: 177 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
11
import ctypes
2+
import inspect
3+
from types import FunctionType
24
from typing import (
3-
TYPE_CHECKING, Any, Dict, Iterator, Optional, Tuple, Type, TypeVar, Union
5+
TYPE_CHECKING, Any, Callable, Dict, Iterable, Iterator, Optional, Sequence,
6+
Type, TypeVar, Union
47
)
58

9+
from _pointers import add_ref
10+
611
from . import _cstd
712
from ._cstd import STRUCT_MAP, DivT, Lconv, LDivT, Tm
813
from ._cstd import c_calloc as _calloc
@@ -19,10 +24,12 @@
1924
from .struct import Struct
2025

2126
T = TypeVar("T")
27+
2228
PointerLike = Union[TypedCPointer[Any], VoidPointer, None]
2329
StringLike = Union[str, bytes, VoidPointer, TypedCPointer[bytes]]
2430
Format = Union[StringLike, PointerLike]
2531
TypedPtr = Optional[TypedCPointer[T]]
32+
PyCFuncPtrType = type(ctypes.CFUNCTYPE(None))
2633

2734
__all__ = (
2835
"isalnum",
@@ -134,6 +141,14 @@
134141
"c_realloc",
135142
"c_free",
136143
"gmtime",
144+
"signal",
145+
"qsort",
146+
"bsearch",
147+
"sizeof",
148+
"PointerLike",
149+
"StringLike",
150+
"Format",
151+
"TypedPtr",
137152
)
138153

139154

@@ -145,10 +160,29 @@ def _not_null(data: Optional[T]) -> T:
145160
StructMap = Dict[Type[ctypes.Structure], Type["Struct"]]
146161

147162

148-
def _decode_response(
163+
class _CFuncTransport:
164+
def __init__(
165+
self,
166+
c_func: "ctypes._FuncPointer",
167+
py_func: Callable,
168+
) -> None:
169+
add_ref(c_func)
170+
self._c_func = c_func
171+
self._py_func = py_func
172+
173+
@property
174+
def c_func(self) -> "ctypes._FuncPointer":
175+
return self._c_func
176+
177+
@property
178+
def py_func(self) -> Callable:
179+
return self._py_func
180+
181+
182+
def _decode_type(
149183
res: Any,
150184
struct_map: StructMap,
151-
fn: "ctypes._NamedFuncPointer",
185+
current: Optional[Type["ctypes._CData"]],
152186
) -> Any:
153187
res_typ = type(res)
154188

@@ -164,7 +198,7 @@ def _decode_response(
164198
else StructPointer(id(struct), type(_not_null(struct)), struct)
165199
)
166200
# type safety gets mad if i dont use elif here
167-
elif fn.restype is ctypes.c_void_p:
201+
elif current is ctypes.c_void_p:
168202
res = VoidPointer(res, ctypes.sizeof(ctypes.c_void_p(res)))
169203

170204
elif issubclass(res_typ, ctypes.Structure):
@@ -175,22 +209,49 @@ def _decode_response(
175209
return res
176210

177211

178-
def _validate_args(
179-
args: Tuple[Any, ...],
212+
def _decode_response(
213+
res: Any,
214+
struct_map: StructMap,
180215
fn: "ctypes._NamedFuncPointer",
216+
) -> Any:
217+
return _decode_type(res, struct_map, fn.restype) # type: ignore
218+
219+
220+
def _process_args(
221+
args: Iterable[Any],
222+
argtypes: Sequence[Type["ctypes._CData"]],
223+
name: str,
181224
) -> None:
182-
if not fn.argtypes:
183-
return
225+
for index, (value, typ) in enumerate(zip(args, argtypes)):
226+
if value is inspect._empty:
227+
continue
228+
229+
if isinstance(value, _CFuncTransport):
230+
py_func = value.py_func
231+
sig = inspect.signature(py_func)
232+
_process_args(
233+
[param.annotation for param in sig.parameters.values()],
234+
value.c_func._argtypes_, # type: ignore
235+
py_func.__name__,
236+
)
237+
continue
238+
is_c_func: bool = isinstance(
239+
typ,
240+
PyCFuncPtrType,
241+
)
242+
n_type = VoidPointer.get_py(typ) if not is_c_func else FunctionType
184243

185-
for index, (value, typ) in enumerate(zip(args, fn.argtypes)):
186-
n_type = VoidPointer.get_py(typ)
244+
is_type: bool = isinstance(value, type)
187245

188-
if not isinstance(value, n_type):
189-
v_type = type(value)
246+
if not (isinstance if not is_type else issubclass)(value, n_type):
247+
v_type = type(value) if not is_type else value
190248

191249
if (n_type is Pointer) and (value is None):
192250
continue
193251

252+
if (n_type is FunctionType) and is_c_func:
253+
continue
254+
194255
if (
195256
typ
196257
in {
@@ -207,21 +268,80 @@ def _validate_args(
207268
continue
208269

209270
raise InvalidBindingParameter(
210-
f"argument {index + 1} got invalid type: expected {n_type.__name__}, got {v_type.__name__}" # noqa
271+
f"argument {index + 1} of {name} got invalid type: expected {n_type.__name__}, got {v_type.__name__}" # noqa
211272
)
212273

213274

275+
def _validate_args(
276+
args: Iterable[Any],
277+
fn: "ctypes._NamedFuncPointer",
278+
) -> None:
279+
if not fn.argtypes:
280+
return
281+
282+
_process_args(args, fn.argtypes, fn.__name__)
283+
284+
285+
def _solve_func(
286+
fn: Callable,
287+
ct_fn: "ctypes._FuncPointer",
288+
struct_map: StructMap,
289+
) -> _CFuncTransport:
290+
at = ct_fn._argtypes_ # type: ignore
291+
292+
@ctypes.CFUNCTYPE(ct_fn._restype_, *at) # type: ignore
293+
def wrapper(*args):
294+
callback_args = []
295+
296+
for value, ctype in zip(args, at):
297+
callback_args.append(_decode_type(value, struct_map, ctype))
298+
299+
return fn(*callback_args)
300+
301+
return _CFuncTransport(wrapper, fn)
302+
303+
214304
def _base(
215305
fn: "ctypes._NamedFuncPointer",
216306
*args,
217307
map_extra: Optional[StructMap] = None,
218308
) -> Any:
219-
_validate_args(args, fn)
220-
res = fn(*args)
309+
smap = {**STRUCT_MAP, **(map_extra or {})}
310+
311+
validator_args = [
312+
arg
313+
if ((not callable(arg)) and (not isinstance(arg, PyCFuncPtrType)))
314+
else _solve_func(
315+
arg,
316+
typ, # type: ignore
317+
smap,
318+
)
319+
for arg, typ in zip(
320+
args,
321+
fn.argtypes or [None for _ in args], # type: ignore
322+
)
323+
]
324+
325+
_validate_args(
326+
validator_args,
327+
fn,
328+
)
329+
330+
res = fn(
331+
*[
332+
i
333+
if not isinstance(
334+
i,
335+
_CFuncTransport,
336+
)
337+
else i.c_func
338+
for i in validator_args
339+
]
340+
)
221341

222342
return _decode_response(
223343
res,
224-
{**STRUCT_MAP, **(map_extra or {})},
344+
smap,
225345
fn,
226346
)
227347

@@ -246,7 +366,11 @@ def _make_char_pointer(data: StringLike) -> Union[bytes, ctypes.c_char_p]:
246366

247367
return ctypes.c_char_p(data.address)
248368

249-
return data.encode()
369+
if isinstance(data, str):
370+
return data.encode()
371+
372+
assert isinstance(data, ctypes.c_char_p), f"{data} is not a char*"
373+
return data
250374

251375

252376
def _make_format(*args: Format) -> Iterator[Format]:
@@ -866,3 +990,39 @@ def c_free(ptr: PointerLike) -> None:
866990

867991
def gmtime(timer: PointerLike) -> StructPointer[Tm]:
868992
return _base(dll.gmtime, timer)
993+
994+
995+
def signal(signum: int, func: Callable[[int, None], int]) -> None:
996+
return _base(dll.signal, signum, func)
997+
998+
999+
def qsort(
1000+
base: PointerLike,
1001+
nitem: int,
1002+
size: int,
1003+
compar: Callable[
1004+
[Any, Any],
1005+
int,
1006+
],
1007+
) -> None:
1008+
return _base(dll.qsort, base, nitem, size, compar)
1009+
1010+
1011+
def bsearch(
1012+
key: PointerLike,
1013+
base: PointerLike,
1014+
nitems: int,
1015+
size: int,
1016+
compar: Callable[
1017+
[Any, Any],
1018+
int,
1019+
],
1020+
) -> VoidPointer:
1021+
return _base(dll.bsearch, key, base, nitems, size, compar)
1022+
1023+
1024+
def sizeof(obj: Any) -> int:
1025+
try:
1026+
return ctypes.sizeof(obj)
1027+
except TypeError:
1028+
return ctypes.sizeof(VoidPointer.get_mapped(obj))

0 commit comments

Comments
 (0)