File: parallel_dimension_map.cpp

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (347 lines) | stat: -rw-r--r-- 11,056 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
#include <torch/csrc/jit/codegen/cuda/parallel_dimension_map.h>

#include <ATen/cuda/CUDAContext.h>
#include <torch/csrc/jit/codegen/cuda/expr_evaluator.h>
#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
#include <torch/csrc/jit/codegen/cuda/iter_visitor.h>
#include <torch/csrc/jit/codegen/cuda/kernel_expr_evaluator.h>
#include <torch/csrc/jit/codegen/cuda/lower2device.h>

#include <sstream>

namespace torch {
namespace jit {
namespace fuser {
namespace cuda {

void ParallelDimensionMap::build(Fusion* fusion) {
  // Scan all TVs to build ParallelType maps
  auto all_vals = fusion->usedMathVals();
  for (auto tv : ir_utils::filterByType<TensorView>(all_vals)) {
    for (auto id : tv->domain()->domain()) {
      registerConstantExtent(id);
      if (!isParallelTypeThread(id->getParallelType())) {
        continue;
      }
      handleParallelDomain(id);
    }
  }

  // Populate the dimension map for each parallel type
  for (const auto& kv : concrete_dom_map_) {
    auto pt = kv.first;
    const auto& concrete_dom_set = kv.second;
    TORCH_INTERNAL_ASSERT(!concrete_dom_set.empty());
    if (concrete_dom_set.size() == 1) {
      populateDimensionMapWithSingleCASet(pt, concrete_dom_set);
    } else {
      populateDimensionMapWithMultipleCASet(pt, concrete_dom_set);
    }
  }

  adjustMappingsForWarpPadding();
}

void ParallelDimensionMap::registerConstantExtent(IterDomain* id) {
  if (!id->extent()->isConstScalar()) {
    // Nothing to do if not constant
    return;
  }

  ExpressionEvaluator ee(id->fusion());
  auto extent_int = ee.evaluate(id->extent());
  TORCH_INTERNAL_ASSERT(
      extent_int.has_value(),
      "Extent of ",
      id->toString(),
      " should have been constant, but could not be evaluated at compile time.");

  auto const_extent = extent_int->as<int64_t>();

  // Uses index map
  auto concrete_id = getCAMappedConcreteDomain(id);

  auto existing_it = constant_extent_map_.find(id);

  // Adds the constant extent to the set for the concrete domain. If
  // multiple constants are found, this concrete domain has multiple
  // distinctive extents, which can happen with broadcast.
  if (existing_it == constant_extent_map_.end()) {
    constant_extent_map_.insert({concrete_id, {const_extent}});
  } else {
    existing_it->second.insert(const_extent);
  }
}

// Adds the conrecte domain of id to the mappsed set for its
// parallel type
void ParallelDimensionMap::handleParallelDomain(IterDomain* id) {
  auto pt = id->getParallelType();
  TORCH_INTERNAL_ASSERT(isParallelTypeThread(pt));
  auto concrete_id = getCAMappedConcreteDomain(id);

  auto it = concrete_dom_map_.find(pt);
  if (it == concrete_dom_map_.end()) {
    concrete_dom_map_.insert({pt, {concrete_id}});
  } else {
    it->second.insert(concrete_id);
  }
}

void ParallelDimensionMap::populateDimensionMapWithSingleCASet(
    ParallelType pt,
    const std::unordered_set<IterDomain*>& dom_set) {
  TORCH_INTERNAL_ASSERT(dom_set.size() == 1);

  // pt is used by only one concrete domain
  auto id = *dom_set.begin();
  auto it = constant_extent_map_.find(id);

  if (it != constant_extent_map_.end()) {
    TORCH_INTERNAL_ASSERT(
        it->second.size() == 1,
        "Only one value found mapped to parallel type ",
        stringifyThread(pt),
        " yet its bound to multiple extents.");
    dim_map_.insert({pt, IrBuilder::create<Int>(*(it->second.begin()))});
    exact_types_.insert(pt);
  } else {
    // Prefer to use blockDim/gridDim if not constant
    dim_map_.insert({pt, NamedScalar::getParallelDim(pt)});
    exact_types_.insert(pt);
  }
}

void ParallelDimensionMap::populateDimensionMapWithMultipleCASet(
    ParallelType pt,
    const std::unordered_set<IterDomain*>& dom_set) {
  TORCH_INTERNAL_ASSERT(dom_set.size() > 1);

  bool all_equal = true;
  // Use nullptr to signal it's not initialied yet
  Val* known_dimension = nullptr;
  // Use -1 to signal it's not initialied yet
  int64_t known_const = -1;

  // Check all of concrete domains to see if they match all together.
  for (auto concrete_id : dom_set) {
    if (concrete_id->isBroadcast()) {
      // Broadcasted concrete id's don't specify anything about shape
      continue;
    }
    // If this concrete domain has a constant extent, check if it
    // matches with the known constant extent.
    auto it = constant_extent_map_.find(concrete_id);
    if (it != constant_extent_map_.end()) {
      const auto& const_extent_set = it->second;
      // If multiple constants are detected, it's not exact.
      if (const_extent_set.size() > 1) {
        all_equal = false;
        break;
      }
      auto this_const = *(const_extent_set.begin());
      // known_const is initialized to -1
      if (known_const == -1) {
        known_const = this_const;
      } else if (known_const == this_const) {
        // Matched with previously known const. The extent of this
        // domain must be equal to that's previously known.
        continue;
      } else {
        // Unmatched. This dom_set extents may not be unique.
        all_equal = false;
        break;
      }
    }

    // At this point, it still remains undetermined whether this id
    // matches with those previously looked at. Constant check failed,
    // but symbolic matching may succeed.
    auto this_dimension = concrete_id->extent();
    if (known_dimension == nullptr) {
      // No previous dimension found yet
      known_dimension = this_dimension;
    } else {
      if (!equalDim(known_dimension, this_dimension)) {
        all_equal = false;
        break;
      }
    }
  }

  // If all_equal is still true, the dimension of this paralel type
  // must be exact.
  if (all_equal) {
    exact_types_.insert(pt);
  }
  // Use the const value, if found, as its dimension
  if (all_equal && known_const != -1) {
    dim_map_.insert({pt, IrBuilder::create<Int>(known_const)});
  } else {
    dim_map_.insert({pt, NamedScalar::getParallelDim(pt)});
  }
}

void ParallelDimensionMap::adjustMappingsForWarpPadding() {
  const auto gpu_lower = GpuLower::current();

  // If TIDx is padded to a multiple of the warp size, mark it as
  // non-exact.

  auto& warp_info = gpu_lower->getWarpPaddedParallelInfo();
  // TIDx isn't really padded if there isn't a warp reduction (this could
  // change)
  if (!(warp_info.is_tidx_padded && warp_info.has_warp_reduction)) {
    return;
  }

  const auto tidx_pt = ParallelType::TIDx;
  auto warp_size = at::cuda::warp_size();

  // If the dimension of TIDx is actually a multple of the warp size
  // before padding, it can be left as exact
  if (isExact(tidx_pt)) {
    auto tidx_dim = dynamic_cast<Int*>(get(tidx_pt));
    if (tidx_dim && tidx_dim->isConst()) {
      auto tidx_dim_val = tidx_dim->value().value();
      if (tidx_dim_val % warp_size == 0) {
        // Dimension of TIDx is a multiple of the warp size
        return;
      }
    }
    // If tidx is strictly defined as blockDim.x then it must be set to a
    // multiple of the warp and can be considered exact
    bool tidx_def_trivial = true;
    for (auto entry : concrete_dom_map_.at(tidx_pt)) {
      if (!entry->isA<NamedScalar>() ||
          !entry->as<NamedScalar>()->sameAs(
              NamedScalar::getParallelDim(tidx_pt))) {
        tidx_def_trivial = false;
      }
    }
    if (tidx_def_trivial) {
      return;
    }
  }

  // TIDx is padded to a multiple of warp. If it's known to be a
  // single warp, use the constant warp size as the dimension of
  // TIDx. Otherwise, just use blockDim.x.
  if (warp_info.is_tidx_single_warp) {
    dim_map_.at(ParallelType::TIDx) = IrBuilder::create<Int>(warp_size);
  } else {
    dim_map_.at(ParallelType::TIDx) =
        NamedScalar::getParallelDim(ParallelType::TIDx);
  }

  // TIDx is no longer exact
  exact_types_.erase(ParallelType::TIDx);
}

Val* ParallelDimensionMap::get(ParallelType pt) const {
  TORCH_INTERNAL_ASSERT(isParallelTypeThread(pt), "Invalid ParallelType: ", pt);
  auto it = dim_map_.find(pt);
  if (it == dim_map_.end()) {
    return nullptr;
  } else {
    return it->second;
  }
}

bool ParallelDimensionMap::isExact(ParallelType pt) const {
  return exact_types_.find(pt) != exact_types_.end();
}

IterDomain* ParallelDimensionMap::getCAMappedConcreteDomain(IterDomain* id) {
  return GpuLower::current()->caMap()->getConcreteMappedID(
      id, IdMappingMode::EXACT);
}

// Symbolically compares equality of two KIR vals. Comparison is done
// conservatively, so returning false does not guarantee non-equality.
bool ParallelDimensionMap::equalDim(Val* dim1, Val* dim2) {
  TORCH_INTERNAL_ASSERT(dim1 != nullptr && dim2 != nullptr);

  if (dim1 == dim2) {
    return true;
  }

  // When Both are Int, they are same if both have the same constant
  auto dim1_int = dynamic_cast<Int*>(dim1);
  auto dim2_int = dynamic_cast<Int*>(dim2);
  if (dim1_int && dim2_int) {
    if (dim1_int->isConst() && dim2_int->isConst()) {
      return dim1_int->value() == dim2_int->value();
    }
  }

  // When both are NamedScalar, they are same if Both have the same
  // name
  auto dim1_ns = dynamic_cast<NamedScalar*>(dim1);
  auto dim2_ns = dynamic_cast<NamedScalar*>(dim2);
  if (dim1_ns && dim2_ns) {
    return dim1_ns->name() == dim2_ns->name();
  }

  // Check recursively their definitions

  auto dim1_def = dim1->definition();
  auto dim2_def = dim2->definition();

  if (dim1_def == nullptr || dim2_def == nullptr) {
    return false;
  }

  // If both are BinaryOp or UnaryOp, check their inputs. Since these
  // Vals are IterDomain extents, UnaryOp should not occur, but
  // checking shouldn't be harmful.
  // TODO:
  //   We might be able to replace this with dim1->toInlineString() ==
  //   dim2->toInlineString()
  //   If we want this less conservative we could make an "exact map" which
  //   could be another mode in compute at that maps all iter domains, but not
  //   concretized broadcast axes and only forwards through non-concretized
  //   broadcast axes.
  if ((dim1_def->isA<BinaryOp>() && dim2_def->isA<BinaryOp>() &&
       (dim1_def->as<BinaryOp>()->getBinaryOpType() ==
        dim2_def->as<BinaryOp>()->getBinaryOpType())) ||
      (dim1_def->isA<UnaryOp>() && dim2_def->isA<UnaryOp>() &&
       (dim1_def->as<UnaryOp>()->getUnaryOpType() ==
        dim2_def->as<UnaryOp>()->getUnaryOpType()))) {
    for (const auto i : c10::irange(dim1_def->inputs().size())) {
      (void)i; // Suppress unused variable warning
      if (!equalDim(dim1_def->inputs()[0], dim2_def->inputs()[0])) {
        return false;
      }
    }
    return true;
  }

  return false;
}

std::string ParallelDimensionMap::toString() const {
  std::stringstream ss;
  for (auto pt : kParallelTypeThreads) {
    ss << pt << ": ";
    auto dim = get(pt);
    if (dim != nullptr) {
      ss << dim->toString();
      if (isExact(pt)) {
        ss << ", exact";
      } else {
        ss << ", non-exact";
      }
    } else {
      ss << "unused";
    }
    ss << "\n";
  }

  return ss.str();
}

} // namespace cuda
} // namespace fuser
} // namespace jit
} // namespace torch