1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84
|
#!/usr/bin/env python3
# Owner(s): ["oncall: distributed"]
import contextlib
import copyreg
import os
import sys
import torch
import torch.distributed as dist
if not dist.is_available():
print("Distributed not available, skipping tests", file=sys.stderr)
sys.exit(0)
import torch.distributed.rpc as rpc
import torch.multiprocessing.reductions as TorchMpReductions
from torch import multiprocessing
from torch.distributed.rpc.api import _use_rpc_pickler
from torch.distributed.rpc.internal import _InternalRPCPickler
from torch.testing._internal.common_utils import run_tests, TestCase
@contextlib.contextmanager
def fs_sharing():
prev_strategy = multiprocessing.get_sharing_strategy()
multiprocessing.set_sharing_strategy("file_system")
try:
yield
finally:
multiprocessing.set_sharing_strategy(prev_strategy)
class ShareMemoryRPCPickler(_InternalRPCPickler):
def __init__(self) -> None:
super().__init__()
self._dispatch_table
# pyre-fixme[4]: Attribute must be annotated.
self._dispatch_table = copyreg.dispatch_table.copy()
for t in torch._storage_classes:
self._dispatch_table[t] = TorchMpReductions.reduce_storage
for t in torch._tensor_classes:
self._dispatch_table[t] = TorchMpReductions.reduce_tensor
self._dispatch_table[torch.Tensor] = TorchMpReductions.reduce_tensor
self._dispatch_table[
torch.nn.parameter.Parameter
] = TorchMpReductions.reduce_tensor
def worker_loop(a):
rpc.init_rpc("worker1", rank=1, world_size=2)
rpc.shutdown()
def worker_fn(m):
pass
class TestRPCPickler(TestCase):
def test_case(self):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29500"
with fs_sharing():
r = multiprocessing.spawn(worker_loop, join=False)
try:
with _use_rpc_pickler(ShareMemoryRPCPickler()):
rpc.init_rpc("worker0", rank=0, world_size=2)
m = torch.nn.Linear(1, 2)
m.share_memory()
rref = rpc.remote("worker1", worker_fn, args=(m,))
rref.to_here()
finally:
rpc.shutdown()
r.join()
if __name__ == "__main__":
run_tests()
|