File: ir_base_nodes.h

package info (click to toggle)
pytorch 1.7.1-7
  • links: PTS, VCS
  • area: main
  • in suites: bullseye
  • size: 80,340 kB
  • sloc: cpp: 670,830; python: 343,991; ansic: 67,845; asm: 5,503; sh: 2,924; java: 2,888; xml: 266; makefile: 244; ruby: 148; yacc: 144; objc: 51; lex: 44
file content (357 lines) | stat: -rw-r--r-- 11,480 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
#pragma once

#include <c10/core/ScalarType.h>
#include <c10/util/Exception.h>
#include <c10/util/Optional.h>
#include <torch/csrc/WindowsTorchApiMacro.h>

#include <torch/csrc/jit/codegen/cuda/type.h>
#include <torch/csrc/jit/codegen/cuda/utils.h>

#include <cstdint>
#include <deque>
#include <iostream>
#include <limits>
#include <memory>
#include <stdexcept>
#include <unordered_map>
#include <vector>

// TODO: Add more types (int32, int64)
// TODO: sameAs should have better logic to check against any type and return
// gracefully

/*
 * This file defines the base IR structure. Any IR node in this system will
 * inherit from one of the following classes: Statement, Expr, Val,
 * IrInputOutput IR is any information that the code generation stack may need
 * for analysis. By analysis we're refering to anything done in response to a
 * user facing call of this stack. This could be careful tracking of user calls,
 * and any transformation including optimizing transformations, user declared
 * transformations, and lowering the IR.
 */

namespace torch {
namespace jit {
namespace fuser {

using StmtNameType = unsigned int;

constexpr StmtNameType UNINITIALIZED_STMTNAMETYPE =
    std::numeric_limits<unsigned int>::max();

class Fusion;
class FusionGuard;
class Expr;
class Val;
class UnaryOp;
class BinaryOp;
class IterDomain;
class IrCloner;

/*
 * Statement is the highest level node representation. Everything that is
 * considered "IR" will be derived from this class at some point. Both Values
 * and Expr's are a Statement. If there will ever be any more fundamental types,
 * they will also derive from Statement.
 *
 * We use Statements to pass around nodes of unknown compile type. Therefore it
 * is also important for the design to have a dispatch system for a Statment.
 * Basically beinng able to succienctly traverse down the inhereitance stack of
 * a Statment at runtime. This is currently implemented in dispatch.h
 */
class TORCH_CUDA_API Statement : public NonCopyable, public PolymorphicBase {
  friend void swap(Fusion&, Fusion&) noexcept;

 public:
  Statement() = default;

  // Cloning constructor
  Statement(const Statement* src, IrCloner* ir_cloner);

  // Dispatch functions, definitions in dispatch.cpp
  template <typename T>
  static void dispatch(T handler, Statement*);

  template <typename T>
  static void constDispatch(T handler, const Statement* const);

  template <typename T>
  static Statement* mutatorDispatch(T mutator, Statement*);

  // Accessor functions to types. Vals always have a DataType, Exprs never do
  virtual c10::optional<ValType> getValType() const {
    return c10::nullopt;
  }
  virtual c10::optional<DataType> getDataType() const {
    return c10::nullopt;
  }
  virtual c10::optional<ExprType> getExprType() const {
    return c10::nullopt;
  }

  // Short cut to figure out if it is a value/expression
  bool isVal() const {
    return getValType() != c10::nullopt;
  }
  bool isExpr() const {
    return getExprType() != c10::nullopt;
  }

  // Make sure this is a Val and return it as a Val*
  Val* asVal();

  // Make sure this is an Expr and return it as an Expr*
  Expr* asExpr();

  // Return the fusion this statement belongs to
  Fusion* fusion() const {
    return fusion_;
  }

  // Return the int that represents its name
  StmtNameType name() const {
    return name_;
  }

