1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
|
#include <torch/csrc/jit/passes/remove_dropout.h>
namespace torch {
namespace jit {
namespace {
bool isDropoutRemovable(const Node* node) {
const auto inputs = node->inputs();
TORCH_INTERNAL_ASSERT(inputs.size() == 3);
const Value* training_input = inputs[2];
auto optional_ivalue = toIValue(training_input);
if (!optional_ivalue) {
return false;
}
const IValue& val = optional_ivalue.value();
TORCH_INTERNAL_ASSERT(val.isBool());
const bool is_training = val.toBool();
return !is_training;
}
void removeDropoutImpl(Block* block) {
std::vector<Node*> deleted_nodes;
for (auto it = block->nodes().rbegin(); it != block->nodes().rend(); it++) {
Node* node = *it;
for (auto block : node->blocks()) {
removeDropoutImpl(block);
}
if ((node->kind() == c10::Symbol::fromQualString("aten::dropout") ||
node->kind() == c10::Symbol::fromQualString("aten::dropout_") ||
node->kind() == c10::Symbol::fromQualString("aten::feature_dropout") ||
node->kind() ==
c10::Symbol::fromQualString("aten::feature_dropout_")) &&
isDropoutRemovable(*it)) {
// Input tensor of dropout.
Value* input_value = node->inputs()[0];
// Output tensor.
Value* output_value = node->outputs()[0];
output_value->replaceAllUsesWith(input_value);
deleted_nodes.push_back(node);
}
}
for (auto del_node : deleted_nodes) {
del_node->destroy();
}
}
} // namespace
void removeDropout(std::shared_ptr<Graph>& graph) {
removeDropoutImpl(graph->block());
}
void removeDropout(script::Module& module) {
TORCH_CHECK(
!module.hasattr("training") || !module.is_training(),
"Dropout removal module in training mode is not yet supported");
auto graph = module.get_method("forward").graph();
removeDropout(graph);
}
} // namespace jit
} // namespace torch
|