1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294
|
#include <c10/core/DispatchKeySet.h>
#include <c10/util/irange.h>
#include <iostream>
namespace c10 {
// backend_dispatch_keyset includes all dispatch keys that map to backends.
// Alias key DispatchKey::CompositeExplicitAutograd maps to
// backend_dispatch_keyset
constexpr DispatchKeySet backend_dispatch_keyset =
autogradother_backends | DispatchKeySet(DispatchKey::Dense);
// See Note [CompositeExplicitAutogradNonFunctional Key]
// We have several types of decompositions in aten, that each have their own
// alias key. You should register your decomposition to the
// `CompositeExplicitAutogradNonFunctional key` if: (1) It's an out-of-place op
// (2) It decomposes into one more mutation ops
// (3) It has a derivative formula
// (In theory we could also have a separate key for
// "CompositeImplicitAutogradNonFunctional", but there isn't much of a use
// case for it currently).
// This key is important for "functional" backends like LazyTensor / XLA.
// If you're a backend that only expects to deal with "functional ops",
// then you don't want to decompose a functional op into an op that causes
// aliasing. You should just directly write a kernel for that functional op
// instead!
constexpr DispatchKeySet non_functional_backend_dispatch_keyset =
backend_dispatch_keyset
// XLA and LazyTensor are currently the only 2 backends in core
// that use functionalization pass in eager mode.
.remove(DispatchKey::Sparse)
.remove_backend(BackendComponent::XLABit)
.remove_backend(BackendComponent::LazyBit);
bool isBackendDispatchKey(DispatchKey t) {
return t != DispatchKey::Undefined
// See Note [No Alias Keys in DispatchKeySet]
&& !isAliasDispatchKey(t)
// Note [NestedTensor Not Included in Backend Keys]
// NestedTensor has been explicitly removed from the "backend keyset" due
// to incompatibility with some kernels, so we don't want it to be
// included in CompositeExplicitAutograd kernels.
&& t != DispatchKey::NestedTensor && backend_dispatch_keyset.has(t);
}
// math_dispatch_keyset contains all keys in backend_dispatch_keyset and
// autograd_dispatch_keyset Alias key DispatchKey::CompositeImplicitAutograd
// maps to [math_dispatch_keyset x full_backend_mask]
constexpr DispatchKeySet math_dispatch_keyset = backend_dispatch_keyset |
autograd_dispatch_keyset |
// See Note [NestedTensor Not Included in Backend Keys]
// The caveat to that note is that nested_tensor is a special case
// where we would like to support composite implicit kernels but not
// explicit kernels therefore we manually add the key to the
// math_dispatch_keyset
DispatchKeySet{DispatchKey::NestedTensor};
constexpr DispatchKeySet nested_dispatch_keyset =
DispatchKeySet(
{DispatchKey::AutogradNestedTensor, DispatchKey::NestedTensor}) |
DispatchKeySet(DispatchKeySet::RAW, full_backend_mask);
DispatchKeySet getRuntimeDispatchKeySet(DispatchKey t) {
TORCH_INTERNAL_ASSERT(t != DispatchKey::Undefined);
switch (t) {
case DispatchKey::Autograd:
// See Note [autograd_dispatch_keyset Does Not Include Backend Bits]
// That's why we OR it with a mask of the backend bits here.
// getRuntimeDispatchKeySet() expects to return a keyset of runtime
// dispatch keys, like AutogradCPU, but that requires having backend bits.
return autograd_dispatch_keyset |
DispatchKeySet(DispatchKeySet::RAW, full_backend_mask);
case DispatchKey::CompositeImplicitAutograd:
return math_dispatch_keyset;
case DispatchKey::CompositeImplicitAutogradNestedTensor:
return nested_dispatch_keyset;
case DispatchKey::CompositeExplicitAutograd:
return backend_dispatch_keyset;
case DispatchKey::CompositeExplicitAutogradNonFunctional:
return non_functional_backend_dispatch_keyset;
default:
return DispatchKeySet(t);
}
}
bool runtimeDispatchKeySetHas(DispatchKey t, DispatchKey k) {
TORCH_INTERNAL_ASSERT(t != DispatchKey::Undefined);
switch (t) {
case DispatchKey::Autograd:
return autograd_dispatch_keyset.has(toFunctionalityKey(k));
case DispatchKey::CompositeImplicitAutograd:
// See Note [NestedTensor Not Included in Backend Keys]
return math_dispatch_keyset.has(k);
case DispatchKey::CompositeImplicitAutogradNestedTensor:
// See Note [NestedTensor Not Included in Backend Keys]
return nested_dispatch_keyset.has(k);
case DispatchKey::CompositeExplicitAutograd:
// See Note [NestedTensor Not Included in Backend Keys]
return k != DispatchKey::NestedTensor && backend_dispatch_keyset.has(k);
case DispatchKey::CompositeExplicitAutogradNonFunctional:
// See Note [NestedTensor Not Included in Backend Keys]
return k != DispatchKey::NestedTensor &&
non_functional_backend_dispatch_keyset.has(k);
default:
return t == k;
}
}
// for a given autograd key, return the (guaranteed nonempty) set of associated
// backend keys. for a non-autograd key, return the empty keyset.
DispatchKeySet getBackendKeySetFromAutograd(DispatchKey t) {
switch (t) {
case DispatchKey::AutogradCPU:
return DispatchKeySet(DispatchKey::CPU);
case DispatchKey::AutogradCUDA:
return DispatchKeySet(DispatchKey::CUDA);
case DispatchKey::AutogradXLA:
return DispatchKeySet(DispatchKey::XLA);
case DispatchKey::AutogradLazy:
return DispatchKeySet(DispatchKey::Lazy);
case DispatchKey::AutogradMeta:
return DispatchKeySet(DispatchKey::Meta);
case DispatchKey::AutogradMPS:
return DispatchKeySet(DispatchKey::MPS);
case DispatchKey::AutogradHPU:
return DispatchKeySet(DispatchKey::HPU);
case DispatchKey::AutogradIPU:
return DispatchKeySet(DispatchKey::IPU);
case DispatchKey::AutogradXPU:
return DispatchKeySet(DispatchKey::XPU);
case DispatchKey::AutogradPrivateUse1:
return DispatchKeySet(DispatchKey::PrivateUse1);
case DispatchKey::AutogradPrivateUse2:
return DispatchKeySet(DispatchKey::PrivateUse2);
case DispatchKey::AutogradPrivateUse3:
return DispatchKeySet(DispatchKey::PrivateUse3);
case DispatchKey::AutogradNestedTensor:
return DispatchKeySet(DispatchKey::NestedTensor) |
DispatchKeySet(DispatchKeySet::RAW, full_backend_mask);
case DispatchKey::AutogradOther:
return autogradother_backends;
default:
return DispatchKeySet();
}
}
bool isIncludedInAlias(DispatchKey k, DispatchKey alias) {
return k != DispatchKey::Undefined && runtimeDispatchKeySetHas(alias, k);
}
std::string toString(DispatchKeySet ts) {
std::stringstream ss;
ss << ts;
return ss.str();
}
std::ostream& operator<<(std::ostream& os, DispatchKeySet ts) {
if (ts.empty()) {
os << "DispatchKeySet()";
return os;
}
os << "DispatchKeySet(";
bool first = true;
for (auto k : ts) {
if (!first) {
os << ", ";
}
os << k;
first = false;
}
os << ")";
return os;
}
DispatchKeySet::iterator& DispatchKeySet::iterator::operator++() {
TORCH_INTERNAL_ASSERT(next_functionality_ <= iterator::end_iter_mask_val);
TORCH_INTERNAL_ASSERT(next_backend_ <= num_backends, next_backend_);
// Create a masked version of the set representation to ignore previous
// keys that we've iterated through.
uint64_t masked_functionality_bits =
llvm::maskTrailingZeros<uint64_t>(next_functionality_) & *data_ptr_;
uint64_t masked_backend_bits =
llvm::maskTrailingZeros<uint64_t>(next_backend_) & full_backend_mask &
*data_ptr_;
uint64_t first_functionality_idx =
llvm::findFirstSet(masked_functionality_bits);
uint64_t first_backendcomponent_idx = llvm::findFirstSet(masked_backend_bits);
// If there are no keys, set to end iterator value
if (first_functionality_idx == std::numeric_limits<uint64_t>::max() ||
next_functionality_ == iterator::end_iter_mask_val) {
// Set up state to be the same as end()
next_functionality_ = iterator::end_iter_mask_val;
current_dispatchkey_idx_ = iterator::end_iter_key_val;
next_backend_ = 0;
current_backendcomponent_idx_ = iterator::end_iter_key_val;
return *this;
}
// The +1 is because of DispatchKey::Undefined and
// BackendComponent::InvalidBit
auto new_next_functionality = first_functionality_idx + 1;
auto new_backendcomponent_idx = first_backendcomponent_idx + 1;
// and the -num_backends is because the first <num_backends> bits in the
// keyset are not Dispatch Keys.
auto next_dispatchkey_idx = new_next_functionality - num_backends;
// If the current functionality bit is a per-backend bit, we need special
// handling
if (isPerBackendFunctionalityKey(
static_cast<DispatchKey>(next_dispatchkey_idx))) {
// case 1: if the current backend is undefined, then there is no valid
// backend instance of this functionality key so we can skip it.
if (first_backendcomponent_idx == std::numeric_limits<uint64_t>::max()) {
// increment the functionality mask so we skip the current functionality
// bit on the next increment.
next_functionality_ = new_next_functionality;
++(*this);
return *this;
}
// Otherwise, at this point we know what the current backend and
// functionality bits are.
current_dispatchkey_idx_ = next_dispatchkey_idx;
current_backendcomponent_idx_ = new_backendcomponent_idx;
// Next, we need to set up the masks for the next increment.
uint64_t next_backendcomponent_bits =
llvm::maskTrailingZeros<uint64_t>(first_backendcomponent_idx + 1) &
full_backend_mask & *data_ptr_;
uint64_t next_backendcomponent_idx =
llvm::findFirstSet(next_backendcomponent_bits);
if (next_backendcomponent_idx == std::numeric_limits<uint64_t>::max()) {
// case 2: the current backend is valid, but there is not another backend
// in the keyset. In this case, we need to bump the functionality mask and
// reset the backend mask for the next increment
next_functionality_ = new_next_functionality;
next_backend_ = 0;
} else {
// case 3: we have another backend to iterate over. We want to iterate
// over the same functionality bit next time, but a different backend bit.
next_backend_ = first_backendcomponent_idx + 1;
}
} else {
// Functionality bits that aren't per backend are simpler to handle. We can
// ignore the backend bits.
TORCH_INTERNAL_ASSERT(next_backend_ == 0);
current_dispatchkey_idx_ = next_dispatchkey_idx;
next_functionality_ = new_next_functionality;
}
return *this;
}
std::array<FunctionalityOffsetAndMask, num_functionality_keys>
initializeFunctionalityOffsetsAndMasks() {
std::array<FunctionalityOffsetAndMask, num_functionality_keys>
offsets_and_masks;
// manualy set the first entry, which corresponds to Undefined.
offsets_and_masks[0] = FunctionalityOffsetAndMask(0, 0);
// loop through every functionality key (aside from Undefined).
for (const auto functionality_idx : c10::irange(1, num_functionality_keys)) {
// functionality_idx should be Dense -> 1, ...
auto prev_offset_and_mask = offsets_and_masks[functionality_idx - 1];
auto k = static_cast<DispatchKey>(functionality_idx);
// If the previous functionality was not per-backend, then we can just
// increment the previous offset. Otherwise, the next offset =
// previous_offset + num_backends.
auto next_offset = prev_offset_and_mask.offset +
(prev_offset_and_mask.mask == 0 ? 1 : num_backends);
// the mask is used in the runtime index calculation to find the offset of
// the backend. For non-per-backend functionalities, this offset should
// always be 0. Otherwise, we need to get the index of the backend (which we
// can do using a backend mask).
auto next_mask = isPerBackendFunctionalityKey(k) ? full_backend_mask : 0;
offsets_and_masks[functionality_idx] =
FunctionalityOffsetAndMask(next_offset, next_mask);
}
// Sanity check that the computed offset index of the last functionality key
// is correct. This assumes that the highest priority functionality key is not
// per backend.
TORCH_INTERNAL_ASSERT(
offsets_and_masks[num_functionality_keys - 1].offset ==
(num_runtime_entries - 1),
"num_runtime_entries: ",
num_runtime_entries,
"last_offset: ",
offsets_and_masks[num_functionality_keys - 1].offset);
return offsets_and_masks;
}
} // namespace c10
|