File: pybind.cpp

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (166 lines) | stat: -rw-r--r-- 5,169 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
#include <torch/csrc/utils/pybind.h>
#include <torch/csrc/utils/python_arg_parser.h>
#include <torch/csrc/utils/python_symnode.h>

namespace pybind11::detail {

bool type_caster<c10::SymInt>::load(py::handle src, bool) {
  if (torch::is_symint(src)) {
    auto node = src.attr("node");
    if (py::isinstance<c10::SymNodeImpl>(node)) {
      value = c10::SymInt(py::cast<c10::SymNode>(node));
      return true;
    }

    value = c10::SymInt(static_cast<c10::SymNode>(
        c10::make_intrusive<torch::impl::PythonSymNodeImpl>(node)));
    return true;
  }

  auto raw_obj = src.ptr();

  if (THPVariable_Check(raw_obj)) {
    auto& var = THPVariable_Unpack(raw_obj);
    if (var.numel() == 1 &&
        at::isIntegralType(var.dtype().toScalarType(), /*include_bool*/ true)) {
      auto scalar = var.item();
      TORCH_INTERNAL_ASSERT(scalar.isIntegral(/*include bool*/ false));
      value = scalar.toSymInt();
      return true;
    }
  }

  if (THPUtils_checkIndex(raw_obj)) {
    value = c10::SymInt{THPUtils_unpackIndex(raw_obj)};
    return true;
  }
  return false;
}

py::handle type_caster<c10::SymInt>::cast(
    const c10::SymInt& si,
    return_value_policy /* policy */,
    handle /* parent */) {
  if (si.is_symbolic()) {
    auto* py_node = dynamic_cast<torch::impl::PythonSymNodeImpl*>(
        si.toSymNodeImplUnowned());
    if (py_node) {
      // Return the Python directly (unwrap)
      return torch::get_symint_class()(py_node->getPyObj()).release();
    } else {
      // Wrap the C++ into Python
      auto inner = py::cast(si.toSymNode());
      if (!inner) {
        throw python_error();
      }
      return torch::get_symint_class()(inner).release();
    }
  } else {
    auto m = si.maybe_as_int();
    // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
    return py::cast(m.value()).release();
  }
}

bool type_caster<c10::SymFloat>::load(py::handle src, bool) {
  if (torch::is_symfloat(src)) {
    value = c10::SymFloat(static_cast<c10::SymNode>(
        c10::make_intrusive<torch::impl::PythonSymNodeImpl>(src.attr("node"))));
    return true;
  }

  auto raw_obj = src.ptr();
  if (THPUtils_checkDouble(raw_obj)) {
    value = c10::SymFloat{THPUtils_unpackDouble(raw_obj)};
    return true;
  }
  return false;
}

py::handle type_caster<c10::SymFloat>::cast(
    const c10::SymFloat& si,
    return_value_policy /* policy */,
    handle /* parent */) {
  if (si.is_symbolic()) {
    // TODO: generalize this to work with C++ backed class
    auto* py_node =
        dynamic_cast<torch::impl::PythonSymNodeImpl*>(si.toSymNodeImpl().get());
    TORCH_INTERNAL_ASSERT(py_node);
    return torch::get_symfloat_class()(py_node->getPyObj()).release();
  } else {
    return py::cast(si.as_float_unchecked()).release();
  }
}

bool type_caster<c10::SymBool>::load(py::handle src, bool) {
  if (torch::is_symbool(src)) {
    value = c10::SymBool(static_cast<c10::SymNode>(
        c10::make_intrusive<torch::impl::PythonSymNodeImpl>(src.attr("node"))));
    return true;
  }

  auto raw_obj = src.ptr();
  if (THPUtils_checkBool(raw_obj)) {
    value = c10::SymBool{THPUtils_unpackBool(raw_obj)};
    return true;
  }
  return false;
}

py::handle type_caster<c10::SymBool>::cast(
    const c10::SymBool& si,
    return_value_policy /* policy */,
    handle /* parent */) {
  if (auto m = si.maybe_as_bool()) {
    return py::cast(*m).release();
  } else {
    // TODO: generalize this to work with C++ backed class
    auto* py_node =
        dynamic_cast<torch::impl::PythonSymNodeImpl*>(si.toSymNodeImpl().get());
    TORCH_INTERNAL_ASSERT(py_node);
    return torch::get_symbool_class()(py_node->getPyObj()).release();
  }
}

bool type_caster<c10::Scalar>::load(py::handle src, bool) {
  TORCH_INTERNAL_ASSERT(
      0, "pybind11 loading for c10::Scalar NYI (file a bug if you need it)");
}

py::handle type_caster<c10::Scalar>::cast(
    const c10::Scalar& scalar,
    return_value_policy /* policy */,
    handle /* parent */) {
  if (scalar.isIntegral(/*includeBool*/ false)) {
    // We have to be careful here; we cannot unconditionally route through
    // SymInt because integer data from Tensors can easily be MIN_INT or
    // very negative, which conflicts with the allocated range.
    if (scalar.isSymbolic()) {
      return py::cast(scalar.toSymInt()).release();
    } else {
      if (scalar.type() == at::ScalarType::UInt64) {
        return py::cast(scalar.toUInt64()).release();
      } else {
        return py::cast(scalar.toLong()).release();
      }
    }
  } else if (scalar.isFloatingPoint()) {
    // This isn't strictly necessary but we add it for symmetry
    if (scalar.isSymbolic()) {
      return py::cast(scalar.toSymFloat()).release();
    } else {
      return py::cast(scalar.toDouble()).release();
    }
  } else if (scalar.isBoolean()) {
    if (scalar.isSymbolic()) {
      return py::cast(scalar.toSymBool()).release();
    }
    return py::cast(scalar.toBool()).release();
  } else if (scalar.isComplex()) {
    return py::cast(scalar.toComplexDouble()).release();
  } else {
    TORCH_INTERNAL_ASSERT(0, "unrecognized scalar type ", scalar.type());
  }
}

} // namespace pybind11::detail