File: fixup_trace_scope_blocks.h

package info (click to toggle)
pytorch 1.7.1-7
  • links: PTS, VCS
  • area: main
  • in suites: bullseye
  • size: 80,340 kB
  • sloc: cpp: 670,830; python: 343,991; ansic: 67,845; asm: 5,503; sh: 2,924; java: 2,888; xml: 266; makefile: 244; ruby: 148; yacc: 144; objc: 51; lex: 44
file content (47 lines) | stat: -rw-r--r-- 1,673 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
#pragma once

#include <torch/csrc/jit/api/module.h>
#include <torch/csrc/jit/ir/ir.h>

namespace torch {
namespace jit {

// Directly after tracing, we have an ill-formed graph with blocks inserted.
// Example:
//
// graph(%self : ClassType<Module>,
//       %input.1 : Float(3, 4)):
//   %1 : ClassType<Module> = prim::GetAttr[name="relu1"](%self)
//   %2 : ClassType<Module> = prim::GetAttr[name="relu2"](%self)
//   %3 : ClassType<Module> = prim::GetAttr[name="rrr"](%2)
//    = prim::TracedModuleForward[scope="__module.relu1"]()
//     block0():
//       %input : Float(3, 4) = aten::relu(%input.1),
//       -> ()
//    = prim::TracedModuleForward[scope="__module.relu2"](),
//     block0():
//        = prim::TracedModuleForward[scope="__module.relu2.rrr"](),
//         block0():
//           %6 : Float(3, 4) = aten::relu(%input),
//           -> ()
//       -> ()
//   return (%6)
//
// In this pass, we:
//   1) Lift Value defs to as high of a scope as needed to ensure that
//      they dominate all their uses. For example, `input` in the above
//      graph needs to be lifted to the top-level block so that its use
//      in the second `relu` operator is dominated.
//   2) Lambda lift the blocks. This ensures that all values used within
//      each scope have their defs captured.
//   3) Convert the scope blocks into methods on their respective Modules,
//      and convert TracedModuleForward nodes to CallMethod nodes into those
//      methods.
//
//  Then, we'll have a well-formed graph with proper method calls.
TORCH_API void FixupTraceScopeBlocks(
    std::shared_ptr<Graph>& graph,
    Module* self);

} // namespace jit
} // namespace torch