File: build_quantization_configs.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (64 lines) | stat: -rw-r--r-- 1,923 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
"""
This script will generate default values of quantization configs.
These are for use in the documentation.
"""

import os.path

import torch
from torch.ao.quantization.backend_config import get_native_backend_config_dict
from torch.ao.quantization.backend_config.utils import (
    entry_to_pretty_str,
    remove_boolean_dispatch_from_name,
)


# Create a directory for the images, if it doesn't exist
QUANTIZATION_BACKEND_CONFIG_IMAGE_PATH = os.path.join(
    os.path.realpath(os.path.join(__file__, "..")), "quantization_backend_configs"
)

if not os.path.exists(QUANTIZATION_BACKEND_CONFIG_IMAGE_PATH):
    os.mkdir(QUANTIZATION_BACKEND_CONFIG_IMAGE_PATH)

output_path = os.path.join(
    QUANTIZATION_BACKEND_CONFIG_IMAGE_PATH, "default_backend_config.txt"
)

with open(output_path, "w") as f:
    native_backend_config_dict = get_native_backend_config_dict()

    configs = native_backend_config_dict["configs"]

    def _sort_key_func(entry):
        pattern = entry["pattern"]
        while isinstance(pattern, tuple):
            pattern = pattern[-1]

        pattern = remove_boolean_dispatch_from_name(pattern)
        if not isinstance(pattern, str):
            # methods are already strings
            pattern = torch.typename(pattern)

        # we want
        #
        #   torch.nn.modules.pooling.AdaptiveAvgPool1d
        #
        # and
        #
        #   torch._VariableFunctionsClass.adaptive_avg_pool1d
        #
        # to be next to each other, so convert to all lower case
        # and remove the underscores, and compare the last part
        # of the string
        pattern_str_normalized = pattern.lower().replace("_", "")
        key = pattern_str_normalized.split(".")[-1]
        return key

    configs.sort(key=_sort_key_func)

    entries = []
    for entry in configs:
        entries.append(entry_to_pretty_str(entry))
    entries = ",\n".join(entries)
    f.write(entries)