File: block_codegen.h

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (146 lines) | stat: -rw-r--r-- 4,290 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
#pragma once

#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>

#include <ATen/ATen.h>
#include <torch/csrc/jit/resource_guard.h>
#include <torch/csrc/jit/tensorexpr/analysis.h>
#include <torch/csrc/jit/tensorexpr/codegen.h>
#include <torch/csrc/jit/tensorexpr/ir.h>
#include <torch/csrc/jit/tensorexpr/ir_printer.h>
#include <torch/csrc/jit/tensorexpr/ir_visitor.h>
#include <torch/csrc/jit/tensorexpr/unique_name_manager.h>

namespace torch::jit::tensorexpr {

// A class that analyzes the given program relevant for Block backend.
class BlockAnalysis : public IRVisitor {
 public:
  bool is_buf_store_target(const BufPtr& buf) const {
    return store_targets_.count(buf) > 0;
  }

  const std::unordered_set<BufPtr>& loads() const {
    return loads_;
  }

  const std::unordered_set<BufPtr>& stores() const {
    return store_targets_;
  }

  int64_t block_size() const {
    return block_size_;
  }

  bool areBufsInMap(const std::unordered_set<BufPtr>& bufs) const;

  BufPtr getMultiDimBuf(const BufPtr& buf) const;

  std::string getInputName(const BufPtr& buf) const;

  std::string getFlatInputName(const BufPtr& buf) const {
    return getInputName(buf) + "_flat";
  }

  std::unordered_map<std::string, BufPtr> getBufferMap() const {
    return map_input_to_tensor_bufs_;
  }

 private:
  void visit(const StorePtr& v) override;
  void visit(const LoadPtr& v) override;
  void visit(const ForPtr& v) override;

  std::unordered_map<std::string, BufPtr> map_input_to_tensor_bufs_;
  std::unordered_set<BufPtr> store_targets_;
  std::unordered_set<BufPtr> loads_;
  int64_t block_size_ = 32;
};

// A class that overrides the underlying IRPrinter to produce Block.
class BlockPrinter : public IRPrinter {
 public:
  BlockPrinter(std::ostream* os, BlockAnalysis* block_analysis)
      : IRPrinter(*os), block_analysis_(block_analysis) {}

  using IRPrinter::name_manager;
  using IRPrinter::visit;

 private:
  BlockAnalysis* block_analysis_;
  std::unordered_map<std::string, int> dim_values_map;
  std::vector<std::string> dim_names = {"N", "H", "W", "C"};
  std::vector<std::string> flat_dim_names = {"N", "NH", "NHW", "NHWC"};
  void PrintTensorInfo(const std::unordered_set<BufPtr>& bufs);
  void PrintArguments(const std::unordered_set<BufPtr>& bufs);
  void PrintBufferInfo(const std::unordered_set<BufPtr>& bufs);
  void PrintDistribution(const std::unordered_set<BufPtr>& bufs);
  void PrintLoop(const std::unordered_set<BufPtr>& bufs, bool block_idx = true);
  void PrintReshapeInfo(
      const std::unordered_set<BufPtr>& bufs,
      bool reverse = false);
  void PrintDMAs(const std::unordered_set<BufPtr>& bufs);
  void PrintAdjustBuffers(const std::unordered_set<BufPtr>& bufs);

  void visit(const ForPtr& v) override;
  void visit(const LoadPtr& v) override;
  void visit(const StorePtr& v) override;
  void visit(const BlockPtr& v) override;
  void visit(const AddPtr& v) override;
  void visit(const MulPtr& v) override;
};

class TORCH_API BlockCodeGen : public CodeGen {
 public:
  template <typename... Ts>
  /* implicit */
  BlockCodeGen(StmtPtr stmt, Ts... ts)
      : CodeGen(
            stmt,
            std::vector<BufferArg>({BufferArg(ts)...}),
            at::Device(at::kCPU)) {
    Initialize();
  }

  BlockCodeGen(
      StmtPtr stmt,
      const std::vector<BufferArg>& buffer_args,
      at::Device device = at::Device(at::kCPU),
      const std::string& kernel_func_name = "func")
      : CodeGen(std::move(stmt), buffer_args, device, kernel_func_name) {
    Initialize();
  }

  ~BlockCodeGen() override;

  void call(const std::vector<CallArg>& args) override;
  void call_raw(const std::vector<void*>& args) override;

  void Initialize();

  std::string getCodeText(const std::string& attr = "") override {
    return oss_.str();
  }

 private:
  UniqueNameManager* name_manager() {
    if (!printer_) {
      throw std::runtime_error("Null IRPrinter is not expected");
    }
    return printer_->name_manager();
  }

  std::ostream& os() {
    return printer_->os();
  }

  std::ostringstream oss_;
  std::unique_ptr<BlockPrinter> printer_;
  std::unique_ptr<BlockAnalysis> block_analysis_;

  std::string GetUniqueFuncName(const std::string& func_prefix);
};
} // namespace torch::jit::tensorexpr