1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126
|
#include <torch/csrc/jit/codegen/cuda/dispatch.h>
#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
#include <torch/csrc/jit/codegen/cuda/iter_visitor.h>
#include <torch/csrc/jit/codegen/cuda/lower2device.h>
#include <torch/csrc/jit/codegen/cuda/lower_trivial_reductions.h>
#include <torch/csrc/jit/codegen/cuda/lower_utils.h>
#include <torch/csrc/jit/codegen/cuda/root_domain_map.h>
#include <unordered_set>
namespace torch {
namespace jit {
namespace fuser {
namespace cuda {
namespace {
bool analyzeIfDerivedFromTrivialReduction(TensorView* tv, IterDomain* id);
// Checks the producer of tv to see if the
bool traverseToRFactorTensor(TensorView* tv, IterDomain* root_id) {
TORCH_INTERNAL_ASSERT(
root_id->definition() == nullptr, "Not root IterDomain: ", root_id);
auto def = tv->definition();
if (def == nullptr) {
// This is an input tensor, so no rfactor tensor to traverse.
return false;
}
// Check the reduction expression that produces tv
if (!ir_utils::isReductionOp(def) || def->isA<MmaOp>()) {
return false;
}
TORCH_INTERNAL_ASSERT(
def->inputs().size() == def->outputs().size(),
"This logic block assumes number of inputs is the same as number of outputs of reduction ops.");
// Reduction expr may have multiple inputs, just grab any TV
// input. Note that in theory it is possible that a
// GroupedReductionOp has rfactor inputs as well as non-rfactor
// inputs, so grabbing the one that actually corresponds to tv can
// be important. In reality, though, such a GroupedReductionOp
// should not happen as we do not group reductions of rfactor and
// non-rfactor tensor.
auto producer_tv = ir_utils::getTvInput(def);
TORCH_INTERNAL_ASSERT(producer_tv != nullptr);
if (!producer_tv->hasRFactor()) {
return false;
}
auto c2p = PairwiseRootDomainMap(producer_tv, tv)
.mapConsumerToProducer(tv->domain(), producer_tv->domain());
auto producer_id_it = c2p.find(root_id);
if (producer_id_it == c2p.end()) {
// No matching producer is found. Stop traversing.
return false;
}
auto producer_root_id = producer_id_it->second;
return analyzeIfDerivedFromTrivialReduction(producer_tv, producer_root_id);
}
bool analyzeIfDerivedFromTrivialReduction(TensorView* tv, IterDomain* id) {
auto id_inputs = InputsOf::output(id->fusion(), id);
for (auto root_id : ir_utils::filterByType<IterDomain>(id_inputs)) {
if (root_id->isReduction() && root_id->extent()->isOneInt()) {
continue;
}
// If not possible to prove the root ID is trivial, see if the ID
// is derived from a rfactor tensor. This may mean that the iteration domain
// was merged or split in another expression through rfactor. Trace back
// through rfactor expressions to find original roots and determine there if
// trivial.
if (!traverseToRFactorTensor(tv, root_id)) {
return false;
}
}
return true;
}
} // namespace
void TrivialReductionInfo::build(Fusion* fusion) {
auto used_vals = fusion->usedMathVals();
for (auto tv : ir_utils::filterByType<TensorView>(used_vals)) {
for (auto id : tv->domain()->domain()) {
if (analyzeIfDerivedFromTrivialReduction(tv, id)) {
// If id is a trivial reduction, all of its ancestor vals are
// also trivial reductions.
for (auto dep_id : DependencyCheck::getAllValsBetween(
std::unordered_set<Val*>(
tv->getRootDomain().begin(), tv->getRootDomain().end()),
{id})) {
domains_.insert(dep_id->as<IterDomain>());
domains_derived_from_root_.insert(dep_id->as<IterDomain>());
}
} else if (id->isReduction() && id->extent()->isOneInt()) {
// This happens when a leaf domain is trivial but its root
// axes are not. For example, consider a non-trivial domain
// split by one. The inner output axis is a trivial domain,
// whereas the outer output axis is not. Since the root axis
// is not trivial, a for-loop needs to be generated.
domains_.insert(id);
}
}
}
}
bool TrivialReductionInfo::isDerived(IterDomain* id) const {
return domains_.find(id) != domains_.end();
}
} // namespace cuda
} // namespace fuser
} // namespace jit
} // namespace torch
|