1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154
|
# Owner(s): ["oncall: distributed"]
import os
import torch
import torch.distributed as dist
from torch.testing._internal.common_utils import (
run_tests,
)
from torch.futures import Future
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
import test_c10d_common
import weakref
from torch._C._distributed_c10d import _create_work_from_future
from torch.testing._internal.common_distributed import (
MultiProcessTestCase,
)
def create_work(result):
future = Future()
future.set_result(result)
return _create_work_from_future(future)
class MyWork(dist._Work):
def __init__(self, result, pg):
super().__init__()
self.result_ = result
self.future_ = torch.futures.Future()
self.future_.set_result(result)
self.pg_ = weakref.ref(pg)
def wait(self, timeout):
self.pg_().wait_count += 1
return True
def get_future(self):
self.pg_().get_future_count += 1
return self.future_
class LonelyRankProcessGroup(dist.ProcessGroup):
"""
This PG only supports world_size of 1
"""
def __init__(self, rank, world, use_wrapper):
super(LonelyRankProcessGroup, self).__init__(rank, world)
assert rank == 0
assert world == 1
self._rank = rank
self._world = world
self.wait_count = 0
self.get_future_count = 0
self.use_wrapper = use_wrapper
self._work = []
def broadcast(self, tensor_list, opts):
if self.use_wrapper:
return create_work(tensor_list)
res = MyWork(tensor_list, self)
self._work.append(res)
return res
def allgather(self, output_tensors, input_tensor, opts):
for o, i in zip(output_tensors[0], input_tensor):
o.copy_(i)
if self.use_wrapper:
return create_work(output_tensors)
res = MyWork(output_tensors, self)
self._work.append(res)
return res
def allreduce(self, tensors, opts):
if self.use_wrapper:
return create_work(tensors)
res = MyWork(tensors, self)
self._work.append(res)
return res
def size(self):
return self._world
def getBackendName(self):
return "lonely-pg"
def __repr__(self):
return f"PLG w:{self._world} r:{self._rank}"
# We cannot use parametrize as some tests are defined on the base class and use _get_process_group
class AbstractDDPSingleRank(test_c10d_common.CommonDistributedDataParallelTest):
def setUp(self):
super(AbstractDDPSingleRank, self).setUp()
self._spawn_processes()
@property
def world_size(self):
return 1
def tearDown(self):
super(AbstractDDPSingleRank, self).tearDown()
try:
os.remove(self.file_name)
except OSError:
pass
def _get_process_group(self):
return LonelyRankProcessGroup(self.rank, self.world_size, self.use_wrapper)
def test_ddp_invoke_work_object(self):
pg = self._get_process_group()
torch.manual_seed(123)
model = nn.Sequential(
nn.Linear(2, 2),
nn.ReLU()
)
wrapped_model = model
input_tensor = torch.rand(2)
model = DDP(model, process_group=pg)
model(input_tensor).sum().backward()
ddp_grad = wrapped_model[0].bias.grad.clone()
wrapped_model.zero_grad()
wrapped_model(input_tensor).sum().backward()
self.assertEqual(wrapped_model[0].bias.grad, ddp_grad)
if not self.use_wrapper:
self.assertTrue(pg.wait_count > 0)
self.assertTrue(pg.get_future_count > 0)
def test_ddp_with_pypg(self):
pg = self._get_process_group()
self._test_ddp_with_process_group(pg, [torch.device("cpu")], device_ids=None)
def test_ddp_with_pypg_with_grad_views(self):
pg = self._get_process_group()
self._test_ddp_with_process_group(pg, [torch.device("cpu")], device_ids=None, gradient_as_bucket_view=True)
class TestDDPWithWorkSubclass(AbstractDDPSingleRank, MultiProcessTestCase):
@property
def use_wrapper(self):
return False
class TestDDPWithWorkWrapper(AbstractDDPSingleRank, MultiProcessTestCase):
@property
def use_wrapper(self):
return True
if __name__ == '__main__':
run_tests()
|