1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179
|
# -*- coding: utf-8 -*-
# Owner(s): ["module: unknown"]
from torch import nn
from torch.ao.sparsity import WeightNormSparsifier
from torch.ao.sparsity import BaseScheduler, LambdaSL, CubicSL
from torch.testing._internal.common_utils import TestCase
import warnings
class ImplementedScheduler(BaseScheduler):
def get_sl(self):
if self.last_epoch > 0:
return [group['sparsity_level'] * 0.5
for group in self.sparsifier.groups]
else:
return list(self.base_sl)
class TestScheduler(TestCase):
def test_constructor(self):
model = nn.Sequential(
nn.Linear(16, 16)
)
sparsifier = WeightNormSparsifier()
sparsifier.prepare(model, config=None)
scheduler = ImplementedScheduler(sparsifier)
assert scheduler.sparsifier is sparsifier
assert scheduler._step_count == 1
assert scheduler.base_sl == [sparsifier.groups[0]['sparsity_level']]
def test_order_of_steps(self):
"""Checks if the warning is thrown if the scheduler step is called
before the sparsifier step"""
model = nn.Sequential(
nn.Linear(16, 16)
)
sparsifier = WeightNormSparsifier()
sparsifier.prepare(model, config=None)
scheduler = ImplementedScheduler(sparsifier)
# Sparsifier step is not called
with self.assertWarns(UserWarning):
scheduler.step()
# Correct order has no warnings
# Note: This will trigger if other warnings are present.
with warnings.catch_warnings(record=True) as w:
sparsifier.step()
scheduler.step()
# Make sure there is no warning related to the base_scheduler
for warning in w:
fname = warning.filename
fname = '/'.join(fname.split('/')[-5:])
assert fname != 'torch/ao/sparsity/scheduler/base_scheduler.py'
def test_step(self):
model = nn.Sequential(
nn.Linear(16, 16)
)
sparsifier = WeightNormSparsifier()
sparsifier.prepare(model, config=None)
assert sparsifier.groups[0]['sparsity_level'] == 0.5
scheduler = ImplementedScheduler(sparsifier)
assert sparsifier.groups[0]['sparsity_level'] == 0.5
sparsifier.step()
scheduler.step()
assert sparsifier.groups[0]['sparsity_level'] == 0.25
def test_lambda_scheduler(self):
model = nn.Sequential(
nn.Linear(16, 16)
)
sparsifier = WeightNormSparsifier()
sparsifier.prepare(model, config=None)
assert sparsifier.groups[0]['sparsity_level'] == 0.5
scheduler = LambdaSL(sparsifier, lambda epoch: epoch * 10)
assert sparsifier.groups[0]['sparsity_level'] == 0.0 # Epoch 0
scheduler.step()
assert sparsifier.groups[0]['sparsity_level'] == 5.0 # Epoch 1
class TestCubicScheduler(TestCase):
def setUp(self):
self.model_sparse_config = [
{'tensor_fqn': '0.weight', 'sparsity_level': 0.8},
{'tensor_fqn': '2.weight', 'sparsity_level': 0.4},
]
self.sorted_sparse_levels = [conf['sparsity_level'] for conf in self.model_sparse_config]
self.initial_sparsity = 0.1
self.initial_step = 3
def _make_model(self, **kwargs):
model = nn.Sequential(
nn.Linear(13, 17),
nn.Dropout(0.5),
nn.Linear(17, 3),
)
return model
def _make_scheduler(self, model, **kwargs):
sparsifier = WeightNormSparsifier()
sparsifier.prepare(model, config=self.model_sparse_config)
scheduler_args = {
'init_sl': self.initial_sparsity,
'init_t': self.initial_step,
}
scheduler_args.update(kwargs)
scheduler = CubicSL(sparsifier, **scheduler_args)
return sparsifier, scheduler
@staticmethod
def _get_sparsity_levels(sparsifier, precision=32):
r"""Gets the current levels of sparsity in a sparsifier."""
return [round(group['sparsity_level'], precision) for group in sparsifier.groups]
def test_constructor(self):
model = self._make_model()
sparsifier, scheduler = self._make_scheduler(model=model, initially_zero=True)
self.assertIs(
scheduler.sparsifier, sparsifier,
msg="Sparsifier is not properly attached")
self.assertEqual(
scheduler._step_count, 1,
msg="Scheduler is initialized with incorrect step count")
self.assertEqual(
scheduler.base_sl, self.sorted_sparse_levels,
msg="Scheduler did not store the target sparsity levels correctly")
# Value before t_0 is 0
self.assertEqual(
self._get_sparsity_levels(sparsifier), scheduler._make_sure_a_list(0.0),
msg="Sparsifier is not reset correctly after attaching to the Scheduler")
# Value before t_0 is s_0
model = self._make_model()
sparsifier, scheduler = self._make_scheduler(model=model, initially_zero=False)
self.assertEqual(
self._get_sparsity_levels(sparsifier),
scheduler._make_sure_a_list(self.initial_sparsity),
msg="Sparsifier is not reset correctly after attaching to the Scheduler")
def test_step(self):
# For n=5, dt=2, there will be totally 10 steps between s_0 and s_f, starting from t_0
model = self._make_model()
sparsifier, scheduler = self._make_scheduler(
model=model, initially_zero=True, init_t=3, delta_t=2, total_t=5)
scheduler.step()
scheduler.step()
self.assertEqual(scheduler._step_count, 3, msg="Scheduler step_count is expected to increment")
# Value before t_0 is supposed to be 0
self.assertEqual(
self._get_sparsity_levels(sparsifier), scheduler._make_sure_a_list(0.0),
msg="Scheduler step updating the sparsity level before t_0")
scheduler.step() # Step = 3 => sparsity = initial_sparsity
self.assertEqual(
self._get_sparsity_levels(sparsifier), scheduler._make_sure_a_list(self.initial_sparsity),
msg="Sparsifier is not reset to initial sparsity at the first step")
scheduler.step() # Step = 4 => sparsity ~ [0.3, 0.2]
self.assertEqual(
self._get_sparsity_levels(sparsifier, 1), [0.3, 0.2],
msg="Sparsity level is not set correctly after the first step")
current_step = scheduler._step_count - scheduler.init_t[0] - 1
more_steps_needed = scheduler.delta_t[0] * scheduler.total_t[0] - current_step
for _ in range(more_steps_needed): # More steps needed to final sparsity level
scheduler.step()
self.assertEqual(
self._get_sparsity_levels(sparsifier), self.sorted_sparse_levels,
msg="Sparsity level is not reaching the target level afer delta_t * n steps ")
|