1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228
|
#pragma once
#include <c10/core/DispatchKey.h>
#include <c10/util/llvmMathExtras.h>
#include <c10/util/Exception.h>
#include <ostream>
namespace c10 {
// A representation of a set of DispatchKeys. A tensor may have multiple
// tensor type ids, e.g., a Variable tensor can also be a CPU tensor; the
// DispatchKeySet specifies what type ids apply. The internal representation is
// as a 64-bit bit set (this means only 64 tensor type ids are supported).
//
// Note that DispatchKeys are ordered; thus, we can ask questions like "what is
// the highest priority DispatchKey in the set"? (The set itself is not
// ordered; two sets with the same ids will always have the ids ordered in the
// same way.)
//
// At the moment, there are no nontrivial uses of this set; tensors are always
// singletons. In the near future, this set will represent variable? + tensor
// type id. In the far future, it will be requires grad? + profiling? +
// tracing? + lazy? + tensor type id.
//
// (The difference between variable and requires grad, is that
// there are currently three states a tensor can be:
// 1. Not a variable
// 2. Variable with requires_grad=False
// 3. Variable with requires_grad=True
// Eventually, we want to kill state (1), and only dispatch to autograd
// handling code if one of the inputs requires grad.)
//
// An undefined tensor is one with an empty tensor type set.
class DispatchKeySet final {
public:
enum Full { FULL };
enum FullAfter { FULL_AFTER };
enum Raw { RAW };
// NB: default constructor representation as zero is MANDATORY as
// use of DispatchKeySet in TLS requires this.
constexpr DispatchKeySet()
: repr_(0) {}
constexpr DispatchKeySet(Full)
: repr_(std::numeric_limits<decltype(repr_)>::max()) {}
constexpr DispatchKeySet(FullAfter, DispatchKey t)
// LSB after t are OK, but not t itself.
: repr_((1ULL << (static_cast<uint8_t>(t) - 1)) - 1) {}
// Public version of DispatchKeySet(uint64_t) API; external users
// must be explicit when they do this!
constexpr DispatchKeySet(Raw, uint64_t x)
: repr_(x) {}
explicit constexpr DispatchKeySet(DispatchKey t)
: repr_(t == DispatchKey::Undefined
? 0
: 1ULL << (static_cast<uint8_t>(t) - 1)) {}
explicit constexpr DispatchKeySet(std::initializer_list<DispatchKey> ks)
: repr_(0) {
for (auto k : ks) {
repr_ |= DispatchKeySet(k).repr_;
}
}
// Test if a DispatchKey is in the set
bool has(DispatchKey t) const {
TORCH_INTERNAL_ASSERT(t != DispatchKey::Undefined);
return static_cast<bool>(repr_ & DispatchKeySet(t).repr_);
}
// Test if DispatchKeySet is a superset of ks.
bool isSupersetOf(DispatchKeySet ks) const {
return (repr_ & ks.repr_) == ks.repr_;
}
// Perform set union
constexpr DispatchKeySet operator|(DispatchKeySet other) const {
return DispatchKeySet(repr_ | other.repr_);
}
// Perform set intersection
DispatchKeySet operator&(DispatchKeySet other) const {
return DispatchKeySet(repr_ & other.repr_);
}
// Compute the set difference self - other
DispatchKeySet operator-(DispatchKeySet other) const {
return DispatchKeySet(repr_ & ~other.repr_);
}
// Perform set equality
bool operator==(DispatchKeySet other) const {
return repr_ == other.repr_;
}
// Add a DispatchKey to the DispatchKey set. Does NOT mutate,
// returns the extended DispatchKeySet!
C10_NODISCARD DispatchKeySet add(DispatchKey t) const {
return *this | DispatchKeySet(t);
}
// Remove a DispatchKey from the DispatchKey set. This is
// generally not an operation you should be doing (it's
// used to implement operator<<)
C10_NODISCARD DispatchKeySet remove(DispatchKey t) const {
return DispatchKeySet(repr_ & ~DispatchKeySet(t).repr_);
}
// Is the set empty? (AKA undefined tensor)
bool empty() const {
return repr_ == 0;
}
uint64_t raw_repr() { return repr_; }
// Return the type id in this set with the highest priority (i.e.,
// is the largest in the DispatchKey enum). Intuitively, this
// type id is the one that should handle dispatch (assuming there
// aren't any further exclusions or inclusions).
DispatchKey highestPriorityTypeId() const {
// TODO: If I put Undefined as entry 64 and then adjust the
// singleton constructor to shift from the right, we can get rid of the
// subtraction here. It's modestly more complicated to get right so I
// didn't do it for now.
return static_cast<DispatchKey>(64 - llvm::countLeadingZeros(repr_));
}
DispatchKey highestPriorityBackendTypeId() const {
return (*this & ((1ULL << static_cast<uint8_t>(DispatchKey::EndOfBackendKeys)) - 1))
.highestPriorityTypeId();
}
private:
constexpr DispatchKeySet(uint64_t repr) : repr_(repr) {}
uint64_t repr_ = 0;
public:
// STL iterator for DispatchKeySet. Iterates through all DispatchKeys in the
// set. The iterator is only invalidated by the destruction of the underlying
// DispatchKeySet as the iterator stores a pointer to the raw represenation of
// the DispatchKeySet.
class iterator {
public:
using self_type = iterator;
using iterator_category = std::input_iterator_tag;
using value_type = DispatchKey;
using difference_type = ptrdiff_t;
explicit iterator(const uint64_t *data_ptr, uint8_t i=0) : data_ptr_(data_ptr), i_(i) {
// Go to the first key in the set
++(*this);
}
self_type& operator++() {
TORCH_INTERNAL_ASSERT(i_ <= static_cast<uint8_t>(DispatchKey::NumDispatchKeys));
// Create a masked version of the set representation to ignore previous
// keys that we've iterated through.
uint64_t masked_data = llvm::maskTrailingZeros<uint64_t>(i_) & *data_ptr_;
uint64_t firstKeyIndex = llvm::findFirstSet(masked_data);
// If there are no keys, set to end iterator value
if (firstKeyIndex == std::numeric_limits<uint64_t>::max() ||
i_ == static_cast<uint8_t>(DispatchKey::NumDispatchKeys)) {
i_ = static_cast<uint8_t>(DispatchKey::NumDispatchKeys);
return *this;
}
i_ = static_cast<uint8_t>(firstKeyIndex) + 1;
return *this;
}
self_type operator++(int) {
self_type previous_iterator = *this;
++(*this);
return previous_iterator;
}
bool operator==(const self_type& rhs) const { return i_ == rhs.i_; }
bool operator!=(const self_type& rhs) const { return i_ != rhs.i_; }
DispatchKey operator*() const { return static_cast<DispatchKey> (i_); }
private:
const uint64_t *data_ptr_;
uint8_t i_;
};
public:
// Returns iterator to the first key in the set. If no keys are in the
// set, then will return the end iterator.
iterator begin() const { return iterator(&repr_); }
// We do not need to iterate beyond NumDispatchKeys so we will treat this as
// the end iterator. NumDispatchKeys will always be strictly less than 64.
iterator end() const { return iterator(&repr_, static_cast<uint8_t>(DispatchKey::NumDispatchKeys)); }
};
C10_API std::string toString(DispatchKeySet);
C10_API std::ostream& operator<<(std::ostream&, DispatchKeySet);
// autograd_dispatch_keyset should include all runtime autograd keys.
// Alias key DispatchKey::Autograd maps to autograd_dispatch_keyset.
constexpr DispatchKeySet autograd_dispatch_keyset = DispatchKeySet({
DispatchKey::AutogradCPU,
DispatchKey::AutogradCUDA,
DispatchKey::AutogradXLA,
DispatchKey::AutogradPrivateUse1,
DispatchKey::AutogradPrivateUse2,
DispatchKey::AutogradPrivateUse3,
DispatchKey::AutogradOther,
});
// Resolve alias dispatch key to DispatchKeySet if applicable
C10_API DispatchKeySet getRuntimeDispatchKeySet(DispatchKey t);
// Returns a DispatchKeySet of all backend keys mapped to Autograd dispatch key t,
// DispatchKeySet is empty if t is not alias of DispatchKey::Autograd.
C10_API DispatchKeySet getBackendKeySetFromAutograd(DispatchKey t);
// This API exists because we have a use case for checking
// getRuntimeDispatchKeySet(alias).has(DispatchKey::Undefind)
// in OperatorEntry.cpp but we disallow it in has() API.
C10_API bool isIncludedInAlias(DispatchKey k, DispatchKey alias);
// Historically, every tensor only had a single DispatchKey, and it was always
// something like CPU, and there wasn't any of this business where TLS
// could cause the DispatchKey of a tensor to change. But we still have some
// legacy code that is still using DispatchKey for things like instanceof
// checks; if at all possible, refactor the code to stop using DispatchKey in
// those cases.
static inline DispatchKey legacyExtractDispatchKey(DispatchKeySet s) {
// NB: If you add any extra keys that can be stored in TensorImpl on
// top of existing "normal" keys like CPU/CUDA, you need to add it
// here. At the moment, RequiresGrad (replacement for Variable)
// is the most likely key that will need this treatment;
// After Autograd keys are moved from globally enabled set to TensorImpl,
// we should remove all Autograd keys before taking highestPriority.
return (s - autograd_dispatch_keyset).highestPriorityTypeId();
}
}
|