1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138
|
# Owner(s): ["oncall: distributed"]
# Copyright 2019 Kakao Brain
#
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import pytest
import torch
from torch import nn
from torch.distributed.pipeline.sync.skip import pop, skippable, stash
from torch.distributed.pipeline.sync.skip.tracker import SkipTracker, use_skip_tracker
@pytest.fixture(autouse=True)
def skip_tracker():
skip_tracker = SkipTracker()
with use_skip_tracker(skip_tracker):
yield skip_tracker
def test_stash(skip_tracker):
@skippable(stash=["foo"])
class Stash(nn.Module):
def forward(self, input):
yield stash("foo", input)
return input * 2 # noqa: B901
l1 = Stash()
assert len(skip_tracker.tensors) == 0
with use_skip_tracker(skip_tracker):
l1(torch.tensor(42))
assert len(skip_tracker.tensors) == 1
def test_pop():
@skippable(stash=["foo"])
class Stash(nn.Module):
def forward(self, input):
yield stash("foo", input)
return input * 2 # noqa: B901
@skippable(pop=["foo"])
class Pop(nn.Module):
def forward(self, input):
foo = yield pop("foo")
return foo
l1 = Stash()
l2 = Pop()
output = l2(l1(torch.tensor(42)))
assert output.item() == 42
def test_declare_but_not_use():
@skippable(stash=["foo"])
class Stash(nn.Module):
def forward(self, input):
return input * 2
@skippable(pop=["foo"])
class Pop(nn.Module):
def forward(self, input):
return input * 3
l1 = Stash()
l2 = Pop()
with pytest.raises(RuntimeError):
l1(torch.tensor(42))
with pytest.raises(RuntimeError):
l2(torch.tensor(42))
def test_stash_not_declared():
@skippable()
class Stash(nn.Module):
def forward(self, input):
yield stash("foo", input)
return input * 2 # noqa: B901
l1 = Stash()
with pytest.raises(RuntimeError):
l1(torch.tensor(42))
def test_pop_not_declared():
@skippable(stash=["foo"])
class Stash(nn.Module):
def forward(self, input):
yield stash("foo", input)
return input * 2 # noqa: B901
@skippable()
class Pop(nn.Module):
def forward(self, input):
foo = yield pop("foo")
return foo
l1 = Stash()
l2 = Pop()
latent = l1(torch.tensor(42))
with pytest.raises(RuntimeError):
l2(latent)
def test_pop_not_stashed():
@skippable(pop=["foo"])
class Pop(nn.Module):
def forward(self, input):
yield pop("foo")
l1 = Pop()
with pytest.raises(RuntimeError):
l1(torch.tensor(42))
def test_stash_none():
@skippable(stash=["foo"])
class Stash(nn.Module):
def forward(self, input):
yield stash("foo", None)
return input * 2 # noqa: B901
l1 = Stash()
l1(torch.tensor(42))
|