File: helper.h

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (216 lines) | stat: -rw-r--r-- 7,489 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
#pragma once
#include <torch/csrc/jit/api/module.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/ir/subgraph_matcher.h>
#include <torch/csrc/jit/passes/graph_rewrite_helper.h>
#include <torch/csrc/jit/passes/quantization/quantization_type.h>

#include <functional>
#include <regex>

namespace torch {
namespace jit {

using graph_rewrite_helper::getFuncName;

// Vector of a module and the name of its method
using ModuleMethodVector = std::vector<std::pair<Module, std::string>>;
// Map of quantization parameter name and value
// for example _scale, _zero_point,
// _scalar_type and _axis(for per channel quantization)
using QParamVector = std::vector<std::pair<std::string, IValue>>;

// =========== helper functions for Value =========
// Check if a value is weight, since we need to use weight observer
// for weight
TORCH_API bool isWeight(Value* v);

// Check if a value is bias for conv and linear, which we do not
// quantize
TORCH_API bool isBiasOfConvOrLinear(Value* v);

TORCH_API bool isEmbeddingBagNonInput(Value* v);

// Get the use as scalar input of clamp ops for the input value
c10::optional<Use> getClampScalarInputUse(Value* v);

// For a given value `v`, get the list of values that we need to check
// if they are observed/quantized or not, if so, we can say the
// `v` is also observed/quantized, since we can derive
// the quantization parameters for `v` given the list of values
TORCH_API std::vector<Value*> getPassThroughInputs(Value* v);

// Clones the method by the name of orig_method_name into new_method_name method
TORCH_API void cloneMethod(
    Module& module,
    const std::string& orig_method_name,
    const std::string& new_method_name);

// Check if a value in the graph is a Scalar value
TORCH_API bool isScalar(Value* v);

// Check if value is the input of the graph
TORCH_API bool hitGraphInput(Value* value);

// Converts a mangled name, such as
//   __torch__.torch.ao.nn.quantized.modules.conv.___torch_mangle_7.Conv2d
// into an unmangled name, such as
//   __torch__.torch.ao.nn.quantized.modules.conv.Conv2d
TORCH_API std::string removeTorchMangle(const std::string& orig_name);

// Return the module name that corresponds to the value.
TORCH_API c10::optional<std::string> getModuleName(Value* value);

// =========== helper functions for Node =========
TORCH_API bool isSingleInputGeneralShapeAtenFunction(Node* n);

TORCH_API bool isSingleInputGeneralValueAtenFunction(Node* n);

TORCH_API bool isSingleInputGeneralCallFunction(Node* n);

TORCH_API bool isSingleInputGeneralAtenFunction(Node* n);

TORCH_API bool isClamp(Node* n);

// Check if the node will produce the same result regardless of whether
// the input tensor is quantized or not, example: aten::size
TORCH_API bool isTensorInfoNode(Node* n);

// Check if this the the propaagate op that has single input, e.g. aten::cat
TORCH_API bool isPropagateQuantSingleInputOp(Node* n);

// Check if this is the propagate op that has two inputs, e.g. aten::add
TORCH_API bool isPropagateQuantBinaryOp(Node* n);

// Check if this is the node that we'll quantize or not quantize depending on
// whether the input of the node is quantized, example: aten::cat
TORCH_API bool isPropagateQuantOp(Node* n);

// Check if the node is a binary op like aten::add and aten::mul and
// if the input 1 is a scalar, these ops will be quantized to
// quantized::{op}_scalar
TORCH_API bool isBinaryOpWithScalarInput(Node* n);

TORCH_API c10::optional<std::tuple<c10::QScheme, QParamVector>> getFixedQParams(
    Node* n);

// We don't want to analyze the graph for some `builtin` CallFunctions
// like `linear` because we want to preserve the op boundary
TORCH_API bool userDefinedCallFunction(Node* n);

// Check if the node has scalar input
TORCH_API bool hasScalarInput(Node* n);

// Check if a node is quantizable
TORCH_API bool nodeQuantizable(
    Node* n,
    QuantType quant_type = QuantType::STATIC);

// Nodes which only require quantization of weight value, eg. embedding_bag
bool isWeightOnlyStaticQuantOp(Node* n);

// Check if a use of the value is quantizable, this depends on
// both the use node and the offset
TORCH_API bool useQuantizable(const Use& use, QuantType quant_type);

// Given a CallFunction node, extract the graph of the called function
TORCH_API std::shared_ptr<Graph> getCallFunctionGraph(Node* n);

// Check if `use` is a CallFunction of name `func_name` and if value
// `v` is the nth argument (if provided) of the function
bool matchCallFuncToUse(
    const Use& use,
    const std::string& func_name,
    c10::optional<int> nth_arg);

// Check if `use` is a AtenFunction of name `func_name` and if value
// `v` is the nth argument (if provided) of the function
bool matchAtenFuncToUse(
    const Use& use,
    const std::string& func_name,
    c10::optional<int> nth_arg);

// =========== helper functions for Block =========
// checks if a block will always raise an Exception
TORCH_API bool alwaysRaisesException(Block* block);

// =========== helper functions for Module  ==========
// TODO: remove
TORCH_API std::vector<std::string> getModuleAccessPath(
    Value* instance,
    Value* self);
// TODO: remove
TORCH_API Module
findChildModule(const Module& module, const std::vector<std::string>& path);

// Given an CallMethod node, get the module instance corresponding
// to the instance Value
// TODO: refactor all current uses of this function to the Opt one
TORCH_API Module getInvokedModule(Module& module, Node* n, Value* self);

// Given an CallMethod node, get the module instance corresponding
// to the instance Value if the instance is a module, otherwise return
// c10::nullopt
c10::optional<Module> getInvokedModuleOpt(
    const Module& module,
    Node* n,
    Value* self);

// ==================== filter functions for matches ==============
// filter to check Value `vname` is a constant of int value `value`
bool is_int_constant(
    const Match& match,
    const std::unordered_map<std::string, Value*>& vmap,
    const std::string& vname,
    int value);

// filter to check if the %alpha argument of aten::add is constant 1
bool aten_add_alpha_is_one(
    const Match& match,
    const std::unordered_map<std::string, Value*>& vmap);

// filter to check if the functional in CallFunction is relu
bool is_functional_relu(
    const Match& match,
    const std::unordered_map<std::string, Value*>& vmap);

// filter to check if the module is torch.nn.ReLU
bool is_relu_module(
    const Match& match,
    const std::unordered_map<std::string, Value*>& vmap);

bool is_linear_module(
    const Match& match,
    const std::unordered_map<std::string, Value*>& vmap);

// TODO: add a macro to declare the filters
bool is_conv1d_module(
    const Match& match,
    const std::unordered_map<std::string, Value*>& vmap);

bool is_conv2d_module(
    const Match& match,
    const std::unordered_map<std::string, Value*>& vmap);

bool is_conv3d_module(
    const Match& match,
    const std::unordered_map<std::string, Value*>& vmap);

bool is_conv_transpose1d_module(
    const Match& match,
    const std::unordered_map<std::string, Value*>& vmap);

bool is_conv_transpose2d_module(
    const Match& match,
    const std::unordered_map<std::string, Value*>& vmap);

bool is_batchnorm2d_module(
    const Match& match,
    const std::unordered_map<std::string, Value*>& vmap);

bool is_batchnorm3d_module(
    const Match& match,
    const std::unordered_map<std::string, Value*>& vmap);

} // namespace jit
} // namespace torch