1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143
|
#pragma once
#ifdef TORCH_ENABLE_LLVM
#include <torch/csrc/Export.h>
#include <torch/csrc/jit/tensorexpr/codegen.h>
#include <torch/csrc/jit/tensorexpr/ir.h>
#include <torch/csrc/jit/tensorexpr/ir_visitor.h>
#include <c10/util/Optional.h>
#include <unordered_map>
#include <vector>
namespace torch {
namespace jit {
namespace tensorexpr {
class LLVMCodeGenImpl;
class LLVMCodeGenCallee;
class TORCH_API LLVMCodeGen : public CodeGen {
public:
explicit LLVMCodeGen(
StmtPtr stmt,
const std::vector<BufferArg>& args,
at::Device device = at::kCPU,
const std::string& kernel_func_name = "func",
Dtype dtype = kInt,
c10::optional<std::string> triple = c10::nullopt,
c10::optional<std::string> cpu = c10::nullopt,
c10::optional<std::string> attrs = c10::nullopt);
explicit LLVMCodeGen(StmtPtr stmt);
LLVMCodeGen() = delete;
~LLVMCodeGen() override;
// Cleans up all the memory used during LLVM code generation pass except
// the generated kernel. After calling this method, users should not call
// methods like `getCodeText` that require the LLVMCodeGenImpl data. However,
// users can continue to call this kernel using `call` and `call_raw`.
void cleanup_memory();
TORCH_API void call(const std::vector<CallArg>& args) override;
TORCH_API void call_raw(const std::vector<void*>& args) override;
TORCH_API void call_with_numel(void** args, int64_t numel) override;
at::Tensor empty_strided(
c10::IntArrayRef size,
c10::IntArrayRef stride,
c10::optional<c10::ScalarType> dtype_opt,
c10::optional<c10::Layout> layout_opt,
c10::optional<c10::Device> device_opt,
c10::optional<bool> pin_memory_opt) override;
template <typename T>
T value() {
return value<T>(nullptr);
}
template <typename T>
T value(std::vector<void*>& args) {
return value<T>(args.data());
}
template <typename T>
T value(void** args) {
T (*fp)(void**) = (T(*)(void**))getKernelAddress(callee_.get());
T rv = fp(args);
return rv;
}
std::string getCodeText(const std::string& attr = "") override;
private:
void* getKernelAddress(LLVMCodeGenCallee* callee);
std::unique_ptr<LLVMCodeGenCallee> callee_;
std::unique_ptr<LLVMCodeGenImpl> impl_;
};
struct TORCH_API LLVMCodeGenBuilder {
using BufferArg = CodeGen::BufferArg;
LLVMCodeGenBuilder(StmtPtr stmt, std::vector<BufferArg> args)
: stmt_(stmt), args_(std::move(args)) {}
LLVMCodeGenBuilder& device(at::Device device) {
device_ = device;
return *this;
}
LLVMCodeGenBuilder& kernelFuncName(std::string name) {
kernelFuncName_ = std::move(name);
return *this;
}
LLVMCodeGenBuilder& dtype(Dtype d) {
dtype_ = d;
return *this;
}
LLVMCodeGenBuilder& triple(std::string triple) {
triple_ = std::move(triple);
return *this;
}
LLVMCodeGenBuilder& cpu(std::string cpu) {
cpu_ = std::move(cpu);
return *this;
}
LLVMCodeGenBuilder& attrs(std::string attrs) {
attrs_ = std::move(attrs);
return *this;
}
std::unique_ptr<LLVMCodeGen> build() {
return std::make_unique<LLVMCodeGen>(
stmt_, args_, device_, kernelFuncName_, dtype_, triple_, cpu_, attrs_);
}
private:
StmtPtr stmt_;
std::vector<BufferArg> args_;
at::Device device_ = at::kCPU;
std::string kernelFuncName_ = "func";
Dtype dtype_ = kInt;
c10::optional<std::string> triple_ = c10::nullopt;
c10::optional<std::string> cpu_ = c10::nullopt;
c10::optional<std::string> attrs_ = c10::nullopt;
};
TORCH_API c10::optional<std::string>& LLVMTargetTriple();
TORCH_API c10::optional<std::string>& LLVMTargetCPU();
TORCH_API c10::optional<std::string>& LLVMTargetAttrs();
TORCH_API bool& LLVMAOTWorkflow();
} // namespace tensorexpr
} // namespace jit
} // namespace torch
#endif // TORCH_ENABLE_LLVM
|