1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
|
#pragma once
#include <torch/csrc/jit/api/module.h>
#include <torch/csrc/jit/ir/ir.h>
namespace torch {
namespace jit {
// Directly after tracing, we have an ill-formed graph with blocks inserted.
// Example:
//
// graph(%self : ClassType<Module>,
// %input.1 : Float(3, 4)):
// %1 : ClassType<Module> = prim::GetAttr[name="relu1"](%self)
// %2 : ClassType<Module> = prim::GetAttr[name="relu2"](%self)
// %3 : ClassType<Module> = prim::GetAttr[name="rrr"](%2)
// = prim::TracedModuleForward[scope="__module.relu1"]()
// block0():
// %input : Float(3, 4) = aten::relu(%input.1),
// -> ()
// = prim::TracedModuleForward[scope="__module.relu2"](),
// block0():
// = prim::TracedModuleForward[scope="__module.relu2.rrr"](),
// block0():
// %6 : Float(3, 4) = aten::relu(%input),
// -> ()
// -> ()
// return (%6)
//
// In this pass, we:
// 1) Lift Value defs to as high of a scope as needed to ensure that
// they dominate all their uses. For example, `input` in the above
// graph needs to be lifted to the top-level block so that its use
// in the second `relu` operator is dominated.
// 2) Lambda lift the blocks. This ensures that all values used within
// each scope have their defs captured.
// 3) Convert the scope blocks into methods on their respective Modules,
// and convert TracedModuleForward nodes to CallMethod nodes into those
// methods.
//
// Then, we'll have a well-formed graph with proper method calls.
TORCH_API void FixupTraceScopeBlocks(
std::shared_ptr<Graph>& graph,
Module* self);
} // namespace jit
} // namespace torch
|