1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417
|
#!/usr/bin/env python3
# mypy: allow-untyped-defs
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import json
import os
import signal
import socket
import time
import uuid
from string import Template
from typing import Any, Dict, Optional, Tuple, TYPE_CHECKING
import torch.distributed.elastic.timer as timer
from torch.distributed.elastic import events
from torch.distributed.elastic.agent.server.api import (
RunResult,
SimpleElasticAgent,
WorkerGroup,
WorkerSpec,
WorkerState,
)
from torch.distributed.elastic.agent.server.health_check_server import (
create_healthcheck_server,
HealthCheckServer,
)
from torch.distributed.elastic.metrics.api import prof
from torch.distributed.elastic.multiprocessing import (
LogsSpecs,
PContext,
start_processes,
)
from torch.distributed.elastic.utils import macros
from torch.distributed.elastic.utils.logging import get_logger
if TYPE_CHECKING:
from torch.distributed.elastic.events.api import EventMetadataValue
logger = get_logger(__name__)
__all__ = [
"LocalElasticAgent",
"TORCHELASTIC_ENABLE_FILE_TIMER",
"TORCHELASTIC_TIMER_FILE",
"TORCHELASTIC_HEALTH_CHECK_PORT",
]
TORCHELASTIC_ENABLE_FILE_TIMER = "TORCHELASTIC_ENABLE_FILE_TIMER"
TORCHELASTIC_HEALTH_CHECK_PORT = "TORCHELASTIC_HEALTH_CHECK_PORT"
TORCHELASTIC_TIMER_FILE = "TORCHELASTIC_TIMER_FILE"
class LocalElasticAgent(SimpleElasticAgent):
"""An implementation of :py:class:`torchelastic.agent.server.ElasticAgent` that handles host-local workers.
This agent is deployed per host and is configured to spawn ``n`` workers.
When using GPUs, ``n`` maps to the number of GPUs available on the host.
The local agent does not communicate to other local agents deployed on
other hosts, even if the workers may communicate inter-host. The worker id
is interpreted to be a local process. The agent starts and stops all worker
processes as a single unit.
The worker function and argument passed to the worker function must be
python multiprocessing compatible. To pass multiprocessing data structures
to the workers you may create the data structure in the same multiprocessing
context as the specified ``start_method`` and pass it as a function argument.
The ``exit_barrier_timeout`` specifies the amount of time (in seconds) to wait
for other agents to finish. This acts as a safety net to handle cases where
workers finish at different times, to prevent agents from viewing workers
that finished early as a scale-down event. It is strongly advised that the
user code deal with ensuring that workers are terminated in a synchronous
manner rather than relying on the exit_barrier_timeout.
A named pipe based watchdog can be enabled in ```LocalElasticAgent``` if an
environment variable ``TORCHELASTIC_ENABLE_FILE_TIMER`` with value 1 has
been defined in the ```LocalElasticAgent``` process.
Optionally, another environment variable ```TORCHELASTIC_TIMER_FILE```
can be set with a unique file name for the named pipe. If the environment
variable ```TORCHELASTIC_TIMER_FILE``` is not set, ```LocalElasticAgent```
will internally create a unique file name and set it to the environment
variable ```TORCHELASTIC_TIMER_FILE```, and this environment variable will
be propagated to the worker processes to allow them to connect to the same
named pipe that ```LocalElasticAgent``` uses.
Logs are written to the specified log directory. Each log line will be by default
prefixed by ``[${role_name}${local_rank}]:`` (e.g. ``[trainer0]: foobar``).
Log prefixes can be customized by passing a `template string
<https://docs.python.org/3/library/string.html#template-strings>`_ as the
``log_line_prefix_template`` argument.
The following macros (identifiers) are substituted at runtime:
``${role_name}, ${local_rank}, ${rank}``. For example, to prefix each log line with
global rank instead of the local rank, set ``log_line_prefix_template = "[${rank}]:``.
Example launching function
::
def trainer(args) -> str:
return "do train"
def main():
start_method="spawn"
shared_queue= multiprocessing.get_context(start_method).Queue()
spec = WorkerSpec(
role="trainer",
local_world_size=nproc_per_process,
entrypoint=trainer,
args=("foobar",),
...<OTHER_PARAMS...>)
agent = LocalElasticAgent(spec, start_method)
results = agent.run()
if results.is_failed():
print("trainer failed")
else:
print(f"rank 0 return value: {results.return_values[0]}")
# prints -> rank 0 return value: do train
Example launching binary
::
def main():
spec = WorkerSpec(
role="trainer",
local_world_size=nproc_per_process,
entrypoint="/usr/local/bin/trainer",
args=("--trainer-args", "foobar"),
...<OTHER_PARAMS...>)
agent = LocalElasticAgent(spec)
results = agent.run()
if not results.is_failed():
print("binary launches do not have return values")
"""
def __init__(
self,
spec: WorkerSpec,
logs_specs: LogsSpecs,
start_method="spawn",
exit_barrier_timeout: float = 300,
log_line_prefix_template: Optional[str] = None,
):
super().__init__(spec, exit_barrier_timeout)
self._start_method = start_method
self._pcontext: Optional[PContext] = None
self._rdzv_handler = spec.rdzv_handler
self._log_line_prefix_template = log_line_prefix_template
self._worker_watchdog: Optional[timer.FileTimerServer] = None
self._logs_specs = logs_specs
self._health_check_server: Optional[HealthCheckServer] = None
def _setup_local_watchdog(self, envs: Dict[int, Dict[str, str]]) -> None:
enable_watchdog_env_name = TORCHELASTIC_ENABLE_FILE_TIMER
watchdog_enabled = os.getenv(enable_watchdog_env_name)
watchdog_file_env_name = TORCHELASTIC_TIMER_FILE
watchdog_file_path = os.getenv(watchdog_file_env_name)
if watchdog_enabled is not None and str(watchdog_enabled) == "1":
if watchdog_file_path is None:
watchdog_file_path = "/tmp/watchdog_timer_" + str(uuid.uuid4())
logger.info("Starting a FileTimerServer with %s ...", watchdog_file_path)
if not envs:
logger.warning(
"Empty envs variables, using empty run_id for FileTimerServer"
)
run_id = ""
else:
run_id = envs[0]["TORCHELASTIC_RUN_ID"]
self._worker_watchdog = timer.FileTimerServer(
file_path=watchdog_file_path,
run_id=run_id,
max_interval=0.1,
daemon=True,
log_event=self._log_watchdog_event,
)
self._worker_watchdog.start()
logger.info("FileTimerServer started")
else:
logger.info(
"Environment variable '%s' not found. Do not start FileTimerServer.",
enable_watchdog_env_name,
)
# Propagate the watchdog file env to worker processes
if watchdog_file_path is not None:
for worker_env in envs.values():
worker_env[watchdog_file_env_name] = watchdog_file_path
@staticmethod
def _get_current_time_secs() -> int:
return int(time.time())
def _setup_healthcheck(self) -> None:
healthcheck_port_env_name = TORCHELASTIC_HEALTH_CHECK_PORT
healthcheck_port = os.getenv(healthcheck_port_env_name)
if healthcheck_port is not None:
logger.info(
"Found healthcheck port %s: %s",
healthcheck_port_env_name,
healthcheck_port,
)
if self._worker_watchdog is None:
logger.info(
"FileTimerServer doesn't exist, using current time as dummy callback"
)
alive_callback = LocalElasticAgent._get_current_time_secs
else:
alive_callback = self._worker_watchdog.get_last_progress_time
try:
healthcheck_port_as_int = int(healthcheck_port)
self._health_check_server = create_healthcheck_server(
alive_callback=alive_callback,
port=healthcheck_port_as_int,
timeout=60,
)
self._health_check_server.start()
except ValueError:
logger.info(
"Invalid healthcheck port value: '%s', expecting integer. Not starting healthcheck server.",
healthcheck_port,
)
else:
logger.info(
"Environment variable '%s' not found. Do not start health check.",
healthcheck_port_env_name,
)
def _get_fq_hostname(self) -> str:
return socket.getfqdn(socket.gethostname())
def _log_watchdog_event(
self,
name: str,
request: Optional[timer.FileTimerRequest],
) -> None:
wg = self._worker_group
spec = wg.spec
md = {"watchdog_event": name}
if request is not None:
md["worker_pid"] = str(request.worker_pid)
md["scope_id"] = request.scope_id
md["expiration_time"] = str(request.expiration_time)
md["signal"] = str(request.signal)
md_str = json.dumps(md)
state = "RUNNING"
metadata: Dict[str, EventMetadataValue] = {
"run_id": spec.rdzv_handler.get_run_id(),
"global_rank": None,
"group_rank": wg.group_rank,
"worker_id": None,
"role": spec.role,
"hostname": self._get_fq_hostname(),
"state": state,
"total_run_time": self._total_execution_time,
"rdzv_backend": spec.rdzv_handler.get_backend(),
"raw_error": None,
"metadata": md_str,
"agent_restarts": spec.max_restarts - self._remaining_restarts,
}
# Note: The 'metadata' field of the Event is converted to a TorchelasticStatusLogEntry later.
# The 'name' field of the Event is NOT used in the TorchelasticStatusLogEntry.
event = events.Event(
name=name, source=events.EventSource.AGENT, metadata=metadata
)
events.record(event)
# pyre-fixme[56]: Pyre was not able to infer the type of the decorator
# `torch.distributed.elastic.metrics.prof`.
@prof
def _stop_workers(
self, worker_group: WorkerGroup, is_restart: bool = False
) -> None:
self._shutdown(is_restart=is_restart)
# pyre-fixme[56]: Pyre was not able to infer the type of the decorator
# `torch.distributed.elastic.metrics.prof`.
@prof
def _start_workers(self, worker_group: WorkerGroup) -> Dict[int, Any]:
spec = worker_group.spec
store = worker_group.store
assert store is not None
restart_count = spec.max_restarts - self._remaining_restarts
use_agent_store: bool = spec.rdzv_handler.use_agent_store
logger.info("use_agent_store: %s", use_agent_store)
args: Dict[int, Tuple] = {}
envs: Dict[int, Dict[str, str]] = {}
log_line_prefixes: Optional[Dict[int, str]] = (
{} if self._log_line_prefix_template else None
)
for worker in worker_group.workers:
local_rank = worker.local_rank
worker_env = {
"LOCAL_RANK": str(local_rank),
"RANK": str(worker.global_rank),
"GROUP_RANK": str(worker_group.group_rank),
"ROLE_RANK": str(worker.role_rank),
"ROLE_NAME": spec.role,
"LOCAL_WORLD_SIZE": str(spec.local_world_size),
"WORLD_SIZE": str(worker.world_size),
"GROUP_WORLD_SIZE": str(worker_group.group_world_size),
"ROLE_WORLD_SIZE": str(worker.role_world_size),
"MASTER_ADDR": worker_group.master_addr,
"MASTER_PORT": str(worker_group.master_port),
"TORCHELASTIC_RESTART_COUNT": str(restart_count),
"TORCHELASTIC_MAX_RESTARTS": str(spec.max_restarts),
"TORCHELASTIC_RUN_ID": spec.rdzv_handler.get_run_id(),
"TORCHELASTIC_USE_AGENT_STORE": str(use_agent_store),
"TORCH_NCCL_ASYNC_ERROR_HANDLING": os.getenv(
"TORCH_NCCL_ASYNC_ERROR_HANDLING", str(1)
),
}
if "OMP_NUM_THREADS" in os.environ:
worker_env["OMP_NUM_THREADS"] = os.environ["OMP_NUM_THREADS"]
if self._log_line_prefix_template:
log_line_prefix = Template(
self._log_line_prefix_template
).safe_substitute(
role_name=spec.role,
rank=worker.global_rank,
local_rank=local_rank,
)
log_line_prefixes[local_rank] = log_line_prefix
envs[local_rank] = worker_env
worker_args = list(spec.args)
worker_args = macros.substitute(worker_args, str(local_rank))
args[local_rank] = tuple(worker_args)
self._setup_local_watchdog(envs=envs)
self._setup_healthcheck()
assert spec.entrypoint is not None
assert self._logs_specs is not None
self._pcontext = start_processes(
name=spec.role,
entrypoint=spec.entrypoint,
args=args,
envs=envs,
logs_specs=self._logs_specs,
log_line_prefixes=log_line_prefixes,
start_method=self._start_method,
)
return self._pcontext.pids()
def _shutdown(
self, death_sig: signal.Signals = signal.SIGTERM, is_restart: bool = False
) -> None:
if self._worker_watchdog is not None:
self._worker_watchdog.stop()
self._worker_watchdog = None
if self._health_check_server is not None:
self._health_check_server.stop()
self._health_check_server = None
if self._pcontext:
self._pcontext.close(death_sig)
if not is_restart and self._rdzv_handler:
self._rdzv_handler.shutdown()
# pyre-fixme[56]: Pyre was not able to infer the type of the decorator
# `torch.distributed.elastic.metrics.prof`.
@prof
def _monitor_workers(self, worker_group: WorkerGroup) -> RunResult:
role = worker_group.spec.role
worker_pids = {w.id for w in worker_group.workers}
assert self._pcontext is not None
pc_pids = set(self._pcontext.pids().values())
if worker_pids != pc_pids:
logger.error(
"[%s] worker pids do not match process_context pids."
" Expected: %s, actual: %s",
role,
worker_pids,
pc_pids,
)
return RunResult(state=WorkerState.UNKNOWN)
result = self._pcontext.wait(0)
if result:
if result.is_failed():
# map local rank failure to global rank
worker_failures = {}
for local_rank, failure in result.failures.items():
worker = worker_group.workers[local_rank]
worker_failures[worker.global_rank] = failure
return RunResult(
state=WorkerState.FAILED,
failures=worker_failures,
)
else:
# copy ret_val_queue into a map with a global ranks
workers_ret_vals = {}
for local_rank, ret_val in result.return_values.items():
worker = worker_group.workers[local_rank]
workers_ret_vals[worker.global_rank] = ret_val
return RunResult(
state=WorkerState.SUCCEEDED,
return_values=workers_ret_vals,
)
else:
return RunResult(state=WorkerState.HEALTHY)
|