File: fusion.h

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (284 lines) | stat: -rw-r--r-- 9,475 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
#pragma once

#include <ATen/core/ivalue.h>
#include <c10/macros/Export.h>
#include <c10/util/Exception.h>

#include <torch/csrc/jit/codegen/cuda/ir_base_nodes.h>
#include <torch/csrc/jit/codegen/cuda/ir_container.h>
#include <torch/csrc/jit/codegen/cuda/iter_visitor.h>

#include <unordered_map>
#include <unordered_set>
#include <vector>

namespace torch {
namespace jit {
namespace fuser {
namespace cuda {

//! Usage: FusionGuard and Fusion are required user interfaces for any operation
//! underlying the code generator. In order to create values, expressions, and
//! generate code a Fusion instance must be active. It is the responsibility of
//! the user to create a Fusion instance and register it with the fusion guard.
//! The simplest example of this is:
//!
//!     Fusion fusion;
//!     FusionGuard fg(&fusion);
//!
//! Once a fusion is active all values and operations will be registered with
//! it.
//!
//! FusionGuard and Fusion are critical to the lifetime model of the IR system.
//! FusionGuard is a convenient way to set what base container instance holds
//! the defined IR. Statements that are defined are registered through the
//! FusionGuard with a particular Fusion. FusionGuard provides convenient
//! methods to access the active fusion so it doesn't need to be passed around
//! constantly. Any IR node derived classes from Statement must register with
//! Fusion to avoid memory leaks.
//!
//! Fusion is generally thought of as a translated fusion group from the JIT. It
//! is likely a single kernel, although, we don't have to stick to this in the
//! future and could in theory generate multiple kernels with an executor to run
//! them.
//!
//! Fusion also allows users to set input/output values that will allow us to
//! figure out how to hook up runtime data to and from the JIT as well as
//! provide us mechanisms for dependency analysis and DCE including safety
//! checks.

class Fusion;
class TensorView;
class WelfordResult;

class SegmentCandidateFinder;
class SegmentedFusion;
class KernelArgumentHolder;

//! Fusion Guard is our "context manager". It holds the actrive fusion and
//! allows it to be accessed anywhere through FusionGuard::getCurFusion()
class TORCH_CUDA_CU_API FusionGuard {
 public:
  Fusion* prev_fusion;

  //! Set the active fusion so it can be manipulated.
  explicit FusionGuard(Fusion* fusion);

  ~FusionGuard();

  static Fusion* getCurFusion();
  static void setCurFusion(Fusion* fusion);
};

//! Fusion is mutable but unique. Nodes cannot be copied in any way from one
//! Fusion to another. If anything like that is desired, it would require
//! duplicating all associated values and exprs. Fusion is considered to be SSA,
//! though this could also change in the future if there is a good reason to do
//! so.
//!
//! The Fusion owns the whole IR graph (Vals and Exprs)
//!
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
class TORCH_CUDA_CU_API Fusion : public IrContainer {
  typedef std::unordered_map<int, std::vector<int64_t>> PermutationMap;

 public:
  Fusion() = default;

  Fusion(const Fusion& other);
  Fusion(Fusion&& other) noexcept;

  Fusion& operator=(const Fusion& other);
  Fusion& operator=(Fusion&& other) noexcept;

  ~Fusion();

  friend void swap(Fusion& a, Fusion& b) noexcept;

  void clear() noexcept;

  //! Break dependency chains associated with Expr, remove references to expr
  //! delete expr
  void removeExpr(Expr* expr) override;

  //! Completely remove val from the fusion, break all dependencies associated
  //! with it
  void removeVal(Val* val) override;

  //! Register input as an input of the fusion
  void addInput(Val* input);

  //! Register output as an output of the fusion
  void addOutput(Val* output);

  //! Deregister input as an input of the fusion
  void removeInput(Val* input);

  //! Deregister output as an output of the fusion
  void removeOutput(Val* output);

  //! Replace output with another value
  void replaceOutput(Val* output, Val* replacement);

  //! Assert that all leaves found from outputs are registered as an input
  void validateInputs();

  //! Print this fusion to the console
  void print();

  //! Print Arith exprs
  //! \param from_outputs_only Only print exprs reachable from outputs
  void printMath(bool from_outputs_only = true);

  //! Print transformations used in fusion (can be very verbose)
  void printTransforms();

  //! Lower the fusion and print a kernel
  void printKernel(DataType index_type = DataType::Int);

  //! Return a list of topologically sorted expressions. This only includes
  //! exprs required to genereate registered outputs.
  std::vector<Expr*> exprs();

