1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274
|
import torch.nn as nn
import torch.nn.intrinsic as nni
from typing import Union, Callable, Tuple, Dict, Optional, Type
from torch.ao.quantization.utils import Pattern
from torch.ao.quantization.utils import get_combined_dict
from torch.ao.quantization.utils import MatchAllNode
import itertools
def fuse_conv_bn(is_qat, conv, bn):
r"""Given the conv and bn modules, fuses them and returns the fused module
Args:
is_qat: a flag for whether we are using quantization aware training fusion
or post training quantization fusion
conv: Module instance of type conv2d/conv3d
bn: Spatial BN instance that needs to be fused with the conv
Examples::
>>> m1 = nn.Conv2d(10, 20, 3)
>>> b1 = nn.BatchNorm2d(20)
>>> # xdoctest: +SKIP
>>> m2 = fuse_conv_bn(m1, b1)
"""
assert(conv.training == bn.training),\
"Conv and BN both must be in the same mode (train or eval)."
fused_module_class_map = {
nn.Conv1d: nni.ConvBn1d,
nn.Conv2d: nni.ConvBn2d,
nn.Conv3d: nni.ConvBn3d,
}
if is_qat:
assert bn.num_features == conv.out_channels, 'Output channel of Conv2d must match num_features of BatchNorm2d'
assert bn.affine, 'Only support fusing BatchNorm2d with affine set to True'
assert bn.track_running_stats, 'Only support fusing BatchNorm2d with tracking_running_stats set to True'
fused_module_class = fused_module_class_map.get((type(conv)), None)
if fused_module_class is not None:
return fused_module_class(conv, bn)
else:
raise NotImplementedError("Cannot fuse train modules: {}".format((conv, bn)))
else:
return nn.utils.fuse_conv_bn_eval(conv, bn)
def fuse_conv_bn_relu(is_qat, conv, bn, relu):
r"""Given the conv and bn modules, fuses them and returns the fused module
Args:
is_qat: a flag for whether we are using quantization aware training fusion
or post training quantization fusion
conv: Module instance of type conv2d/conv3d
bn: Spatial BN instance that needs to be fused with the conv
Examples::
>>> m1 = nn.Conv2d(10, 20, 3)
>>> b1 = nn.BatchNorm2d(20)
>>> r1 = nn.ReLU(inplace=False)
>>> # xdoctest: +SKIP
>>> m2 = fuse_conv_bn_relu(m1, b1, r1)
"""
assert(conv.training == bn.training == relu.training),\
"Conv and BN both must be in the same mode (train or eval)."
fused_module : Optional[Type[nn.Sequential]] = None
if is_qat:
map_to_fused_module_train = {
nn.Conv1d: nni.ConvBnReLU1d,
nn.Conv2d: nni.ConvBnReLU2d,
nn.Conv3d: nni.ConvBnReLU3d,
}
assert bn.num_features == conv.out_channels, 'Output channel of Conv must match num_features of BatchNorm'
assert bn.affine, 'Only support fusing BatchNorm with affine set to True'
assert bn.track_running_stats, 'Only support fusing BatchNorm with tracking_running_stats set to True'
fused_module = map_to_fused_module_train.get(type(conv), None)
if fused_module is not None:
return fused_module(conv, bn, relu)
else:
raise NotImplementedError("Cannot fuse train modules: {}".format((conv, bn, relu)))
else:
map_to_fused_module_eval = {
nn.Conv1d: nni.ConvReLU1d,
nn.Conv2d: nni.ConvReLU2d,
nn.Conv3d: nni.ConvReLU3d,
}
fused_module = map_to_fused_module_eval.get(type(conv), None)
if fused_module is not None:
fused_conv = nn.utils.fusion.fuse_conv_bn_eval(conv, bn)
return fused_module(fused_conv, relu)
else:
raise NotImplementedError("Cannot fuse eval modules: {}".format((conv, bn, relu)))
def fuse_linear_bn(is_qat, linear, bn):
r"""Given the linear and bn modules, fuses them and returns the fused module
Args:
is_qat: a flag for whether we are using quantization aware training fusion
or post training quantization fusion
linear: Module instance of type Linear
bn: BatchNorm1d instance that needs to be fused with the linear layer
Examples::
>>> m1 = nn.Linear(20, 10)
>>> b1 = nn.BatchNorm1d(10)
>>> # xdoctest: +SKIP
>>> m2 = fuse_linear_bn(m1, b1)
"""
assert(linear.training == bn.training),\
"Linear and BN both must be in the same mode (train or eval)."
if is_qat:
assert bn.num_features == linear.out_features,\
"Output features of Linear must match num_features of BatchNorm1d"
assert bn.affine, "Only support fusing BatchNorm1d with affine set to True"
assert bn.track_running_stats,\
"Only support fusing BatchNorm1d with tracking_running_stats set to True"
return nni.LinearBn1d(linear, bn)
else:
return nn.utils.fusion.fuse_linear_bn_eval(linear, bn)
def fuse_convtranspose_bn(is_qat, convt, bn):
r"""Given ConvTranspose and bn modules, fuses them and returns the fused module
Args:
convt: Module instance of type ConvTransposeNd
bn: BatchNormNd instance that needs to be fused with the linear layer.
batch norm N should match the ConvTranspose N
Examples::
>>> m1 = nn.ConvTranspose2d(10, 20, 3)
>>> b1 = nn.BatchNorm2d(20)
>>> # xdoctest: +SKIP
>>> m2 = fuse_convtranspose_bn(m1, b1)
"""
assert(convt.training == bn.training),\
"ConvTranspose and BN both must be in the same mode (train or eval)."
if is_qat:
raise Exception("Fusing ConvTranspose+BatchNorm not yet supported in QAT.")
else:
return nn.utils.fusion.fuse_conv_bn_eval(convt, bn, transpose=True)
def sequential_wrapper2(sequential):
""" Given a sequential class for two modules, return a function that takes
is_qat, and then two modules as argument, that ignores the is_qat flag
and always returns the sequential that combines the two input modules
"""
def fuser_method(is_qat, m1, m2):
return sequential(m1, m2)
return fuser_method
DEFAULT_OP_LIST_TO_FUSER_METHOD: Dict[Tuple, Union[nn.Sequential, Callable]] = {
(nn.Conv1d, nn.BatchNorm1d): fuse_conv_bn,
(nn.Conv1d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
(nn.Conv2d, nn.BatchNorm2d): fuse_conv_bn,
(nn.Conv2d, nn.BatchNorm2d, nn.ReLU): fuse_conv_bn_relu,
(nn.Conv3d, nn.BatchNorm3d): fuse_conv_bn,
(nn.Conv3d, nn.BatchNorm3d, nn.ReLU): fuse_conv_bn_relu,
(nn.Conv1d, nn.ReLU): sequential_wrapper2(nni.ConvReLU1d),
(nn.Conv2d, nn.ReLU): sequential_wrapper2(nni.ConvReLU2d),
(nn.Conv3d, nn.ReLU): sequential_wrapper2(nni.ConvReLU3d),
(nn.Linear, nn.BatchNorm1d): fuse_linear_bn,
(nn.Linear, nn.ReLU): sequential_wrapper2(nni.LinearReLU),
(nn.BatchNorm2d, nn.ReLU): sequential_wrapper2(nni.BNReLU2d),
(nn.BatchNorm3d, nn.ReLU): sequential_wrapper2(nni.BNReLU3d),
(nn.ConvTranspose1d, nn.BatchNorm1d): fuse_convtranspose_bn,
(nn.ConvTranspose2d, nn.BatchNorm2d): fuse_convtranspose_bn,
(nn.ConvTranspose3d, nn.BatchNorm3d): fuse_convtranspose_bn,
}
def get_fuser_method(op_list, additional_fuser_method_mapping=None):
''' Get fuser method for the given list of module types,
return None if fuser method does not exist
'''
if additional_fuser_method_mapping is None:
additional_fuser_method_mapping = {}
all_mappings = get_combined_dict(DEFAULT_OP_LIST_TO_FUSER_METHOD,
additional_fuser_method_mapping)
fuser_method = all_mappings.get(op_list, None)
assert fuser_method is not None, "did not find fuser method for: {} ".format(op_list)
return fuser_method
def reverse_sequential_wrapper2(sequential):
""" Given a sequential class for two modules, return a function that takes
is_qat, and then two modules as argument, that ignores the is_qat flag
and always returns the sequential that combines the two input modules, with
the order of two inputs reversed
"""
def fuser_method(is_qat, m1, m2):
return sequential(m2, m1)
return fuser_method
def reverse2(f):
def reversed(is_qat, x, y):
return f(is_qat, y, x)
return reversed
def reverse3(f):
def reversed(is_qat, x, w):
y, z = w
return f(is_qat, z, y, x)
return reversed
DEFAULT_PATTERN_TO_FUSER_METHOD: Dict[Pattern, Union[nn.Sequential, Callable]] = {
(nn.BatchNorm1d, nn.Conv1d): reverse2(fuse_conv_bn),
(nn.ReLU, (nn.BatchNorm1d, nn.Conv1d)): reverse3(fuse_conv_bn_relu),
(nn.BatchNorm2d, nn.Conv2d): reverse2(fuse_conv_bn),
(nn.ReLU, (nn.BatchNorm2d, nn.Conv2d)): reverse3(fuse_conv_bn_relu),
(nn.BatchNorm3d, nn.Conv3d): reverse2(fuse_conv_bn),
(nn.ReLU, (nn.BatchNorm3d, nn.Conv3d)): reverse3(fuse_conv_bn_relu),
(nn.ReLU, nn.Conv1d): reverse_sequential_wrapper2(nni.ConvReLU1d),
(nn.ReLU, nn.Conv2d): reverse_sequential_wrapper2(nni.ConvReLU2d),
(nn.ReLU, nn.Conv3d): reverse_sequential_wrapper2(nni.ConvReLU3d),
(nn.BatchNorm1d, nn.Linear): reverse2(fuse_linear_bn),
(nn.ReLU, nn.Linear): reverse_sequential_wrapper2(nni.LinearReLU),
(nn.ReLU, nn.BatchNorm2d): reverse_sequential_wrapper2(nni.BNReLU2d),
(nn.ReLU, nn.BatchNorm3d): reverse_sequential_wrapper2(nni.BNReLU3d),
(nn.BatchNorm1d, nn.ConvTranspose1d): reverse2(fuse_convtranspose_bn),
(nn.BatchNorm2d, nn.ConvTranspose2d): reverse2(fuse_convtranspose_bn),
(nn.BatchNorm3d, nn.ConvTranspose3d): reverse2(fuse_convtranspose_bn),
}
def get_valid_patterns(op_pattern):
"""
Returns a list of valid patterns generated from the op_pattern,
since MatchAllNode can match all types of nodes,
e.g. pattern (torch.nn.Conv2d, torch.add) should also be able to match keys like
(MatchAllNode, torch.add) and (torch.nn.Conv2d, MatchAllNode)
Example Input:
(torch.add, (torch.nn.ReLU, torch.nn.Conv2d))
Example Output:
[(torch.add, (torch.nn.ReLU, torch.nn.Conv2d)),
(torch.add, (torch.nn.ReLU, MatchAllNode)),
(torch.add, (MatchAllNode, torch.nn.Conv2d)),
(torch.add, (MatchAllNode, MatchAllNode)),
(MatchAllNode, (torch.nn.ReLU, torch.nn.Conv2d)),
(MatchAllNode, (torch.nn.ReLU, MatchAllNode)),
(MatchAllNode, (MatchAllNode, torch.nn.Conv2d)),
(MatchAllNode, (MatchAllNode, MatchAllNode)),
]
"""
result = []
if isinstance(op_pattern, (tuple, list)):
sub_combs = []
for sub_pattern in op_pattern:
sub_combs.append(get_valid_patterns(sub_pattern))
result = list(itertools.product(*sub_combs))
else:
result = [op_pattern, MatchAllNode]
return result
def get_fuser_method_new(
op_pattern: Pattern,
fuser_method_mapping: Optional[Dict[Pattern, Union[nn.Sequential, Callable]]] = None):
""" This will be made defult after we deparate the get_fuser_method
Would like to implement this first and have a separate PR for deprecation
"""
if fuser_method_mapping is None:
fuser_method_mapping = DEFAULT_PATTERN_TO_FUSER_METHOD
op_patterns = get_valid_patterns(op_pattern)
fuser_method = None
for op_pattern in op_patterns:
fuser_method = fuser_method_mapping.get(op_pattern, None)
if fuser_method is not None:
break
assert fuser_method is not None, "did not find fuser method for: {} ".format(op_pattern)
return fuser_method
|