File: DispatchKeySet.h

package info (click to toggle)
pytorch 1.7.1-7
  • links: PTS, VCS
  • area: main
  • in suites: bullseye
  • size: 80,340 kB
  • sloc: cpp: 670,830; python: 343,991; ansic: 67,845; asm: 5,503; sh: 2,924; java: 2,888; xml: 266; makefile: 244; ruby: 148; yacc: 144; objc: 51; lex: 44
file content (228 lines) | stat: -rw-r--r-- 9,158 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
#pragma once

#include <c10/core/DispatchKey.h>
#include <c10/util/llvmMathExtras.h>
#include <c10/util/Exception.h>
#include <ostream>

namespace c10 {

// A representation of a set of DispatchKeys.  A tensor may have multiple
// tensor type ids, e.g., a Variable tensor can also be a CPU tensor; the
// DispatchKeySet specifies what type ids apply.  The internal representation is
// as a 64-bit bit set (this means only 64 tensor type ids are supported).
//
// Note that DispatchKeys are ordered; thus, we can ask questions like "what is
// the highest priority DispatchKey in the set"?  (The set itself is not
// ordered; two sets with the same ids will always have the ids ordered in the
// same way.)
//
// At the moment, there are no nontrivial uses of this set; tensors are always
// singletons.  In the near future, this set will represent variable? + tensor
// type id.  In the far future, it will be requires grad? + profiling? +
// tracing? + lazy? + tensor type id.
//
// (The difference between variable and requires grad, is that
// there are currently three states a tensor can be:
//  1. Not a variable
//  2. Variable with requires_grad=False
//  3. Variable with requires_grad=True
// Eventually, we want to kill state (1), and only dispatch to autograd
// handling code if one of the inputs requires grad.)
//
// An undefined tensor is one with an empty tensor type set.
class DispatchKeySet final {
public:
  enum Full { FULL };
  enum FullAfter { FULL_AFTER };
  enum Raw { RAW };

  // NB: default constructor representation as zero is MANDATORY as
  // use of DispatchKeySet in TLS requires this.
  constexpr DispatchKeySet()
    : repr_(0) {}
  constexpr DispatchKeySet(Full)
    : repr_(std::numeric_limits<decltype(repr_)>::max()) {}
  constexpr DispatchKeySet(FullAfter, DispatchKey t)
    // LSB after t are OK, but not t itself.
    : repr_((1ULL << (static_cast<uint8_t>(t) - 1)) - 1) {}
  // Public version of DispatchKeySet(uint64_t) API; external users
  // must be explicit when they do this!
  constexpr DispatchKeySet(Raw, uint64_t x)
    : repr_(x) {}
  explicit constexpr DispatchKeySet(DispatchKey t)
    : repr_(t == DispatchKey::Undefined
              ? 0
              : 1ULL << (static_cast<uint8_t>(t) - 1)) {}
  explicit constexpr DispatchKeySet(std::initializer_list<DispatchKey> ks)
  : repr_(0) {
    for (auto k : ks) {
      repr_ |= DispatchKeySet(k).repr_;
    }
  }
  // Test if a DispatchKey is in the set
  bool has(DispatchKey t) const {
    TORCH_INTERNAL_ASSERT(t != DispatchKey::Undefined);
    return static_cast<bool>(repr_ & DispatchKeySet(t).repr_);
  }
  // Test if DispatchKeySet is a superset of ks.
  bool isSupersetOf(DispatchKeySet ks) const {
    return (repr_ & ks.repr_) == ks.repr_;
  }
  // Perform set union
  constexpr DispatchKeySet operator|(DispatchKeySet other) const {
    return DispatchKeySet(repr_ | other.repr_);
  }
  // Perform set intersection
  DispatchKeySet operator&(DispatchKeySet other) const {
    return DispatchKeySet(repr_ & other.repr_);
  }
  // Compute the set difference self - other
  DispatchKeySet operator-(DispatchKeySet other) const {
    return DispatchKeySet(repr_ & ~other.repr_);
  }
  // Perform set equality
  bool operator==(DispatchKeySet other) const {
    return repr_ == other.repr_;
  }
  // Add a DispatchKey to the DispatchKey set.  Does NOT mutate,
  // returns the extended DispatchKeySet!
  C10_NODISCARD DispatchKeySet add(DispatchKey t) const {
    return *this | DispatchKeySet(t);
  }
  // Remove a DispatchKey from the DispatchKey set.  This is
  // generally not an operation you should be doing (it's
  // used to implement operator<<)
  C10_NODISCARD DispatchKeySet remove(DispatchKey t) const {
    return DispatchKeySet(repr_ & ~DispatchKeySet(t).repr_);
  }
  // Is the set empty?  (AKA undefined tensor)
  bool empty() const {
    return repr_ == 0;
  }
  uint64_t raw_repr() { return repr_; }
  // Return the type id in this set with the highest priority (i.e.,
  // is the largest in the DispatchKey enum).  Intuitively, this
  // type id is the one that should handle dispatch (assuming there
  // aren't any further exclusions or inclusions).
  DispatchKey highestPriorityTypeId() const {
    // TODO: If I put Undefined as entry 64 and then adjust the
    // singleton constructor to shift from the right, we can get rid of the
    // subtraction here.  It's modestly more complicated to get right so I
    // didn't do it for now.
    return static_cast<DispatchKey>(64 - llvm::countLeadingZeros(repr_));
  }

