File: index_compute.h

package info (click to toggle)
pytorch 1.7.1-7
  • links: PTS, VCS
  • area: main
  • in suites: bullseye
  • size: 80,340 kB
  • sloc: cpp: 670,830; python: 343,991; ansic: 67,845; asm: 5,503; sh: 2,924; java: 2,888; xml: 266; makefile: 244; ruby: 148; yacc: 144; objc: 51; lex: 44
file content (193 lines) | stat: -rw-r--r-- 6,865 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
#pragma once

#include <torch/csrc/jit/codegen/cuda/iter_visitor.h>

#include <unordered_map>
#include <unordered_set>
#include <vector>

/*
 * Index compute takes in a list of indices typically generated from the
 * surrounding for loop nest. The number of indicies are intended to match the
 * number of dimensions of the incomming TensorView which may have less or more
 * dimensions than its root due to split/merge operations.
 * Split/merge operations are then replayed backwards produce resulting
 * indices (based on input indices) that match the root dimension.
 *
 * For example with GLOBAL tensor:
 * TV[I, J]
 * TV[Io, Ii{4}, J] = TV.split(I, factor=4)
 * ALLOC: NONE
 * INDEX: indexCompute {i, j, k} -> {i * 4 + j, k}
 * FLATTENED_INDEX: {i * 4 + j, k} -> {i * 4 + j * J + k}
 * PREDICATE: {i * 4 + j, k} -> i * 4 + j < I
 *
 *
 * For example with SHARED tensor:
 *
 * global_TV[I, J]
 * global_TV[Io, Ii{4}, J] = global_TV.split(I, factor=4)
 * smem_TV.compute_at(global_TV, 1)
 * global_TV.parallelize(1, threadIDx.x)
 *
 * ALLOC: alloc(smem_TV, 4 x J)
 * INDEX: indexCompute(smem_TV, {threadIdx.x, k}) -> {threadIdx.x, k}
 * FLATTENED_INDEX: {threadIdx.x * 4 + j, k} -> {threadIdx.x * 4 + j * J + k}
 * PREDICATE: {threadIdx.x * 4 + j, k} -> threadIdx.x * 4 + j < I // Same as if
 * global
 *
 *
 * For example with LOCAL tensor:
 * global_TV[I, J, K]
 * global_TV[Io, Ii{4}, J] = global_TV.split(I, factor=4)
 * reg_TV.compute_at(global_TV, 1)
 * global_TV.parallelize(1, threadIDx.x)
 * global_TV{i, j, k, l} -> { i * 4 + j, k, l }
 * global_TV{ i * 4 + j, k, l } -> { i * 4 + j * J * K  +  k * K  +  l}
 *
 * ALLOC: alloc(reg_TV, J x K)
 * INDEX: {k, l} -> {k, l}
 * FLATTENED_INDEX: {k, l} -> {k * J + l}
 * PREDICATE: i * 4 + j < I && k < J && l < K ->  // Same as if global
 *
 * These indices can then be flattened later based on strides.
 */

namespace torch {
namespace jit {
namespace fuser {

class IndexCompute : public BackwardVisitor {
 private:
  using BackwardVisitor::handle;
  void handle(Split*) override;
  void handle(Merge*) override;
  void handle(Expr*) override;

  // return extent_map_[id] if exists, else return id->extent()
  Val* getExtent(kir::IterDomain* id);

  bool hasZeroMerged(kir::IterDomain* id);

  // Tensor domain we're mapping back to root
  const TensorDomain* td_;

  // Map we update as we propagate backward, containing all IDs in the
  // propagation. Initial indices are mapped with this map at tv->domain()
  // and are back propagated to tv->rootDomain(). This index_map_ keeps the
  // indices at intermediate IterDomain's in that back propagation.
  std::unordered_map<kir::IterDomain*, Val*> index_map_;

  // Map from IterDomain to their broadcasted extent. If a TV has I0*I1 but its
  // producer has B0*I1 this map will contain a mapping from the ID{B0*I1} to
  // the extent I0*I1. Also contains updated extents if we merge in a 0 index.
  // See zero_merged_in_.
  std::unordered_map<kir::IterDomain*, Val*> extent_map_;