  //! Return a vector of fusion inputs that feed this Val
  std::vector<Val*> inputsOf(Val* val);

  //! Return all Vals in math expressions that cannot be eliminated.
  //!
  //! It is generally equivalent to vals that are used to generate
  //! outputs, however, when a multi-output expression exists, and only
  //! some of the outputs are used, the remaining unused outputs are
  //! also included as they must show up in the final code.
  std::vector<Val*> usedMathVals();

  //! Returns all vals that are produced by used math expressions and
  //!  also do not have further consumers.
  //!
  //! In the case of an active multi-output expressions, the returned vector
  //!  will include the expression outputs that did not lead to an fusion
  //!  output.
  std::vector<Val*> terminatingMathVals();

  //! Return all Exprs that use val
  std::unordered_set<Expr*> unordered_uses(const Val* val) const;

  //! Return the Expr that produces val
  Expr* definition(const Val* val) const;

  //! Indicate to kernel to set itself up to generate random numbers
  bool isStochastic();

  //! Run fusion segmentation algorithm to create a segmented fusion
  std::unique_ptr<SegmentedFusion> segment(const KernelArgumentHolder& args);

  const auto& inputs() const {
    return inputs_;
  }

  std::vector<Val*> inputsAndCreated();

  const auto& outputs() const {
    return outputs_;
  }

  std::vector<Val*> getTerminatingOutputs() const;

  // Aliasing output to input value, this is a WAR to allow inplace update on
  // input tensor.
  // Note: this is not always safe and should be used with extra caution.
  // Currently the only place it's used is in the running stats update for batch
  // normalization.
  // TODO: alias should be made aware to segmentation, so we'll always include
  // the input tensor to the section where output is produced.
  void aliasOutputToInput(Val* output, Val* input);
  Val* getOutputAlias(Val* output);
  std::unordered_set<int> getOutputAliasIndices() const;
  std::vector<std::pair<int, int>> getInputAliasIndices() const;

  // mark input at index to be permuted by permutation
  void setPermutationOnInput(int index, std::vector<int64_t> permutation) {
    permuted_input_map_.insert({index, permutation});
  }

  // mark output at index to be restored by permutation
  void setPermutationOnOutput(int index, std::vector<int64_t> permutation) {
    permuted_output_map_.insert({index, permutation});
  }

  // return a map of indices to permutation, which indicates all input tensors
  // that needs to be permuted
  const PermutationMap& getPermutationInputMap() const {
    return permuted_input_map_;
  }

  // return a map of indices to permutation, which indicates all output tensors
  // that needs to be permuted
  const PermutationMap& getPermutationOutputMap() const {
    return permuted_output_map_;
  }

  bool isTVUseInfoValid() {
    return all_tv_uses_valid_;
  }

  bool isUpdatingTVUseInfo() {
    return is_during_update_uses_;
  }

  const auto& ioAlias() const {
    return io_alias_;
  }

 protected:
  friend SegmentCandidateFinder;
  friend SegmentedFusion;
  friend class TranslateApplicableWelford;
  friend Val;

  static IrCloner copy(const Fusion* from, Fusion* to);

  //! Register the Val with this fusion
  virtual void registerVal(Val* val) override;

  //! Register expr with this fusion.
  //! When we register an expression, we want to update the dependency tracking
  //! of Vals. If this container is a not a Kernel, it will remove previous
  //! definitions of outputs and register this Expr as the definition. Otherwise
  //! will update definition if not previously set, but will not remove old
  //! definitions.
  virtual void registerExpr(Expr* expr) override;

  //! Clear Expr's from TV uses that are not required to produce outputs from
  //! inputs. Only other place this is used (other than Fusion) is in
  //! Val::uses()
  void resetTvUses();

 private:
  // Determine if the two values are compatible for aliasing
  // Same DataType, ValType, and number of dimensions
  bool isAliasCompatible(Val* left, Val* right);

 private:
  // Fusion inputs and outputs
  std::vector<Val*> inputs_;
  std::vector<Val*> outputs_;

  // io alias pointing from output to input
  std::unordered_map<Val*, Val*> io_alias_;

  // See Note [ Permutation support in nvfuser ]
  // map from indices of input tensor to permutation
  PermutationMap permuted_input_map_;
  // map from indices of output tensor to permutation
  PermutationMap permuted_output_map_;

  // Records if the current use data in the IR nodes are valid
  //  the states are either all valid or all invalid
  bool all_tv_uses_valid_ = false;
  bool is_during_update_uses_ = false;
};

} // namespace cuda
} // namespace fuser
} // namespace jit
} // namespace torch