File: ir_base_nodes.h

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (509 lines) | stat: -rw-r--r-- 15,757 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
#pragma once

#include <c10/core/ScalarType.h>
#include <c10/macros/Export.h>
#include <c10/util/Exception.h>
#include <c10/util/Optional.h>

#include <torch/csrc/jit/codegen/cuda/type.h>
#include <torch/csrc/jit/codegen/cuda/utils.h>

#include <cstdint>
#include <iostream>
#include <limits>
#include <memory>
#include <stdexcept>
#include <unordered_map>
#include <vector>

// TODO: Add more types (int32, int64)
// TODO: sameAs should have better logic to check against any type and return
// gracefully

/*
 * This file defines the base IR structure. Any IR node in this system will
 * inherit from one of the following classes: Statement, Expr, Val,
 * IrInputOutput IR is any information that the code generation stack may need
 * for analysis. By analysis we're refering to anything done in response to a
 * user facing call of this stack. This could be careful tracking of user calls,
 * and any transformation including optimizing transformations, user declared
 * transformations, and lowering the IR.
 */

namespace torch {
namespace jit {
namespace fuser {
namespace cuda {

using ValueId = int32_t;

using StmtNameType = unsigned int;

constexpr StmtNameType kInvalidStmName =
    std::numeric_limits<unsigned int>::max();

class Fusion;
class FusionGuard;
class Expr;
class Val;
class UnaryOp;
class BinaryOp;
class RNGOp;
class IterDomain;
class IrCloner;
class IrContainer;
class IrBuilderPasskey;
class IrContainerPasskey;

namespace kir {
class Kernel;
class Predicate;
} // namespace kir

// Passkey for container to register names with statements
class ExprPasskey {
  friend class Expr;

 private:
  explicit ExprPasskey() {}
};

TORCH_CUDA_CU_API void swap(Fusion& a, Fusion& b) noexcept;

//! Statement is the highest level node representation. Everything that is
//! considered "IR" will be derived from this class at some point. Both Values
//! and Expr's are a Statement. If there will ever be any more fundamental
//! types, they will also derive from Statement.
//!
//! We use Statements to pass around nodes of unknown compile type. Therefore it
//! is also important for the design to have a dispatch system for a Statment.
//! Basically beinng able to succienctly traverse down the inhereitance stack of
//! a Statment at runtime. This is currently implemented in dispatch.h
class TORCH_CUDA_CU_API Statement : public NonCopyable, public PolymorphicBase {
  friend void swap(Fusion&, Fusion&) noexcept;
  friend void swap(IrContainer& a, IrContainer& b) noexcept;

 public:
  Statement() = delete;

  // Cloning constructor
  Statement(const Statement* src, IrCloner* ir_cloner);

  // Dispatch functions, definitions in dispatch.cpp
  template <typename T>
  static void dispatch(T handler, Statement*);

  template <typename T>
  static void constDispatch(T handler, const Statement* const);

  template <typename T>
  static void mutatorDispatch(T mutator, Statement*);

  // Accessor functions to types. Vals always have a DataType, Exprs never do
  virtual c10::optional<ValType> getValType() const {
    return c10::nullopt;
  }
  virtual c10::optional<DataType> getDataType() const {
    return c10::nullopt;
  }
  virtual c10::optional<ExprType> getExprType() const {
    return c10::nullopt;
  }

  // Short cut to figure out if it is a value/expression
  bool isVal() const {
    return getValType() != c10::nullopt;
  }
  bool isExpr() const {
    return getExprType() != c10::nullopt;
  }

  // Make sure this is a Val and return it as a Val*
  Val* asVal();

  // Make sure this is an Expr and return it as an Expr*
  Expr* asExpr();

  // Return the fusion this statement belongs to
  Fusion* fusion() const;

  // Return the kernel this statement belongs to
  kir::Kernel* kernel() const;

  // Return the container this statement belongs to
  IrContainer* container() const {
    return ir_container_;
  }

  // Return the int that represents its name
  StmtNameType name() const {
    return name_;
  }

  // Set the statements' name. Typically the container will set the name,
  // however if we're dealing with cloning, IrBuilder will set the name, this
  // maybe should be from IrCloner, however I didn't want to add another
  // passkey.
  void setName(IrContainerPasskey, StmtNameType name);
  void setName(IrBuilderPasskey, StmtNameType name);

  virtual bool sameType(const Statement* const other) {
    if (isVal() && other->isVal())
      return getValType().value() == other->getValType().value();
    if (isExpr() && other->isExpr())
      return getExprType().value() == other->getExprType().value();
    return false;
  }

  // Return if this statement is the same as another statement
  // TODO: should this run through dispatch on this and other?
  virtual bool sameAs(const Statement* other) const {
    return this == other;
  }

  std::string toString() const;
  std::string toInlineString() const;

 protected:
  Statement(IrBuilderPasskey);

  // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
  StmtNameType name_ = kInvalidStmName;

  // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
  IrContainer* ir_container_ = nullptr;
};

//! A Val represents a "value." These are objects, like tensors, scalars, and
//! memory locations, that are inputs and outputs of computations (represented
//! by Exprs, below)
//!
//! Vals are constant and unique and should always be passed
//! around as a pointer. Val can generally be thought of as representing any
//! type of data. Some examples: a constant size like convolution filter width a
//! runtime constant like batch normalizations momentum a "symbolic" tensor like
//! one passed down from the JIT a memory buffer used in device code
//!
//! Adding a Val:
//! Right now adding a Val is quite involved. Val's can be defined in ir.h or in
//! their own header file. The following is what is currently needed to add a
//! new Val:
//!
//! 1) Definition inheriting from Val
//!     - Members must be private or protected
//!     - Accessor functions for members
//!     - Must call Val constructor, Val constructor registers with fusion
//!     - Implementation of bool sameAs(...)
//!     - Must implement a "cloning" constructor, ex.
//!        Int::Int(const Int* src, IrCloner* ir_cloner)
//! 2) dispatch.h/.cpp must be updated to include dispatch of the new Val
//! 3) Default mutator function should be added to mutator.cpp
//! 4a) Printing functions should be added to ir_iostream.h/.cpp
//! 4b) Graphviz generation must be added to ir_graphviz.h/.cpp
//! 5) An enum value must be added to ValType in type.h
//! 6) A string entry must be added in val_type_string_map
//!
class TORCH_CUDA_CU_API Val : public Statement {
 public:
  explicit Val(
      IrBuilderPasskey,
      ValType _vtype,
      DataType _dtype = DataType::Null);

  Val(const Val* src, IrCloner* ir_cloner);

  // Dispatch functions, definitions in dispatch.cpp
  template <typename T>
  static void dispatch(T handler, Val*);

  template <typename T>
  static void constDispatch(T handler, const Val* const);

  template <typename T>
  static void mutatorDispatch(T mutator, Val*);

  c10::optional<ValType> getValType() const override {
    return vtype_;
  }

  ValType vtype() const {
    return vtype_;
  }

  DataType dtype() const {
    return dtype_;
  }

  // Throws if no DataType is found. Vals must have a DataType
  c10::optional<DataType> getDataType() const override;

  bool isScalar() const {
    return vtype_ == ValType::Scalar || vtype_ == ValType::NamedScalar;
  }

  // Returns if all dependencies are constant scalars
  bool isConstScalar() const;

  // Returns if all dependencies are constant integers
  bool isConstInt() const;

  bool isAnInt() const {
    return isScalar() && dtype_ == DataType::Int;
  }

  bool isADouble() const {
    return isScalar() && dtype_ == DataType::Double;
  }

  // If this Val is an integer with a direct constant value associated with it,
  // will return the value of that constant integer. If this integer has
  // defining expressions it will return a c10::nullopt. Those values should be
  // infered using evaluateInt.
  c10::optional<int64_t> getInt() const;

  // If this Val is a double with a direct constant value associated with it,
  // will return the value of that constant double. If this double has
  // defining expressions it will return a c10::nullopt. Those values should be
  // infered using evaluateDouble.
  c10::optional<double> getDouble() const;

  // If this Val is a constant integer, and its history is comprised only of
  // constant values, will return the value of that constant integer. Cannot
  // make constant as expression evaluator takes non-constant Vals.
  int64_t evaluateInt();

  // If this Val is a constant double, and its history is comprised only of
  // constant values, will return the value of that constant double. Cannot
  // make constant as expression evaluator takes non-constant Vals.
  double evaluateDouble();

  // Returns if no dependencies and is a constant scalar.
  virtual bool isConst() const {
    return false;
  }

  bool isZeroInt() const;
  bool isOneInt() const;

  // Returns the Expr that this value is an output of, returns nullptr if none
  // was found
  Expr* definition() const {
    if (is_fusion_input_) {
      return nullptr;
    }
    return definition_;
  }

  // Determine if value definition matches given expression type
  bool isDefinitionType(ExprType expression_type) const;

  const std::vector<Expr*>& uses() const;

  bool isFusionInput() const {
    return is_fusion_input_;
  }

  bool isFusionOutput() const {
    return is_fusion_output_;
  }

  //! Returns true when other is a producer of this
  bool isProducerOf(const Val* other) const;

  //! Returns true when other is a consumer of this
  bool isConsumerOf(const Val* other) const;

  bool sameType(const Statement* other) override {
    return Statement::sameType(other) &&
        getDataType() == other->as<Val>()->getDataType();
  }

  // TODO: Make this more sophisticated. A value being the same as another value
  // should be evaluated based on the DAG that created it, and that DAGs leaf
  // nodes
  bool sameAs(const Statement* other) const override {
    return this == other;
  }

  void setEvaluatorIndex(int to) {
    TORCH_INTERNAL_ASSERT(evaluator_index_ == -1);
    evaluator_index_ = to;
  }

  int evaluatorIndex() const {
    return evaluator_index_;
  }

  // Following is managed by Fusion (or kirIrBuilder) and can change.
  // TODO: Protect with a passkey.
  void setDefinition(Expr* expr) {
    definition_ = expr;
  }

  void resolveIndexDtype();

 protected:
  friend Fusion;

  // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
  const ValType vtype_;

  // TODO: Add fusion passkey for this
  void setIsFusionInput(bool is_fusion_input) {
    is_fusion_input_ = is_fusion_input;
  }

  // TODO: Add fusion passkey for this
  void setIsFusionOutput(bool is_fusion_output) {
    is_fusion_output_ = is_fusion_output;
  }

  // TODO: Add fusion or container passkey for this
  void setUses(const std::vector<Expr*>& uses) {
    uses_ = uses;
  }

 private:
  // There's only one instance where dtype can change, and that's through
  // resolving the index data type from nvfuser to either Int or Int32 for
  // welford operations.
  DataType dtype_;

  // Following is managed by Fusion and can change.
  bool is_fusion_input_ = false;
  bool is_fusion_output_ = false;

  Expr* definition_ = nullptr;
  std::vector<Expr*> uses_;

  // Expr evaluator idx;
  int evaluator_index_ = -1;
};

//!  A Expr represents a "computation." These are functions that takes inputs
//!  and produce outputs, inputs and outputs all being Vals. There are
//!  specializations of BinaryOp which takes 2 inputs and produces 1 output, and
//!  UnaryOp which takes 1 input and produces 1 output. Exprs are unique and
//!  immutable. Conceptually, Exprs could always be manipulated using unique
//!  pointers, and we could add this later. However, for now Exprs can be
//!  replaced in a fusion, but they cannot be modified in place.
//!
//!  The IR is static single assignment (SSA). Values can only be defined as an
//!  output of an Expr once. If they are re-defined the original definition is
//!  deleted from the program, as opposed to an ordered redefinition of the
//!  value in the program.
//!
//!  Note: Registering an Expr with a Fusion is actually 2 parts, one part is
//!  done in the Expr constructor, so that should be called on anything that
//!  inherits Expr. The issue with having registration in Expr's constructor, is
//!  that the constructor of an Expr will set ouputs and inputs. This
//!  information is important for registration with Fuser, so it can track the
//!  dependency chain.
//!
//!  Adding an Expr:
//!  Right now adding an Expr is quite involved. Expr's can be defined in ir.h
//!  or in their own header file. The following is what is currently needed for
//!  Expr definitions:
//!
//! 1) Definition inheriting from Expr.
//!      - Members must be private or protected
//!      - Accessor functions for members
//!      - Constructors need to register with the Fusion after inputs/outputs
//!         are defined
//!      - Implementation of bool sameAs(...)
//!  2) dispatch.h/.cpp must be updated to include dispatch of the new Val
//!  3) Default mutator function should be added to mutator.h/.cpp
//!  4) Printing functions should be added to ir_iostream.h/.cpp
//!  5) Lower case convenience functions should be added to arith.h/.cpp (If
//!     user facing)
//!  6) An enum value must be added to ExprType in type.h
//!  7) A string entry must be added in expr_type_string_map
//!  8) Entry added to ir_graphviz .cpp/.h
//!
class TORCH_CUDA_CU_API Expr : public Statement {
 public:
  explicit Expr(IrBuilderPasskey, ExprType type);

  Expr(const Expr* src, IrCloner* ir_cloner);

  c10::optional<ExprType> getExprType() const override {
    return etype_;
  }

  ExprType etype() const {
    return etype_;
  }

  bool sameAs(const Statement* other) const override;

  // Input/output accessors
  const auto& inputs() const {
    return inputs_;
  }

  const auto& outputs() const {
    return outputs_;
  }

  auto input(size_t index) const {
    return inputs_[index];
  }

  auto output(size_t index) const {
    return outputs_[index];
  }

  // Dispatch functions, definitions in dispatch.cpp
  template <typename T>
  static void dispatch(T handler, Expr*);

  template <typename T>
  static void constDispatch(T handler, const Expr* const);

  template <typename T>
  static void mutatorDispatch(T mutator, Expr*);

  // TODO: Protect based on being in kernel container
  kir::Predicate* predicate() const;

  // TODO: Protect based on being in kernel container
  void setPredicate(kir::Predicate* predicate);

  // TODO: Protect based on being in kernel container
  kir::Predicate* writePredicate() const;

  // TODO: Protect based on being in kernel container
  void setWritePredicate(kir::Predicate* write_predicate);

 protected:
  // TODO: Add Fusion passkey
  void addInput(Val* input) {
    TORCH_INTERNAL_ASSERT(input != nullptr);
    inputs_.push_back(input);
  }

  // TODO: Add Fusion passkey
  void addOutput(Val* output) {
    TORCH_INTERNAL_ASSERT(output != nullptr);
    outputs_.push_back(output);
  }

  ExprPasskey exprPasskey() {
    return ExprPasskey();
  }

 private:
  ExprType etype_ = ExprType::Invalid;
  std::vector<Val*> inputs_;
  std::vector<Val*> outputs_;

  kir::Predicate* predicate_ = nullptr;

  // Only used for reduction-related expressions
  kir::Predicate* write_predicate_ = nullptr;
};

} // namespace cuda
} // namespace fuser
} // namespace jit
} // namespace torch