1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330
|
# Owner(s): ["module: optimizer"]
import itertools
import pickle
import torch
from torch.optim.swa_utils import (
AveragedModel,
get_ema_multi_avg_fn,
get_swa_multi_avg_fn,
update_bn,
)
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
load_tests,
parametrize,
TestCase,
)
# load_tests from common_utils is used to automatically filter tests for
# sharding on sandcastle. This line silences flake warnings
load_tests = load_tests
class TestSWAUtils(TestCase):
class SWATestDNN(torch.nn.Module):
def __init__(self, input_features):
super().__init__()
self.n_features = 100
self.fc1 = torch.nn.Linear(input_features, self.n_features)
self.bn = torch.nn.BatchNorm1d(self.n_features)
def compute_preactivation(self, x):
return self.fc1(x)
def forward(self, x):
x = self.fc1(x)
x = self.bn(x)
return x
class SWATestCNN(torch.nn.Module):
def __init__(self, input_channels):
super().__init__()
self.n_features = 10
self.conv1 = torch.nn.Conv2d(
input_channels, self.n_features, kernel_size=3, padding=1
)
self.bn = torch.nn.BatchNorm2d(self.n_features, momentum=0.3)
def compute_preactivation(self, x):
return self.conv1(x)
def forward(self, x):
x = self.conv1(x)
x = self.bn(x)
return x
def _test_averaged_model(self, net_device, swa_device, ema):
dnn = torch.nn.Sequential(
torch.nn.Conv2d(1, 5, kernel_size=3),
torch.nn.ReLU(),
torch.nn.MaxPool2d(kernel_size=2),
torch.nn.BatchNorm2d(5, momentum=0.3),
torch.nn.Conv2d(5, 2, kernel_size=3),
torch.nn.ReLU(),
torch.nn.Linear(5, 5),
torch.nn.ReLU(),
torch.nn.Linear(5, 10),
).to(net_device)
averaged_params, averaged_dnn = self._run_averaged_steps(dnn, swa_device, ema)
for p_avg, p_swa in zip(averaged_params, averaged_dnn.parameters()):
self.assertEqual(p_avg, p_swa)
# Check that AveragedModel is on the correct device
self.assertTrue(p_swa.device == swa_device)
self.assertTrue(p_avg.device == net_device)
self.assertTrue(averaged_dnn.n_averaged.device == swa_device)
def _run_averaged_steps(self, dnn, swa_device, ema):
ema_decay = 0.999
if ema:
averaged_dnn = AveragedModel(
dnn, device=swa_device, multi_avg_fn=get_ema_multi_avg_fn(ema_decay)
)
else:
averaged_dnn = AveragedModel(
dnn, device=swa_device, multi_avg_fn=get_swa_multi_avg_fn()
)
averaged_params = [torch.zeros_like(param) for param in dnn.parameters()]
n_updates = 10
for i in range(n_updates):
for p, p_avg in zip(dnn.parameters(), averaged_params):
p.detach().add_(torch.randn_like(p))
if ema:
p_avg += (
p.detach()
* ema_decay ** (n_updates - i - 1)
* ((1 - ema_decay) if i > 0 else 1.0)
)
else:
p_avg += p.detach() / n_updates
averaged_dnn.update_parameters(dnn)
return averaged_params, averaged_dnn
@parametrize("ema", [True, False])
def test_averaged_model_all_devices(self, ema):
cpu = torch.device("cpu")
self._test_averaged_model(cpu, cpu, ema)
if torch.cuda.is_available():
cuda = torch.device(0)
self._test_averaged_model(cuda, cpu, ema)
self._test_averaged_model(cpu, cuda, ema)
self._test_averaged_model(cuda, cuda, ema)
@parametrize("ema", [True, False])
def test_averaged_model_mixed_device(self, ema):
if not torch.cuda.is_available():
return
dnn = torch.nn.Sequential(
torch.nn.Conv2d(1, 5, kernel_size=3), torch.nn.Linear(5, 10)
)
dnn[0].cuda()
dnn[1].cpu()
averaged_params, averaged_dnn = self._run_averaged_steps(dnn, None, ema)
for p_avg, p_swa in zip(averaged_params, averaged_dnn.parameters()):
self.assertEqual(p_avg, p_swa)
# Check that AveragedModel is on the correct device
self.assertTrue(p_avg.device == p_swa.device)
def test_averaged_model_state_dict(self):
dnn = torch.nn.Sequential(
torch.nn.Conv2d(1, 5, kernel_size=3), torch.nn.Linear(5, 10)
)
averaged_dnn = AveragedModel(dnn)
averaged_dnn2 = AveragedModel(dnn)
n_updates = 10
for i in range(n_updates):
for p in dnn.parameters():
p.detach().add_(torch.randn_like(p))
averaged_dnn.update_parameters(dnn)
averaged_dnn2.load_state_dict(averaged_dnn.state_dict())
for p_swa, p_swa2 in zip(averaged_dnn.parameters(), averaged_dnn2.parameters()):
self.assertEqual(p_swa, p_swa2)
self.assertTrue(averaged_dnn.n_averaged == averaged_dnn2.n_averaged)
def test_averaged_model_default_avg_fn_picklable(self):
dnn = torch.nn.Sequential(
torch.nn.Conv2d(1, 5, kernel_size=3),
torch.nn.BatchNorm2d(5),
torch.nn.Linear(5, 5),
)
averaged_dnn = AveragedModel(dnn)
pickle.dumps(averaged_dnn)
@parametrize("use_multi_avg_fn", [True, False])
@parametrize("use_buffers", [True, False])
def test_averaged_model_exponential(self, use_multi_avg_fn, use_buffers):
# Test AveragedModel with EMA as avg_fn and use_buffers as True.
dnn = torch.nn.Sequential(
torch.nn.Conv2d(1, 5, kernel_size=3),
torch.nn.BatchNorm2d(5, momentum=0.3),
torch.nn.Linear(5, 10),
)
decay = 0.9
if use_multi_avg_fn:
averaged_dnn = AveragedModel(
dnn, multi_avg_fn=get_ema_multi_avg_fn(decay), use_buffers=use_buffers
)
else:
def avg_fn(p_avg, p, n_avg):
return decay * p_avg + (1 - decay) * p
averaged_dnn = AveragedModel(dnn, avg_fn=avg_fn, use_buffers=use_buffers)
if use_buffers:
dnn_params = list(itertools.chain(dnn.parameters(), dnn.buffers()))
else:
dnn_params = list(dnn.parameters())
averaged_params = [
torch.zeros_like(param)
for param in dnn_params
if param.size() != torch.Size([])
]
n_updates = 10
for i in range(n_updates):
updated_averaged_params = []
for p, p_avg in zip(dnn_params, averaged_params):
if p.size() == torch.Size([]):
continue
p.detach().add_(torch.randn_like(p))
if i == 0:
updated_averaged_params.append(p.clone())
else:
updated_averaged_params.append(
(p_avg * decay + p * (1 - decay)).clone()
)
averaged_dnn.update_parameters(dnn)
averaged_params = updated_averaged_params
if use_buffers:
for p_avg, p_swa in zip(
averaged_params,
itertools.chain(
averaged_dnn.module.parameters(), averaged_dnn.module.buffers()
),
):
self.assertEqual(p_avg, p_swa)
else:
for p_avg, p_swa in zip(averaged_params, averaged_dnn.parameters()):
self.assertEqual(p_avg, p_swa)
for b_avg, b_swa in zip(dnn.buffers(), averaged_dnn.module.buffers()):
self.assertEqual(b_avg, b_swa)
def _test_update_bn(self, dnn, dl_x, dl_xy, cuda):
preactivation_sum = torch.zeros(dnn.n_features)
preactivation_squared_sum = torch.zeros(dnn.n_features)
if cuda:
preactivation_sum = preactivation_sum.cuda()
preactivation_squared_sum = preactivation_squared_sum.cuda()
total_num = 0
for x in dl_x:
x = x[0]
if cuda:
x = x.cuda()
dnn.forward(x)
preactivations = dnn.compute_preactivation(x)
if len(preactivations.shape) == 4:
preactivations = preactivations.transpose(1, 3)
preactivations = preactivations.contiguous().view(-1, dnn.n_features)
total_num += preactivations.shape[0]
preactivation_sum += torch.sum(preactivations, dim=0)
preactivation_squared_sum += torch.sum(preactivations**2, dim=0)
preactivation_mean = preactivation_sum / total_num
preactivation_var = preactivation_squared_sum / total_num
preactivation_var = preactivation_var - preactivation_mean**2
update_bn(dl_xy, dnn, device=x.device)
self.assertEqual(preactivation_mean, dnn.bn.running_mean)
self.assertEqual(preactivation_var, dnn.bn.running_var, atol=1e-1, rtol=0)
def _reset_bn(module):
if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm):
module.running_mean = torch.zeros_like(module.running_mean)
module.running_var = torch.ones_like(module.running_var)
# reset batch norm and run update_bn again
dnn.apply(_reset_bn)
update_bn(dl_xy, dnn, device=x.device)
self.assertEqual(preactivation_mean, dnn.bn.running_mean)
self.assertEqual(preactivation_var, dnn.bn.running_var, atol=1e-1, rtol=0)
# using the dl_x loader instead of dl_xy
dnn.apply(_reset_bn)
update_bn(dl_x, dnn, device=x.device)
self.assertEqual(preactivation_mean, dnn.bn.running_mean)
self.assertEqual(preactivation_var, dnn.bn.running_var, atol=1e-1, rtol=0)
def test_update_bn_dnn(self):
# Test update_bn for a fully-connected network with BatchNorm1d
objects, input_features = 100, 5
x = torch.rand(objects, input_features)
y = torch.rand(objects)
ds_x = torch.utils.data.TensorDataset(x)
ds_xy = torch.utils.data.TensorDataset(x, y)
dl_x = torch.utils.data.DataLoader(ds_x, batch_size=5, shuffle=True)
dl_xy = torch.utils.data.DataLoader(ds_xy, batch_size=5, shuffle=True)
dnn = self.SWATestDNN(input_features=input_features)
dnn.train()
self._test_update_bn(dnn, dl_x, dl_xy, False)
if torch.cuda.is_available():
dnn = self.SWATestDNN(input_features=input_features)
dnn.train()
self._test_update_bn(dnn.cuda(), dl_x, dl_xy, True)
self.assertTrue(dnn.training)
def test_update_bn_cnn(self):
# Test update_bn for convolutional network and BatchNorm2d
objects = 100
input_channels = 3
height, width = 5, 5
x = torch.rand(objects, input_channels, height, width)
y = torch.rand(objects)
ds_x = torch.utils.data.TensorDataset(x)
ds_xy = torch.utils.data.TensorDataset(x, y)
dl_x = torch.utils.data.DataLoader(ds_x, batch_size=5, shuffle=True)
dl_xy = torch.utils.data.DataLoader(ds_xy, batch_size=5, shuffle=True)
cnn = self.SWATestCNN(input_channels=input_channels)
cnn.train()
self._test_update_bn(cnn, dl_x, dl_xy, False)
if torch.cuda.is_available():
cnn = self.SWATestCNN(input_channels=input_channels)
cnn.train()
self._test_update_bn(cnn.cuda(), dl_x, dl_xy, True)
self.assertTrue(cnn.training)
def test_bn_update_eval_momentum(self):
# check that update_bn preserves eval mode
objects = 100
input_channels = 3
height, width = 5, 5
x = torch.rand(objects, input_channels, height, width)
ds_x = torch.utils.data.TensorDataset(x)
dl_x = torch.utils.data.DataLoader(ds_x, batch_size=5, shuffle=True)
cnn = self.SWATestCNN(input_channels=input_channels)
cnn.eval()
update_bn(dl_x, cnn)
self.assertFalse(cnn.training)
# check that momentum is preserved
self.assertEqual(cnn.bn.momentum, 0.3)
instantiate_parametrized_tests(TestSWAUtils)
if __name__ == "__main__":
print("These tests should be run through test/test_optim.py instead")
|