File: kernel_ir_dispatch.h

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (144 lines) | stat: -rw-r--r-- 4,528 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
#pragma once

#include <torch/csrc/jit/codegen/cuda/dispatch.h>

namespace torch {
namespace jit {
namespace fuser {
namespace cuda {

class Expr;

namespace kir {
class Predicate;
class TensorIndex;
class ForLoop;
class IfThenElse;
class Scope;

// Base visitor class that visits all nodes in provided vector<Expr*>.
//
// Includes visiting through scopes like IfThenElse and ForLoop, and tracks
// them in scopes_ and for_loops_.
//
// Makes a copy of exprs at exprs_ which could be used to modify and return.
//
// When traversing through ITE/FLs it will use a copy
// of the provided expressions to make it safe to insert/delete nodes.
//
// Provides a simple base class to inherit from for typical lowering passes on
// Expr list
class TORCH_CUDA_CU_API IrVisitor : public OptOutDispatch {
 public:
  std::vector<Expr*> handle(const std::vector<Expr*>& expr);

 protected:
  using OptOutDispatch::handle;

  virtual void handle(ForLoop*) override;
  virtual void handle(IfThenElse*) override;

 protected:
  std::vector<ForLoop*> for_loops_;
  std::vector<Scope*> scope_;
  std::vector<Expr*> scope_exprs_;
  std::vector<Expr*> exprs_;
};

// Const version of IrVisitor
class TORCH_CUDA_CU_API ConstIrVisitor : public OptOutConstDispatch {
 public:
  std::vector<const Expr*> handle(const std::vector<const Expr*>& expr);

 protected:
  using OptOutConstDispatch::handle;

  virtual void handle(const ForLoop*) override;
  virtual void handle(const IfThenElse*) override;

 protected:
  std::vector<const ForLoop*> for_loops_;
  std::vector<const Scope*> scope_;
  std::vector<const Expr*> scope_exprs_;
  std::vector<const Expr*> exprs_;
};

// Base Expr Mutator class that visits all nodes with IrVisitor, and then
// inserts new expressions, replaces expressions based on insertion/replace
// maps provided or removes existing expressions. These replacement
// maps are expected to accumulate during an initial traversal, then
// runs an insertion based on them after the overloaded traversal.
//
// Order of mutations may be important, mutations are ordered according to the
// following rules:
//   Before/After insertions are ordered as registered when reverse_order ==
//   false,
//
//   Before/After insertions are in reverse order as registered when
//   reverse_order == true,
//
//   Before/After insertions are done before Expr replacements, so reference for
//   insertions must be on pre-replaced Exprs
//
//   Removal of expressions is done after replacements.
//
// To place in a scope that is empty, simply provide a nullptr reference
// Since insertions are done in order, it's possible to insert an expression in
// an empty scope, and then use that inserted scope as a reference for
// subsequent mutations.
class ExprMutator : public IrVisitor {
 protected:
  std::vector<Expr*> traverseAndInsert(
      const std::vector<Expr*>& expr,
      bool reverse_order = false);

  std::vector<Expr*> mutate(bool reverse_order = false);

  using IrVisitor::handle;
  // Registration function which *don't* need to be called "in place" during
  // visiting.
  void registerInsertBefore(Expr* reference, Expr* new_expr, Scope* scope);
  void registerInsertAfter(Expr* reference, Expr* new_expr, Scope* scope);
  void registerReplace(Expr* reference, Expr* new_expr, Scope* scope);
  void registerRemove(Expr* expr_to_remove, Scope* scope);

  // Registration function which need to be called "in place" during visiting.
  // I.E.
  // if you want to insert before/after or replace an Expr, you must register
  // when in handle(Expr*) of that expr.
  void registerInsertBefore(Expr* reference, Expr* new_expr);
  void registerInsertAfter(Expr* reference, Expr* new_expr);
  void registerReplace(Expr* reference, Expr* new_expr);
  void registerRemove(Expr* expr_to_remove);

 private:
  enum class MutationMode { BEFORE, AFTER, REPLACE, REMOVE };

  void registerMutation(
      Expr* ref,
      Expr* new_expr,
      Scope* scope,
      MutationMode mode);

  struct MutationInformation {
    Expr* reference = nullptr;
    Expr* new_expr = nullptr;
    Scope* scope = nullptr;
    MutationMode mode = MutationMode::BEFORE;
  };

  // Track insertions as they're registered
  std::vector<MutationInformation> insertions_;

  // Track replacements as they're registered
  std::vector<MutationInformation> replacements_;

  // Track removal as they're registered
  std::vector<MutationInformation> removal_;
};

} // namespace kir
} // namespace cuda
} // namespace fuser
} // namespace jit
} // namespace torch