File: cuda_codegen.h

package info (click to toggle)
pytorch 1.7.1-7
  • links: PTS, VCS
  • area: main
  • in suites: bullseye
  • size: 80,340 kB
  • sloc: cpp: 670,830; python: 343,991; ansic: 67,845; asm: 5,503; sh: 2,924; java: 2,888; xml: 266; makefile: 244; ruby: 148; yacc: 144; objc: 51; lex: 44
file content (250 lines) | stat: -rw-r--r-- 6,947 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
#pragma once

#include <unordered_map>
#include <unordered_set>

#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/nvrtc_stub/ATenNVRTC.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/csrc/jit/resource_guard.h>
#include <torch/csrc/jit/tensorexpr/codegen.h>
#include <torch/csrc/jit/tensorexpr/ir.h>
#include <torch/csrc/jit/tensorexpr/ir_printer.h>
#include <torch/csrc/jit/tensorexpr/ir_visitor.h>
#include <torch/csrc/jit/tensorexpr/unique_name_manager.h>

namespace torch {
namespace jit {
namespace tensorexpr {

// A class that analyzes the given program relevant for Cuda backends.
class CudaAnalysis : public IRVisitor {
 public:
  CudaAnalysis() {
    gpu_block_extents_ = {new IntImm(1), new IntImm(1), new IntImm(1)};
    gpu_thread_extents_ = {new IntImm(1), new IntImm(1), new IntImm(1)};
  }
  bool is_buf_store_target(const Buf* buf) const {
    return store_targets_.count(buf) > 0;
  }

  const std::unordered_set<const Var*>& thread_local_bufs() const {
    return thread_local_bufs_;
  }

  const std::unordered_set<const Var*>& cross_block_bufs() const {
    return cross_block_bufs_;
  }

  const std::vector<const Expr*>& gpu_block_extents() const {
    return gpu_block_extents_;
  }

  const std::vector<const Expr*>& gpu_thread_extents() const {
    return gpu_thread_extents_;
  }

 private:
  void visit(const Store* v) override {
    store_targets_.insert(v->buf());
  }

  void visit(const Allocate* v) override;
  void visit(const Free* v) override;
  void visit(const For* v) override;

  std::unordered_set<const Buf*> store_targets_;
  std::unordered_set<const Var*> thread_local_bufs_;
  std::unordered_set<const Var*> cross_block_bufs_;

  std::vector<const Expr*> gpu_block_extents_;
  std::vector<const Expr*> gpu_thread_extents_;
};

// An IRMutator that replaces binding loop options with Cuda metavars, and masks
// statements blocks which should execute with less reach than the launch
// parameter extent.
//
// We do this by segmenting each block into chunks which should have the same
// execution parameters, then if those params differ from the max mask each dim.
class GPUMetaVarRewriter : public IRMutator {
 public:
  explicit GPUMetaVarRewriter(const CudaAnalysis* cuda_analysis)
      : cuda_analysis_(cuda_analysis) {
    gpu_block_vars_ = {new Var("blockIdx.x", kInt),
                       new Var("blockIdx.y", kInt),
                       new Var("blockIdx.z", kInt)};
    gpu_thread_vars_ = {new Var("threadIdx.x", kInt),
                        new Var("threadIdx.y", kInt),
                        new Var("threadIdx.z", kInt)};

    current_block_reach_ = {new IntImm(1), new IntImm(1), new IntImm(1)};
    current_thread_reach_ = {new IntImm(1), new IntImm(1), new IntImm(1)};
  }

  Stmt* mutate(const For* v) override;
  Stmt* mutate(const Block* v) override;

  const std::vector<const Var*>& gpu_block_vars() const {
    return gpu_block_vars_;
  }

  const std::vector<const Var*>& gpu_thread_vars() const {
    return gpu_thread_vars_;
  }

  const std::vector<const Expr*>& gpu_block_extents() const {
    return cuda_analysis_->gpu_block_extents();
  }

  const std::vector<const Expr*>& gpu_thread_extents() const {
    return cuda_analysis_->gpu_thread_extents();
  }

