File: portal.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (232 lines) | stat: -rw-r--r-- 7,251 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
# -*- coding: utf-8 -*-
# Copyright 2019 Kakao Brain
#
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
"""Portal keeps a tensor in the pocket plane. The tensor becomes hidden to the
autograd engine. The shared context of three functions (:class:`PortalBlue`,
:class:`PortalOrange`, and :class:`PortalCopy`) out of the computation graph is
one of the most important feature of :mod:`torchpipe.skip`.

The metaphor is inspired by Portalâ„¢ from Valve.

"""
from typing import List, Optional, Tuple

import torch
from torch import Tensor

from ..copy import Context as CopyContext
from ..copy import Copy
from ..phony import get_phony
from ..stream import AbstractStream, get_device

__all__: List[str] = []


class Portal:
    """A portal for a tensor."""

    def __init__(self, tensor: Optional[Tensor], tensor_life: int) -> None:
        self.put_tensor(tensor, tensor_life)
        self.grad: Optional[Tensor] = None

    def blue(self) -> Tensor:
        """Creates a :class:`PortalBlue` which hides the underlying tensor from
        the autograd engine.

        Join the returning phony to the main lane of the autograd graph to
        assure the correct backpropagation::

            PortalBlue --+
                         |
            ---------- Join --

        """
        tensor = self.use_tensor()

        if tensor is None:
            return get_phony(torch.device("cpu"), requires_grad=False)

        return PortalBlue.apply(self, tensor)

    def orange(self, phony: Tensor) -> Optional[Tensor]:
        """Creates a :class:`PortalOrange` which retrieves the hidden tensor
        without losing ability of backpropagation.

        Give a phony forked from the main lane of an autograd graph::

                +-- PortalOrange --+
                |                  |
            -- Fork --------- f(a, b) --

        """
        self.check_tensor_life()

        if self.tensor is None:
            return self.use_tensor()

        return PortalOrange.apply(self, phony)

    def copy(self, prev_stream: AbstractStream, next_stream: AbstractStream, phony: Tensor,) -> Tensor:
        """Copies the hidden tensor by a :class:`PortalCopy`.

        Give a phony and use the returning phony to keep backpropagation::

                +-- PortalCopy --+
                |                |
            -- Fork ---------- Join --

        """
        if self.tensor is None:
            return get_phony(torch.device("cpu"), requires_grad=False)

        return PortalCopy.apply(self, prev_stream, next_stream, phony)

    def check_tensor_life(self) -> None:
        if self.tensor_life <= 0:
            raise RuntimeError("tensor in portal has been removed")

    def put_tensor(self, tensor: Optional[Tensor], tensor_life: int) -> None:
        """Stores a tensor into this portal."""
        # [Life of Tensor through Portal]
        #
        # The tensor can be retrieved by use_tensor() up to 'tensor_life'
        # times. When the life becomes 0, the tensor will be deleted for
        # deallocation in CUDA memory.
        #
        # The below events participate in a tensor through a portal.
        # Note that [x] denotes the events which call use_tensor():
        #
        #  1. [x] blue()
        #  2. [ ]   PortalBlue.forward
        #  3. [ ] copy()
        #  4. [ ]   PortalCopy.forward
        #  5. [ ] orange()
        #  6. [x]   PortalOrange.forward
        # - - - - - - - - - - - - - - - - - - - - - - - - - - -
        #  7. [ ] orange() (recomputed)
        #  8. [x]   PortalOrange.forward (recomputed)
        #  9. [ ]   PortalOrange.backward
        # 10. [ ] PortalCopy.backward
        # 11. [x] blue() (recomputed)
        # 12. [ ]   PortalBlue.forward (recomputed)
        # 13. [ ]   PortalBlue.backward
        #
        self.tensor_life = tensor_life

        if tensor_life > 0:
            self.tensor = tensor
        else:
            self.tensor = None

    def use_tensor(self) -> Optional[Tensor]:
        """Retrieves the underlying tensor and decreases the tensor  life. When
        the life becomes 0, it the tensor will be removed.
        """
        self.check_tensor_life()

        tensor = self.tensor

        self.tensor_life -= 1

        if self.tensor_life <= 0:
            self.tensor = None

        return tensor

    def put_grad(self, grad: Tensor) -> None:
        """Stores a gradient into this portal."""
        self.grad = grad

    def use_grad(self) -> Tensor:
        """Retrieves and removes the underlying gradient. The gradient is
        always ephemeral.
        """
        if self.grad is None:
            raise RuntimeError("grad in portal has been removed or never set")

        grad = self.grad
        self.grad = None
        return grad


# Common interface between :class:`PortalBlue`, :class:`PortalOrange`, and
# :class:`PortalCopy`.
class Context(CopyContext):
    portal: Portal


class PortalBlue(torch.autograd.Function):
    """Hides a tensor from the autograd engine by a :class:`Portal`."""

    @staticmethod
    # type: ignore[override]
    def forward(
        ctx: Context,
        portal: Portal,
        # This tensor must be retrieved by portal.use_tensor().
        tensor: Tensor,
    ) -> Tensor:
        ctx.portal = portal

        phony = get_phony(tensor.device, requires_grad=False)
        return phony.detach()

    @staticmethod
    # type: ignore[override]
    def backward(ctx: Context, grad_phony: Tensor,) -> Tuple[None, Tensor]:
        # The paired PortalOrange should keep the gradient.
        grad = ctx.portal.use_grad()
        return None, grad


class PortalOrange(torch.autograd.Function):
    """Retrieves the hidden tensor from a :class:`Portal`."""

    @staticmethod
    # type: ignore[override]
    def forward(ctx: Context, portal: Portal, phony: Tensor) -> Tensor:
        ctx.portal = portal

        tensor = portal.use_tensor()
        assert tensor is not None

        return tensor.detach()

    @staticmethod
    def backward(ctx: Context, grad: Tensor) -> Tuple[None, None]:  # type: ignore[override]
        # The paired PortalBlue will use the gradient.
        ctx.portal.put_grad(grad)
        return None, None


class PortalCopy(torch.autograd.Function):
    """Copies the hidden tensor in a :class:`Portal`. It replaces the hidden
    tensor with copied one.
    """

    @staticmethod
    # type: ignore[override]
    def forward(
        ctx: Context, portal: Portal, prev_stream: AbstractStream, next_stream: AbstractStream, phony: Tensor,
    ) -> Tensor:
        ctx.portal = portal

        assert portal.tensor is not None
        (portal.tensor,) = Copy.forward(ctx, prev_stream, next_stream, portal.tensor)

        phony = get_phony(get_device(next_stream), requires_grad=False)
        return phony.detach()

    @staticmethod
    # type: ignore[override]
    def backward(ctx: Context, grad_phony: Tensor,) -> Tuple[None, None, None, None]:
        portal = ctx.portal

        assert portal.grad is not None
        _, _, portal.grad = Copy.backward(ctx, portal.grad)

        return None, None, None, None