File: lower_shift.h

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (235 lines) | stat: -rw-r--r-- 8,268 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
#pragma once

#include <c10/macros/Export.h>

#include <torch/csrc/jit/codegen/cuda/dispatch.h>
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
#include <torch/csrc/jit/codegen/cuda/kernel_ir.h>

#include <vector>

namespace torch {
namespace jit {
namespace fuser {
namespace cuda {

class LoopIndexing;

//! Auxiliary class to represent information about halo of an axis
class AxisHaloInfo {
 public:
  //! Width of halo.
  //!
  //! pos is either 0 or 1. The width of halo at offset zero is set
  //! when pos is 0.
  int width(int pos) const;

  //! Sum of the widths of both widths
  int width() const;

  const auto& widths() const {
    return widths_;
  }

  //! Set the halo width of either side.
  //! pos is either 0 or 1. The width of halo at offset zero is set
  //! when pos is 0.
  void setWidth(int pos, int width);

  //! Extend the halo width to account for another axis.
  void merge(int pos, int other);

  //! Extend the halo width to account for another axis.
  void merge(const AxisHaloInfo& other);

  //! True when halo may be attached
  bool hasHalo() const;

  std::string toString() const;

 private:
  //! Sizes of the halo regions of two sides. Both values are zero for
  //! axes with no halo. When an axis has halo at offset zero,
  //! widths_[0] is non-zero and designates the size of the
  //! halo. Similarly, non-zero widths_[1] means the axis has halo at
  //! the other end of the axis.
  std::array<int, 2> widths_ = {0, 0};
};

//! Helper class for lowering tensors with halo. Only valid at the
//! lowering time.
class TORCH_CUDA_CU_API HaloInfo {
 public:
  //! Scan a fusion and collect all information for lowering
  void build(Fusion* fusion);

  //! Build mappings of extent information of a TensorDomain
  void build(TensorDomain* td);

  //! Almost exact duplicate of build(TensorDomain* td), except that
  //!  the traversal was done on loop indexing expressions.
  std::unordered_map<IterDomain*, Val*> buildConcreteHaloExtentMap(
      const LoopIndexing& loop_indexing);

  //! Set initial AxisHaloInfo of a root axis
  //!
  //! The axis does not need to be a root domain in the case of
  //! reference tensors. Reference tensors get halo information from
  //! consumer root domains, which may correspond to rfactor domains
  //! of tensors from which reference tensors are derived.
  void setRootAxisInfo(IterDomain* id, const AxisHaloInfo& root_axis_info);

  //! Returns true if id has the root halo information set by
  //! setRootAxisInfo.
  bool hasRootAxisInfo(IterDomain* id) const;

  //! Returns the registed AxisHaloInfo of a root axis.
  //!
  //! This is only for root axes. It is an error to query with
  //! non-root axes.
  const AxisHaloInfo& getRootAxisInfo(IterDomain* id) const;
  AxisHaloInfo& getRootAxisInfo(IterDomain* id);

  //! Query if an axis has a halo width.
  //!
  //! See the comment at halo_width_map_.
  bool hasHaloWidth(IterDomain* id) const;

  //! Return the halo width of an axis.
  //!
  //! It's an error if queried for an axis with no halo width
  //! information.
  int getHaloWidth(IterDomain* id) const;

  //! Returns an extent if id is extended for halo. Nullptr is
  //! returned otherwise.
  Val* getExtent(IterDomain* id) const;

  //! Returns all child domains of a root domain that inherits the
  //! halo of the root domain.
  //!
  //! If a root domain is split, only the inner domain inherits the
  //! halo, so the inner domain is included but not the outer domain.
  const std::unordered_set<IterDomain*>& getChildDomains(
      IterDomain* root_id) const;

  //! Returns all root domains from which the halo of a domain
  //! originates.
  std::unordered_set<IterDomain*> getRootDomains(IterDomain* id) const;

  //! Returns true if a domain inherits halo associated with a root
  //! domain.
  bool isHaloInherited(IterDomain* root_id, IterDomain* id) const;

