1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190
|
# Owner(s): ["oncall: distributed"]
# Copyright 2019 Kakao Brain
#
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import pytest
import torch
from torch.distributed.pipeline.sync.stream import (
CPUStream,
current_stream,
default_stream,
get_device,
is_cuda,
new_stream,
record_stream,
use_device,
use_stream,
wait_stream,
)
skip_if_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
class TestNewStream:
def test_new_stream_cpu(self):
stream = new_stream(torch.device("cpu"))
assert stream is CPUStream
@skip_if_no_cuda
def test_new_stream_cuda(self):
stream = new_stream(torch.device("cuda"))
assert isinstance(stream, torch.cuda.Stream)
assert stream != torch.cuda.default_stream()
class TestCurrentStream:
def test_current_stream_cpu(self):
stream = current_stream(torch.device("cpu"))
assert stream is CPUStream
@skip_if_no_cuda
def test_current_stream_cuda(self):
stream = current_stream(torch.device("cuda"))
assert isinstance(stream, torch.cuda.Stream)
assert stream == torch.cuda.current_stream()
class TestDefaultStream:
def test_default_stream_cpu(self):
stream = default_stream(torch.device("cpu"))
assert stream is CPUStream
@skip_if_no_cuda
def test_default_stream_cuda(self):
stream = default_stream(torch.device("cuda"))
assert isinstance(stream, torch.cuda.Stream)
assert stream == torch.cuda.default_stream()
class TestUseDevice:
def test_use_device_cpu(self):
with use_device(torch.device("cpu")):
pass
@skip_if_no_cuda
def test_use_device_cuda(self):
with use_device(torch.device("cuda")):
pass
class TestUseStream:
def test_use_stream_cpu(self):
with use_stream(CPUStream):
pass
@skip_if_no_cuda
def test_use_stream_cuda(self):
stream = new_stream(torch.device("cuda"))
with use_stream(stream):
assert current_stream(torch.device("cuda")) == stream
class TestGetDevice:
def test_get_device_cpu(self):
assert get_device(CPUStream).type == "cpu"
@skip_if_no_cuda
def test_get_device_cuda(self):
stream = current_stream(torch.device("cuda"))
assert get_device(stream).type == "cuda"
class TestWaitStream:
def _test_wait_stream(self, source, target, cuda_sleep=None):
with use_stream(target):
if is_cuda(target):
cuda_sleep(0.5)
x = torch.ones(100, 100, device=get_device(target))
wait_stream(source, target)
with use_stream(source):
assert x.sum().item() == 10000
def test_wait_stream_cpu_cpu(self):
source = CPUStream
target = CPUStream
self._test_wait_stream(source, target)
@skip_if_no_cuda
def test_wait_stream_cpu_cuda(self, cuda_sleep):
source = CPUStream
target = new_stream(torch.device("cuda"))
self._test_wait_stream(source, target, cuda_sleep)
@skip_if_no_cuda
def test_wait_stream_cuda_cpu(self, cuda_sleep):
source = new_stream(torch.device("cuda"))
target = CPUStream
self._test_wait_stream(source, target, cuda_sleep)
@skip_if_no_cuda
def test_wait_stream_cuda_cuda(self, cuda_sleep):
source = current_stream(torch.device("cuda"))
target = new_stream(torch.device("cuda"))
self._test_wait_stream(source, target, cuda_sleep)
class TestRecordStream:
def test_record_stream_cpu(self):
# It should silently ignore CPU tensors.
x = torch.rand(1, device=torch.device("cpu"))
record_stream(x, CPUStream)
@skip_if_no_cuda
def test_record_stream_cuda(self, cuda_sleep):
# This test detects unexpected block reallocation. For reliable test,
# the stream to allocate tensors is isolated. The allocator will not
# reuse free blocks which were allocated from another stream.
stream_alloc = new_stream(torch.device("cuda"))
with torch.cuda.stream(stream_alloc):
x = torch.rand(1, device=torch.device("cuda"))
stream = new_stream(torch.device("cuda"))
record_stream(x, stream)
with use_stream(stream):
cuda_sleep(0.5)
# 'x' is deleted at Python's perspective. But the block of 'x' is still
# required for 'stream'. 'y' shouldn't be allocated to the block.
data_ptr = x.data_ptr()
del x
stream_alloc.synchronize()
with torch.cuda.stream(stream_alloc):
y = torch.rand(1, device=torch.device("cuda"))
assert y.data_ptr() != data_ptr
# Pause Python until 'stream' finishes tasks queued. Now the block of
# 'x' is free to be reallocated.
wait_stream(CPUStream, stream)
with torch.cuda.stream(stream_alloc):
z = torch.rand(1, device=torch.device("cuda"))
assert z.data_ptr() == data_ptr
@skip_if_no_cuda
def test_record_stream_shifted_view(self, cuda_sleep):
# Issue: https://github.com/pytorch/pytorch/issues/27366
stream_alloc = new_stream(torch.device("cuda"))
with torch.cuda.stream(stream_alloc):
x = torch.rand(2, device=torch.device("cuda"))
y = x[1:]
assert y.data_ptr() > x.data_ptr()
stream = new_stream(torch.device("cuda"))
with use_stream(stream):
cuda_sleep(0.5)
record_stream(y, stream)
data_ptr = x.data_ptr()
del x, y
stream_alloc.synchronize()
with torch.cuda.stream(stream_alloc):
z = torch.rand(2, device=torch.device("cuda"))
assert z.data_ptr() != data_ptr
|