Skip to content
This repository was archived by the owner on Nov 11, 2025. It is now read-only.

Commit b2070c9

Browse files
wangxfrwgk
authored andcommitted
Generate lambda expressions for all functions.
This is to make function type cast generation unnecessary. We use function type cast for overloaded functions, but sometimes the type cast might be referencing private types. For example, https://fusion2.corp.google.com/presubmit/tap/492042529/OCL:492903434:BASE:493136442:1670284688940:bfb8ae74;filter=/targets/invocations/557e8169-3ff7-46ab-802c-a1000f26a981/targets/%2F%2Fnlp%2Fsemantic_parsing%2Flearning%2Fneural%2Fdata_collection%2Fdata_mutation%2Finput_context%2Fpython:input_context_test/log (failed). There is no easy way to fix this, except not generating type casts. Also removed the logic for return value policy overrides because now we generate `py::cast(std::move(ret))` for all returns in the lambda expressions, so the overrides do not work anymore. TGP: https://fusion2.corp.google.com/presubmit/tap/487383037/OCL:487383037:BASE:493634501:1670436092899:5a58878;groups=PossiblyAlreadyFailing,PossiblyNewlyFailing/targets (~8000 failures). PiperOrigin-RevId: 493723278
1 parent fd93e12 commit b2070c9

File tree

7 files changed

+47
-244
lines changed

7 files changed

+47
-244
lines changed