 private:
  // When processing a block, stores the contents of each sub-segment.
  class Segment {
   public:
    void reset(bool mask) {
      stmts_.clear();
      mask_ = mask;
    }

    bool empty() const {
      return stmts_.empty();
    }

    std::vector<Stmt*>& stmts() {
      return stmts_;
    }
    bool mask() {
      return mask_;
    }

   private:
    std::vector<Stmt*> stmts_;
    bool mask_{true};
  };

  // Returns true if the current execution scope is equivalent to the launch
  // parameters.
  bool isFullExtent();

  std::vector<const Var*> gpu_block_vars_;
  std::vector<const Var*> gpu_thread_vars_;

  std::vector<const Expr*> current_block_reach_;
  std::vector<const Expr*> current_thread_reach_;

  const CudaAnalysis* cuda_analysis_;
};

// A class that overrides the underlying IRPrinter to produce Cuda C.
class CudaPrinter : public IRPrinter {
 public:
  explicit CudaPrinter(
      std::ostream* os,
      const CudaAnalysis* cuda_analysis,
      bool has_random)
      : IRPrinter(*os), cuda_analysis_(cuda_analysis) {
    if (has_random) {
      rand_func_ = new Var("rand", kHandle);
    }
  }

  void visit(const Cast* v) override;
  void visit(const Intrinsics* v) override;
  void visit(const For* v) override;

  void visit(const Load* v) override;
  void visit(const Store* v) override;
  void visit(const AtomicAdd* v) override;
  void visit(const Max* v) override;
  void visit(const Min* v) override;
  void visit(const IfThenElse* v) override;
  void visit(const Block* v) override;
  void visit(const Allocate* v) override;
  void visit(const Free* v) override;
  void visit(const Let* v) override;

  const Var* rand_func() const {
    return rand_func_;
  }

  using IRPrinter::name_manager;
  using IRPrinter::visit;

 private:
  const Var* rand_func_;
  const CudaAnalysis* cuda_analysis_;
};

// Construct Cuda C from the buffer and tensor input, and invoke the kernel
// when real arguments are provided.
class TORCH_CUDA_API CudaCodeGen : public CodeGen {
 public:
  template <typename... Ts>
  CudaCodeGen(Stmt* stmt, Ts... ts)
      : CodeGen(
            stmt,
            std::vector<BufferArg>({BufferArg(ts)...}),
            at::Device(at::kCUDA, at::cuda::current_device())) {
    Initialize();
  }

  CudaCodeGen(
      Stmt* stmt,
      const std::vector<BufferArg>& buffer_args,
      at::Device device = at::Device(at::kCUDA, at::cuda::current_device()))
      : CodeGen(stmt, buffer_args, device) {
    Initialize();
  }

  ~CudaCodeGen() override;

  void call(const std::vector<CallArg>& args) override;

  template <typename... Ts>
  void operator()(const Ts&... ts) {
    call(std::vector<CallArg>({CallArg(ts)...}));
  }

  const std::vector<const Expr*>& gpu_block_extents() const {
    return cuda_analysis_->gpu_block_extents();
  }

  const std::vector<const Expr*>& gpu_thread_extents() const {
    return cuda_analysis_->gpu_thread_extents();
  }

 private:
  void Initialize();

  void CompileToNVRTC(const std::string& code, const std::string& func_name);

  UniqueNameManager* name_manager() {
    if (!printer_) {
      throw std::runtime_error("Null IRPrinter is not expected");
    }
    return printer_->name_manager();
  }

  std::ostream& os() {
    return printer_->os();
  }

  std::ostringstream oss_;
  std::unique_ptr<CudaPrinter> printer_;
  std::unique_ptr<CudaAnalysis> cuda_analysis_;
  std::unique_ptr<GPUMetaVarRewriter> metavar_rewriter_;
  CUfunction function_;
  bool has_random_ = false;

  std::string GetUniqueFuncName(const std::string& func_prefix);
};

} // namespace tensorexpr
} // namespace jit
} // namespace torch