1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
|
from typing import Optional
import torch.distributed as dist
from . import config
_COMPILE_PG: Optional[dist.ProcessGroup] = None
def get_compile_pg() -> Optional[dist.ProcessGroup]:
if (
config.enable_compiler_collectives
and dist.is_available()
and dist.is_initialized()
):
global _COMPILE_PG
if _COMPILE_PG is None:
# , timeout=datetime.timedelta(seconds=2)
_COMPILE_PG = dist.distributed_c10d._new_group_with_tag(
pg_tag="pt2_compile_pg"
)
return _COMPILE_PG
return None
|