File: test_sparse.py

package info (click to toggle)
pytorch-geometric 2.6.1-7
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 12,904 kB
  • sloc: python: 127,155; sh: 338; cpp: 27; makefile: 18; javascript: 16
file content (311 lines) | stat: -rw-r--r-- 10,164 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
import os.path as osp

import pytest
import torch

import torch_geometric.typing
from torch_geometric.io import fs
from torch_geometric.profile import benchmark
from torch_geometric.testing import is_full_test, withCUDA, withPackage
from torch_geometric.typing import SparseTensor
from torch_geometric.utils import (
    dense_to_sparse,
    is_sparse,
    is_torch_sparse_tensor,
    to_edge_index,
    to_torch_coo_tensor,
    to_torch_csc_tensor,
    to_torch_csr_tensor,
    to_torch_sparse_tensor,
)
from torch_geometric.utils.sparse import cat


def test_dense_to_sparse():
    adj = torch.tensor([
        [3.0, 1.0],
        [2.0, 0.0],
    ])
    edge_index, edge_attr = dense_to_sparse(adj)
    assert edge_index.tolist() == [[0, 0, 1], [0, 1, 0]]
    assert edge_attr.tolist() == [3, 1, 2]

    if is_full_test():
        jit = torch.jit.script(dense_to_sparse)
        edge_index, edge_attr = jit(adj)
        assert edge_index.tolist() == [[0, 0, 1], [0, 1, 0]]
        assert edge_attr.tolist() == [3, 1, 2]

    adj = torch.tensor([[
        [3.0, 1.0],
        [2.0, 0.0],
    ], [
        [0.0, 1.0],
        [0.0, 2.0],
    ]])
    edge_index, edge_attr = dense_to_sparse(adj)
    assert edge_index.tolist() == [[0, 0, 1, 2, 3], [0, 1, 0, 3, 3]]
    assert edge_attr.tolist() == [3, 1, 2, 1, 2]

    if is_full_test():
        jit = torch.jit.script(dense_to_sparse)
        edge_index, edge_attr = jit(adj)
        assert edge_index.tolist() == [[0, 0, 1, 2, 3], [0, 1, 0, 3, 3]]
        assert edge_attr.tolist() == [3, 1, 2, 1, 2]

    adj = torch.tensor([
        [
            [3.0, 1.0, 0.0],
            [2.0, 0.0, 0.0],
            [0.0, 0.0, 0.0],
        ],
        [
            [0.0, 1.0, 0.0],
            [0.0, 2.0, 3.0],
            [0.0, 5.0, 0.0],
        ],
    ])
    mask = torch.tensor([[True, True, False], [True, True, True]])

    edge_index, edge_attr = dense_to_sparse(adj, mask)

    assert edge_index.tolist() == [[0, 0, 1, 2, 3, 3, 4],
                                   [0, 1, 0, 3, 3, 4, 3]]
    assert edge_attr.tolist() == [3, 1, 2, 1, 2, 3, 5]

    if is_full_test():
        jit = torch.jit.script(dense_to_sparse)
        edge_index, edge_attr = jit(adj, mask)
        assert edge_index.tolist() == [[0, 0, 1, 2, 3, 3, 4],
                                       [0, 1, 0, 3, 3, 4, 3]]
        assert edge_attr.tolist() == [3, 1, 2, 1, 2, 3, 5]


def test_dense_to_sparse_bipartite():
    edge_index, edge_attr = dense_to_sparse(torch.rand(2, 10, 5))
    assert edge_index[0].max() == 19
    assert edge_index[1].max() == 9


def test_is_torch_sparse_tensor():
    x = torch.randn(5, 5)

    assert not is_torch_sparse_tensor(x)
    assert is_torch_sparse_tensor(x.to_sparse())

    if torch_geometric.typing.WITH_TORCH_SPARSE:
        assert not is_torch_sparse_tensor(SparseTensor.from_dense(x))


def test_is_sparse():
    x = torch.randn(5, 5)

    assert not is_sparse(x)
    assert is_sparse(x.to_sparse())

    if torch_geometric.typing.WITH_TORCH_SPARSE:
        assert is_sparse(SparseTensor.from_dense(x))


