File: test_sparsity_utils.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (150 lines) | stat: -rw-r--r-- 5,779 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
# -*- coding: utf-8 -*-
# Owner(s): ["module: unknown"]


import logging

import torch
from torch.ao.sparsity.sparsifier.utils import (
    fqn_to_module,
    get_arg_info_from_tensor_fqn,
    module_to_fqn,
)

from torch.testing._internal.common_quantization import (
    ConvBnReLUModel,
    ConvModel,
    FunctionalLinear,
    LinearAddModel,
    ManualEmbeddingBagLinear,
    SingleLayerLinearModel,
    TwoLayerLinearModel,
)
from torch.testing._internal.common_utils import TestCase

logging.basicConfig(
    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
)

model_list = [
    ConvModel,
    SingleLayerLinearModel,
    TwoLayerLinearModel,
    LinearAddModel,
    ConvBnReLUModel,
    ManualEmbeddingBagLinear,
    FunctionalLinear,
]


class TestSparsityUtilFunctions(TestCase):
    def test_module_to_fqn(self):
        """
        Tests that module_to_fqn works as expected when compared to known good
        module.get_submodule(fqn) function
        """
        for model_class in model_list:
            model = model_class()
            list_of_modules = [m for _, m in model.named_modules()] + [model]
            for module in list_of_modules:
                fqn = module_to_fqn(model, module)
                check_module = model.get_submodule(fqn)
                self.assertEqual(module, check_module)

    def test_module_to_fqn_fail(self):
        """
        Tests that module_to_fqn returns None when an fqn that doesn't
        correspond to a path to a node/tensor is given
        """
        for model_class in model_list:
            model = model_class()
            fqn = module_to_fqn(model, torch.nn.Linear(3, 3))
            self.assertEqual(fqn, None)

    def test_module_to_fqn_root(self):
        """
        Tests that module_to_fqn returns '' when model and target module are the same
        """
        for model_class in model_list:
            model = model_class()
            fqn = module_to_fqn(model, model)
            self.assertEqual(fqn, "")

    def test_fqn_to_module(self):
        """
        Tests that fqn_to_module operates as inverse
        of module_to_fqn
        """
        for model_class in model_list:
            model = model_class()
            list_of_modules = [m for _, m in model.named_modules()] + [model]
            for module in list_of_modules:
                fqn = module_to_fqn(model, module)
                check_module = fqn_to_module(model, fqn)
                self.assertEqual(module, check_module)

    def test_fqn_to_module_fail(self):
        """
        Tests that fqn_to_module returns None when it tries to
        find an fqn of a module outside the model
        """
        for model_class in model_list:
            model = model_class()
            fqn = "foo.bar.baz"
            check_module = fqn_to_module(model, fqn)
            self.assertEqual(check_module, None)

    def test_fqn_to_module_for_tensors(self):
        """
        Tests that fqn_to_module works for tensors, actually all parameters
        of the model. This is tested by identifying a module with a tensor,
        and generating the tensor_fqn using module_to_fqn on the module +
        the name of the tensor.
        """
        for model_class in model_list:
            model = model_class()
            list_of_modules = [m for _, m in model.named_modules()] + [model]
            for module in list_of_modules:
                module_fqn = module_to_fqn(model, module)
                for tensor_name, tensor in module.named_parameters(recurse=False):
                    tensor_fqn = (  # string manip to handle tensors on root
                        module_fqn + ("." if module_fqn != "" else "") + tensor_name
                    )
                    check_tensor = fqn_to_module(model, tensor_fqn)
                    self.assertEqual(tensor, check_tensor)

    def test_get_arg_info_from_tensor_fqn(self):
        """
        Tests that get_arg_info_from_tensor_fqn works for all parameters of the model.
        Generates a tensor_fqn in the same way as test_fqn_to_module_for_tensors and
        then compares with known (parent) module and tensor_name as well as module_fqn
        from module_to_fqn.
        """
        for model_class in model_list:
            model = model_class()
            list_of_modules = [m for _, m in model.named_modules()] + [model]
            for module in list_of_modules:
                module_fqn = module_to_fqn(model, module)
                for tensor_name, tensor in module.named_parameters(recurse=False):
                    tensor_fqn = (
                        module_fqn + ("." if module_fqn != "" else "") + tensor_name
                    )
                    arg_info = get_arg_info_from_tensor_fqn(model, tensor_fqn)
                    self.assertEqual(arg_info["module"], module)
                    self.assertEqual(arg_info["module_fqn"], module_fqn)
                    self.assertEqual(arg_info["tensor_name"], tensor_name)
                    self.assertEqual(arg_info["tensor_fqn"], tensor_fqn)

    def test_get_arg_info_from_tensor_fqn_fail(self):
        """
        Tests that get_arg_info_from_tensor_fqn works as expected for invalid tensor_fqn
        inputs. The string outputs still work but the output module is expected to be None.
        """
        for model_class in model_list:
            model = model_class()
            tensor_fqn = "foo.bar.baz"
            arg_info = get_arg_info_from_tensor_fqn(model, tensor_fqn)
            self.assertEqual(arg_info["module"], None)
            self.assertEqual(arg_info["module_fqn"], "foo.bar")
            self.assertEqual(arg_info["tensor_name"], "baz")
            self.assertEqual(arg_info["tensor_fqn"], "foo.bar.baz")