File: block_codegen.h

package info (click to toggle)
pytorch 1.7.1-7
  • links: PTS, VCS
  • area: main
  • in suites: bullseye
  • size: 80,340 kB
  • sloc: cpp: 670,830; python: 343,991; ansic: 67,845; asm: 5,503; sh: 2,924; java: 2,888; xml: 266; makefile: 244; ruby: 148; yacc: 144; objc: 51; lex: 44
file content (149 lines) | stat: -rw-r--r-- 4,192 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
#pragma once

#include <string>
#include <unordered_map>
#include <unordered_set>

#include <ATen/ATen.h>
#include <torch/csrc/jit/resource_guard.h>
#include <torch/csrc/jit/tensorexpr/analysis.h>
#include <torch/csrc/jit/tensorexpr/codegen.h>
#include <torch/csrc/jit/tensorexpr/ir.h>
#include <torch/csrc/jit/tensorexpr/ir_printer.h>
#include <torch/csrc/jit/tensorexpr/ir_visitor.h>
#include <torch/csrc/jit/tensorexpr/unique_name_manager.h>

namespace torch {
namespace jit {
namespace tensorexpr {

// A class that analyzes the given program relevant for Block backend.
class BlockAnalysis : public IRVisitor {
 public:
  bool is_buf_store_target(const Buf* buf) const {
    return store_targets_.count(buf) > 0;
  }

  const std::unordered_set<const Buf*>& loads() const {
    return loads_;
  }

  const std::unordered_set<const Buf*>& stores() const {
    return store_targets_;
  }

  int block_size() const {
    return block_size_;
  }

  bool areBufsInMap(const std::unordered_set<const Buf*>& bufs) const;

  const Buf* getMultiDimBuf(const Buf* buf) const;

  std::string getInputName(const Buf* buf) const;

  std::string getFlatInputName(const Buf* buf) const {
    return getInputName(buf) + "_flat";
  }

  std::unordered_map<std::string, const Buf*> getBufferMap() const {
    return map_input_to_tensor_bufs_;
  }

 private:
  void visit(const Store* v) override;
  void visit(const Load* v) override;
  void visit(const For* v) override;

  std::unordered_map<std::string, const Buf*> map_input_to_tensor_bufs_;
  std::unordered_set<const Buf*> store_targets_;
  std::unordered_set<const Buf*> loads_;
  int block_size_ = 32;
};

// A class that overrides the underlying IRPrinter to produce Block.
class BlockPrinter : public IRPrinter {
 public:
  BlockPrinter(std::ostream* os, const BlockAnalysis* block_analysis)
      : IRPrinter(*os), block_analysis_(block_analysis) {}

  using IRPrinter::name_manager;
  using IRPrinter::visit;

 private:
  const BlockAnalysis* block_analysis_;
  std::unordered_map<std::string, int> dim_values_map;
  std::vector<std::string> dim_names = {"N", "H", "W", "C"};
  std::vector<std::string> flat_dim_names = {"N", "NH", "NHW", "NHWC"};
  void PrintTensorInfo(const std::unordered_set<const Buf*>& bufs);
  void PrintArguments(const std::unordered_set<const Buf*>& bufs);
  void PrintBufferInfo(const std::unordered_set<const Buf*>& bufs);
  void PrintDistribution(const std::unordered_set<const Buf*>& bufs);
  void PrintLoop(
      const std::unordered_set<const Buf*>& bufs,
      bool block_idx = true);
  void PrintReshapeInfo(
      const std::unordered_set<const Buf*>& bufs,
      bool reverse = false);
  void PrintDMAs(const std::unordered_set<const Buf*>& bufs);
  void PrintAdjustBuffers(const std::unordered_set<const Buf*>& bufs);

  void visit(const For* v) override;
  void visit(const Load* v) override;
  void visit(const Store* v) override;
  void visit(const Block* v) override;
  void visit(const Add* v) override;
  void visit(const Mul* v) override;
};

class TORCH_API BlockCodeGen : public CodeGen {
 public:
  template <typename... Ts>
  /* implicit */
  BlockCodeGen(Stmt* stmt, Ts... ts)
      : CodeGen(
            stmt,
            std::vector<BufferArg>({BufferArg(ts)...}),
            at::Device(at::kCPU)) {
    Initialize();
  }

  BlockCodeGen(
      Stmt* stmt,
      const std::vector<BufferArg>& buffer_args,
      at::Device device = at::Device(at::kCPU))
      : CodeGen(stmt, buffer_args, device) {
    Initialize();
  }

  ~BlockCodeGen() override;

  void call(const std::vector<CallArg>& args) override;

  void Initialize();

  std::string getCodeText() override {
    return oss_.str();
  }

 private:
  UniqueNameManager* name_manager() {
    if (!printer_) {
      throw std::runtime_error("Null IRPrinter is not expected");
    }
    return printer_->name_manager();
  }

  std::ostream& os() {
    return printer_->os();
  }

  std::ostringstream oss_;
  std::unique_ptr<BlockPrinter> printer_;
  std::unique_ptr<BlockAnalysis> block_analysis_;

  std::string GetUniqueFuncName(const std::string& func_prefix);
};
} // namespace tensorexpr
} // namespace jit
} // namespace torch