1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137
|
// This is an example of using nb::call_policy to support binding an
// object that takes non-owning callbacks. Since the callbacks can't
// directly keep a Python object alive (they're trivially copyable), we
// maintain a sideband structure to manage the lifetimes.
#include <algorithm>
#include <unordered_set>
#include <vector>
#include <nanobind/nanobind.h>
#include <nanobind/stl/unordered_set.h>
namespace nb = nanobind;
// The callback type accepted by the object, which we assume we can't change.
// It's trivially copyable, so it can't directly keep a Python object alive.
struct callback {
void *context;
void (*func)(void *context, int arg);
void operator()(int arg) const { (*func)(context, arg); }
bool operator==(const callback& other) const {
return context == other.context && func == other.func;
}
};
// An object that uses these callbacks, which we want to write bindings for
class publisher {
public:
void subscribe(callback cb) { cbs.push_back(cb); }
void unsubscribe(callback cb) {
cbs.erase(std::remove(cbs.begin(), cbs.end(), cb), cbs.end());
}
void emit(int arg) const { for (auto cb : cbs) cb(arg); }
private:
std::vector<callback> cbs;
};
template <> struct nanobind::detail::type_caster<callback> {
static void wrap_call(void *context, int arg) {
borrow<callable>((PyObject *) context)(arg);
}
bool from_python(handle src, uint8_t, cleanup_list*) noexcept {
if (!isinstance<callable>(src)) return false;
value = {(void *) src.ptr(), &wrap_call};
return true;
}
static handle from_cpp(callback cb, rv_policy policy, cleanup_list*) noexcept {
if (cb.func == &wrap_call)
return handle((PyObject *) cb.context).inc_ref();
if (policy == rv_policy::none)
return handle();
return cpp_function(cb, policy).release();
}
NB_TYPE_CASTER(callback, const_name("Callable[[int], None]"))
};
nb::dict cb_registry() {
return nb::cast<nb::dict>(
nb::module_::import_("test_callbacks_ext").attr("registry"));
}
struct callback_data {
struct py_hash {
size_t operator()(const nb::object& obj) const { return nb::hash(obj); }
};
struct py_eq {
bool operator()(const nb::object& a, const nb::object& b) const {
return a.equal(b);
}
};
std::unordered_set<nb::object, py_hash, py_eq> subscribers;
};
callback_data& callbacks_for(nb::handle publisher) {
auto registry = cb_registry();
nb::weakref key(publisher, registry.attr("__delitem__"));
if (nb::handle value = PyDict_GetItem(registry.ptr(), key.ptr())) {
return nb::cast<callback_data&>(value);
}
nb::object new_data = nb::cast(callback_data{});
registry[key] = new_data;
return nb::cast<callback_data&>(new_data);
}
struct cb_policy_common {
using TwoArgs = std::integral_constant<size_t, 2>;
static void precall(PyObject **args, TwoArgs,
nb::detail::cleanup_list *cleanup) {
nb::handle self = args[0], cb = args[1];
auto& cbs = callbacks_for(self);
auto it = cbs.subscribers.find(nb::borrow(cb));
if (it != cbs.subscribers.end() && !it->is(cb)) {
// A callback is already subscribed that is
// equal-but-not-identical to the one passed in.
// Adjust args to refer to that one, to work around
// the fact that the C++ object does not understand py-equality.
args[1] = it->ptr();
// This ensures that the normalized callback won't be
// immediately destroyed if it's removed from the registry
// in the unsubscribe postcall hook. Such destruction could
// result in a use-after-free if you have other postcall hooks
// or keep_alives that try to inspect the function args.
// It's not strictly necessary if each arg is inspected by
// only one call policy or keep_alive.
cleanup->append(it->inc_ref().ptr());
}
}
};
struct subscribe_policy : cb_policy_common {
static void postcall(PyObject **args, TwoArgs, nb::handle) {
nb::handle self = args[0], cb = args[1];
callbacks_for(self).subscribers.insert(nb::borrow(cb));
}
};
struct unsubscribe_policy : cb_policy_common {
static void postcall(PyObject **args, TwoArgs, nb::handle) {
nb::handle self = args[0], cb = args[1];
callbacks_for(self).subscribers.erase(nb::borrow(cb));
}
};
NB_MODULE(test_callbacks_ext, m) {
m.attr("registry") = nb::dict();
nb::class_<callback_data>(m, "callback_data")
.def_ro("subscribers", &callback_data::subscribers);
nb::class_<publisher>(m, "publisher", nb::is_weak_referenceable())
.def(nb::init<>())
.def("subscribe", &publisher::subscribe,
nb::call_policy<subscribe_policy>())
.def("unsubscribe", &publisher::unsubscribe,
nb::call_policy<unsubscribe_policy>())
.def("emit", &publisher::emit);
}
|