File: ir_builder.h

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (140 lines) | stat: -rw-r--r-- 4,579 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
#pragma once

#include <torch/csrc/jit/codegen/cuda/fusion.h>
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
#include <torch/csrc/jit/codegen/cuda/ir_container.h>

namespace torch {
namespace jit {
namespace fuser {
namespace cuda {

namespace kir {
class Kernel;
}

class IrCloner;

// Passkey for builder to register properties with statements, and to call
// functions in IrContainer
class TORCH_CUDA_CU_API IrBuilderPasskey {
  friend class IrBuilder;

 public:
  // TODO: Collapse ir_container and Kernel once Kernel inherits from
  // IrContainer
  IrContainer* const ir_container_ = nullptr;

 private:
  explicit IrBuilderPasskey(IrContainer* ir_container);
};

//! IR builder interface
class TORCH_CUDA_CU_API IrBuilder {
 public:
  //! Allocate a new IR node, forwarding the arguments to the appropriate
  //! constructor and registering with the container
  template <class T, class... Args>
  static T* create(Args&&... args) {
    auto container = FusionGuard::getCurFusion();
    // return create<T>(container, std::forward<Args>(args)...);
    TORCH_INTERNAL_ASSERT(
        container != nullptr, "Need an active container to build IR.");
    T* node = new T(IrBuilderPasskey(container), std::forward<Args>(args)...);

    container->registerStmt(IrBuilderPasskey(container), node);

    return node;
  }

  //! Allocate a new IR node, forwarding the arguments to the appropriate
  //! constructor and registering with the container
  template <class T, class... Args>
  static T* create(IrContainer* container, Args&&... args) {
    TORCH_INTERNAL_ASSERT(
        container != nullptr, "Need an active container to build IR.");
    T* node = new T(IrBuilderPasskey(container), std::forward<Args>(args)...);

    container->registerStmt(IrBuilderPasskey(container), node);

    return node;
  }

  //! Clone an IR node, forwarding the arguments to the IrCloner constructor.
  //! Register clones with IrCloner's target container.
  template <class T>
  static T* clone(const T* src, IrCloner* ir_cloner);

  // Unary operations
  static Val* negExpr(Val* val);
  static Val* notExpr(Val* val);
  static Val* setExpr(Val* val);
  static Val* setExprNamedScalar(const std::string& name, Val* val);
  static Val* addressExprNamedScalar(const std::string& name, Val* val);

  // Binary operations
  static Val* andExpr(Val* lhs, Val* rhs);
  static Val* eqExpr(Val* lhs, Val* rhs);
  static Val* gtExpr(Val* lhs, Val* rhs);
  static Val* ltExpr(Val* lhs, Val* rhs);
  static Val* leExpr(Val* lhs, Val* rhs);
  static Val* geExpr(Val* lhs, Val* rhs);
  static Val* addExpr(Val* lhs, Val* rhs);
  static Val* subExpr(Val* lhs, Val* rhs);
  static Val* mulExpr(Val* lhs, Val* rhs);
  static Val* divExpr(Val* lhs, Val* rhs);
  static Val* ceilDivExpr(Val* lhs, Val* rhs);
  static Val* modExpr(Val* lhs, Val* rhs);
  static Val* maxExpr(Val* lhs, Val* rhs);
  static Val* minExpr(Val* lhs, Val* rhs);

  // Ternary operations
  static Val* whereExpr(Val* pred, Val* lhs, Val* rhs);

  // Swizzle operations
  static Val* swizzle2DIntExpr(
      Val* x,
      Val* y,
      Val* extent_x,
      Val* extent_y,
      Swizzle2DType swizzle_type);
  static Val* pairSelectExpr(Val* in, kir::PairSelect::Selection sel);

 private:
  static Val* newResult(DataType dtype);
  static Val* newArithmeticExpr(BinaryOpType op_type, Val* lhs, Val* rhs);
  static Val* newLogicExpr(BinaryOpType op_type, Val* lhs, Val* rhs);
};

//! A wrapper builder with static expression simplification
//!
//! Example:
//! - addExpr(new Int(1), new Int(2)) -> Int(3)
//! - addExpr(new Int(0), new NamedScalar("foo")) -> NamedScalar("foo")
//!
//! Designed to be used to simplify predicate and index expressions in
//! generated code. Also, the shift validation may fail without
//! this simplification.
class TORCH_CUDA_CU_API SimplifyingIrBuilder : public IrBuilder {
 public:
  static Val* negExpr(Val* val);
  static Val* notExpr(Val* val);

  static Val* addExpr(Int* lhs, Int::ScalarType rhs);
  static Val* addExpr(Val* lhs, Int::ScalarType rhs);
  static Val* addExpr(Int* lhs, Int* rhs);
  static Val* addExpr(Val* lhs, Val* rhs);
  static Val* subExpr(Val* lhs, Val* rhs);
  static Val* mulExpr(Int* lhs, Int::ScalarType rhs);
  static Val* mulExpr(Val* lhs, Int::ScalarType rhs);
  static Val* mulExpr(Int* lhs, Int* rhs);
  static Val* mulExpr(Val* lhs, Val* rhs);
  static Val* andExpr(Val* lhs, Val* rhs);
  static Val* maxExpr(Val* lhs, Val* rhs);
  static Val* minExpr(Val* lhs, Val* rhs);
};

} // namespace cuda
} // namespace fuser
} // namespace jit
} // namespace torch