File: store.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (225 lines) | stat: -rw-r--r-- 7,253 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
#!/usr/bin/env python3
# mypy: allow-untyped-defs

# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from contextlib import contextmanager
from datetime import timedelta
from typing import Callable, Iterable, List, Optional

import torch


DistStoreError = torch._C._DistStoreError

_NUM_MEMBERS = "/num_members"
_LAST_MEMBER_CHECKIN = "/last_member"
_TRACE = "/TRACE"
_TRACING_GATE = "/TRACING_GATE"
_MAX_TRACE_MISSING_RANKS = 16


__all__ = ["store_timeout", "get_all", "synchronize", "barrier"]


@contextmanager
def store_timeout(store, timeout: float):
    """
    This sets the timeout and then restores the old timeout when the context
    manager exits.

    Args:
        store: the store to set the timeout on
        timeout: the timeout to set
    """

    old_timeout = store.timeout
    store.set_timeout(timedelta(seconds=timeout))
    yield
    store.set_timeout(old_timeout)


def get_all(store, rank: int, prefix: str, world_size: int):
    r"""
    Given a store and a prefix, the method goes through the array of keys
    of the following format: ``{prefix}{idx}``, where idx is in a range
    from 0 to size, and tries to retrieve the data.

    The Rank0 process waits at the end to make sure all other processes
    finished the procedure before exiting.

    Usage

    ::

     values = get_all(store, 'torchelastic/data', 3)
     value1 = values[0] # retrieves the data for key torchelastic/data0
     value2 = values[1] # retrieves the data for key torchelastic/data1
     value3 = values[2] # retrieves the data for key torchelastic/data2

    """
    data_arr = store.multi_get([f"{prefix}{idx}" for idx in range(world_size)])

    barrier_key = _barrier_nonblocking(
        store=store,
        world_size=world_size,
        key_prefix=f"{prefix}/finished",
    )
    if rank == 0:
        # Rank0 runs the TCPStore daemon, as a result it needs to exit last.
        # Otherwise, the barrier may timeout if rank0 process finished the work
        # before other processes finished `get_all` method
        store.wait([barrier_key])

    return data_arr


def synchronize(
    store,
    data: bytes,
    rank: int,
    world_size: int,
    key_prefix: str,
    timeout: float = 300,
) -> List[bytes]:
    """
    Synchronizes ``world_size`` agents between each other using the underlying c10d store.
    The ``data`` will be available on each of the agents.

    Note: The data on the path is not deleted, as a result there can be stale data if
        you use the same key_prefix twice.

    Time complexity: O(N) per worker, O(N^2) globally.
    """
    with store_timeout(store, timeout):
        store.set(f"{key_prefix}{rank}", data)
        agent_data = get_all(store, rank, key_prefix, world_size)
        return agent_data


def _try_detecting_missing_ranks(
    store,
    world_size: int,
    key_prefix: str,
    rank: int,
    rank_decoder: Callable[[int], str],
    trace_timeout: float,
) -> Optional[Iterable[str]]:
    store.set(f"{key_prefix}{rank}{_TRACE}", "<val_ignored>")

    def _find_missing_ranks():
        missing_rank_info = set()
        ranks_missing = 0
        for i in range(1, world_size):
            # reduce noise, assuming in general 8 ranks per node
            # It is valuable to know that 1 or >1 nodes have timed-out.
            if ranks_missing >= _MAX_TRACE_MISSING_RANKS:
                break
            try:
                if ranks_missing == 0:
                    store.wait(
                        [f"{key_prefix}{i}{_TRACE}"], timedelta(seconds=trace_timeout)
                    )
                else:
                    # use a shortest timeout, some ranks have failed to check-in
                    store.wait([f"{key_prefix}{i}{_TRACE}"], timedelta(milliseconds=1))
            except DistStoreError:
                ranks_missing += 1
                missing_rank_info.add(rank_decoder(i))
        return missing_rank_info

    def _checkin():
        try:
            store.wait([f"{key_prefix}{_TRACING_GATE}"])
            return [f"[<check rank 0 ({rank_decoder(0)}) for missing rank info>]"]
        except DistStoreError:
            # in case rank0 is the source of the timeout, original exception will be raised
            return None

    if rank == 0:
        missing_rank_info = _find_missing_ranks()
        store.set(f"{key_prefix}{_TRACING_GATE}", "<val_ignored>")
        return missing_rank_info
    else:
        return _checkin()


def _barrier_nonblocking(store, world_size: int, key_prefix: str) -> str:
    """
    Does all the non-blocking operations for a barrier and returns the final key
    that can be waited on.
    """
    num_members_key = key_prefix + _NUM_MEMBERS
    last_member_key = key_prefix + _LAST_MEMBER_CHECKIN

    idx = store.add(num_members_key, 1)
    if idx == world_size:
        store.set(last_member_key, "<val_ignored>")

    return last_member_key


def barrier(
    store,
    world_size: int,
    key_prefix: str,
    barrier_timeout: float = 300,
    rank: Optional[int] = None,
    rank_tracing_decoder: Optional[Callable[[int], str]] = None,
    trace_timeout: float = 10,
) -> None:
    """
    A global lock between agents. This will pause all workers until at least
    ``world_size`` workers respond.

    This uses a fast incrementing index to assign waiting ranks and a success
    flag set by the last worker.

    Time complexity: O(1) per worker, O(N) globally.

    Optionally, passing rank will enable tracing of missing ranks on timeouts.
    `rank_tracing_decoder` lambda arg can be used to convert rank data
    into a more meaninful information at an app level (e.g. hostname).

    Note: Since the data is not removed from the store, the barrier can be used
        once per unique ``key_prefix``.
    """

    if rank is None:
        assert rank_tracing_decoder is None, "Tracing requires rank information"

    with store_timeout(store, barrier_timeout):
        last_member_key = _barrier_nonblocking(
            store=store, world_size=world_size, key_prefix=key_prefix
        )
        try:
            store.wait([last_member_key])
        except DistStoreError as e:
            if rank is None:
                raise e
            else:
                missing_ranks = _try_detecting_missing_ranks(
                    store,
                    world_size,
                    key_prefix,
                    rank,
                    rank_tracing_decoder or (lambda x: str(x)),
                    trace_timeout,
                )
                if missing_ranks is not None:
                    raise DistStoreError(
                        "Timed out waiting on barrier on "
                        "rank {}, for key prefix: {} (world_size={}, missing_ranks={}, timeout={})".format(
                            rank,
                            key_prefix,
                            world_size,
                            f"[{', '.join(missing_ranks)}]",
                            barrier_timeout,
                        )
                    ) from None
                else:
                    raise e