1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163
|
#pragma once
#include <c10/util/irange.h>
#include <torch/csrc/jit/ir/ir.h>
namespace torch {
namespace jit {
struct IfView {
explicit IfView(Node* node) : node_(node) {
AT_ASSERT(node->kind() == ::c10::prim::If);
}
Value* cond() const {
return node_->input(0);
}
Block* thenBlock() const {
return node_->blocks().at(0);
}
Block* elseBlock() const {
return node_->blocks().at(1);
}
ArrayRef<Value*> thenOutputs() const {
return thenBlock()->outputs();
}
ArrayRef<Value*> elseOutputs() const {
return elseBlock()->outputs();
}
ArrayRef<Value*> outputs() const {
return node_->outputs();
}
Node* node() const {
return node_;
}
operator Node*() const {
return node_;
}
void permuteOutputs(const std::vector<size_t>& new_output_order) {
node_->permuteOutputs(new_output_order);
thenBlock()->permuteOutputs(new_output_order);
elseBlock()->permuteOutputs(new_output_order);
}
private:
Node* node_;
};
struct LoopView {
explicit LoopView(Node* node) : node_(node) {
AT_ASSERT(
node->kind() == ::c10::prim::Loop || node->kind() == ::c10::onnx::Loop);
}
Block* bodyBlock() const {
return node_->blocks().at(0);
}
Value* cond() const {
return node_->input(0);
}
Value* maxTripCount() const {
return node_->input(0);
}
Value* inputCond() const {
return node_->input(1);
}
Value* nextCond() const {
return bodyBlock()->outputs().at(0);
}
Value* currentTripCount() const {
return bodyBlock()->inputs().at(0);
}
ArrayRef<Value*> carriedInputs() const {
// skip trip count and cond
return node_->inputs().slice(2);
}
ArrayRef<Value*> carriedInputsWithCond() const {
// skip trip count and cond
return node_->inputs().slice(1);
}
ArrayRef<Value*> carriedOutputs() const {
return node_->outputs();
}
ArrayRef<Value*> bodyCarriedInputs() const {
// skip trip count and cond
return bodyBlock()->inputs().slice(1);
}
ArrayRef<Value*> bodyCarriedOutputs() const {
return bodyBlock()->outputs().slice(1);
}
Node* node() const {
return node_;
}
operator Node*() const {
return node_;
}
void permuteLoopCarried(const std::vector<size_t>& new_output_order) {
node_->permuteOutputs(new_output_order);
// skip trip count and cond
node_->permuteInputs(adjustIndices(2, new_output_order));
auto adjusted_block_order = adjustIndices(1, new_output_order);
bodyBlock()->permuteOutputs(adjusted_block_order);
bodyBlock()->permuteInputs(adjusted_block_order);
}
void replaceMaxTripCount(Value* new_max_trip_count) {
node_->replaceInput(0, new_max_trip_count);
}
void replaceInputCondition(Value* new_input_condition) {
node_->replaceInput(1, new_input_condition);
}
// our way of encoding loops makes them difficult to turn back into python
// syntax. we have to check properties of the condition and trip count inputs
// to figure out which one it initially was. ModifiedLoops are not directly
// mappable to either For or While
enum LoopType { While, For, ModifiedLoop };
LoopType loopType() {
auto trip_count = toIValue(maxTripCount());
auto cond_input = toIValue(inputCond());
auto cond_next = toIValue(nextCond());
bool condition_is_always_true =
cond_input && cond_input->toBool() && cond_next && cond_next->toBool();
bool trip_count_is_specified = !trip_count || // trip is not a constant
trip_count->toInt() !=
std::numeric_limits<int64_t>::max() || // it is a constant but not
// the default one
currentTripCount()->uses().size() >
0; // it is actually being used in the body.
if (condition_is_always_true) {
// if the trip count was not specified this was a user-written while True:
return trip_count_is_specified ? For : While;
} else {
if (trip_count_is_specified) {
return ModifiedLoop;
}
return While;
}
}
private:
Node* node_;
// adjust index_ordering by adding indices 0 - thorugh adjust, and
// incrementing all existing inputs by adjust
static std::vector<size_t> adjustIndices(
size_t adjust,
const std::vector<size_t>& index_ordering) {
std::vector<size_t> adjusted;
adjusted.reserve(adjust + index_ordering.size());
for (const auto i : c10::irange(adjust)) {
adjusted.push_back(i);
}
for (auto index : index_ordering) {
adjusted.push_back(index + adjust);
}
return adjusted;
}
};
} // namespace jit
} // namespace torch
|