1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406
|
#include <torch/csrc/jit/codegen/cuda/lower2device.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/csrc/jit/codegen/cuda/expr_evaluator.h>
#include <torch/csrc/jit/codegen/cuda/fusion.h>
#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
#include <torch/csrc/jit/codegen/cuda/lower_alias_memory.h>
#include <torch/csrc/jit/codegen/cuda/lower_allocation.h>
#include <torch/csrc/jit/codegen/cuda/lower_double_buffer.h>
#include <torch/csrc/jit/codegen/cuda/lower_expr_sort.h>
#include <torch/csrc/jit/codegen/cuda/lower_fusion_simplifier.h>
#include <torch/csrc/jit/codegen/cuda/lower_index.h>
#include <torch/csrc/jit/codegen/cuda/lower_insert_syncs.h>
#include <torch/csrc/jit/codegen/cuda/lower_instrument.h>
#include <torch/csrc/jit/codegen/cuda/lower_loops.h>
#include <torch/csrc/jit/codegen/cuda/lower_magic_zero.h>
#include <torch/csrc/jit/codegen/cuda/lower_misaligned_vectorization.h>
#include <torch/csrc/jit/codegen/cuda/lower_predicate.h>
#include <torch/csrc/jit/codegen/cuda/lower_replace_size.h>
#include <torch/csrc/jit/codegen/cuda/lower_shift.h>
#include <torch/csrc/jit/codegen/cuda/lower_trivial_reductions.h>
#include <torch/csrc/jit/codegen/cuda/lower_unroll.h>
#include <torch/csrc/jit/codegen/cuda/lower_utils.h>
#include <torch/csrc/jit/codegen/cuda/lower_validation.h>
#include <torch/csrc/jit/codegen/cuda/lower_warp_reduce.h>
#include <list>
#include <unordered_map>
#include <unordered_set>
namespace torch {
namespace jit {
namespace fuser {
namespace cuda {
thread_local GpuLower* active_gpu_lower = nullptr; // NOLINT
namespace {
class KIRCleaner : public OptOutDispatch {
public:
//! Remove nop IR nodes
static std::vector<Expr*> cleanUp(const std::vector<Expr*>& loop_nests) {
KIRCleaner cleaner;
std::vector<Expr*> out_loop_nests;
for (auto loop_nest : loop_nests) {
cleaner.handle(loop_nest);
// No need to keep the loop nest if it's determined to be nop
if (!cleaner.is_nop_) {
out_loop_nests.push_back(loop_nest);
}
}
return out_loop_nests;
}
private:
using OptOutDispatch::handle;
void handle(Expr* expr) final {
if (expr->isA<kir::ForLoop>() || expr->isA<kir::IfThenElse>()) {
OptOutDispatch::handle(expr);
} else {
// Any non-scoping expr is not considered nop
is_nop_ = false;
}
}
void handle(kir::ForLoop* fl) final {
auto exprs = fl->body().exprs();
fl->body().clear();
for (auto expr : exprs) {
handle(expr);
// Add the expr to the loop body only when the expr is not nop
if (!is_nop_) {
fl->body().push_back(expr);
}
}
// The loop is nop when no expr exists in the body
is_nop_ = fl->body().empty();
}
void handle(kir::IfThenElse* ite) final {
const auto conditional = ite->predicate()->value();
// Visit the then block
auto then_exprs = ite->thenBody().exprs();
ite->thenBody().clear();
if (!conditional->isConst() || conditional->value().value()) {
for (auto expr : then_exprs) {
handle(expr);
if (!is_nop_) {
ite->thenBody().push_back(expr);
}
}
}
const bool then_nop = ite->thenBody().empty();
// Visit the else block
auto else_exprs = ite->elseBody().exprs();
ite->elseBody().clear();
if (!conditional->isConst() || !conditional->value().value()) {
for (auto expr : else_exprs) {
handle(expr);
if (!is_nop_) {
ite->elseBody().push_back(expr);
}
}
}
const bool else_nop = ite->elseBody().empty();
// If the then block is nop but the else is not, invert the
// conditional and move the exprs in the else block to the then
// block.
if (then_nop && !else_nop) {
Bool* pred = ite->predicate()->value();
Bool* not_pred = SimplifyingIrBuilder::notExpr(pred)->as<Bool>();
ite->predicate()->setValue(not_pred);
for (auto expr : ite->elseBody().exprs()) {
ite->thenBody().push_back(expr);
}
ite->elseBody().clear();
}
// This IfThenElse is nop if both the then and else blocks are nop
is_nop_ = then_nop && else_nop;
}
private:
//! True if the last visited expr is nop
bool is_nop_ = false;
};
} // namespace
void GpuLower::collectPaddedParallelDims() {
ExpressionEvaluator ee(fusion_);
bool can_be_single_warp = true;
auto warp_size = at::cuda::warp_size();
auto used_vals = fusion_->usedMathVals();
for (auto tv : ir_utils::filterByType<TensorView>(used_vals)) {
for (auto id : tv->domain()->domain()) {
if (tv->definition()) {
// TODO: Support GroupedReductionOp
if (auto reduction = dynamic_cast<ReductionOp*>(tv->definition())) {
if (ir_utils::getMaybeWarpReductionDim(
reduction->out(), reduction->in())
.has_value()) {
warp_pad_info_.has_warp_reduction = true;
}
}
}
// Check ifi TIDx is padded in this kernel
if (id->hasPaddingToMultipleOfWarp()) {
TORCH_INTERNAL_ASSERT(
id->getParallelType() == ParallelType::TIDx,
"Padded types supported only on TIDx");
warp_pad_info_.is_tidx_padded = true;
}
// Check all possible bindings of TIDx to see
// if TIDx will eventually be bound to a single warp.
if (id->getParallelType() == ParallelType::TIDx) {
auto eval_dim = ee.evaluate(id->extent());
auto size_after_padding = id->getMaybeSizeAfterPadding();
bool padding_to_single_warp = size_after_padding.has_value() &&
size_after_padding.value() == warp_size;
if ((!eval_dim.has_value() || eval_dim.value() > warp_size) &&
!padding_to_single_warp) {
// If we see any other TIDx binding that's larger than
// a warp or unknown, we shouldn't lower warp reduce
// to a single warp type.
can_be_single_warp = false;
warp_pad_info_.is_tidx_single_warp = false;
} else if (can_be_single_warp) {
if (padding_to_single_warp ||
(eval_dim.has_value() && eval_dim.value() == warp_size)) {
warp_pad_info_.is_tidx_single_warp = true;
}
}
}
}
}
}
void assignRNGOffset(Fusion* fusion) {
int counter = 0;
for (auto expr : fusion->exprs()) {
if (expr->isA<RNGOp>()) {
auto rop = expr->as<RNGOp>();
rop->setRNGOffset(counter++);
}
}
}
void GpuLower::lower(Fusion* fusion, DataType index_type) {
FUSER_PERF_SCOPE("GpuLower::lower");
TORCH_INTERNAL_ASSERT(fusion != nullptr);
TORCH_INTERNAL_ASSERT(
active_gpu_lower == nullptr, "Nested lowering passes are not supported");
struct LowerGuard {
LowerGuard(GpuLower* gpu_lower) {
active_gpu_lower = gpu_lower;
}
~LowerGuard() {
active_gpu_lower = nullptr;
}
} lower_guard(this);
// Copy fusion into a new kernel for processing
kernel_ = std::make_unique<kir::Kernel>(fusion, index_type);
// Alias the fusion kernel caries around as a view of itself.
fusion_ = kernel_.get();
// Convert tensor views of DataType::Index type to either Int or Int32
for (auto tv : ir_utils::allTvs(fusion_)) {
if (tv->dtype() == DataType::Index) {
tv->resolveIndexDtype();
}
}
assignRNGOffset(fusion_);
FusionGuard fg(fusion_);
// prepare for lowering
validateIr(fusion_);
// Checks if any TIDx dim is marked as padded to a warp. Also checks if we can
// determine the padding is explicitly a single warp.
collectPaddedParallelDims();
// Replaces integers that are tensor sizes by named scalars as "T0.size[0]"
replaceSymbolicSizes(fusion_);
// Traverse through reductions and termine if any iteration domains are
// trivial reductions. Add these iteration domains to trivial_reduction_info_
// which simply holds a map of which axes are trivial and which are not.
trivial_reduction_info_.build(fusion_);
// Replaces trivial reduction expressions (all id's being reduced are trivial)
// with set unary op
trivialReductionReplacement(fusion_, trivial_reduction_info_);
// Build what's refered to as the compute at map. This map contains the
// mappings of all iteration domains across the fusion. There are three types
// of mappings Permissive, Exact, and Loop, see compute_at_map.h/cpp for more
// information.
compute_at_map_ = std::make_unique<ComputeAtMap>(fusion_);
if (isDebugDumpEnabled(DebugDumpOption::ComputeAtMap)) {
std::cout << compute_at_map_->toString() << std::endl;
}
compute_at_map_->validateAndPropagatePType();
// Used in parallel dimension map
concretized_broadcast_domains_.build(fusion_);
parallelDimensionMap().build(fusion_);
if (isDebugDumpEnabled(DebugDumpOption::ParallelDimensions)) {
std::cout << "Parallel dimension map:" << std::endl;
std::cout << parallel_dimension_map_.toString() << std::endl;
}
// Validate mma data format and compatibility if any on the fusion.
validateMma(fusion_);
// Validate swizzle usage on the fusion schedule.
validateSwizzle(fusion_);
// Compute thread predicates. Depends on parallel_dimension_map_
thread_pred_map_.build(fusion_);
// Fuse cetain patterns of reductions, such as a grid reduction
// followed by a grid broadcast. Only depends on parallelization and
// thread predicate map.
fuseReductionsAndBroadcasts(fusion_);
// Scan the whole fusion and build mappings about halo extensions of
// all IterDomains
haloInfo().build(fusion_);
// Want to run this after parallel map and halo info map are
// created. vectorized_accesses_ and vectorized_set_info_ are filled.
validateAndCollectVectorizeInfo(fusion_);
// Depends on ComputeAtMap and HaloInfo.
validateAndConvertIterDomainGrouping(fusion_);
// Assumes all grouped reductions are convered to
// GroupedReductionOp, which is done by
// validateAndConvertIterDomainGrouping
validateGroupedReductions(fusion_);
// Depends on thread_pred_map_, validates parallelization collects which
// tensor views need WAR or RAW syncs
sync_map_.build(fusion_);
partialSplitMap().build(fusion_);
validatePartialSplit(fusion_);
nonDivisibleSplitInfo().build(fusion_);
// Detects all exprssions that don't need predicates. Depends on
// nonDivisibleSplitInfo.
predicateElimination().build(fusion_);
doubleBufferInfo().build(fusion_);
compute_at_map_->allocateIndexVariables();
// Run our passes keeping the lowered expressions and forwarding
// them
// Reorder expressions for loop-nest generation respecting computeAt
// relationships
const auto exprs_sorted = reorderExprsForComputeAt();
// Generate loop-nests and place each expression at its
// corresponding loop
const auto exprs_lowered = LoopNestGenerator::loweredExprs(exprs_sorted);
// Replace trivial reductions, Transpose, Shift, Gather, and View ops with
// unary ops since they're not separately processed in lowering.
const auto exprs_unary_replaced = unarySetOpInserter(exprs_lowered);
// Insert allocations
const auto exprs_alloced = insertAllocations(exprs_unary_replaced);
// Insert read after write smem syncs
const auto exprs_raw_sync = insertRawThreadSynchronization(exprs_alloced);
// Reuse memory locations
const auto exprs_reuse_mem = reuseMemoryAllocations(exprs_raw_sync);
// Insert SyncThreads at end of for-loop to avoid WAR race condition
const auto exprs_war_sync = insertWarThreadSynchronization(exprs_reuse_mem);
const auto exprs_double_buffered = DoubleBufferPass::run(exprs_war_sync);
// This pass inserts predicates as well as branches in the code. Up until now
// the code is explicitly single shot for loop based. Need to be careful in
// later passes when doing any kind of insertions in loop nest structure as
// insertions could be on if then or else instead of directly on a for loop.
const auto exprs_unrolled_loops =
UnrollPass::runPass(fusion_, exprs_double_buffered);
const auto exprs_unrolled_mv_loops =
processMisalignedVectorization(exprs_unrolled_loops);
const auto exprs_indexed_loops =
IndexLowering::getIndexedExprs(exprs_unrolled_mv_loops);
// TODO: It seems this type of optimization would be far easier to implement
// on fusion ir than kernel ir. We should likely refactor this to at least run
// before allocation insertion.
const auto exprs_with_fused_broadcast = fuseWarpReduce(exprs_indexed_loops);
const auto exprs_conditional_loops =
generateConditionalFromPredicate(exprs_with_fused_broadcast);
const auto exprs_common_index_allocated =
allocateCommonIndices(exprs_conditional_loops);
// Insert fake zero updates to make sure nvrtc doesn't blow out register use
// on index and predicate reuse
const auto exprs_register_adjusted =
insertMagicZero(exprs_common_index_allocated);
const auto exprs_cleaned_up_loops =
KIRCleaner::cleanUp(exprs_register_adjusted);
const auto exprs_instrumented = instrumentKernel(exprs_cleaned_up_loops);
// We now have the lowered expressions, finalize the kernel IR. This function
// will also copy over some relevant information for code generation from
// GpuLower.
kernel_->finalize(exprs_instrumented);
}
kir::Kernel* GpuLower::kernel() const {
TORCH_CHECK(kernel_);
return kernel_.get();
}
GpuLower* GpuLower::current() {
TORCH_INTERNAL_ASSERT(
active_gpu_lower != nullptr, "No active GpuLower available");
return active_gpu_lower;
}
bool GpuLower::hasCurrent() {
return active_gpu_lower != nullptr;
}
void GpuLower::propagateExprInfo(const Expr* old_expr, const Expr* new_expr) {
pred_elimination_.propagateRemovalInfo(old_expr, new_expr);
}
} // namespace cuda
} // namespace fuser
} // namespace jit
} // namespace torch
|