1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536
|
# mypy: allow-untyped-defs
import sys
import threading
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Union
from functools import partial, reduce
import torch
import torch.distributed as dist
import weakref
from torch._C._distributed_c10d import (
_create_work_from_future,
AllgatherOptions,
AllreduceOptions,
AllToAllOptions,
BarrierOptions,
BroadcastOptions,
ReduceScatterOptions,
ScatterOptions,
Store,
ReduceOp,
)
from torch.distributed.distributed_c10d import _CollOp, _store_based_barrier, P2POp
from torch.futures import Future
from torch.utils import _pytree as pytree
"""
TODO:
Lots of missing collectives.
Collectives validation.
Make timeout robust by making collectives respect the test deadline.
Make tests robust by making collectives interruptible.
We need some synchronization around cleanup to ensure that timedout ranks don't cause spurious failures.
"""
def flatten_list(lst):
return pytree.tree_leaves(lst)
def ret_work(ret):
fut = Future()
fut.set_result(ret)
return _create_work_from_future(fut)
def binop_reduce(tensors, op):
res = op(torch.stack(tensors), dim=0)
if isinstance(res, torch.Tensor):
return res
# min/max return a namedtuple
return res.values
def bitwise_reduce(tensors, op):
return reduce(op, tensors)
_reduce_ops = {
ReduceOp.SUM: partial(binop_reduce, op=torch.sum),
ReduceOp.AVG: partial(binop_reduce, op=torch.mean),
ReduceOp.PRODUCT: partial(binop_reduce, op=torch.prod),
ReduceOp.MIN: partial(binop_reduce, op=torch.min),
ReduceOp.MAX: partial(binop_reduce, op=torch.max),
ReduceOp.BAND: partial(bitwise_reduce, op=torch.bitwise_and),
ReduceOp.BOR: partial(bitwise_reduce, op=torch.bitwise_or),
ReduceOp.BXOR: partial(bitwise_reduce, op=torch.bitwise_xor),
}
class AllToAll:
@torch.no_grad()
def work(self, data):
world_size = len(data)
for dest_rank in range(world_size):
output_tensor_list, _ = data[dest_rank]
for src_rank in range(world_size):
_, input_tensor_list = data[src_rank]
output_tensor_list[src_rank].copy_(input_tensor_list[dest_rank])
class AllToAllBase:
@torch.no_grad()
def work(self, data):
world_size = len(data)
for dest_rank in range(world_size):
output_buffer, _, output_split_sizes, _ = data[dest_rank]
output_indexes = self._size_cumsum(output_buffer.size(0), output_split_sizes, world_size)
for src_rank in range(world_size):
_, input_buffer, _, input_split_sizes = data[src_rank]
input_indexes = self._size_cumsum(input_buffer.size(0), input_split_sizes, world_size)
output_buffer[output_indexes[src_rank]:output_indexes[src_rank + 1]].copy_(
input_buffer[input_indexes[dest_rank]:input_indexes[dest_rank + 1]]
)
def _size_cumsum(self, buf_size: int, sizes: Union[torch.Tensor, List[int], None], world_size: int) -> torch.Tensor:
if sizes is None or len(sizes) == 0:
sizes = torch.full(
(world_size,), buf_size // world_size, dtype=torch.int64
)
if not isinstance(sizes, torch.Tensor):
sizes = torch.tensor(sizes, dtype=torch.int64)
assert sizes.dtype == torch.int64
sizes = torch.cumsum(
torch.cat(
(
torch.tensor([0], dtype=torch.int64, device=sizes.device), sizes
),
dim=0
),
dim=0
)
return sizes
class AllReduce:
def __init__(self, op):
if op.op not in _reduce_ops:
raise NotImplementedError(
f"AllReduce op {op.op} not supported on multithreaded pg for now."
)
self.op = op.op
@torch.no_grad()
def work(self, data):
for i in range(len(data[0])):
# use rank0 as the device for sum
rank_0_device = data[0][i].device
# collect all data to the list and make them
# all on rank 0 device
tensors = [data[src_rank][i].to(rank_0_device) for src_rank in range(0, len(data))]
# now mimic reduce across all ranks
res = _reduce_ops[self.op](tensors)
# copy all the reduced value to each rank
for src_rank in range(len(data)):
data[src_rank][i].copy_(res.to(data[src_rank][i].device))
class AllGather:
@torch.no_grad()
def work(self, data):
for src_rank in range(len(data)):
in_tensor_list = data[src_rank][1]
# Can't handle all_gather with multiple tensors
assert len(in_tensor_list) == 1
src_tensor = in_tensor_list[0]
for dest in data:
dest_tensor = dest[0][0][src_rank]
dest_tensor.copy_(src_tensor)
class Scatter:
def __init__(self, src):
self.src = src
@torch.no_grad()
def work(self, data):
src_in_tensor_list = data[self.src][1]
# Can't handle scatter with multiple input tensor list
assert len(src_in_tensor_list) == 1
src_in_tensors = src_in_tensor_list[0]
for rank, each_rank_data in enumerate(data):
out_tensor_list = each_rank_data[0]
# Can't handle scatter with multiple output tensor
assert len(out_tensor_list) == 1
dest_tensor = out_tensor_list[0]
dest_tensor.copy_(src_in_tensors[rank])
class Gather:
def __init__(self, dst):
self.dst = dst
@torch.no_grad()
def work(self, data):
# Can't handle gather with multiple tensor lists
assert len(data[self.dst][0]) == 1
out_tensor_list = data[self.dst][0][0]
for rank, each_rank_data in enumerate(data):
src_in_tensor_list = each_rank_data[1]
# Can't handle gather with multiple tensor lists
assert len(src_in_tensor_list) == 1
dest_tensor = out_tensor_list[rank]
dest_tensor.copy_(src_in_tensor_list[0])
class ReduceScatter:
def __init__(self, op):
if op != dist.ReduceOp.SUM and op != dist.ReduceOp.AVG:
raise NotImplementedError(f"ReduceScatter does not support {op}")
self.op = op
@torch.no_grad()
def work(self, data):
start_reduction = [False for _ in range(len(data))]
for each_rank_data in data:
# Can't handle reduce_scatter with multiple scatter list
assert len(each_rank_data[1]) == 1
to_scatter = each_rank_data[1][0]
for i in range(len(to_scatter)):
dest_tensor_on_rank_i = data[i][0]
# Can't handle reduce_scatter with multiple output tensor
assert len(dest_tensor_on_rank_i) == 1
dst_tensor_device = dest_tensor_on_rank_i[0].device
if not start_reduction[i]:
dest_tensor_on_rank_i[0].copy_(to_scatter[i].to(dst_tensor_device))
start_reduction[i] = True
else:
dest_tensor_on_rank_i[0].add_(to_scatter[i].to(dst_tensor_device))
if self.op == dist.ReduceOp.AVG:
num_ranks = len(data)
for each_rank_data in data:
each_rank_data[0][0] /= num_ranks
class Broadcast:
def __init__(self, src):
self.src = src
@torch.no_grad()
def work(self, data):
in_tensor_list = flatten_list(data[self.src])
for i in range(len(data)):
out_tensor_list = flatten_list(data[i])
for j in range(len(in_tensor_list)):
out_tensor_list[j].copy_(in_tensor_list[j])
class Collective:
def __init__(self, world_size, collective, pg):
self._world_size = world_size
self._collective = collective
self._start_cond = threading.Condition()
self._done_cond = threading.Condition()
self._data = [None] * world_size
self._count = 0
self._done = False
self._pg = pg
def join(self, rank, data):
with self._start_cond:
self._data[rank] = data
self._count += 1
# notify rank 0
if self._count == self._world_size:
if rank > 0:
self._start_cond.notify()
if rank == 0:
self._start_cond.wait_for(
lambda: self._count == self._world_size or self._pg._terminate.is_set()
)
# SystemExit is not a subclass of Exception but BaseException
# and can be distinguished from normal exception raised from program errors
# so that we can hide it from the exception queue
if self._pg._terminate.is_set():
sys.exit("Test termination event occurs.")
with self._done_cond:
# wait for rank 0 to finish
if rank > 0:
self._done_cond.wait_for(lambda: self._done or self._pg._terminate.is_set())
if self._pg._terminate.is_set():
sys.exit("Test termination event occurs.")
else:
# copy data around
self._collective.work(self._data)
self._done = True
self._done_cond.notify_all()
return ret_work(data)
class ProcessLocalGroup(dist.ProcessGroup):
_coll_lock = threading.Lock()
_cur_coll_on_pgs = {}
_terminate = threading.Event()
@classmethod
def _start_coll(cls, collective, pg):
with cls._coll_lock:
# pg_name is unique, we use that to record the mapping between pg and collective
if pg.pg_name not in cls._cur_coll_on_pgs:
cls._cur_coll_on_pgs[pg.pg_name] = Collective(pg.size(), collective, cls)
return cls._cur_coll_on_pgs[pg.pg_name]
@classmethod
def _end_coll(cls, collective, pg):
# This is racily called by all ranks, so only one will work
with cls._coll_lock:
if pg.pg_name in cls._cur_coll_on_pgs and cls._cur_coll_on_pgs[pg.pg_name] == collective:
cls._cur_coll_on_pgs.pop(pg.pg_name)
@classmethod
def exception_handle(cls, exc):
cls._terminate.set()
for coll in cls._cur_coll_on_pgs.values():
with coll._start_cond:
coll._start_cond.notify()
with coll._done_cond:
coll._done_cond.notify_all()
@classmethod
def reset(cls):
with cls._coll_lock:
cls._cur_coll_on_pgs = {}
cls._terminate.clear()
def alltoall_base(
self,
output_buffer: torch.Tensor,
input_buffer: torch.Tensor,
output_split_sizes: Optional[List[int]],
input_split_sizes: Optional[List[int]],
opts=AllToAllOptions()
) -> torch.Tensor:
coll = ProcessLocalGroup._start_coll(AllToAllBase(), self)
res = coll.join(self._rank, (output_buffer, input_buffer, output_split_sizes, input_split_sizes))
ProcessLocalGroup._end_coll(coll, self)
return res
def alltoall(self, output_tensor_list, input_tensor_list, opts=AllToAllOptions()):
coll = ProcessLocalGroup._start_coll(AllToAll(), self)
res = coll.join(self._rank, (output_tensor_list, input_tensor_list))
ProcessLocalGroup._end_coll(coll, self)
return res
def allreduce(self, tensor_list, opts=AllreduceOptions()):
coll = ProcessLocalGroup._start_coll(AllReduce(opts.reduceOp), self)
res = coll.join(self._rank, tensor_list)
ProcessLocalGroup._end_coll(coll, self)
return res
def allreduce_coalesced(self, tensor_list, opts=AllreduceOptions()):
coll = ProcessLocalGroup._start_coll(AllReduce(opts.reduceOp), self)
res = coll.join(self._rank, tensor_list)
ProcessLocalGroup._end_coll(coll, self)
return res
def barrier(self, opts=BarrierOptions()):
return self.allreduce(tensor_list=[torch.ones(1)])
def allgather(self, output_tensors, input_tensor, opts=AllgatherOptions()):
coll = ProcessLocalGroup._start_coll(AllGather(), self)
res = coll.join(self._rank, (output_tensors, input_tensor))
ProcessLocalGroup._end_coll(coll, self)
return res
def _allgather_base(self, output_tensor, input_tensor, opts=AllgatherOptions()):
tensor_list = list(torch.chunk(output_tensor, self._world_size))
return self.allgather([tensor_list], [input_tensor], opts)
def broadcast(self, tensor_list, opts=BroadcastOptions()):
coll = ProcessLocalGroup._start_coll(Broadcast(opts.rootRank), self)
res = coll.join(self._rank, tensor_list)
ProcessLocalGroup._end_coll(coll, self)
return res
def scatter(self, output_tensors, input_tensors, opts=ScatterOptions()):
coll = ProcessLocalGroup._start_coll(Scatter(opts.rootRank), self)
res = coll.join(self._rank, (output_tensors, input_tensors))
ProcessLocalGroup._end_coll(coll, self)
return res
def gather(self, output_tensors, input_tensors, opts=ScatterOptions()):
coll = ProcessLocalGroup._start_coll(Gather(opts.rootRank), self)
res = coll.join(self._rank, (output_tensors, input_tensors))
ProcessLocalGroup._end_coll(coll, self)
return res
def reduce_scatter(self, output_tensor, scatter_list, opts=ReduceScatterOptions()):
coll = ProcessLocalGroup._start_coll(ReduceScatter(opts.reduceOp), self)
res = coll.join(self._rank, (output_tensor, scatter_list))
ProcessLocalGroup._end_coll(coll, self)
return res
def _reduce_scatter_base(self, output_tensor, input_tensor, opts=ReduceScatterOptions()):
tensor_list = list(torch.chunk(input_tensor, self._world_size))
return self.reduce_scatter([output_tensor], [tensor_list], opts)
def reduce_scatter_tensor_coalesced(self, output_tensors, input_tensors, opts=ReduceScatterOptions()):
works = [
self._reduce_scatter_base(output_tensor, input_tensor, opts)
for output_tensor, input_tensor
in zip(output_tensors, input_tensors)
]
for work in works[:-1]:
work.wait()
return works[-1]
def allgather_into_tensor_coalesced(self, output_tensor_list, input_tensor_list, opts=AllgatherOptions()):
res = None
for o_t, i_t in zip(output_tensor_list, input_tensor_list):
res = self._allgather_base(o_t, i_t)
return res
def __init__(self, rank, world_size):
super().__init__(rank, world_size)
self._rank = rank
self._world_size = world_size
world = dist.distributed_c10d._world
if isinstance(world, ThreadLocalWorld):
world = world._get_world()
self._world = weakref.ref(world)
self._ctx = torch.autograd.set_multithreading_enabled(False)
def size(self):
return self._world_size
@property
def pg_name(self):
"""
return the global registered name of the current pg in the world
"""
return self._world().pg_names[self]
@property
def group_name(self):
return self.pg_name
def getBackendName(self):
return "threaded"
def __repr__(self):
return f"ThreadedPG world_size:{self._world_size} rank:{self._rank}"
def _create_threaded_pg(prefix_store, rank, world_size, timeout):
pg = ProcessLocalGroup(rank, world_size)
# https://github.com/pytorch/pytorch/pull/103033 changed store based barrier to optional
# When device mesh involves sub groups while store based barrier is not enabled in c10d,
# even though threaded pg actual collectives are assumed to be single threaded,
# different threads may be initializing different groups,
# leading to race conditions.
# For example, if we have a mesh of [[0, 1], [2, 3]], the sub groups
# (dim 0 and 1) would be initialized in different threads independently.
# In this case we can no longer rely on class or global variables
# but have to rely on store based barrier to make sure each group
# is ready separately before we can invoke collectives in any of the groups.
# the prefix store is already per group so we pass an empty name here
_store_based_barrier(rank, prefix_store, "", world_size, timeout)
return pg
dist.Backend.register_backend("threaded", _create_threaded_pg, devices=["cpu", "cuda"])
@dataclass
class WorldData:
default_pg: dist.ProcessGroup
pg_map: Dict[dist.ProcessGroup, Tuple[str, Optional[Store]]]
pg_names: Dict[dist.ProcessGroup, str]
pg_group_ranks: Dict[dist.ProcessGroup, Dict[int, int]]
pg_backend_config: Dict[dist.ProcessGroup, str]
group_count: int
tags_to_pg: Dict[str, List[dist.ProcessGroup]]
pg_to_tag: Dict[dist.ProcessGroup, str]
pg_coalesce_state: Dict[dist.ProcessGroup, List[Union[_CollOp, P2POp]]]
class ThreadLocalWorld:
_world = threading.local()
def _get_world(self) -> WorldData:
if not hasattr(ThreadLocalWorld._world, "world"):
ThreadLocalWorld._world.world = WorldData(None, {}, {}, {}, {}, 0, {}, {}, {})
return ThreadLocalWorld._world.world
@property
def default_pg(self):
return self._get_world().default_pg
@default_pg.setter
def default_pg(self, value):
self._get_world().default_pg = value
@property
def pg_map(self):
return self._get_world().pg_map
@property
def pg_names(self):
return self._get_world().pg_names
@property
def pg_group_ranks(self):
return self._get_world().pg_group_ranks
@property
def pg_backend_config(self):
return self._get_world().pg_backend_config
@property
def group_count(self) -> int:
return self._get_world().group_count
@group_count.setter
def group_count(self, value):
self._get_world().group_count = value
@property
def tags_to_pg(self):
return self._get_world().tags_to_pg
@property
def pg_to_tag(self):
return self._get_world().pg_to_tag
@property
def pg_coalesce_state(self) -> Dict[dist.ProcessGroup, List[Union[_CollOp, P2POp]]]:
return self._get_world().pg_coalesce_state
_old_pg_world = None
_ctx_manager = None
def _install_threaded_pg():
global _old_pg_world
global _ctx_manager
_old_pg_world = dist.distributed_c10d._world
dist.distributed_c10d._world = ThreadLocalWorld()
_ctx_manager = torch.autograd.set_multithreading_enabled(False)
return dist.distributed_c10d._world
def _uninstall_threaded_pg():
dist.distributed_c10d._world = _old_pg_world
|