File: index_compute.h

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (432 lines) | stat: -rw-r--r-- 15,833 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
#pragma once

#include <torch/csrc/jit/codegen/cuda/iter_visitor.h>
#include <torch/csrc/jit/codegen/cuda/reference_tensor.h>
#include <torch/csrc/jit/codegen/cuda/root_domain_map.h>

#include <unordered_map>
#include <unordered_set>
#include <vector>

/*
 * Index compute takes in a list of indices typically generated from the
 * surrounding for loop nest. The number of indicies are intended to match the
 * number of dimensions of the incomming TensorView which may have less or more
 * dimensions than its root due to split/merge operations.
 * Split/merge operations are then replayed backwards produce resulting
 * indices (based on input indices) that match the root dimension.
 *
 * For example with GLOBAL tensor:
 * TV[I, J]
 * TV[Io, Ii{4}, J] = TV.split(I, factor=4)
 * ALLOC: NONE
 * INDEX: indexCompute {i, j, k} -> {i * 4 + j, k}
 * FLATTENED_INDEX: {i * 4 + j, k} -> {i * 4 + j * J + k}
 * PREDICATE: {i * 4 + j, k} -> i * 4 + j < I
 *
 *
 * For example with SHARED tensor:
 *
 * global_TV[I, J]
 * global_TV[Io, Ii{4}, J] = global_TV.split(I, factor=4)
 * smem_TV.compute_at(global_TV, 1)
 * global_TV.parallelize(1, threadIDx.x)
 *
 * ALLOC: alloc(smem_TV, 4 x J)
 * INDEX: indexCompute(smem_TV, {threadIdx.x, k}) -> {threadIdx.x, k}
 * FLATTENED_INDEX: {threadIdx.x * 4 + j, k} -> {threadIdx.x * 4 + j * J + k}
 * PREDICATE: {threadIdx.x * 4 + j, k} -> threadIdx.x * 4 + j < I // Same as if
 * global
 *
 *
 * For example with LOCAL tensor:
 * global_TV[I, J, K]
 * global_TV[Io, Ii{4}, J] = global_TV.split(I, factor=4)
 * reg_TV.compute_at(global_TV, 1)
 * global_TV.parallelize(1, threadIDx.x)
 * global_TV{i, j, k, l} -> { i * 4 + j, k, l }
 * global_TV{ i * 4 + j, k, l } -> { i * 4 + j * J * K  +  k * K  +  l}
 *
 * ALLOC: alloc(reg_TV, J x K)
 * INDEX: {k, l} -> {k, l}
 * FLATTENED_INDEX: {k, l} -> {k * J + l}
 * PREDICATE: i * 4 + j < I && k < J && l < K ->  // Same as if global
 *
 * These indices can then be flattened later based on strides.
 */

namespace torch {
namespace jit {
namespace fuser {
namespace cuda {

class ContigIDs;
class LoopIndexing;

class IndexCompute : public BackwardVisitor {
 protected:
  using BackwardVisitor::handle;

  void handle(Split*) override;
  void handle(Merge*) override;
  void handle(Expr*) override;
  void handle(Swizzle2D*) override;

  // return extent_map_[id] if exists, else return id->extent()
  Val* getExtent(IterDomain* id) const;

  //! True if a domain is not used to index
  bool isZero(IterDomain* id) const;
  //! True if any dependent of a domain is not used to index
  bool hasZeroMerged(IterDomain* id) const;

  //! Returns the concrete ID from the compute at EXACT mode map if
  //! concrete_id_pass == true, otherwise returns id passed in.
  //! Helps unify the expr handling logic in reference domain and concrete id
  //! based traversal.
  IterDomain* maybeGetExactMapConcreteID(IterDomain* id);

  //! (Concrete indexing pass only)
  //!  Collect permissive index binding from the given expression.
  //! See also permissive_map_ and LoopIndexing::getBackwardOutOfLineExprList.
  void collectIndexIntoPermissiveMap(const LoopIndexing& loop_indexing);

  //! (Concrete indexing pass only)
  //!  Iterate through id_expr's input and pull index vals from permissive
  //! map, when both of the following are true:
  //!    1. the output id is missing in index_map_.
  //!    2. the output id is found in permissive map.
  void updateIndexMapFromPermissiveMap(const Expr* id_expr);

