File: lower_thread_predicate.cpp

package info (click to toggle)
pytorch 1.7.1-7
  • links: PTS, VCS
  • area: main
  • in suites: bullseye
  • size: 80,340 kB
  • sloc: cpp: 670,830; python: 343,991; ansic: 67,845; asm: 5,503; sh: 2,924; java: 2,888; xml: 266; makefile: 244; ruby: 148; yacc: 144; objc: 51; lex: 44
file content (283 lines) | stat: -rw-r--r-- 8,556 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283

#include <torch/csrc/jit/codegen/cuda/lower_thread_predicate.h>

#include <torch/csrc/jit/codegen/cuda/arith.h>
#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
#include <torch/csrc/jit/codegen/cuda/kernel_ir_builder.h>
#include <torch/csrc/jit/codegen/cuda/lower2device.h>
#include <torch/csrc/jit/codegen/cuda/lower_utils.h>

namespace torch {
namespace jit {
namespace fuser {

namespace {

Val* getPredicatePerParallelType(
    ParallelType pt,
    const ThreadPredicateMap::SourceMapType& source_map) {
  kir::IrBuilder ir_builder(GpuLower::current()->kernel());

  if (pt == ParallelType::BIDx || pt == ParallelType::BIDy ||
      pt == ParallelType::BIDz) {
    auto source = source_map.at(pt);
    TORCH_INTERNAL_ASSERT(!source.empty(), "No predicate source found");
    TORCH_INTERNAL_ASSERT(source.size() == 1, "Multiple sources detected");
    auto src = *source.begin();
    auto flag_name = kir::GridReduction::getPredicateFlagName(src);
    return ir_builder.create<kir::NamedScalar>(flag_name, DataType::Bool);
  } else {
    return ir_builder.eqExpr(
        kir::NamedScalar::getParallelIndex(pt), ir_builder.create<kir::Int>(0));
  }
}

kir::Bool* getPredicate(
    const ir_utils::ParallelTypeBitmap& bits,
    const ThreadPredicateMap::SourceMapType& source_map) {
  kir::IrBuilder ir_builder(GpuLower::current()->kernel());

  if (bits.none()) {
    return ir_builder.create<kir::Bool>(true);
  }

  Val* pred = nullptr;

  for (const auto& pt_bool : bits.getMap()) {
    if (pt_bool.second) {
      auto tp = getPredicatePerParallelType(pt_bool.first, source_map);
      pred = (pred == nullptr) ? tp : ir_builder.andExpr(pred, tp);
    }
  }

  // Should never be hit.
  TORCH_INTERNAL_ASSERT(pred != nullptr);

  TORCH_INTERNAL_ASSERT(
      pred->getDataType().value() == DataType::Bool,
      "Tried to return a predicate that is not a bool val.");

  return pred->as<kir::Bool>();
}

void mergeSourceMap(
    ThreadPredicateMap::SourceMapType& dst,
    const ThreadPredicateMap::SourceMapType& src) {
  for (const auto& kv : src) {
    const auto& src_key = kv.first;
    const auto& src_value = kv.second;
    std::unordered_set<const TensorView*>& dst_set = dst[src_key];
    for (const auto& src_tensor : src_value) {
      dst_set.insert(src_tensor);
    }
  }
}

void addToSouceMap(
    ThreadPredicateMap::SourceMapType& dst,
    const TensorView* tv,
    const ir_utils::ParallelTypeBitmap& reducton_pred) {
  for (const auto& kv : reducton_pred.getMap()) {
    if (kv.second) {
      ParallelType ptype = kv.first;
      dst[ptype].insert(tv);
    }
  }
}

void maskSouceMap(
    ThreadPredicateMap::SourceMapType& src_map,
    const ir_utils::ParallelTypeBitmap& mask) {
  for (const auto& kv : mask.getMap()) {
    if (!kv.second) {
      ParallelType ptype = kv.first;
      src_map[ptype].clear();
    }
  }
}

// A bit of a hack for now for GEMM tiling so we don't fetch tiles multiple
// times. It's safe to do, there may simply be a better place to do it.
void avoidRedundantWritesToSmem(
    TensorView* out_tv,
    ir_utils::ParallelTypeBitmap& pred) {
  if (out_tv->getMemoryType() == MemoryType::Shared) {
    for (size_t i = 0; i < out_tv->nDims(); i++) {
      auto id = out_tv->getComputeAtAxis(i).first;
      if (out_tv->axis(i)->isBroadcast() && id->isThreadDim()) {
        pred.set(id->getParallelType(), true);
      }
    }
  }
}

} // namespace

// Update the reduction_deps bitset based on provided Expr
void ThreadPredicateMap::updateBitSet(Expr* expr) {
  FUSER_PERF_SCOPE("ThreadPredicateMap::updateBitSet");

  // Which predicates were set for the inputs
  ir_utils::ParallelTypeBitmap input_preds;

  // Which dims are reductions in inputs
  ir_utils::ParallelTypeBitmap input_reductions;

  // Which dims are bcast in inputs
  ir_utils::ParallelTypeBitmap input_bcasts;

  SourceMapType src_map;

  // Run through inputs and update bitsets
  for (const auto* inp : expr->inputs()) {
    if (!ir_utils::isTV(inp))
      continue;

    auto tv_inp = ir_utils::asConstTV(inp);
    TORCH_INTERNAL_ASSERT(
        thread_predicates_.find(tv_inp) != thread_predicates_.end(),
        "Thread predicate map was not initialized, couldn't find ",
        inp);

    input_preds |= at(tv_inp).first;

    mergeSourceMap(src_map, at(tv_inp).second);

    ir_utils::ParallelTypeBitmap id_reductions;
    ir_utils::ParallelTypeBitmap id_bcasts;
    ir_utils::ParallelTypeBitmap id_ptypes;

    for (auto id : tv_inp->domain()->domain()) {
      if (id->isThread()) {
        id_ptypes.set(id->getParallelType(), true);
        if (id->isReduction())
          id_reductions.set(id->getParallelType(), true);
        if (id->isBroadcast())
          id_bcasts.set(id->getParallelType(), true);
      }
    }

    // Validate the combination of ptypes, reductions, bcasts
    for (size_t i = 0; i < ir_utils::ParallelTypeBitmap::num_p_type; i++) {
      if (input_reductions[i]) {
        if (id_ptypes[i]) {
          TORCH_INTERNAL_ASSERT(
              id_reductions[i],
              "Mismatched parallelized reductions found on inputs of epxr: ",
              expr);
          TORCH_CHECK(
              !id_bcasts[i],
              "Invalid broadcast and reduction combination, tried to parallelize both with the same thread dim: ",
              inp);
        }
      }
    }

    // Accumulate
    input_reductions |= id_reductions;
    input_bcasts |= id_bcasts;

    if (id_reductions.any()) {
      // add tv_inp as a source
      addToSouceMap(src_map, tv_inp, id_reductions);
    }
  }

  // Update map for this tv, before accumulating to other inputs
  // Add any reductions this id has to any input predicates
  auto output_preds = input_preds | input_reductions;

  // Figure out which dims bcast wants to reset
  auto bcast_reset_map = output_preds & input_bcasts;

  // Flip it to make a bit mask
  bcast_reset_map = ~bcast_reset_map;

  // Get rid of any reductions which are bcasted
  output_preds &= bcast_reset_map;
  // Similarly, drop non-relevant source tensors
  maskSouceMap(src_map, bcast_reset_map);

  // Run through outputs and set bitset predicates
  for (auto* out : expr->outputs()) {
    if (!ir_utils::isTV(out))
      continue;
    TORCH_INTERNAL_ASSERT(find(ir_utils::asConstTV(out)) == end());
    auto pred_for_this_out = output_preds;
    avoidRedundantWritesToSmem(ir_utils::asTV(out), pred_for_this_out);
    insert(ir_utils::asConstTV(out), pred_for_this_out, src_map);
  }
}

// TODO(kir): revisit this - can we build it from the kernel IR?
ThreadPredicateMap::ThreadPredicateMap(Fusion* _fusion) : fusion_(_fusion) {
  FUSER_PERF_SCOPE("ThreadPredicateMap");
  // Initialize mapping for input tensors
  for (auto inp : fusion_->inputs()) {
    if (ir_utils::isTV(inp)) {
      insert(
          ir_utils::asConstTV(inp),
          ir_utils::ParallelTypeBitmap(),
          SourceMapType());
    }
  }
  for (auto expr : fusion_->exprs(true)) {
    updateBitSet(expr);
  }
}

ThreadPredicateMap::const_iterator ThreadPredicateMap::find(
    const TensorView* tv) const {
  return thread_predicates_.find(tv);
}

ThreadPredicateMap::const_iterator ThreadPredicateMap::end() const {
  return thread_predicates_.end();
}

const ThreadPredicateMap::MapType::mapped_type& ThreadPredicateMap::at(
    const TensorView* tv) const {
  return thread_predicates_.at(tv);
}

ThreadPredicateMap::MapType::mapped_type& ThreadPredicateMap::at(
    const TensorView* tv) {
  return thread_predicates_.at(tv);
}

ThreadPredicateMap::MapType::mapped_type& ThreadPredicateMap::operator[](
    const TensorView* tv) {
  return thread_predicates_[tv];
}

void ThreadPredicateMap::insert(
    const TensorView* tv,
    const ir_utils::ParallelTypeBitmap& pred,
    const SourceMapType& src_map) {
  insert(tv, std::make_pair(pred, src_map));
}

void ThreadPredicateMap::insert(
    const TensorView* tv,
    const std::pair<ir_utils::ParallelTypeBitmap, SourceMapType>&
        pred_and_src) {
  thread_predicates_.insert(std::make_pair(tv, pred_and_src));
}

void ThreadPredicateMap::duplicate(
    const TensorView* copy,
    const TensorView* origin) {
  if (find(origin) != end()) {
    insert(copy, at(origin).first, at(origin).second);
  }
}

kir::Bool* ThreadPredicateMap::getExpr(const TensorView* out_tv) const {
  TORCH_INTERNAL_ASSERT(find(out_tv) != end(), "Couldn't find ", out_tv);
  return getPredicate(at(out_tv).first, at(out_tv).second);
}

} // namespace fuser
} // namespace jit
} // namespace torch