1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150
|
# -*- coding: utf-8 -*-
# Owner(s): ["module: unknown"]
import logging
import torch
from torch.ao.sparsity.sparsifier.utils import (
fqn_to_module,
get_arg_info_from_tensor_fqn,
module_to_fqn,
)
from torch.testing._internal.common_quantization import (
ConvBnReLUModel,
ConvModel,
FunctionalLinear,
LinearAddModel,
ManualEmbeddingBagLinear,
SingleLayerLinearModel,
TwoLayerLinearModel,
)
from torch.testing._internal.common_utils import TestCase
logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
)
model_list = [
ConvModel,
SingleLayerLinearModel,
TwoLayerLinearModel,
LinearAddModel,
ConvBnReLUModel,
ManualEmbeddingBagLinear,
FunctionalLinear,
]
class TestSparsityUtilFunctions(TestCase):
def test_module_to_fqn(self):
"""
Tests that module_to_fqn works as expected when compared to known good
module.get_submodule(fqn) function
"""
for model_class in model_list:
model = model_class()
list_of_modules = [m for _, m in model.named_modules()] + [model]
for module in list_of_modules:
fqn = module_to_fqn(model, module)
check_module = model.get_submodule(fqn)
self.assertEqual(module, check_module)
def test_module_to_fqn_fail(self):
"""
Tests that module_to_fqn returns None when an fqn that doesn't
correspond to a path to a node/tensor is given
"""
for model_class in model_list:
model = model_class()
fqn = module_to_fqn(model, torch.nn.Linear(3, 3))
self.assertEqual(fqn, None)
def test_module_to_fqn_root(self):
"""
Tests that module_to_fqn returns '' when model and target module are the same
"""
for model_class in model_list:
model = model_class()
fqn = module_to_fqn(model, model)
self.assertEqual(fqn, "")
def test_fqn_to_module(self):
"""
Tests that fqn_to_module operates as inverse
of module_to_fqn
"""
for model_class in model_list:
model = model_class()
list_of_modules = [m for _, m in model.named_modules()] + [model]
for module in list_of_modules:
fqn = module_to_fqn(model, module)
check_module = fqn_to_module(model, fqn)
self.assertEqual(module, check_module)
def test_fqn_to_module_fail(self):
"""
Tests that fqn_to_module returns None when it tries to
find an fqn of a module outside the model
"""
for model_class in model_list:
model = model_class()
fqn = "foo.bar.baz"
check_module = fqn_to_module(model, fqn)
self.assertEqual(check_module, None)
def test_fqn_to_module_for_tensors(self):
"""
Tests that fqn_to_module works for tensors, actually all parameters
of the model. This is tested by identifying a module with a tensor,
and generating the tensor_fqn using module_to_fqn on the module +
the name of the tensor.
"""
for model_class in model_list:
model = model_class()
list_of_modules = [m for _, m in model.named_modules()] + [model]
for module in list_of_modules:
module_fqn = module_to_fqn(model, module)
for tensor_name, tensor in module.named_parameters(recurse=False):
tensor_fqn = ( # string manip to handle tensors on root
module_fqn + ("." if module_fqn != "" else "") + tensor_name
)
check_tensor = fqn_to_module(model, tensor_fqn)
self.assertEqual(tensor, check_tensor)
def test_get_arg_info_from_tensor_fqn(self):
"""
Tests that get_arg_info_from_tensor_fqn works for all parameters of the model.
Generates a tensor_fqn in the same way as test_fqn_to_module_for_tensors and
then compares with known (parent) module and tensor_name as well as module_fqn
from module_to_fqn.
"""
for model_class in model_list:
model = model_class()
list_of_modules = [m for _, m in model.named_modules()] + [model]
for module in list_of_modules:
module_fqn = module_to_fqn(model, module)
for tensor_name, tensor in module.named_parameters(recurse=False):
tensor_fqn = (
module_fqn + ("." if module_fqn != "" else "") + tensor_name
)
arg_info = get_arg_info_from_tensor_fqn(model, tensor_fqn)
self.assertEqual(arg_info["module"], module)
self.assertEqual(arg_info["module_fqn"], module_fqn)
self.assertEqual(arg_info["tensor_name"], tensor_name)
self.assertEqual(arg_info["tensor_fqn"], tensor_fqn)
def test_get_arg_info_from_tensor_fqn_fail(self):
"""
Tests that get_arg_info_from_tensor_fqn works as expected for invalid tensor_fqn
inputs. The string outputs still work but the output module is expected to be None.
"""
for model_class in model_list:
model = model_class()
tensor_fqn = "foo.bar.baz"
arg_info = get_arg_info_from_tensor_fqn(model, tensor_fqn)
self.assertEqual(arg_info["module"], None)
self.assertEqual(arg_info["module_fqn"], "foo.bar")
self.assertEqual(arg_info["tensor_name"], "baz")
self.assertEqual(arg_info["tensor_fqn"], "foo.bar.baz")
|