File: transform.hpp

package info (click to toggle)
python-boost-histogram 1.7.0-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 2,232 kB
  • sloc: python: 7,745; cpp: 3,243; makefile: 22; sh: 1
file content (162 lines) | stat: -rw-r--r-- 6,475 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
// Copyright 2018-2019 Henry Schreiner and Hans Dembinski
//
// Distributed under the 3-Clause BSD License.  See accompanying
// file LICENSE or https://github.com/scikit-hep/boost-histogram for details.

#pragma once

#include <bh_python/pybind11.hpp>

#include <boost/core/nvp.hpp>
#include <boost/histogram/axis/regular.hpp>
#include <utility>

#include <pybind11/functional.h>

namespace bh = boost::histogram;

struct func_transform {
    using raw_t = double(double);

    raw_t* _forward = nullptr;
    raw_t* _inverse = nullptr;
    py::object _forward_ob; // Held for reference counting, repr, and pickling
    py::object _inverse_ob;
    py::object _forward_converted; // Held for reference counting if conversion makes a
                                   // new object (ctypes does not bump the refcount)
    py::object _inverse_converted;
    py::object _convert_ob; // Called before computing transform if not None
    py::str _name;          // Optional name (uses repr from objects otherwise)

    /// Convert an object into a std::function. Can handle ctypes
    /// function pointers and pybind11 C++ functions, or anything
    /// else with a defined convert function
    std::tuple<raw_t*, py::object> compute(py::object& input) {
        // Run the conversion function on the input (unless conversion is None)
        py::object const tmp_src = _convert_ob.is_none() ? input : _convert_ob(input);

        // If a CTypes object is present, just use that (numba, for example)
        py::object const src = py::getattr(tmp_src, "ctypes", tmp_src);

        // import ctypes
        py::module const ctypes = py::module::import("ctypes");

        // Get the type: double(double)
        // function_type = ctypes.CFUNCTYPE(ctypes.c_double, ctypes.c_double)
        py::handle const CFUNCTYPE     = ctypes.attr("CFUNCTYPE");
        py::handle const c_double      = ctypes.attr("c_double");
        py::object const function_type = CFUNCTYPE(c_double, c_double);

        if(py::isinstance(src, function_type)) {
            py::handle const cast     = ctypes.attr("cast");
            py::handle const c_void_p = ctypes.attr("c_void_p");

            // ctypes.cast(in, ctypes.c_void_p).value
            py::object const addr_obj = cast(src, c_void_p);
            auto addr = py::cast<std::uintptr_t>(addr_obj.attr("value"));
            auto ptr
                = reinterpret_cast<raw_t*>(addr); // NOLINT(performance-no-int-to-ptr)
            return std::make_tuple(ptr, src);
        }

        // If we made it to this point, we probably have a C++ pybind object or an
        // invalid object. The following is based on the std::function conversion in
        // pybind11/functional.hpp
        if(!py::isinstance<py::function>(src))
            throw py::type_error("Only ctypes double(double) and C++ functions allowed "
                                 "(must be function)");

        py::detail::make_caster<std::function<raw_t>> func_caster;
        if(!func_caster.load(src, /*convert*/ false)) {
            // Note that each error is slightly different just to help with debugging
            throw py::type_error("Only ctypes double(double) and C++ functions allowed "
                                 "(must be stateless)");
        }
        auto func   = static_cast<std::function<raw_t>&>(func_caster);
        auto* cfunc = func.target<raw_t*>();
        if(cfunc == nullptr) {
            throw py::type_error(
                "Retrieving double(double) function failed (must be stateless)");
        }
        return std::make_tuple(*cfunc, src);
    }

    func_transform(py::object f, py::object i, py::object c, py::str n)
        : _forward_ob(f)
        , _inverse_ob(i)
        , _convert_ob(std::move(c))
        , _name(std::move(n)) {
        std::tie(_forward, _forward_converted) = compute(f);
        std::tie(_inverse, _inverse_converted) = compute(i);
    }

    func_transform()                          = default;
    ~func_transform()                         = default;
    func_transform(const func_transform&)     = default;
    func_transform(func_transform&&) noexcept = default;

    func_transform& operator=(const func_transform&)     = default;
    func_transform& operator=(func_transform&&) noexcept = default;

    double forward(double x) const { return _forward(x); }

    double inverse(double x) const { return _inverse(x); }

    bool operator==(const func_transform& other) const noexcept {
        try {
            return _forward_ob.equal(other._forward_ob)
                   && _inverse_ob.equal(other._inverse_ob);
        } catch(const py::error_already_set&) {
            return false;
        }
    }

    template <class Archive>
    void serialize(Archive& ar, unsigned /* version */) {
        ar& boost::make_nvp("forward", _forward_ob);
        ar& boost::make_nvp("inverse", _inverse_ob);
        ar& boost::make_nvp("convert", _convert_ob);
        ar& boost::make_nvp("name", _name);

        if(Archive::is_loading::value) {
            std::tie(_forward, _forward_converted) = compute(_forward_ob);
            std::tie(_inverse, _inverse_converted) = compute(_inverse_ob);
        }
    }
};

namespace boost {
namespace histogram {
namespace detail {
inline const char* axis_suffix(const ::func_transform&) { return "_trans"; }
} // namespace detail
} // namespace histogram
} // namespace boost

/// Simple deep copy for any class *without* a python component
template <class T>
T deep_copy(const T& input, py::object&) {
    return T(input);
}

/// Specialization for the case where Python components are present
/// (Function transform in this case)
template <>
inline func_transform deep_copy<func_transform>(const func_transform& input,
                                                py::object& memo) {
    py::module const copy = py::module::import("copy");

    py::object const forward = copy.attr("deepcopy")(input._forward_ob, memo);
    py::object const inverse = copy.attr("deepcopy")(input._inverse_ob, memo);
    py::object const convert = copy.attr("deepcopy")(input._convert_ob, memo);
    py::str const name       = copy.attr("deepcopy")(input._name, memo);

    return {forward, inverse, convert, name};
}

// Print in repr
template <class CharT, class Traits>
std::basic_ostream<CharT, Traits>& operator<<(std::basic_ostream<CharT, Traits>& os,
                                              const func_transform&) {
    return os << "func_transform";
}