1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378
|
//===- PybindUtils.h - Utilities for interop with pybind11 ------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_BINDINGS_PYTHON_PYBINDUTILS_H
#define MLIR_BINDINGS_PYTHON_PYBINDUTILS_H
#include "mlir-c/Support.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/DataTypes.h"
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
namespace mlir {
namespace python {
/// CRTP template for special wrapper types that are allowed to be passed in as
/// 'None' function arguments and can be resolved by some global mechanic if
/// so. Such types will raise an error if this global resolution fails, and
/// it is actually illegal for them to ever be unresolved. From a user
/// perspective, they behave like a smart ptr to the underlying type (i.e.
/// 'get' method and operator-> overloaded).
///
/// Derived types must provide a method, which is called when an environmental
/// resolution is required. It must raise an exception if resolution fails:
/// static ReferrentTy &resolve()
///
/// They must also provide a parameter description that will be used in
/// error messages about mismatched types:
/// static constexpr const char kTypeDescription[] = "<Description>";
template <typename DerivedTy, typename T>
class Defaulting {
public:
using ReferrentTy = T;
/// Type casters require the type to be default constructible, but using
/// such an instance is illegal.
Defaulting() = default;
Defaulting(ReferrentTy &referrent) : referrent(&referrent) {}
ReferrentTy *get() const { return referrent; }
ReferrentTy *operator->() { return referrent; }
private:
ReferrentTy *referrent = nullptr;
};
} // namespace python
} // namespace mlir
namespace pybind11 {
namespace detail {
template <typename DefaultingTy>
struct MlirDefaultingCaster {
PYBIND11_TYPE_CASTER(DefaultingTy, _(DefaultingTy::kTypeDescription));
bool load(pybind11::handle src, bool) {
if (src.is_none()) {
// Note that we do want an exception to propagate from here as it will be
// the most informative.
value = DefaultingTy{DefaultingTy::resolve()};
return true;
}
// Unlike many casters that chain, these casters are expected to always
// succeed, so instead of doing an isinstance check followed by a cast,
// just cast in one step and handle the exception. Returning false (vs
// letting the exception propagate) causes higher level signature parsing
// code to produce nice error messages (other than "Cannot cast...").
try {
value = DefaultingTy{
pybind11::cast<typename DefaultingTy::ReferrentTy &>(src)};
return true;
} catch (std::exception &) {
return false;
}
}
static handle cast(DefaultingTy src, return_value_policy policy,
handle parent) {
return pybind11::cast(src, policy);
}
};
} // namespace detail
} // namespace pybind11
//------------------------------------------------------------------------------
// Conversion utilities.
//------------------------------------------------------------------------------
namespace mlir {
/// Accumulates into a python string from a method that accepts an
/// MlirStringCallback.
struct PyPrintAccumulator {
pybind11::list parts;
void *getUserData() { return this; }
MlirStringCallback getCallback() {
return [](MlirStringRef part, void *userData) {
PyPrintAccumulator *printAccum =
static_cast<PyPrintAccumulator *>(userData);
pybind11::str pyPart(part.data,
part.length); // Decodes as UTF-8 by default.
printAccum->parts.append(std::move(pyPart));
};
}
pybind11::str join() {
pybind11::str delim("", 0);
return delim.attr("join")(parts);
}
};
/// Accumulates int a python file-like object, either writing text (default)
/// or binary.
class PyFileAccumulator {
public:
PyFileAccumulator(const pybind11::object &fileObject, bool binary)
: pyWriteFunction(fileObject.attr("write")), binary(binary) {}
void *getUserData() { return this; }
MlirStringCallback getCallback() {
return [](MlirStringRef part, void *userData) {
pybind11::gil_scoped_acquire acquire;
PyFileAccumulator *accum = static_cast<PyFileAccumulator *>(userData);
if (accum->binary) {
// Note: Still has to copy and not avoidable with this API.
pybind11::bytes pyBytes(part.data, part.length);
accum->pyWriteFunction(pyBytes);
} else {
pybind11::str pyStr(part.data,
part.length); // Decodes as UTF-8 by default.
accum->pyWriteFunction(pyStr);
}
};
}
private:
pybind11::object pyWriteFunction;
bool binary;
};
/// Accumulates into a python string from a method that is expected to make
/// one (no more, no less) call to the callback (asserts internally on
/// violation).
struct PySinglePartStringAccumulator {
void *getUserData() { return this; }
MlirStringCallback getCallback() {
return [](MlirStringRef part, void *userData) {
PySinglePartStringAccumulator *accum =
static_cast<PySinglePartStringAccumulator *>(userData);
assert(!accum->invoked &&
"PySinglePartStringAccumulator called back multiple times");
accum->invoked = true;
accum->value = pybind11::str(part.data, part.length);
};
}
pybind11::str takeValue() {
assert(invoked && "PySinglePartStringAccumulator not called back");
return std::move(value);
}
private:
pybind11::str value;
bool invoked = false;
};
/// A CRTP base class for pseudo-containers willing to support Python-type
/// slicing access on top of indexed access. Calling ::bind on this class
/// will define `__len__` as well as `__getitem__` with integer and slice
/// arguments.
///
/// This is intended for pseudo-containers that can refer to arbitrary slices of
/// underlying storage indexed by a single integer. Indexing those with an
/// integer produces an instance of ElementTy. Indexing those with a slice
/// produces a new instance of Derived, which can be sliced further.
///
/// A derived class must provide the following:
/// - a `static const char *pyClassName ` field containing the name of the
/// Python class to bind;
/// - an instance method `intptr_t getRawNumElements()` that returns the
/// number
/// of elements in the backing container (NOT that of the slice);
/// - an instance method `ElementTy getRawElement(intptr_t)` that returns a
/// single element at the given linear index (NOT slice index);
/// - an instance method `Derived slice(intptr_t, intptr_t, intptr_t)` that
/// constructs a new instance of the derived pseudo-container with the
/// given slice parameters (to be forwarded to the Sliceable constructor).
///
/// The getRawNumElements() and getRawElement(intptr_t) callbacks must not
/// throw.
///
/// A derived class may additionally define:
/// - a `static void bindDerived(ClassTy &)` method to bind additional methods
/// the python class.
template <typename Derived, typename ElementTy>
class Sliceable {
protected:
using ClassTy = pybind11::class_<Derived>;
/// Transforms `index` into a legal value to access the underlying sequence.
/// Returns <0 on failure.
intptr_t wrapIndex(intptr_t index) {
if (index < 0)
index = length + index;
if (index < 0 || index >= length)
return -1;
return index;
}
/// Computes the linear index given the current slice properties.
intptr_t linearizeIndex(intptr_t index) {
intptr_t linearIndex = index * step + startIndex;
assert(linearIndex >= 0 &&
linearIndex < static_cast<Derived *>(this)->getRawNumElements() &&
"linear index out of bounds, the slice is ill-formed");
return linearIndex;
}
/// Returns the element at the given slice index. Supports negative indices
/// by taking elements in inverse order. Returns a nullptr object if out
/// of bounds.
pybind11::object getItem(intptr_t index) {
// Negative indices mean we count from the end.
index = wrapIndex(index);
if (index < 0) {
PyErr_SetString(PyExc_IndexError, "index out of range");
return {};
}
return pybind11::cast(
static_cast<Derived *>(this)->getRawElement(linearizeIndex(index)));
}
/// Returns a new instance of the pseudo-container restricted to the given
/// slice. Returns a nullptr object on failure.
pybind11::object getItemSlice(PyObject *slice) {
ssize_t start, stop, extraStep, sliceLength;
if (PySlice_GetIndicesEx(slice, length, &start, &stop, &extraStep,
&sliceLength) != 0) {
PyErr_SetString(PyExc_IndexError, "index out of range");
return {};
}
return pybind11::cast(static_cast<Derived *>(this)->slice(
startIndex + start * step, sliceLength, step * extraStep));
}
public:
explicit Sliceable(intptr_t startIndex, intptr_t length, intptr_t step)
: startIndex(startIndex), length(length), step(step) {
assert(length >= 0 && "expected non-negative slice length");
}
/// Returns the `index`-th element in the slice, supports negative indices.
/// Throws if the index is out of bounds.
ElementTy getElement(intptr_t index) {
// Negative indices mean we count from the end.
index = wrapIndex(index);
if (index < 0) {
throw pybind11::index_error("index out of range");
}
return static_cast<Derived *>(this)->getRawElement(linearizeIndex(index));
}
/// Returns the size of slice.
intptr_t size() { return length; }
/// Returns a new vector (mapped to Python list) containing elements from two
/// slices. The new vector is necessary because slices may not be contiguous
/// or even come from the same original sequence.
std::vector<ElementTy> dunderAdd(Derived &other) {
std::vector<ElementTy> elements;
elements.reserve(length + other.length);
for (intptr_t i = 0; i < length; ++i) {
elements.push_back(static_cast<Derived *>(this)->getElement(i));
}
for (intptr_t i = 0; i < other.length; ++i) {
elements.push_back(static_cast<Derived *>(&other)->getElement(i));
}
return elements;
}
/// Binds the indexing and length methods in the Python class.
static void bind(pybind11::module &m) {
auto clazz = pybind11::class_<Derived>(m, Derived::pyClassName,
pybind11::module_local())
.def("__add__", &Sliceable::dunderAdd);
Derived::bindDerived(clazz);
// Manually implement the sequence protocol via the C API. We do this
// because it is approx 4x faster than via pybind11, largely because that
// formulation requires a C++ exception to be thrown to detect end of
// sequence.
// Since we are in a C-context, any C++ exception that happens here
// will terminate the program. There is nothing in this implementation
// that should throw in a non-terminal way, so we forgo further
// exception marshalling.
// See: https://github.com/pybind/pybind11/issues/2842
auto heap_type = reinterpret_cast<PyHeapTypeObject *>(clazz.ptr());
assert(heap_type->ht_type.tp_flags & Py_TPFLAGS_HEAPTYPE &&
"must be heap type");
heap_type->as_sequence.sq_length = +[](PyObject *rawSelf) -> Py_ssize_t {
auto self = pybind11::cast<Derived *>(rawSelf);
return self->length;
};
// sq_item is called as part of the sequence protocol for iteration,
// list construction, etc.
heap_type->as_sequence.sq_item =
+[](PyObject *rawSelf, Py_ssize_t index) -> PyObject * {
auto self = pybind11::cast<Derived *>(rawSelf);
return self->getItem(index).release().ptr();
};
// mp_subscript is used for both slices and integer lookups.
heap_type->as_mapping.mp_subscript =
+[](PyObject *rawSelf, PyObject *rawSubscript) -> PyObject * {
auto self = pybind11::cast<Derived *>(rawSelf);
Py_ssize_t index = PyNumber_AsSsize_t(rawSubscript, PyExc_IndexError);
if (!PyErr_Occurred()) {
// Integer indexing.
return self->getItem(index).release().ptr();
}
PyErr_Clear();
// Assume slice-based indexing.
if (PySlice_Check(rawSubscript)) {
return self->getItemSlice(rawSubscript).release().ptr();
}
PyErr_SetString(PyExc_ValueError, "expected integer or slice");
return nullptr;
};
}
/// Hook for derived classes willing to bind more methods.
static void bindDerived(ClassTy &) {}
private:
intptr_t startIndex;
intptr_t length;
intptr_t step;
};
} // namespace mlir
namespace llvm {
template <>
struct DenseMapInfo<MlirTypeID> {
static inline MlirTypeID getEmptyKey() {
auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
return mlirTypeIDCreate(pointer);
}
static inline MlirTypeID getTombstoneKey() {
auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
return mlirTypeIDCreate(pointer);
}
static inline unsigned getHashValue(const MlirTypeID &val) {
return mlirTypeIDHashValue(val);
}
static inline bool isEqual(const MlirTypeID &lhs, const MlirTypeID &rhs) {
return mlirTypeIDEqual(lhs, rhs);
}
};
} // namespace llvm
#endif // MLIR_BINDINGS_PYTHON_PYBINDUTILS_H
|