File: qconfig_multi_mapping.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (249 lines) | stat: -rw-r--r-- 10,195 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
# mypy: allow-untyped-defs
from __future__ import annotations

import copy
from typing import Any, Callable, Dict, List, TYPE_CHECKING, Union

import torch
from torch.ao.quantization import QConfigMapping
from torch.ao.quantization.qconfig_mapping import _QCONFIG_STYLE_ORDER


if TYPE_CHECKING:
    from torch.ao.quantization.qconfig import QConfigAny

__all__ = ["QConfigMultiMapping"]

_QCONFIG_STYLE_TO_METHOD: Dict[str, str] = {
    "global_qconfig": "set_global",
    "object_type_qconfigs": "set_object_type",
    "module_name_regex_qconfigs": "set_module_name_regex",
    "module_name_qconfigs": "set_module_name",
    "module_name_object_type_order_qconfigs": "set_module_name_object_type_order",
}


def _remove_duplicates_and_none(qconfig_list: List[QConfigAny]) -> None:
    to_remove = []
    for index, cur_qconfig in enumerate(qconfig_list):
        if cur_qconfig is None:
            to_remove.append(index)
            break
        for checked_qconfig in qconfig_list[:index]:
            if torch.ao.quantization.qconfig_equals(cur_qconfig, checked_qconfig):
                to_remove.append(index)
                break
    for index in to_remove[::-1]:
        qconfig_list.pop(index)


