From 6959d81d232c6d4d5dca7acf4921bf4ddc8705ad Mon Sep 17 00:00:00 2001 From: Paul Caprioli Date: Fri, 31 Oct 2025 17:06:50 -0700 Subject: [PATCH] Specialize complex function dispatcher This PR adds a variant of the complex dispatcher that applies when no overloads have kwargs and all overloads have 8 or fewer arguments. --- src/nb_func.cpp | 196 +++++++++++++++++++++++++++++-- src/nb_internals.h | 1 + tests/test_functions.cpp | 3 + tests/test_functions.py | 10 ++ tests/test_functions_ext.pyi.ref | 4 + 5 files changed, 202 insertions(+), 12 deletions(-) diff --git a/src/nb_func.cpp b/src/nb_func.cpp index d81fff9c..7a60dc4f 100644 --- a/src/nb_func.cpp +++ b/src/nb_func.cpp @@ -35,6 +35,8 @@ static PyObject *nb_func_vectorcall_simple_1(PyObject *, PyObject *const *, size_t, PyObject *) noexcept; static PyObject *nb_func_vectorcall_simple(PyObject *, PyObject *const *, size_t, PyObject *) noexcept; +static PyObject *nb_func_vectorcall_modest(PyObject *, PyObject *const *, + size_t, PyObject *) noexcept; static PyObject *nb_func_vectorcall_complex(PyObject *, PyObject *const *, size_t, PyObject *) noexcept; static uint32_t nb_func_render_signature(const func_data *f, @@ -298,16 +300,17 @@ PyObject *nb_func_new(const func_data_prelim_base *f) noexcept { make_immortal((PyObject *) func); // Check if the complex dispatch loop is needed - bool complex_call = can_mutate_args || has_var_kwargs || has_var_args || + bool has_kwargs = has_var_kwargs; + bool complex_call = can_mutate_args || has_var_args || f->nargs > NB_MAXARGS_SIMPLE; - if (has_args) { for (size_t i = is_method; i < f->nargs; ++i) { arg_data &a = args_in[i - is_method]; - complex_call |= a.name != nullptr || a.value != nullptr || - a.flag != cast_flags::convert; + has_kwargs |= a.name != nullptr; + complex_call |= a.value != nullptr || a.flag != cast_flags::convert; } } + complex_call |= has_kwargs; uint32_t max_nargs = f->nargs; @@ -315,8 +318,9 @@ PyObject *nb_func_new(const func_data_prelim_base *f) noexcept { if (func_prev) { nb_func *nb_func_prev = (nb_func *) func_prev; - complex_call |= nb_func_prev->complex_call; max_nargs = std::max(max_nargs, nb_func_prev->max_nargs); + has_kwargs |= nb_func_prev->has_kwargs; + complex_call |= nb_func_prev->complex_call; func_data *cur = nb_func_data(func), *prev = nb_func_data(func_prev); @@ -339,12 +343,15 @@ PyObject *nb_func_new(const func_data_prelim_base *f) noexcept { } func->max_nargs = max_nargs; + func->has_kwargs = has_kwargs; func->complex_call = complex_call; - PyObject* (*vectorcall)(PyObject *, PyObject * const*, size_t, PyObject *); if (complex_call) { - vectorcall = nb_func_vectorcall_complex; + if (max_nargs <= NB_MAXARGS_SIMPLE && !has_kwargs) + vectorcall = nb_func_vectorcall_modest; + else + vectorcall = nb_func_vectorcall_complex; } else { if (f->nargs == 0 && !prev_overloads) vectorcall = nb_func_vectorcall_simple_0; @@ -636,7 +643,7 @@ static PyObject *nb_func_vectorcall_complex(PyObject *self, cleanup_list cleanup(self_arg); // Preallocate stack memory for function dispatch - size_t max_nargs = ((nb_func *) self)->max_nargs; + const size_t max_nargs = ((nb_func *) self)->max_nargs; PyObject **args = (PyObject **) alloca(max_nargs * sizeof(PyObject *)); uint8_t *args_flags = (uint8_t *) alloca(max_nargs * sizeof(uint8_t)); bool *kwarg_used = (bool *) alloca(nkwargs_in * sizeof(bool)); @@ -715,7 +722,7 @@ static PyObject *nb_func_vectorcall_complex(PyObject *self, // Number of C++ parameters eligible to be filled from individual // Python positional arguments - size_t nargs_pos = f->nargs_pos; + const size_t nargs_pos = f->nargs_pos; // Number of C++ parameters in total, except for a possible trailing // nb::kwargs. All of these are eligible to be filled from individual @@ -723,7 +730,7 @@ static PyObject *nb_func_vectorcall_complex(PyObject *self, // except for a potential nb::args, which exists at index nargs_pos // if has_var_args is true. We'll skip that one in the individual-args // loop, and go back and fill it later with the unused positionals. - size_t nargs_step1 = f->nargs - has_var_kwargs; + const size_t nargs_step1 = f->nargs - has_var_kwargs; if (nargs_in > nargs_pos && !has_var_args) continue; // Too many positional arguments given for this overload @@ -827,6 +834,171 @@ static PyObject *nb_func_vectorcall_complex(PyObject *self, continue; } + if (is_constructor) + args_flags[0] |= (uint8_t) cast_flags::construct; + + rv_policy policy = (rv_policy) (f->flags & 0b111); + + try { + result = nullptr; + + // Found a suitable overload, let's try calling it + result = f->impl((void *) f->capture, args, args_flags, + policy, &cleanup); + + if (NB_UNLIKELY(!result)) + error_handler = nb_func_error_noconvert; + } catch (builtin_exception &e) { + if (!set_builtin_exception_status(e)) + result = NB_NEXT_OVERLOAD; + } catch (python_error &e) { + e.restore(); + } catch (...) { + nb_func_convert_cpp_exception(); + } + + if (result != NB_NEXT_OVERLOAD) { + if (is_constructor && result != nullptr) { + nb_inst *self_arg_nb = (nb_inst *) self_arg; + self_arg_nb->destruct = true; + self_arg_nb->state = nb_inst::state_ready; + if (NB_UNLIKELY(self_arg_nb->intrusive)) + nb_type_data(Py_TYPE(self_arg)) + ->set_self_py(inst_ptr(self_arg_nb), self_arg); + } + + goto done; + } + } + } + + error_handler = nb_func_error_overload; + +done: + if (NB_UNLIKELY(cleanup.used())) + cleanup.release(); + + if (NB_UNLIKELY(error_handler)) + result = error_handler(self, args_in, nargs_in, kwargs_in); + + return result; +} + +/// Simplified nb_func_vectorcall variant for functions w/o keyword arguments +/// and with no more than NB_MAXARGS_SIMPLE arguments +static PyObject *nb_func_vectorcall_modest(PyObject *self, + PyObject *const *args_in, + size_t nargsf, + PyObject *kwargs_in) noexcept { + const size_t count = (size_t) Py_SIZE(self), + nargs_in = (size_t) NB_VECTORCALL_NARGS(nargsf); + + func_data *fr = nb_func_data(self); + + const bool is_method = fr->flags & (uint32_t) func_flags::is_method, + is_constructor = fr->flags & (uint32_t) func_flags::is_constructor; + + PyObject *result = nullptr, + *self_arg = (is_method && nargs_in > 0) ? args_in[0] : nullptr; + + // Handler routine that will be invoked in case of an error condition + PyObject *(*error_handler)(PyObject *, PyObject *const *, size_t, + PyObject *) noexcept = nullptr; + + // Small array holding temporaries (implicit conversion/*args) + cleanup_list cleanup(self_arg); + + if (kwargs_in != nullptr) { // keyword arguments are unsupported + error_handler = nb_func_error_overload; + goto done; + } + + // Stack memory for function dispatch + PyObject* args[NB_MAXARGS_SIMPLE]; + uint8_t args_flags[NB_MAXARGS_SIMPLE]; + + /* The logic below tries to find a suitable overload using two passes + of the overload chain (or 1, if there are no overloads). The first pass + is strict and permits no implicit conversions, while the second pass + allows them. + + The following is done per overload during a pass + + 1. Copy individual arguments, substituting missing entries using + default argument values provided in the bindings, if available. + + 2. Any positional arguments still left get put into a tuple. + + 3. Pack everything into a vector; if we have nb::args, it becomes + a tuple at the end of the positional arguments. + + 4. Call the function call dispatcher (func_data::impl) + + If one of these fail, move on to the next overload and keep trying + until we get a result other than NB_NEXT_OVERLOAD. + */ + + for (size_t pass = (count > 1) ? 0 : 1; pass < 2; ++pass) { + for (size_t k = 0; k < count; ++k) { + const func_data *f = fr + k; + + const bool has_args = f->flags & (uint32_t) func_flags::has_args, + has_var_args = f->flags & (uint32_t) func_flags::has_var_args; + + // Number of C++ parameters eligible to be filled from individual + // Python positional arguments + const size_t nargs_pos = f->nargs_pos; + + if (nargs_in > nargs_pos && !has_var_args) + continue; // Too many positional arguments for this overload + + if (nargs_in < nargs_pos && !has_args) + continue; // Not enough positional arguments + + // 1. Copy individual arguments, potentially substitute defaults + size_t i = 0; + for (; i < nargs_pos; ++i) { + PyObject *arg = nullptr; + uint8_t arg_flag = 1; + + if (i < nargs_in) + arg = args_in[i]; + + if (has_args) { + const arg_data &ad = f->args[i]; + + if (!arg) + arg = ad.value; + arg_flag = ad.flag; + } + + if (!arg || (arg == Py_None && (arg_flag & cast_flags::accepts_none) == 0)) + break; + + // Implicit conversion only active in the 2nd pass + args_flags[i] = arg_flag & ~uint8_t(pass == 0); + args[i] = arg; + } + + // Skip this overload if any arguments were unavailable + if (i != nargs_pos) + continue; + + // Deal with remaining positional arguments + if (has_var_args) { + PyObject *tuple = PyTuple_New( + nargs_in > nargs_pos ? (Py_ssize_t) (nargs_in - nargs_pos) : 0); + + for (size_t j = nargs_pos; j < nargs_in; ++j) { + PyObject *o = args_in[j]; + Py_INCREF(o); + NB_TUPLE_SET_ITEM(tuple, j - nargs_pos, o); + } + + args[nargs_pos] = tuple; + args_flags[nargs_pos] = 0; + cleanup.append(tuple); + } if (is_constructor) args_flags[0] |= (uint8_t) cast_flags::construct; @@ -887,8 +1059,8 @@ static PyObject *nb_func_vectorcall_simple(PyObject *self, uint8_t args_flags[NB_MAXARGS_SIMPLE]; func_data *fr = nb_func_data(self); - const size_t count = (size_t) Py_SIZE(self), - nargs_in = (size_t) NB_VECTORCALL_NARGS(nargsf); + const size_t count = (size_t) Py_SIZE(self), + nargs_in = (size_t) NB_VECTORCALL_NARGS(nargsf); const bool is_method = fr->flags & (uint32_t) func_flags::is_method, is_constructor = fr->flags & (uint32_t) func_flags::is_constructor; diff --git a/src/nb_internals.h b/src/nb_internals.h index 3209d32b..f61e6673 100644 --- a/src/nb_internals.h +++ b/src/nb_internals.h @@ -100,6 +100,7 @@ struct nb_func { PyObject_VAR_HEAD PyObject* (*vectorcall)(PyObject *, PyObject * const*, size_t, PyObject *); uint32_t max_nargs; // maximum value of func_data::nargs for any overload + bool has_kwargs; // whether any overload has keyword arguments bool complex_call; bool doc_uniform; }; diff --git a/tests/test_functions.cpp b/tests/test_functions.cpp index 3c2010df..9c512764 100644 --- a/tests/test_functions.cpp +++ b/tests/test_functions.cpp @@ -95,6 +95,9 @@ NB_MODULE(test_functions_ext, m) { // Simple binary function (via function pointer) auto test_02 = [](int up, int down) -> int { return up - down; }; m.def("test_02", (int (*)(int, int)) test_02, "up"_a = 8, "down"_a = 1); + m.def("test_02p", (int (*)(int, int)) test_02, nb::arg()=8, nb::arg()=1); + m.def("test_02nc", (int (*)(int, int)) test_02, nb::arg().noconvert()=8, + nb::arg().noconvert()=1); // Simple binary function with capture object int i = 42; diff --git a/tests/test_functions.py b/tests/test_functions.py index d9da6ea6..32272015 100644 --- a/tests/test_functions.py +++ b/tests/test_functions.py @@ -20,6 +20,8 @@ def test01_capture(): # Functions with and without capture object of different sizes assert t.test_01() is None assert t.test_02(5, 3) == 2 + assert t.test_02p(5, 3) == 2 + assert t.test_02nc(5, 3) == 2 assert t.test_03(5, 3) == 44 assert t.test_04() == 60 assert t.test_simple(0, 1, 2, 3, 4, 5, 6, 7) == 14 @@ -29,6 +31,14 @@ def test02_default_args(): # Default arguments assert t.test_02() == 7 assert t.test_02(7) == 6 + assert t.test_02('17') == 16 + assert t.test_02p() == 7 + assert t.test_02p(7) == 6 + assert t.test_02p('17') == 16 + assert t.test_02nc() == 7 + assert t.test_02nc(7) == 6 + with pytest.raises(TypeError): + t.test_02nc('17') def test03_kwargs(): diff --git a/tests/test_functions_ext.pyi.ref b/tests/test_functions_ext.pyi.ref index ecf60146..98455d6d 100644 --- a/tests/test_functions_ext.pyi.ref +++ b/tests/test_functions_ext.pyi.ref @@ -7,6 +7,10 @@ def test_01() -> None: ... def test_02(up: int = 8, down: int = 1) -> int: ... +def test_02p(arg0: int = 8, arg1: int = 1) -> int: ... + +def test_02nc(arg0: int = 8, arg1: int = 1) -> int: ... + def test_03(arg0: int, arg1: int, /) -> int: ... def test_04() -> int: ...