1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176
|
# Owner(s): ["oncall: distributed"]
from copy import deepcopy
from typing import List, Tuple
import torch
import torch.nn as nn
from torch.distributed._composable import _get_registry, contract
from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TestCase
class ToyModel(nn.Module):
def __init__(self) -> None:
super().__init__()
self.seq1 = nn.Sequential(*[nn.Linear(10, 10) for _ in range(2)])
self.seq2 = nn.Sequential(*[nn.Linear(10, 10) for _ in range(2)])
self.p = nn.Parameter(torch.randn(10, 10), requires_grad=True)
self.b = torch.zeros(1) # buffer
def forward(self, x, y):
with torch.no_grad():
self.b += x.sum() + y.sum()
return self.p + self.seq1(x) + self.seq2(y)
class TestContract(TestCase):
@skipIfTorchDynamo("Dynamo does not support the state key")
def test_add_hooks(self):
def forward_pre_hook(
module: nn.Module, inp: Tuple[torch.Tensor]
) -> Tuple[torch.Tensor]:
return inp
def forward_hook(
module: nn.Module, inp: Tuple[torch.Tensor], out: torch.Tensor
) -> torch.Tensor:
return out
def backward_pre_hook(
module: nn.Module, grad_output: torch.Tensor
) -> torch.Tensor:
return grad_output
def backward_hook(
module: nn.Module,
grad_input: Tuple[torch.Tensor],
grad_output: torch.Tensor,
) -> Tuple[torch.Tensor]:
return grad_input
@contract()
def noop_api(module: nn.Module) -> nn.Module:
module.register_forward_pre_hook(forward_pre_hook)
module.register_forward_hook(forward_hook)
module.register_full_backward_pre_hook(backward_pre_hook)
module.register_full_backward_hook(backward_hook)
return module
model = ToyModel()
model_with_hooks = deepcopy(model)
noop_api(model.seq1)
noop_api(model.seq2)
x, y = torch.randn(10, 10), torch.randn(10, 10)
model(x, y).sum().backward()
model_with_hooks(x, y).sum().backward()
for p1, p2 in zip(model.parameters(), model_with_hooks.parameters()):
self.assertEqual(p1, p2)
@skipIfTorchDynamo("Dynamo does not support the state key")
def test_modify_fqn(self):
class ModelWrapper(nn.Module):
def __init__(self, module):
super().__init__()
self.module = module
def forward(self, x):
return self.module(x)
@contract()
def wrap_module(module: nn.Module) -> nn.Module:
return ModelWrapper(module)
model = ToyModel()
regex = "Checking parameters: Composable distributed API implementations cannot modify FQNs."
with self.assertRaisesRegex(RuntimeError, regex):
wrap_module(model.seq1)
@skipIfTorchDynamo("Dynamo does not support the state key")
def test_state(self):
def check_and_update_state_hook(
module: nn.Module, inp: Tuple[torch.Tensor]
) -> Tuple[torch.Tensor]:
self.assertEqual(api.state(module).dummy_state, 7)
api.state(module).dummy_state = 8
return inp
# FIXME: circular reference looks a bit weird. Shall we make .state a
# top-level API instead attached to contract API?
@contract()
def api(module: nn.Module) -> nn.Module:
api.state(module).dummy_state = 7
module.register_forward_pre_hook(check_and_update_state_hook)
return module
model = ToyModel()
api(model.seq1)
self.assertEqual(api.state(model.seq1).dummy_state, 7)
model(torch.zeros(10, 10), torch.zeros(10, 10))
self.assertEqual(api.state(model.seq1).dummy_state, 8)
@skipIfTorchDynamo("Dynamo does not support the state key")
def test_registry(self):
@contract()
def api1(module: nn.Module) -> nn.Module:
return module
@contract()
def api2(module: nn.Module) -> nn.Module:
return module
model = ToyModel()
model = api1(model)
self.assertEqual(1, len(_get_registry(model)))
self.assertTrue("api1" in _get_registry(model))
model = api2(model)
self.assertEqual(2, len(_get_registry(model)))
self.assertTrue([_get_registry(model).keys()], ["api1", "api2"])
self.assertEqual(None, _get_registry(model.seq1))
self.assertEqual(None, _get_registry(model.seq2))
with self.assertRaisesRegex(AssertionError, "api1 has already been applied"):
model = api1(model)
@skipIfTorchDynamo("Dynamo does not support the state key")
def test_multi_module_api(self):
@contract()
def multi_module_api(modules: List[nn.Module]) -> nn.Module:
return modules
model = nn.Sequential(*[nn.Linear(3, 3) for _ in range(5)])
multi_module_api([model[0], model[1]])
multi_module_api([model[2], model[3]])
multi_module_api([model[4]])
# Check that modules have the same state and registry iff they shared
# the same API call
states = [multi_module_api.state(module) for module in model]
self.assertEqual(states[0], states[1])
self.assertEqual(states[2], states[3])
self.assertNotEqual(states[0], states[2])
self.assertNotEqual(states[0], states[4])
self.assertNotEqual(states[2], states[4])
registries = [_get_registry(module) for module in model]
self.assertEqual(registries[0], registries[1])
self.assertEqual(registries[2], registries[3])
self.assertNotEqual(registries[0], registries[2])
self.assertNotEqual(registries[0], registries[4])
self.assertNotEqual(registries[2], registries[4])
# Check that applying an API to a module multiple times errors
model = nn.Sequential(*[nn.Linear(3, 3) for _ in range(5)])
multi_module_api([model[0], model[1]])
with self.assertRaisesRegex(
AssertionError,
"Each distinct composable distributed API can only be applied to "
r"a module once. multi_module_api has already been applied to the "
"following module:",
):
multi_module_api([model[0], model[2]])
if __name__ == "__main__":
run_tests()
|