def test_to_torch_coo_tensor():
    edge_index = torch.tensor([
        [0, 1, 1, 2, 2, 3],
        [1, 0, 2, 1, 3, 2],
    ])
    edge_attr = torch.randn(edge_index.size(1), 8)

    adj = to_torch_coo_tensor(edge_index, is_coalesced=False)
    assert adj.is_coalesced()
    assert adj.size() == (4, 4)
    assert adj.layout == torch.sparse_coo
    assert torch.allclose(adj.indices(), edge_index)

    adj = to_torch_coo_tensor(edge_index, is_coalesced=True)
    assert adj.is_coalesced()
    assert adj.size() == (4, 4)
    assert adj.layout == torch.sparse_coo
    assert torch.allclose(adj.indices(), edge_index)

    adj = to_torch_coo_tensor(edge_index, size=6)
    assert adj.size() == (6, 6)
    assert adj.layout == torch.sparse_coo
    assert torch.allclose(adj.indices(), edge_index)

    adj = to_torch_coo_tensor(edge_index, edge_attr)
    assert adj.size() == (4, 4, 8)
    assert adj.layout == torch.sparse_coo
    assert torch.allclose(adj.indices(), edge_index)
    assert torch.allclose(adj.values(), edge_attr)

    if is_full_test():
        jit = torch.jit.script(to_torch_coo_tensor)
        adj = jit(edge_index, edge_attr)
        assert adj.size() == (4, 4, 8)
        assert adj.layout == torch.sparse_coo
        assert torch.allclose(adj.indices(), edge_index)
        assert torch.allclose(adj.values(), edge_attr)


def test_to_torch_csr_tensor():
    edge_index = torch.tensor([
        [0, 1, 1, 2, 2, 3],
        [1, 0, 2, 1, 3, 2],
    ])

    adj = to_torch_csr_tensor(edge_index)
    assert adj.size() == (4, 4)
    assert adj.layout == torch.sparse_csr
    assert torch.allclose(adj.to_sparse_coo().coalesce().indices(), edge_index)

    edge_weight = torch.randn(edge_index.size(1))
    adj = to_torch_csr_tensor(edge_index, edge_weight)
    assert adj.size() == (4, 4)
    assert adj.layout == torch.sparse_csr
    coo = adj.to_sparse_coo().coalesce()
    assert torch.allclose(coo.indices(), edge_index)
    assert torch.allclose(coo.values(), edge_weight)

    if torch_geometric.typing.WITH_PT20:
        edge_attr = torch.randn(edge_index.size(1), 8)
        adj = to_torch_csr_tensor(edge_index, edge_attr)
        assert adj.size() == (4, 4, 8)
        assert adj.layout == torch.sparse_csr
        coo = adj.to_sparse_coo().coalesce()
        assert torch.allclose(coo.indices(), edge_index)
        assert torch.allclose(coo.values(), edge_attr)


def test_to_torch_csc_tensor():
    edge_index = torch.tensor([
        [0, 1, 1, 2, 2, 3],
        [1, 0, 2, 1, 3, 2],
    ])

    adj = to_torch_csc_tensor(edge_index)
    assert adj.size() == (4, 4)
    assert adj.layout == torch.sparse_csc
    adj_coo = adj.to_sparse_coo().coalesce()
    if torch_geometric.typing.WITH_PT20:
        assert torch.allclose(adj_coo.indices(), edge_index)
    else:
        assert torch.allclose(adj_coo.indices().flip([0]), edge_index)

    edge_weight = torch.randn(edge_index.size(1))
    adj = to_torch_csc_tensor(edge_index, edge_weight)
    assert adj.size() == (4, 4)
    assert adj.layout == torch.sparse_csc
    adj_coo = adj.to_sparse_coo().coalesce()
    if torch_geometric.typing.WITH_PT20:
        assert torch.allclose(adj_coo.indices(), edge_index)
        assert torch.allclose(adj_coo.values(), edge_weight)
    else:
        perm = adj_coo.indices()[0].argsort()
        assert torch.allclose(adj_coo.indices()[:, perm], edge_index)
        assert torch.allclose(adj_coo.values()[perm], edge_weight)

    if torch_geometric.typing.WITH_PT20:
        edge_attr = torch.randn(edge_index.size(1), 8)
        adj = to_torch_csc_tensor(edge_index, edge_attr)
        assert adj.size() == (4, 4, 8)
        assert adj.layout == torch.sparse_csc
        assert torch.allclose(adj.to_sparse_coo().coalesce().indices(),
                              edge_index)
        assert torch.allclose(adj.to_sparse_coo().coalesce().values(),
                              edge_attr)


