File: normalize_ops.cpp

package info (click to toggle)
pytorch 1.7.1-7
  • links: PTS, VCS
  • area: main
  • in suites: bullseye
  • size: 80,340 kB
  • sloc: cpp: 670,830; python: 343,991; ansic: 67,845; asm: 5,503; sh: 2,924; java: 2,888; xml: 266; makefile: 244; ruby: 148; yacc: 144; objc: 51; lex: 44
file content (87 lines) | stat: -rw-r--r-- 3,067 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
#include <torch/csrc/jit/passes/normalize_ops.h>
#include <c10/util/Exception.h>

namespace torch {
namespace jit {

namespace {

// map from op alias -> normalized op
static const std::unordered_map<Symbol, Symbol> alias_map = {
    {aten::absolute, aten::abs},     {aten::absolute_, aten::abs_},
    {aten::clip, aten::clamp},       {aten::clip_, aten::clamp_},
    {aten::linalg_det, aten::det},   {aten::ger, aten::outer},
    {aten::arccos, aten::acos},      {aten::arccos_, aten::acos_},
    {aten::arcsin, aten::asin},      {aten::arcsin_, aten::asin_},
    {aten::arctan, aten::atan},      {aten::arctan_, aten::atan_},
    {aten::arccosh, aten::acosh},    {aten::arccosh_, aten::acosh_},
    {aten::arcsinh, aten::asinh},    {aten::arcsinh_, aten::asinh_},
    {aten::arctanh, aten::atanh},    {aten::arctanh_, aten::atanh_},
    {aten::fix, aten::trunc},        {aten::fix_, aten::trunc_},
    {aten::negative, aten::neg},     {aten::negative_, aten::neg_},
    {aten::subtract, aten::sub},     {aten::subtract_, aten::sub_},
    {aten::greater_equal, aten::ge}, {aten::greater_equal_, aten::ge_},
    {aten::greater, aten::gt},       {aten::greater_, aten::gt_},
    {aten::less_equal, aten::le},    {aten::less_equal_, aten::le_},
    {aten::less, aten::lt},          {aten::less_, aten::lt_},
    {aten::not_equal, aten::ne},     {aten::not_equal_, aten::ne_},
    {aten::divide, aten::div},       {aten::divide_, aten::div_},
    {aten::multiply, aten::mul},     {aten::multiply_, aten::mul_},
    {aten::true_divide, aten::div},  {aten::true_divide_, aten::div_},
};

void replaceNodeWithNewSymbol(Node* node, Symbol new_symbol) {
  WithInsertPoint insert_guard{node};
  auto graph = node->owningGraph();
  auto replace_node = graph->insertNode(graph->create(new_symbol, 0));
  for (Value* v : node->inputs()) {
    replace_node->addInput(v);
  }
  for (Value* v : node->outputs()) {
    auto new_out = replace_node->addOutput()->copyMetadata(v);
    v->replaceAllUsesWith(new_out);
  }
  replace_node->copyMetadata(node);
  TORCH_INTERNAL_ASSERT(
      replace_node->maybeOperator(),
      "invalid symbol replacement:",
      new_symbol,
      node->kind());
}

// having multiple ops in our IR that do the same thing makes the IR more
// difficult to consumer for downstream user of the IR, such as our own
// optimization passes here, we convert op aliases into a standard form
bool normalizeOpAliases(graph_node_list_iterator& iter) {
  auto alias = alias_map.find(iter->kind());
  if (alias != alias_map.end()) {
    replaceNodeWithNewSymbol(*iter, alias->second);
    iter.destroyCurrent();
    return true;
  }
  return false;
}

void NormalizeOps(Block* block) {
  for (auto it = block->nodes().begin(), end = block->nodes().end();
       it != end;) {
    for (auto sub : it->blocks()) {
      NormalizeOps(sub);
    }

    if (normalizeOpAliases(it)) {
      continue;
    }

    it++;
  }
}

} // namespace

void NormalizeOps(const std::shared_ptr<Graph>& graph) {
  NormalizeOps(graph->block());
}

} // namespace jit
} // namespace torch