1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260
|
# Owner(s): ["module: nn"]
import tempfile
from copy import deepcopy
from functools import partial
from unittest import expectedFailure
import torch
from torch import nn
from torch.nn.modules.lazy import LazyModuleMixin
from torch.nn.utils.parametrize import (
register_parametrization,
remove_parametrizations,
)
from torch.testing._internal.common_subclass import (
DiagTensorBelow,
subclass_db,
)
from torch.testing._internal.common_utils import (
TestCase,
instantiate_parametrized_tests,
parametrize,
run_tests,
skipIfTorchDynamo,
subtest,
)
from torch.testing._internal.logging_tensor import LoggingTensor
from torch.utils._pytree import tree_map
# The current test methodology in this file is to test a variety of real use cases
# with a set of fully-fledged tensor subclasses. In the future, this may change
# to more narrowly specify toy subclasses for each of the specific invariants under
# test, avoiding the need to maintain the set of fully-fledged tensor subclasses.
# Decorator for parametrizing tests across the various tensor classes.
parametrize_tensor_cls = parametrize("tensor_cls", [
subtest(tensor_cls, name=info.name) for tensor_cls, info in subclass_db.items()])
class TestSubclass(TestCase):
def _create_tensor(self, tensor_cls):
return subclass_db[tensor_cls].create_fn(3)
@parametrize_tensor_cls
@parametrize("tensor_requires_grad", [False, True])
def test_param_invariants(self, tensor_cls, tensor_requires_grad):
x = self._create_tensor(tensor_cls).requires_grad_(tensor_requires_grad)
param = nn.Parameter(x, requires_grad=(not tensor_requires_grad))
self.assertIsInstance(param, nn.Parameter)
# Ensure requires_grad passed to Parameter's constructor takes precedence.
self.assertEqual(param.requires_grad, not tensor_requires_grad)
# Ensure original tensor is not mutated by Parameter construction.
self.assertNotIsInstance(x, nn.Parameter)
self.assertEqual(x.requires_grad, tensor_requires_grad)
@skipIfTorchDynamo()
@parametrize_tensor_cls
@parametrize("as_param", [False, True])
def test_deepcopy(self, tensor_cls, as_param):
x = self._create_tensor(tensor_cls)
if as_param:
x = nn.Parameter(x)
x_copy = deepcopy(x)
self.assertEqual(x, x_copy)
self.assertEqual(x.__class__, x_copy.__class__)
self.assertIsNot(x, x_copy)
self.assertIsInstance(x_copy, tensor_cls)
if as_param:
# Deepcopy should preserve both custom type and "parameter-ness".
self.assertIsInstance(x_copy, nn.Parameter)
@parametrize_tensor_cls
@parametrize("as_param", [False, True])
def test_serialization(self, tensor_cls, as_param):
with tempfile.TemporaryFile() as f:
x = self._create_tensor(tensor_cls)
if as_param:
x = nn.Parameter(x)
torch.save(x, f)
f.seek(0)
x_loaded = torch.load(f)
self.assertEqual(x, x_loaded)
self.assertIsNot(x, x_loaded)
self.assertIsInstance(x_loaded, tensor_cls)
if as_param:
# Serialization should preserve both custom type and "parameter-ness".
self.assertIsInstance(x_loaded, nn.Parameter)
@skipIfTorchDynamo("Visible only with functorch as functorch monkeypatches tensor str")
@parametrize_tensor_cls
@parametrize("as_param", [False, True])
def test_repr(self, tensor_cls, as_param):
x = self._create_tensor(tensor_cls)
if as_param:
x = nn.Parameter(x)
str_repr = x.__repr__()
if tensor_cls is not torch.Tensor:
self.assertEqual(str_repr.count(f"{tensor_cls.__name__}("), 1)
self.assertEqual(str_repr.count("Parameter"), 1 if as_param else 0)
@parametrize_tensor_cls
@parametrize("as_param", [False, True])
def test_type_propagation(self, tensor_cls, as_param):
x = self._create_tensor(tensor_cls)
if as_param:
x = nn.Parameter(x)
# Call the add operator to produce an output tensor.
output = x + self._create_tensor(torch.Tensor)
# Custom type should be propagated across operations if closed under the op, but
# "parameter-ness" should not be.
if subclass_db[tensor_cls].closed_under_ops:
self.assertIsInstance(output, tensor_cls)
else:
self.assertIsInstance(output, torch.Tensor)
self.assertNotIsInstance(output, nn.Parameter)
@parametrize_tensor_cls
def test_module_optimization(self, tensor_cls):
create_fn = partial(self._create_tensor, tensor_cls)
class MyModule(nn.Module):
def __init__(self):
super().__init__()
self.p1 = nn.Parameter(create_fn())
self.p_list = nn.ParameterList([create_fn() for _ in range(3)])
self.p_list.append(create_fn())
self.p_dict = nn.ParameterDict({
'foo': create_fn(),
'bar': create_fn(),
})
self.p_dict['baz'] = create_fn()
with torch.no_grad():
nn.init.normal_(self.p1)
for p in self.p_list:
nn.init.uniform_(p)
for _, p in self.p_dict.items():
nn.init.uniform_(p)
def forward(self, x):
out = self.p1 + x
for p in self.p_list:
out = p + out
for _, v in self.p_dict.items():
out = v + out
return out
m = MyModule()
self.assertEqual(len(m.state_dict()), 8)
optimizer = torch.optim.SGD(m.parameters(), lr=0.1)
m(create_fn()).sum().backward(torch.tensor(1))
optimizer.step()
@parametrize_tensor_cls
@parametrize("leave_parametrized", [False, True])
def test_parametrization(self, tensor_cls, leave_parametrized):
# TODO: Either implement set_() properly for these tensor subclasses or apply a
# more general fix to avoid the need for special set_() handling. For now, skip
# testing these as they're expected to fail.
if tensor_cls in [LoggingTensor, DiagTensorBelow]:
return
create_fn = partial(self._create_tensor, tensor_cls)
class MyModule(nn.Module):
def __init__(self):
super().__init__()
self.weight = nn.Parameter(create_fn())
def forward(self, x):
return self.weight + x
class MyParametrization(nn.Module):
def forward(self, X):
return -X
m = MyModule()
self.assertEqual(len(m.state_dict()), 1)
register_parametrization(m, 'weight', MyParametrization())
self.assertIsInstance(m.weight, tensor_cls)
output = m(self._create_tensor(torch.Tensor))
self.assertIsInstance(output, tensor_cls)
remove_parametrizations(m, 'weight', leave_parametrized=leave_parametrized)
# Lazy modules with custom tensors are not supported yet.
@expectedFailure
@parametrize_tensor_cls
def test_lazy_module(self, tensor_cls):
if tensor_cls is torch.Tensor:
self.fail('dummy fail for base tensor until the test passes for subclasses')
class MyLazyModule(LazyModuleMixin, nn.Module):
def __init__(self):
super().__init__()
self.param = nn.UninitializedParameter()
def initialize_parameters(self, input) -> None: # type: ignore[override]
if self.has_uninitialized_params():
with torch.no_grad():
self.param.materialize(input.shape)
nn.init.uniform_(self.param)
def forward(self, x):
return self.param + x
m = MyLazyModule()
self.assertTrue(m.has_uninitialized_params())
output = m(self._create_tensor(tensor_cls))
self.assertFalse(m.has_uninitialized_params())
self.assertIsInstance(m.param, tensor_cls)
def test_non_rewrapping_torch_dispatch_subclass_as_parameter_throws_for_detach(self):
# Define a subclass that does not rewrap for any function in its __torch_dispatch__ impl.
class NonRewrappingTensor(torch.Tensor):
@staticmethod
def __new__(
cls, t: torch.Tensor
):
r = super(NonRewrappingTensor, cls)._make_wrapper_subclass(
cls, t.shape, dtype=t.dtype, requires_grad=t.requires_grad, device=t.device)
return r
def __init__(self, t) -> None:
self.tensor: torch.Tensor = t
__torch_function__ = torch._C._disabled_torch_function_impl
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
def unwrap(e) -> torch.Tensor:
if isinstance(e, NonRewrappingTensor):
t = e.tensor
return t
else:
return e
r = func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))
# Return an unwrapped tensor no longer of original subclass type.
return r
with self.assertRaisesRegex(RuntimeError, r"requires that detach\(\) returns an instance of the same type"):
param = nn.Parameter(NonRewrappingTensor(torch.randn(3)))
instantiate_parametrized_tests(TestSubclass)
if __name__ == '__main__':
run_tests()
|