clif/pybind11/function.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -86,24 +86,9 @@ def _generate_function(
8686
"""Generates pybind11 bindings code for ast_pb2.FuncDecl."""
8787
if operators.needs_operator_overloading(func_decl):
8888
yield from operators.generate_operator(module_name, func_decl, class_decl)
89-
elif lambdas.needs_lambda(func_decl, codegen_info, class_decl):
89+
else:
9090
yield from lambdas.generate_lambda(
9191
module_name, func_decl, codegen_info, class_decl)
92-
else:
93-
yield from _generate_simple_function(module_name, func_decl, class_decl)
94-
95-
96-
def _generate_simple_function(
97-
module_name: str, func_decl: ast_pb2.FuncDecl,
98-
class_decl: Optional[ast_pb2.ClassDecl] = None
99-
) -> Generator[str, None, None]:
100-
func_name = func_decl.name.native.rstrip('#') # @sequential
101-
yield f'{module_name}.{function_lib.generate_def(func_decl)}("{func_name}",'
102-
yield I + function_lib.generate_cpp_function_cast(func_decl, class_decl)
103-
yield I + f'&{func_decl.name.cpp_name},'
104-
is_member_function = (class_decl is not None)
105-
yield I + function_lib.generate_function_suffixes(
106-
func_decl, is_member_function=is_member_function)
10792

10893

10994
def _generate_overload_for_unknown_default_function(

clif/pybind11/function_lib.py

Lines changed: 1 addition & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -102,20 +102,6 @@ def is_cpp_set(param_type: ast_pb2.Type) -> bool:
102102
param_type.lang_type.startswith('set'))
103103

104104

105-
def func_has_vector_param(func_decl: ast_pb2.FuncDecl) -> bool:
106-
for param in func_decl.params:
107-
if is_cpp_vector(param.type):
108-
return True
109-
return False
110-
111-
112-
def func_has_set_param(func_decl: ast_pb2.FuncDecl) -> bool:
113-
for param in func_decl.params:
114-
if is_cpp_set(param.type):
115-
return True
116-
return False
117-
118-
119105
def num_unknown_default_values(func_decl: ast_pb2.FuncDecl) -> int:
120106
num_unknown = 0
121107
for param in func_decl.params:
@@ -176,7 +162,7 @@ def generate_function_suffixes(
176162
suffix = ''
177163
if py_args:
178164
suffix += f'{py_args}, '
179-
suffix += f'{generate_return_value_policy(func_decl)}'
165+
suffix += 'py::return_value_policy::_clif_automatic'
180166
if func_decl.docstring:
181167
suffix += f', {generate_docstring(func_decl.docstring)}'
182168
if release_gil and not func_decl.py_keep_gil:
@@ -377,38 +363,3 @@ def func_keeps_gil(func_decl: ast_pb2.FuncDecl) -> bool:
377363
if func_decl.py_keep_gil:
378364
return True
379365
return False
380-
381-
382-
def generate_return_value_policy(func_decl: ast_pb2.FuncDecl) -> str:
383-
"""Generates pybind11 return value policy based on function return type.
384-
385-
Emulates the behavior of the generated Python C API code.
386-
387-
Args:
388-
func_decl: The function declaration that needs to be processed.
389-
390-
Returns:
391-
pybind11 return value policy based on the function return value.
392-
"""
393-
prefix = 'py::return_value_policy::'
394-
if has_bytes_return(func_decl):
395-
return prefix + '_return_as_bytes'
396-
if func_decl.cpp_void_return or not func_decl.returns:
397-
return prefix + 'automatic'
398-
if func_decl.HasField('return_value_policy'):
399-
if func_decl.return_value_policy == ast_pb2.FuncDecl.AUTOMATIC_REFERENCE:
400-
return prefix + 'automatic_reference'
401-
if func_decl.return_value_policy == ast_pb2.FuncDecl.TAKE_OWNERSHIP:
402-
return prefix + 'take_ownership'
403-
if func_decl.return_value_policy == ast_pb2.FuncDecl.COPY:
404-
return prefix + 'copy'
405-
if func_decl.return_value_policy == ast_pb2.FuncDecl.MOVE:
406-
return prefix + 'move'
407-
if func_decl.return_value_policy == ast_pb2.FuncDecl.REFERENCE:
408-
return prefix + 'reference'
409-
if func_decl.return_value_policy == ast_pb2.FuncDecl.REFERENCE_INTERNAL:
410-
return prefix + 'reference_internal'
411-
if func_decl.return_value_policy == ast_pb2.FuncDecl.RETURN_AS_BYTES:
412-
return prefix + '_return_as_bytes'
413-
return prefix + 'automatic'
414-
return prefix + '_clif_automatic'

clif/pybind11/lambdas.py

Lines changed: 44 additions & 136 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414
"""Generates C++ lambda functions inside pybind11 bindings code."""
1515

16-
from typing import Generator, List, Optional, Set
16+
from typing import Generator, List, Optional
1717

1818
from clif.protos import ast_pb2
1919
from clif.pybind11 import function_lib
@@ -62,27 +62,6 @@ def generate_lambda(
6262
yield f'}}, {function_suffix}'
6363

6464

65-
def needs_lambda(
66-
func_decl: ast_pb2.FuncDecl, codegen_info: utils.CodeGenInfo,
67-
class_decl: Optional[ast_pb2.ClassDecl] = None) -> bool:
68-
if class_decl and _is_inherited_method(class_decl, func_decl):
69-
return True
70-
return (bool(func_decl.postproc) or
71-
func_decl.is_overloaded or
72-
_func_is_extend_static_method(func_decl, class_decl) or
73-
function_lib.func_has_vector_param(func_decl) or
74-
function_lib.func_has_set_param(func_decl) or
75-
_func_is_context_manager(func_decl) or
76-
_func_needs_index_check(func_decl) or
77-
_func_has_capsule_params(func_decl, codegen_info.capsule_types) or
78-
_func_needs_implicit_conversion(func_decl) or
79-
_func_has_pointer_params(func_decl) or
80-
function_lib.func_has_py_object_params(func_decl) or
81-
_func_has_status_params(func_decl, codegen_info.requires_status) or
82-
_func_has_status_callback(func_decl, codegen_info.requires_status) or
83-
func_decl.cpp_num_params != len(func_decl.params))
84-
85-
8665
def generate_check_nullptr(
8766
func_decl: ast_pb2.FuncDecl, param_name: str) -> Generator[str, None, None]:
8867
yield I + f'if ({param_name} == nullptr) {{'
@@ -125,9 +104,6 @@ def _generate_lambda_body(
125104
function_call = generate_function_call(func_decl, class_decl)
126105
params_str = ', '.join([p.function_argument for p in params])
127106
function_call_params = generate_function_call_params(func_decl, params_str)
128-
function_call_returns = generate_function_call_returns(
129-
func_decl, codegen_info.capsule_types)
130-
131107
cpp_void_return = func_decl.cpp_void_return or not func_decl.returns
132108

133109
# Generates void pointer check for parameters that are converted from non
@@ -158,32 +134,42 @@ def _generate_lambda_body(
158134
yield I +'}'
159135
yield I + f'{index.gen_name} = {index.gen_name}_;'
160136

161-
# Generates declarations of return values
137+
if not cpp_void_return:
138+
yield I + 'py::object ret0;'
139+
140+
# Generates declarations of pointer return values outside of scope
162141
for i, r in enumerate(func_decl.returns):
163142
if i or cpp_void_return:
164143
yield I + f'{r.type.cpp_type} ret{i}{{}};'
165144

145+
yield I + '{'
166146
if not function_lib.func_keeps_gil(func_decl):
167-
yield I + 'PyThreadState* _save;'
168-
yield I + 'Py_UNBLOCK_THREADS'
147+
yield I + I + 'py::gil_scoped_release gil_release;'
169148

170149
# Generates call to the wrapped function
171150
cpp_void_return = func_decl.cpp_void_return or not func_decl.returns
151+
ret0_with_py_cast = ''
172152
if not cpp_void_return:
173153
ret0 = func_decl.returns[0]
174154
if not ret0.type.cpp_type:
175155
callback_cpp_type = function_lib.generate_callback_signature(ret0)
176-
yield I + (f'{callback_cpp_type} ret0 = '
177-
f'{function_call}({function_call_params});')
156+
yield I + I + (f'{callback_cpp_type} ret0_ = '
157+
f'{function_call}({function_call_params});')
178158
else:
179-
yield I + (f'{ret0.type.cpp_type} ret0 = '
180-
f'{function_call}({function_call_params});')
159+
yield I + I + (f'{ret0.type.cpp_type} ret0_ = '
160+
f'{function_call}({function_call_params});')
161+
ret0_with_py_cast = generate_function_call_return(
162+
func_decl, ret0, 'ret0_', codegen_info)
163+
yield I + I + '{'
164+
yield I + I + I + 'py::gil_scoped_acquire gil_acquire;'
165+
yield I + I + I + f'ret0 = {ret0_with_py_cast};'
166+
yield I + I + '}'
181167
else:
182-
yield I + f'{function_call}({function_call_params});'
183-
184-
if not function_lib.func_keeps_gil(func_decl):
185-
yield I + 'Py_BLOCK_THREADS'
168+
yield I + I + f'{function_call}({function_call_params});'
169+
yield I + '}'
186170

171+
function_call_returns = generate_function_call_returns(
172+
func_decl, codegen_info)
187173
# Generates returns of the lambda expression
188174
self_param = 'self'
189175
if func_decl.is_extend_method and len(params):
@@ -212,32 +198,35 @@ def generate_function_call_params(
212198

213199

214200
def generate_function_call_returns(
215-
func_decl: ast_pb2.FuncDecl, capsule_types: Set[str],
216-
requires_status: bool = True) -> str:
201+
func_decl: ast_pb2.FuncDecl, codegen_info: utils.CodeGenInfo) -> str:
217202
"""Generates return values of cpp function."""
218203
all_returns_list = []
219204
for i, r in enumerate(func_decl.returns):
220-
if function_lib.is_bytes_type(r.type):
221-
all_returns_list.append(
222-
f'py::cast(ret{i}, py::return_value_policy::_return_as_bytes)')
223-
elif r.type.lang_type in capsule_types:
224-
all_returns_list.append(
225-
f'clif::CapsuleWrapper<{r.type.cpp_type}>(ret{i})')
226-
elif function_lib.is_status_param(r, requires_status):
227-
status_type = function_lib.generate_status_type(func_decl, r)
228-
all_returns_list.append(f'py::cast(({status_type})(std::move(ret{i})), '
229-
'py::return_value_policy::_clif_automatic)')
230-
# When the lambda expression returns multiple values, we construct an
231-
# `std::tuple` with those return values. For uncopyable return values, we
232-
# need `std::move` when constructing the `std::tuple`.
233-
elif (len(func_decl.returns) > 1 and
234-
('std::unique_ptr' in r.cpp_exact_type or not r.type.cpp_copyable)):
235-
all_returns_list.append(f'std::move(ret{i})')
205+
if i == 0 and not func_decl.cpp_void_return:
206+
all_returns_list.append('ret0')
236207
else:
237-
all_returns_list.append(f'ret{i}')
208+
ret = generate_function_call_return(func_decl, r, f'ret{i}', codegen_info)
209+
all_returns_list.append(ret)
238210
return ', '.join(all_returns_list)
239211

240212

213+
def generate_function_call_return(
214+
func_decl: ast_pb2.FuncDecl, return_value: ast_pb2.ParamDecl,
215+
return_value_name: str, codegen_info: utils.CodeGenInfo) -> str:
216+
"""Generates return values of cpp function."""
217+
ret = f'std::move({return_value_name})'
218+
return_value_policy = 'py::return_value_policy::_clif_automatic'
219+
if function_lib.is_bytes_type(return_value.type):
220+
return_value_policy = 'py::return_value_policy::_return_as_bytes'
221+
elif return_value.type.lang_type in codegen_info.capsule_types:
222+
ret = (f'clif::CapsuleWrapper<{return_value.type.cpp_type}>'
223+
f'({return_value_name})')
224+
elif function_lib.is_status_param(return_value, codegen_info.requires_status):
225+
status_type = function_lib.generate_status_type(func_decl, return_value)
226+
ret = f'({status_type})(std::move({return_value_name}))'
227+
return f'py::cast({ret}, {return_value_policy})'
228+
229+
241230
def _generate_lambda_params_with_types(
242231
func_decl: ast_pb2.FuncDecl,
243232
params: List[function_lib.Parameter],
@@ -264,87 +253,6 @@ def generate_function_call(
264253
return f'self.{method_name}'
265254

266255

267-
def _func_is_extend_static_method(
268-
func_decl: ast_pb2.FuncDecl,
269-
class_decl: Optional[ast_pb2.ClassDecl] = None) -> bool:
270-
return class_decl and func_decl.is_extend_method and func_decl.classmethod
271-
272-
273-
def _func_has_pointer_params(func_decl: ast_pb2.FuncDecl) -> bool:
274-
num_returns = len(func_decl.returns)
275-
return num_returns >= 2 or (num_returns == 1 and func_decl.cpp_void_return)
276-
277-
278-
def _func_has_status_params(func_decl: ast_pb2.FuncDecl,
279-
requires_status: bool) -> bool:
280-
for p in func_decl.params:
281-
if function_lib.is_status_param(p, requires_status):
282-
return True
283-
for r in func_decl.returns:
284-
if function_lib.is_status_param(r, requires_status):
285-
return True
286-
return False
287-
288-
289-
def _func_has_status_callback(func_decl: ast_pb2.FuncDecl,
290-
requires_status: bool) -> bool:
291-
for r in func_decl.returns:
292-
if function_lib.is_status_callback(r, requires_status):
293-
return True
294-
return False
295-
296-
297-
def _func_has_capsule_params(
298-
func_decl: ast_pb2.FuncDecl, capsule_types: Set[str]) -> bool:
299-
for p in func_decl.params:
300-
if p.type.lang_type in capsule_types:
301-
return True
302-
for r in func_decl.returns:
303-
if r.type.lang_type in capsule_types:
304-
return True
305-
return False
306-
307-
308-
def _func_is_context_manager(func_decl: ast_pb2.FuncDecl) -> bool:
309-
return func_decl.name.native in ('__enter__@', '__exit__@')
310-
311-
312-
def _func_needs_index_check(func_decl: ast_pb2.FuncDecl) -> bool:
313-
return func_decl.name.native in _NEEDS_INDEX_CHECK_METHODS
314-
315-
316-
def _is_inherited_method(class_decl: ast_pb2.ClassDecl,
317-
func_decl: ast_pb2.FuncDecl) -> bool:
318-
if class_decl.cpp_bases and not func_decl.is_extend_method:
319-
namespaces = func_decl.name.cpp_name.split('::')
320-
if (len(namespaces) > 1 and
321-
namespaces[-2] != class_decl.name.cpp_name.strip(':')):
322-
return True
323-
return False
324-
325-
326-
def _func_needs_implicit_conversion(func_decl: ast_pb2.FuncDecl) -> bool:
327-
"""Check if a function contains an implicitly converted parameter."""
328-
for param in func_decl.params:
329-
if (_extract_bare_type(param.cpp_exact_type) !=
330-
_extract_bare_type(param.type.cpp_type) and
331-
param.type.cpp_toptr_conversion and
332-
param.type.cpp_touniqptr_conversion):
333-
return True
334-
return False
335-
336-
337-
def _extract_bare_type(cpp_name: str) -> str:
338-
# This helper function is not general and only meant
339-
# to be used in _func_needs_implicit_conversion.
340-
t = cpp_name.split(' ')
341-
if t[0] == 'const':
342-
t = t[1:]
343-
if t[-1] in {'&', '*'}: # Minimum viable approach. To be refined as needed.
344-
t = t[:-1]
345-
return ' '.join(t)
346-
347-
348256
def generate_return_value_cpp_type(
349257
func_decl: ast_pb2.FuncDecl, codegen_info: utils.CodeGenInfo) -> str:
350258
"""Generates type for the return value of the C++ function."""

clif/pybind11/unknown_default_value.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def generate_from(
127127
# Generates call to the C++ function
128128
function_call = lambdas.generate_function_call(func_decl, class_decl)
129129
function_call_returns = lambdas.generate_function_call_returns(
130-
func_decl, codegen_info.capsule_types)
130+
func_decl, codegen_info)
131131
yield I + 'switch (nargs) {'
132132
for n in range(minargs, nargs+1):
133133
yield I + I + f'case {n}:'

clif/testing/python/return_value_policy.clif

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,18 +18,6 @@ from "clif/testing/return_value_policy.h":
1818
mtxt: str
1919

2020
def return_value() -> Obj
21-
@return_value_policy_copy
22-
def `return_pointer` as return_value_policy_copy() -> Obj
23-
@return_value_policy_move
24-
def `return_pointer` as return_value_policy_move() -> Obj
25-
@return_value_policy_reference
26-
def `return_pointer` as return_value_policy_reference() -> Obj
27-
@return_value_policy_take_ownership
28-
def `return_pointer_unowned` as return_value_policy_take_ownership() -> Obj
29-
@return_value_policy_automatic
30-
def `return_pointer_unowned` as return_value_policy_automatic() -> Obj
31-
@return_value_policy_return_as_bytes
32-
def return_string() -> str
3321
def return_reference() -> Obj
3422
def return_const_reference() -> Obj
3523
def return_pointer() -> Obj

0 commit comments

Comments
 (0)