File: test_storage.py

package info (click to toggle)
pytorch-geometric 2.6.1-7
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 12,904 kB
  • sloc: python: 127,155; sh: 338; cpp: 27; makefile: 18; javascript: 16
file content (106 lines) | stat: -rw-r--r-- 3,071 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import copy
from typing import Any

import pytest
import torch

from torch_geometric.data.storage import BaseStorage


def test_base_storage():
    storage = BaseStorage()
    assert storage._mapping == {}
    storage.x = torch.zeros(1)
    storage.y = torch.ones(1)
    assert len(storage) == 2
    assert storage._mapping == {'x': torch.zeros(1), 'y': torch.ones(1)}
    assert storage.x is not None
    assert storage.y is not None

    assert torch.allclose(storage.get('x', None), storage.x)
    assert torch.allclose(storage.get('y', None), storage.y)
    assert storage.get('z', 2) == 2
    assert storage.get('z', None) is None
    assert len(list(storage.keys('x', 'y', 'z'))) == 2
    assert len(list(storage.keys('x', 'y', 'z'))) == 2
    assert len(list(storage.values('x', 'y', 'z'))) == 2
    assert len(list(storage.items('x', 'y', 'z'))) == 2

    del storage.y
    assert len(storage) == 1
    assert storage.x is not None

    storage = BaseStorage({'x': torch.zeros(1)})
    assert len(storage) == 1
    assert storage.x is not None

    storage = BaseStorage(x=torch.zeros(1))
    assert len(storage) == 1
    assert storage.x is not None

    storage = BaseStorage(x=torch.zeros(1))
    copied_storage = copy.copy(storage)
    assert storage == copied_storage
    assert id(storage) != id(copied_storage)
    assert storage.x.data_ptr() == copied_storage.x.data_ptr()
    assert int(storage.x) == 0
    assert int(copied_storage.x) == 0

    deepcopied_storage = copy.deepcopy(storage)
    assert storage == deepcopied_storage
    assert id(storage) != id(deepcopied_storage)
    assert storage.x.data_ptr() != deepcopied_storage.x.data_ptr()
    assert int(storage.x) == 0
    assert int(deepcopied_storage.x) == 0

    with pytest.raises(AttributeError, match="has no attribute 'asdf'"):
        storage.asdf


def test_storage_tensor_methods():
    x = torch.randn(5)
    storage = BaseStorage({'x': x})

    storage = storage.clone()
    assert storage.x.data_ptr() != x.data_ptr()

    storage = storage.contiguous()
    assert storage.x.is_contiguous()

    storage = storage.to('cpu')
    assert storage.x.device == torch.device('cpu')

    storage = storage.cpu()
    assert storage.x.device == torch.device('cpu')

    if torch.cuda.is_available():
        storage = storage.pin_memory()
        assert storage.x.is_pinned()

    storage = storage.share_memory_()
    assert storage.x.is_shared

    storage = storage.detach_()
    assert not storage.x.requires_grad

    storage = storage.detach()
    assert not storage.x.requires_grad

    storage = storage.requires_grad_()
    assert storage.x.requires_grad


def test_setter_and_getter():
    class MyStorage(BaseStorage):
        @property
        def my_property(self) -> Any:
            return self._my_property

        @my_property.setter
        def my_property(self, value: Any):
            self._my_property = value

    storage = MyStorage()
    storage.my_property = 'hello'
    assert storage.my_property == 'hello'
    assert storage._my_property == storage._my_property