File: lower_predicate.cpp

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (209 lines) | stat: -rw-r--r-- 7,851 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
#include <torch/csrc/jit/codegen/cuda/lower_predicate.h>

#include <torch/csrc/jit/codegen/cuda/arith.h>
#include <torch/csrc/jit/codegen/cuda/index_compute.h>
#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
#include <torch/csrc/jit/codegen/cuda/kernel_ir.h>
#include <torch/csrc/jit/codegen/cuda/kernel_ir_dispatch.h>
#include <torch/csrc/jit/codegen/cuda/lower2device.h>
#include <torch/csrc/jit/codegen/cuda/lower_utils.h>
#include <torch/csrc/jit/codegen/cuda/predicate_compute.h>
#include <torch/csrc/jit/codegen/cuda/transform_iter.h>
#include <torch/csrc/jit/codegen/cuda/transform_replay.h>

namespace torch {
namespace jit {
namespace fuser {
namespace cuda {

namespace {

class ConditionalFromPredicateModifier : public kir::IrVisitor {
 public:
  ConditionalFromPredicateModifier() = delete;

  static std::vector<Expr*> fillPredicates(const std::vector<Expr*>& exprs) {
    ConditionalFromPredicateModifier cfpm(exprs);
    return cfpm.exprs_;
  }

 private:
  ConditionalFromPredicateModifier(const std::vector<Expr*>& exprs) {
    FUSER_PERF_SCOPE(
        "GpuLower::Lower::ConditionalFromPredicateModifier::process");
    kir::IrVisitor::handle(exprs);
  }

  using kir::IrVisitor::handle;

  void handle(Expr* expr) final {
    if (expr != nullptr && expr->predicate() != nullptr) {
      // Replace expr predicate with bool conditional
      auto conditional = generateConditional(expr->predicate());
      if (expr->predicate()->predicate_type() == PredicateType::Vectorize) {
        // TODO: This logic doesn't seem to fit well here, for unswitch the
        // logic is in the unroll loop to set the thread predicate to the expr.
        // I didn't have a quick way to do that so placing this here for now.
        TORCH_INTERNAL_ASSERT(
            expr->isA<kir::IfThenElse>(),
            "Predicate handling expects ITE statement.");
        auto ite = expr->as<kir::IfThenElse>();

        TORCH_INTERNAL_ASSERT(
            ite->thenBody().size() == 1,
            "Expecting predicated body to only have one vectorized expression.");
        auto vec_expr = ite->thenBody()[0];
        TORCH_INTERNAL_ASSERT(
            vec_expr->isA<UnaryOp>() || vec_expr->isA<LoadStoreOp>(),
            "Vectorize predicate exprs only supported on set operations.");
        TORCH_INTERNAL_ASSERT(
            ir_utils::isTvOp(vec_expr),
            "Vectorize predicate exprs only supported on tensor view operations.");
        if (!vec_expr->inputs()[0]->isConstScalar()) {
          conditional = SimplifyingIrBuilder::andExpr(
                            conditional,
                            GpuLower::current()->threadPredMap().getPredicate(
                                ir_utils::getTvOutput(vec_expr)))
                            ->as<Bool>();
        }
      }
      TORCH_INTERNAL_ASSERT(conditional != nullptr);
      expr->predicate()->setValue(conditional);
      TORCH_INTERNAL_ASSERT(expr->predicate()->value() != nullptr);
      setWritePredicate(expr, conditional);
    }

    // Note: [Predicate Inversion for CpAsync]
    // Today for vectorized support the pattern is:
    // Initialize buffer -> predicated load
    // For memcpy async:
    //    If we initialized and then loaded (without sync) it would be undefined
    //    behavior.
    // Initialize only the "virtual out of boundary" accesses.
    //  Memory allocated, but outside the virtual tensor space.
    //  Virtual tensor space today is effectively what would be allocated in
    //  global memory. Then only copy the "within bound" accesses.
    // This is a WAR today based on how our system is set up.
    //    We would want to have a separate concept of SMEM space from Virtual or
    //    GMEM space, so that we know we're only working with the allocated
    //    SMEM.
    //  If we hit outside the allocated SMEM bad things happen.
    // Today asserting in predicate removal making sure that the virtual and
    // SMEM boundaries line up based on the IterDomains.
    //
    // TODO: in a follow up we need to extend the predicate
    //  infrastructure to generate predicate for both gmem
    //  and smem, and the predicate removal will need to
    //  be extended as well for the perf critical regions.
    if (isPredicatedInitForCpAsync(expr)) {
      invertPredicateForGmemToSharedMemInitialize(expr);
    }

    kir::IrVisitor::handle(expr);
  }