  DispatchKey highestPriorityBackendTypeId() const {
    return (*this & ((1ULL << static_cast<uint8_t>(DispatchKey::EndOfBackendKeys)) - 1))
      .highestPriorityTypeId();
  }
private:
  constexpr DispatchKeySet(uint64_t repr) : repr_(repr) {}
  uint64_t repr_ = 0;

public:
  // STL iterator for DispatchKeySet. Iterates through all DispatchKeys in the
  // set. The iterator is only invalidated by the destruction of the underlying
  // DispatchKeySet as the iterator stores a pointer to the raw represenation of
  // the DispatchKeySet.
  class iterator {
   public:
    using self_type = iterator;
    using iterator_category = std::input_iterator_tag;
    using value_type = DispatchKey;
    using difference_type = ptrdiff_t;

    explicit iterator(const uint64_t *data_ptr, uint8_t i=0) : data_ptr_(data_ptr), i_(i) {
      // Go to the first key in the set
      ++(*this);
    }

    self_type& operator++() {
      TORCH_INTERNAL_ASSERT(i_ <= static_cast<uint8_t>(DispatchKey::NumDispatchKeys));

      // Create a masked version of the set representation to ignore previous
      // keys that we've iterated through.
      uint64_t masked_data = llvm::maskTrailingZeros<uint64_t>(i_) & *data_ptr_;
      uint64_t firstKeyIndex = llvm::findFirstSet(masked_data);

      // If there are no keys, set to end iterator value
      if (firstKeyIndex == std::numeric_limits<uint64_t>::max() ||
         i_ == static_cast<uint8_t>(DispatchKey::NumDispatchKeys)) {
        i_ = static_cast<uint8_t>(DispatchKey::NumDispatchKeys);
        return *this;
      }

      i_ = static_cast<uint8_t>(firstKeyIndex) + 1;
      return *this;
    }

    self_type operator++(int) {
      self_type previous_iterator =  *this;
      ++(*this);
      return previous_iterator;
    }

    bool operator==(const self_type& rhs) const { return i_ == rhs.i_; }
    bool operator!=(const self_type& rhs) const { return i_ != rhs.i_; }
    DispatchKey operator*() const { return static_cast<DispatchKey> (i_); }

   private:
    const uint64_t *data_ptr_;
    uint8_t i_;
  };

 public:
  // Returns iterator to the first key in the set. If no keys are in the
  // set, then will return the end iterator.
  iterator begin() const { return iterator(&repr_); }

  // We do not need to iterate beyond NumDispatchKeys so we will treat this as
  // the end iterator. NumDispatchKeys will always be strictly less than 64.
  iterator end() const { return iterator(&repr_, static_cast<uint8_t>(DispatchKey::NumDispatchKeys)); }

};

C10_API std::string toString(DispatchKeySet);
C10_API std::ostream& operator<<(std::ostream&, DispatchKeySet);

// autograd_dispatch_keyset should include all runtime autograd keys.
// Alias key DispatchKey::Autograd maps to autograd_dispatch_keyset.
constexpr DispatchKeySet autograd_dispatch_keyset = DispatchKeySet({
  DispatchKey::AutogradCPU,
  DispatchKey::AutogradCUDA,
  DispatchKey::AutogradXLA,
  DispatchKey::AutogradPrivateUse1,
  DispatchKey::AutogradPrivateUse2,
  DispatchKey::AutogradPrivateUse3,
  DispatchKey::AutogradOther,
});

// Resolve alias dispatch key to DispatchKeySet if applicable
C10_API DispatchKeySet getRuntimeDispatchKeySet(DispatchKey t);

// Returns a DispatchKeySet of all backend keys mapped to Autograd dispatch key t,
// DispatchKeySet is empty if t is not alias of DispatchKey::Autograd.
C10_API DispatchKeySet getBackendKeySetFromAutograd(DispatchKey t);

// This API exists because we have a use case for checking
// getRuntimeDispatchKeySet(alias).has(DispatchKey::Undefind)
// in OperatorEntry.cpp but we disallow it in has() API.
C10_API bool isIncludedInAlias(DispatchKey k, DispatchKey alias);

// Historically, every tensor only had a single DispatchKey, and it was always
// something like CPU, and there wasn't any of this business where TLS
// could cause the DispatchKey of a tensor to change.  But we still have some
// legacy code that is still using DispatchKey for things like instanceof
// checks; if at all possible, refactor the code to stop using DispatchKey in
// those cases.
static inline DispatchKey legacyExtractDispatchKey(DispatchKeySet s) {
  // NB: If you add any extra keys that can be stored in TensorImpl on
  // top of existing "normal" keys like CPU/CUDA, you need to add it
  // here.  At the moment, RequiresGrad (replacement for Variable)
  // is the most likely key that will need this treatment;
  // After Autograd keys are moved from globally enabled set to TensorImpl,
  // we should remove all Autograd keys before taking highestPriority.
  return (s - autograd_dispatch_keyset).highestPriorityTypeId();
}
}