1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232
|
# -*- coding: utf-8 -*-
# Copyright 2019 Kakao Brain
#
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
"""Portal keeps a tensor in the pocket plane. The tensor becomes hidden to the
autograd engine. The shared context of three functions (:class:`PortalBlue`,
:class:`PortalOrange`, and :class:`PortalCopy`) out of the computation graph is
one of the most important feature of :mod:`torchpipe.skip`.
The metaphor is inspired by Portalâ„¢ from Valve.
"""
from typing import List, Optional, Tuple
import torch
from torch import Tensor
from ..copy import Context as CopyContext
from ..copy import Copy
from ..phony import get_phony
from ..stream import AbstractStream, get_device
__all__: List[str] = []
class Portal:
"""A portal for a tensor."""
def __init__(self, tensor: Optional[Tensor], tensor_life: int) -> None:
self.put_tensor(tensor, tensor_life)
self.grad: Optional[Tensor] = None
def blue(self) -> Tensor:
"""Creates a :class:`PortalBlue` which hides the underlying tensor from
the autograd engine.
Join the returning phony to the main lane of the autograd graph to
assure the correct backpropagation::
PortalBlue --+
|
---------- Join --
"""
tensor = self.use_tensor()
if tensor is None:
return get_phony(torch.device("cpu"), requires_grad=False)
return PortalBlue.apply(self, tensor)
def orange(self, phony: Tensor) -> Optional[Tensor]:
"""Creates a :class:`PortalOrange` which retrieves the hidden tensor
without losing ability of backpropagation.
Give a phony forked from the main lane of an autograd graph::
+-- PortalOrange --+
| |
-- Fork --------- f(a, b) --
"""
self.check_tensor_life()
if self.tensor is None:
return self.use_tensor()
return PortalOrange.apply(self, phony)
def copy(self, prev_stream: AbstractStream, next_stream: AbstractStream, phony: Tensor,) -> Tensor:
"""Copies the hidden tensor by a :class:`PortalCopy`.
Give a phony and use the returning phony to keep backpropagation::
+-- PortalCopy --+
| |
-- Fork ---------- Join --
"""
if self.tensor is None:
return get_phony(torch.device("cpu"), requires_grad=False)
return PortalCopy.apply(self, prev_stream, next_stream, phony)
def check_tensor_life(self) -> None:
if self.tensor_life <= 0:
raise RuntimeError("tensor in portal has been removed")
def put_tensor(self, tensor: Optional[Tensor], tensor_life: int) -> None:
"""Stores a tensor into this portal."""
# [Life of Tensor through Portal]
#
# The tensor can be retrieved by use_tensor() up to 'tensor_life'
# times. When the life becomes 0, the tensor will be deleted for
# deallocation in CUDA memory.
#
# The below events participate in a tensor through a portal.
# Note that [x] denotes the events which call use_tensor():
#
# 1. [x] blue()
# 2. [ ] PortalBlue.forward
# 3. [ ] copy()
# 4. [ ] PortalCopy.forward
# 5. [ ] orange()
# 6. [x] PortalOrange.forward
# - - - - - - - - - - - - - - - - - - - - - - - - - - -
# 7. [ ] orange() (recomputed)
# 8. [x] PortalOrange.forward (recomputed)
# 9. [ ] PortalOrange.backward
# 10. [ ] PortalCopy.backward
# 11. [x] blue() (recomputed)
# 12. [ ] PortalBlue.forward (recomputed)
# 13. [ ] PortalBlue.backward
#
self.tensor_life = tensor_life
if tensor_life > 0:
self.tensor = tensor
else:
self.tensor = None
def use_tensor(self) -> Optional[Tensor]:
"""Retrieves the underlying tensor and decreases the tensor life. When
the life becomes 0, it the tensor will be removed.
"""
self.check_tensor_life()
tensor = self.tensor
self.tensor_life -= 1
if self.tensor_life <= 0:
self.tensor = None
return tensor
def put_grad(self, grad: Tensor) -> None:
"""Stores a gradient into this portal."""
self.grad = grad
def use_grad(self) -> Tensor:
"""Retrieves and removes the underlying gradient. The gradient is
always ephemeral.
"""
if self.grad is None:
raise RuntimeError("grad in portal has been removed or never set")
grad = self.grad
self.grad = None
return grad
# Common interface between :class:`PortalBlue`, :class:`PortalOrange`, and
# :class:`PortalCopy`.
class Context(CopyContext):
portal: Portal
class PortalBlue(torch.autograd.Function):
"""Hides a tensor from the autograd engine by a :class:`Portal`."""
@staticmethod
# type: ignore[override]
def forward(
ctx: Context,
portal: Portal,
# This tensor must be retrieved by portal.use_tensor().
tensor: Tensor,
) -> Tensor:
ctx.portal = portal
phony = get_phony(tensor.device, requires_grad=False)
return phony.detach()
@staticmethod
# type: ignore[override]
def backward(ctx: Context, grad_phony: Tensor,) -> Tuple[None, Tensor]:
# The paired PortalOrange should keep the gradient.
grad = ctx.portal.use_grad()
return None, grad
class PortalOrange(torch.autograd.Function):
"""Retrieves the hidden tensor from a :class:`Portal`."""
@staticmethod
# type: ignore[override]
def forward(ctx: Context, portal: Portal, phony: Tensor) -> Tensor:
ctx.portal = portal
tensor = portal.use_tensor()
assert tensor is not None
return tensor.detach()
@staticmethod
def backward(ctx: Context, grad: Tensor) -> Tuple[None, None]: # type: ignore[override]
# The paired PortalBlue will use the gradient.
ctx.portal.put_grad(grad)
return None, None
class PortalCopy(torch.autograd.Function):
"""Copies the hidden tensor in a :class:`Portal`. It replaces the hidden
tensor with copied one.
"""
@staticmethod
# type: ignore[override]
def forward(
ctx: Context, portal: Portal, prev_stream: AbstractStream, next_stream: AbstractStream, phony: Tensor,
) -> Tensor:
ctx.portal = portal
assert portal.tensor is not None
(portal.tensor,) = Copy.forward(ctx, prev_stream, next_stream, portal.tensor)
phony = get_phony(get_device(next_stream), requires_grad=False)
return phony.detach()
@staticmethod
# type: ignore[override]
def backward(ctx: Context, grad_phony: Tensor,) -> Tuple[None, None, None, None]:
portal = ctx.portal
assert portal.grad is not None
_, _, portal.grad = Copy.backward(ctx, portal.grad)
return None, None, None, None
|