1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242
|
#include <c10/util/irange.h>
#include <torch/csrc/jit/passes/value_refinement_utils.h>
namespace torch {
namespace jit {
// [value refinement algorithm]
// When a comparison like `cond = len(x) == 4` or `cond = len(x) != 4` is made,
// `cond` value carries information (refinements) about the len of `x`.
// When `cond` is used as the conditional of an if statement, the information
// it carries for its true value can be inserted into the true block
// and the same for its false value.
// For something like `y = len(x) if len(x) == 1 else 1`, in the true branch
// we can replace len(x) with 1 because the true refinements from `len(x) == 1`
// will be present in the true block.
// Additionally, we can optimize something like:
// if len(x) != 4:
// raise Exception(...)
// return len(x)
// Because the true block always throws, whatever refinements exist in the false
// block become present in the owning block of the if node. We can also merge
// refinements carried by two different booleans across an if node join by
// taking the intersections of their refinements.
// if cond:
// z = len(x) == 4 and len(y) == 5
// else:
// z = len(x) == 4
// Here, z's true value will refine the len(x) to 4, but not len(y).
// If the code was written as:
// if cond:
// z = len(x) == 4 and len(y) == 5
// else:
// z = False
//
// Then z's true value would refine x and y, because if z is true it had to have
// come from the true block. Code that is written with `and` or `or` will
// desugar to something similar. Additionally, any True refinements that were
// present on `cond` can also be associated with the if node True output value.
// The intersection of the refinements is the Value* which are in both
// refinements and are refined to the same length
// in an example like:
// if cond:
// x = len(a) == 4 and len(b) == 5
// else:
// x = len(a) == 4
// For the x output of the node we take the intersection between
// the refinements stored on each block output, which will result
// in only the refinement of len(a) == 4
ListRefinement intersectRefinements(
const ListRefinement& ref1,
const ListRefinement& ref2) {
ListRefinement out;
for (const auto& pair : ref1) {
auto val2 = ref2.find(pair.first);
if (val2 != ref2.end() && val2->second == pair.second) {
out[pair.first] = pair.second;
}
}
return out;
}
// To union, just take all refinements from both inputs. We do not need to worry
// about len refinements disagreeing because a path like `if len(x) == 4 and
// len(x) == 5` will never be taken
// in an example like:
// if len(a) == 5:
// x = len(b) == 4
// else:
// x = False
// For the output x Value, if is true then the refinements present in the true
// block must also be true, so we take the union of `len(a) == 5` and len(b) ==
// 4` and assign them to true refinements of the output x value. This is a very
// common pattern in desugaring of `and` or `or` boolean expressions
ListRefinement unionRefinements(
const ListRefinement& ref1,
const ListRefinement& ref2) {
ListRefinement out = ref1;
out.insert(ref2.begin(), ref2.end());
return out;
}
void joinIfRefinements(
Node* if_node,
std::unordered_set<Block*>& throwing_blocks,
ListRefinement& curr_block_refinements,
ListRefinement& true_block_refinements,
ListRefinement& false_block_refinements,
std::unordered_map<Value*, BooleanRefinementMapping>&
boolean_value_refinements) {
IfView if_n(if_node);
Block* b = if_node->owningBlock();
bool true_block_throws = throwing_blocks.count(if_n.thenBlock());
bool false_block_throws = throwing_blocks.count(if_n.elseBlock());
// if one block throws, the refinements for the other block
// become present in the current block, and all bool outputs
// of the if node take their refinements from non throwing block
// output
if (true_block_throws || false_block_throws) {
if (true_block_throws && false_block_throws) {
throwing_blocks.insert(b);
return;
}
if (true_block_throws) {
curr_block_refinements.insert(
false_block_refinements.begin(), false_block_refinements.end());
} else {
curr_block_refinements.insert(
true_block_refinements.begin(), true_block_refinements.end());
}
Block* non_throwing_block =
true_block_throws ? if_node->blocks().at(1) : if_node->blocks().at(0);
for (const auto i : c10::irange(if_n.outputs().size())) {
if (boolean_value_refinements.count(
non_throwing_block->outputs().at(i))) {
boolean_value_refinements[if_node->outputs().at(i)] =
boolean_value_refinements[non_throwing_block->outputs().at(i)];
}
}
return;
}
for (const auto i : c10::irange(if_n.outputs().size())) {
if (!(if_n.outputs().at(i)->type() == BoolType::get())) {
return;
}
Value* true_v = if_n.thenOutputs().at(i);
Value* false_v = if_n.elseOutputs().at(i);
if (!boolean_value_refinements.count(true_v) &&
!boolean_value_refinements.count(false_v) &&
!constant_as<bool>(true_v) && !constant_as<bool>(false_v)) {
return;
}
// if either block has a constant bool output, e.g. `true` on the
// true block, then for the `false` value we can take the false
// refinements present on the false block and from the other block
// output value bc if the output is false it had to have come from the
// false block. if len(a) == 5:
// x = len(b) == 4
// else:
// x = False
// if x is true, then we know both len(a) == 5 and len(b) == 4
//
// if neither block has a constant bool value, we just take the
// intersection of the refinements from boolean outputs.
// if cond:
// x = len(a) == 4 and len(b) == 5
// else:
// x = len(a) == 4
// here, we know if x is true, then len(a) == 4, but not len(b)
// == 5, because that refinement is not present in the true block.
// TODO: could also take intersection of refinements present in
// both blocks, but it's not a real use case.
// boolean_value_refinements[value] is safe to access because
// BooleanRefinementMapping has a default constructor
BooleanRefinementMapping out;
if (auto maybe_bool = constant_as<bool>(true_v)) {
if (*maybe_bool) {
out = BooleanRefinementMapping::FalseRefinements(unionRefinements(
boolean_value_refinements[false_v].false_refine(),
false_block_refinements));
} else {
out = BooleanRefinementMapping::TrueRefinements(unionRefinements(
boolean_value_refinements[false_v].true_refine(),
false_block_refinements));
}
} else if (auto maybe_bool = constant_as<bool>(false_v)) {
if (*maybe_bool) {
out = BooleanRefinementMapping::FalseRefinements(unionRefinements(
boolean_value_refinements[true_v].false_refine(),
true_block_refinements));
} else {
out = BooleanRefinementMapping::TrueRefinements(unionRefinements(
boolean_value_refinements[true_v].true_refine(),
true_block_refinements));
}
} else if (
boolean_value_refinements.count(true_v) &&
boolean_value_refinements.count(false_v)) {
out = boolean_value_refinements[true_v].intersectBooleanRefinementMapping(
boolean_value_refinements[false_v]);
}
boolean_value_refinements[if_n.outputs().at(i)] = out;
}
}
bool handleCommonRefinentOperators(
Node* n,
std::unordered_set<Block*>& throwing_blocks,
std::unordered_map<Value*, BooleanRefinementMapping>& info) {
if (n->kind() == prim::RaiseException) {
throwing_blocks.insert(n->owningBlock());
return true;
}
if (n->kind() == aten::__not__ &&
n->inputs().at(0)->type()->cast<BoolType>()) {
// __not__(inp) -> reverse refinements
if (info.count(n->input())) {
auto& input_ref = info[n->input()];
info[n->output()] = BooleanRefinementMapping(
input_ref.false_refine(), input_ref.true_refine());
}
return true;
}
if (n->matches("aten::eq(bool a, bool b) -> bool") ||
(n->matches("aten::ne(bool a, bool b) -> bool"))) {
for (size_t const_index : {0, 1}) {
if (n->input(const_index)->node()->kind() != prim::Constant) {
continue;
}
auto const_input = constant_as<bool>(n->input(const_index)).value();
auto non_const_input = n->input(1 - const_index);
if (!info.count(non_const_input)) {
continue;
}
// value == False / value != True -> equivalent to __not__ value
// value == True / value != False -> equivalent to value
auto& input_ref = info[non_const_input];
if ((!const_input && n->kind() == aten::eq) ||
(const_input && n->kind() == aten::ne)) {
info[n->output()] = BooleanRefinementMapping(
input_ref.false_refine(), input_ref.true_refine());
} else {
info[n->output()] = BooleanRefinementMapping(
input_ref.true_refine(), input_ref.false_refine());
}
}
return true;
}
return false;
}
} // namespace jit
} // namespace torch
|