File: _replicated_tensor_ddp_utils.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (31 lines) | stat: -rw-r--r-- 1,038 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from contextlib import contextmanager

_DDP_WITH_REPLICATED_TENSOR = False

@contextmanager
def _ddp_replicated_tensor(val):
    """
    A context manager to tag tensors in the forward pass of DDP to be
    ``ReplicatedTensor``. This can be used by ReplicatedTensor inter-op
    during the forward pass to perform appropriate optimizations.

    This context manager needs to wrap DDP creation and modifying the underlying
    module passed into DDP after leaving this context manager would cause
    inconsitencies and the changes will not be picked up during the forward
    pass.
    """
    global _DDP_WITH_REPLICATED_TENSOR
    old_val = _DDP_WITH_REPLICATED_TENSOR
    _DDP_WITH_REPLICATED_TENSOR = val
    try:
        yield
    finally:
        _DDP_WITH_REPLICATED_TENSOR = old_val

def _ddp_with_replicated_tensor_enabled():
    global _DDP_WITH_REPLICATED_TENSOR
    return _DDP_WITH_REPLICATED_TENSOR

def _set_ddp_with_replicated_tensor(value):
    global _DDP_WITH_REPLICATED_TENSOR
    _DDP_WITH_REPLICATED_TENSOR = value