1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742
|
#include <torch/csrc/jit/codegen/cuda/scheduler/pointwise.h>
#include <torch/csrc/jit/codegen/cuda/executor_utils.h>
#include <torch/csrc/jit/codegen/cuda/inline_propagator.h>
#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
#include <torch/csrc/jit/codegen/cuda/lower_utils.h>
#include <torch/csrc/jit/codegen/cuda/scheduler/pointwise_utils.h>
#include <torch/csrc/jit/codegen/cuda/scheduler/registry.h>
#include <torch/csrc/jit/codegen/cuda/scheduler/utils.h>
#include <torch/csrc/jit/codegen/cuda/scheduler/vectorize_helper.h>
#include <torch/csrc/jit/codegen/cuda/transform_replay.h>
#include <torch/csrc/jit/codegen/cuda/utils.h>
#include <ATen/cuda/CUDAContext.h>
#include <algorithm>
#include <unordered_map>
namespace torch {
namespace jit {
namespace fuser {
namespace cuda {
namespace {
// constexpr int64_t x_grid_limit = ((int64_t)1 << (int64_t)31) - (int64_t)1;
// Unused at the moment, commenting for clang tidy
constexpr int64_t kThreadX = 128;
class DomainMap : public pointwise_utils::DomainMap {
public:
using pointwise_utils::DomainMap::DomainMap;
// The pointwise scheduler heuristics requires a minimum number of axes.
// The output reference tensor should respect this requirement.
TensorView* findReferenceTensorView(int minimum_num_axes = 0) const {
TensorView* result = nullptr;
int max_dims = -1;
for (auto output_tv :
ir_utils::filterByType<TensorView>(fusion_->outputs())) {
if (isValidReference(output_tv) &&
hasMinimumSize(output_tv, minimum_num_axes) &&
!output_tv->isFusionInput()) {
int n_dims = pointwise_utils::nRootDims(output_tv);
if (n_dims > max_dims) {
result = output_tv;
max_dims = n_dims;
}
}
}
return result;
}
static bool hasReferenceTensorView(Fusion* fusion) {
FusionGuard fg(fusion);
DomainMap domain_map(fusion);
return domain_map.findReferenceTensorView() != nullptr;
}
private:
bool hasMinimumSize(TensorView* tv, int num_axes) const {
TORCH_INTERNAL_ASSERT(tv != nullptr);
return (num_axes == 0 || tv->getMaybeRFactorDomain().size() > num_axes);
}
};
} // namespace
std::shared_ptr<PointwiseParams> getPointwiseHeuristics(
Fusion* fusion,
const at::ArrayRef<c10::IValue>& runtime_inputs,
HeuristicSummary* data_cache) {
SchedulerRuntimeInfo runtime_info(fusion, runtime_inputs, true);
return getPointwiseHeuristics(fusion, runtime_info, data_cache);
}
std::shared_ptr<PointwiseParams> getPointwiseHeuristics(
Fusion* fusion,
SchedulerRuntimeInfo& runtime_info,
HeuristicSummary* data_cache) {
FUSER_PERF_SCOPE("getPointwiseHeuristics");
FusionGuard fg(fusion);
// Incase any buffer is of type DataType::Index
DataType index_type = indexModeToDtype(runtime_info.getIndexMode());
auto in_tvs = ir_utils::filterByType<TensorView>(fusion->inputs());
auto domain_map_entry =
HeuristicSummaryEntry<HeuristicCompileTime::DomainMap>(
data_cache,
[fusion]() { return std::make_unique<DomainMap>(fusion); });
const auto& domain_map = dynamic_cast<DomainMap&>(domain_map_entry.get());
auto largest_out_entry =
HeuristicSummaryEntry<HeuristicCompileTime::ReferenceTensors>(
data_cache, [&domain_map]() {
std::vector<TensorView*> data{domain_map.findReferenceTensorView()};
return std::make_unique<std::vector<TensorView*>>(std::move(data));
});
TensorView* largest_out = largest_out_entry.get()[0];
TORCH_INTERNAL_ASSERT(largest_out != nullptr);
const int64_t device_multiprocessor_count =
(int64_t)at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
// TODO: Set to 1?
int64_t max_input_dtype_size = 2;
for (auto inp : in_tvs) {
max_input_dtype_size = std::max(
max_input_dtype_size,
(int64_t)dataTypeSize(inp->getDataType().value(), index_type));
}
auto ref_root = largest_out->getMaybeRFactorDomain();
std::vector<int64_t> elem_counts(ref_root.size(), 1);
int64_t n_elems = 1;
for (size_t ref_i = 0; ref_i < ref_root.size(); ref_i++) {
auto inferred_val =
runtime_info.expressionEvaluator().evaluate(ref_root[ref_i]->extent());
TORCH_INTERNAL_ASSERT(
inferred_val.has_value(),
"Error inferring size for pointwise scheduler: ",
ref_root[ref_i]->extent()->toInlineString());
elem_counts[ref_i] = inferred_val->as<int64_t>();
n_elems *= elem_counts[ref_i];
}
// If zero dimensional or zero size, return default parameters
if (TensorDomain::noReductions(
TensorDomain::noBroadcasts(largest_out->domain()->domain()))
.size() == 0 ||
n_elems == 0) {
auto vectorizable_inputs_outputs_entry = HeuristicSummaryEntry<
HeuristicCompileTime::VectorizableInputsAndOutputs>(data_cache, []() {
return std::make_unique<std::vector<TensorView*>>();
});
vectorizable_inputs_outputs_entry.get();
auto broadcast_byte_multiples_entry =
HeuristicSummaryEntry<HeuristicCompileTime::BroadcastMultiples>(
data_cache, []() {
return std::make_unique<
std::vector<scheduler_utils::BroadcastMultiple>>();
});
broadcast_byte_multiples_entry.get();
return std::make_shared<PointwiseParams>("Pointwise heuristics");
}
// Find all vectorizable inputs/outputs
auto vectorizable_inputs_outputs_entry =
HeuristicSummaryEntry<HeuristicCompileTime::VectorizableInputsAndOutputs>(
data_cache, [&largest_out]() {
return std::make_unique<std::vector<TensorView*>>(
scheduler_utils::getInputsOutputsWithInnerDim(
largest_out, true, true));
});
constexpr int64_t kSixteen = 16; // clang tidy
auto max_unroll_factor = ceilDiv(
// Available unrolling based on size of data type
(int64_t)kSixteen / max_input_dtype_size,
// Reduce max unrolling factor if we have many inputs/outputs to unroll
// as it could start consuming a lot of registers.
std::max(
(scheduler_utils::lastPow2(
(int64_t)vectorizable_inputs_outputs_entry.get().size()) >>
2),
(int64_t)1));
// Don't unroll at the cost of getting a full wave on the GPU
if (n_elems < device_multiprocessor_count * kThreadX &&
max_unroll_factor > 1) {
max_unroll_factor = std::min(
max_unroll_factor,
ceilDiv(n_elems, device_multiprocessor_count * kThreadX));
}
auto params = std::make_shared<PointwiseParams>("Pointwise heuristics");
/*
* 2D pointwise scheduling logic. What is expected is there's some
* broadcasting pattern which would make scheduling as a 2D problem more
* efficient than scheduling simply as a 1D problem.
*
* Mapping count holds how many bytes are in each dimension for both inputs
* and outputs relative to the reference tensor. What we're looking for is a
* break point in reference_tvs dimensions which separates the outer dimension
* and inner dimension of the problem mapped to 2D.
*
* break_point is computed assuming no reuse, ignoring parallelization
* limitations, and simply figures out which point best separates broadcasted
* dimensions. In other words, where's the point where we isolate the most
* broadcasted elements to one side.
*
* Once a break point is found, simply schedule the pointwise op as 2D
* balancing parallelization as best as possible.
*/
// Ideal break point location
int break_point = 0;
// If break_point, mark if BIDy and BIDx should be positionally reversed
// relative to root domains
bool flip_grid_binding = false;
// Elements on the right of break point (without break point all are on the
// right)
int64_t right_elem_count = 0;
int64_t bdimx = kThreadX;
// bdimy may be used if the right side of the break point is not large and we
// need to expand block level parallelism into the left side of the break
// point.
int64_t bdimy = 1;
// In 2D scheduler gdim_left is used to parallelize the left side of the break
// point.
int64_t gdim_left = 1;
// gdim_right is used if there's too much parallelization in the right side of
// the break point. We will expand grid parallelization into the right side of
// the break point with gdim_left and use gdim_right for the left side of the
// break point.
int64_t gdim_right = 1;
auto broadcast_byte_multiples_entry =
HeuristicSummaryEntry<HeuristicCompileTime::BroadcastMultiples>(
data_cache, [&largest_out, &index_type]() {
return std::make_unique<
std::vector<scheduler_utils::BroadcastMultiple>>(
scheduler_utils::getBroadcastMultiples(
largest_out, index_type));
});
auto& broadcast_byte_multiples = broadcast_byte_multiples_entry.get();
TORCH_INTERNAL_ASSERT(broadcast_byte_multiples.size() == ref_root.size());
int64_t dtype_sum = 0;
for (auto inp : ir_utils::filterByType<TensorView>(fusion->inputs())) {
dtype_sum += dataTypeSize(inp->getDataType().value(), index_type);
}
for (auto out : ir_utils::filterByType<TensorView>(fusion->outputs())) {
dtype_sum += dataTypeSize(out->getDataType().value(), index_type);
}
{ // Figure out break point position. Empty scope, consider moving to a
// separate function.
//
// How much would this transfer cost if it was done as a 1-D schedule
int64_t transfer_size_1d = 1;
for (const auto i : c10::irange(ref_root.size())) {
transfer_size_1d = transfer_size_1d * elem_counts[i] * dtype_sum;
}
// If there isn't very much parallelism available, just use 1D scheduler
if (n_elems * 2 > device_multiprocessor_count * kThreadX) {
int64_t min_total_transfer = std::numeric_limits<int64_t>::max();
for (const auto break_point_i : c10::irange(ref_root.size())) {
// Number of elements in the right side of reference tv with
// break_point_i
int64_t cur_right_elem_count = 1;
for (const auto right_i : c10::irange(break_point_i, ref_root.size())) {
cur_right_elem_count = cur_right_elem_count * elem_counts[right_i];
}
auto cur_left_elem_count = n_elems / cur_right_elem_count;
if (cur_left_elem_count <= 1) {
continue;
}
auto lhs_byte_multiple =
broadcast_byte_multiples[break_point_i].lhs_multiple;
auto rhs_byte_multiple =
broadcast_byte_multiples[break_point_i].rhs_multiple;
// Estimate transfer cost with this break point
int64_t cur_transfer_size = 1;
int64_t right_transfer_size = 1;
for (const auto left_i : c10::irange(break_point_i)) {
cur_transfer_size =
cur_transfer_size * elem_counts[left_i] * lhs_byte_multiple;
}
for (const auto right_i : c10::irange(break_point_i, ref_root.size())) {
right_transfer_size =
right_transfer_size * elem_counts[right_i] * rhs_byte_multiple;
}
cur_transfer_size *= right_transfer_size;
// Continue if this break point doesn't save at least 10% of 1D
// scheduling or isn't better than previous break_points found.
if (cur_transfer_size >= min_total_transfer ||
cur_transfer_size * 10 >= transfer_size_1d * 9) {
continue;
}
// Need to be able to parallelize, don't use break if there's not
// at least an unrolled warp.
if (ceilDiv(cur_right_elem_count, max_unroll_factor) <=
at::cuda::getCurrentDeviceProperties()->warpSize) {
continue;
}
// If outer broadcast, or balanced broadcast:
if (lhs_byte_multiple <= rhs_byte_multiple &&
// If right transfer size is bigger than half of L2
at::cuda::getCurrentDeviceProperties()->l2CacheSize <
right_transfer_size * 2) {
// flip BIDx and BIDy bindings
flip_grid_binding = true;
} else {
flip_grid_binding = false;
}
// Min transfer found, start setting values
bdimx = std::min(
ceilDiv(cur_right_elem_count, max_unroll_factor), kThreadX);
bdimy = 1;
gdim_right = 1;
// Put remainder in bdimy if there's at least a wave of grid level
// parallelism.
if (cur_left_elem_count > device_multiprocessor_count) {
bdimy = kThreadX / bdimx;
}
auto remainder_left = ceilDiv(cur_left_elem_count, bdimy);
auto remainder_right =
ceilDiv(cur_right_elem_count, bdimx * max_unroll_factor);
// Use this break point
break_point = static_cast<int>(break_point_i);
min_total_transfer = cur_transfer_size;
right_elem_count = cur_right_elem_count;
gdim_left = remainder_left;
gdim_right = remainder_right;
}
}
}
// Vectorizing innermost domains
// Don't try to vectorize if it's not recommended
params->unroll_factor = 1;
// Compute maximum vectorize factor that can be used
size_t vectorize_factor = max_unroll_factor;
auto& vectorizable_inputs_outputs = vectorizable_inputs_outputs_entry.get();
for (auto tv : vectorizable_inputs_outputs) {
const auto tv_vectorize_factor =
runtime_info.getInnerDimVectorizableWidth(tv);
vectorize_factor = std::min(vectorize_factor, tv_vectorize_factor);
}
// Try expanding vectorization to contig merged domains
auto expanded_vector_word_size =
scheduler_utils::expandVectorizationToContigMergedDomains(
fusion,
runtime_info,
vectorizable_inputs_outputs,
largest_out,
break_point,
vectorize_factor);
expanded_vector_word_size = std::min(
static_cast<size_t>(max_unroll_factor), expanded_vector_word_size);
if (expanded_vector_word_size > vectorize_factor) {
vectorize_factor = expanded_vector_word_size;
}
if (vectorize_factor == 1) {
params->vectorize = false;
params->unroll_factor = max_unroll_factor;
} else {
params->vectorize = true;
params->unroll_factor = vectorize_factor;
}
TORCH_INTERNAL_ASSERT(right_elem_count > 0 || break_point == 0);
TORCH_INTERNAL_ASSERT(!(bdimy > 1 && gdim_right > 1));
params->break_point = break_point;
params->flip_grid_binding = flip_grid_binding;
params->split_block = bdimy > 1;
params->lparams.bind(bdimx, ParallelType::TIDx);
if (params->split_block) {
params->lparams.bind(bdimy, ParallelType::TIDy);
}
if ((flip_grid_binding && gdim_right > 65535) ||
(!flip_grid_binding && gdim_left > 65535)) {
params->split_grid_y_dim = true;
}
if (isDebugDumpEnabled(DebugDumpOption::SchedulerDebug)) {
std::cerr << "\n===== Pointwise Stats ========\n"
<< "num_elems: " << n_elems << "\n"
<< "elem_counts: " << elem_counts << "\n"
<< "max_input_dtype_size: " << max_input_dtype_size << "\n"
<< "vectorize_factor: " << vectorize_factor << std::endl;
std::cerr << "broadcast_byte_multiples: ";
for (auto multiple : broadcast_byte_multiples) {
std::cerr << "(" << multiple.lhs_multiple << ", " << multiple.rhs_multiple
<< "), ";
}
std::cerr << "LHS elems: "
<< (right_elem_count > 0 ? n_elems / right_elem_count : 0)
<< " RHS elems: " << right_elem_count << std::endl;
std::cerr << std::endl;
std::cerr << params->toString() << std::endl;
}
return params;
}
// TODO: remove or return launch parameters
LaunchParams schedulePointwise(
Fusion* fusion,
const at::ArrayRef<c10::IValue>& runtime_inputs) {
FUSER_PERF_SCOPE("scheduleFusion");
auto params = getPointwiseHeuristics(fusion, runtime_inputs);
TORCH_INTERNAL_ASSERT(
params != nullptr, "Could not schedule pointwise operation.");
schedulePointwise(fusion, *params);
return params->lparams;
}
bool hasReferenceTensorView(Fusion* fusion) {
return DomainMap::hasReferenceTensorView(fusion);
}
// TODO: Inline intermediate operations (avoid inlining unrolled/vectorized
// input/output caches)
void schedulePointwise(Fusion* fusion, const PointwiseParams& params) {
FusionGuard fg(fusion);
// Make sure we don't have global memory set on intermediate tensors from
// fusion segmentation
scheduler_utils::clearMemorySpace(fusion);
// maybe has_reduction for scheduling should be done on a per output tensor
// basis.
TORCH_INTERNAL_ASSERT(
ir_utils::getReductionOps(fusion /*, ignore_trivial=true */).empty(),
"This scheduler only handles pointwise ops.");
// Cache inputs
auto cached_inputs = scheduler_utils::cacheInputs(fusion, true);
// Cache and fork outputs
auto cached_outputs = scheduler_utils::cacheAndForkOutputs(fusion, true);
std::vector<TensorView*> input_tvs;
{
auto filtered_tvs = ir_utils::filterByType<TensorView>(fusion->inputs());
// Remove hanging tensor views
for (auto tv : filtered_tvs) {
if (tv->uses().empty()) {
continue;
}
input_tvs.push_back(tv);
}
}
auto output_tvs = ir_utils::filterByType<TensorView>(fusion->outputs());
size_t max_dims = 0;
for (auto inp : input_tvs) {
max_dims = std::max(pointwise_utils::nRootDims(inp), max_dims);
}
for (auto out : output_tvs) {
max_dims = std::max(pointwise_utils::nRootDims(out), max_dims);
}
// If everything is zero dim tensors, just return.
if (max_dims == 0) {
return;
}
DomainMap domain_map(fusion);
TensorView* reference_tv =
domain_map.findReferenceTensorView(params.break_point);
TORCH_INTERNAL_ASSERT(
reference_tv != nullptr,
"Could not find a fully broadcasted output to reference schedule on.");
auto all_tvs = ir_utils::allTvs(fusion);
// Merge right side of break point
int rhs_i = -1;
for (int i = (int)reference_tv->nDims(); i > (int)params.break_point; i--) {
auto axis_i = i - 1;
if (rhs_i == -1) {
rhs_i = axis_i;
} else {
reference_tv->merge(axis_i, rhs_i);
rhs_i = axis_i;
}
}
if (rhs_i >= 0) {
// If there's an rhs
reference_tv->reorder({{rhs_i, -1}});
}
// Merge left side of break point
int lhs_i = -1;
for (int i = (int)params.break_point; i > 0; i--) {
auto axis_i = i - 1;
if (lhs_i == -1) {
lhs_i = axis_i;
} else {
reference_tv->merge(axis_i, lhs_i);
lhs_i = axis_i;
}
}
int64_t unswitch_pos;
IterDomain* vectorize_id = nullptr;
if (params.break_point) {
// 2D parallelization scheme
TORCH_INTERNAL_ASSERT(rhs_i >= 0 && lhs_i >= 0);
// Right (inner merged) dimension is at inner most position, left (outer
// merged) dimension is at lhs_i. Order as [lhs_i, rhs_i, unmerged...]
reference_tv->reorder({{lhs_i, 0}, {-1, 1}});
if (params.vectorize) {
reference_tv->split(1, params.unroll_factor);
reference_tv->split(1, NamedScalar::getParallelDim(ParallelType::TIDx));
reference_tv->split(0, 1);
// [outer, Unswitch | i-remainder, TIDx, Vectorization]
reference_tv->axis(1)->parallelize(ParallelType::Unswitch);
reference_tv->axis(3)->parallelize(ParallelType::TIDx);
// Vectorization are propagated separately
vectorize_id = reference_tv->axis(4);
// [outer, Unswitch | i-remainder, TIDx, Vectorization]
// To make consistent with unrolling:
reference_tv->reorder({{1, 2}, {2, 1}, {3, 4}, {4, 3}});
//[outer | i-remainder, Unswitch, Vectorization, TIDx]
} else {
reference_tv->split(1, NamedScalar::getParallelDim(ParallelType::TIDx));
reference_tv->split(1, params.unroll_factor);
reference_tv->split(0, 1);
// [outer, unswitch | i-remainder, unroll, TIDx ]
reference_tv->reorder({{1, 2}});
// [outer, i-remainder, unswitch, unroll, TIDx ]
reference_tv->axis(2)->parallelize(ParallelType::Unswitch);
// Here we do not set axis(3)->parallelize(Unroll) because we do not want
// it to be propagated. We manually unroll by splitting the inline
// propagation process into two steps:
// step 1: inline at the unswitch position for cached inputs and outputs
// step 2: inline at the inner most dim for the rest of the graph
reference_tv->axis(4)->parallelize(ParallelType::TIDx);
//[outer | i-remainder, Unswitch, Unroll, TIDx]
}
// Move out of the way to furthest left point
reference_tv->reorder({{1, 0}});
//[i-remainder | outer | Unswitch, Unroll, TIDx]
if (params.split_block) {
reference_tv->split(1, NamedScalar::getParallelDim(ParallelType::TIDy));
if (params.flip_grid_binding) {
// [BIDy | BIDx, TIDy | Unswitch, Unroll, TIDx]
reference_tv->axis(1)->parallelize(ParallelType::BIDx);
reference_tv->axis(2)->parallelize(ParallelType::TIDy);
if (params.split_grid_y_dim) {
// [i-remainder, BIDy{65535} | BIDx, TIDy | Unswitch, Unroll, TIDx]
reference_tv->split(0, 65535);
reference_tv->axis(1)->parallelize(ParallelType::BIDy);
unswitch_pos = 5;
} else {
reference_tv->axis(0)->parallelize(ParallelType::BIDy);
unswitch_pos = 4;
}
} else {
// [BIDx | BIDy TIDy | Unswitch, Unroll, TIDx]
reference_tv->axis(0)->parallelize(ParallelType::BIDx);
reference_tv->axis(2)->parallelize(ParallelType::TIDy);
if (params.split_grid_y_dim) {
// [BIDx | i-remainder, BIDy{65535}, TIDy | Unswitch, Unroll, TIDx]
reference_tv->split(1, 65535);
reference_tv->axis(2)->parallelize(ParallelType::BIDy);
unswitch_pos = 5;
} else {
reference_tv->axis(1)->parallelize(ParallelType::BIDy);
unswitch_pos = 4;
}
}
} else {
// [BIDy | BIDx | Unswitch, Unroll, TIDx]
if (params.flip_grid_binding) {
// [BIDy | BIDx | Unswitch, Unroll, TIDx]
reference_tv->axis(1)->parallelize(ParallelType::BIDx);
if (params.split_grid_y_dim) {
// [i-remainder, BIDy{65535} | BIDx | Unswitch, Unroll, TIDx]
reference_tv->split(0, 65535);
reference_tv->axis(1)->parallelize(ParallelType::BIDy);
unswitch_pos = 4;
} else {
reference_tv->axis(0)->parallelize(ParallelType::BIDy);
unswitch_pos = 3;
}
} else {
// [BIDx | BIDy | Unswitch, Unroll, TIDx]
reference_tv->axis(0)->parallelize(ParallelType::BIDx);
if (params.split_grid_y_dim) {
// [BIDx | i-remainder, BIDy{65535} | Unswitch, Unroll, TIDx]
reference_tv->split(1, 65535);
reference_tv->axis(2)->parallelize(ParallelType::BIDy);
unswitch_pos = 4;
} else {
reference_tv->axis(1)->parallelize(ParallelType::BIDy);
unswitch_pos = 3;
}
}
}
} else {
// 1D Scheduler
TORCH_INTERNAL_ASSERT(rhs_i >= 0 && lhs_i == -1);
// right hand side exists and is the only axis we care to schedule, move
// it from the inner most position to left most. Order as [rhs_i,
// unmerged...]
reference_tv->reorder({{-1, 0}});
if (params.vectorize) {
// Vectorize
reference_tv->split(0, params.unroll_factor);
// Unswitch
reference_tv->split(0, 1);
// Threads
reference_tv->split(0, kThreadX);
reference_tv->axis(0)->parallelize(ParallelType::BIDx);
reference_tv->axis(1)->parallelize(ParallelType::TIDx);
reference_tv->axis(2)->parallelize(ParallelType::Unswitch);
// Vectorization are propagated separately
vectorize_id = reference_tv->axis(3);
//[BIDx, TIDx, Unswitch, Vectorization]
// To make consistent with unrolling:
reference_tv->reorder({{1, 3}, {2, 1}, {3, 2}});
//[BIDx, Unswitch, Vectorization, TIDx]
} else {
// Threads
reference_tv->split(0, kThreadX);
// Unroll
reference_tv->split(0, params.unroll_factor);
// Unswitch
reference_tv->split(0, 1);
// [BIDx, Unswitch, Unroll, TIDx]
reference_tv->axis(0)->parallelize(ParallelType::BIDx);
reference_tv->axis(1)->parallelize(ParallelType::Unswitch);
// Here we do not set axis(2)->parallelize(Unroll) because we do not want
// it to be propagated. We manually unroll by splitting the inline
// propagation process into two steps:
// step 1: inline at the unswitch position for cached inputs and outputs
// step 2: inline at the inner most dim for the rest of the graph
reference_tv->axis(3)->parallelize(ParallelType::TIDx);
}
unswitch_pos = 2;
}
TransformPropagator propagator(reference_tv);
MaxRootDomainInfoSpanningTree spanning_tree(reference_tv);
spanning_tree.traverse(&propagator);
scheduler_utils::parallelizeAllLike(reference_tv);
if (params.vectorize) {
// Grab all tensor views that should be vectorized
auto inputs_outputs =
scheduler_utils::getInputsOutputsWithInnerDim(reference_tv, true, true);
std::vector<TensorView*> vectorized_tvs;
bool should_vectorize_reference_tv = false;
for (auto tv : inputs_outputs) {
if (tv == reference_tv) {
should_vectorize_reference_tv = true;
}
if (!tv->isFusionInput()) {
vectorized_tvs.emplace_back(tv);
continue;
}
// move inputs to consumers of inputs
auto consumer_tvs = ir_utils::consumerTvsOf(tv);
vectorized_tvs.insert(
vectorized_tvs.end(), consumer_tvs.begin(), consumer_tvs.end());
}
// Aggressively mark with vectorized and cleanup later. That way we
// don't have to manually specify parallelization outside the reference.
vectorize_id->parallelize(ParallelType::Vectorize);
scheduler_utils::parallelizeAllLike(
reference_tv, vectorized_tvs, {ParallelType::Vectorize});
if (!should_vectorize_reference_tv) {
vectorize_id->parallelize(ParallelType::Serial);
}
}
// Begin by inlining at the unswitch position for the entire DAG. The cached
// inputs, and outputs will keep this inline position, but other tensors will
// get a higher position in later inline propagation. We need this separate
// step because we were not using ParallelType::Unroll, so we have to do
// unrolling manually.
InlinePropagator inline_unswitch(
reference_tv, unswitch_pos, ComputeAtMode::BestEffort);
spanning_tree.traverse(&inline_unswitch);
// Inline at the inner most position. The CA position of all tensors except
// inputs, cached inputs and outputs will be updated.
std::unordered_set<TensorView*> inner_most_tensors(
all_tvs.begin(), all_tvs.end());
for (auto cached_input : cached_inputs) {
inner_most_tensors.erase(cached_input);
}
for (auto entry : cached_outputs) {
auto output = entry.second;
inner_most_tensors.erase(output);
}
InlinePropagator inline_inner_most(
reference_tv, -1, ComputeAtMode::BestEffort, inner_most_tensors);
spanning_tree.traverse(&inline_inner_most);
}
} // namespace cuda
} // namespace fuser
} // namespace jit
} // namespace torch
|