File: test_portal.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (157 lines) | stat: -rw-r--r-- 4,862 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
# Owner(s): ["oncall: distributed"]

# Copyright 2019 Kakao Brain
#
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import pytest
import torch

from torch.distributed.pipeline.sync.dependency import fork, join
from torch.distributed.pipeline.sync.skip.portal import Portal
from torch.distributed.pipeline.sync.stream import default_stream


@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
def test_copy_returns_on_next_device():
    portal = Portal(torch.rand(1), tensor_life=1)

    prev_stream = default_stream(torch.device("cpu"))
    next_stream = default_stream(torch.device("cuda"))

    phony = torch.zeros(0, requires_grad=True)
    assert phony.device.type == "cpu"

    phony = portal.copy(prev_stream, next_stream, phony)
    assert phony.device.type == "cuda"


def test_blue_orange():
    tensor1 = torch.rand(1, requires_grad=True)
    tensor2 = torch.rand(1, requires_grad=True)

    # Same with: output = tensor1*2 + tensor2
    #
    #                +----------------------+
    #                |                      |
    # tensor2 -- PortalBlue -+      +- PortalOrange -+
    #                        |      |                |
    # tensor1 ------------ Join -- Fork --- Mul --- Add -- output
    #
    main = tensor1
    portal = Portal(tensor2, tensor_life=2)
    phony = portal.blue()
    main = join(main, phony)
    main, phony = fork(main)
    sub = portal.orange(phony)
    output = main * 2 + sub

    output.backward()

    assert torch.allclose(tensor1.grad, torch.tensor([2.0]))
    assert torch.allclose(tensor2.grad, torch.tensor([1.0]))


def test_blue_orange_not_requires_grad():
    tensor1 = torch.rand(1, requires_grad=True)
    tensor2 = torch.rand(1)

    # Same with: output = tensor1*2 + tensor2
    #
    #                +----------------------+
    #                |                      |
    # tensor2 -- PortalBlue -+      +- PortalOrange -+
    #                        |      |                |
    # tensor1 ------------ Join -- Fork --- Mul --- Add -- output
    #
    main = tensor1
    portal = Portal(tensor2, tensor_life=2)
    phony = portal.blue()
    main = join(main, phony)
    main, phony = fork(main)
    sub = portal.orange(phony)
    output = main * 2 + sub

    output.backward()

    assert torch.allclose(tensor1.grad, torch.tensor([2.0]))
    assert tensor2.grad is None


def test_use_grad():
    tensor = torch.rand(1, requires_grad=True)
    portal = Portal(tensor, tensor_life=1)

    portal.put_grad(tensor)
    assert portal.use_grad() is tensor

    # Gradient in a portal is ephemeral.
    with pytest.raises(RuntimeError):
        portal.use_grad()


class TestTensorLife:
    @pytest.fixture
    def new_portal(self):
        portal = None

        def new_portal(tensor_life):
            nonlocal portal
            tensor = torch.rand(1, requires_grad=True)
            portal = Portal(tensor, tensor_life)
            return portal, tensor

        yield new_portal

        # A test using this fixture must exhaust the tensor in the portal.
        with pytest.raises(RuntimeError):
            portal.check_tensor_life()
        assert portal.tensor is None

    def test_tensor_life_0(self, new_portal):
        portal, tensor = new_portal(0)
        assert portal.tensor is None

    def test_tensor_life_1(self, new_portal):
        portal, tensor = new_portal(1)
        assert portal.tensor is tensor

        portal.blue()

    def test_tensor_life_2(self, new_portal):
        portal, tensor = new_portal(2)
        assert portal.tensor is tensor

        phony = portal.blue()
        assert portal.orange(phony).data_ptr() == tensor.data_ptr()

    def test_tensor_life_3(self, new_portal):
        portal, tensor = new_portal(3)
        assert portal.tensor is tensor

        phony = portal.blue()
        assert portal.orange(phony).data_ptr() == tensor.data_ptr()
        assert portal.orange(phony).data_ptr() == tensor.data_ptr()

    def test_tensor_life_4(self, new_portal):
        portal, tensor = new_portal(4)
        assert portal.tensor is tensor

        phony = portal.blue()
        assert portal.orange(phony).data_ptr() == tensor.data_ptr()
        assert portal.orange(phony).data_ptr() == tensor.data_ptr()
        portal.blue()

    def test_tensor_life_3_plus_1(self, new_portal):
        portal, tensor = new_portal(3)
        assert portal.tensor is tensor

        phony = portal.blue()
        assert portal.orange(phony).data_ptr() == tensor.data_ptr()
        assert portal.orange(phony).data_ptr() == tensor.data_ptr()

        another_tensor = torch.rand(1, requires_grad=True)
        portal.put_tensor(another_tensor, tensor_life=1)
        portal.blue()