File: remove_dropout.cpp

package info (click to toggle)
pytorch 1.7.1-7
  • links: PTS, VCS
  • area: main
  • in suites: bullseye
  • size: 80,340 kB
  • sloc: cpp: 670,830; python: 343,991; ansic: 67,845; asm: 5,503; sh: 2,924; java: 2,888; xml: 266; makefile: 244; ruby: 148; yacc: 144; objc: 51; lex: 44
file content (55 lines) | stat: -rw-r--r-- 1,621 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
#include <torch/csrc/jit/passes/remove_dropout.h>

namespace torch {
namespace jit {

namespace {
bool isDropoutRemovable(const Node* node) {
  const auto inputs = node->inputs();
  TORCH_INTERNAL_ASSERT(inputs.size() == 3);
  const Value* training_input = inputs[2];
  auto optional_ivalue = toIValue(training_input);
  if (!optional_ivalue) {
    return false;
  }
  const IValue& val = optional_ivalue.value();
  TORCH_INTERNAL_ASSERT(val.isBool());
  const bool is_training = val.toBool();
  return !is_training;
}

void removeDropoutImpl(Block* block) {
  std::vector<Node*> deleted_nodes;

  for (auto it = block->nodes().rbegin(); it != block->nodes().rend(); it++) {
    Node* node = *it;
    for (auto block : node->blocks()) {
      removeDropoutImpl(block);
    }
    if ((node->kind() == c10::Symbol::fromQualString("aten::dropout") ||
         node->kind() == c10::Symbol::fromQualString("aten::dropout_")) &&
        isDropoutRemovable(*it)) {
      // Input tensor of dropout.
      Value* input_value = node->inputs()[0];
      // Output tensor.
      Value* output_value = node->outputs()[0];
      output_value->replaceAllUsesWith(input_value);
      deleted_nodes.push_back(node);
    }
  }
  for (auto del_node : deleted_nodes) {
    del_node->destroy();
  }
}
} // namespace

void removeDropout(script::Module& module) {
  TORCH_CHECK(
      !module.hasattr("training") || !module.is_training(),
      "Dropout removal module in training mode is not yet supported");
  auto graph = module.get_method("forward").graph();
  removeDropoutImpl(graph->block());
}

} // namespace jit
} // namespace torch