1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60
|
#pragma once
#ifdef TORCH_ENABLE_LLVM
#include <torch/csrc/WindowsTorchApiMacro.h>
#include <torch/csrc/jit/tensorexpr/codegen.h>
#include <torch/csrc/jit/tensorexpr/ir.h>
#include <torch/csrc/jit/tensorexpr/ir_visitor.h>
#include <unordered_map>
#include <vector>
namespace torch {
namespace jit {
namespace tensorexpr {
class LLVMCodeGenImpl;
class TORCH_API LLVMCodeGen : public CodeGen {
public:
explicit LLVMCodeGen(
Stmt* stmt,
const std::vector<BufferArg>& args,
at::Device device = at::kCPU,
Dtype dtype = kInt);
explicit LLVMCodeGen(Stmt* stmt);
LLVMCodeGen() = delete;
~LLVMCodeGen() override;
TORCH_API void call(const std::vector<CallArg>& args) override;
template <typename T>
T value() {
return value<T>(nullptr);
}
template <typename T>
T value(std::vector<void*>& args) {
return value<T>(args.data());
}
template <typename T>
T value(void** args) {
T (*fp)(void**) = (T(*)(void**))getKernelAddress(impl_.get());
T rv = fp(args);
return rv;
}
private:
void* getKernelAddress(LLVMCodeGenImpl* impl);
std::unique_ptr<LLVMCodeGenImpl> impl_;
};
} // namespace tensorexpr
} // namespace jit
} // namespace torch
#endif // TORCH_ENABLE_LLVM
|