1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247
|
from datetime import timedelta
import logging
import os
import threading
import warnings
from typing import Generator, Tuple
from urllib.parse import urlparse
import torch
import torch.distributed as dist
logger = logging.getLogger(__name__)
_init_counter = 0
_init_counter_lock = threading.Lock()
def is_available():
return hasattr(torch._C, "_rpc_init")
if is_available() and not torch._C._rpc_init():
raise RuntimeError("Failed to initialize torch.distributed.rpc")
if is_available():
from torch._C._distributed_c10d import Store
from torch._C._distributed_rpc import (
_disable_jit_rref_pickle,
_enable_jit_rref_pickle,
_disable_server_process_global_profiler,
_enable_server_process_global_profiler,
_set_and_start_rpc_agent,
_reset_current_rpc_agent,
_delete_all_user_and_unforked_owner_rrefs,
_destroy_rref_context,
_set_profiler_node_id,
_is_current_rpc_agent_set,
_rref_context_get_debug_info,
_cleanup_python_rpc_handler,
_invoke_rpc_builtin,
_invoke_rpc_python_udf,
_invoke_rpc_torchscript,
_invoke_remote_builtin,
_invoke_remote_python_udf,
_invoke_remote_torchscript,
_set_rpc_timeout,
_get_current_rpc_agent,
get_rpc_timeout,
enable_gil_profiling,
RpcBackendOptions,
_TensorPipeRpcBackendOptionsBase,
RpcAgent,
PyRRef,
TensorPipeAgent,
RemoteProfilerManager,
WorkerInfo,
_DEFAULT_INIT_METHOD,
_DEFAULT_NUM_WORKER_THREADS,
_UNSET_RPC_TIMEOUT,
_DEFAULT_RPC_TIMEOUT_SEC,
) # noqa: F401
from . import api, backend_registry, functions
from .api import * # noqa: F401,F403
import numbers
import torch.distributed.autograd as dist_autograd
from .backend_registry import BackendType
from .options import TensorPipeRpcBackendOptions # noqa: F401
from .server_process_global_profiler import (
_server_process_global_profile,
)
rendezvous_iterator: Generator[Tuple[Store, int, int], None, None]
def init_rpc(
name,
backend=None,
rank=-1,
world_size=None,
rpc_backend_options=None,
):
r"""
Initializes RPC primitives such as the local RPC agent
and distributed autograd, which immediately makes the current
process ready to send and receive RPCs.
Args:
name (str): a globally unique name of this node. (e.g.,
``Trainer3``, ``ParameterServer2``, ``Master``, ``Worker1``)
Name can only contain number, alphabet, underscore, colon,
and/or dash, and must be shorter than 128 characters.
backend (BackendType, optional): The type of RPC backend
implementation. Supported values is
``BackendType.TENSORPIPE`` (the default).
See :ref:`rpc-backends` for more information.
rank (int): a globally unique id/rank of this node.
world_size (int): The number of workers in the group.
rpc_backend_options (RpcBackendOptions, optional): The options
passed to the RpcAgent constructor. It must be an agent-specific
subclass of :class:`~torch.distributed.rpc.RpcBackendOptions`
and contains agent-specific initialization configurations. By
default, for all agents, it sets the default timeout to 60
seconds and performs the rendezvous with an underlying process
group initialized using ``init_method = "env://"``,
meaning that environment variables ``MASTER_ADDR`` and
``MASTER_PORT`` need to be set properly. See
:ref:`rpc-backends` for more information and find which options
are available.
"""
torch._C._log_api_usage_once("torch.distributed.init_rpc")
if backend is not None and not isinstance(
backend, backend_registry.BackendType
):
raise TypeError("Argument backend must be a member of BackendType")
if rpc_backend_options is not None and not isinstance(
rpc_backend_options, RpcBackendOptions
):
raise TypeError(
"Argument rpc_backend_options must be an instance of RpcBackendOptions"
)
# Try to detect the backend from the options
if backend is None and rpc_backend_options is not None:
for candidate_backend in BackendType:
if isinstance(
rpc_backend_options,
type(
backend_registry.construct_rpc_backend_options(
candidate_backend
)
),
):
backend = candidate_backend
break
else:
raise TypeError(
f"Could not infer backend for options {rpc_backend_options}"
)
# Ignore type error because mypy doesn't handle dynamically generated type objects (#4865)
if backend != BackendType.TENSORPIPE: # type: ignore[attr-defined]
logger.warning(
f"RPC was initialized with no explicit backend but with options " # type: ignore[attr-defined]
f"corresponding to {backend}, hence that backend will be used "
f"instead of the default {BackendType.TENSORPIPE}. To silence this "
f"warning pass `backend={backend}` explicitly."
)
if backend is None:
backend = BackendType.TENSORPIPE # type: ignore[attr-defined]
if rpc_backend_options is None:
# default construct a set of RPC backend options.
rpc_backend_options = backend_registry.construct_rpc_backend_options(
backend
)
# Create store, performs rendezvous for static RPC group.
if not world_size:
# If world_size is not set in construction and also not set in environment variables
# The store will be created for the dynamic group setting
store = dist._create_store_from_options(rpc_backend_options, rank)
else:
# This rendezvous state sometimes is destroyed before all processes
# finishing handshaking. To avoid that issue, we make it global to
# keep it alive.
global rendezvous_iterator
rendezvous_iterator = dist.rendezvous(
rpc_backend_options.init_method, rank=rank, world_size=world_size
)
store, _, _ = next(rendezvous_iterator)
# Use same timeout as RPC.
store.set_timeout(timedelta(seconds=rpc_backend_options.rpc_timeout))
# Use a PrefixStore to distinguish multiple invocations.
with _init_counter_lock:
global _init_counter
store = dist.PrefixStore(str("rpc_prefix_{}".format(_init_counter)), store)
_init_counter += 1
# Initialize autograd before RPC since _init_rpc_backend guarantees all
# processes sync via the store. If we initialize autograd after RPC,
# there could be a race where some nodes might have initialized autograd
# and others might not have. As a result, a node calling
# torch.distributed.autograd.backward() would run into errors since
# other nodes might not have been initialized.
dist_autograd._init(rank)
_set_profiler_node_id(rank)
# Initialize RPC.
_init_rpc_backend(backend, store, name, rank, world_size, rpc_backend_options)
def _validate_rpc_args(backend, store, name, rank, world_size, rpc_backend_options):
type_mapping = {
backend: backend_registry.BackendType,
store: dist.Store,
name: str,
rank: numbers.Integral,
# world_size can be None for a dynamic group
world_size: (numbers.Integral, type(None)),
rpc_backend_options: RpcBackendOptions,
}
for arg, arg_type in type_mapping.items():
if not isinstance(arg, arg_type): # type: ignore[arg-type]
raise RuntimeError(
"Argument {} must be of type {} but got type {}".format(
arg, arg_type, type(arg)
)
)
def _init_rpc_backend(
backend=BackendType.TENSORPIPE, # type: ignore[attr-defined]
store=None,
name=None,
rank=-1,
world_size=None,
rpc_backend_options=None,
):
_validate_rpc_args(backend, store, name, rank, world_size, rpc_backend_options)
if _is_current_rpc_agent_set():
raise RuntimeError("RPC is already initialized")
# Initialize RPC.
rpc_agent = backend_registry.init_backend(
backend,
store=store,
name=name,
rank=rank,
world_size=world_size,
rpc_backend_options=rpc_backend_options,
)
api._init_rpc_states(rpc_agent)
@api._require_initialized
def _get_debug_info():
info = _rref_context_get_debug_info()
info.update(api._get_current_rpc_agent().get_debug_info())
info.update(dist_autograd._get_debug_info())
return info
|