Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add argument number dispatch mechanism for std::function casting #5285

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 43 additions & 1 deletion include/pybind11/functional.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "pybind11.h"

#include <functional>
#include <iostream>

PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
PYBIND11_NAMESPACE_BEGIN(detail)
Expand Down Expand Up @@ -101,8 +102,17 @@ struct type_caster<std::function<Return(Args...)>> {
if (detail::is_function_record_capsule(c)) {
rec = c.get_pointer<function_record>();
}

while (rec != nullptr) {
const size_t self_offset = rec->is_method ? 1 : 0;
if (rec->nargs != sizeof...(Args) + self_offset) {
rec = rec->next;
// if the overload is not feasible in terms of number of arguments, we
// continue to the next one. If there is no next one, we return false.
if (rec == nullptr) {
return false;
}
continue;
}
if (rec->is_stateless
&& same_type(typeid(function_type),
*reinterpret_cast<const std::type_info *>(rec->data[1]))) {
Expand All @@ -118,6 +128,38 @@ struct type_caster<std::function<Return(Args...)>> {
// PYPY segfaults here when passing builtin function like sum.
// Raising an fail exception here works to prevent the segfault, but only on gcc.
// See PR #1413 for full details
} else {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To try to minimize the runtime overhead, I would consider adding a rec->next check here to only trigger this if there is a another possible overload to try.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm trying to understand your comment but I'm afraid I don't know what you are suggesting. Can you please elaborate?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the idea is to skip this logic if there's not another overload (null rec->next) to save time.

// Check number of arguments of Python function
auto get_argument_count = [](const handle &obj) -> size_t {
// Faster then `import inspect` and `inspect.signature(obj).parameters`
return obj.attr("co_argcount").cast<size_t>();
};
size_t argCount = 0;

handle empty;
object codeAttr = getattr(src, "__code__", empty);

if (codeAttr) {
argCount = get_argument_count(codeAttr);
} else {
object callAttr = getattr(src, "__call__", empty);

if (callAttr) {
object codeAttr2 = getattr(callAttr, "__code__");
argCount = get_argument_count(codeAttr2) - 1; // removing the self argument
} else {
// No __code__ or __call__ attribute, this is not a proper Python function
return false;
}
}
// if we are a method, we have to correct the argument count since we are not counting
// the self argument
const size_t self_offset = static_cast<bool>(PyMethod_Check(src.ptr())) ? 1 : 0;

argCount -= self_offset;
if (argCount != sizeof...(Args)) {
return false;
}
}

value = type_caster_std_function_specializations::func_wrapper<Return, Args...>(
Expand Down
6 changes: 6 additions & 0 deletions tests/test_callbacks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,12 @@ TEST_SUBMODULE(callbacks, m) {
return "argument does NOT match dummy_function. This should never happen!";
});

// test_cpp_correct_overload_resolution
m.def("dummy_function_overloaded_std_func_arg",
[](const std::function<int(int)> &f) { return 3 * f(3); });
m.def("dummy_function_overloaded_std_func_arg",
[](const std::function<int(int, int)> &f) { return 2 * f(3, 4); });

class AbstractBase {
public:
// [workaround(intel)] = default does not work here
Expand Down
30 changes: 29 additions & 1 deletion tests/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,31 @@ def test_cpp_callable_cleanup():
assert alive_counts == [0, 1, 2, 1, 2, 1, 0]


def test_cpp_correct_overload_resolution():
henryiii marked this conversation as resolved.
Show resolved Hide resolved
def f(a):
return a

class A:
def __call__(self, a):
return a

assert m.dummy_function_overloaded_std_func_arg(f) == 9
a = A()
assert m.dummy_function_overloaded_std_func_arg(a) == 9
assert m.dummy_function_overloaded_std_func_arg(lambda i: i) == 9

def f2(a, b):
return a + b

class B:
def __call__(self, a, b):
return a + b

assert m.dummy_function_overloaded_std_func_arg(f2) == 14
assert m.dummy_function_overloaded_std_func_arg(B()) == 14
assert m.dummy_function_overloaded_std_func_arg(lambda i, j: i + j) == 14


def test_cpp_function_roundtrip():
"""Test if passing a function pointer from C++ -> Python -> C++ yields the original pointer"""

Expand Down Expand Up @@ -131,7 +156,10 @@ def test_cpp_function_roundtrip():
m.test_dummy_function(lambda x, y: x + y)
assert any(
s in str(excinfo.value)
for s in ("missing 1 required positional argument", "takes exactly 2 arguments")
for s in (
"incompatible function arguments. The following argument types are",
"function test_cpp_function_roundtrip.<locals>.<lambda>",
)
)


Expand Down
16 changes: 16 additions & 0 deletions tests/test_embed/test_interpreter.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include <pybind11/embed.h>
#include <pybind11/functional.h>

// Silence MSVC C++17 deprecation warning from Catch regarding std::uncaught_exceptions (up to
// catch 2.0.1; this should be fixed in the next catch release after 2.0.1).
Expand Down Expand Up @@ -78,6 +79,12 @@ PYBIND11_EMBEDDED_MODULE(throw_error_already_set, ) {
d["missing"].cast<py::object>();
}

PYBIND11_EMBEDDED_MODULE(func_module, m) {
m.def("funcOverload", [](const std::function<int(int, int)> &f) {
return f(2, 3);
}).def("funcOverload", [](const std::function<int(int)> &f) { return f(2); });
}

TEST_CASE("PYTHONPATH is used to update sys.path") {
// The setup for this TEST_CASE is in catch.cpp!
auto sys_path = py::str(py::module_::import("sys").attr("path")).cast<std::string>();
Expand Down Expand Up @@ -171,6 +178,15 @@ TEST_CASE("There can be only one interpreter") {
py::initialize_interpreter();
}

TEST_CASE("Check the overload resolution from cpp_function objects to std::function") {
auto m = py::module_::import("func_module");
auto f = std::function<int(int)>([](int x) { return 2 * x; });
REQUIRE(m.attr("funcOverload")(f).template cast<int>() == 4);

auto f2 = std::function<int(int, int)>([](int x, int y) { return 2 * x * y; });
REQUIRE(m.attr("funcOverload")(f2).template cast<int>() == 12);
}

#if PY_VERSION_HEX >= PYBIND11_PYCONFIG_SUPPORT_PY_VERSION_HEX
TEST_CASE("Custom PyConfig") {
py::finalize_interpreter();
Expand Down
Loading