  // True when the extent of id1 is guaranteed to be lesser than or
  // equal to id2. False when it *may* not.
  bool extentLessEqual(IterDomain* id1, IterDomain* id2) const;
  // True when the extent of id1 is guaranteed to be equal to
  // id2. False when it *may* not.
  bool extentEqual(IterDomain* id1, IterDomain* id2) const;

  //! Check if expr must be predicated based on boundary conditions
  //! directly or indirectly induced by shift expressions.
  //!
  //! When yes, the expression needs two predications: one for
  //! interior and another for padding. Predicate insertion is done in
  //! the ShiftPredicateInserter class below.
  bool needsShiftPredicate(Expr* expr) const;

  std::string toString() const;

 private:
  //! Propagate root axis information from outputs to inputs of an
  //! expression
  void propagateRootAxisInfo(Expr* expr);

  //! Adds a domain to the halo inheritance map.
  //!
  //! A domain, child, is added to the same set as domain parent. Both
  //! domains must be part of TensorDomain td.
  void insertToInheritanceMap(
      TensorDomain* td,
      IterDomain* parent,
      IterDomain* child);

  //! Propagate root axis information from consumer to producer
  void propagateRootAxisInfo(
      TensorView* producer,
      TensorView* consumer,
      Expr* expr);

  //! Initialize mappings for a given root domain. The given domain
  //! must be previously given to setRootAxisInfo.
  void initializeFromRootAxisInfo(IterDomain* id);

  //! Validate shift usage
  void validate(TensorView* td) const;

  void setHaloWidth(IterDomain* id, int halo_width);

 private:
  //! Halo information of root axes
  std::unordered_map<IterDomain*, AxisHaloInfo> root_axis_map_;

  //! Halo-extended extents. No mapping for axes without halo extension
  std::unordered_map<IterDomain*, Val*> extent_map_;

  //! The halo width of an axis.
  //!
  //! The mapped value is a sum of two widths of both sizes of an
  //! axis. For root axes, it is equivalent to AxisHaloInfo.widths_[0]
  //! + AxisHaloInfo.widths_[1] (or AxisHaloInfo.width()). For
  //! example, when a root axis is extended by 1 for both sides, it'd
  //! be mapped to 2. For axes with no halo, they are mapped to zero.
  //!
  //! When an axis is split, its halo is only propagated to the inner
  //! output axis, so the value of this map for the inner output is
  //! the same as the input of split, while the outer output is mapped
  //! to zero.
  //!
  //! When an axis is merged, no mapping is created for its
  //! output at this point primarly because it isn't clear what the
  //! "halo width" for a merged axis should mean. Perhaps, a merged
  //! axis of (N+a)*(M+b), where N and M correspond to the original
  //! extens of two axes, and a and b correspond to their halo widths,
  //! it might make sense to set the halo width of this merged axis as
  //! (N+a)*(M+b)-N*M. Currently, however, this isn't necessary, so no
  //! particular mapping is created for merged axes.
  //!
  //! This is currently used only for conservatively comparing the
  //! overall extents of axes. See HaloInfo::extentLessEqual and
  //! HaloInfo::extentEqual.
  //!
  //! Example: Suppose a root axis has {0, 1} of
  //! AxisHaloInfo.widths_. The root axis is mapped to 1. When it is
  //! split, say, by 4, the output axes, [N / 4] and [4], where N is
  //! the extent of the root axis, the outer axis is mapped to 0,
  //! whereas the inner axis is mapped to 1. Further, suppose the
  //! inner axis is merged with another axis of extent M, we know that
  //! the extent of the resulting output axis is 5*M, but we don't
  //! create its mapping.
  std::unordered_map<IterDomain*, int> halo_width_map_;

  //! Mappings from root domains to child domains that inherit halo
  std::unordered_map<IterDomain*, std::unordered_set<IterDomain*>>
      inheritance_map_;
};

class ShiftPredicateInserter {
 public:
  //! Works mostly the same way as
  //! PredicateCompute::getInlinePredicate but does the insertion of
  //! the generated predicate. The branch structure is different from
  //! the usual predicated expression, so the insertion is also done
  //! here.
  static void insert(
      Expr* expr,
      const std::vector<kir::ForLoop*>& loops,
      Bool* thread_pred,
      bool within_unswitch);
};

} // namespace cuda
} // namespace fuser
} // namespace jit
} // namespace torch