File: fusion.h

package info (click to toggle)
pytorch 1.7.1-7
  • links: PTS, VCS
  • area: main
  • in suites: bullseye
  • size: 80,340 kB
  • sloc: cpp: 670,830; python: 343,991; ansic: 67,845; asm: 5,503; sh: 2,924; java: 2,888; xml: 266; makefile: 244; ruby: 148; yacc: 144; objc: 51; lex: 44
file content (234 lines) | stat: -rw-r--r-- 7,441 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
#pragma once

#include <c10/util/Exception.h>
#include <torch/csrc/WindowsTorchApiMacro.h>

#include <torch/csrc/jit/codegen/cuda/ir_base_nodes.h>

#include <unordered_map>
#include <unordered_set>
#include <vector>

namespace torch {
namespace jit {
namespace fuser {

/*
 * Usage: FusionGuard and Fusion are required user interfaces for any operation
 * underlying the code generator. In order to create values, expressions, and
 * generate code a Fusion instance must be active. It is the responsibility of
 * the user to create a Fusion instance and register it with the fusion guard.
 * The simplest example of this is: Fusion fusion; FusionGuard fg(&fusion); Once
 * a fusion is active all values and operations will be registered with it.
 *
 * FusionGuard and Fusion are critical to the lifetime model of the IR system.
 * FusionGuard is a convenient way to set what base container instance holds the
 * defined IR. Statements that are defined are registered through the
 * FusionGuard with a particular Fusion. FusionGuard provides convenient methods
 * to access the active fusion so it doesn't need to be passed around
 * constantly. Any IR node derived classes from Statement must register with
 * Fusion to avoid memory leaks.
 *
 * Fusion is generally thought of as a translated fusion group from the JIT. It
 * is likely a single kernel, although, we don't have to stick to this in the
 * future and could in theory generate multiple kernels with an executor to run
 * them.
 *
 * Fusion also allows users to set input/output values that will allow us to
 * figure out how to hook up runtime data to and from the JIT as well as provide
 * us mechanisms for dependency analysis and DCE including safety checks.
 */

class Fusion;
class TensorView;

// Fusion Guard is our "context manager". It holds the actrive fusion and allows
// it to be accessed anywhere through FusionGuard::getCurFusion().
class TORCH_CUDA_API FusionGuard {
 public:
  Fusion* prev_fusion;

  // Set the active fusion so it can be manipulated.
  explicit FusionGuard(Fusion* fusion);

  ~FusionGuard();

  static Fusion* getCurFusion();
};

/*
 * Fusion is mutable but unique. Nodes cannot be copied in any way from one
 * Fusion to another. If anything like that is desired, it would require
 * duplicating all associated values and exprs. Fusion is considered to SSA,
 * though this could also change in the future if there is a good reason to do
 * so.
 *
 * The Fusion owns the whole IR graph (Vals and Exprs)
 */
class TORCH_CUDA_API Fusion final {
 public:
  Fusion() = default;

  Fusion(const Fusion& other);
  Fusion(Fusion&& other) noexcept;

  Fusion& operator=(const Fusion& other);
  Fusion& operator=(Fusion&& other) noexcept;

  ~Fusion();

  friend void swap(Fusion& a, Fusion& b) noexcept;

  void clear() noexcept;

  // Break dependency chains associated with Expr, remove references to expr
  // delete expr.
  void removeExpr(Expr* expr);

  // Completely remove val from the fusion, break all dependencies associated
  // with it.
  void removeVal(Val* val);

  // Register input as an input of the fusion
  void addInput(Val* input);

  // Register output as an output of the fusion
  void addOutput(Val* output);

  // Check if stmt is properly registered with this fusion
  bool inFusion(const Statement* stmt) const;

  // Throw an error if stmt is not in this fusion. Message will be:
  // msg + " it was not found in the active fusion."
  void assertInFusion(const Statement* stmt, const std::string& msg = "") const;

