File: ir_graphviz.h

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (128 lines) | stat: -rw-r--r-- 3,795 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
#pragma once

#include <c10/macros/Export.h>
#include <torch/csrc/jit/codegen/cuda/dispatch.h>

#include <sstream>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>

namespace torch {
namespace jit {
namespace fuser {
namespace cuda {

// Generates a DOT (https://www.graphviz.org) graph
// representation of a fuser IR
//
// Usage:
// 1) Add calls to IrGraphGenerator::print(), for example:
//  `IrGraphGenerator::print(&fusion, "ir.dot")`
//
// 2) Call IrGraphGenerator::print() from a debugger. Using gdb for example:
//  `call IrGraphGenerator::print(&fusion, "ir.dot",
//      IrGraphGenerator::DetailLevel::Explicit)`
//
// Notes:
//  - When called from the debugger, the detail_level must be
//    explicitly passed in (most debuggers don't support default arguments)
//
//  - The output dot file path can't include shell specific notations,
//    for example you can't use "~/temp/ir.dot" ("/home/user/temp/ir.dot"
//    must be used instead)
//
class TORCH_CUDA_CU_API IrGraphGenerator : private OptInConstDispatch {
 public:
  enum class DetailLevel {
    ComputeOnly, // Only dataflow (compute) nodes
    Basic, // Compute + schedule, with minimal details (default)
    Explicit, // Additional details (ex. symbolic names for scalar constants)
    Verbose, // Includes all values and dead definitions
  };

  using ExprColorMap = std::unordered_map<const Expr*, size_t>;

 public:
  static void print(
      const Fusion* fusion,
      const char* filename,
      DetailLevel detail_level = DetailLevel::Basic,
      ExprColorMap* expr_color_map = nullptr);

  static std::string toGraphviz(
      const Fusion* fusion,
      DetailLevel detail_level,
      ExprColorMap* expr_color_map = nullptr);

 private:
  IrGraphGenerator(
      const Fusion* fusion,
      DetailLevel detail_level,
      ExprColorMap* expr_color_map = nullptr);
  ~IrGraphGenerator() override = default;

  std::string generate();

  void generateComputeGraph();
  void generateScheduleGraph();

  void handle(const Statement*) override;
  void handle(const Val*) override;
  void handle(const Expr*) override;

  void handle(const TensorDomain*) override;
  void handle(const TensorView*) override;
  void handle(const IterDomain*) override;

  void handle(const Bool*) override;
  void handle(const Double*) override;
  void handle(const Int*) override;
  void handle(const ComplexDouble*) override;
  void handle(const NamedScalar*) override;

  void handle(const ARangeOp*) override;
  void handle(const UnaryOp*) override;
  void handle(const BinaryOp*) override;
  void handle(const TernaryOp*) override;
  void handle(const RNGOp*) override;
  void handle(const BroadcastOp*) override;
  void handle(const ReductionOp*) override;

  void handle(const Split*) override;
  void handle(const Merge*) override;

  // lookup the graph id, creating one if not found
  std::string getid(const Statement* stm);

  bool visited(const Statement* s) const {
    return visited_.find(s) != visited_.end();
  }

  void addArc(
      const Statement* src,
      const Statement* dst,
      const std::string& style = "");

  void printExpr(const Expr* expr, const std::string& label);
  void printValue(const Val* val, const std::string& label);

 private:
  const DetailLevel detail_level_;
  const Fusion* const fusion_;
  std::stringstream graph_def_;
  std::unordered_map<const Statement*, std::string> id_map_;
  std::unordered_set<const Statement*> visited_;
  std::unordered_set<const Val*> inputs_;
  std::unordered_set<const Val*> outputs_;
  std::vector<const TensorView*> tensor_views_;
  std::vector<std::string> arcs_;
  int next_id_ = 1;
  ExprColorMap* expr_color_map_ = nullptr;
};

} // namespace cuda
} // namespace fuser
} // namespace jit
} // namespace torch