File: _distributed_c10d.pyi

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (421 lines) | stat: -rw-r--r-- 9,964 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
from datetime import timedelta
from enum import Enum
from typing import Optional, List, Any, Tuple, overload, Union

from torch import Tensor

# This module is defined in torch/csrc/distributed/c10d/init.cpp

_DEFAULT_FIRST_BUCKET_BYTES: int
_DEFAULT_NO_TIMEOUT: timedelta
_DEFAULT_PG_TIMEOUT: timedelta

class BuiltinCommHookType(Enum):
    ALLREDUCE = ...
    FP16_COMPRESS = ...

def _register_comm_hook(reducer: Reducer, state: Any, comm_hook: Any): ...
def _register_builtin_comm_hook(
    reducer: Reducer, comm_hook_type: BuiltinCommHookType
): ...

class GradBucket:
    def index(self) -> int: ...
    def buffer(self) -> Tensor: ...
    def gradients(self) -> List[Tensor]: ...
    def is_last(self) -> bool: ...
    def set_buffer(self, tensor: Tensor) -> None: ...
    def parameters(self) -> List[Tensor]: ...

class Reducer:
    def __init__(
        self,
        params: List[Tensor],
        bucket_indices: List[List[int]],
        process_group: ProcessGroup,
        expect_sparse_gradients: List[bool],
        bucket_bytes_cap: int,
        find_unused_parameters: bool,
        gradient_as_bucket_view: bool,
    ): ...
    ...

class Logger:
    def __init__(self, reducer: Reducer): ...
    def set_construction_data_and_log(
        self,
        module_name: str,
        device_ids: List[int],
        output_device: int,
        broadcast_buffers: bool,
        has_sync_bn: bool,
    ): ...
    ...

def get_debug_level(): ...
def set_debug_level(): ...
def set_debug_level_from_env(): ...

class DebugLevel(Enum):
    OFF = ...
    INFO = ...
    DETAIL = ...

class ReduceOp:

    SUM = ...
    PRODUCT = ...
    MIN = ...
    MAX = ...
    BAND = ...
    BOR = ...
    BXOR = ...
    PREMUL_SUM = ...
    UNUSED = ...

    class RedOpType(Enum): ...

class BroadcastOptions:
    rootRank: int
    rootTensor: int
    timeout: timedelta

class AllreduceOptions:
    reduceOp: ReduceOp
    timeout: timedelta

class AllreduceCoalescedOptions(AllreduceOptions): ...

class ReduceOptions:
    reduceOp: ReduceOp
    rootRank: int
    rootTensor: int
    timeout: timedelta

class AllGatherOptions:
    timeout: timedelta

class GatherOptions:
    rootRank: int
    timeout: timedelta

class ScatterOptions:
    rootRank: int
    timeout: timedelta

class ReduceScatterOptions:
    reduceOp: ReduceOp
    timeout: timedelta

class BarrierOptions:
    device_ids: List[int]
    timeout: timedelta

class AllToAllOptions:
    timeout: timedelta

class Store:
    def set(self, key: str, value: str): ...
    def get(self, key: str) -> bytes: ...
    def add(self, key: str, value: int) -> int: ...
    def compare_set(self, key: str, expected_value: str, desired_value: str) -> bytes: ...
    def delete_key(self, key: str) -> bool: ...
    def num_keys(self) -> int: ...
    def set_timeout(self, timeout: timedelta): ...
    @overload
    def wait(self, keys: List[str]): ...
    @overload
    def wait(self, keys: List[str], timeout: timedelta): ...

class FileStore(Store):
    def __init__(self, path: str, numWorkers: int = ...): ...

class HashStore(Store):
    def __init__(self): ...

class TCPStore(Store):
    def __init__(
        self,
        host_name: str,
        port: int,
        world_size: Optional[int] = ...,
        is_master: bool = ...,
        timeout: timedelta = ...,
        wait_for_workers: bool = ...,
        multi_tenant: bool = ...
    ): ...
    @property
    def host(self) -> str: ...
    @property
    def port(self) -> int: ...

class PrefixStore(Store):
    def __init__(self, prefix: str, store: Store): ...
    @property
    def underlying_store(self) -> Store: ...

class Work:
    def is_completed(self) -> bool: ...
    def is_success(self) -> bool: ...
    def exception(self) -> Any: ...
    def wait(self, timeout: timedelta = _DEFAULT_NO_TIMEOUT) -> bool: ...
    def source_rank(self) -> int: ...
    def _source_rank(self) -> int: ...
    def result(self) -> List[Tensor]: ...
    def synchronize(self): ...
    ...