  virtual bool sameType(const Statement* const other) {
    if (isVal() && other->isVal())
      return getValType().value() == other->getValType().value();
    if (isExpr() && other->isExpr())
      return getExprType().value() == other->getExprType().value();
    return false;
  }

  // Return if this statement is the same as another statement
  // TODO: should this run through dispatch on this and other?
  bool sameAs(const Statement* const other) const {
    return this == other;
  }

  void print() const;

 protected:
  StmtNameType name_ = UNINITIALIZED_STMTNAMETYPE;
  Fusion* fusion_ = nullptr;
};

/*
 * A Val represents a "value." These are objects, like tensors, scalars, and
 * memory locations, that are inputs and outputs of computations (represented
 * by Exprs, below). Vals are constant and unique and should always be passed
 * around as a pointer. Val can generally be thought of as representing any type
 * of data. Some examples: a constant size like convolution filter width a
 * runtime constant like batch normalizations momentum a "symbolic" tensor like
 * one passed down from the JIT a memory buffer used in device code
 *
 * Adding a Val:
 * Right now adding a Val is quite involved. Val's can be defined in ir.h or in
 * their own header file. The following is what is currently needed to add a new
 * Val:
 * 1) Definition inheriting from Val
 *     - Members must be private or protected
 *     - Accessor functions for members
 *     - Must call Val constructor, Val constructor registers with fusion
 *     - Implementation of bool sameAs(...)
 *     - Must implement a "cloning" constructor, ex.
 *        Int::Int(const Int* src, IrCloner* ir_cloner)
 * 2) dispatch.h/.cpp must be updated to include dispatch of the new Val
 * 3) Default mutator function should be added to mutator.cpp
 * 4a) Printing functions should be added to ir_iostream.h/.cpp
 * 4b) Graphviz generation must be added to ir_graphviz.h/.cpp
 * 5) An enum value must be added to ValType in type.h
 * 6) A string entry must be added in val_type_string_map
 */
class TORCH_CUDA_API Val : public Statement {
 public:
  virtual ~Val() = default;

  Val() = delete;

  // We may not want to register this value during Val's constructor. The reason
  // for this is that if we register the val, then ina derived constructor try
  // to throw, fusion's destructor will get called, but the pointer to this Val
  // will be invalid. When fusion tries to delete this value it will cause a seg
  // fault, instead of showing the thrown error.
  explicit Val(
      ValType _vtype,
      DataType _dtype = DataType::Null,
      bool register_val = true,
      bool lowered = false);

  // Lowers an existing Fusion IR node into a Kernel IR counterpart
  explicit Val(const Val* fusion_ir_node);

  Val(const Val* src, IrCloner* ir_cloner);

  // TODO: Values are unique and not copyable
  Val(const Val& other) = delete;
  Val& operator=(const Val& other) = delete;

  Val(Val&& other) = delete;
  Val& operator=(Val&& other) = delete;

  c10::optional<ValType> getValType() const override {
    return vtype_;
  }

  // Throws if no DataType is found. Vals must have a DataType
  c10::optional<DataType> getDataType() const override;

  bool isScalar() const {
    return vtype_ == ValType::Scalar || vtype_ == ValType::NamedScalar ||
        vtype_ == ValType::KirScalar || vtype_ == ValType::KirNamedScalar;
  }

  bool isConstScalar() const;

  bool isAnInt() const {
    return isScalar() && dtype_ == DataType::Int;
  }

  c10::optional<int64_t> getInt() const;

  bool isZeroInt() const;
  bool isOneInt() const;

  // Returns the Expr that this value is an output of, returns nullptr if none
  // was found
  Expr* getOrigin();
  const Expr* getOrigin() const;

  virtual bool sameType(const Statement* other) {
    return Statement::sameType(other) &&
        getDataType() == other->as<Val>()->getDataType();
  }

  // TODO: Make this more sophisticated. A value being the same as another value
  // should be evaluated based on the DAG that created it, and that DAGs leaf
  // nodes
  bool sameAs(const Val* const other) const {
    return this == other;
  }

