File: python_scalars.h

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (119 lines) | stat: -rw-r--r-- 3,801 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
#pragma once

#include <ATen/ATen.h>
#include <c10/util/TypeCast.h>
#include <torch/csrc/python_headers.h>

#include <torch/csrc/Exceptions.h>
#include <torch/csrc/utils/python_numbers.h>

namespace torch {
namespace utils {

template <typename T>
inline T unpackIntegral(PyObject* obj, const char* type) {
#if PY_VERSION_HEX >= 0x030a00f0
  // In Python-3.10 floats can no longer be silently converted to integers
  // Keep backward compatible behavior for now
  if (PyFloat_Check(obj)) {
    return c10::checked_convert<T>(THPUtils_unpackDouble(obj), type);
  }
  return c10::checked_convert<T>(THPUtils_unpackLong(obj), type);
#else
  return static_cast<T>(THPUtils_unpackLong(obj));
#endif
}

inline void store_scalar(void* data, at::ScalarType scalarType, PyObject* obj) {
  switch (scalarType) {
    case at::kByte:
      *(uint8_t*)data = unpackIntegral<uint8_t>(obj, "uint8");
      break;
    case at::kChar:
      *(int8_t*)data = unpackIntegral<int8_t>(obj, "int8");
      break;
    case at::kShort:
      *(int16_t*)data = unpackIntegral<int16_t>(obj, "int16");
      break;
    case at::kInt:
      *(int32_t*)data = unpackIntegral<int32_t>(obj, "int32");
      break;
    case at::kLong:
      *(int64_t*)data = unpackIntegral<int64_t>(obj, "int64");
      break;
    case at::kHalf:
      *(at::Half*)data =
          at::convert<at::Half, double>(THPUtils_unpackDouble(obj));
      break;
    case at::kFloat:
      *(float*)data = (float)THPUtils_unpackDouble(obj);
      break;
    case at::kDouble:
      *(double*)data = THPUtils_unpackDouble(obj);
      break;
    case at::kComplexHalf:
      *(c10::complex<at::Half>*)data =
          (c10::complex<at::Half>)static_cast<c10::complex<float>>(
              THPUtils_unpackComplexDouble(obj));
      break;
    case at::kComplexFloat:
      *(c10::complex<float>*)data =
          (c10::complex<float>)THPUtils_unpackComplexDouble(obj);
      break;
    case at::kComplexDouble:
      *(c10::complex<double>*)data = THPUtils_unpackComplexDouble(obj);
      break;
    case at::kBool:
      *(bool*)data = THPUtils_unpackNumberAsBool(obj);
      break;
    case at::kBFloat16:
      *(at::BFloat16*)data =
          at::convert<at::BFloat16, double>(THPUtils_unpackDouble(obj));
      break;
    default:
      throw std::runtime_error("invalid type");
  }
}

inline PyObject* load_scalar(void* data, at::ScalarType scalarType) {
  switch (scalarType) {
    case at::kByte:
      return THPUtils_packInt64(*(uint8_t*)data);
    case at::kChar:
      return THPUtils_packInt64(*(int8_t*)data);
    case at::kShort:
      return THPUtils_packInt64(*(int16_t*)data);
    case at::kInt:
      return THPUtils_packInt64(*(int32_t*)data);
    case at::kLong:
      return THPUtils_packInt64(*(int64_t*)data);
    case at::kHalf:
      return PyFloat_FromDouble(
          at::convert<double, at::Half>(*(at::Half*)data));
    case at::kFloat:
      return PyFloat_FromDouble(*(float*)data);
    case at::kDouble:
      return PyFloat_FromDouble(*(double*)data);
    case at::kComplexHalf: {
      auto data_ = reinterpret_cast<c10::complex<at::Half>*>(data);
      return PyComplex_FromDoubles(data_->real(), data_->imag());
    }
    case at::kComplexFloat: {
      auto data_ = reinterpret_cast<c10::complex<float>*>(data);
      return PyComplex_FromDoubles(data_->real(), data_->imag());
    }
    case at::kComplexDouble:
      return PyComplex_FromCComplex(
          *reinterpret_cast<Py_complex*>((c10::complex<double>*)data));
    case at::kBool:
      return PyBool_FromLong(*(bool*)data);
    case at::kBFloat16:
      return PyFloat_FromDouble(
          at::convert<double, at::BFloat16>(*(at::BFloat16*)data));
    default:
      throw std::runtime_error("invalid type");
  }
}

} // namespace utils
} // namespace torch