  /*
   * Return a list of topologically sorted expressions. We can start
   * by only traversing back from registered outputs, or from all terminating
   * Vals.
   *
   * from_outputs_only:
   *   True - Sort from DAG associated with registered outputs
   *   False - Sort from all terminating Vals.
   */
  std::vector<Expr*> exprs(bool from_outputs_only = false);

  // Return a vector of fusion inputs that feed this Val
  std::unordered_set<Val*> inputsOf(Val* val);

  // Assert that all leaves found from outputs are registered as an input.
  void validateInputs();

  // Print this fusion to cout.
  void print();

  // Print Arith exprs used in outputs
  void printMath();

  // Print transformations used in fusion (can be very verbose)
  void printTransforms();

  // Lower the fusion and print a kernel
  void printKernel();

  // Register the Val with this fusion
  StmtNameType registerVal(Val* val);

  // Register expr with this fusion.
  // When we register an expression, we want to update the dependency tracking
  // of Vals. We add expr to our general expr_set_, we add use tracking for
  // inputs and origin tracking for outputs.
  StmtNameType registerExpr(Expr* expr);

  // Register stmt with this fusion.
  StmtNameType registerStatement(Statement* stmt);

  // Lowered nodes
  // TODO(kir): to be removed
  StmtNameType registerLoweredVal(Val* val);
  StmtNameType registerLoweredExpr(Expr* expr);

  // Lowered counterpart to inFusion()
  // TODO(kir): to be removed
  bool inKernelIr(const Statement* stmt) const;

  // Check if val is used in this fusion. Not equivelent to DCE
  bool used(Val* val) const;

  // Return the set of Vals registered with this fusion
  const std::unordered_set<Val*>& vals() const noexcept;
  // Return in insertion order
  const std::deque<Val*>& deterministic_vals() const noexcept;

  // Return the set of Exprs registered with this fusion
  const std::unordered_set<Expr*>& unordered_exprs() const noexcept;

  // Return all Exprs that use val
  std::unordered_set<Expr*> unordered_uses(Val* val) const;

  // Return the Expr that produces val
  Expr* origin(const Val* val) const;

  // Indicate to kernel to set itself up to generate random numbers
  bool isStochastic();

  // TODO(kir): revisit to see how many of these are still needed
  bool hasReduction();
  bool hasBlockReduction();
  bool hasGridReduction();
  bool hasBlockBroadcast();
  bool hasBroadcast();
  size_t gridReductionTempBufferSize();

  const auto& inputs() const {
    return inputs_;
  }

  const auto& outputs() const {
    return outputs_;
  }

  std::vector<Val*> getTerminatingOutputs();

  bool hasInput(const Val* val) const;
  bool hasOutput(const Val* val) const;

  void replaceInput(Val* replace, Val* with);
  void replaceOutput(Val* replace, Val* with);

 private:
  // Return an int that monotonically increases for each val/expr, some are
  // explicitly incremented by type.
  StmtNameType getValName(ValType vtype);
  StmtNameType getExprName();

 private:
  // Sets of all Vals/Exprs registered with this fusion
  // (val_deque_ is not owning the objects)
  std::unordered_set<Val*> val_set_;
  std::deque<Val*> val_deque_;
  std::unordered_set<Expr*> expr_set_;

  // Values names counters
  std::unordered_map<ValType, StmtNameType, TypeHash> val_type_name_map_;

  // Expression names counter
  StmtNameType expr_name_counter_ = 0;

  // Dependency tracking for Vals. Where did it come from? Where is it used?
  std::unordered_map<const Val*, Expr*> origin_;
  std::unordered_map<Val*, std::unordered_set<Expr*>> uses_;

  // Fusion inputs and outputs
  std::vector<Val*> inputs_;
  std::vector<Val*> outputs_;

  // Lowered IR
  std::unordered_set<Val*> lowered_val_set_;
  std::unordered_set<Expr*> lowered_expr_set_;
  std::unordered_map<const Val*, Expr*> lowered_origin_;
};

} // namespace fuser
} // namespace jit
} // namespace torch