File: SymInt.h

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (246 lines) | stat: -rw-r--r-- 7,841 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
#pragma once

#include <c10/core/SymIntNodeImpl.h>
#include <c10/macros/Macros.h>
#include <c10/util/Exception.h>
#include <c10/util/intrusive_ptr.h>

#include <memory>
#include <numeric>

namespace c10 {

class SymFloat;

// `SymInt` is a C++ wrapper class around int64_t data_ which  and is used to
// represent concrete dimension values.
//
// `SymInt` is also a data type in Pytorch that can be used in function schemas
// to enable tracing.
//
// `SymInt` is introduced to enable tracing arithmetic
// operations on symbolic integers (e.g. sizes). Tracing symbolic sizes will
// allow LTC and AOTAutograd representing dynamic shapes in expression graphs
// faithfully without baking in concrete dimension values.
//
// To trace the operations, SymInt will overload arithmetic operators (e.g. +,
// -, *) and will provide overloads taking SymInt for commonly used math
// functions.
//
// SymInt will be extenteded to represent a union structure Union[int64_t,
// SymIntNodeImpl*] which will be implemented as a single packed int64_t field
// named data_.

#ifdef C10_MOBILE
#define SKIP_IS_SYMBOLIC_ON_MOBILE(_) \
  do {                                \
  } while (0)
#else
#define SKIP_IS_SYMBOLIC_ON_MOBILE(X) TORCH_CHECK(X)
#endif

class C10_API SymInt {
 public:
  enum Unchecked {
    UNCHECKED,
  };

  /*implicit*/ SymInt(int64_t d) : data_(d) {
    SKIP_IS_SYMBOLIC_ON_MOBILE(!is_symbolic());
  };
  SymInt() : data_(0) {}

  // unchecked c-tor accepting raw `data_`
  // One appropriate use for this is when you are constructing a symint
  // in a situation where you know it is non-negative (or, if it is negative,
  // the negative value is -1; i.e., not user controlled)
  SymInt(Unchecked, int64_t d) : data_(d) {}

  // TODO: these implementations are not optimal because they allocate a
  // temporary and then use the move constructor/assignment
  SymInt(const SymInt& s) : data_(0) {
    if (s.is_symbolic()) {
      *this = SymInt::toSymInt(s.toSymIntNodeImpl());
    } else {
      data_ = s.data_;
    }
  }
  SymInt(SymInt&& s) : data_(s.data_) {
    s.data_ = 0;
  }

  SymInt& operator=(const SymInt& s) {
    if (this != &s) {
      if (s.is_symbolic()) {
        *this = SymInt::toSymInt(s.toSymIntNodeImpl());
      } else {
        data_ = s.data_;
      }
    }
    return *this;
  }
  SymInt& operator=(SymInt&& s) {
    if (this != &s) {
      release_(); // release the current SymIntNode if any
      data_ = s.data_;
      if (s.is_symbolic())
        s.data_ = 0;
    };
    return *this;
  }

  SymInt clone() const {
#ifndef C10_MOBILE
    if (is_symbolic()) {
      return toSymIntNodeImplUnowned()->clone()->toSymInt();
    }
#else
    TORCH_INTERNAL_ASSERT(!is_symbolic());
#endif
    return *this;
  }

#ifndef C10_MOBILE
  SymIntNodeImpl* toSymIntNodeImplUnowned() const {
    uint64_t unextended_bits = static_cast<uint64_t>(data_) & ~MASK;
    uint64_t sign_bit_mask = 1ULL << (62 - 1);
    // https://stackoverflow.com/questions/42534749/signed-extension-from-24-bit-to-32-bit-in-c
    uint64_t extended_bits = (unextended_bits ^ sign_bit_mask) - sign_bit_mask;
    return static_cast<SymIntNodeImpl*>(
        reinterpret_cast<void*>(static_cast<uintptr_t>(extended_bits)));
  }

  void release_() {
    if (is_symbolic()) {
      SymIntNode::reclaim(toSymIntNodeImplUnowned()); // steal
    }
  }

  SymIntNodeImpl* release() && {
    TORCH_INTERNAL_ASSERT(is_symbolic());
    auto* r = toSymIntNodeImplUnowned();
    data_ = 0; // transfer ownership
    return r;
  }
#else
  void release_() {}

  SymIntNodeImpl* release() && {
    TORCH_INTERNAL_ASSERT(false);
  }
#endif