class QConfigMultiMapping:
    """
    This class, used with the prepare_n_shadows_model API, stores a list of :class:`torch.ao.quantization.QConfigMapping`s
    so that multiple QConfigs can be specified for each QConfig matching style.

    The user can specify QConfigs using the following methods (in increasing match priority):

        ``set_global`` : sets the global (default) QConfigs

        ``set_object_type`` : sets the QConfigs for a given module type, function, or method name

        ``set_module_name_regex`` : sets the QConfigs for modules matching the given regex string

        ``set_module_name`` : sets the QConfigs for modules matching the given module name

        ``set_module_name_object_type_order`` : sets the QConfigs for modules matching a combination
        of the given module name, object type, and the index at which the module appears

    Note: Usage of set methods is the same as in QConfigMapping except with a passed in list of QConfigs rather than a
    single QConfig.

    Example usage::

        qconfig_mapping = QConfigMultiMapping()
            .set_global([qconfig1, qconfig2])
            .set_object_type(torch.nn.Linear, [qconfig2, qconfig3])
            .set_object_type(torch.nn.ReLU, [qconfig1])
            .set_module_name_regex("foo.*bar.*conv[0-9]+", [qconfig2])
            .set_module_name_regex("foo.*", [qconfig1, qconfig2, qconfig3])
            .set_module_name("module1", [None])
            .set_module_name("module2", [qconfig2])
            .set_module_name_object_type_order("foo.bar", torch.nn.functional.linear, 0, [qconfig3])

    """

    def __init__(self) -> None:
        # initialize this with 1 QConfigMapping to avoid corner cases
        self.qconfig_mappings_list: List[QConfigMapping] = [QConfigMapping()]

    def _handle_list_size_mismatch(
        self, qconfig_list: List[QConfigAny], style: str
    ) -> None:
        # this method handles cases where the size of qconfig_list does not match
        # the size of qconfig_mappings_list.
        # Issue: Consider a user inserting global_qconfig A and B first, then inserting
        # qconfig C as an object_type_qconfig for conv ops. If we internally store
        # 1 QConfigMapping with A and C and another with just B, then the
        # second QConfigMapping will match B to conv ops (which is not wanted), since B is global.

        # we avoid this by maintaining the invariant that if any QConfigMapping
        # has a qconfig style+key with a qconfig in it, all QConfigMappings must
        # have either a qconfig or None for that same style+key. In the above
        # example, a None qconfig would prevent the unwanted match in the
        # second QConfigMapping

        if len(qconfig_list) > len(self.qconfig_mappings_list):
            # Case: we have more qconfigs (in qconfig_list) than QConfigMappings

            # Add new QConfigMappings (initialized so we maintain the `invariant`)

            new_qconfig_mapping = QConfigMapping()
            # searches other QConfigMappings for qconfig style+keys
            # that need to be inserted as `None` into the new QConfigMapping
            for qconfig_mapping in self.qconfig_mappings_list:
                # global_qconfig has None by default
                for check_style in _QCONFIG_STYLE_ORDER[1:]:
                    qconfigs_dict = getattr(qconfig_mapping, check_style)
                    target_qconfigs_dict = getattr(new_qconfig_mapping, check_style)
                    for key in qconfigs_dict:
                        target_qconfigs_dict[key] = None
                break

            # insert copies of this new QConfigMapping until all entires
            # in qconfig_list can fit among the QConfigMappings
            while len(qconfig_list) > len(self.qconfig_mappings_list):
                self.qconfig_mappings_list.append(copy.deepcopy(new_qconfig_mapping))
        else:
            # Case: we have fewer qconfigs in qconfig_list than QConfigMappings

            # pad qconfig_list with `None` until length is same
            while len(qconfig_list) < len(self.qconfig_mappings_list):
                qconfig_list.append(None)

    # this function applies the insertion method across each QConfigMapping
    def _insert_qconfig_list(
        self,
        style: str,
        args: List[Union[str, int, Callable]],
        qconfig_list: List[QConfigAny],
    ) -> None:
        # we remove duplicates and None to make the ordering of qconfigs
        # deterministic upon insertion.
        _remove_duplicates_and_none(qconfig_list)

        self._handle_list_size_mismatch(qconfig_list, style)
        method_name = _QCONFIG_STYLE_TO_METHOD[style]
        for qconfig_mapping, qconfig in zip(self.qconfig_mappings_list, qconfig_list):
            # uses QConfigMapping set method to insert qconfig
            set_method = getattr(qconfig_mapping, method_name)
            set_method(*args, qconfig)

    def set_global(self, global_qconfig_list: List[QConfigAny]) -> QConfigMultiMapping:
        """
        Set global QConfigs
        see :func:`~torch.ao.quantization.QConfigMapping.set_global()` for more info
        """
        self._insert_qconfig_list("global_qconfig", [], global_qconfig_list)
        return self

    def set_object_type(
        self, object_type: Union[Callable, str], qconfig_list: List[QConfigAny]
    ) -> QConfigMultiMapping:
        """
        Set object type QConfigs
        see :func:`~torch.ao.quantization.QConfigMapping.set_object_type()` for more info
        """
        self._insert_qconfig_list("object_type_qconfigs", [object_type], qconfig_list)
        return self

    def set_module_name_regex(
        self, module_name_regex: str, qconfig_list: List[QConfigAny]
    ) -> QConfigMultiMapping:
        """
        Set module_name_regex QConfigs
        see :func:`~torch.ao.quantization.QConfigMapping.set_module_name_regex()` for more info
        """
        self._insert_qconfig_list(
            "module_name_regex_qconfigs", [module_name_regex], qconfig_list
        )
        return self

    def set_module_name(
        self, module_name: str, qconfig_list: List[QConfigAny]
    ) -> QConfigMultiMapping:
        """
        Set module_name QConfigs
        see :func:`~torch.ao.quantization.QConfigMapping.set_module_name()` for more info
        """
        self._insert_qconfig_list("module_name_qconfigs", [module_name], qconfig_list)
        return self

    def set_module_name_object_type_order(
        self,
        module_name: str,
        object_type: Callable,
        index: int,
        qconfig_list: List[QConfigAny],
    ) -> QConfigMultiMapping:
        """
        Set module_name QConfigs
        see :func:`~torch.ao.quantization.QConfigMapping.set_module_name_object_type_order()` for more info
        """
        self._insert_qconfig_list(
            "module_name_object_type_order_qconfigs",
            [module_name, object_type, index],
            qconfig_list,
        )
        return self

    def __repr__(self):
        return (
            self.__class__.__name__
            + " ["
            + "".join(
                f"\n{qconfig_mapping.__repr__()},"
                for qconfig_mapping in self.qconfig_mappings_list
            )
            + "\n]"
        )

    @classmethod
    def from_list_qconfig_mapping(
        cls, qconfig_mapping_list: List[QConfigMapping]
    ) -> QConfigMultiMapping:
        """
        Creates a QConfigMultiMapping from a list of QConfigMappings
        """
        new_qconfig_multi_mapping = cls()

        new_qconfig_multi_mapping.qconfig_mappings_list = copy.deepcopy(
            qconfig_mapping_list
        )

        # we need to avoid the issue described in _handle_list_size_mismatch,
        # so we reinsert all the qconfigs using the QConfigMultiMapping
        # set methods

        # go through all qconfig styles
        # note: global can be ignored since it is None by default
        for style in _QCONFIG_STYLE_ORDER[1:]:
            # gather all key+qconfigs for current style
            # into qconfig_dict_list
            qconfig_dict_list: Dict[Any, List[QConfigAny]] = {}
            for qconfig_mapping in qconfig_mapping_list:
                qconfig_dict = getattr(qconfig_mapping, style)
                for key, qconfig in qconfig_dict.items():
                    if key not in qconfig_dict_list:
                        qconfig_dict_list[key] = []
                    qconfig_dict_list[key].append(qconfig)

            # reinsert all gathered key+qconfigs
            set_method_name = _QCONFIG_STYLE_TO_METHOD[style]
            set_method = getattr(new_qconfig_multi_mapping, set_method_name)
            for key, qconfig_list in qconfig_dict_list.items():
                if isinstance(key, tuple):
                    set_method(*key, qconfig_list)
                else:
                    set_method(key, qconfig_list)

        return new_qconfig_multi_mapping