1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279
|
import threading
import torch
import torch.distributed.autograd as dist_autograd
import torch.distributed.rpc as rpc
from torch import optim
from torch.distributed.optim import DistributedOptimizer
from torch.testing._internal.dist_utils import dist_init
from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
RpcAgentTestFixture,
)
class MyModule:
lock = threading.Lock()
def __init__(self, requires_grad=True):
# cannot directly use torch.manual_seed(0) as all threads share the same
# default generator. The race from multiple RPC threads could mess up
# the draw order from the default RNG instance, leading to
# non-deterministic behavior. Hence, create a dedicated RNG here.
g_cpu = torch.Generator()
g_cpu.manual_seed(0)
self.w = torch.rand((3, 3), requires_grad=requires_grad, generator=g_cpu)
def forward(self, t1):
return torch.mm(self.w, t1)
def get_w(self):
return self.w
class FailingOptimizer(optim.Optimizer):
def __init__(self, params):
super().__init__(params, {})
def step(self, closure=None):
raise ValueError("Error running optimizer.")
class OptimizerFailingOnConstructor(optim.Optimizer):
def __init__(self, params):
super().__init__(params, {})
raise ValueError("Error creating optimizer.")
def step(self, closure=None):
raise NotImplementedError
def _call_method(method, obj_rref, *args, **kwargs):
return method(obj_rref.local_value(), *args, **kwargs)
def remote_method(method, obj_rref, *args, **kwargs):
"""
Call rpc.remote on a method in a remote object.
Args:
method: the method (for example, Class.method)
obj_rref (RRef): remote reference to the object
args: positional arguments to pass to the method
kwargs: keyword arguments to pass to the method
Returns a RRef to the remote method call result.
"""
return rpc.remote(
obj_rref.owner(),
_call_method,
args=[method, obj_rref] + list(args),
kwargs=kwargs,
)
def rpc_async_method(method, obj_rref, *args, **kwargs):
"""
Call rpc.rpc_async on a method in a remote object.
Args:
method: the method (for example, Class.method)
obj_rref (RRef): remote reference to the object
args: positional arguments to pass to the method
kwargs: keyword arguments to pass to the method
Returns a Future to the method call result.
"""
return rpc.rpc_async(
obj_rref.owner(),
_call_method,
args=[method, obj_rref] + list(args),
kwargs=kwargs,
)
class DistOptimizerTest(RpcAgentTestFixture):
@dist_init()
def test_dist_optim_exception(self):
# distributed version
owner1 = "worker%d" % ((self.rank + 1) % self.world_size)
owner2 = "worker%d" % ((self.rank + 2) % self.world_size)
remote_module1 = rpc.remote(owner1, MyModule)
remote_module2 = rpc.remote(owner2, MyModule)
remote_param1 = remote_method(MyModule.get_w, remote_module1)
remote_param2 = remote_method(MyModule.get_w, remote_module2)
dist_optim = DistributedOptimizer(
FailingOptimizer, [remote_param1, remote_param2]
)
with dist_autograd.context() as context_id:
g_cpu = torch.Generator()
g_cpu.manual_seed(0)
t1 = torch.rand((3, 3), requires_grad=True, generator=g_cpu)
t2 = torch.rand((3, 3), requires_grad=True, generator=g_cpu)
output1 = rpc_async_method(MyModule.forward, remote_module1, t2)
output2 = rpc_async_method(MyModule.forward, remote_module2, output1.wait())
loss = torch.add(output2.wait(), t1).sum()
dist_autograd.backward(context_id, [loss])
with self.assertRaisesRegex(Exception, "Error running optimizer"):
dist_optim.step(context_id)
@dist_init()
def test_dist_optim_exception_on_constructor(self):
# distributed version
owner1 = "worker%d" % ((self.rank + 1) % self.world_size)
owner2 = "worker%d" % ((self.rank + 2) % self.world_size)
remote_module1 = rpc.remote(owner1, MyModule)
remote_module2 = rpc.remote(owner2, MyModule)
remote_param1 = remote_method(MyModule.get_w, remote_module1)
remote_param2 = remote_method(MyModule.get_w, remote_module2)
with self.assertRaisesRegex(Exception, "Error creating optimizer."):
dist_optim = DistributedOptimizer(
OptimizerFailingOnConstructor, [remote_param1, remote_param2]
)
def _test_dist_optim_base(self, optim_cls, *args, **kwargs):
# local version
module1 = MyModule()
module2 = MyModule()
params = [module1.get_w(), module2.get_w()]
local_optim = optim_cls(params, *args, **kwargs)
old_w1 = module1.w.clone().detach()
old_w2 = module2.w.clone().detach()
g_cpu = torch.Generator()
g_cpu.manual_seed(0)
t1 = torch.rand((3, 3), requires_grad=True, generator=g_cpu)
t2 = torch.rand((3, 3), requires_grad=True, generator=g_cpu)
output1 = module1.forward(t2)
output2 = module2.forward(output1)
loss = torch.add(output2, t1).sum()
loss.backward()
local_optim.step()
# distributed version
owner1 = "worker%d" % ((self.rank + 1) % self.world_size)
owner2 = "worker%d" % ((self.rank + 2) % self.world_size)
remote_module1 = rpc.remote(owner1, MyModule)
remote_module2 = rpc.remote(owner2, MyModule)
remote_param1 = remote_method(MyModule.get_w, remote_module1)
remote_param2 = remote_method(MyModule.get_w, remote_module2)
old_w1_remote = remote_param1.to_here()
# sanity check: local and remote initial weights should match
self.assertEqual(old_w1, remote_param1.to_here())
self.assertEqual(old_w2, remote_param2.to_here())
dist_optim = DistributedOptimizer(
optim_cls, [remote_param1, remote_param2], *args, **kwargs
)
with dist_autograd.context() as context_id:
g_cpu.manual_seed(0)
t1 = torch.rand((3, 3), requires_grad=True, generator=g_cpu)
t2 = torch.rand((3, 3), requires_grad=True, generator=g_cpu)
output1 = rpc_async_method(MyModule.forward, remote_module1, t2)
output2 = rpc_async_method(MyModule.forward, remote_module2, output1.wait())
loss = torch.add(output2.wait(), t1)
dist_autograd.backward(context_id, [loss.sum()])
dist_optim.step(context_id)
new_w1 = rpc_async_method(MyModule.get_w, remote_module1).wait()
new_w2 = rpc_async_method(MyModule.get_w, remote_module2).wait()
# ensure optimizer changed weights
self.assertNotEqual(old_w1, new_w1)
self.assertNotEqual(old_w2, new_w2)
# ensure local equals remote
self.assertEqual(new_w1, module1.get_w())
self.assertEqual(new_w2, module2.get_w())
@dist_init()
def test_dist_optim(self):
self._test_dist_optim_base(optim.Adagrad, lr=0.05)
self._test_dist_optim_base(optim.Adam, lr=1e-2, amsgrad=True)
self._test_dist_optim_base(optim.AdamW, lr=0.05, amsgrad=True)
self._test_dist_optim_base(optim.SGD, lr=0.05)
self._test_dist_optim_base(optim.SGD, lr=1e-3, momentum=1, weight_decay=1, nesterov=True)
self._test_dist_optim_base(optim.Adadelta, rho=0.95)
self._test_dist_optim_base(optim.RMSprop, lr=0.05)
self._test_dist_optim_base(optim.Adamax, lr=0.05)
self._test_dist_optim_base(optim.Rprop, lr=0.05)
def _test_dist_optim_none_grads(self, optim_cls, *args, **kwargs):
# local version
module1 = MyModule()
module2 = MyModule(requires_grad=False)
params = [module1.get_w(), module2.get_w()]
local_optim = optim_cls(params, *args, **kwargs)
old_w1 = module1.w.clone().detach()
old_w2 = module2.w.clone().detach()
g_cpu = torch.Generator()
g_cpu.manual_seed(0)
t1 = torch.rand((3, 3), requires_grad=True, generator=g_cpu)
t2 = torch.rand((3, 3), requires_grad=True, generator=g_cpu)
output1 = module1.forward(t2)
output2 = module2.forward(output1)
loss = torch.add(output2, t1).sum()
loss.backward()
local_optim.step()
# distributed version
owner1 = "worker%d" % ((self.rank + 1) % self.world_size)
owner2 = "worker%d" % ((self.rank + 2) % self.world_size)
remote_module1 = rpc.remote(owner1, MyModule)
remote_module2 = rpc.remote(owner2, MyModule, args=(False,))
remote_param1 = remote_module1.remote().get_w()
remote_param2 = remote_module2.remote().get_w()
# sanity check: local and remote initial weights should match
self.assertEqual(old_w1, remote_param1.to_here())
self.assertEqual(old_w2, remote_param2.to_here())
dist_optim = DistributedOptimizer(
optim_cls, [remote_param1, remote_param2], *args, **kwargs
)
with dist_autograd.context() as context_id:
g_cpu.manual_seed(0)
t1 = torch.rand((3, 3), requires_grad=True, generator=g_cpu)
t2 = torch.rand((3, 3), requires_grad=True, generator=g_cpu)
output1 = remote_module1.rpc_async().forward(t2)
output2 = remote_module2.rpc_async().forward(output1.wait())
loss = torch.add(output2.wait(), t1)
dist_autograd.backward(context_id, [loss.sum()])
dist_optim.step(context_id)
new_w1 = remote_module1.rpc_async().get_w().wait()
new_w2 = remote_module2.rpc_async().get_w().wait()
# ensure optimizer changed weights for w1
self.assertNotEqual(old_w1, new_w1)
# ensure optimizer not changed weights for w2
self.assertEqual(old_w2, new_w2)
# ensure local equals remote
self.assertEqual(new_w1, module1.get_w())
self.assertEqual(new_w2, module2.get_w())
@dist_init()
def test_dist_optim_none_grads(self):
self._test_dist_optim_none_grads(optim.SGD, lr=0.05)
self._test_dist_optim_none_grads(optim.RMSprop, lr=0.05)
self._test_dist_optim_none_grads(optim.Rprop, lr=0.05)
self._test_dist_optim_none_grads(optim.Adadelta, rho=0.95)
|