  // Tensor domain we're mapping back to root
  const TensorDomain* td_; // NOLINT

  // Map we update as we propagate backward, containing all IDs in the
  // propagation. Initial indices are mapped with this map at tv->domain()
  // and are back propagated to tv->getRootDomain(). This index_map_ keeps the
  // indices at intermediate IterDomain's in that back propagation.
  std::unordered_map<IterDomain*, Val*> index_map_; // NOLINT

  // Map from IterDomain to their broadcasted extent. If a TV has I0*I1 but its
  // producer has B0*I1 this map will contain a mapping from the ID{B0*I1} to
  // the extent I0*I1. Also contains updated extents if we merge in a 0 index.
  // See zero_merged_in_.
  std::unordered_map<IterDomain*, Val*> extent_map_; // NOLINT

  // Keeps track of domains that do not contribute to indexing
  std::unordered_set<IterDomain*> zero_domains_; // NOLINT

  // This set keeps track of IterDomain's that have had a zero index merged into
  // them. This happens if we do something like tv->axis(0)->split(4) then
  // tv->computeAt(1, ...) if this tensor is in smem or lmem the backward
  // indexing would be (0, i) then when we do the backward computation that zero
  // and i would attempt to be merged together. We handle indices like these
  // specially.
  std::unordered_set<IterDomain*> zero_merged_in_;

  // IDs that are a result of contiguous merges
  std::unordered_set<IterDomain*> contig_ids_;

  // Map from root to indexed domains
  std::unordered_map<IterDomain*, IterDomain*> root_to_indexed_id_;

  // Mentions if we should propagate an index down a particular IterDomain path
  // if there's an option
  std::unordered_set<IterDomain*> preferred_paths_;

  // Map from IterDomains to halo-extended extents in corresponding
  // reference tensor
  std::unordered_map<IterDomain*, Val*> reference_halo_extent_map_;

  // Temporary flag which tells IndexCompute to use concrete id's from the exact
  // map rather than the actual IDs used in the ID expressions.
  bool concrete_id_pass_ = false;

  // Mode of swizzle that are activated in this index compute
  //  instance. Will treat swizzles of different mode as no-op.
  // Currently data mode swizzles are handled same as before in IndexSwizzle
  //  pass, while loop mode swizzles are handled early on in concrete indexing
  //  pass. See also [Note on swizzle mode]
  SwizzleMode swizzle_mode_ = SwizzleMode::NoSwizzle;

  // (Concrete id pass only)
  // Contains the indexing math that could be resolved with only the
  //  iterdomains on the right of the consumer_tv's ca axis, i.e. the
  //  ones that corresponding to the loops that consumer_tv would not
  //  share with any of its consumers.
  // These indexing vals should be kept separate from index_map_ and
  //  should only be used when the indexing traversal follows the
  //  order defined in LoopIndexingAnalysis::traverseFromDomainVals.
  std::unordered_map<IterDomain*, Val*> permissive_index_map_;

 public:
  const std::unordered_map<IterDomain*, Val*>& indexMap() const {
    return index_map_;
  }

  const std::unordered_map<IterDomain*, Val*>& extentMap() const {
    return extent_map_;
  }

  const std::unordered_set<IterDomain*>& zeroDomains() const {
    return zero_domains_;
  }

  const std::unordered_set<IterDomain*>& zeroMergedIn() const {
    return zero_merged_in_;
  }

  const std::unordered_map<IterDomain*, IterDomain*>& rootToContigID() const {
    return root_to_indexed_id_;
  }

  // Propagate back from _td using initial_index_map
  IndexCompute(
      const TensorDomain* _td,
      std::unordered_map<IterDomain*, Val*> initial_index_map,
      std::unordered_map<IterDomain*, Val*> _extent_map,
      std::unordered_set<IterDomain*> zero_domains,
      std::unordered_set<IterDomain*> _zero_merged_in,
      std::unordered_set<IterDomain*> preferred_paths = {},
      std::unordered_map<IterDomain*, Val*> reference_halo_extent_map = {});

