File: function_extraction.h

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (70 lines) | stat: -rw-r--r-- 2,295 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
#pragma once

#include <torch/csrc/jit/ir/ir.h>

namespace torch {
namespace jit {

// This api will be used by serialization/export.cpp to extract function
// information. It should do conversion on graph to
//    1. Extract subgraph pattern of functions and define as local function
//    node.
//    2. Replace subgraph pattern of functions with a single node reflecting
//    that local function node type.
// Function attribute map information is also returned, as Torch IR cannot
// represent these info inside Graph object.
// export.cpp will serialize the ONNX model with function_proto with
// above information.
namespace onnx {

// The following return types are used to track information regarding function
// attributes, that are unable to be traced through Torch IR.
// NodeAttrNameMap tracks mapping from attribute name of IR Node inside function
// subgraph, to function attribute name. Here's an example of exporting CELU and
// LayerNorm.
//
// clang-format off
// class M(torch.nn.Module):
//     def __init__(self):
//         super().__init__()
//         self.lns = torch.nn.ModuleList([torch.nn.LayerNorm(3, eps = i) for i in range(2)])
//         self.celu1 = torch.nn.CELU(1.0)
//         self.celu2 = torch.nn.CELU(2.0)

//     def forward(self, x: torch.Tensor, y: torch.Tensor, z: torch.Tensor) -> torch.Tensor:
//         res1 = self.celu1(x)
//         res2 = self.celu2(y)
//         for ln in self.lns:
//             z = ln(z)
//         return res1 + res2 + z
// clang-format on
//
// Returning
//
// NodeAttrNameMap:
// {
//    %1 : Float(2, 3) = onnx::Celu[alpha=2.](%y) : {
//      'alpha' : 'Celu_alpha'
//    }
// }
//
// The info here helps graph._export_onnx to construct function attributes for
// onnx local FunctionProto.
using NodeAttrNameMap = std::
    unordered_map<const Node*, std::unordered_map<std::string, std::string>>;

TORCH_API NodeAttrNameMap ONNXFunctionExtraction(
    std::shared_ptr<Graph>& graph,
    const std::unordered_set<std::string>& module_names,
    const std::vector<std::string>& param_names);

TORCH_API void ONNXClearScopeRecords();

TORCH_API void ONNXTrackScopeAttributes(
    std::shared_ptr<Graph>& graph,
    std::map<std::string, IValue>& attributes);

} // namespace onnx

} // namespace jit
} // namespace torch