1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162
|
#include <torch/csrc/jit/passes/requires_grad_analysis.h>
#include <ATen/core/jit_type.h>
#include <c10/util/irange.h>
#include <torch/csrc/autograd/autograd.h>
#include <torch/csrc/jit/ir/constants.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/runtime/operator.h>
#include <vector>
namespace torch {
namespace jit {
namespace {
bool getRequiresGrad(Value* value) {
return value->requires_grad();
}
void setRequiresGrad(Value* value, bool req_value) {
if (auto type = value->type()->cast<TensorType>()) {
value->setType(type->withRequiresGrad(req_value));
}
}
void setRequiresGrad(
at::ArrayRef<Value*> outputs,
const std::vector<bool>& values) {
AT_ASSERT(outputs.size() == values.size());
for (const auto i : c10::irange(values.size())) {
setRequiresGrad(outputs[i], values[i]);
}
}
void setRequiresGrad(Node* node, const std::vector<bool>& values) {
setRequiresGrad(node->outputs(), values);
}
std::vector<bool> bitwiseOr(std::vector<bool> a, const std::vector<bool>& b) {
AT_ASSERT(a.size() == b.size());
for (const auto i : c10::irange(a.size())) {
a[i] = a[i] || b[i];
}
return a;
}
void PropagateRequiresGradSimpleNode(Node* node) {
static const OperatorSet comparison_ops = {
"aten::lt(Tensor self, Tensor other) -> Tensor",
"aten::le(Tensor self, Tensor other) -> Tensor",
"aten::gt(Tensor self, Tensor other) -> Tensor",
"aten::ge(Tensor self, Tensor other) -> Tensor",
"aten::eq(Tensor self, Tensor other) -> Tensor",
"aten::ne(Tensor self, Tensor other) -> Tensor",
"aten::lt(Tensor self, Scalar other) -> Tensor",
"aten::le(Tensor self, Scalar other) -> Tensor",
"aten::gt(Tensor self, Scalar other) -> Tensor",
"aten::ge(Tensor self, Scalar other) -> Tensor",
"aten::eq(Tensor self, Scalar other) -> Tensor",
"aten::ne(Tensor self, Scalar other) -> Tensor",
};
// NOLINTNEXTLINE(bugprone-branch-clone)
if (node->isMemberOf(comparison_ops)) {
return setRequiresGrad(node->output(), false);
} else if (node->matches(
"aten::type_as(Tensor self, Tensor other) -> Tensor")) {
return setRequiresGrad(node->output(), node->input(0)->requires_grad());
} else if (node->matches("aten::detach(Tensor(a) self) -> Tensor(a)")) {
return setRequiresGrad(node->output(), false);
} else if (node->kind() == aten::tensor) {
if (auto grad_index =
node->schema().argumentIndexWithName("requires_grad")) {
if (auto const_arg = constant_as<bool>(node->inputs().at(*grad_index))) {
return setRequiresGrad(node->output(), *const_arg);
}
}
if (auto type = node->output()->type()->cast<TensorType>()) {
if (type->scalarType()) {
setRequiresGrad(
node->output(),
autograd::isDifferentiableType(*type->scalarType()));
}
}
return;
}
auto inputs = node->inputs();
auto outputs = node->outputs();
bool should_require =
std::any_of(inputs.begin(), inputs.end(), getRequiresGrad);
for (Value* output : outputs) {
if (auto type = output->type()->cast<TensorType>()) {
if (type->scalarType()) {
setRequiresGrad(
output,
should_require &&
autograd::isDifferentiableType(*type->scalarType()));
}
}
}
}
void PropagateRequiresGrad(Block* block);
void PropagateRequiresGrad(Node* node) {
if (node->kind() == prim::If) {
auto blocks = node->blocks();
auto true_block = blocks.at(0);
auto false_block = blocks.at(1);
PropagateRequiresGrad(true_block);
PropagateRequiresGrad(false_block);
auto outputs_require = bitwiseOr(
fmap(true_block->outputs(), getRequiresGrad),
fmap(false_block->outputs(), getRequiresGrad));
setRequiresGrad(node, outputs_require);
} else if (node->kind() == prim::Loop) {
auto body = node->blocks().at(0);
std::vector<bool> loop_inputs_require =
fmap(node->inputs().slice(2), getRequiresGrad);
std::vector<bool> body_inputs_require = loop_inputs_require;
std::vector<bool> body_outputs_require(node->outputs().size(), false);
std::vector<bool> new_body_inputs_require = body_inputs_require;
std::vector<bool> new_body_outputs_require = body_outputs_require;
// continue iterating until the results have converged
do {
body_inputs_require = new_body_inputs_require;
body_outputs_require = new_body_outputs_require;
new_body_inputs_require =
bitwiseOr(body_inputs_require, body_outputs_require);
setRequiresGrad(
body->param_node()->outputs().slice(1), new_body_inputs_require);
PropagateRequiresGrad(body);
new_body_outputs_require =
fmap(body->return_node()->inputs().slice(1), getRequiresGrad);
} while (new_body_inputs_require != body_inputs_require ||
new_body_outputs_require != body_outputs_require);
setRequiresGrad(node, bitwiseOr(body_outputs_require, loop_inputs_require));
} else {
PropagateRequiresGradSimpleNode(node);
}
}
void PropagateRequiresGrad(Block* block) {
for (Node* node : block->nodes()) {
PropagateRequiresGrad(node);
}
}
} // anonymous namespace
void PropagateRequiresGrad(std::shared_ptr<Graph>& graph) {
PropagateRequiresGrad(graph->block());
}
} // namespace jit
} // namespace torch
|