Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
196 changes: 184 additions & 12 deletions src/nb_func.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ static PyObject *nb_func_vectorcall_simple_1(PyObject *, PyObject *const *,
size_t, PyObject *) noexcept;
static PyObject *nb_func_vectorcall_simple(PyObject *, PyObject *const *,
size_t, PyObject *) noexcept;
static PyObject *nb_func_vectorcall_modest(PyObject *, PyObject *const *,
size_t, PyObject *) noexcept;
static PyObject *nb_func_vectorcall_complex(PyObject *, PyObject *const *,
size_t, PyObject *) noexcept;
static uint32_t nb_func_render_signature(const func_data *f,
Expand Down Expand Up @@ -298,25 +300,27 @@ PyObject *nb_func_new(const func_data_prelim_base *f) noexcept {
make_immortal((PyObject *) func);

// Check if the complex dispatch loop is needed
bool complex_call = can_mutate_args || has_var_kwargs || has_var_args ||
bool has_kwargs = has_var_kwargs;
bool complex_call = can_mutate_args || has_var_args ||
f->nargs > NB_MAXARGS_SIMPLE;

if (has_args) {
for (size_t i = is_method; i < f->nargs; ++i) {
arg_data &a = args_in[i - is_method];
complex_call |= a.name != nullptr || a.value != nullptr ||
a.flag != cast_flags::convert;
has_kwargs |= a.name != nullptr;
complex_call |= a.value != nullptr || a.flag != cast_flags::convert;
}
}
complex_call |= has_kwargs;

uint32_t max_nargs = f->nargs;

const char *prev_doc = nullptr;

if (func_prev) {
nb_func *nb_func_prev = (nb_func *) func_prev;
complex_call |= nb_func_prev->complex_call;
max_nargs = std::max(max_nargs, nb_func_prev->max_nargs);
has_kwargs |= nb_func_prev->has_kwargs;
complex_call |= nb_func_prev->complex_call;

func_data *cur = nb_func_data(func),
*prev = nb_func_data(func_prev);
Expand All @@ -339,12 +343,15 @@ PyObject *nb_func_new(const func_data_prelim_base *f) noexcept {
}

func->max_nargs = max_nargs;
func->has_kwargs = has_kwargs;
func->complex_call = complex_call;


PyObject* (*vectorcall)(PyObject *, PyObject * const*, size_t, PyObject *);
if (complex_call) {
vectorcall = nb_func_vectorcall_complex;
if (max_nargs <= NB_MAXARGS_SIMPLE && !has_kwargs)
vectorcall = nb_func_vectorcall_modest;
else
vectorcall = nb_func_vectorcall_complex;
} else {
if (f->nargs == 0 && !prev_overloads)
vectorcall = nb_func_vectorcall_simple_0;
Expand Down Expand Up @@ -636,7 +643,7 @@ static PyObject *nb_func_vectorcall_complex(PyObject *self,
cleanup_list cleanup(self_arg);

// Preallocate stack memory for function dispatch
size_t max_nargs = ((nb_func *) self)->max_nargs;
const size_t max_nargs = ((nb_func *) self)->max_nargs;
PyObject **args = (PyObject **) alloca(max_nargs * sizeof(PyObject *));
uint8_t *args_flags = (uint8_t *) alloca(max_nargs * sizeof(uint8_t));
bool *kwarg_used = (bool *) alloca(nkwargs_in * sizeof(bool));
Expand Down Expand Up @@ -715,15 +722,15 @@ static PyObject *nb_func_vectorcall_complex(PyObject *self,

// Number of C++ parameters eligible to be filled from individual
// Python positional arguments
size_t nargs_pos = f->nargs_pos;
const size_t nargs_pos = f->nargs_pos;

// Number of C++ parameters in total, except for a possible trailing
// nb::kwargs. All of these are eligible to be filled from individual
// Python arguments (keyword always, positional until index nargs_pos)
// except for a potential nb::args, which exists at index nargs_pos
// if has_var_args is true. We'll skip that one in the individual-args
// loop, and go back and fill it later with the unused positionals.
size_t nargs_step1 = f->nargs - has_var_kwargs;
const size_t nargs_step1 = f->nargs - has_var_kwargs;

if (nargs_in > nargs_pos && !has_var_args)
continue; // Too many positional arguments given for this overload
Expand Down Expand Up @@ -827,6 +834,171 @@ static PyObject *nb_func_vectorcall_complex(PyObject *self,
continue;
}

if (is_constructor)
args_flags[0] |= (uint8_t) cast_flags::construct;

rv_policy policy = (rv_policy) (f->flags & 0b111);

try {
result = nullptr;

// Found a suitable overload, let's try calling it
result = f->impl((void *) f->capture, args, args_flags,
policy, &cleanup);

if (NB_UNLIKELY(!result))
error_handler = nb_func_error_noconvert;
} catch (builtin_exception &e) {
if (!set_builtin_exception_status(e))
result = NB_NEXT_OVERLOAD;
} catch (python_error &e) {
e.restore();
} catch (...) {
nb_func_convert_cpp_exception();
}

if (result != NB_NEXT_OVERLOAD) {
if (is_constructor && result != nullptr) {
nb_inst *self_arg_nb = (nb_inst *) self_arg;
self_arg_nb->destruct = true;
self_arg_nb->state = nb_inst::state_ready;
if (NB_UNLIKELY(self_arg_nb->intrusive))
nb_type_data(Py_TYPE(self_arg))
->set_self_py(inst_ptr(self_arg_nb), self_arg);
}

goto done;
}
}
}

error_handler = nb_func_error_overload;

done:
if (NB_UNLIKELY(cleanup.used()))
cleanup.release();

if (NB_UNLIKELY(error_handler))
result = error_handler(self, args_in, nargs_in, kwargs_in);

return result;
}

/// Simplified nb_func_vectorcall variant for functions w/o keyword arguments
/// and with no more than NB_MAXARGS_SIMPLE arguments
static PyObject *nb_func_vectorcall_modest(PyObject *self,
PyObject *const *args_in,
size_t nargsf,
PyObject *kwargs_in) noexcept {
const size_t count = (size_t) Py_SIZE(self),
nargs_in = (size_t) NB_VECTORCALL_NARGS(nargsf);

func_data *fr = nb_func_data(self);

const bool is_method = fr->flags & (uint32_t) func_flags::is_method,
is_constructor = fr->flags & (uint32_t) func_flags::is_constructor;

PyObject *result = nullptr,
*self_arg = (is_method && nargs_in > 0) ? args_in[0] : nullptr;

// Handler routine that will be invoked in case of an error condition
PyObject *(*error_handler)(PyObject *, PyObject *const *, size_t,
PyObject *) noexcept = nullptr;

// Small array holding temporaries (implicit conversion/*args)
cleanup_list cleanup(self_arg);

if (kwargs_in != nullptr) { // keyword arguments are unsupported
error_handler = nb_func_error_overload;
goto done;
}

// Stack memory for function dispatch
PyObject* args[NB_MAXARGS_SIMPLE];
uint8_t args_flags[NB_MAXARGS_SIMPLE];

/* The logic below tries to find a suitable overload using two passes
of the overload chain (or 1, if there are no overloads). The first pass
is strict and permits no implicit conversions, while the second pass
allows them.

The following is done per overload during a pass

1. Copy individual arguments, substituting missing entries using
default argument values provided in the bindings, if available.

2. Any positional arguments still left get put into a tuple.

3. Pack everything into a vector; if we have nb::args, it becomes
a tuple at the end of the positional arguments.

4. Call the function call dispatcher (func_data::impl)

If one of these fail, move on to the next overload and keep trying
until we get a result other than NB_NEXT_OVERLOAD.
*/

for (size_t pass = (count > 1) ? 0 : 1; pass < 2; ++pass) {
for (size_t k = 0; k < count; ++k) {
const func_data *f = fr + k;

const bool has_args = f->flags & (uint32_t) func_flags::has_args,
has_var_args = f->flags & (uint32_t) func_flags::has_var_args;

// Number of C++ parameters eligible to be filled from individual
// Python positional arguments
const size_t nargs_pos = f->nargs_pos;

if (nargs_in > nargs_pos && !has_var_args)
continue; // Too many positional arguments for this overload

if (nargs_in < nargs_pos && !has_args)
continue; // Not enough positional arguments

// 1. Copy individual arguments, potentially substitute defaults
size_t i = 0;
for (; i < nargs_pos; ++i) {
PyObject *arg = nullptr;
uint8_t arg_flag = 1;

if (i < nargs_in)
arg = args_in[i];

if (has_args) {
const arg_data &ad = f->args[i];

if (!arg)
arg = ad.value;
arg_flag = ad.flag;
}

if (!arg || (arg == Py_None && (arg_flag & cast_flags::accepts_none) == 0))
break;

// Implicit conversion only active in the 2nd pass
args_flags[i] = arg_flag & ~uint8_t(pass == 0);
args[i] = arg;
}

// Skip this overload if any arguments were unavailable
if (i != nargs_pos)
continue;

// Deal with remaining positional arguments
if (has_var_args) {
PyObject *tuple = PyTuple_New(
nargs_in > nargs_pos ? (Py_ssize_t) (nargs_in - nargs_pos) : 0);

for (size_t j = nargs_pos; j < nargs_in; ++j) {
PyObject *o = args_in[j];
Py_INCREF(o);
NB_TUPLE_SET_ITEM(tuple, j - nargs_pos, o);
}

args[nargs_pos] = tuple;
args_flags[nargs_pos] = 0;
cleanup.append(tuple);
}

if (is_constructor)
args_flags[0] |= (uint8_t) cast_flags::construct;
Expand Down Expand Up @@ -887,8 +1059,8 @@ static PyObject *nb_func_vectorcall_simple(PyObject *self,
uint8_t args_flags[NB_MAXARGS_SIMPLE];
func_data *fr = nb_func_data(self);

const size_t count = (size_t) Py_SIZE(self),
nargs_in = (size_t) NB_VECTORCALL_NARGS(nargsf);
const size_t count = (size_t) Py_SIZE(self),
nargs_in = (size_t) NB_VECTORCALL_NARGS(nargsf);

const bool is_method = fr->flags & (uint32_t) func_flags::is_method,
is_constructor = fr->flags & (uint32_t) func_flags::is_constructor;
Expand Down
1 change: 1 addition & 0 deletions src/nb_internals.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ struct nb_func {
PyObject_VAR_HEAD
PyObject* (*vectorcall)(PyObject *, PyObject * const*, size_t, PyObject *);
uint32_t max_nargs; // maximum value of func_data::nargs for any overload
bool has_kwargs; // whether any overload has keyword arguments
bool complex_call;
bool doc_uniform;
};
Expand Down
3 changes: 3 additions & 0 deletions tests/test_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ NB_MODULE(test_functions_ext, m) {
// Simple binary function (via function pointer)
auto test_02 = [](int up, int down) -> int { return up - down; };
m.def("test_02", (int (*)(int, int)) test_02, "up"_a = 8, "down"_a = 1);
m.def("test_02p", (int (*)(int, int)) test_02, nb::arg()=8, nb::arg()=1);
m.def("test_02nc", (int (*)(int, int)) test_02, nb::arg().noconvert()=8,
nb::arg().noconvert()=1);

// Simple binary function with capture object
int i = 42;
Expand Down
10 changes: 10 additions & 0 deletions tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ def test01_capture():
# Functions with and without capture object of different sizes
assert t.test_01() is None
assert t.test_02(5, 3) == 2
assert t.test_02p(5, 3) == 2
assert t.test_02nc(5, 3) == 2
assert t.test_03(5, 3) == 44
assert t.test_04() == 60
assert t.test_simple(0, 1, 2, 3, 4, 5, 6, 7) == 14
Expand All @@ -29,6 +31,14 @@ def test02_default_args():
# Default arguments
assert t.test_02() == 7
assert t.test_02(7) == 6
assert t.test_02('17') == 16
assert t.test_02p() == 7
assert t.test_02p(7) == 6
assert t.test_02p('17') == 16
assert t.test_02nc() == 7
assert t.test_02nc(7) == 6
with pytest.raises(TypeError):
t.test_02nc('17')


def test03_kwargs():
Expand Down
4 changes: 4 additions & 0 deletions tests/test_functions_ext.pyi.ref
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ def test_01() -> None: ...

def test_02(up: int = 8, down: int = 1) -> int: ...

def test_02p(arg0: int = 8, arg1: int = 1) -> int: ...

def test_02nc(arg0: int = 8, arg1: int = 1) -> int: ...

def test_03(arg0: int, arg1: int, /) -> int: ...

def test_04() -> int: ...
Expand Down
Loading