File: fixup_trace_scope_blocks.h

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (47 lines) | stat: -rw-r--r-- 1,673 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
#pragma once

#include <torch/csrc/jit/api/module.h>
#include <torch/csrc/jit/ir/ir.h>

namespace torch {
namespace jit {

// Directly after tracing, we have an ill-formed graph with blocks inserted.
// Example:
//
// graph(%self : ClassType<Module>,
//       %input.1 : Float(3, 4)):
//   %1 : ClassType<Module> = prim::GetAttr[name="relu1"](%self)
//   %2 : ClassType<Module> = prim::GetAttr[name="relu2"](%self)
//   %3 : ClassType<Module> = prim::GetAttr[name="rrr"](%2)
//    = prim::TracedModuleForward[scope="__module.relu1"]()
//     block0():
//       %input : Float(3, 4) = aten::relu(%input.1),
//       -> ()
//    = prim::TracedModuleForward[scope="__module.relu2"](),
//     block0():
//        = prim::TracedModuleForward[scope="__module.relu2.rrr"](),
//         block0():
//           %6 : Float(3, 4) = aten::relu(%input),
//           -> ()
//       -> ()
//   return (%6)
//
// In this pass, we:
//   1) Lift Value defs to as high of a scope as needed to ensure that
//      they dominate all their uses. For example, `input` in the above
//      graph needs to be lifted to the top-level block so that its use
//      in the second `relu` operator is dominated.
//   2) Lambda lift the blocks. This ensures that all values used within
//      each scope have their defs captured.
//   3) Convert the scope blocks into methods on their respective Modules,
//      and convert TracedModuleForward nodes to CallMethod nodes into those
//      methods.
//
//  Then, we'll have a well-formed graph with proper method calls.
TORCH_API void FixupTraceScopeBlocks(
    std::shared_ptr<Graph>& graph,
    Module* self);

} // namespace jit
} // namespace torch