File: test_segment.py

package info (click to toggle)
pytorch-geometric 2.6.1-7
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 12,904 kB
  • sloc: python: 127,155; sh: 338; cpp: 27; makefile: 18; javascript: 16
file content (108 lines) | stat: -rw-r--r-- 3,772 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
from itertools import product

import pytest
import torch

import torch_geometric.typing
from torch_geometric.index import index2ptr
from torch_geometric.profile import benchmark
from torch_geometric.testing import withCUDA, withoutExtensions
from torch_geometric.utils import scatter, segment


@withCUDA
@withoutExtensions
@pytest.mark.parametrize('reduce', ['sum', 'mean', 'min', 'max'])
def test_segment(device, without_extensions, reduce):
    src = torch.randn(20, 16, device=device)
    ptr = torch.tensor([0, 0, 5, 10, 15, 20], device=device)

    if (not torch_geometric.typing.WITH_TORCH_SCATTER
            and not torch_geometric.typing.WITH_PT20):
        with pytest.raises(ImportError, match="requires the 'torch-scatter'"):
            segment(src, ptr, reduce=reduce)
    else:
        out = segment(src, ptr, reduce=reduce)

        expected = getattr(torch, reduce)(src.view(4, 5, -1), dim=1)
        expected = expected[0] if isinstance(expected, tuple) else expected

        assert torch.allclose(out[:1], torch.zeros(1, 16, device=device))
        assert torch.allclose(out[1:], expected)


if __name__ == '__main__':
    # Insights on GPU:
    # ================
    # * "mean": Prefer `torch._segment_reduce` implementation
    # * others: Prefer `torch_scatter` implementation
    #
    # Insights on CPU:
    # ================
    # * "all": Prefer `torch_scatter` implementation (but `scatter(...)`
    #          implementation is far superior due to multi-threading usage.
    import argparse

    from torch_geometric.typing import WITH_TORCH_SCATTER, torch_scatter

    parser = argparse.ArgumentParser()
    parser.add_argument('--device', type=str, default='cuda')
    parser.add_argument('--backward', action='store_true')
    parser.add_argument('--aggr', type=str, default='all')
    args = parser.parse_args()

    num_nodes_list = [4_000, 8_000, 16_000, 32_000, 64_000]

    if args.aggr == 'all':
        aggrs = ['sum', 'mean', 'min', 'max']
    else:
        aggrs = args.aggr.split(',')

    def pytorch_segment(x, ptr, reduce):
        if reduce == 'min' or reduce == 'max':
            reduce = f'a{aggr}'  # `amin` or `amax`
        return torch._segment_reduce(x, reduce, offsets=ptr)

    def own_segment(x, ptr, reduce):
        return torch_scatter.segment_csr(x, ptr, reduce=reduce)

    def optimized_scatter(x, index, reduce, dim_size):
        return scatter(x, index, dim=0, dim_size=dim_size, reduce=reduce)

    def optimized_segment(x, index, reduce):
        return segment(x, ptr, reduce=reduce)

    for aggr, num_nodes in product(aggrs, num_nodes_list):
        num_edges = num_nodes * 50
        print(f'aggr: {aggr}, #nodes: {num_nodes}, #edges: {num_edges}')

        x = torch.randn(num_edges, 64, device=args.device)
        index = torch.randint(num_nodes, (num_edges, ), device=args.device)
        index, _ = index.sort()
        ptr = index2ptr(index, size=num_nodes)

        funcs = [pytorch_segment]
        func_names = ['PyTorch segment_reduce']
        arg_list = [(x, ptr, aggr)]

        if WITH_TORCH_SCATTER:
            funcs.append(own_segment)
            func_names.append('torch_scatter')
            arg_list.append((x, ptr, aggr))

        funcs.append(optimized_scatter)
        func_names.append('Optimized PyG Scatter')
        arg_list.append((x, index, aggr, num_nodes))

        funcs.append(optimized_segment)
        func_names.append('Optimized PyG Segment')
        arg_list.append((x, ptr, aggr))

        benchmark(
            funcs=funcs,
            func_names=func_names,
            args=arg_list,
            num_steps=100 if args.device == 'cpu' else 1000,
            num_warmups=50 if args.device == 'cpu' else 500,
            backward=args.backward,
        )