File: test_graclus.py

package info (click to toggle)
pytorch-cluster 1.6.3-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 648 kB
  • sloc: cpp: 2,076; python: 1,081; sh: 53; makefile: 8
file content (56 lines) | stat: -rw-r--r-- 1,828 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from itertools import product

import pytest
import torch
from torch_cluster import graclus_cluster
from torch_cluster.testing import devices, dtypes, tensor

tests = [{
    'row': [0, 0, 1, 1, 1, 2, 2, 2, 3, 3],
    'col': [1, 2, 0, 2, 3, 0, 1, 3, 1, 2],
}, {
    'row': [0, 0, 1, 1, 1, 2, 2, 2, 3, 3],
    'col': [1, 2, 0, 2, 3, 0, 1, 3, 1, 2],
    'weight': [1, 2, 1, 3, 2, 2, 3, 1, 2, 1],
}]


def assert_correct(row, col, cluster):
    row, col, cluster = row.to('cpu'), col.to('cpu'), cluster.to('cpu')
    n = cluster.size(0)

    # Every node was assigned a cluster.
    assert cluster.min() >= 0

    # There are no more than two nodes in each cluster.
    _, index = torch.unique(cluster, return_inverse=True)
    count = torch.zeros_like(cluster)
    count.scatter_add_(0, index, torch.ones_like(cluster))
    assert (count > 2).max() == 0

    # Cluster value is minimal.
    assert (cluster <= torch.arange(n, dtype=cluster.dtype)).sum() == n

    # Corresponding clusters must be adjacent.
    for i in range(n):
        x = cluster[col[row == i]] == cluster[i]  # Neighbors with same cluster
        y = cluster == cluster[i]  # Nodes with same cluster.
        y[i] = 0  # Do not look at cluster of `i`.
        assert x.sum() == y.sum()


@pytest.mark.parametrize('test,dtype,device', product(tests, dtypes, devices))
def test_graclus_cluster(test, dtype, device):
    if dtype == torch.bfloat16 and device == torch.device('cuda:0'):
        return

    row = tensor(test['row'], torch.long, device)
    col = tensor(test['col'], torch.long, device)
    weight = tensor(test.get('weight'), dtype, device)

    cluster = graclus_cluster(row, col, weight)
    assert_correct(row, col, cluster)

    jit = torch.jit.script(graclus_cluster)
    cluster = jit(row, col, weight)
    assert_correct(row, col, cluster)