  // This set keeps track of IterDomain's that have had a zero index merged into
  // them. This happens if we do something like tv->axis(0)->split(4) then
  // tv->computeAt(1, ...) if this tensor is in smem or lmem the backward
  // indexing would be (0, i) then when we do the backward computation that zero
  // and i would attempt to be merged together. We handle indices like these
  // specially.
  std::unordered_set<kir::IterDomain*> zero_merged_in_;

  // IDs that are a result of contiguous merges
  std::unordered_set<kir::IterDomain*> contig_ids;

 public:
  const std::unordered_map<kir::IterDomain*, Val*> indexMap() const {
    return index_map_;
  }

  const std::unordered_map<kir::IterDomain*, Val*> extentMap() const {
    return extent_map_;
  }

  std::unordered_set<kir::IterDomain*> zeroMergedIn() const {
    return zero_merged_in_;
  }

  // Propagate back from _td using initial_index_map
  IndexCompute(
      const TensorDomain* _td,
      std::unordered_map<kir::IterDomain*, Val*> initial_index_map,
      std::unordered_map<kir::IterDomain*, Val*> _extent_map,
      std::unordered_set<kir::IterDomain*> _zero_merged_in,
      const std::vector<bool>& _root_contiguity);

  // Updates index_map, extent_map, and zero_merged_in based on id_map and
  // returns a new IndexCompute ready to be used. new_index_entries are not
  // mapped, but are added to index_map.
  IndexCompute updateIndexCompute(
      const TensorDomain* new_td,
      const std::unordered_map<IterDomain*, IterDomain*>& id_map,
      std::unordered_map<kir::IterDomain*, Val*> new_index_entries,
      const std::vector<bool>& _root_contiguity);

  // Map producer contiguity information to consumer, if entries don't match
  // mark as false
  static std::vector<bool> contiguityPasC(
      TensorDomain* producer,
      TensorDomain* consumer);

  static std::vector<bool> contiguityAnd(
      const std::vector<bool>& contig1,
      const std::vector<bool>& contig2);
};

// Simple interface for IndexCompute
// If getComputeAtAxis and more generally TensorView const model is fixed, we
// can make the below tensorviews const.
class Index {
 private:
  // Producer indexing if it's in shared or local memory
  static kir::TensorIndex* getProducerIndex_impl(
      TensorView* producer,
      TensorView* consumer,
      const std::vector<kir::ForLoop*>& loops);

  // Consumer indexing if it's in shared or local memory
  static kir::TensorIndex* getConsumerIndex_impl(
      TensorView* consumer,
      const std::vector<kir::ForLoop*>& loops);

  // Producer if it's in global memory
  static kir::TensorIndex* getGlobalProducerIndex(
      TensorView* producer,
      TensorView* consumer,
      const std::vector<kir::ForLoop*>& loops);

  // Consumer indexing if it's in global memory
  static kir::TensorIndex* getGlobalConsumerIndex(
      TensorView* consumer,
      const std::vector<kir::ForLoop*>& loops);

 public:
  // Indexing functions
  // Consumer = Producer
  // i.e. T0 = T1... -> T0 is the consumer, T1 is the producer
  // Producer indexing dispatch
  static kir::TensorIndex* getProducerIndex(
      TensorView* producer,
      TensorView* consumer,
      const std::vector<kir::ForLoop*>& loops);

  // Consumer index dispatch
  static kir::TensorIndex* getConsumerIndex(
      TensorView* consumer,
      const std::vector<kir::ForLoop*>& loops);

  // Consumer indices for predicates, keep all indices matching in root domain.
  // Even those not used for physical addressing. Returns pair <root indices, if
  // indices are mapped to rfactor dom>
  static std::pair<std::vector<Val*>, bool> getConsumerRootPredIndices(
      TensorView* consumer,
      const std::vector<kir::ForLoop*>& loops,
      const std::vector<bool>& root_contiguity,
      bool unroll = false);
};

} // namespace fuser
} // namespace jit
} // namespace torch