File: backend_config_utils.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (145 lines) | stat: -rw-r--r-- 6,467 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import torch
from torch.ao.quantization.fx.pattern_utils import get_default_quant_patterns, sorted_patterns_dict
from torch.ao.quantization.backend_config import (
    get_native_backend_config,
    ObservationType,
)
from torch.ao.quantization.quantization_types import (
    Pattern,
    NodePattern,
    QuantizerCls,
)
from torch.ao.quantization.utils import (
    activation_dtype,
    get_combined_dict,
)

from ..backend_config import BackendConfig
from .quantization_patterns import QuantizeHandler
from .fusion_patterns import DefaultFuseHandler

from typing import Dict, Any, Callable, Optional

def get_quantize_handler_cls(
        observation_type,
        dtype_configs,
        num_tensor_args_to_observation_type,
        overwrite_output_fake_quantizer,
        overwrite_output_observer,
        input_output_observed):

    class ConfigurableQuantizeHandler(QuantizeHandler):
        def __init__(
                self,
                node_pattern: NodePattern,
                modules: Dict[str, torch.nn.Module],
                root_node_getter: Callable = None):
            super().__init__(node_pattern, modules, root_node_getter)
            if num_tensor_args_to_observation_type:
                assert self.num_tensor_args in num_tensor_args_to_observation_type, \
                    f"Must provide observation_type config for tensor number {self.num_tensor_args}" \
                    f" in num_tensor_args_to_observation_type for {node_pattern}"
                self.observation_type = num_tensor_args_to_observation_type[self.num_tensor_args]
            else:
                self.observation_type = observation_type
            self.dtype_configs = dtype_configs
            self.overwrite_output_fake_quantizer = overwrite_output_fake_quantizer
            self.overwrite_output_observer = overwrite_output_observer
            self.input_output_observed_ = input_output_observed

        def is_general_tensor_value_op(self) -> bool:
            return self.observation_type == ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT

        # TODO: change this to output activation
        def get_activation_ctr(
                self,
                qconfig: Any,
                pattern: Pattern,
                is_training: bool,
        ) -> Optional[Callable]:
            """
            Returns the constructor for the activation observer which should be
            used for the pattern matched to this handler. Some handlers override
            this to a different value than what is specified in the qconfig.
            """
            act_dtype = activation_dtype(qconfig)
            # TODO: change to is_qat
            if is_training:
                if act_dtype == torch.quint8 and self.overwrite_output_fake_quantizer is not None:
                    return self.overwrite_output_fake_quantizer
            else:
                if act_dtype == torch.quint8 and self.overwrite_output_observer is not None:
                    return self.overwrite_output_observer
            return qconfig.activation

        # This is temporary, and will be removed soon
        def input_output_observed(self):
            return self.input_output_observed_


    return ConfigurableQuantizeHandler

def get_pattern_to_quantize_handlers(backend_config: BackendConfig) -> Dict[Pattern, QuantizerCls]:
    """
    Note: Quantize handler is just a holder for some check methods like
    (should_insert_observer_for_output), maybe this can be a enum as well,
    we can refactor this after we convert the path for fbgemm/qnnpack fully to the
    new path, this is not exposed to backend developers
    """
    pattern_to_quantize_handlers = {}
    for pattern, config in backend_config.configs.items():
        observation_type = config.observation_type
        dtype_configs = config.dtype_configs
        num_tensor_args_to_observation_type = config._num_tensor_args_to_observation_type
        overwrite_fake_quantizer = config._overwrite_output_fake_quantize
        overwrite_observer = config._overwrite_output_observer
        input_output_observed = config._input_output_observed
        if input_output_observed is None:
            input_output_observed = True
        pattern_to_quantize_handlers[pattern] = \
            get_quantize_handler_cls(
                observation_type,
                dtype_configs,
                num_tensor_args_to_observation_type,
                overwrite_fake_quantizer,
                overwrite_observer,
                input_output_observed)

    return pattern_to_quantize_handlers

# TODO: move this to torch/ao/quantization/backend_config/utils.py
def get_fusion_pattern_to_fuse_handler_cls(
        backend_config: BackendConfig) -> Dict[Pattern, Callable]:
    fusion_pattern_to_fuse_handlers: Dict[Pattern, Callable] = {}
    for pattern, config in backend_config.configs.items():
        if config.fuser_method is not None:
            # TODO: is this logic right?
            fusion_pattern_to_fuse_handlers[pattern] = DefaultFuseHandler

    return fusion_pattern_to_fuse_handlers

# TODO: remove when all uses are changed to backend_config
def get_native_quant_patterns(additional_quant_patterns: Dict[Pattern, QuantizerCls] = None) -> Dict[Pattern, QuantizerCls]:
    """
    Return a map from pattern to quantize handlers based on the default patterns and the native backend_config.
    The returned map is sorted such that longer patterns will be encountered first when iterating through it.
    """
    patterns = get_default_quant_patterns()
    if additional_quant_patterns is not None:
        patterns = get_combined_dict(patterns, additional_quant_patterns)
    # TODO: currently we just extend the quantize handlers generated from
    # `get_native_backend_config`
    # in the future we can just assign backend_config when everything is defined
    for pattern, quantize_handler in get_pattern_to_quantize_handlers(get_native_backend_config()).items():
        patterns[pattern] = quantize_handler
    return sorted_patterns_dict(patterns)

get_fusion_pattern_to_fuse_handler_cls.__module__ = "torch.ao.quantization.fx.backend_config_utils"
get_native_quant_patterns.__module__ = "torch.ao.quantization.fx.backend_config_utils"
get_pattern_to_quantize_handlers.__module__ = "torch.ao.quantization.fx.backend_config_utils"

__all__ = [
    "get_fusion_pattern_to_fuse_handler_cls",
    "get_native_quant_patterns",
    "get_pattern_to_quantize_handlers",
]