File: test_on_disk_dataset.py

package info (click to toggle)
pytorch-geometric 2.6.1-7
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 12,904 kB
  • sloc: python: 127,155; sh: 338; cpp: 27; makefile: 18; javascript: 16
file content (111 lines) | stat: -rw-r--r-- 3,354 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import os.path as osp
from typing import Any, Dict

import torch

from torch_geometric.data import Data, OnDiskDataset
from torch_geometric.testing import withPackage


@withPackage('sqlite3')
def test_pickle(tmp_path):
    dataset = OnDiskDataset(tmp_path)
    assert len(dataset) == 0
    assert str(dataset) == 'OnDiskDataset(0)'
    assert osp.exists(osp.join(tmp_path, 'processed', 'sqlite.db'))

    data_list = [
        Data(
            x=torch.randn(5, 8),
            edge_index=torch.randint(0, 5, (2, 16)),
            num_nodes=5,
        ) for _ in range(4)
    ]

    dataset.append(data_list[0])
    assert len(dataset) == 1

    dataset.extend(data_list[1:])
    assert len(dataset) == 4

    out = dataset.get(0)
    assert torch.equal(out.x, data_list[0].x)
    assert torch.equal(out.edge_index, data_list[0].edge_index)
    assert out.num_nodes == data_list[0].num_nodes

    out_list = dataset.multi_get([1, 2, 3])
    for out, data in zip(out_list, data_list[1:]):
        assert torch.equal(out.x, data.x)
        assert torch.equal(out.edge_index, data.edge_index)
        assert out.num_nodes == data.num_nodes

    dataset.close()

    # Test persistence of datasets:
    dataset = OnDiskDataset(tmp_path)
    assert len(dataset) == 4

    out = dataset.get(0)
    assert torch.equal(out.x, data_list[0].x)
    assert torch.equal(out.edge_index, data_list[0].edge_index)
    assert out.num_nodes == data_list[0].num_nodes

    dataset.close()


@withPackage('sqlite3')
def test_custom_schema(tmp_path):
    class CustomSchemaOnDiskDataset(OnDiskDataset):
        def __init__(self, root: str):
            schema = {
                'x': dict(dtype=torch.float, size=(-1, 8)),
                'edge_index': dict(dtype=torch.long, size=(2, -1)),
                'num_nodes': int,
            }
            self.serialize_count = 0
            self.deserialize_count = 0
            super().__init__(root, schema=schema)

        def serialize(self, data: Data) -> Dict[str, Any]:
            self.serialize_count += 1
            return data.to_dict()

        def deserialize(self, mapping: Dict[str, Any]) -> Any:
            self.deserialize_count += 1
            return Data.from_dict(mapping)

    dataset = CustomSchemaOnDiskDataset(tmp_path)
    assert len(dataset) == 0
    assert str(dataset) == 'CustomSchemaOnDiskDataset(0)'
    assert osp.exists(osp.join(tmp_path, 'processed', 'sqlite.db'))

    data_list = [
        Data(
            x=torch.randn(5, 8),
            edge_index=torch.randint(0, 5, (2, 16)),
            num_nodes=5,
        ) for _ in range(4)
    ]

    dataset.append(data_list[0])
    assert dataset.serialize_count == 1
    assert len(dataset) == 1

    dataset.extend(data_list[1:])
    assert dataset.serialize_count == 4
    assert len(dataset) == 4

    out = dataset.get(0)
    assert dataset.deserialize_count == 1
    assert torch.equal(out.x, data_list[0].x)
    assert torch.equal(out.edge_index, data_list[0].edge_index)
    assert out.num_nodes == data_list[0].num_nodes

    out_list = dataset.multi_get([1, 2, 3])
    assert dataset.deserialize_count == 4
    for out, data in zip(out_list, data_list[1:]):
        assert torch.equal(out.x, data.x)
        assert torch.equal(out.edge_index, data.edge_index)
        assert out.num_nodes == data.num_nodes

    dataset.close()