1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162
|
// Copyright 2018-2019 Henry Schreiner and Hans Dembinski
//
// Distributed under the 3-Clause BSD License. See accompanying
// file LICENSE or https://github.com/scikit-hep/boost-histogram for details.
#pragma once
#include <bh_python/pybind11.hpp>
#include <boost/core/nvp.hpp>
#include <boost/histogram/axis/regular.hpp>
#include <utility>
#include <pybind11/functional.h>
namespace bh = boost::histogram;
struct func_transform {
using raw_t = double(double);
raw_t* _forward = nullptr;
raw_t* _inverse = nullptr;
py::object _forward_ob; // Held for reference counting, repr, and pickling
py::object _inverse_ob;
py::object _forward_converted; // Held for reference counting if conversion makes a
// new object (ctypes does not bump the refcount)
py::object _inverse_converted;
py::object _convert_ob; // Called before computing transform if not None
py::str _name; // Optional name (uses repr from objects otherwise)
/// Convert an object into a std::function. Can handle ctypes
/// function pointers and pybind11 C++ functions, or anything
/// else with a defined convert function
std::tuple<raw_t*, py::object> compute(py::object& input) {
// Run the conversion function on the input (unless conversion is None)
py::object const tmp_src = _convert_ob.is_none() ? input : _convert_ob(input);
// If a CTypes object is present, just use that (numba, for example)
py::object const src = py::getattr(tmp_src, "ctypes", tmp_src);
// import ctypes
py::module const ctypes = py::module::import("ctypes");
// Get the type: double(double)
// function_type = ctypes.CFUNCTYPE(ctypes.c_double, ctypes.c_double)
py::handle const CFUNCTYPE = ctypes.attr("CFUNCTYPE");
py::handle const c_double = ctypes.attr("c_double");
py::object const function_type = CFUNCTYPE(c_double, c_double);
if(py::isinstance(src, function_type)) {
py::handle const cast = ctypes.attr("cast");
py::handle const c_void_p = ctypes.attr("c_void_p");
// ctypes.cast(in, ctypes.c_void_p).value
py::object const addr_obj = cast(src, c_void_p);
auto addr = py::cast<std::uintptr_t>(addr_obj.attr("value"));
auto ptr
= reinterpret_cast<raw_t*>(addr); // NOLINT(performance-no-int-to-ptr)
return std::make_tuple(ptr, src);
}
// If we made it to this point, we probably have a C++ pybind object or an
// invalid object. The following is based on the std::function conversion in
// pybind11/functional.hpp
if(!py::isinstance<py::function>(src))
throw py::type_error("Only ctypes double(double) and C++ functions allowed "
"(must be function)");
py::detail::make_caster<std::function<raw_t>> func_caster;
if(!func_caster.load(src, /*convert*/ false)) {
// Note that each error is slightly different just to help with debugging
throw py::type_error("Only ctypes double(double) and C++ functions allowed "
"(must be stateless)");
}
auto func = static_cast<std::function<raw_t>&>(func_caster);
auto* cfunc = func.target<raw_t*>();
if(cfunc == nullptr) {
throw py::type_error(
"Retrieving double(double) function failed (must be stateless)");
}
return std::make_tuple(*cfunc, src);
}
func_transform(py::object f, py::object i, py::object c, py::str n)
: _forward_ob(f)
, _inverse_ob(i)
, _convert_ob(std::move(c))
, _name(std::move(n)) {
std::tie(_forward, _forward_converted) = compute(f);
std::tie(_inverse, _inverse_converted) = compute(i);
}
func_transform() = default;
~func_transform() = default;
func_transform(const func_transform&) = default;
func_transform(func_transform&&) noexcept = default;
func_transform& operator=(const func_transform&) = default;
func_transform& operator=(func_transform&&) noexcept = default;
double forward(double x) const { return _forward(x); }
double inverse(double x) const { return _inverse(x); }
bool operator==(const func_transform& other) const noexcept {
try {
return _forward_ob.equal(other._forward_ob)
&& _inverse_ob.equal(other._inverse_ob);
} catch(const py::error_already_set&) {
return false;
}
}
template <class Archive>
void serialize(Archive& ar, unsigned /* version */) {
ar& boost::make_nvp("forward", _forward_ob);
ar& boost::make_nvp("inverse", _inverse_ob);
ar& boost::make_nvp("convert", _convert_ob);
ar& boost::make_nvp("name", _name);
if(Archive::is_loading::value) {
std::tie(_forward, _forward_converted) = compute(_forward_ob);
std::tie(_inverse, _inverse_converted) = compute(_inverse_ob);
}
}
};
namespace boost {
namespace histogram {
namespace detail {
inline const char* axis_suffix(const ::func_transform&) { return "_trans"; }
} // namespace detail
} // namespace histogram
} // namespace boost
/// Simple deep copy for any class *without* a python component
template <class T>
T deep_copy(const T& input, py::object&) {
return T(input);
}
/// Specialization for the case where Python components are present
/// (Function transform in this case)
template <>
inline func_transform deep_copy<func_transform>(const func_transform& input,
py::object& memo) {
py::module const copy = py::module::import("copy");
py::object const forward = copy.attr("deepcopy")(input._forward_ob, memo);
py::object const inverse = copy.attr("deepcopy")(input._inverse_ob, memo);
py::object const convert = copy.attr("deepcopy")(input._convert_ob, memo);
py::str const name = copy.attr("deepcopy")(input._name, memo);
return {forward, inverse, convert, name};
}
// Print in repr
template <class CharT, class Traits>
std::basic_ostream<CharT, Traits>& operator<<(std::basic_ostream<CharT, Traits>& os,
const func_transform&) {
return os << "func_transform";
}
|