  // Invert the predicate of given expr.
  void invertPredicateForGmemToSharedMemInitialize(Expr* expr) {
    auto pred = expr->predicate()->value();
    auto invert = SimplifyingIrBuilder::notExpr(pred);
    expr->predicate()->setValue(invert->as<Bool>());
  }

  // Detect if this expr is an initialization for vectorized
  //  cp asyc with predicates.
  bool isPredicatedInitForCpAsync(Expr* expr) {
    // Match the pattern:
    //  If(pred)
    //    TV = 0;
    //  where TV is the output of cp async.
    auto maybe_init = ir_utils::getMaybePredicatedSingleton(expr);
    return maybe_init.has_value() &&
        ir_utils::isCpAsyncInit(maybe_init.value());
  }

  void setWritePredicate(Expr* expr, Bool* read_cond) {
    if (expr->writePredicate() != nullptr) {
      auto write_cond = generateConditional(expr->writePredicate());
      if (write_cond) {
        expr->writePredicate()->setValue(write_cond);
      } else {
        // If generateConditional returns null, it means no specific
        // predicate needs to be used.
        expr->setWritePredicate(nullptr);
      }
    }
  }

  void handle(kir::IfThenElse* ite) final {
    TORCH_INTERNAL_ASSERT(ite->predicate() != nullptr);

    // If ite already has Bool conditional, handle internal expressions
    // Otherwise, generate conditional and update predicate
    if (!ite->predicate()->hasValue()) {
      auto conditional = generateConditional(ite->predicate());
      TORCH_INTERNAL_ASSERT(conditional != nullptr);
      TORCH_INTERNAL_ASSERT(conditional->isA<Bool>());

      // Update bool conditional in-place
      ite->predicate()->setValue(conditional);
      TORCH_INTERNAL_ASSERT(ite->predicate()->value() != nullptr);
    }
    kir::IrVisitor::handle(ite);
  }

  // Generate conditional according to PredicateType
  Bool* generateConditional(kir::Predicate* pred) {
    switch (pred->predicate_type()) {
      case PredicateType::Inline:
      case PredicateType::ReductionWrite:
      case PredicateType::Misaligned:
      case PredicateType::Shift:
      case PredicateType::Padding: {
        return PredicateCompute::getInlinePredicate(
            pred->expr(),
            for_loops_,
            pred->thread_pred(),
            pred->predicate_type());
      }
      case PredicateType::Vectorize: {
        std::vector<kir::ForLoop*> outer_loops;
        kir::ForLoop* vectorized_loop = nullptr;
        for (auto loop : for_loops_) {
          if (loop->iter_domain()->getParallelType() ==
              ParallelType::Vectorize) {
            vectorized_loop = loop;
            break;
          } else {
            outer_loops.emplace_back(loop);
          }
        }
        TORCH_INTERNAL_ASSERT(
            vectorized_loop != nullptr, "Should be unreachable.");
        return UnswitchPredicate::get(outer_loops, vectorized_loop);
      }
      case PredicateType::Unswitch: {
        return UnswitchPredicate::get(for_loops_, pred->unrolled_loop());
      }
      case PredicateType::Manual: {
        return pred->value();
      }
      default:
        break;
    }
    return nullptr;
  }
};

} // namespace

std::vector<Expr*> generateConditionalFromPredicate(
    const std::vector<Expr*>& exprs) {
  return ConditionalFromPredicateModifier::fillPredicates(exprs);
}

} // namespace cuda
} // namespace fuser
} // namespace jit
} // namespace torch