File: ir_iostream.h

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (160 lines) | stat: -rw-r--r-- 4,378 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
#pragma once

#include <c10/macros/Export.h>

#include <torch/csrc/jit/codegen/cuda/dispatch.h>

#include <c10/util/irange.h>

#include <iostream>

namespace torch {
namespace jit {
namespace fuser {
namespace cuda {

class Fusion;
namespace kir {
class Kernel;
class Scope;
} // namespace kir

//! Define pretty printing functions for IR nodes
//!
//! This class is intended for debug printing, so it attempts
//! to handle invalid states as well.
//!
class TORCH_CUDA_CU_API IrPrinter : public OptInConstDispatch {
  static constexpr char const* kTab = "  ";

 public:
  explicit IrPrinter(std::ostream& os) : os_(os) {}

  // Indent the generated code
  std::ostream& indent() {
    for (const auto i : c10::irange(indent_size_)) {
      (void)i; // Suppress unused variable warning
      os_ << "  ";
    }
    return os_;
  }

  void resetIndent() {
    indent_size_ = 0;
  }

  bool printInline() const {
    return print_inline_;
  }

  using OptInConstDispatch::handle;

  virtual void handle(Fusion* f);

  // handle calls some non const fusion ops,
  // eventhough fusion should remain unchanged.
  // Need to look into this.
  virtual void handle(const Fusion* f) {
    // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
    handle(const_cast<Fusion*>(f));
  }

  virtual void handle(Fusion& f) {
    handle(&f);
  }

  virtual void handle(const kir::Kernel* kernel);
  virtual void handle(kir::Kernel& kernel);

  void handleScope(const kir::Scope& scope);

  void handle(const Statement* s) final;
  void handle(const Val* v) final;
  void handle(const Expr* e) final;

  void handle(const IterDomain*) final;
  void handle(const TensorDomain*) final;
  void handle(const TensorView*) final;

  void handle(const Bool*) final;
  void handle(const Double*) final;
  void handle(const Int*) final;
  void handle(const ComplexDouble*) final;
  void handle(const NamedScalar*) final;

  void handle(const ARangeOp*) final;
  void handle(const UnaryOp*) final;
  void handle(const BinaryOp*) final;
  void handle(const TernaryOp*) final;
  void handle(const RNGOp*) final;
  void handle(const ReductionOp*) final;
  void handle(const GroupedReductionOp*) final;
  void handle(const WelfordOp*) final;
  void handle(const GroupedWelfordOp*) final;
  void handle(const LoadStoreOp*) final;
  void handle(const MmaOp*) final;
  void handle(const BroadcastOp*) final;
  void handle(const TransposeOp*) final;
  void handle(const ExpandOp*) final;
  void handle(const ShiftOp*) final;
  void handle(const GatherOp*) final;
  void handle(const ViewAsScalar*) final;
  void handle(const ViewOp*) final;

  void handle(const kir::Predicate*) final;
  void handle(const kir::TensorIndex*) final;
  void handle(const kir::IntPair*) final;

  void handle(const kir::GridBroadcast*) final;
  void handle(const kir::GridReduction*) final;
  void handle(const kir::GroupedGridReduction*) final;
  void handle(const kir::GridWelford*) final;
  void handle(const kir::GroupedGridWelford*) final;
  void handle(const kir::ForLoop*) final;
  void handle(const kir::IfThenElse*) final;
  void handle(const kir::Allocate*) final;
  void handle(const kir::BlockSync*) final;
  void handle(const kir::GridSync*) final;
  void handle(const kir::CpAsyncWait*) final;
  void handle(const kir::CpAsyncCommit*) final;
  void handle(const kir::InitMagicZero*) final;
  void handle(const kir::UpdateMagicZero*) final;
  void handle(const kir::AllocateFusedReduction*) final;
  void handle(const kir::Swizzle2DInt*) final;
  void handle(const kir::PairSelect*) final;

  // IR math printer overrides these to prevent them from printing, keep
  // override
  void handle(const Split*) override;
  void handle(const Merge*) override;
  void handle(const Swizzle2D*) override;

  void print_inline(const Statement* stmt) {
    bool prev = print_inline_;
    print_inline_ = true;
    handle(stmt);
    print_inline_ = prev;
  }

 protected:
  std::ostream& os() {
    return os_;
  }

 private:
  std::ostream& os_;
  bool print_inline_ = false;
  int indent_size_ = 0;
};

TORCH_CUDA_CU_API std::ostream& operator<<(
    std::ostream& os,
    const Statement* stmt);

TORCH_CUDA_CU_API std::ostream& operator<<(std::ostream& os, Fusion* f);
TORCH_CUDA_CU_API std::ostream& operator<<(std::ostream& os, Fusion& f);

} // namespace cuda
} // namespace fuser
} // namespace jit
} // namespace torch