File: remove_inplace_ops.cpp

package info (click to toggle)
pytorch 1.7.1-7
  • links: PTS, VCS
  • area: main
  • in suites: bullseye
  • size: 80,340 kB
  • sloc: cpp: 670,830; python: 343,991; ansic: 67,845; asm: 5,503; sh: 2,924; java: 2,888; xml: 266; makefile: 244; ruby: 148; yacc: 144; objc: 51; lex: 44
file content (83 lines) | stat: -rw-r--r-- 2,725 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
#include <torch/csrc/jit/passes/remove_inplace_ops.h>

namespace torch {
namespace jit {
namespace {
static const std::unordered_map<NodeKind, NodeKind> inPlaceToOutOfPlace = {
    {aten::add_, aten::add},
    {aten::sub_, aten::sub},
    {aten::div_, aten::div},
    {aten::mul_, aten::mul},
    {aten::masked_fill_, aten::masked_fill},
    {aten::zero_, aten::zeros_like},
    {aten::fill_, aten::full_like}};

// This is a horrible no good awful hack to "fill in" the TensorOptions
// arguments of zeros_like and full_like so that the defaults are filled
// in.  Ugh.  Would be better to just run the frontend to get the correct
// arity here.
static const std::unordered_map<NodeKind, int> expectedInputCount = {
    {aten::zero_, 6},
    {aten::fill_, 7}};

bool isInplaceOp(const Node* node) {
  return inPlaceToOutOfPlace.count(node->kind()) != 0;
}

// Remove all in-place ops and replace them with out-of-place equivalents.
// e.g.
//   %foo = aten::add_(%foo, %n)
// becomes
//   %foo.2 = aten::add(%foo, %n)
//
// NOTE: this is NOT SAFE, since it assumes that the LHS is not aliased by
// another value. This is only to avoid breaking ONNX export; when alias
// analysis is done we can emit a warning if someone tries to export.
void RemoveInplaceOps(Block* block) {
  auto graph = block->owningGraph();
  auto it = block->nodes().begin();
  while (it != block->nodes().end()) {
    auto node = *it;
    ++it;
    for (auto block : node->blocks()) {
      RemoveInplaceOps(block);
    }

    if (isInplaceOp(node)) {
      // create a replacement out of place op
      auto newNode = graph->create(inPlaceToOutOfPlace.at(node->kind()));
      newNode->insertBefore(node);
      newNode->setScope(node->scope());
      // copy inputs
      for (auto input : node->inputs()) {
        newNode->addInput(input);
      }

      int additionalInputCount = 0;
      if (expectedInputCount.find(node->kind()) != expectedInputCount.end()) {
        additionalInputCount = expectedInputCount.at(node->kind()) -
            static_cast<int>(newNode->inputs().size());
      }

      for (int i = 0; i < additionalInputCount; ++i) {
        auto noneNode = graph->createNone();
        noneNode->insertBefore(newNode);
        newNode->addInput(noneNode->output());
      }

      // Create a new output node and replace all uses of self with it
      newNode->output()->copyMetadata(node->output());
      node->replaceAllUsesWith(newNode);
      node->inputs()[0]->replaceAllUsesAfterNodeWith(
          newNode, newNode->output());
      node->destroy();
    }
  }
}
} // namespace

void RemoveInplaceOps(const std::shared_ptr<Graph>& graph) {
  RemoveInplaceOps(graph->block());
}
} // namespace jit
} // namespace torch