File: build_quantization_configs.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (62 lines) | stat: -rw-r--r-- 1,920 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
"""
This script will generate default values of quantization configs.
These are for use in the documentation.
"""

import torch
from torch.ao.quantization.backend_config import get_native_backend_config_dict
from torch.ao.quantization.backend_config.utils import (
    entry_to_pretty_str,
    remove_boolean_dispatch_from_name,
)
import os.path


# Create a directory for the images, if it doesn't exist
QUANTIZATION_BACKEND_CONFIG_IMAGE_PATH = os.path.join(
    os.path.realpath(os.path.join(__file__, "..")),
    "quantization_backend_configs"
)

if not os.path.exists(QUANTIZATION_BACKEND_CONFIG_IMAGE_PATH):
    os.mkdir(QUANTIZATION_BACKEND_CONFIG_IMAGE_PATH)

output_path = os.path.join(QUANTIZATION_BACKEND_CONFIG_IMAGE_PATH, "default_backend_config.txt")

with open(output_path, "w") as f:
    native_backend_config_dict = get_native_backend_config_dict()

    configs = native_backend_config_dict['configs']

    def _sort_key_func(entry):
        pattern = entry['pattern']
        while isinstance(pattern, tuple):
            pattern = pattern[-1]

        pattern = remove_boolean_dispatch_from_name(pattern)
        if not isinstance(pattern, str):
            # methods are already strings
            pattern = torch.typename(pattern)

        # we want
        #
        #   torch.nn.modules.pooling.AdaptiveAvgPool1d
        #
        # and
        #
        #   torch._VariableFunctionsClass.adaptive_avg_pool1d
        #
        # to be next to each other, so convert to all lower case
        # and remove the underscores, and compare the last part
        # of the string
        pattern_str_normalized = pattern.lower().replace('_', '')
        key = pattern_str_normalized.split('.')[-1]
        return key

    configs.sort(key=_sort_key_func)

    entries = []
    for entry in configs:
        entries.append(entry_to_pretty_str(entry))
    entries = ",\n".join(entries)
    f.write(entries)