1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106
|
import copy
from typing import Any
import pytest
import torch
from torch_geometric.data.storage import BaseStorage
def test_base_storage():
storage = BaseStorage()
assert storage._mapping == {}
storage.x = torch.zeros(1)
storage.y = torch.ones(1)
assert len(storage) == 2
assert storage._mapping == {'x': torch.zeros(1), 'y': torch.ones(1)}
assert storage.x is not None
assert storage.y is not None
assert torch.allclose(storage.get('x', None), storage.x)
assert torch.allclose(storage.get('y', None), storage.y)
assert storage.get('z', 2) == 2
assert storage.get('z', None) is None
assert len(list(storage.keys('x', 'y', 'z'))) == 2
assert len(list(storage.keys('x', 'y', 'z'))) == 2
assert len(list(storage.values('x', 'y', 'z'))) == 2
assert len(list(storage.items('x', 'y', 'z'))) == 2
del storage.y
assert len(storage) == 1
assert storage.x is not None
storage = BaseStorage({'x': torch.zeros(1)})
assert len(storage) == 1
assert storage.x is not None
storage = BaseStorage(x=torch.zeros(1))
assert len(storage) == 1
assert storage.x is not None
storage = BaseStorage(x=torch.zeros(1))
copied_storage = copy.copy(storage)
assert storage == copied_storage
assert id(storage) != id(copied_storage)
assert storage.x.data_ptr() == copied_storage.x.data_ptr()
assert int(storage.x) == 0
assert int(copied_storage.x) == 0
deepcopied_storage = copy.deepcopy(storage)
assert storage == deepcopied_storage
assert id(storage) != id(deepcopied_storage)
assert storage.x.data_ptr() != deepcopied_storage.x.data_ptr()
assert int(storage.x) == 0
assert int(deepcopied_storage.x) == 0
with pytest.raises(AttributeError, match="has no attribute 'asdf'"):
storage.asdf
def test_storage_tensor_methods():
x = torch.randn(5)
storage = BaseStorage({'x': x})
storage = storage.clone()
assert storage.x.data_ptr() != x.data_ptr()
storage = storage.contiguous()
assert storage.x.is_contiguous()
storage = storage.to('cpu')
assert storage.x.device == torch.device('cpu')
storage = storage.cpu()
assert storage.x.device == torch.device('cpu')
if torch.cuda.is_available():
storage = storage.pin_memory()
assert storage.x.is_pinned()
storage = storage.share_memory_()
assert storage.x.is_shared
storage = storage.detach_()
assert not storage.x.requires_grad
storage = storage.detach()
assert not storage.x.requires_grad
storage = storage.requires_grad_()
assert storage.x.requires_grad
def test_setter_and_getter():
class MyStorage(BaseStorage):
@property
def my_property(self) -> Any:
return self._my_property
@my_property.setter
def my_property(self, value: Any):
self._my_property = value
storage = MyStorage()
storage.my_property = 'hello'
assert storage.my_property == 'hello'
assert storage._my_property == storage._my_property
|