1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
|
# Owner(s): ["oncall: distributed"]
import sys
import torch
from torch import distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.nn import Linear, Module, Sequential
from torch.optim import SGD
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import FSDPTest
from torch.testing._internal.common_utils import TEST_WITH_DEV_DBG_ASAN, run_tests
if not dist.is_available():
print("Distributed not available, skipping tests", file=sys.stderr)
sys.exit(0)
if TEST_WITH_DEV_DBG_ASAN:
print(
"Skip dev-asan as torch + multiprocessing spawn have known issues",
file=sys.stderr,
)
sys.exit(0)
class InnerModel(Module):
def __init__(self):
super().__init__()
self.layers = Sequential(FSDP(Linear(5, 5)))
def forward(self, x):
return self.layers(x)
class TestMultipleWrapping(FSDPTest):
@skip_if_lt_x_gpu(2)
def test_multiple_wrapping(self):
"""
This test simulates wrapping the module after training to run inference.
This is required in cases where later in a session, the model is wrapped again in FSDP but
contains nested FSDP wrappers within the module.
"""
inner_model = InnerModel()
model = FSDP(inner_model).cuda()
optim = SGD(model.parameters(), lr=0.1)
for i in range(3):
input = torch.rand((1, 5), dtype=torch.float).cuda()
input.requires_grad = True
output = model(input)
output.sum().backward()
optim.step()
optim.zero_grad()
input = torch.rand((1, 5), dtype=torch.float).cuda()
output = model(input)
# second time to rewrap the inner model
rewrapped_model = FSDP(inner_model).cuda()
rewrapped_output = rewrapped_model(input)
self.assertEqual(output, rewrapped_output)
if __name__ == "__main__":
run_tests()
|