  SymIntNode toSymIntNodeImpl() const;
  static c10::SymInt toSymInt(SymIntNode sin);

  ~SymInt() {
    release_();
  }

  // Require the int to be non-symbolic, and if it is symbolic raise an
  // error.  This is safe to use for C++ code that doesn't work for symbolic
  // shapes, and you don't have time to fix it immediately, as if we
  // try to trigger the path in C++ you'll appropriately get an error
  int64_t expect_int() const {
    SKIP_IS_SYMBOLIC_ON_MOBILE(!is_symbolic());
    return data_;
  }

  // Insert a guard for the int to be its concrete value, and then return
  // that value.  This operation always works, even if the int is symbolic,
  // so long as we know what the underlying value is (e.g., this won't work
  // if you call it on the size of nonzero output).  Don't blindly put this
  // everywhere; you can cause overspecialization of PyTorch programs with
  // this method.
  //
  // It should be called as guard_int(__FILE__, __LINE__).  The file and line
  // number can be used to diagnose overspecialization.
  int64_t guard_int(const char* file, int64_t line) const;

  // N.B. It's important to keep this definition in the header
  // as we expect if checks to be folded for mobile builds
  // where `is_symbolic` is always false
  C10_ALWAYS_INLINE bool is_symbolic() const {
#ifdef C10_MOBILE
    return false;
#else
    return (MASK & static_cast<uint64_t>(this->data_)) == IS_SYM;
#endif
  }

  SymInt operator+(SymInt sci) const;
  SymInt operator-(SymInt sci) const;
  SymInt operator*(SymInt sci) const;
  SymInt operator/(SymInt sci) const;
  SymInt operator%(SymInt sci) const;
  bool operator==(SymInt sci) const;
  bool operator!=(SymInt p2) const;
  bool operator<(SymInt sci) const;
  bool operator<=(SymInt sci) const;
  bool operator>(SymInt sci) const;
  bool operator>=(SymInt sci) const;
  void operator*=(SymInt sci);
  void operator+=(SymInt sci);

  SymInt operator*(int64_t sci) const;
  bool operator<(int64_t sci) const;
  bool operator==(int64_t sci) const;
  bool operator!=(int64_t sci) const;
  bool operator<=(int64_t sci) const;
  bool operator>(int64_t sci) const;
  bool operator>=(int64_t sci) const;

  operator SymFloat() const;

  int64_t as_int_unchecked() const {
    return data_;
  }

  // Return whether the integer is representable as a SymInt.
  static bool check_range(int64_t i) {
    return i > MIN_INT;
  }

 private:
  // Constraints on the internal representation:
  // - Should represent positive and small negative ints
  // - No conversion necessary for operations on ints.
  // - Must represent valid 64-bit pointers
  //
  // So, the scheme is to reserve large negative numbers:
  // - 0b0.... means we are a positive int (following two's complement)
  // - 0b11... means we are a negative int (following two's complement)
  // - 0b10... means we are are a pointer. This means that
  //           [-2^63, -2^62-1] are not representable as ints.
  //           We don't actually need all of this space as on x86_64
  //           as the top 16bits aren't used for anything
  static constexpr uint64_t MASK = 1ULL << 63 | 1ULL << 62;
  static constexpr uint64_t IS_SYM = 1ULL << 63;
  // Since we use the top two bits to determine whether something is symbolic,
  // we cannot represent symbolic indices that are large enough to use those
  // bits. This will probably never happen.
  static constexpr uint64_t MAX_SYM_IDX = 1ULL << 62;
  // Since 0b10... is reserved for symbolic indices, any integers lower than
  // this value would collide with our representation.
  static constexpr int64_t MIN_INT = -1LL & static_cast<int64_t>(~(1ULL << 62));
  int64_t data_;
};

#undef SKIP_IS_SYMBOLIC_ON_MOBILE

/// Sum of a list of SymInt; accumulates into the c10::SymInt expression
template <
    typename C,
    typename std::enable_if<
        std::is_same<typename C::value_type, c10::SymInt>::value,
        int>::type = 0>
inline c10::SymInt multiply_integers(const C& container) {
  return std::accumulate(
      container.begin(),
      container.end(),
      c10::SymInt(1),
      [](c10::SymInt a, c10::SymInt b) { return a * b; });
}

C10_API std::ostream& operator<<(std::ostream& os, SymInt s);
} // namespace c10