1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230
|
#include <ATen/core/jit_type.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/passes/integer_value_refinement.h>
#include <torch/csrc/jit/passes/value_refinement_utils.h>
#include <torch/csrc/utils/memory.h>
namespace torch {
namespace jit {
using IntegerRefinement = std::unordered_map<Value*, int64_t>;
// see [value refinement algorithm] for full explanation.
// When a comparison like `cond = x == 4` or `cond = x != 4` is made,
// `cond` value carries information (refinements) about the value of `x`.
// in an example like:
// if x == 1:
// ...
// we can substitute all uses of x dominated by the true block
// with 1.
struct IntegerValueRefiner {
IntegerValueRefiner(std::shared_ptr<Graph> graph)
: graph_(std::move(graph)) {}
bool run() {
if (!blockHasIntComparisons(graph_->block())) {
return false;
}
IntegerRefinement refinements;
RefineIntegerValues(graph_->block(), refinements);
return changed_;
}
bool blockHasIntComparisons(Block* b) {
for (Node* n : b->nodes()) {
if (n->matches("aten::eq(int a, int b) -> bool") ||
n->matches("aten::ne(int a, int b) -> bool")) {
for (size_t const_index : {0, 1}) {
auto non_const_index = 1 - const_index;
if (n->inputs().at(const_index)->node()->kind() == prim::Constant &&
n->inputs().at(non_const_index)->uses().size() > 1) {
return true;
}
}
}
for (Block* block : n->blocks()) {
if (blockHasIntComparisons(block)) {
return true;
}
}
}
return false;
}
void removeIfNodeOutputsWithRefinements(
Node* if_node,
IntegerRefinement& true_block_refinements,
IntegerRefinement& false_block_refinements) {
// we are looking for cases where we can replace both block outputs with the
// same value, which opens up further optimization opportunities. The pass
// will already handle if both outputs are refined to the same constant.
// Here, we look for cases where one block output has been refined in the
// other block to be equal to the same constant value as the other other
// block output:
// graph(%y.1 : int):
// %one_constant : int = prim::Constant[value=1]()
// %3 : bool = aten::eq(%y.1, %one_constant)
// %15 : int = prim::If(%3)
// block0():
// -> (%one_constant)
// block1():
// -> (%y.1)
// return (%15)
// %15 can always be safely replaced with %y.1
// this is an important case for symbolic shape analysis
for (size_t block_index : {0, 1}) {
Block* if_block = if_node->blocks().at(block_index);
Block* other_if_block = if_node->blocks().at(1 - block_index);
for (size_t i = 0; i < if_node->outputs().size(); ++i) {
Value* block_output = if_block->outputs().at(i);
if (!block_output->type()->cast<IntType>()) {
continue;
}
// Value must be in scope for both blocks
// in example above, %y.1 cannot be defined in block1
if (!if_node->isDominatedBy(block_output->node())) {
continue;
}
// one constant value one not - we are looking for the pattern
// where y.1 is refined to the existing block output %one_constant
auto other_output = other_if_block->outputs().at(i);
auto other_const_value = other_output->type()->cast<IntType>()
? constant_as<int64_t>(other_output)
: c10::nullopt;
if (!other_const_value ||
block_output->node()->kind() == prim::Constant) {
continue;
}
// here, we are looking in refinements in the other block of our
// current output. in the example, we are looking for refinements of
// %y.1 in `block0`, and we are checking that %y.1 is refined
// to the constant value of %one_constant
const auto& other_block_refinements =
block_index == 0 ? false_block_refinements : true_block_refinements;
if (!other_block_refinements.count(block_output)) {
continue;
}
if (other_block_refinements.at(block_output) == *other_const_value) {
if_node->outputs().at(i)->replaceAllUsesWith(block_output);
changed_ = true;
}
}
}
}
// iteratively look through the block `b` for refinements or Value uses that
// can be refined, `block_refinements` are the refinements present starting at
// this block (and for all blocks dominated by this block).
IntegerRefinement RefineIntegerValues(
Block* b,
IntegerRefinement block_refinements) {
active_refinements_.push_back(&block_refinements);
for (Node* n : b->nodes()) {
if (n->matches("aten::eq(int a, int b) -> bool") ||
n->matches("aten::ne(int a, int b) -> bool")) {
for (size_t const_index : {0, 1}) {
if (auto ival = constant_as<int64_t>(n->inputs().at(const_index))) {
IntegerRefinement refine;
refine[n->inputs().at(1 - const_index)] = *ival;
info_[n->output()] = n->kind() == aten::eq
? BooleanRefinementMapping::TrueRefinements(std::move(refine))
: BooleanRefinementMapping::FalseRefinements(std::move(refine));
}
}
}
for (size_t input = 0; input < n->inputs().size(); ++input) {
Value* input_v = n->inputs().at(input);
if (!input_v->type()->cast<IntType>()) {
continue;
}
if (auto refine = tryFindRefinement(input_v)) {
WithInsertPoint guard(n);
auto refine_constant =
graph_->insertConstant(static_cast<int64_t>(*refine));
n->replaceInputWith(input_v, refine_constant);
changed_ = true;
}
}
if (n->kind() == prim::If) {
IfView if_n(n);
bool has_cond_ref = info_.count(if_n.cond()) != 0;
IntegerRefinement empty;
auto true_block_refinements = RefineIntegerValues(
if_n.thenBlock(),
has_cond_ref ? info_[if_n.cond()].true_refine() : empty);
auto false_block_refinements = RefineIntegerValues(
if_n.elseBlock(),
has_cond_ref ? info_[if_n.cond()].false_refine() : empty);
removeIfNodeOutputsWithRefinements(
n, true_block_refinements, false_block_refinements);
joinIfRefinements(
n,
throwing_blocks_,
block_refinements,
true_block_refinements,
false_block_refinements,
info_);
} else {
handleCommonRefinentOperators(n, throwing_blocks_, info_);
}
}
// iterating over all nodes in the block will not iterate over
// block outputs, so we need to add handling of them.
// %3 : int = prim::Constant[value=3]()
// %4 : bool = aten::eq(%y.1, %3)
// %a : int = prim::If(%4)
// block0():
// -> (%y.1)
// Here, we can replace y.1 with 3
for (size_t i = 0; i < b->outputs().size(); ++i) {
Value* output_v = b->outputs().at(i);
if (!output_v->type()->cast<IntType>()) {
continue;
}
if (auto refine = tryFindRefinement(output_v)) {
WithInsertPoint guard(b);
auto refine_constant =
graph_->insertConstant(static_cast<int64_t>(*refine));
b->replaceOutput(i, refine_constant);
changed_ = true;
}
}
active_refinements_.pop_back();
return block_refinements;
};
c10::optional<int64_t> tryFindRefinement(Value* v) {
for (const auto& ref : active_refinements_) {
auto maybe_refinement = ref->find(v);
if (maybe_refinement != ref->end()) {
return maybe_refinement->second;
}
}
return c10::nullopt;
}
std::shared_ptr<Graph> graph_;
// A stack of active refinements, one for each block
std::vector<IntegerRefinement*> active_refinements_;
// A map from Boolean Value * -> associated refinements
std::unordered_map<Value*, BooleanRefinementMapping> info_;
std::unordered_set<Block*> throwing_blocks_;
bool changed_ = false;
};
bool RefineIntegerValues(const std::shared_ptr<Graph>& graph) {
return IntegerValueRefiner(graph).run();
}
} // namespace jit
} // namespace torch
|