@withPackage('torch>=2.1.0')
def test_to_torch_coo_tensor_save_load(tmp_path):
    edge_index = torch.tensor([
        [0, 1, 1, 2, 2, 3],
        [1, 0, 2, 1, 3, 2],
    ])
    adj = to_torch_coo_tensor(edge_index, is_coalesced=False)
    assert adj.is_coalesced()

    path = osp.join(tmp_path, 'adj.t')
    torch.save(adj, path)
    adj = fs.torch_load(path)
    assert adj.is_coalesced()


def test_to_edge_index():
    adj = torch.tensor([
        [0., 1., 0., 0.],
        [1., 0., 1., 0.],
        [0., 1., 0., 1.],
        [0., 0., 1., 0.],
    ]).to_sparse()

    edge_index, edge_attr = to_edge_index(adj)
    assert edge_index.tolist() == [[0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2]]
    assert edge_attr.tolist() == [1., 1., 1., 1., 1., 1.]

    if is_full_test():
        jit = torch.jit.script(to_edge_index)
        edge_index, edge_attr = jit(adj)
        assert edge_index.tolist() == [[0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2]]
        assert edge_attr.tolist() == [1., 1., 1., 1., 1., 1.]


@withCUDA
@pytest.mark.parametrize(
    'layout',
    [torch.sparse_coo, torch.sparse_csr, torch.sparse_csc],
)
@pytest.mark.parametrize('dim', [0, 1, (0, 1)])
def test_cat(layout, dim, device):
    edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], device=device)
    if torch_geometric.typing.WITH_PT20:
        edge_weight = torch.rand(4, 2, device=device)
    else:
        edge_weight = torch.rand(4, device=device)

    adj = to_torch_sparse_tensor(edge_index, edge_weight, layout=layout)

    out = cat([adj, adj], dim=dim)
    edge_index, edge_weight = to_edge_index(out.to_sparse_csr())

    if dim == 0:
        if torch_geometric.typing.WITH_PT20:
            assert out.size() == (6, 3, 2)
        else:
            assert out.size() == (6, 3)
        assert edge_index[0].tolist() == [0, 1, 1, 2, 3, 4, 4, 5]
        assert edge_index[1].tolist() == [1, 0, 2, 1, 1, 0, 2, 1]
    elif dim == 1:
        if torch_geometric.typing.WITH_PT20:
            assert out.size() == (3, 6, 2)
        else:
            assert out.size() == (3, 6)
        assert edge_index[0].tolist() == [0, 0, 1, 1, 1, 1, 2, 2]
        assert edge_index[1].tolist() == [1, 4, 0, 2, 3, 5, 1, 4]
    else:
        if torch_geometric.typing.WITH_PT20:
            assert out.size() == (6, 6, 2)
        else:
            assert out.size() == (6, 6)
        assert edge_index[0].tolist() == [0, 1, 1, 2, 3, 4, 4, 5]
        assert edge_index[1].tolist() == [1, 0, 2, 1, 4, 3, 5, 4]


if __name__ == '__main__':
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument('--device', type=str, default='cuda')
    args = parser.parse_args()

    num_nodes, num_edges = 10_000, 200_000
    edge_index = torch.randint(num_nodes, (2, num_edges), device=args.device)

    benchmark(
        funcs=[
            SparseTensor.from_edge_index, to_torch_coo_tensor,
            to_torch_csr_tensor, to_torch_csc_tensor
        ],
        func_names=['SparseTensor', 'To COO', 'To CSR', 'To CSC'],
        args=(edge_index, None, (num_nodes, num_nodes)),
        num_steps=50 if args.device == 'cpu' else 500,
        num_warmups=10 if args.device == 'cpu' else 100,
    )