class ProcessGroup:
    class Options: ...
    def __init__(self): ...
    def rank(self) -> int: ...
    def size(self) -> int: ...
    @overload
    def broadcast(
        self,
        tensors: List[Tensor],
        opts=BroadcastOptions(),
    ) -> Work: ...
    @overload
    def broadcast(
        self,
        tensor: Tensor,
        root: int,
    ) -> Work: ...
    @overload
    def allreduce(
        self,
        tensors: List[Tensor],
        opts: AllreduceOptions = AllreduceOptions(),
    ) -> Work: ...
    @overload
    def allreduce(
        self,
        tensors: List[Tensor],
        op=ReduceOp.SUM,
    ) -> Work: ...
    @overload
    def allreduce(
        self,
        tensor: Tensor,
        op=ReduceOp.SUM,
    ) -> Work: ...
    def allreduce_coalesced(
        self,
        tensors: List[Tensor],
        opts=AllreduceCoalescedOptions(),
    ) -> Work: ...
    @overload
    def reduce(
        self,
        tensors: List[Tensor],
        opts=ReduceOptions(),
    ) -> Work: ...
    @overload
    def reduce(
        self,
        tensor: Tensor,
        root: int,
        op=ReduceOp.SUM,
    ) -> Work: ...
    @overload
    def allgather(
        self,
        output_tensors: List[List[Tensor]],
        input_tensors: List[Tensor],
        opts=AllGatherOptions(),
    ) -> Work: ...
    @overload
    def allgather(
        self,
        output_tensors: List[Tensor],
        input_tensor: Tensor,
    ) -> Work: ...
    def _allgather_base(
        self,
        output: Tensor,
        input: Tensor,
        opts = AllGatherOptions(),
    ) -> Work: ...
    def allgather_coalesced(
        self,
        output_lists: List[List[Tensor]],
        input_list: List[Tensor],
        opts=AllGatherOptions(),
    ) -> Work: ...
    @overload
    def gather(
        self,
        output_tensors: List[List[Tensor]],
        input_tensors: List[Tensor],
        opts=GatherOptions(),
    ) -> Work: ...
    @overload
    def gather(
        self,
        output_tensors: List[Tensor],
        input_tensor: Tensor,
        root: int,
    ) -> Work: ...
    @overload
    def scatter(
        self,
        output_tensors: List[Tensor],
        input_tensors: List[List[Tensor]],
        opts=ScatterOptions(),
    ) -> Work: ...
    @overload
    def scatter(
        self,
        output_tensor: Tensor,
        input_tensors: List[Tensor],
        root: int,
    ) -> Work: ...
    @overload
    def reduce_scatter(
        self,
        output_tensors: List[Tensor],
        input_tensors: List[List[Tensor]],
        opts=ReduceScatterOptions(),
    ) -> Work: ...
    @overload
    def reduce_scatter(
        self,
        output_tensors: Tensor,
        input_tensor: List[Tensor],
    ) -> Work: ...
    def _reduce_scatter_base(
        self,
        outputTensor: Tensor,
        inputTensor: Tensor,
    ) -> Work: ...
    @overload
    def alltoall_base(
        self,
        output_tensor: Tensor,
        input_tensor: Tensor,
        output_split_sizes: List[int],
        input_split_sizes: List[int],
        opts=AllToAllOptions(),
    ) -> Work: ...
    @overload
    def alltoall_base(
        self,
        output: Tensor,
        input: Tensor,
        output_split_sizes: List[int],
        input_split_sizes: List[int],
    ) -> Work: ...
    @overload
    def alltoall(
        self,
        output_tensor: List[Tensor],
        input_tensor: List[Tensor],
        opts=AllToAllOptions(),
    ) -> Work: ...
    @overload
    def alltoall(
        self,
        output: List[Tensor],
        input: List[Tensor],
    ) -> Work: ...
    def send(
        self,
        tensors: List[Tensor],
        dstRank: int,
        tag: int,
    ) -> Work: ...
    def recv(
        self,
        tensors: List[Tensor],
        srcRank: int,
        tag: int,
    ) -> Work: ...
    def recv_anysource(self, tensors: List[Tensor], tag: int) -> Work: ...
    def barrier(self, opts=BarrierOptions()) -> Work: ...

class ProcessGroupRoundRobin(ProcessGroup): ...

def _round_robin_process_groups(
    process_groups: List[ProcessGroup],
) -> ProcessGroupRoundRobin: ...

class ProcessGroupGloo(ProcessGroup):
    class Device: ...
    class Options: ...
    def __init__(
        self,
        store: Store,
        rank: int,
        size: int,
        timeout: timedelta,
    ): ...
    @staticmethod
    def create_device(hostname=str(), interface=str()) -> Device: ...
    ...
    @staticmethod
    def create_default_device() -> Device: ...
    ...

class _ProcessGroupWrapper(ProcessGroup):
    def __init__(
        self,
        pg: ProcessGroup,
        gloo_pg: ProcessGroupGloo
    ): ...
    wrapped_pg: ProcessGroup


class ProcessGroupNCCL(ProcessGroup):
    class Options: ...
    def __init__(
        self,
        store: Store,
        rank: int,
        size: int,
        timeout: timedelta,
    ): ...
    @staticmethod
    def _group_start() -> None: ...
    @staticmethod
    def _group_end() -> None: ...
    ...

class ProcessGroupUCC(ProcessGroup):
    def __init__(
        self,
        store: Store,
        rank: int,
        size: int,
        timeout: timedelta,
    ): ...

class ProcessGroupMPI(ProcessGroup):
    def __init__(
        self,
        rank: int,
        size: int,
        pgComm: int,
    ): ...
    @staticmethod
    def create(ranks: List[int]) -> ProcessGroupMPI: ...

def _compute_bucket_assignment_by_size(
    tensors: List[Tensor],
    bucket_size: int,
    expect_sparse_gradient: List[bool],
    tensor_indices: List[int],
) -> Tuple[List[List[int]], List[int]]: ...
def _broadcast_coalesced(
    process_group: ProcessGroup,
    tensors: List[Tensor],
    buffer_size: int,
    src: int,
): ...
def _test_python_store(store: Store): ...
def _verify_params_across_processes(
    process_group: ProcessGroup,
    params: List[Tensor],
    logger: Optional[Logger],
): ...
def _make_nccl_premul_sum(factor: Union[float, List[Tensor]]) -> ReduceOp: ...