1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153
|
# Owner(s): ["oncall: distributed"]
import json
import logging
import os
import re
import sys
from functools import partial, wraps
import torch
import torch.distributed as dist
from torch.distributed.c10d_logger import _c10d_logger, _exception_logger
if not dist.is_available():
print("Distributed not available, skipping tests", file=sys.stderr)
sys.exit(0)
from torch.testing._internal.common_distributed import MultiProcessTestCase, TEST_SKIPS
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
if TEST_WITH_DEV_DBG_ASAN:
print(
"Skip dev-asan as torch + multiprocessing spawn have known issues",
file=sys.stderr,
)
sys.exit(0)
BACKEND = dist.Backend.NCCL
WORLD_SIZE = min(4, max(2, torch.cuda.device_count()))
def with_comms(func=None):
if func is None:
return partial(
with_comms,
)
@wraps(func)
def wrapper(self, *args, **kwargs):
if BACKEND == dist.Backend.NCCL and torch.cuda.device_count() < self.world_size:
sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code)
self.dist_init()
func(self)
self.destroy_comms()
return wrapper
class C10dErrorLoggerTest(MultiProcessTestCase):
def setUp(self):
super().setUp()
os.environ["WORLD_SIZE"] = str(self.world_size)
os.environ["BACKEND"] = BACKEND
self._spawn_processes()
@property
def device(self):
return (
torch.device(self.rank)
if BACKEND == dist.Backend.NCCL
else torch.device("cpu")
)
@property
def world_size(self):
return WORLD_SIZE
@property
def process_group(self):
return dist.group.WORLD
def destroy_comms(self):
# Wait for all ranks to reach here before starting shutdown.
dist.barrier()
dist.destroy_process_group()
def dist_init(self):
dist.init_process_group(
backend=BACKEND,
world_size=self.world_size,
rank=self.rank,
init_method=f"file://{self.file_name}",
)
# set device for nccl pg for collectives
if BACKEND == "nccl":
torch.cuda.set_device(self.rank)
def test_get_or_create_logger(self):
self.assertIsNotNone(_c10d_logger)
self.assertEqual(1, len(_c10d_logger.handlers))
self.assertIsInstance(_c10d_logger.handlers[0], logging.NullHandler)
@_exception_logger
def _failed_broadcast_raise_exception(self):
tensor = torch.arange(2, dtype=torch.int64)
dist.broadcast(tensor, self.world_size + 1)
@_exception_logger
def _failed_broadcast_not_raise_exception(self):
try:
tensor = torch.arange(2, dtype=torch.int64)
dist.broadcast(tensor, self.world_size + 1)
except Exception:
pass
@with_comms
def test_exception_logger(self) -> None:
with self.assertRaises(Exception):
self._failed_broadcast_raise_exception()
with self.assertLogs(_c10d_logger, level="DEBUG") as captured:
self._failed_broadcast_not_raise_exception()
error_msg_dict = json.loads(
re.search("({.+})", captured.output[0]).group(0).replace("'", '"')
)
self.assertEqual(len(error_msg_dict), 9)
self.assertIn("pg_name", error_msg_dict.keys())
self.assertEqual("None", error_msg_dict["pg_name"])
self.assertIn("func_name", error_msg_dict.keys())
self.assertEqual("broadcast", error_msg_dict["func_name"])
self.assertIn("backend", error_msg_dict.keys())
self.assertEqual("nccl", error_msg_dict["backend"])
self.assertIn("nccl_version", error_msg_dict.keys())
nccl_ver = torch.cuda.nccl.version()
self.assertEqual(
".".join(str(v) for v in nccl_ver), error_msg_dict["nccl_version"]
)
# In this test case, group_size = world_size, since we don't have multiple processes on one node.
self.assertIn("group_size", error_msg_dict.keys())
self.assertEqual(str(self.world_size), error_msg_dict["group_size"])
self.assertIn("world_size", error_msg_dict.keys())
self.assertEqual(str(self.world_size), error_msg_dict["world_size"])
self.assertIn("global_rank", error_msg_dict.keys())
self.assertIn(str(dist.get_rank()), error_msg_dict["global_rank"])
# In this test case, local_rank = global_rank, since we don't have multiple processes on one node.
self.assertIn("local_rank", error_msg_dict.keys())
self.assertIn(str(dist.get_rank()), error_msg_dict["local_rank"])
if __name__ == "__main__":
run_tests()
|