1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128
|
#!/usr/bin/env python3
# mypy: allow-untyped-defs
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import datetime
import logging
from typing import cast, Optional
from torch.distributed import PrefixStore, Store, TCPStore
from torch.distributed.elastic.rendezvous import (
RendezvousHandler,
RendezvousInfo,
RendezvousParameters,
RendezvousStoreInfo,
)
from torch.distributed.elastic.rendezvous.utils import parse_rendezvous_endpoint
__all__ = ["StaticTCPRendezvous", "create_rdzv_handler"]
logger = logging.getLogger(__name__)
_default_timeout_seconds = 600
class StaticTCPRendezvous(RendezvousHandler):
"""
Static rendezvous that is a wrapper around the TCPStore.
Creates TCPStore based on the input parameters with the
listener on the agent with group_rank=0
"""
def __init__(
self,
master_addr: str,
master_port: int,
rank: int,
world_size: int,
run_id: str,
timeout: int,
):
self.master_addr = master_addr
self.master_port = master_port
self.rank = rank
self.world_size = world_size
self.run_id = run_id
self.timeout = datetime.timedelta(seconds=timeout)
self._store: Optional[Store] = None
def get_backend(self) -> str:
return "static"
@property
def use_agent_store(self) -> bool:
return True
def next_rendezvous(self) -> RendezvousInfo:
logger.info("Creating TCPStore as the c10d::Store implementation")
is_master = self.rank == 0
if not self._store:
self._store = TCPStore( # type: ignore[call-arg]
self.master_addr,
self.master_port,
self.world_size,
is_master,
self.timeout,
multi_tenant=True,
)
store = PrefixStore(self.run_id, self._store)
# TCPStore server instance is used by trainer code
bootstrap_store_info = RendezvousStoreInfo(self.master_addr, self.master_port)
return RendezvousInfo(
store,
self.rank,
self.world_size,
bootstrap_store_info,
)
def is_closed(self):
return False
def set_closed(self):
pass
def num_nodes_waiting(self):
return 0
def get_run_id(self) -> str:
return self.run_id
def shutdown(self) -> bool:
return True
def create_rdzv_handler(params: RendezvousParameters) -> RendezvousHandler:
if "rank" not in params.config:
raise ValueError(
"rank is absent in RendezvousParameters."
"Try add --node-rank to the cmd request"
)
endpoint = params.endpoint.strip()
if not endpoint:
raise ValueError(
"endpoint is absent in RendezvousParameters"
"Try add --master-port and --master-addr to the cmd request"
)
master_addr, master_port = parse_rendezvous_endpoint(endpoint, -1)
if master_port == -1:
raise ValueError(
f"Port is absent in endpoint: {endpoint}. Try launching with --master-port"
)
world_size = params.max_nodes
rank = cast(int, params.config.get("rank"))
run_id = params.run_id
if "timeout" in params.config:
timeout = int(params.config["timeout"])
else:
timeout = _default_timeout_seconds
return StaticTCPRendezvous(
master_addr, master_port, rank, world_size, run_id, timeout
)
|