1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
|
# Owner(s): ["oncall: distributed"]
import sys
import torch
from torch import distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.nn import Linear
from torch.optim import SGD
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import (
FSDPTest,
)
from torch.testing._internal.common_utils import TEST_WITH_DEV_DBG_ASAN, run_tests
if not dist.is_available():
print("Distributed not available, skipping tests", file=sys.stderr)
sys.exit(0)
if TEST_WITH_DEV_DBG_ASAN:
print(
"Skip dev-asan as torch + multiprocessing spawn have known issues",
file=sys.stderr,
)
sys.exit(0)
class TestUnevenParamShard(FSDPTest):
def _get_ref_results(self, model, input, my_lr):
with torch.no_grad():
# Compute one iteration local output.
weight = model.weight.T.clone().to(self.rank)
v = torch.Tensor(input[self.rank]).to(self.rank)
ref_forward_output_my_rank = torch.matmul(v, weight)
# Compute one iteration global weight update.
v = torch.Tensor(input[: self.world_size]).to(self.rank)
grad = v.float().sum(0).repeat(weight.shape[0], 1).div(self.world_size)
ref_weight_out = weight - grad.T * my_lr
return ref_forward_output_my_rank, ref_weight_out
@skip_if_lt_x_gpu(2)
def test_one_iteration(self):
"""Test FSDP with uneven divide of parameter shards."""
model = Linear(3, 3, bias=False)
input = torch.rand(8, 3)
my_lr = 0.1
ref_forward_output_my_rank, ref_weight_out = self._get_ref_results(
model, input, my_lr
)
model.to(self.rank)
model = FSDP(model)
optim = SGD(model.parameters(), lr=my_lr)
self.assertTrue(len(input) >= self.world_size)
in_data = torch.Tensor(input[self.rank]).to(self.rank)
out = model(in_data)
out.float().sum().backward()
optim.step()
optim.zero_grad()
with model.summon_full_params(model):
weight_out = model.module.weight.T.clone()
self.assertEqual(ref_forward_output_my_rank, out)
self.assertEqual(ref_weight_out, weight_out)
if __name__ == "__main__":
run_tests()
|