1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620
|
# Owner(s): ["module: nn"]
import re
import unittest
from copy import deepcopy
from itertools import product
import torch
import torch.nn as nn
from torch.testing._internal.common_nn import NNTestCase
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
run_tests,
skipIfCrossRef,
skipIfTorchDynamo,
swap,
TEST_NUMPY,
TestCase,
)
from torch.utils._pytree import tree_map
if TEST_NUMPY:
import numpy as np
class TestLoadStateDict(NNTestCase):
_do_cuda_memory_leak_check = True
_do_cuda_non_default_stream = True
@unittest.skipIf(not TEST_NUMPY, "numpy not found")
@swap([True, False])
def test_load_state_dict_invalid(self):
m = torch.nn.Linear(2, 2, bias=False)
state_dict = {"weight": np.random.randn(2, 2)}
with self.assertRaisesRegex(
RuntimeError,
"expected torch.Tensor or Tensor-like object from checkpoint but received",
):
m.load_state_dict(state_dict)
state_dict = {"weight": ((1.0, 1.0), (2.0, 2.0))}
with self.assertRaisesRegex(
RuntimeError,
"expected torch.Tensor or Tensor-like object from checkpoint but received",
):
m.load_state_dict(state_dict)
@swap([True, False])
def test_load_state_dict_type(self):
m = nn.Module()
with self.assertRaisesRegex(
TypeError, "Expected state_dict to be dict-like, got"
):
m.load_state_dict("")
with self.assertRaisesRegex(
TypeError, "Expected state_dict to be dict-like, got"
):
m.load_state_dict(2)
@swap([True, False])
@skipIfTorchDynamo("dynamo installs weakrefs on some params")
def test_load_state_dict(self):
l = nn.Linear(5, 5)
block = nn.Module()
block.conv1 = nn.Conv2d(3, 3, 3, bias=True)
block.conv2 = nn.Conv2d(3, 3, 3, bias=False)
net = nn.Module()
net.linear1 = l
net.linear2 = l
net.bn = nn.BatchNorm2d(2)
net.block = block
net.add_module("empty", None)
conv1_bias_dtype = block.conv1.bias.dtype
state_dict = net.state_dict()
state_dict.update(
{
"linear1.weight": torch.ones(5, 5),
"block.conv1.bias": torch.arange(1, 4, dtype=conv1_bias_dtype),
"bn.running_mean": torch.randn(2),
}
)
# Also test if a DDP state_dict can be loaded from a local model.
ddp_state_dict = net.state_dict()
ddp_state_dict.update(
{
"module.linear1.weight": torch.ones(5, 5),
"module.block.conv1.bias": torch.arange(1, 4, dtype=conv1_bias_dtype),
"module.bn.running_mean": torch.randn(2),
}
)
torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(
ddp_state_dict, "module."
)
for sd in [state_dict, ddp_state_dict]:
incompatible_keys = net.load_state_dict(sd)
self.assertEqual(len(incompatible_keys.missing_keys), 0)
self.assertEqual(len(incompatible_keys.unexpected_keys), 0)
self.assertNotIn("Incompatible", str(incompatible_keys))
self.assertEqual(net.linear1.weight, sd["linear1.weight"])
self.assertEqual(net.block.conv1.bias, sd["block.conv1.bias"])
self.assertEqual(net.bn.running_mean, sd["bn.running_mean"])
state_dict = net.state_dict()
state_dict.update({"extra": torch.ones(5)})
self.assertRaises(RuntimeError, lambda: net.load_state_dict(state_dict))
incompatible_keys = net.load_state_dict(state_dict, strict=False)
self.assertEqual(len(incompatible_keys.missing_keys), 0)
self.assertEqual(len(incompatible_keys.unexpected_keys), 1)
self.assertIn("extra", incompatible_keys.unexpected_keys)
self.assertIn("Incompatible", str(incompatible_keys))
state_dict = net.state_dict()
state_dict.update({"extra.param": torch.ones(5)})
self.assertRaises(RuntimeError, lambda: net.load_state_dict(state_dict))
incompatible_keys = net.load_state_dict(state_dict, strict=False)
self.assertEqual(len(incompatible_keys.missing_keys), 0)
self.assertEqual(len(incompatible_keys.unexpected_keys), 1)
self.assertIn("extra.param", incompatible_keys.unexpected_keys)
state_dict = net.state_dict()
del state_dict["linear1.weight"]
self.assertRaises(RuntimeError, lambda: net.load_state_dict(state_dict))
incompatible_keys = net.load_state_dict(state_dict, strict=False)
self.assertEqual(len(incompatible_keys.missing_keys), 1)
self.assertEqual(len(incompatible_keys.unexpected_keys), 0)
self.assertIn("linear1.weight", incompatible_keys.missing_keys)
state_dict.update({"extra.param": torch.ones(5)})
self.assertRaises(RuntimeError, lambda: net.load_state_dict(state_dict))
incompatible_keys = net.load_state_dict(state_dict, strict=False)
self.assertEqual(len(incompatible_keys.missing_keys), 1)
self.assertEqual(len(incompatible_keys.unexpected_keys), 1)
self.assertIn("linear1.weight", incompatible_keys.missing_keys)
self.assertIn("extra.param", incompatible_keys.unexpected_keys)
state_dict = net.state_dict()
state_dict.update({"bn.running_mean": torch.rand(14, 4)}) # wrong size
self.assertRaises(RuntimeError, lambda: net.load_state_dict(state_dict))
self.assertRaises(
RuntimeError, lambda: net.load_state_dict(state_dict, strict=False)
)
state_dict = net.state_dict()
old_state_dict = deepcopy(state_dict)
state_dict = {
"linear1.weight": torch.ones(5, 5),
"block.conv1.bias": torch.arange(1, 4, dtype=conv1_bias_dtype),
"bn.running_mean": torch.randn(2),
"nonexistent_key": torch.rand(3),
}
net.load_state_dict(state_dict, strict=False)
self.assertEqual(net.linear1.weight, state_dict["linear1.weight"])
self.assertEqual(net.block.conv1.bias, state_dict["block.conv1.bias"])
self.assertEqual(net.bn.running_mean, state_dict["bn.running_mean"])
new_state_dict = net.state_dict()
del old_state_dict["linear1.weight"]
del old_state_dict["block.conv1.bias"]
del old_state_dict["bn.running_mean"]
for (
k,
v,
) in old_state_dict.items():
self.assertTrue(v.equal(new_state_dict[k]))
@swap([True, False])
def test_load_state_dict_BC(self):
# BatchNormNd
# Added num_batches_tracked buffer at version 2. For state dict with
# earlier versions or no versions, it should provide default value of 0.
bn = nn.BatchNorm2d(3)
state_dict = bn.state_dict()
del state_dict["num_batches_tracked"]
state_dict._metadata[""]["version"] = 1 # version 1
bn.load_state_dict(state_dict)
self.assertEqual(bn.num_batches_tracked.dtype, torch.long)
self.assertEqual(bn.num_batches_tracked.item(), 0)
del state_dict._metadata[""]["version"] # no version
bn.load_state_dict(state_dict)
self.assertEqual(bn.num_batches_tracked.dtype, torch.long)
self.assertEqual(bn.num_batches_tracked.item(), 0)
@swap([True, False])
def test_load_state_dict_child(self):
base_module = nn.Linear(1, 1)
model = base_module
for _ in range(3):
model = nn.Sequential(*[deepcopy(model) for _ in range(10)])
def hook_fn(
module,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
module_state_dict = module.state_dict()
self.assertEqual(len(module_state_dict.keys()), len(state_dict.keys()))
model[0][0].register_load_state_dict_pre_hook(hook_fn)
model.load_state_dict(model.state_dict(), strict=True)
# fails swapping as LSTM installs weak references on the parameters
@swap([False])
@skipIfTorchDynamo("TorchDynamo fails here for unknown reasons")
def test_load_state_dict_ref_cycle(self):
# load_state_dict shouldn't cause a reference cycle involving Tensors
import gc
m = torch.nn.LSTM(16, 16, bidirectional=True)
gc.collect()
m.load_state_dict(deepcopy(m).state_dict())
refcycles = gc.collect()
self.assertEqual(refcycles, 0)
@swap([True, False])
def test_load_state_dict_custom(self):
class CustomState(nn.Module):
def __init__(self) -> None:
super().__init__()
self.param = torch.nn.Parameter(torch.ones(1))
self.sub = torch.nn.Linear(5, 5)
def _save_to_state_dict(self, destination, prefix, keep_vars):
destination[prefix + "serialized"] = self.param.data + 1
def _load_from_state_dict(
self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
# skip some of the error handling
self.param.data.copy_(state_dict[prefix + "serialized"] - 1)
# use sequential to verify nesting
m = nn.Sequential(CustomState())
with torch.no_grad():
m[0].param[0] = 10
m[0].sub.weight[0, 0] = 555
state_dict = m.state_dict()
self.assertEqual(state_dict["0.serialized"].item(), 11)
self.assertIn("0.sub.weight", state_dict)
self.assertNotIn("0.param", state_dict)
del m
mm = nn.Sequential(CustomState())
self.assertEqual(mm[0].param[0].item(), 1)
mm.load_state_dict(state_dict)
self.assertEqual(mm[0].param[0].item(), 10)
self.assertEqual(mm[0].sub.weight[0, 0].item(), 555)
@swap([True, False])
@parametrize("keep_vars", [True, False])
def test_load_state_dict_assign_meta(self, keep_vars):
class MyModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.fc1 = nn.Linear(3, 5)
self.bn = nn.BatchNorm1d(5)
self.x = nn.Parameter(torch.rand(5), requires_grad=False)
def forward(self, input):
return self.x + self.bn(self.fc1(input))
swap = torch.__future__.get_swap_module_params_on_conversion()
net = MyModule()
state_dict = net.state_dict(keep_vars=keep_vars)
for v in state_dict.values():
v.requires_grad_(False)
with torch.device("meta"):
net_meta = MyModule()
net_meta_state_dict_old = net_meta.state_dict(keep_vars=True)
net_meta.load_state_dict(state_dict, assign=True)
# Make sure parameters and persistent buffers were assigned
net_meta_state_dict = net_meta.state_dict(keep_vars=True)
for key in state_dict.keys():
if key in net_meta._parameters:
if keep_vars and not swap:
# state_dict[key] is an nn.Parameter
self.assertTrue(state_dict[key] is net_meta_state_dict[key])
else:
if swap:
self.assertTrue(
net_meta_state_dict[key] is net_meta_state_dict_old[key]
)
else:
# state_dict[key] is not an nn.Parameter so it will be detached when wrapping with a Parameter
self.assertTrue(
net_meta_state_dict[key] is not net_meta_state_dict_old[key]
)
self.assertEqual(
net_meta_state_dict_old[key].requires_grad,
net_meta_state_dict[key].requires_grad,
)
self.assertEqual(
net_meta_state_dict_old[key].requires_grad,
net_meta_state_dict[key].requires_grad,
)
self.assertEqual(state_dict[key], net_meta_state_dict[key])
elif (
key in net_meta._buffers
and key not in net_meta._non_persistent_buffers_set
):
self.assertTrue(state_dict[key] is net_meta_state_dict[key])
self.assertEqual(state_dict[key], net_meta_state_dict[key])
# Make sure that ordering of parameters and buffers is preserved
net_named_parameters = net.named_parameters()
net_named_buffers = net.named_buffers()
net_meta_named_parameters = net_meta.named_parameters()
net_meta_named_buffers = net_meta.named_buffers()
for (n1, _), (n2, _) in zip(net_named_parameters, net_meta_named_parameters):
self.assertEqual(n1, n2)
for (n1, _), (n2, _) in zip(net_named_buffers, net_meta_named_buffers):
self.assertEqual(n1, n2)
# Make sure outputs are the same
t = torch.randn(4, 3)
out_net = net(t)
out_net_meta = net_meta(t.clone())
self.assertEqual(out_net, out_net_meta)
@swap([True, False])
def test_load_state_dict_assign_with_optimizer(self):
class MyModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.fc1 = nn.Linear(3, 5)
self.bn = nn.BatchNorm1d(5)
def forward(self, input):
return self.bn(self.fc1(input))
net = MyModule()
opt = torch.optim.Adam(net.parameters(), lr=1000)
x = torch.randn(4, 3)
num_iters = 3
for i in range(num_iters):
opt.zero_grad()
out = net(x)
out.sum().backward()
opt.step()
opt_state_dict = deepcopy(opt.state_dict())
net_state_dict = deepcopy(net.state_dict())
with torch.device("meta"):
net_meta = MyModule()
net_meta.load_state_dict(net_state_dict, assign=True)
# must create optimizer only after loading state_dict when assign=True
opt2 = torch.optim.Adam(net_meta.parameters(), lr=1000)
opt2.load_state_dict(opt_state_dict)
y = x.clone()
for i in range(num_iters):
opt.zero_grad()
out = net(x)
out.sum().backward()
opt.step()
opt2.zero_grad()
out2 = net_meta(y)
out2.sum().backward()
opt2.step()
self.assertEqual(opt.state_dict(), opt2.state_dict())
self.assertEqual(net.state_dict(), net_meta.state_dict())
@swap([True, False])
def test_load_state_dict_assign_shape_stride(self):
# Assigned tensor is allowed to have different properties than initial
# tensor except for shape
class MyModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.fc1 = nn.Linear(3, 5)
self.bn = nn.BatchNorm1d(5)
def forward(self, input):
return self.bn(self.fc1(input))
net = MyModule()
state_dict = net.state_dict()
# loading should be ok if stride is different
state_dict["fc1.weight"] = torch.randn(3, 5).transpose(0, 1)
net2 = MyModule()
net2.load_state_dict(state_dict, strict=False, assign=True)
state_dict["fc1.weight"] = torch.randn(2, 4)
with self.assertRaisesRegex(
RuntimeError, "size mismatch for fc1.weight: copying a param with shape"
):
net2.load_state_dict(state_dict, strict=False, assign=True)
@swap([True, False])
def test_load_state_dict_warn_assign(self):
with torch.device("meta"):
m = torch.nn.Linear(3, 5)
state_dict = m.state_dict()
state_dict["weight"] = torch.empty_like(state_dict["weight"], device="cpu")
with self.assertWarnsRegex(
UserWarning,
"for weight: copying from a non-meta parameter in the checkpoint to a meta",
):
m.load_state_dict(state_dict)
@swap([True, False])
def test_load_state_dict_with_unexpected_key(self):
class MyModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.fc1 = torch.nn.Linear(5, 10)
m = MyModule()
# Unexpected key & strict = True
with self.assertRaisesRegex(RuntimeError, "Unexpected key"):
state_dict = m.state_dict()
state_dict["fc1.bad_suffix"] = torch.randn(5, 10)
m.load_state_dict(state_dict)
# Unexpected key & strict = False
state_dict = m.load_state_dict(state_dict, strict=False)
self.assertIn("fc1.bad_suffix", state_dict.unexpected_keys)
# Unexpected key whose prefix matches a valid key & strict = True
with self.assertRaisesRegex(RuntimeError, "Unexpected key"):
state_dict = m.state_dict()
state_dict["fc1.weight.bad_suffix"] = torch.randn(5, 10)
m.load_state_dict(state_dict)
# Unexpected key whose prefix matches a valid key & strict = False
state_dict = m.load_state_dict(state_dict, strict=False)
self.assertIn("fc1.weight.bad_suffix", state_dict.unexpected_keys)
def load_torch_function_handler(cls, func, types, args=(), kwargs=None):
kwargs = {} if kwargs is None else kwargs
def module_load(dest, src, assign=False):
if isinstance(dest, cls):
if assign:
return src.detach()
else:
if type(src) is torch.Tensor:
return cls(src)
elif type(src) is cls:
return src.detach()
else:
if isinstance(src, MyWrapperLoadTensor):
return cls(src._data)
return cls(src)
else:
assert isinstance(
src, cls
), f"Expected isinstance(src, {cls}) but got {type(src)}"
assert (
type(dest) == torch.Tensor
or type(dest) == torch.nn.Parameter
or issubclass(cls, type(dest))
)
if assign:
return src.detach()
else:
if isinstance(src, MyWrapperLoadTensor):
if type(dest) not in {torch.Tensor, torch.nn.Parameter}:
return type(dest)(src._data)
else:
return src._data.detach()
else:
return torch.Tensor(src)
if func is torch.Tensor.module_load:
return module_load(*args, **kwargs)
else:
with torch._C.DisableTorchFunctionSubclass():
# detach must return instance of same subclass for nn.Parameter()
if func == torch.Tensor.detach:
ret = func(*args, **kwargs)
if not isinstance(ret, cls):
return cls(ret)
return ret
return func(*args, **kwargs)
class MyLoadTensor(torch.Tensor):
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
return load_torch_function_handler(cls, func, types, args, kwargs)
# We use MyLoadTensor2 to test tensor subclass, wrapper tensor subclass
# where neither inherits from each other
class MyLoadTensor2(torch.Tensor):
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
return load_torch_function_handler(cls, func, types, args, kwargs)
class MyBrokenLoadTensor(torch.Tensor):
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
kwargs = {} if kwargs is None else kwargs
if func is torch.Tensor.module_load:
# wrong as this doesn't detach!
return args[1]
else:
with torch._C.DisableTorchFunctionSubclass():
# detach must return instance of same subclass for nn.Parameter()
if func == torch.Tensor.detach:
return cls(func(*args, **kwargs))
return func(*args, **kwargs)
class MyWrapperLoadTensor(MyLoadTensor):
@staticmethod
def __new__(cls, data: torch.Tensor):
t = torch.Tensor._make_wrapper_subclass(
cls,
data.size(),
dtype=data.dtype,
layout=data.layout,
device=data.device,
requires_grad=data.requires_grad,
strides=data.stride(),
storage_offset=data.storage_offset(),
)
return t
def __init__(self, data: torch.Tensor):
self._data = data
def __repr__(self):
return f"MyWrapperLoadTensor({self._data.__repr__()})"
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
def unwrap(t):
return t._data if isinstance(t, MyWrapperLoadTensor) else t
def wrap(t):
return MyWrapperLoadTensor(t) if isinstance(t, torch.Tensor) else t
kwargs = {} if kwargs is None else kwargs
out = func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))
return tree_map(wrap, out)
class TestLoadStateDictSwap(TestCase):
@skipIfCrossRef
@skipIfTorchDynamo("Can't swap with dynamo as dynamo installs weakrefs")
@swap([True])
@parametrize("assign", [True, False])
def test_swap_subclass(self, assign):
def _create_model(subclass=None):
m = torch.nn.Linear(2, 3, bias=False)
m.buf = torch.nn.Buffer(torch.randn(2, 3))
if subclass is not None:
m.weight = torch.nn.Parameter(subclass(m.weight))
m.buf = subclass(m.buf)
return m
def _test(m_subclass=None, sd_subclass=None):
m = _create_model(m_subclass)
sd = _create_model(sd_subclass).state_dict()
m.load_state_dict(sd, assign=assign)
self.assertEqual(m.weight, sd["weight"])
self.assertEqual(m.buf, sd["buf"])
self.assertTrue(isinstance(m.weight, torch.nn.Parameter))
self.assertTrue(not isinstance(m.buf, torch.nn.Parameter))
weight_type, buf_type = (torch.nn.Parameter, torch.Tensor)
if assign:
if sd_subclass is not None:
weight_type, buf_type = (sd_subclass, sd_subclass)
else:
if m_subclass is not None:
weight_type, buf_type = (m_subclass, m_subclass)
self.assertTrue(type(m.weight) is weight_type)
self.assertTrue(type(m.buf) is buf_type)
# (MyLoadTensor, MyWrapperLoadTensor) tests the behavior of (superclass, subclass)
subclasses = [None, MyLoadTensor, MyLoadTensor2, MyWrapperLoadTensor]
for m_s, sd_s in product(subclasses, subclasses):
_test(m_s, sd_s)
# MyBrokenLoadTensor should error since its module_load doesn't call .detach()
with self.assertRaisesRegex(
RuntimeError, re.escape("Error(s) in loading state_dict for Linear:")
):
_test(None, MyBrokenLoadTensor)
instantiate_parametrized_tests(TestLoadStateDict)
instantiate_parametrized_tests(TestLoadStateDictSwap)
if __name__ == "__main__":
TestCase._default_dtype_check_enabled = True
run_tests()
|