File: lower_unroll.cpp

package info (click to toggle)
pytorch 1.7.1-7
  • links: PTS, VCS
  • area: main
  • in suites: bullseye
  • size: 80,340 kB
  • sloc: cpp: 670,830; python: 343,991; ansic: 67,845; asm: 5,503; sh: 2,924; java: 2,888; xml: 266; makefile: 244; ruby: 148; yacc: 144; objc: 51; lex: 44
file content (142 lines) | stat: -rw-r--r-- 4,718 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142

#include <torch/csrc/jit/codegen/cuda/lower_unroll.h>

#include <torch/csrc/jit/codegen/cuda/arith.h>
#include <torch/csrc/jit/codegen/cuda/index_compute.h>
#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
#include <torch/csrc/jit/codegen/cuda/kernel_ir_builder.h>
#include <torch/csrc/jit/codegen/cuda/lower2device.h>
#include <torch/csrc/jit/codegen/cuda/lower_utils.h>
#include <torch/csrc/jit/codegen/cuda/predicate_compute.h>

namespace torch {
namespace jit {
namespace fuser {

kir::Bool* UnrollPass::getThreadPredicate(TensorView* tv) {
  // No thread predicate is needed predicate when tv is output of a
  // parallel broadcast expression.
  const auto origin = tv->getOrigin();
  if (origin != nullptr && origin->getExprType() == ExprType::BroadcastOp) {
    const auto out = origin->as<BroadcastOp>()->out();
    if (ir_utils::getParallelBroadcastDomains(out, thread_predicates_).any()) {
      return nullptr;
    }
  }

  return thread_predicates_.getExpr(tv);
}

// Custom dispatch for Expr, want to find out of it's a TV op.
void UnrollPass::handle(Expr* expr) {
  // If tv op, predciate it.
  if (ir_utils::isTVOp(expr)) {
    TORCH_INTERNAL_ASSERT(for_loops.size() != 0);

    auto pred = PredicateCompute::getInlinePredicate(
        expr, for_loops, getThreadPredicate(ir_utils::getTVOutput(expr)));

    // If we need a predicate, put expr inside an if then else
    if (!(pred->isConst()) || !(pred->isConst() && pred->value().value())) {
      non_trivial_pred_found = true;
      kir::IrBuilder ir_builder(GpuLower::current()->kernel());
      kir::IfThenElse* inline_ite =
          ir_builder.create<kir::IfThenElse>(pred, for_loops.back());
      inline_ite->thenBody().push_back(expr);
      for_loops.back()->body().insert_before(expr, inline_ite);
      for_loops.back()->body().erase(expr);
    }

  } else {
    // If not tv op, dispatch it.
    OptOutDispatch::handle(expr);
  }
}

// We should factor our actual predicate generation from unrolling but insering
// IR nodes "unroll_pred" or "inline_pred", then generate those later.
void UnrollPass::handle(kir::ForLoop* fl) {
  // Setup for loop scoping
  bool is_unroll = ir_utils::isUnrolledFor(fl);
  // If we're not looking for an unroll loop, or didn't find one, process as
  // normal.
  if (!is_unroll || !look_for_unroll) {
    for_loops.push_back(fl);

    std::vector<Expr*> exprs_copy = fl->body().exprs();
    // Make copy of exprs because we replace them inplace in fl
    for (auto expr : exprs_copy) {
      handle(expr);
    }
    for_loops.pop_back();

    return;
  }

  auto unroll_pred = UnrollPredicate::get(for_loops, fl, p2c_root_map);

  kir::ForLoop* parent_scope = for_loops.empty() ? nullptr : for_loops.back();

  kir::IrBuilder ir_builder(GpuLower::current()->kernel());
  kir::IfThenElse* unroll_ite =
      ir_builder.create<kir::IfThenElse>(unroll_pred, parent_scope);

  // Get the loop nest for the unrolled path
  kir::ForLoop* unrolled_loop_nest = scope_utils::cloneLoopNest(fl, unroll_ite);

  unroll_ite->thenBody().push_back(unrolled_loop_nest);

  // Loop nest for inlined path
  kir::ForLoop* inlined_loop = scope_utils::cloneLoopNest(fl, unroll_ite);

  // Add inline predicates for inlined loop nest
  look_for_unroll = false;
  non_trivial_pred_found = false;
  handle(inlined_loop);
  look_for_unroll = true;
  if (!non_trivial_pred_found) {
    inlined_loop->setParentScope(parent_scope);
    loop_replacement_map.insert({fl, inlined_loop});
  } else {
    unroll_ite->elseBody().push_back(inlined_loop);
    loop_replacement_map.insert({fl, unroll_ite});
  }
}

// Generate the loop nest structure and place it in lowered_exprs
void UnrollPass::computeMap() {
  FUSER_PERF_SCOPE("UnrollPass::computeMap");

  FusionGuard fg(fusion_);

  // Run through loop nests and further lower the expressions
  for (auto* expr : incoming_exprs_) {
    OptOutDispatch::handle(expr);
  }
}

std::vector<Expr*> UnrollPass::runPass(
    Fusion* fusion,
    const std::vector<Expr*>& exprs,
    const ThreadPredicateMap& thread_predicates) {
  FUSER_PERF_SCOPE("UnrollPass::runPass");
  FusionGuard fg(fusion);
  UnrollPass up(fusion, exprs, thread_predicates);
  up.computeMap();
  std::vector<Expr*> mutated_exprs;
  for (Expr* expr : exprs) {
    if (up.loop_replacement_map.find(expr) != up.loop_replacement_map.end()) {
      mutated_exprs.push_back(up.loop_replacement_map[expr]);
    } else {
      if (ir_utils::isScope(expr))
        scope_utils::replaceExprsInScope(expr, up.loop_replacement_map);
      mutated_exprs.push_back(expr);
    }
  }
  return mutated_exprs;
}

} // namespace fuser
} // namespace jit
} // namespace torch