File: llvm_codegen.h

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (143 lines) | stat: -rw-r--r-- 3,850 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
#pragma once

#ifdef TORCH_ENABLE_LLVM
#include <torch/csrc/Export.h>

#include <torch/csrc/jit/tensorexpr/codegen.h>
#include <torch/csrc/jit/tensorexpr/ir.h>
#include <torch/csrc/jit/tensorexpr/ir_visitor.h>

#include <c10/util/Optional.h>

#include <unordered_map>
#include <vector>

namespace torch {
namespace jit {
namespace tensorexpr {

class LLVMCodeGenImpl;
class LLVMCodeGenCallee;

class TORCH_API LLVMCodeGen : public CodeGen {
 public:
  explicit LLVMCodeGen(
      StmtPtr stmt,
      const std::vector<BufferArg>& args,
      at::Device device = at::kCPU,
      const std::string& kernel_func_name = "func",
      Dtype dtype = kInt,
      c10::optional<std::string> triple = c10::nullopt,
      c10::optional<std::string> cpu = c10::nullopt,
      c10::optional<std::string> attrs = c10::nullopt);
  explicit LLVMCodeGen(StmtPtr stmt);

  LLVMCodeGen() = delete;
  ~LLVMCodeGen() override;

  // Cleans up all the memory used during LLVM code generation pass except
  // the generated kernel. After calling this method, users should not call
  // methods like `getCodeText` that require the LLVMCodeGenImpl data. However,
  // users can continue to call this kernel using `call` and `call_raw`.
  void cleanup_memory();

  TORCH_API void call(const std::vector<CallArg>& args) override;
  TORCH_API void call_raw(const std::vector<void*>& args) override;
  TORCH_API void call_with_numel(void** args, int64_t numel) override;

  at::Tensor empty_strided(
      c10::IntArrayRef size,
      c10::IntArrayRef stride,
      c10::optional<c10::ScalarType> dtype_opt,
      c10::optional<c10::Layout> layout_opt,
      c10::optional<c10::Device> device_opt,
      c10::optional<bool> pin_memory_opt) override;

  template <typename T>
  T value() {
    return value<T>(nullptr);
  }

  template <typename T>
  T value(std::vector<void*>& args) {
    return value<T>(args.data());
  }

  template <typename T>
  T value(void** args) {
    T (*fp)(void**) = (T(*)(void**))getKernelAddress(callee_.get());
    T rv = fp(args);
    return rv;
  }

  std::string getCodeText(const std::string& attr = "") override;

 private:
  void* getKernelAddress(LLVMCodeGenCallee* callee);

  std::unique_ptr<LLVMCodeGenCallee> callee_;
  std::unique_ptr<LLVMCodeGenImpl> impl_;
};

struct TORCH_API LLVMCodeGenBuilder {
  using BufferArg = CodeGen::BufferArg;

  LLVMCodeGenBuilder(StmtPtr stmt, std::vector<BufferArg> args)
      : stmt_(stmt), args_(std::move(args)) {}

  LLVMCodeGenBuilder& device(at::Device device) {
    device_ = device;
    return *this;
  }

  LLVMCodeGenBuilder& kernelFuncName(std::string name) {
    kernelFuncName_ = std::move(name);
    return *this;
  }

  LLVMCodeGenBuilder& dtype(Dtype d) {
    dtype_ = d;
    return *this;
  }

  LLVMCodeGenBuilder& triple(std::string triple) {
    triple_ = std::move(triple);
    return *this;
  }

  LLVMCodeGenBuilder& cpu(std::string cpu) {
    cpu_ = std::move(cpu);
    return *this;
  }

  LLVMCodeGenBuilder& attrs(std::string attrs) {
    attrs_ = std::move(attrs);
    return *this;
  }

  std::unique_ptr<LLVMCodeGen> build() {
    return std::make_unique<LLVMCodeGen>(
        stmt_, args_, device_, kernelFuncName_, dtype_, triple_, cpu_, attrs_);
  }

 private:
  StmtPtr stmt_;
  std::vector<BufferArg> args_;
  at::Device device_ = at::kCPU;
  std::string kernelFuncName_ = "func";
  Dtype dtype_ = kInt;
  c10::optional<std::string> triple_ = c10::nullopt;
  c10::optional<std::string> cpu_ = c10::nullopt;
  c10::optional<std::string> attrs_ = c10::nullopt;
};

TORCH_API c10::optional<std::string>& LLVMTargetTriple();
TORCH_API c10::optional<std::string>& LLVMTargetCPU();
TORCH_API c10::optional<std::string>& LLVMTargetAttrs();
TORCH_API bool& LLVMAOTWorkflow();

} // namespace tensorexpr
} // namespace jit
} // namespace torch

#endif // TORCH_ENABLE_LLVM