File: __init__.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (51 lines) | stat: -rw-r--r-- 1,629 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51

import sys
import torch


def is_available():
    return hasattr(torch._C, "_dist_autograd_init")


if is_available() and not torch._C._dist_autograd_init():
    raise RuntimeError("Failed to initialize torch.distributed.autograd")

if is_available():
    from torch._C._distributed_autograd import (
        get_gradients,
        backward,
        _init,
        _new_context,
        _release_context,
        _get_max_id,
        _is_valid_context,
        _retrieve_context,
        _current_context,
        _get_debug_info,
        DistAutogradContext,
    )

class context(object):
    '''
    Context object to wrap forward and backward passes when using
    distributed autograd. The ``context_id`` generated in the ``with``
    statement  is required to uniquely identify a distributed backward pass
    on all workers. Each worker stores metadata associated with this
    ``context_id``, which is required to correctly execute a distributed
    autograd pass.

    Example::
        >>> import torch.distributed.autograd as dist_autograd
        >>> # xdoctest: +SKIP
        >>> with dist_autograd.context() as context_id:
        >>>   t1 = torch.rand((3, 3), requires_grad=True)
        >>>   t2 = torch.rand((3, 3), requires_grad=True)
        >>>   loss = rpc.rpc_sync("worker1", torch.add, args=(t1, t2)).sum()
        >>>   dist_autograd.backward(context_id, [loss])
    '''
    def __enter__(self):
        self.autograd_context = _new_context()
        return self.autograd_context._context_id()

    def __exit__(self, type, value, traceback):
        _release_context(self.autograd_context._context_id())