  IndexCompute(
      const TensorDomain* _td,
      std::unordered_map<IterDomain*, Val*> initial_index_map,
      std::unordered_map<IterDomain*, Val*> _extent_map,
      std::unordered_set<IterDomain*> zero_domains,
      std::unordered_set<IterDomain*> _zero_merged_in,
      const ContigIDs& contig_finder,
      std::unordered_set<IterDomain*> preferred_paths = {},
      std::unordered_map<IterDomain*, Val*> reference_halo_extent_map = {});

  // Entry point used for using concrete id based traversal. This traversal is
  // assumed to start at leaf IDs provided by initial_index_map.
  IndexCompute(
      std::unordered_map<IterDomain*, Val*> initial_index_map,
      std::unordered_set<IterDomain*> zero_domains,
      std::unordered_set<IterDomain*> preferred_paths,
      std::unordered_map<IterDomain*, Val*> concrete_halo_extent_map);

  // Updates index_map, extent_map, and zero_merged_in based on id_map and
  // returns a new IndexCompute ready to be used.
  IndexCompute updateIndexCompute(
      const TensorDomain* new_td,
      const std::unordered_map<IterDomain*, IterDomain*>& id_map,
      const ContigIDs& contig_finder,
      const std::unordered_map<IterDomain*, Val*>& reference_halo_extent_map =
          {}) const;

  // Interface to run index traversal through loop indexing analysis result to
  // be used with the entry point for concrete id based traversal.
  void run(const LoopIndexing& loop_indexing);

  virtual void run();
};

//! Apply swizzle and update root indices accordingly
class IndexSwizzle : public IndexCompute {
 public:
  IndexSwizzle(
      const TensorView* tv,
      std::unordered_map<IterDomain*, Val*> initial_index_map,
      std::unordered_map<IterDomain*, Val*> extent_map,
      std::unordered_set<IterDomain*> zero_domains,
      std::unordered_set<IterDomain*> zero_merged_in);

  IndexSwizzle(
      const TensorView* tv,
      const TensorDomain* domain,
      std::unordered_map<IterDomain*, Val*> initial_index_map,
      std::unordered_map<IterDomain*, Val*> extent_map,
      std::unordered_set<IterDomain*> zero_domains,
      std::unordered_set<IterDomain*> zero_merged_in);

  void run() override;

 protected:
  using IndexCompute::handle;

  void handle(Expr* e) override;

  void handle(Swizzle2D* swizzle_2d) override;

 private:
  const TensorView* tv_ = nullptr;
  SwizzleType swizzle_type_ = SwizzleType::NoSwizzle;
  std::vector<IterDomain*> ids_to_swizzle_;
  std::unordered_set<IterDomain*> swizzled_ids_;
};

//! Predicate information of a root or contiguous merged domain
class RootPredicateInfo {
  friend class Index;

 public:
  const auto& startPredicate() const {
    return start_predicate_;
  }

  auto& startPredicate() {
    return start_predicate_;
  }

  const auto& startOffset() const {
    return start_offset_;
  }

  const auto& stopPredicate() const {
    return stop_predicate_;
  }

  const auto& stopOffset() const {
    return stop_offset_;
  }

  const auto& rootIds() const {
    return root_ids_;
  }

  //! Return a false RootPredicateInfo, i.e., both start and stop
  //! predicates are false.
  static RootPredicateInfo getFalseInfo();

 private:
  // prdicate for lower end
  Bool* start_predicate_ = nullptr;
  // prdicate for upper end
  Bool* stop_predicate_ = nullptr;
  // Offset of the start predicate
  Val* start_offset_ = nullptr;
  // Offset of the stop predicate
  Val* stop_offset_ = nullptr;
  // Track which roots have been handled by the generated predicates
  std::unordered_set<IterDomain*> root_ids_;
};

// Simple interface for IndexCompute
// If getComputeAtAxis and more generally TensorView const model is fixed, we
// can make the below tensorviews const.
class Index {
 private:
  // Producer indexing if it's in shared or local memory
  static std::vector<Val*> getNonGlobalProducerStridedIndices(
      TensorView* producer,
      const TensorView* consumer,
      const std::vector<kir::ForLoop*>& loops);

  // Consumer indexing if it's in shared or local memory
  static std::vector<Val*> getNonGlobalConsumerStridedIndices(
      const TensorView* consumer,
      const std::vector<kir::ForLoop*>& loops);

  // Producer if it's in global memory
  static std::vector<Val*> getGlobalProducerStridedIndices(
      TensorView* producer,
      const TensorView* consumer,
      const std::vector<kir::ForLoop*>& loops);

