1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149
|
#pragma once
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <ATen/ATen.h>
#include <torch/csrc/jit/resource_guard.h>
#include <torch/csrc/jit/tensorexpr/analysis.h>
#include <torch/csrc/jit/tensorexpr/codegen.h>
#include <torch/csrc/jit/tensorexpr/ir.h>
#include <torch/csrc/jit/tensorexpr/ir_printer.h>
#include <torch/csrc/jit/tensorexpr/ir_visitor.h>
#include <torch/csrc/jit/tensorexpr/unique_name_manager.h>
namespace torch {
namespace jit {
namespace tensorexpr {
// A class that analyzes the given program relevant for Block backend.
class BlockAnalysis : public IRVisitor {
public:
bool is_buf_store_target(BufPtr buf) const {
return store_targets_.count(buf) > 0;
}
const std::unordered_set<BufPtr>& loads() const {
return loads_;
}
const std::unordered_set<BufPtr>& stores() const {
return store_targets_;
}
int block_size() const {
return block_size_;
}
bool areBufsInMap(const std::unordered_set<BufPtr>& bufs) const;
BufPtr getMultiDimBuf(BufPtr buf) const;
std::string getInputName(BufPtr buf) const;
std::string getFlatInputName(BufPtr buf) const {
return getInputName(buf) + "_flat";
}
std::unordered_map<std::string, BufPtr> getBufferMap() const {
return map_input_to_tensor_bufs_;
}
private:
void visit(StorePtr v) override;
void visit(LoadPtr v) override;
void visit(ForPtr v) override;
std::unordered_map<std::string, BufPtr> map_input_to_tensor_bufs_;
std::unordered_set<BufPtr> store_targets_;
std::unordered_set<BufPtr> loads_;
int block_size_ = 32;
};
// A class that overrides the underlying IRPrinter to produce Block.
class BlockPrinter : public IRPrinter {
public:
BlockPrinter(std::ostream* os, BlockAnalysis* block_analysis)
: IRPrinter(*os), block_analysis_(block_analysis) {}
using IRPrinter::name_manager;
using IRPrinter::visit;
private:
BlockAnalysis* block_analysis_;
std::unordered_map<std::string, int> dim_values_map;
std::vector<std::string> dim_names = {"N", "H", "W", "C"};
std::vector<std::string> flat_dim_names = {"N", "NH", "NHW", "NHWC"};
void PrintTensorInfo(const std::unordered_set<BufPtr>& bufs);
void PrintArguments(const std::unordered_set<BufPtr>& bufs);
void PrintBufferInfo(const std::unordered_set<BufPtr>& bufs);
void PrintDistribution(const std::unordered_set<BufPtr>& bufs);
void PrintLoop(const std::unordered_set<BufPtr>& bufs, bool block_idx = true);
void PrintReshapeInfo(
const std::unordered_set<BufPtr>& bufs,
bool reverse = false);
void PrintDMAs(const std::unordered_set<BufPtr>& bufs);
void PrintAdjustBuffers(const std::unordered_set<BufPtr>& bufs);
void visit(ForPtr v) override;
void visit(LoadPtr v) override;
void visit(StorePtr v) override;
void visit(BlockPtr v) override;
void visit(AddPtr v) override;
void visit(MulPtr v) override;
};
class TORCH_API BlockCodeGen : public CodeGen {
public:
template <typename... Ts>
/* implicit */
BlockCodeGen(StmtPtr stmt, Ts... ts)
: CodeGen(
stmt,
std::vector<BufferArg>({BufferArg(ts)...}),
at::Device(at::kCPU)) {
Initialize();
}
BlockCodeGen(
StmtPtr stmt,
const std::vector<BufferArg>& buffer_args,
at::Device device = at::Device(at::kCPU),
const std::string& kernel_func_name = "func")
: CodeGen(stmt, buffer_args, device, kernel_func_name) {
Initialize();
}
~BlockCodeGen() override;
void call(const std::vector<CallArg>& args) override;
void call_raw(const std::vector<void*>& args) override;
void Initialize();
std::string getCodeText(const std::string& attr = "") override {
return oss_.str();
}
private:
UniqueNameManager* name_manager() {
if (!printer_) {
throw std::runtime_error("Null IRPrinter is not expected");
}
return printer_->name_manager();
}
std::ostream& os() {
return printer_->os();
}
std::ostringstream oss_;
std::unique_ptr<BlockPrinter> printer_;
std::unique_ptr<BlockAnalysis> block_analysis_;
std::string GetUniqueFuncName(const std::string& func_prefix);
};
} // namespace tensorexpr
} // namespace jit
} // namespace torch
|