1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394
|
# -*- coding: utf-8 -*-
# Owner(s): ["module: unknown"]
import copy
import logging
import torch
from torch import nn
from torch.ao.sparsity._experimental.pruner import BasePruner, PruningParametrization, ZeroesParametrization
from torch.nn.utils import parametrize
from torch.testing._internal.common_utils import TestCase, skipIfTorchDynamo
logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', level=logging.INFO)
DEVICES = {
torch.device("cpu"),
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
}
NEEDS_ZEROS = { # these layers should have pruned indices zero-ed, not removed
nn.BatchNorm2d
}
class Linear(nn.Module):
r"""Model with Linear layers, in Sequential and outside, without biases"""
def __init__(self):
super().__init__()
self.seq = nn.Sequential(
nn.Linear(16, 16, bias=False)
)
self.linear = nn.Linear(16, 16, bias=False)
def forward(self, x):
x = self.seq(x)
x = self.linear(x)
return x
class LinearB(nn.Module):
r"""Model with Linear layers, in Sequential and outside, with biases"""
def __init__(self):
super().__init__()
self.seq = nn.Sequential(
nn.Linear(16, 16, bias=True)
)
self.linear = nn.Linear(16, 16, bias=True)
def forward(self, x):
x = self.seq(x)
x = self.linear(x)
return x
class MultipleLinear(nn.Module):
r"""Model with multiple Linear layers, in Sequential and outside, without biases
and with activation functions"""
def __init__(self):
super().__init__()
self.seq = nn.Sequential(
nn.Linear(7, 5, bias=False),
nn.ReLU(),
nn.Linear(5, 8, bias=False),
nn.ReLU(),
nn.Linear(8, 6, bias=False)
)
self.linear = nn.Linear(6, 4, bias=False)
def forward(self, x):
x = self.seq(x)
x = self.linear(x)
return x
class MultipleLinearB(nn.Module):
r"""Model with multiple Linear layers, in Sequential and outside, with biases
and with activation functions"""
def __init__(self):
super().__init__()
self.seq = nn.Sequential(
nn.Linear(7, 5, bias=True),
nn.ReLU(),
nn.Linear(5, 8, bias=True),
nn.ReLU(),
nn.Linear(8, 6, bias=True)
)
self.linear = nn.Linear(6, 4, bias=True)
def forward(self, x):
x = self.seq(x)
x = self.linear(x)
return x
class MultipleLinearMixed(nn.Module):
r"""Model with multiple Linear layers, in Sequential and outside, some with biases
and with activation functions"""
def __init__(self):
super().__init__()
self.seq = nn.Sequential(
nn.Linear(7, 5, bias=True),
nn.ReLU(),
nn.Linear(5, 8, bias=False),
nn.ReLU(),
nn.Linear(8, 6, bias=True)
)
self.linear = nn.Linear(6, 4, bias=False)
def forward(self, x):
x = self.seq(x)
x = self.linear(x)
return x
class Conv2dA(nn.Module):
r"""Model with Conv2d layers, in Sequential and outside, without biases"""
def __init__(self):
super().__init__()
self.seq = nn.Sequential(
nn.Conv2d(1, 32, 3, 1, bias=False),
)
self.conv2d = nn.Conv2d(32, 64, 3, 1, bias=False)
def forward(self, x):
x = self.seq(x)
x = self.conv2d(x)
return x
class Conv2dB(nn.Module):
r"""Model with Conv2d layers, in Sequential and outside, with biases"""
def __init__(self):
super().__init__()
self.seq = nn.Sequential(
nn.Conv2d(1, 32, 3, 1, bias=True),
)
self.conv2d = nn.Conv2d(32, 64, 3, 1, bias=True)
def forward(self, x):
x = self.seq(x)
x = self.conv2d(x)
return x
class Conv2dC(nn.Module):
r"""Model with Conv2d layers, in Sequential and outside, with and without biases"""
def __init__(self):
super().__init__()
self.seq = nn.Sequential(
nn.Conv2d(1, 32, 3, 1, bias=True),
)
self.conv2d = nn.Conv2d(32, 64, 3, 1, bias=False)
def forward(self, x):
x = self.seq(x)
x = self.conv2d(x)
return x
class Conv2dBN(nn.Module):
r"""Model with Conv2d layers and BatchNorms"""
def __init__(self):
super().__init__()
self.seq = nn.Sequential(
nn.Conv2d(1, 32, 3, 1, bias=True),
nn.BatchNorm2d(32)
)
self.conv2d = nn.Conv2d(32, 64, 3, 1, bias=True)
self.bn = nn.BatchNorm2d(64)
def forward(self, x):
x = self.seq(x)
x = self.conv2d(x)
x = self.bn(x)
return x
class SimplePruner(BasePruner):
def update_mask(self, module, tensor_name, **kwargs):
getattr(module.parametrizations, tensor_name)[0].pruned_outputs.add(1)
class MultiplePruner(BasePruner):
def update_mask(self, module, tensor_name, **kwargs):
getattr(module.parametrizations, tensor_name)[0].pruned_outputs.update([1, 2])
class TestBasePruner(TestCase):
def _check_pruner_prepared(self, model, pruner, device):
for config in pruner.groups:
modules = []
if type(config['module']) is tuple:
for module in config['module']:
modules.append(module)
else:
module = config['module']
modules.append(module)
for module in modules:
assert module.weight.device.type == device.type
# Check mask exists
assert hasattr(module, 'mask')
# Check parametrization exists and is correct
assert parametrize.is_parametrized(module)
assert hasattr(module, "parametrizations")
# Assume that this is the 1st/only parametrization
if isinstance(module, tuple(NEEDS_ZEROS)):
assert type(module.parametrizations.weight[0]) == ZeroesParametrization
else:
assert type(module.parametrizations.weight[0]) == PruningParametrization
def _check_pruner_mask_squashed(self, model, pruner, device):
for config in pruner.groups:
modules = []
if type(config['module']) is tuple:
for module in config['module']:
modules.append(module)
else:
module = config['module']
modules.append(module)
for module in modules:
assert module.weight.device.type == device.type
assert not hasattr(module, "parametrizations")
assert not hasattr(module, 'mask')
def _check_pruner_valid_before_step(self, model, pruner, device):
for config in pruner.groups:
modules = []
if type(config['module']) is tuple:
for module in config['module']:
modules.append(module)
else:
module = config['module']
modules.append(module)
for module in modules:
assert module.weight.device.type == device.type
assert module.parametrizations.weight[0].pruned_outputs == set()
def _check_pruner_valid_after_step(self, model, pruner, pruned_set, device):
for config in pruner.groups:
modules = []
if type(config['module']) is tuple:
for module in config['module']:
modules.append(module)
else:
module = config['module']
modules.append(module)
for module in modules:
assert module.weight.device.type == device.type
assert module.parametrizations.weight[0].pruned_outputs == pruned_set
def _test_constructor_on_device(self, model, device):
self.assertRaisesRegex(TypeError, 'BasePruner .* update_mask',
BasePruner)
model1 = copy.deepcopy(model).to(device)
pruner = SimplePruner(None)
pruner.prepare(model1, None)
for g in pruner.groups:
module = g['module']
assert module.weight.device.type == device.type
assert len(pruner.groups) == 2
pruner.step()
# Can instantiate the model with configs
model2 = copy.deepcopy(model).to(device)
pruner = SimplePruner({'test': 3})
pruner.prepare(model2, [model2.linear])
assert len(pruner.groups) == 1
assert pruner.groups[0]['module_fqn'] == 'linear'
assert 'test' in pruner.groups[0]
assert pruner.groups[0]['test'] == 3
def test_constructor(self):
model = Linear()
for device in DEVICES:
self._test_constructor_on_device(model, torch.device(device))
def _test_prepare_linear_on_device(self, model, device):
model = copy.deepcopy(model).to(device)
x = torch.ones(128, 16, device=device)
pruner = SimplePruner(None)
pruner.prepare(model, None)
self._check_pruner_prepared(model, pruner, device)
assert model(x).shape == (128, 16)
def test_prepare_linear(self):
models = [Linear(), LinearB()] # without and with bias
for device in DEVICES:
for model in models:
self._test_prepare_linear_on_device(model, torch.device(device))
def _test_prepare_conv2d_on_device(self, model, config, device):
x = torch.ones((1, 1, 28, 28), device=device)
pruner = SimplePruner(None)
pruner.prepare(model, config)
self._check_pruner_prepared(model, pruner, device)
assert model(x).shape == (1, 64, 24, 24)
def test_prepare_conv2d(self):
bn_model = Conv2dBN()
bn_config = [(bn_model.seq[0], bn_model.seq[1]), (bn_model.conv2d, bn_model.bn)]
models = [Conv2dA(), Conv2dB(), Conv2dC(), bn_model]
configs = [None, None, None, bn_config]
for device in DEVICES:
for model, config in zip(models, configs):
model = model.to(device)
self._test_prepare_conv2d_on_device(model, config, torch.device(device))
def _test_squash_mask_linear_on_device(self, model, device):
model = copy.deepcopy(model).to(device)
x = torch.ones(128, 16, device=device)
pruner = SimplePruner(None)
pruner.prepare(model, None)
pruner.squash_mask()
self._check_pruner_mask_squashed(model, pruner, device)
assert model(x).shape == (128, 16)
def test_squash_mask_linear(self):
models = [Linear(), LinearB()] # without and with bias
for device in DEVICES:
for model in models:
self._test_squash_mask_linear_on_device(model, torch.device(device))
def _test_squash_mask_conv2d_on_device(self, model, config, device):
model = copy.deepcopy(model).to(device)
x = torch.ones((1, 1, 28, 28), device=device)
pruner = SimplePruner(None)
pruner.prepare(model, config)
pruner.squash_mask()
self._check_pruner_mask_squashed(model, pruner, device)
assert model(x).shape == (1, 64, 24, 24)
def test_squash_mask_conv2d(self):
bn_model = Conv2dBN()
bn_config = [(bn_model.seq[0], bn_model.seq[1]), (bn_model.conv2d, bn_model.bn)]
models = [Conv2dA(), Conv2dB(), Conv2dC(), bn_model]
configs = [None, None, None, bn_config]
for device in DEVICES:
for model, config in zip(models, configs):
model = model.to(device)
self._test_squash_mask_conv2d_on_device(model, config, torch.device(device))
def _test_step_linear_on_device(self, model, is_basic, device):
model = model.to(device)
if is_basic:
x = torch.ones(16, 16)
pruner = SimplePruner(None)
pruner.prepare(model, None)
self._check_pruner_valid_before_step(model, pruner, device)
pruner.step()
self._check_pruner_valid_after_step(model, pruner, {1}, device)
else:
x = torch.ones(7, 7)
pruner = MultiplePruner(None)
pruner.prepare(model, None)
self._check_pruner_valid_before_step(model, pruner, device)
pruner.step()
self._check_pruner_valid_after_step(model, pruner, {1, 2}, device)
def test_step_linear(self):
basic_models = [Linear(), LinearB()]
complex_models = [MultipleLinear(), MultipleLinearB(), MultipleLinearMixed()]
for device in DEVICES:
for model in basic_models:
self._test_step_linear_on_device(model, True, torch.device(device))
for model in complex_models:
self._test_step_linear_on_device(model, False, torch.device(device))
def _test_step_conv2d_on_device(self, model, config, device):
model = model.to(device)
x = torch.ones((1, 1, 28, 28)).to(device)
pruner = SimplePruner(None)
pruner.prepare(model, config)
self._check_pruner_valid_before_step(model, pruner, device)
pruner.step()
if type(model) is Conv2dBN:
assert pruner.get_module_pruned_outputs(model.seq[1]) == pruner.get_module_pruned_outputs(model.seq[0])
assert pruner.get_module_pruned_outputs(model.bn) == pruner.get_module_pruned_outputs(model.conv2d)
self._check_pruner_valid_after_step(model, pruner, {1}, device)
assert model(x).shape == (1, 64, 24, 24)
@skipIfTorchDynamo("TorchDynamo fails with unknown reason")
def test_step_conv2d(self):
bn_model = Conv2dBN()
bn_config = [(bn_model.seq[0], bn_model.seq[1]),
(bn_model.conv2d, bn_model.bn)]
models = [Conv2dA(), Conv2dB(), Conv2dC(), bn_model]
configs = [None, None, None, None, bn_config]
for device in DEVICES:
for model, config in zip(models, configs):
self._test_step_conv2d_on_device(model, config, torch.device(device))
|