  // Consumer indexing if it's in global memory
  static std::vector<Val*> getGlobalConsumerStridedIndices(
      const TensorView* consumer,
      const std::vector<kir::ForLoop*>& loops);

 public:
  // Indexing functions
  // Consumer = Producer
  // i.e. T0 = T1... -> T0 is the consumer, T1 is the producer
  // Producer indexing dispatch
  static kir::TensorIndex* getProducerIndex(
      TensorView* producer,
      const TensorView* consumer,
      const std::vector<kir::ForLoop*>& loops);

  // Consumer index dispatch
  static kir::TensorIndex* getConsumerIndex(
      const TensorView* consumer,
      const std::vector<kir::ForLoop*>& loops);

  //! Returns a vector of strided indices mapped onto the (rfactor)
  //! root domain of a producer tensor. The size of the returned
  //! vector is guaranteed to be equal to the number of axes of the
  //! indexing root domain.
  static std::vector<Val*> getProducerStridedIndices(
      TensorView* producer,
      const TensorView* consumer,
      const std::vector<kir::ForLoop*>& loops);

  //! Returns a vector of strided indices mapped onto the (rfactor)
  //! root domain of a consumer tensor. The size of the returned
  //! vector is guaranteed to be equal to the number of axes of the
  //! indexing root domain.
  static std::vector<Val*> getConsumerStridedIndices(
      const TensorView* consumer,
      const std::vector<kir::ForLoop*>& loops);

  //! Returns a vector of strided indices mapped onto the (rfactor)
  //! root domain of a consumer tensor. The returned index is intended
  //! to be used to index into arange or Philox pseudo random sequences
  static std::vector<Val*> getLinearIndex(
      TensorView* consumer_tv,
      const std::vector<kir::ForLoop*>& loops);

  //! Take a consumer tensorview and loop nest and generates predicates
  //! associated with the concrete roots of the loop nest. Returns a list of
  //! predicates, and a list of concrete roots they're associated with. It is
  //! assumed that no predicate is required if index[i] is an index directly
  //! from a for loop. This will not catch all cases if we actually have static
  //! size information for example:
  //!
  //! TV[I].split(4)
  //! would produce the code:
  //! for(i : I/4)
  //!   for(j : 4)
  //!     if( i * 4 + j < TV.size(0))
  //!       TV[i * 4 + j]...
  //!
  //! However if we had TV.size[0] = 16 at "compile time" then we wouldn't need
  //! the predicate. This will be caught by canOmitPredicate in the predicate
  //! lowering
  //!
  //! unswitch_or_vec_loop is the for loop to start the unswitch like predicate,
  //! this is not a bool value as if we have an unswitch loop with a vectorized
  //! loop inside, we only want to base the "unswitch" like predicate on the
  //! vectorized loop.
  static std::vector<RootPredicateInfo> getReferenceRootPredicates(
      TensorView* consumer_tv,
      const std::vector<kir::ForLoop*>& loops,
      kir::ForLoop* unswitch_or_vec_loop,
      bool padding_predicate);
};

// Used for local and shared index mapping. Returns a map from loops
// to loop indices as well as a set of loops that do not contribute to
// indexing.
// TODO: could be cleaned up further.
std::pair<
    std::unordered_map<kir::ForLoop*, Val*>,
    std::unordered_set<kir::ForLoop*>>
indexMapFromTV(
    const TensorView* tv,
    const std::vector<kir::ForLoop*>& loops,
    kir::ForLoop* alloc_loop,
    bool as_consumer,
    kir::ForLoop* double_buffer_loop = nullptr);

//! Set "pragma unroll" required for loops that indexing of Local
//! tensors depends on.
//!
//! \param tv Indexed tensor
//! \param alloc_loop Allocation loop of tv
//! \param loops The current loop structure
//! \param id_map Producer-to-consumer map in case of indexing as producer
void ensureStaticIndexing(
    const TensorView* tv,
    kir::ForLoop* alloc_loop,
    const std::vector<kir::ForLoop*>& loops,
    const std::unordered_map<IterDomain*, IterDomain*>& id_map = {});

} // namespace cuda
} // namespace fuser
} // namespace jit
} // namespace torch