1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188
|
#pragma once
#include <torch/csrc/jit/tensorexpr/ir.h>
#include <torch/csrc/jit/tensorexpr/ir_visitor.h>
#include <torch/csrc/jit/tensorexpr/stmt.h>
#include <torch/csrc/jit/tensorexpr/tensor.h>
namespace torch {
namespace jit {
namespace tensorexpr {
class HasRand : public IRVisitor {
public:
HasRand(Stmt* stmt) : stmt_(stmt) {
stmt_->accept(this);
}
bool has_rand() const {
return has_rand_;
}
private:
void visit(const Intrinsics* v) override {
if (v->op_type() == IntrinsicsOp::kRand) {
has_rand_ = true;
} else {
IRVisitor::visit(v);
}
}
Stmt* stmt_;
bool has_rand_ = false;
};
template <typename Node>
class NodeFinder : public IRVisitor {
public:
virtual void visit(const Node* v) override {
nodes.push_back((Node*)v);
IRVisitor::visit(v);
}
static std::vector<Node*> find(Stmt* s) {
NodeFinder<Node> nf;
s->accept(&nf);
return nf.nodes;
}
std::vector<Node*> nodes;
};
class VarFinder : public IRVisitor {
public:
virtual void visit(const Var* v) override {
vars_.insert(v);
IRVisitor::visit(v);
}
static std::unordered_set<const Var*> find(Stmt* s) {
VarFinder nf;
s->accept(&nf);
return nf.vars();
}
static std::unordered_set<const Var*> find(const Expr* e) {
VarFinder nf;
e->accept(&nf);
return nf.vars();
}
const std::unordered_set<const Var*>& vars() {
return vars_;
}
private:
std::unordered_set<const Var*> vars_;
};
// Finds all kinds of write operations to the provided Buf.
class WritesToBuf : public IRVisitor {
public:
WritesToBuf(const Buf* target) : target_(target) {}
std::vector<const Stmt*> writes() {
return writes_;
}
static std::vector<const Stmt*> find(Stmt* s, const Buf* b) {
WritesToBuf finder(b);
s->accept(&finder);
return finder.writes();
}
private:
void visit(const Store* v) override {
if (v->buf() == target_) {
writes_.push_back(v);
}
}
void visit(const AtomicAdd* v) override {
if (v->buf() == target_) {
writes_.push_back(v);
}
}
const Buf* target_;
std::vector<const Stmt*> writes_;
};
// Traverses the IR to determine if a particular Var is modified within it.
class ModifiesVarChecker : public IRVisitor {
public:
ModifiesVarChecker(const Var* v) : var_(v) {}
static bool check(const Stmt* s, const Var* v) {
ModifiesVarChecker checker(v);
s->accept(&checker);
return checker.found();
}
bool found() {
return found_;
}
private:
void visit(const Store* v) override {
if (v->buf()->base_handle() == var_) {
found_ = true;
return;
}
IRVisitor::visit(v);
}
void visit(const AtomicAdd* v) override {
if (v->buf()->base_handle() == var_) {
found_ = true;
return;
}
IRVisitor::visit(v);
}
void visit(const Let* v) override {
if (v->var() == var_) {
found_ = true;
return;
}
IRVisitor::visit(v);
}
void visit(const For* v) override {
if (v->var() == var_) {
found_ = true;
return;
}
IRVisitor::visit(v);
}
const Var* var_;
bool found_{false};
};
// A class that analyzes the given program relevant for Block backend
// It creates a map of multi dim buffers and their flat verions
class CreateBufferMap : public IRVisitor {
public:
const std::unordered_map<std::string, const Buf*>& getBufferMap() const {
return map_input_to_tensor_bufs_;
}
private:
void visit(const Store* v) override {
auto load_node = dynamic_cast<const Load*>(v->value());
auto call_node = dynamic_cast<const FunctionCall*>(v->value());
if (load_node || call_node) {
TORCH_INTERNAL_ASSERT(!(load_node && call_node));
auto t_buf = load_node ? load_node->buf() : call_node->tensor()->buf();
if (load_node) {
map_input_to_tensor_bufs_.emplace(t_buf->name_hint(), v->buf());
} else {
map_input_to_tensor_bufs_.emplace(v->buf()->name_hint(), t_buf);
}
}
v->value()->accept(this);
}
std::unordered_map<std::string, const Buf*> map_input_to_tensor_bufs_;
};
} // namespace tensorexpr
} // namespace jit
} // namespace torch
|