1313# limitations under the License.
1414"""Generates C++ lambda functions inside pybind11 bindings code."""
1515
16- from typing import Generator , List , Optional , Set
16+ from typing import Generator , List , Optional
1717
1818from clif .protos import ast_pb2
1919from clif .pybind11 import function_lib
@@ -62,27 +62,6 @@ def generate_lambda(
6262 yield f'}}, { function_suffix } '
6363
6464
65- def needs_lambda (
66- func_decl : ast_pb2 .FuncDecl , codegen_info : utils .CodeGenInfo ,
67- class_decl : Optional [ast_pb2 .ClassDecl ] = None ) -> bool :
68- if class_decl and _is_inherited_method (class_decl , func_decl ):
69- return True
70- return (bool (func_decl .postproc ) or
71- func_decl .is_overloaded or
72- _func_is_extend_static_method (func_decl , class_decl ) or
73- function_lib .func_has_vector_param (func_decl ) or
74- function_lib .func_has_set_param (func_decl ) or
75- _func_is_context_manager (func_decl ) or
76- _func_needs_index_check (func_decl ) or
77- _func_has_capsule_params (func_decl , codegen_info .capsule_types ) or
78- _func_needs_implicit_conversion (func_decl ) or
79- _func_has_pointer_params (func_decl ) or
80- function_lib .func_has_py_object_params (func_decl ) or
81- _func_has_status_params (func_decl , codegen_info .requires_status ) or
82- _func_has_status_callback (func_decl , codegen_info .requires_status ) or
83- func_decl .cpp_num_params != len (func_decl .params ))
84-
85-
8665def generate_check_nullptr (
8766 func_decl : ast_pb2 .FuncDecl , param_name : str ) -> Generator [str , None , None ]:
8867 yield I + f'if ({ param_name } == nullptr) {{'
@@ -125,9 +104,6 @@ def _generate_lambda_body(
125104 function_call = generate_function_call (func_decl , class_decl )
126105 params_str = ', ' .join ([p .function_argument for p in params ])
127106 function_call_params = generate_function_call_params (func_decl , params_str )
128- function_call_returns = generate_function_call_returns (
129- func_decl , codegen_info .capsule_types )
130-
131107 cpp_void_return = func_decl .cpp_void_return or not func_decl .returns
132108
133109 # Generates void pointer check for parameters that are converted from non
@@ -158,32 +134,42 @@ def _generate_lambda_body(
158134 yield I + '}'
159135 yield I + f'{ index .gen_name } = { index .gen_name } _;'
160136
161- # Generates declarations of return values
137+ if not cpp_void_return :
138+ yield I + 'py::object ret0;'
139+
140+ # Generates declarations of pointer return values outside of scope
162141 for i , r in enumerate (func_decl .returns ):
163142 if i or cpp_void_return :
164143 yield I + f'{ r .type .cpp_type } ret{ i } {{}};'
165144
145+ yield I + '{'
166146 if not function_lib .func_keeps_gil (func_decl ):
167- yield I + 'PyThreadState* _save;'
168- yield I + 'Py_UNBLOCK_THREADS'
147+ yield I + I + 'py::gil_scoped_release gil_release;'
169148
170149 # Generates call to the wrapped function
171150 cpp_void_return = func_decl .cpp_void_return or not func_decl .returns
151+ ret0_with_py_cast = ''
172152 if not cpp_void_return :
173153 ret0 = func_decl .returns [0 ]
174154 if not ret0 .type .cpp_type :
175155 callback_cpp_type = function_lib .generate_callback_signature (ret0 )
176- yield I + (f'{ callback_cpp_type } ret0 = '
177- f'{ function_call } ({ function_call_params } );' )
156+ yield I + I + (f'{ callback_cpp_type } ret0_ = '
157+ f'{ function_call } ({ function_call_params } );' )
178158 else :
179- yield I + (f'{ ret0 .type .cpp_type } ret0 = '
180- f'{ function_call } ({ function_call_params } );' )
159+ yield I + I + (f'{ ret0 .type .cpp_type } ret0_ = '
160+ f'{ function_call } ({ function_call_params } );' )
161+ ret0_with_py_cast = generate_function_call_return (
162+ func_decl , ret0 , 'ret0_' , codegen_info )
163+ yield I + I + '{'
164+ yield I + I + I + 'py::gil_scoped_acquire gil_acquire;'
165+ yield I + I + I + f'ret0 = { ret0_with_py_cast } ;'
166+ yield I + I + '}'
181167 else :
182- yield I + f'{ function_call } ({ function_call_params } );'
183-
184- if not function_lib .func_keeps_gil (func_decl ):
185- yield I + 'Py_BLOCK_THREADS'
168+ yield I + I + f'{ function_call } ({ function_call_params } );'
169+ yield I + '}'
186170
171+ function_call_returns = generate_function_call_returns (
172+ func_decl , codegen_info )
187173 # Generates returns of the lambda expression
188174 self_param = 'self'
189175 if func_decl .is_extend_method and len (params ):
@@ -212,32 +198,35 @@ def generate_function_call_params(
212198
213199
214200def generate_function_call_returns (
215- func_decl : ast_pb2 .FuncDecl , capsule_types : Set [str ],
216- requires_status : bool = True ) -> str :
201+ func_decl : ast_pb2 .FuncDecl , codegen_info : utils .CodeGenInfo ) -> str :
217202 """Generates return values of cpp function."""
218203 all_returns_list = []
219204 for i , r in enumerate (func_decl .returns ):
220- if function_lib .is_bytes_type (r .type ):
221- all_returns_list .append (
222- f'py::cast(ret{ i } , py::return_value_policy::_return_as_bytes)' )
223- elif r .type .lang_type in capsule_types :
224- all_returns_list .append (
225- f'clif::CapsuleWrapper<{ r .type .cpp_type } >(ret{ i } )' )
226- elif function_lib .is_status_param (r , requires_status ):
227- status_type = function_lib .generate_status_type (func_decl , r )
228- all_returns_list .append (f'py::cast(({ status_type } )(std::move(ret{ i } )), '
229- 'py::return_value_policy::_clif_automatic)' )
230- # When the lambda expression returns multiple values, we construct an
231- # `std::tuple` with those return values. For uncopyable return values, we
232- # need `std::move` when constructing the `std::tuple`.
233- elif (len (func_decl .returns ) > 1 and
234- ('std::unique_ptr' in r .cpp_exact_type or not r .type .cpp_copyable )):
235- all_returns_list .append (f'std::move(ret{ i } )' )
205+ if i == 0 and not func_decl .cpp_void_return :
206+ all_returns_list .append ('ret0' )
236207 else :
237- all_returns_list .append (f'ret{ i } ' )
208+ ret = generate_function_call_return (func_decl , r , f'ret{ i } ' , codegen_info )
209+ all_returns_list .append (ret )
238210 return ', ' .join (all_returns_list )
239211
240212
213+ def generate_function_call_return (
214+ func_decl : ast_pb2 .FuncDecl , return_value : ast_pb2 .ParamDecl ,
215+ return_value_name : str , codegen_info : utils .CodeGenInfo ) -> str :
216+ """Generates return values of cpp function."""
217+ ret = f'std::move({ return_value_name } )'
218+ return_value_policy = 'py::return_value_policy::_clif_automatic'
219+ if function_lib .is_bytes_type (return_value .type ):
220+ return_value_policy = 'py::return_value_policy::_return_as_bytes'
221+ elif return_value .type .lang_type in codegen_info .capsule_types :
222+ ret = (f'clif::CapsuleWrapper<{ return_value .type .cpp_type } >'
223+ f'({ return_value_name } )' )
224+ elif function_lib .is_status_param (return_value , codegen_info .requires_status ):
225+ status_type = function_lib .generate_status_type (func_decl , return_value )
226+ ret = f'({ status_type } )(std::move({ return_value_name } ))'
227+ return f'py::cast({ ret } , { return_value_policy } )'
228+
229+
241230def _generate_lambda_params_with_types (
242231 func_decl : ast_pb2 .FuncDecl ,
243232 params : List [function_lib .Parameter ],
@@ -264,87 +253,6 @@ def generate_function_call(
264253 return f'self.{ method_name } '
265254
266255
267- def _func_is_extend_static_method (
268- func_decl : ast_pb2 .FuncDecl ,
269- class_decl : Optional [ast_pb2 .ClassDecl ] = None ) -> bool :
270- return class_decl and func_decl .is_extend_method and func_decl .classmethod
271-
272-
273- def _func_has_pointer_params (func_decl : ast_pb2 .FuncDecl ) -> bool :
274- num_returns = len (func_decl .returns )
275- return num_returns >= 2 or (num_returns == 1 and func_decl .cpp_void_return )
276-
277-
278- def _func_has_status_params (func_decl : ast_pb2 .FuncDecl ,
279- requires_status : bool ) -> bool :
280- for p in func_decl .params :
281- if function_lib .is_status_param (p , requires_status ):
282- return True
283- for r in func_decl .returns :
284- if function_lib .is_status_param (r , requires_status ):
285- return True
286- return False
287-
288-
289- def _func_has_status_callback (func_decl : ast_pb2 .FuncDecl ,
290- requires_status : bool ) -> bool :
291- for r in func_decl .returns :
292- if function_lib .is_status_callback (r , requires_status ):
293- return True
294- return False
295-
296-
297- def _func_has_capsule_params (
298- func_decl : ast_pb2 .FuncDecl , capsule_types : Set [str ]) -> bool :
299- for p in func_decl .params :
300- if p .type .lang_type in capsule_types :
301- return True
302- for r in func_decl .returns :
303- if r .type .lang_type in capsule_types :
304- return True
305- return False
306-
307-
308- def _func_is_context_manager (func_decl : ast_pb2 .FuncDecl ) -> bool :
309- return func_decl .name .native in ('__enter__@' , '__exit__@' )
310-
311-
312- def _func_needs_index_check (func_decl : ast_pb2 .FuncDecl ) -> bool :
313- return func_decl .name .native in _NEEDS_INDEX_CHECK_METHODS
314-
315-
316- def _is_inherited_method (class_decl : ast_pb2 .ClassDecl ,
317- func_decl : ast_pb2 .FuncDecl ) -> bool :
318- if class_decl .cpp_bases and not func_decl .is_extend_method :
319- namespaces = func_decl .name .cpp_name .split ('::' )
320- if (len (namespaces ) > 1 and
321- namespaces [- 2 ] != class_decl .name .cpp_name .strip (':' )):
322- return True
323- return False
324-
325-
326- def _func_needs_implicit_conversion (func_decl : ast_pb2 .FuncDecl ) -> bool :
327- """Check if a function contains an implicitly converted parameter."""
328- for param in func_decl .params :
329- if (_extract_bare_type (param .cpp_exact_type ) !=
330- _extract_bare_type (param .type .cpp_type ) and
331- param .type .cpp_toptr_conversion and
332- param .type .cpp_touniqptr_conversion ):
333- return True
334- return False
335-
336-
337- def _extract_bare_type (cpp_name : str ) -> str :
338- # This helper function is not general and only meant
339- # to be used in _func_needs_implicit_conversion.
340- t = cpp_name .split (' ' )
341- if t [0 ] == 'const' :
342- t = t [1 :]
343- if t [- 1 ] in {'&' , '*' }: # Minimum viable approach. To be refined as needed.
344- t = t [:- 1 ]
345- return ' ' .join (t )
346-
347-
348256def generate_return_value_cpp_type (
349257 func_decl : ast_pb2 .FuncDecl , codegen_info : utils .CodeGenInfo ) -> str :
350258 """Generates type for the return value of the C++ function."""
0 commit comments