1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347
|
#include <torch/csrc/jit/codegen/cuda/parallel_dimension_map.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/csrc/jit/codegen/cuda/expr_evaluator.h>
#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
#include <torch/csrc/jit/codegen/cuda/iter_visitor.h>
#include <torch/csrc/jit/codegen/cuda/kernel_expr_evaluator.h>
#include <torch/csrc/jit/codegen/cuda/lower2device.h>
#include <sstream>
namespace torch {
namespace jit {
namespace fuser {
namespace cuda {
void ParallelDimensionMap::build(Fusion* fusion) {
// Scan all TVs to build ParallelType maps
auto all_vals = fusion->usedMathVals();
for (auto tv : ir_utils::filterByType<TensorView>(all_vals)) {
for (auto id : tv->domain()->domain()) {
registerConstantExtent(id);
if (!isParallelTypeThread(id->getParallelType())) {
continue;
}
handleParallelDomain(id);
}
}
// Populate the dimension map for each parallel type
for (const auto& kv : concrete_dom_map_) {
auto pt = kv.first;
const auto& concrete_dom_set = kv.second;
TORCH_INTERNAL_ASSERT(!concrete_dom_set.empty());
if (concrete_dom_set.size() == 1) {
populateDimensionMapWithSingleCASet(pt, concrete_dom_set);
} else {
populateDimensionMapWithMultipleCASet(pt, concrete_dom_set);
}
}
adjustMappingsForWarpPadding();
}
void ParallelDimensionMap::registerConstantExtent(IterDomain* id) {
if (!id->extent()->isConstScalar()) {
// Nothing to do if not constant
return;
}
ExpressionEvaluator ee(id->fusion());
auto extent_int = ee.evaluate(id->extent());
TORCH_INTERNAL_ASSERT(
extent_int.has_value(),
"Extent of ",
id->toString(),
" should have been constant, but could not be evaluated at compile time.");
auto const_extent = extent_int->as<int64_t>();
// Uses index map
auto concrete_id = getCAMappedConcreteDomain(id);
auto existing_it = constant_extent_map_.find(id);
// Adds the constant extent to the set for the concrete domain. If
// multiple constants are found, this concrete domain has multiple
// distinctive extents, which can happen with broadcast.
if (existing_it == constant_extent_map_.end()) {
constant_extent_map_.insert({concrete_id, {const_extent}});
} else {
existing_it->second.insert(const_extent);
}
}
// Adds the conrecte domain of id to the mappsed set for its
// parallel type
void ParallelDimensionMap::handleParallelDomain(IterDomain* id) {
auto pt = id->getParallelType();
TORCH_INTERNAL_ASSERT(isParallelTypeThread(pt));
auto concrete_id = getCAMappedConcreteDomain(id);
auto it = concrete_dom_map_.find(pt);
if (it == concrete_dom_map_.end()) {
concrete_dom_map_.insert({pt, {concrete_id}});
} else {
it->second.insert(concrete_id);
}
}
void ParallelDimensionMap::populateDimensionMapWithSingleCASet(
ParallelType pt,
const std::unordered_set<IterDomain*>& dom_set) {
TORCH_INTERNAL_ASSERT(dom_set.size() == 1);
// pt is used by only one concrete domain
auto id = *dom_set.begin();
auto it = constant_extent_map_.find(id);
if (it != constant_extent_map_.end()) {
TORCH_INTERNAL_ASSERT(
it->second.size() == 1,
"Only one value found mapped to parallel type ",
stringifyThread(pt),
" yet its bound to multiple extents.");
dim_map_.insert({pt, IrBuilder::create<Int>(*(it->second.begin()))});
exact_types_.insert(pt);
} else {
// Prefer to use blockDim/gridDim if not constant
dim_map_.insert({pt, NamedScalar::getParallelDim(pt)});
exact_types_.insert(pt);
}
}
void ParallelDimensionMap::populateDimensionMapWithMultipleCASet(
ParallelType pt,
const std::unordered_set<IterDomain*>& dom_set) {
TORCH_INTERNAL_ASSERT(dom_set.size() > 1);
bool all_equal = true;
// Use nullptr to signal it's not initialied yet
Val* known_dimension = nullptr;
// Use -1 to signal it's not initialied yet
int64_t known_const = -1;
// Check all of concrete domains to see if they match all together.
for (auto concrete_id : dom_set) {
if (concrete_id->isBroadcast()) {
// Broadcasted concrete id's don't specify anything about shape
continue;
}
// If this concrete domain has a constant extent, check if it
// matches with the known constant extent.
auto it = constant_extent_map_.find(concrete_id);
if (it != constant_extent_map_.end()) {
const auto& const_extent_set = it->second;
// If multiple constants are detected, it's not exact.
if (const_extent_set.size() > 1) {
all_equal = false;
break;
}
auto this_const = *(const_extent_set.begin());
// known_const is initialized to -1
if (known_const == -1) {
known_const = this_const;
} else if (known_const == this_const) {
// Matched with previously known const. The extent of this
// domain must be equal to that's previously known.
continue;
} else {
// Unmatched. This dom_set extents may not be unique.
all_equal = false;
break;
}
}
// At this point, it still remains undetermined whether this id
// matches with those previously looked at. Constant check failed,
// but symbolic matching may succeed.
auto this_dimension = concrete_id->extent();
if (known_dimension == nullptr) {
// No previous dimension found yet
known_dimension = this_dimension;
} else {
if (!equalDim(known_dimension, this_dimension)) {
all_equal = false;
break;
}
}
}
// If all_equal is still true, the dimension of this paralel type
// must be exact.
if (all_equal) {
exact_types_.insert(pt);
}
// Use the const value, if found, as its dimension
if (all_equal && known_const != -1) {
dim_map_.insert({pt, IrBuilder::create<Int>(known_const)});
} else {
dim_map_.insert({pt, NamedScalar::getParallelDim(pt)});
}
}
void ParallelDimensionMap::adjustMappingsForWarpPadding() {
const auto gpu_lower = GpuLower::current();
// If TIDx is padded to a multiple of the warp size, mark it as
// non-exact.
auto& warp_info = gpu_lower->getWarpPaddedParallelInfo();
// TIDx isn't really padded if there isn't a warp reduction (this could
// change)
if (!(warp_info.is_tidx_padded && warp_info.has_warp_reduction)) {
return;
}
const auto tidx_pt = ParallelType::TIDx;
auto warp_size = at::cuda::warp_size();
// If the dimension of TIDx is actually a multple of the warp size
// before padding, it can be left as exact
if (isExact(tidx_pt)) {
auto tidx_dim = dynamic_cast<Int*>(get(tidx_pt));
if (tidx_dim && tidx_dim->isConst()) {
auto tidx_dim_val = tidx_dim->value().value();
if (tidx_dim_val % warp_size == 0) {
// Dimension of TIDx is a multiple of the warp size
return;
}
}
// If tidx is strictly defined as blockDim.x then it must be set to a
// multiple of the warp and can be considered exact
bool tidx_def_trivial = true;
for (auto entry : concrete_dom_map_.at(tidx_pt)) {
if (!entry->isA<NamedScalar>() ||
!entry->as<NamedScalar>()->sameAs(
NamedScalar::getParallelDim(tidx_pt))) {
tidx_def_trivial = false;
}
}
if (tidx_def_trivial) {
return;
}
}
// TIDx is padded to a multiple of warp. If it's known to be a
// single warp, use the constant warp size as the dimension of
// TIDx. Otherwise, just use blockDim.x.
if (warp_info.is_tidx_single_warp) {
dim_map_.at(ParallelType::TIDx) = IrBuilder::create<Int>(warp_size);
} else {
dim_map_.at(ParallelType::TIDx) =
NamedScalar::getParallelDim(ParallelType::TIDx);
}
// TIDx is no longer exact
exact_types_.erase(ParallelType::TIDx);
}
Val* ParallelDimensionMap::get(ParallelType pt) const {
TORCH_INTERNAL_ASSERT(isParallelTypeThread(pt), "Invalid ParallelType: ", pt);
auto it = dim_map_.find(pt);
if (it == dim_map_.end()) {
return nullptr;
} else {
return it->second;
}
}
bool ParallelDimensionMap::isExact(ParallelType pt) const {
return exact_types_.find(pt) != exact_types_.end();
}
IterDomain* ParallelDimensionMap::getCAMappedConcreteDomain(IterDomain* id) {
return GpuLower::current()->caMap()->getConcreteMappedID(
id, IdMappingMode::EXACT);
}
// Symbolically compares equality of two KIR vals. Comparison is done
// conservatively, so returning false does not guarantee non-equality.
bool ParallelDimensionMap::equalDim(Val* dim1, Val* dim2) {
TORCH_INTERNAL_ASSERT(dim1 != nullptr && dim2 != nullptr);
if (dim1 == dim2) {
return true;
}
// When Both are Int, they are same if both have the same constant
auto dim1_int = dynamic_cast<Int*>(dim1);
auto dim2_int = dynamic_cast<Int*>(dim2);
if (dim1_int && dim2_int) {
if (dim1_int->isConst() && dim2_int->isConst()) {
return dim1_int->value() == dim2_int->value();
}
}
// When both are NamedScalar, they are same if Both have the same
// name
auto dim1_ns = dynamic_cast<NamedScalar*>(dim1);
auto dim2_ns = dynamic_cast<NamedScalar*>(dim2);
if (dim1_ns && dim2_ns) {
return dim1_ns->name() == dim2_ns->name();
}
// Check recursively their definitions
auto dim1_def = dim1->definition();
auto dim2_def = dim2->definition();
if (dim1_def == nullptr || dim2_def == nullptr) {
return false;
}
// If both are BinaryOp or UnaryOp, check their inputs. Since these
// Vals are IterDomain extents, UnaryOp should not occur, but
// checking shouldn't be harmful.
// TODO:
// We might be able to replace this with dim1->toInlineString() ==
// dim2->toInlineString()
// If we want this less conservative we could make an "exact map" which
// could be another mode in compute at that maps all iter domains, but not
// concretized broadcast axes and only forwards through non-concretized
// broadcast axes.
if ((dim1_def->isA<BinaryOp>() && dim2_def->isA<BinaryOp>() &&
(dim1_def->as<BinaryOp>()->getBinaryOpType() ==
dim2_def->as<BinaryOp>()->getBinaryOpType())) ||
(dim1_def->isA<UnaryOp>() && dim2_def->isA<UnaryOp>() &&
(dim1_def->as<UnaryOp>()->getUnaryOpType() ==
dim2_def->as<UnaryOp>()->getUnaryOpType()))) {
for (const auto i : c10::irange(dim1_def->inputs().size())) {
(void)i; // Suppress unused variable warning
if (!equalDim(dim1_def->inputs()[0], dim2_def->inputs()[0])) {
return false;
}
}
return true;
}
return false;
}
std::string ParallelDimensionMap::toString() const {
std::stringstream ss;
for (auto pt : kParallelTypeThreads) {
ss << pt << ": ";
auto dim = get(pt);
if (dim != nullptr) {
ss << dim->toString();
if (isExact(pt)) {
ss << ", exact";
} else {
ss << ", non-exact";
}
} else {
ss << "unused";
}
ss << "\n";
}
return ss.str();
}
} // namespace cuda
} // namespace fuser
} // namespace jit
} // namespace torch
|