11import ctypes
2+ import inspect
3+ from types import FunctionType
24from typing import (
3- TYPE_CHECKING , Any , Dict , Iterator , Optional , Tuple , Type , TypeVar , Union
5+ TYPE_CHECKING , Any , Callable , Dict , Iterable , Iterator , Optional , Sequence ,
6+ Type , TypeVar , Union
47)
58
9+ from _pointers import add_ref
10+
611from . import _cstd
712from ._cstd import STRUCT_MAP , DivT , Lconv , LDivT , Tm
813from ._cstd import c_calloc as _calloc
1924 from .struct import Struct
2025
2126T = TypeVar ("T" )
27+
2228PointerLike = Union [TypedCPointer [Any ], VoidPointer , None ]
2329StringLike = Union [str , bytes , VoidPointer , TypedCPointer [bytes ]]
2430Format = Union [StringLike , PointerLike ]
2531TypedPtr = Optional [TypedCPointer [T ]]
32+ PyCFuncPtrType = type (ctypes .CFUNCTYPE (None ))
2633
2734__all__ = (
2835 "isalnum" ,
134141 "c_realloc" ,
135142 "c_free" ,
136143 "gmtime" ,
144+ "signal" ,
145+ "qsort" ,
146+ "bsearch" ,
147+ "sizeof" ,
148+ "PointerLike" ,
149+ "StringLike" ,
150+ "Format" ,
151+ "TypedPtr" ,
137152)
138153
139154
@@ -145,10 +160,29 @@ def _not_null(data: Optional[T]) -> T:
145160StructMap = Dict [Type [ctypes .Structure ], Type ["Struct" ]]
146161
147162
148- def _decode_response (
163+ class _CFuncTransport :
164+ def __init__ (
165+ self ,
166+ c_func : "ctypes._FuncPointer" ,
167+ py_func : Callable ,
168+ ) -> None :
169+ add_ref (c_func )
170+ self ._c_func = c_func
171+ self ._py_func = py_func
172+
173+ @property
174+ def c_func (self ) -> "ctypes._FuncPointer" :
175+ return self ._c_func
176+
177+ @property
178+ def py_func (self ) -> Callable :
179+ return self ._py_func
180+
181+
182+ def _decode_type (
149183 res : Any ,
150184 struct_map : StructMap ,
151- fn : "ctypes._NamedFuncPointer" ,
185+ current : Optional [ Type [ "ctypes._CData" ]] ,
152186) -> Any :
153187 res_typ = type (res )
154188
@@ -164,7 +198,7 @@ def _decode_response(
164198 else StructPointer (id (struct ), type (_not_null (struct )), struct )
165199 )
166200 # type safety gets mad if i dont use elif here
167- elif fn . restype is ctypes .c_void_p :
201+ elif current is ctypes .c_void_p :
168202 res = VoidPointer (res , ctypes .sizeof (ctypes .c_void_p (res )))
169203
170204 elif issubclass (res_typ , ctypes .Structure ):
@@ -175,22 +209,49 @@ def _decode_response(
175209 return res
176210
177211
178- def _validate_args (
179- args : Tuple [Any , ...],
212+ def _decode_response (
213+ res : Any ,
214+ struct_map : StructMap ,
180215 fn : "ctypes._NamedFuncPointer" ,
216+ ) -> Any :
217+ return _decode_type (res , struct_map , fn .restype ) # type: ignore
218+
219+
220+ def _process_args (
221+ args : Iterable [Any ],
222+ argtypes : Sequence [Type ["ctypes._CData" ]],
223+ name : str ,
181224) -> None :
182- if not fn .argtypes :
183- return
225+ for index , (value , typ ) in enumerate (zip (args , argtypes )):
226+ if value is inspect ._empty :
227+ continue
228+
229+ if isinstance (value , _CFuncTransport ):
230+ py_func = value .py_func
231+ sig = inspect .signature (py_func )
232+ _process_args (
233+ [param .annotation for param in sig .parameters .values ()],
234+ value .c_func ._argtypes_ , # type: ignore
235+ py_func .__name__ ,
236+ )
237+ continue
238+ is_c_func : bool = isinstance (
239+ typ ,
240+ PyCFuncPtrType ,
241+ )
242+ n_type = VoidPointer .get_py (typ ) if not is_c_func else FunctionType
184243
185- for index , (value , typ ) in enumerate (zip (args , fn .argtypes )):
186- n_type = VoidPointer .get_py (typ )
244+ is_type : bool = isinstance (value , type )
187245
188- if not isinstance (value , n_type ):
189- v_type = type (value )
246+ if not ( isinstance if not is_type else issubclass ) (value , n_type ):
247+ v_type = type (value ) if not is_type else value
190248
191249 if (n_type is Pointer ) and (value is None ):
192250 continue
193251
252+ if (n_type is FunctionType ) and is_c_func :
253+ continue
254+
194255 if (
195256 typ
196257 in {
@@ -207,21 +268,80 @@ def _validate_args(
207268 continue
208269
209270 raise InvalidBindingParameter (
210- f"argument { index + 1 } got invalid type: expected { n_type .__name__ } , got { v_type .__name__ } " # noqa
271+ f"argument { index + 1 } of { name } got invalid type: expected { n_type .__name__ } , got { v_type .__name__ } " # noqa
211272 )
212273
213274
275+ def _validate_args (
276+ args : Iterable [Any ],
277+ fn : "ctypes._NamedFuncPointer" ,
278+ ) -> None :
279+ if not fn .argtypes :
280+ return
281+
282+ _process_args (args , fn .argtypes , fn .__name__ )
283+
284+
285+ def _solve_func (
286+ fn : Callable ,
287+ ct_fn : "ctypes._FuncPointer" ,
288+ struct_map : StructMap ,
289+ ) -> _CFuncTransport :
290+ at = ct_fn ._argtypes_ # type: ignore
291+
292+ @ctypes .CFUNCTYPE (ct_fn ._restype_ , * at ) # type: ignore
293+ def wrapper (* args ):
294+ callback_args = []
295+
296+ for value , ctype in zip (args , at ):
297+ callback_args .append (_decode_type (value , struct_map , ctype ))
298+
299+ return fn (* callback_args )
300+
301+ return _CFuncTransport (wrapper , fn )
302+
303+
214304def _base (
215305 fn : "ctypes._NamedFuncPointer" ,
216306 * args ,
217307 map_extra : Optional [StructMap ] = None ,
218308) -> Any :
219- _validate_args (args , fn )
220- res = fn (* args )
309+ smap = {** STRUCT_MAP , ** (map_extra or {})}
310+
311+ validator_args = [
312+ arg
313+ if ((not callable (arg )) and (not isinstance (arg , PyCFuncPtrType )))
314+ else _solve_func (
315+ arg ,
316+ typ , # type: ignore
317+ smap ,
318+ )
319+ for arg , typ in zip (
320+ args ,
321+ fn .argtypes or [None for _ in args ], # type: ignore
322+ )
323+ ]
324+
325+ _validate_args (
326+ validator_args ,
327+ fn ,
328+ )
329+
330+ res = fn (
331+ * [
332+ i
333+ if not isinstance (
334+ i ,
335+ _CFuncTransport ,
336+ )
337+ else i .c_func
338+ for i in validator_args
339+ ]
340+ )
221341
222342 return _decode_response (
223343 res ,
224- { ** STRUCT_MAP , ** ( map_extra or {})} ,
344+ smap ,
225345 fn ,
226346 )
227347
@@ -246,7 +366,11 @@ def _make_char_pointer(data: StringLike) -> Union[bytes, ctypes.c_char_p]:
246366
247367 return ctypes .c_char_p (data .address )
248368
249- return data .encode ()
369+ if isinstance (data , str ):
370+ return data .encode ()
371+
372+ assert isinstance (data , ctypes .c_char_p ), f"{ data } is not a char*"
373+ return data
250374
251375
252376def _make_format (* args : Format ) -> Iterator [Format ]:
@@ -866,3 +990,39 @@ def c_free(ptr: PointerLike) -> None:
866990
867991def gmtime (timer : PointerLike ) -> StructPointer [Tm ]:
868992 return _base (dll .gmtime , timer )
993+
994+
995+ def signal (signum : int , func : Callable [[int , None ], int ]) -> None :
996+ return _base (dll .signal , signum , func )
997+
998+
999+ def qsort (
1000+ base : PointerLike ,
1001+ nitem : int ,
1002+ size : int ,
1003+ compar : Callable [
1004+ [Any , Any ],
1005+ int ,
1006+ ],
1007+ ) -> None :
1008+ return _base (dll .qsort , base , nitem , size , compar )
1009+
1010+
1011+ def bsearch (
1012+ key : PointerLike ,
1013+ base : PointerLike ,
1014+ nitems : int ,
1015+ size : int ,
1016+ compar : Callable [
1017+ [Any , Any ],
1018+ int ,
1019+ ],
1020+ ) -> VoidPointer :
1021+ return _base (dll .bsearch , key , base , nitems , size , compar )
1022+
1023+
1024+ def sizeof (obj : Any ) -> int :
1025+ try :
1026+ return ctypes .sizeof (obj )
1027+ except TypeError :
1028+ return ctypes .sizeof (VoidPointer .get_mapped (obj ))
0 commit comments