1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79
|
# Owner(s): ["oncall: distributed"]
import sys
import torch
from torch import distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.nn import Linear, Module
from torch.optim import SGD
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import (
FSDPTest,
)
from torch.testing._internal.common_utils import (
TEST_WITH_DEV_DBG_ASAN,
instantiate_parametrized_tests,
parametrize,
run_tests,
subtest,
)
if not dist.is_available():
print("Distributed not available, skipping tests", file=sys.stderr)
sys.exit(0)
if TEST_WITH_DEV_DBG_ASAN:
print(
"Skip dev-asan as torch + multiprocessing spawn have known issues",
file=sys.stderr,
)
sys.exit(0)
class TestInput(FSDPTest):
@property
def world_size(self):
return 1
@skip_if_lt_x_gpu(1)
@parametrize("input_cls", [subtest(dict, name="dict"), subtest(list, name="list")])
def test_input_type(self, input_cls):
"""Test FSDP with input being a list or a dict, only single GPU."""
class Model(Module):
def __init__(self):
super().__init__()
self.layer = Linear(4, 4)
def forward(self, input):
if isinstance(input, list):
input = input[0]
else:
assert isinstance(input, dict), input
input = input["in"]
return self.layer(input)
model = FSDP(Model()).cuda()
optim = SGD(model.parameters(), lr=0.1)
for _ in range(5):
in_data = torch.rand(64, 4).cuda()
in_data.requires_grad = True
if input_cls is list:
in_data = [in_data]
else:
self.assertTrue(input_cls is dict)
in_data = {"in": in_data}
out = model(in_data)
out.sum().backward()
optim.step()
optim.zero_grad()
instantiate_parametrized_tests(TestInput)
if __name__ == "__main__":
run_tests()
|