File: test_bias_correction_eager.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (103 lines) | stat: -rw-r--r-- 4,173 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
# Owner(s): ["oncall: quantization"]

import torch
import torch.nn as nn
from torch.testing._internal.common_quantization import QuantizationTestCase
from torch.testing._internal.common_quantization import skipIfNoFBGEMM

from torch.ao.quantization import default_qconfig
from torch.ao.quantization import QuantWrapper
import torch.ao.ns._numeric_suite as ns

from torch.ao.quantization._correct_bias import (
    _supported_modules,
    _supported_modules_quantized,
    bias_correction,
    get_module,
    get_param,
    parent_child_names
)

import copy


class TestBiasCorrectionEager(QuantizationTestCase):
    def compute_sqnr(self, x, y):
        Ps = torch.norm(x)
        Pn = torch.norm(x - y)
        return 20 * torch.log10(Ps / Pn)

    def correct_artificial_bias_quantize(self, float_model, img_data):
        ''' Adding artificial bias and testing if bias persists after bias
            correction. This test case changes the bias of a quantized submodule
        '''
        artificial_model = copy.deepcopy(float_model)
        artificial_model.qconfig = default_qconfig
        torch.ao.quantization.prepare(artificial_model, inplace=True)
        for data in img_data:
            artificial_model(data[0])
        torch.ao.quantization.convert(artificial_model, inplace=True)

        # manually changing bias
        for name, submodule in artificial_model.named_modules():
            if type(submodule) in _supported_modules:
                x = get_param(submodule, 'bias')
                weight = get_param(submodule, 'weight')
                if x is not None:
                    submodule.set_weight_bias(weight, x.data * 3)

        bias_correction(float_model, artificial_model, img_data, target_modules=_supported_modules_quantized)

        # Trims off the shadow module,
        for name, submodule in artificial_model.named_modules():
            if isinstance(submodule, ns.Shadow):
                parent_name, child_name = parent_child_names(name)
                parent = get_module(artificial_model, parent_name)
                parent._modules[child_name] = submodule.orig_module

        for name, artificial_submodule in artificial_model.named_modules():
            if type(artificial_submodule) in _supported_modules_quantized:
                submodule = get_module(float_model, name)
                float_bias = get_param(submodule, 'bias')
                artificial_bias = get_param(artificial_submodule, 'bias')

                self.assertTrue(self.compute_sqnr(float_bias, artificial_bias) > 30,
                                "Correcting quantized bias produced too much noise, sqnr score too low")

    @skipIfNoFBGEMM
    def test_linear_chain(self):
        class LinearChain(nn.Module):
            def __init__(self):
                super(LinearChain, self).__init__()
                self.linear1 = nn.Linear(3, 4)
                self.linear2 = nn.Linear(4, 5)
                self.linear3 = nn.Linear(5, 6)

            def forward(self, x):
                x = self.linear1(x)
                x = self.linear2(x)
                x = self.linear3(x)
                return x
        float_model = QuantWrapper(LinearChain())
        img_data = [(torch.rand(10, 3, dtype=torch.float), torch.randint(0, 1, (2,), dtype=torch.long))
                    for _ in range(50)]
        self.correct_artificial_bias_quantize(float_model, img_data)

    @skipIfNoFBGEMM
    def test_conv_chain(self):
        class ConvChain(nn.Module):
            def __init__(self):
                super(ConvChain, self).__init__()
                self.conv2d1 = nn.Conv2d(3, 4, 5, 5)
                self.conv2d2 = nn.Conv2d(4, 5, 5, 5)
                self.conv2d3 = nn.Conv2d(5, 6, 5, 5)

            def forward(self, x):
                x = self.conv2d1(x)
                x = self.conv2d2(x)
                x = self.conv2d3(x)
                return x
        float_model = QuantWrapper(ConvChain())
        img_data = [(torch.rand(10, 3, 125, 125, dtype=torch.float), torch.randint(0, 1, (2,), dtype=torch.long))
                    for _ in range(50)]
        self.correct_artificial_bias_quantize(float_model, img_data)