  // Dispatch functions, definitions in dispatch.cpp
  template <typename T>
  static void dispatch(T handler, Val*);

  template <typename T>
  static void constDispatch(T handler, const Val* const);

  template <typename T>
  static Statement* mutatorDispatch(T mutator, Val*);

 protected:
  const ValType vtype_;
  const DataType dtype_;
};

//  A Expr represents a "computation." These are functions that takes inputs
//  and produce outputs, inputs and outputs all being Vals. There are
//  specializations of BinaryOp which takes 2 inputs and produces 1 output, and
//  UnaryOp which takes 1 input and produces 1 output. Exprs are unique and
//  immutable. Conceptually, Exprs could always be manipulated using unique
//  pointers, and we could add this later. However, for now Exprs can be
//  replaced in a fusion, but they cannot be modified in place.

//  The IR is static single assignment (SSA). Values can only be defined as an
//  output of an Expr once. If they are re-defined the original definition is
//  deleted from the program, as opposed to an ordered redefinition of the value
//  in the program.

//  Note: Registering an Expr with a Fusion is actually 2 parts, one part is
//  done in the Expr constructor, so that should be called on anything that
//  inherits Expr. The issue with having registration in Expr's constructor, is
//  that the constructor of an Expr will set ouputs and inputs. This information
//  is important for registration with Fuser, so it can track the dependency
//  chain.

//  Adding an Expr:
//  Right now adding an Expr is quite involved. Expr's can be defined in ir.h or
//  in their own header file. The following is what is currently needed for Expr
//  definitions:
//  1) Definition inheriting from Expr.
//      - Members must be private or protected
//      - Accessor functions for members
//      - Constructors need to register with the Fusion after inputs/outputs are
//         defined
//      - Implementation of bool sameAs(...)
//  2) dispatch.h/.cpp must be updated to include dispatch of the new Val
//  3) Default mutator function should be added to mutator.h/.cpp
//  4) Printing functions should be added to ir_iostream.h/.cpp
//  5) Lower case convenience functions should be added to arith.h/.cpp (If user
//   facing)
//  6) An enum value must be added to ExprType in type.h
//  7) A string entry must be added in expr_type_string_map
//  8) Entry added to ir_graphviz .cpp/.h

class TORCH_CUDA_API Expr : public Statement {
 public:
  Expr() = delete;
  explicit Expr(ExprType _type);
  Expr(const Expr* src, IrCloner* ir_cloner);
  virtual ~Expr() = default;

  Expr(const Expr& other) = delete;
  Expr& operator=(const Expr& other) = delete;

  Expr(Expr&& other) = delete;
  Expr& operator=(Expr&& other) = delete;

  c10::optional<ExprType> getExprType() const override {
    return type_;
  }

  ExprType type() const {
    return type_;
  }

  bool sameAs(const Expr* const other) const;

  // Input/output accessors
  const auto& inputs() const {
    return inputs_;
  }

  const auto& outputs() const {
    return outputs_;
  }

  auto input(size_t index) const {
    return inputs_[index];
  }

  auto output(size_t index) const {
    return outputs_[index];
  }

  // Dispatch functions, definitions in dispatch.cpp
  template <typename T>
  static void dispatch(T handler, Expr*);

  template <typename T>
  static void constDispatch(T handler, const Expr* const);

  template <typename T>
  static Statement* mutatorDispatch(T mutator, Expr*);

 protected:
  void addInput(Val* input) {
    TORCH_INTERNAL_ASSERT(input != nullptr);
    inputs_.push_back(input);
  }

  void addOutput(Val* output) {
    TORCH_INTERNAL_ASSERT(output != nullptr);
    outputs_.push_back(output);
  }

 private:
  ExprType type_ = ExprType::Invalid;
  std::vector<Val*> inputs_;
  std::vector<Val*> outputs_;
};

} // namespace fuser
} // namespace jit
} // namespace torch