File: lower_trivial_reductions.cpp

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (126 lines) | stat: -rw-r--r-- 4,427 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
#include <torch/csrc/jit/codegen/cuda/dispatch.h>
#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
#include <torch/csrc/jit/codegen/cuda/iter_visitor.h>
#include <torch/csrc/jit/codegen/cuda/lower2device.h>
#include <torch/csrc/jit/codegen/cuda/lower_trivial_reductions.h>
#include <torch/csrc/jit/codegen/cuda/lower_utils.h>
#include <torch/csrc/jit/codegen/cuda/root_domain_map.h>

#include <unordered_set>

namespace torch {
namespace jit {
namespace fuser {
namespace cuda {

namespace {

bool analyzeIfDerivedFromTrivialReduction(TensorView* tv, IterDomain* id);

// Checks the producer of tv to see if the
bool traverseToRFactorTensor(TensorView* tv, IterDomain* root_id) {
  TORCH_INTERNAL_ASSERT(
      root_id->definition() == nullptr, "Not root IterDomain: ", root_id);

  auto def = tv->definition();

  if (def == nullptr) {
    // This is an input tensor, so no rfactor tensor to traverse.
    return false;
  }

  // Check the reduction expression that produces tv
  if (!ir_utils::isReductionOp(def) || def->isA<MmaOp>()) {
    return false;
  }

  TORCH_INTERNAL_ASSERT(
      def->inputs().size() == def->outputs().size(),
      "This logic block assumes number of inputs is the same as number of outputs of reduction ops.");

  // Reduction expr may have multiple inputs, just grab any TV
  // input. Note that in theory it is possible that a
  // GroupedReductionOp has rfactor inputs as well as non-rfactor
  // inputs, so grabbing the one that actually corresponds to tv can
  // be important. In reality, though, such a GroupedReductionOp
  // should not happen as we do not group reductions of rfactor and
  // non-rfactor tensor.
  auto producer_tv = ir_utils::getTvInput(def);

  TORCH_INTERNAL_ASSERT(producer_tv != nullptr);

  if (!producer_tv->hasRFactor()) {
    return false;
  }

  auto c2p = PairwiseRootDomainMap(producer_tv, tv)
                 .mapConsumerToProducer(tv->domain(), producer_tv->domain());

  auto producer_id_it = c2p.find(root_id);
  if (producer_id_it == c2p.end()) {
    // No matching producer is found. Stop traversing.
    return false;
  }

  auto producer_root_id = producer_id_it->second;

  return analyzeIfDerivedFromTrivialReduction(producer_tv, producer_root_id);
}

bool analyzeIfDerivedFromTrivialReduction(TensorView* tv, IterDomain* id) {
  auto id_inputs = InputsOf::output(id->fusion(), id);
  for (auto root_id : ir_utils::filterByType<IterDomain>(id_inputs)) {
    if (root_id->isReduction() && root_id->extent()->isOneInt()) {
      continue;
    }
    // If not possible to prove the root ID is trivial, see if the ID
    // is derived from a rfactor tensor. This may mean that the iteration domain
    // was merged or split in another expression through rfactor. Trace back
    // through rfactor expressions to find original roots and determine there if
    // trivial.
    if (!traverseToRFactorTensor(tv, root_id)) {
      return false;
    }
  }
  return true;
}

} // namespace

void TrivialReductionInfo::build(Fusion* fusion) {
  auto used_vals = fusion->usedMathVals();

  for (auto tv : ir_utils::filterByType<TensorView>(used_vals)) {
    for (auto id : tv->domain()->domain()) {
      if (analyzeIfDerivedFromTrivialReduction(tv, id)) {
        // If id is a trivial reduction, all of its ancestor vals are
        // also trivial reductions.
        for (auto dep_id : DependencyCheck::getAllValsBetween(
                 std::unordered_set<Val*>(
                     tv->getRootDomain().begin(), tv->getRootDomain().end()),
                 {id})) {
          domains_.insert(dep_id->as<IterDomain>());
          domains_derived_from_root_.insert(dep_id->as<IterDomain>());
        }
      } else if (id->isReduction() && id->extent()->isOneInt()) {
        // This happens when a leaf domain is trivial but its root
        // axes are not. For example, consider a non-trivial domain
        // split by one. The inner output axis is a trivial domain,
        // whereas the outer output axis is not. Since the root axis
        // is not trivial, a for-loop needs to be generated.
        domains_.insert(id);
      }
    }
  }
}

bool TrivialReductionInfo::isDerived(IterDomain* id) const {
  return domains_.find(id) != domains_.end();
}

} // namespace cuda
} // namespace fuser
} // namespace jit
} // namespace torch