File: canonicalize_graph_fuser_ops.cpp

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (100 lines) | stat: -rw-r--r-- 3,850 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
#include <c10/util/irange.h>
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/passes/canonicalize_graph_fuser_ops.h>
#include <torch/csrc/jit/passes/dead_code_elimination.h>

namespace torch {
namespace jit {

struct ChunkOutput {
  ChunkOutput(Value* v, size_t o) : val(v), offset(o){};
  Value* val;
  size_t offset;
};

static c10::optional<std::vector<ChunkOutput>> getChunkOutputs(Node* chunk) {
  std::vector<ChunkOutput> outputs;
  for (auto list_use : chunk->output()->uses()) {
    if (list_use.user->matches(
            "aten::select(t[] list, int idx) -> t", attr::idx) &&
        list_use.user->output()->type()->cast<TensorType>()) {
      outputs.emplace_back(
          list_use.user->output(),
          list_use.user->get<int64_t>(attr::idx).value());
    } else if (list_use.user->kind() == prim::ListUnpack) {
      // This sometimes happens if the sizes can't be evenly divided by the
      // number of chunks
      if (static_cast<int64_t>(list_use.user->outputs().size()) !=
          chunk->get<int64_t>(attr::chunks).value()) {
        return c10::nullopt;
      }
      auto unpack_outputs = list_use.user->outputs();
      for (const auto i : c10::irange(unpack_outputs.size())) {
        outputs.emplace_back(unpack_outputs[i], i);
      }
    } else {
      return c10::nullopt;
    }
  }
  return outputs;
}

static void CanonicalizeOps(Block* block) {
  for (auto it = block->nodes().begin(), end = block->nodes().end(); it != end;
       ++it) {
    for (auto sub : it->blocks())
      CanonicalizeOps(sub);
    if (it->matches(
            "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor") ||
        it->matches(
            "aten::sub(Tensor self, Tensor other, *, Scalar alpha) -> Tensor") ||
        it->matches("aten::mul(Tensor self, Tensor other) -> Tensor") ||
        it->matches("aten::div(Tensor self, Tensor other) -> Tensor")) {
      // Replace rank 0 Tensor constants with scalar constants.
      if (auto other = it->get<at::Tensor>(attr::other)) {
        if (other->dim() == 0) {
          WithInsertPoint insert_guard{*it};
          auto graph = it->owningGraph();
          auto new_other = graph->insertConstant(other->item());
          std::vector<Value*> inputs = it->inputs().vec();
          inputs.at(1) = new_other;
          Value* new_output =
              graph->insertNode(graph->create(it->kind(), inputs))->output();
          new_output->node()->copyMetadata(*it);
          new_output->copyMetadata(it->output());
          it->output()->replaceAllUsesWith(new_output);
        }
      }
    } else if (it->matches(
                   "aten::chunk(Tensor self, int chunks, int dim) -> Tensor[]",
                   /*const_inputs=*/{attr::chunks, attr::dim})) {
      // Replace aten::chunk (which returns a list) with ConstantChunk with the
      // outputs unpacked.
      if (auto orig_outputs = getChunkOutputs(*it)) {
        WithInsertPoint guard(*it);
        auto* self = it->namedInput(attr::self);
        auto* graph = it->owningGraph();
        const auto chunks = it->get<int64_t>(attr::chunks).value();
        const auto dim = it->get<int64_t>(attr::dim).value();
        auto* node =
            graph->insertNode(graph->create(prim::ConstantChunk, chunks));
        node->addInput(self);
        node->i_(attr::chunks, chunks)->i_(attr::dim, dim);
        node->copyMetadata(*it);
        for (const auto& orig_out : *orig_outputs) {
          orig_out.val->replaceAllUsesWith(node->outputs()[orig_out.offset]);
          node->outputs()[orig_out.offset]->setType(orig_out.val->type());
        }
      }
    }
  }
}

void CanonicalizeOps(const std::shared_ptr<Graph>& graph) {
  CanonicalizeOps(graph->block());
  GRAPH_DUMP("After CanonicalizeOps: ", graph);
  EliminateDeadCode(graph);
}

} // namespace jit
} // namespace torch