1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182
|
# mypy: allow-untyped-defs
import operator
import torch
from torch.ao.quantization.backend_config import (
BackendConfig,
BackendPatternConfig,
DTypeConfig,
ObservationType,
)
weighted_op_quint8_dtype_config = DTypeConfig(
input_dtype=torch.quint8,
output_dtype=torch.quint8,
weight_dtype=torch.qint8,
bias_dtype=torch.float,
)
from typing import List
def get_linear_configs():
linear_configs = []
observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT
dtype_configs = [weighted_op_quint8_dtype_config]
# TODO: need to fix the way we insert observers for this pattern
# should be solved in the new fusion API
# reason that this doesn't work: the pattern is a bit complicated and we don't
# have a way to specify which input of the pattern we would like to observe
# pattern:
# bias input weight
# \ | /
# \ | t
# \ | /
# addmm
# we want to observe "weight" as weight, but there is not way to convey this
# information with current pattern language
#
# right now:
# original:
# weight - t \
# input - addmm
# observed (no hack):
# weight - t - observer \
# input - observer - addmm
# target:
# weight - observer - t \
# input - observer - addmm
# def root_node_getter(node_pattern):
# addmm, bias, act, weight = node_pattern
# return addmm
# linear_configs.append(
# BackendPatternConfig((torch.ops.aten.addmm.default, MatchAllNode, MatchAllNode, torch.ops.aten.t.default))
# .set_observation_type(observation_type) # noqa: E131
# .set_dtype_configs(dtype_configs)
# ._set_root_node_getter(root_node_getter))
linear_configs.append(
BackendPatternConfig(torch.ops.aten.addmm.default)
.set_observation_type(observation_type) # noqa: E131
.set_dtype_configs(dtype_configs)
._set_input_type_to_index({"weight": 2, "bias": 0})
)
# linear is decomposed to `t - mm` if bias is not present
linear_configs.append(
BackendPatternConfig(torch.ops.aten.mm.default)
.set_observation_type(observation_type) # noqa: E131
.set_dtype_configs(dtype_configs)
._set_input_type_to_index({"weight": 1})
)
return linear_configs
def get_conv_configs():
conv_configs = []
observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT
dtype_configs = [weighted_op_quint8_dtype_config]
conv_configs.append(
BackendPatternConfig(torch.ops.aten.convolution.default)
.set_observation_type(observation_type) # noqa: E131
.set_dtype_configs(dtype_configs)
._set_input_type_to_index({"weight": 1, "bias": 2})
)
conv_configs.append(
BackendPatternConfig(
(torch.ops.aten.convolution.default, torch.ops.aten.relu.default)
)
.set_observation_type(observation_type) # noqa: E131
.set_dtype_configs(dtype_configs)
._set_input_type_to_index({"weight": 1, "bias": 2})
)
# TODO: remove when functionalization is supported in PT2 mode
conv_configs.append(
BackendPatternConfig(
(torch.ops.aten.convolution.default, torch.ops.aten.relu_.default)
)
.set_observation_type(observation_type) # noqa: E131
.set_dtype_configs(dtype_configs)
._set_input_type_to_index({"weight": 1, "bias": 2})
)
return conv_configs
def get_pooling_configs():
backend_pattern_configs = []
observation_type = ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT
dtype_configs = [weighted_op_quint8_dtype_config]
def root_node_getter(node_pattern):
getitem, maxpool, index = node_pattern
return maxpool
backend_pattern_configs.append(
BackendPatternConfig()
._set_pattern_complex_format(
(operator.getitem, torch.ops.aten.max_pool2d_with_indices.default, 0)
)
.set_observation_type(observation_type) # noqa: E131
.set_dtype_configs(dtype_configs)
._set_root_node_getter(root_node_getter)
)
return backend_pattern_configs
def get_relu_configs():
backend_pattern_configs = []
observation_type = ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT
dtype_configs = [weighted_op_quint8_dtype_config]
backend_pattern_configs.append(
BackendPatternConfig(torch.ops.aten.relu.default)
.set_observation_type(observation_type) # noqa: E131
.set_dtype_configs(dtype_configs)
)
return backend_pattern_configs
def get_binary_op_configs():
binary_op_configs: List[BackendPatternConfig] = []
dtype_configs = [weighted_op_quint8_dtype_config]
num_tensor_args_to_observation_type_mapping = {
# TODO: this is not used right now since we have extra check in prepare
# will need to change this to NO_OBSERVER later after we implemented
# Tensor dtype inference properly
0: ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
1: ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT,
2: ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
}
for op_with_quantized_bop_scalar_variant in [
torch.ops.aten.add.Tensor,
torch.ops.aten.add_.Tensor,
]:
bop_patterns = [
(op_with_quantized_bop_scalar_variant, torch.ops.aten.relu.default),
op_with_quantized_bop_scalar_variant,
# TODO: remove when functionalization is supported in pt2_mode
(op_with_quantized_bop_scalar_variant, torch.ops.aten.relu_.default),
]
binary_op_configs.extend(
BackendPatternConfig(bop_pattern)
.set_dtype_configs(dtype_configs) # noqa: E131
._set_num_tensor_args_to_observation_type(
num_tensor_args_to_observation_type_mapping
)
for bop_pattern in bop_patterns
)
return binary_op_configs
def get_qnnpack_pt2e_backend_config():
return (
BackendConfig("qnnpack_pytorch_2.0_export")
.set_backend_pattern_configs(get_linear_configs())
.set_backend_pattern_configs(get_binary_op_configs())
.set_backend_pattern_configs(get_conv_configs())
.set_backend_pattern_configs(get_pooling_configs())
.set_backend_pattern_configs(get_relu_configs())
)
|