File: test_dependency.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (146 lines) | stat: -rw-r--r-- 3,092 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
# Owner(s): ["oncall: distributed"]

# Copyright 2019 Kakao Brain
#
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import weakref

import pytest
import torch

from torch.distributed.pipeline.sync.dependency import Fork, Join, fork, join


@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
def test_fork_join():
    logs = []

    class Log(torch.autograd.Function):
        @staticmethod
        def forward(ctx, number, tensor):
            ctx.number = number
            return tensor.detach()

        @staticmethod
        def backward(ctx, grad):
            logs.append(ctx.number)
            return None, grad

    a = torch.rand(1, device="cpu", requires_grad=True)
    b = torch.rand(1, device="cuda", requires_grad=True)

    a = Log.apply(1, a)

    a, phony = fork(a)
    b = join(a, phony)

    b = Log.apply(2, b)
    b = b.to("cpu")

    (a + b).backward()

    assert logs == [2, 1]


def test_fork_join_enable_grad():
    x = torch.rand(1, requires_grad=True)

    with torch.enable_grad():
        x2, p = fork(x)

    assert p.requires_grad
    assert x2 is not x
    x = x2

    assert x.requires_grad
    assert p.requires_grad
    assert x.grad_fn.__class__ is Fork._backward_cls
    assert p.grad_fn.__class__ is Fork._backward_cls

    with torch.enable_grad():
        x2 = join(x, p)

    assert x2 is not x
    x = x2

    assert x.requires_grad
    assert x.grad_fn.__class__ is Join._backward_cls


def test_fork_join_no_grad(monkeypatch):
    def do_not_apply(*args):
        raise AssertionError("Function.apply called")

    monkeypatch.setattr("torch.autograd.Function.apply", do_not_apply)

    x = torch.rand(1, requires_grad=True)

    with torch.no_grad():
        x2, p = fork(x)

    assert not p.requires_grad
    assert x2 is x
    x = x2

    with torch.no_grad():
        x2 = join(x, p)

    assert x2 is x
    x = x2


def test_fork_leak():
    leak = None

    class F(torch.autograd.Function):
        @staticmethod
        def forward(ctx, input):
            return input

        @staticmethod
        def backward(ctx, grad):
            nonlocal leak
            leak = weakref.ref(ctx)
            return grad

    x = torch.rand(1, requires_grad=True)
    x = F.apply(x)
    x, phony = fork(x)
    x = join(x, phony)

    x.backward()
    del x, phony

    assert leak() is None


def test_join_when_fork_not_requires_grad():
    x = torch.rand(2, 1)
    a, b = x.chunk(2)

    assert not a.requires_grad
    a, p = fork(a)
    assert not a.requires_grad
    assert not p.requires_grad

    assert not b.requires_grad
    b = join(b, p)
    assert not b.requires_grad


def test_join_when_fork_requires_grad():
    x = torch.rand(2, 1)
    a, b = x.chunk(2)

    a.requires_grad_()
    assert a.requires_grad
    a, p = fork(a)
    assert a.requires_grad
    assert p.requires_grad

    assert not b.requires_grad
    b = join(b, p)
    assert b.requires_grad