1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52
|
from __future__ import annotations
import pytest
from distributed.protocol import deserialize, serialize
np = pytest.importorskip("numpy")
torch = pytest.importorskip("torch")
def test_tensor():
x = np.arange(10)
t = torch.Tensor(x)
header, frames = serialize(t)
assert header["serializer"] == "dask"
t2 = deserialize(header, frames)
assert (x == t2.numpy()).all()
@pytest.mark.parametrize("requires_grad", [True, False])
def test_grad(requires_grad):
x = np.arange(10)
t = torch.tensor(x, dtype=torch.float, requires_grad=requires_grad)
if requires_grad:
t.grad = torch.zeros_like(t) + 1
t2 = deserialize(*serialize(t))
assert t2.requires_grad is requires_grad
assert t.requires_grad is requires_grad
assert np.allclose(t2.detach().numpy(), x)
if requires_grad:
assert np.allclose(t2.grad.numpy(), 1)
def test_resnet():
torchvision = pytest.importorskip("torchvision")
model = torchvision.models.resnet.resnet18()
header, frames = serialize(model)
model2 = deserialize(header, frames)
assert str(model) == str(model2)
def test_deserialize_grad():
a = np.random.rand(8, 1)
t = torch.tensor(a, requires_grad=True, dtype=torch.float)
t2 = deserialize(*serialize(t))
assert t2.requires_grad
assert np.allclose(a, t2.detach_().numpy())
|