File: remove_dropout.cpp

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (62 lines) | stat: -rw-r--r-- 1,876 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
#include <torch/csrc/jit/passes/remove_dropout.h>

namespace torch {
namespace jit {

namespace {
bool isDropoutRemovable(const Node* node) {
  const auto inputs = node->inputs();
  TORCH_INTERNAL_ASSERT(inputs.size() == 3);
  const Value* training_input = inputs[2];
  auto optional_ivalue = toIValue(training_input);
  if (!optional_ivalue) {
    return false;
  }
  const IValue& val = optional_ivalue.value();
  TORCH_INTERNAL_ASSERT(val.isBool());
  const bool is_training = val.toBool();
  return !is_training;
}

void removeDropoutImpl(Block* block) {
  std::vector<Node*> deleted_nodes;

  for (auto it = block->nodes().rbegin(); it != block->nodes().rend(); it++) {
    Node* node = *it;
    for (auto block : node->blocks()) {
      removeDropoutImpl(block);
    }
    if ((node->kind() == c10::Symbol::fromQualString("aten::dropout") ||
         node->kind() == c10::Symbol::fromQualString("aten::dropout_") ||
         node->kind() == c10::Symbol::fromQualString("aten::feature_dropout") ||
         node->kind() ==
             c10::Symbol::fromQualString("aten::feature_dropout_")) &&
        isDropoutRemovable(*it)) {
      // Input tensor of dropout.
      Value* input_value = node->inputs()[0];
      // Output tensor.
      Value* output_value = node->outputs()[0];
      output_value->replaceAllUsesWith(input_value);
      deleted_nodes.push_back(node);
    }
  }
  for (auto del_node : deleted_nodes) {
    del_node->destroy();
  }
}
} // namespace

void removeDropout(std::shared_ptr<Graph>& graph) {
  removeDropoutImpl(graph->block());
}

void removeDropout(script::Module& module) {
  TORCH_CHECK(
      !module.hasattr("training") || !module.is_training(),
      "Dropout removal module in training mode is not yet supported");
  auto graph = module.get_method("forward").graph();
  removeDropout(graph);
}

} // namespace jit
} // namespace torch