1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166
|
#include <torch/csrc/utils/pybind.h>
#include <torch/csrc/utils/python_arg_parser.h>
#include <torch/csrc/utils/python_symnode.h>
namespace pybind11::detail {
bool type_caster<c10::SymInt>::load(py::handle src, bool) {
if (torch::is_symint(src)) {
auto node = src.attr("node");
if (py::isinstance<c10::SymNodeImpl>(node)) {
value = c10::SymInt(py::cast<c10::SymNode>(node));
return true;
}
value = c10::SymInt(static_cast<c10::SymNode>(
c10::make_intrusive<torch::impl::PythonSymNodeImpl>(node)));
return true;
}
auto raw_obj = src.ptr();
if (THPVariable_Check(raw_obj)) {
auto& var = THPVariable_Unpack(raw_obj);
if (var.numel() == 1 &&
at::isIntegralType(var.dtype().toScalarType(), /*include_bool*/ true)) {
auto scalar = var.item();
TORCH_INTERNAL_ASSERT(scalar.isIntegral(/*include bool*/ false));
value = scalar.toSymInt();
return true;
}
}
if (THPUtils_checkIndex(raw_obj)) {
value = c10::SymInt{THPUtils_unpackIndex(raw_obj)};
return true;
}
return false;
}
py::handle type_caster<c10::SymInt>::cast(
const c10::SymInt& si,
return_value_policy /* policy */,
handle /* parent */) {
if (si.is_symbolic()) {
auto* py_node = dynamic_cast<torch::impl::PythonSymNodeImpl*>(
si.toSymNodeImplUnowned());
if (py_node) {
// Return the Python directly (unwrap)
return torch::get_symint_class()(py_node->getPyObj()).release();
} else {
// Wrap the C++ into Python
auto inner = py::cast(si.toSymNode());
if (!inner) {
throw python_error();
}
return torch::get_symint_class()(inner).release();
}
} else {
auto m = si.maybe_as_int();
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
return py::cast(m.value()).release();
}
}
bool type_caster<c10::SymFloat>::load(py::handle src, bool) {
if (torch::is_symfloat(src)) {
value = c10::SymFloat(static_cast<c10::SymNode>(
c10::make_intrusive<torch::impl::PythonSymNodeImpl>(src.attr("node"))));
return true;
}
auto raw_obj = src.ptr();
if (THPUtils_checkDouble(raw_obj)) {
value = c10::SymFloat{THPUtils_unpackDouble(raw_obj)};
return true;
}
return false;
}
py::handle type_caster<c10::SymFloat>::cast(
const c10::SymFloat& si,
return_value_policy /* policy */,
handle /* parent */) {
if (si.is_symbolic()) {
// TODO: generalize this to work with C++ backed class
auto* py_node =
dynamic_cast<torch::impl::PythonSymNodeImpl*>(si.toSymNodeImpl().get());
TORCH_INTERNAL_ASSERT(py_node);
return torch::get_symfloat_class()(py_node->getPyObj()).release();
} else {
return py::cast(si.as_float_unchecked()).release();
}
}
bool type_caster<c10::SymBool>::load(py::handle src, bool) {
if (torch::is_symbool(src)) {
value = c10::SymBool(static_cast<c10::SymNode>(
c10::make_intrusive<torch::impl::PythonSymNodeImpl>(src.attr("node"))));
return true;
}
auto raw_obj = src.ptr();
if (THPUtils_checkBool(raw_obj)) {
value = c10::SymBool{THPUtils_unpackBool(raw_obj)};
return true;
}
return false;
}
py::handle type_caster<c10::SymBool>::cast(
const c10::SymBool& si,
return_value_policy /* policy */,
handle /* parent */) {
if (auto m = si.maybe_as_bool()) {
return py::cast(*m).release();
} else {
// TODO: generalize this to work with C++ backed class
auto* py_node =
dynamic_cast<torch::impl::PythonSymNodeImpl*>(si.toSymNodeImpl().get());
TORCH_INTERNAL_ASSERT(py_node);
return torch::get_symbool_class()(py_node->getPyObj()).release();
}
}
bool type_caster<c10::Scalar>::load(py::handle src, bool) {
TORCH_INTERNAL_ASSERT(
0, "pybind11 loading for c10::Scalar NYI (file a bug if you need it)");
}
py::handle type_caster<c10::Scalar>::cast(
const c10::Scalar& scalar,
return_value_policy /* policy */,
handle /* parent */) {
if (scalar.isIntegral(/*includeBool*/ false)) {
// We have to be careful here; we cannot unconditionally route through
// SymInt because integer data from Tensors can easily be MIN_INT or
// very negative, which conflicts with the allocated range.
if (scalar.isSymbolic()) {
return py::cast(scalar.toSymInt()).release();
} else {
if (scalar.type() == at::ScalarType::UInt64) {
return py::cast(scalar.toUInt64()).release();
} else {
return py::cast(scalar.toLong()).release();
}
}
} else if (scalar.isFloatingPoint()) {
// This isn't strictly necessary but we add it for symmetry
if (scalar.isSymbolic()) {
return py::cast(scalar.toSymFloat()).release();
} else {
return py::cast(scalar.toDouble()).release();
}
} else if (scalar.isBoolean()) {
if (scalar.isSymbolic()) {
return py::cast(scalar.toSymBool()).release();
}
return py::cast(scalar.toBool()).release();
} else if (scalar.isComplex()) {
return py::cast(scalar.toComplexDouble()).release();
} else {
TORCH_INTERNAL_ASSERT(0, "unrecognized scalar type ", scalar.type());
}
}
} // namespace pybind11::detail
|