File: test_prefetch.py

package info (click to toggle)
pytorch-geometric 2.6.1-7
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 12,904 kB
  • sloc: python: 127,155; sh: 338; cpp: 27; makefile: 18; javascript: 16
file content (57 lines) | stat: -rw-r--r-- 1,579 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import torch

from torch_geometric.loader import NeighborLoader, PrefetchLoader
from torch_geometric.nn import GraphSAGE
from torch_geometric.testing import withCUDA


@withCUDA
def test_prefetch_loader(device):
    data = [torch.randn(5, 5) for _ in range(10)]

    loader = PrefetchLoader(data, device=device)
    assert str(loader).startswith('PrefetchLoader')
    assert len(loader) == 10

    for i, batch in enumerate(loader):
        assert batch.device == device
        assert torch.equal(batch.cpu(), data[i])


if __name__ == '__main__':
    import argparse

    from ogb.nodeproppred import PygNodePropPredDataset
    from tqdm import tqdm

    parser = argparse.ArgumentParser()
    parser.add_argument('--num_workers', type=int, default=0)
    args = parser.parse_args()

    data = PygNodePropPredDataset('ogbn-products', root='/tmp/ogb')[0]

    model = GraphSAGE(
        in_channels=data.x.size(-1),
        hidden_channels=64,
        num_layers=2,
    ).cuda()

    loader = NeighborLoader(
        data,
        input_nodes=torch.arange(1024 * 200),
        batch_size=1024,
        num_neighbors=[10, 10],
        num_workers=args.num_workers,
        persistent_workers=args.num_workers > 0,
    )

    print('Forward pass without prefetching...')
    for batch in tqdm(loader):
        with torch.no_grad():
            batch = batch.cuda()
            model(batch.x, batch.edge_index)

    print('Forward pass with prefetching...')
    for batch in tqdm(PrefetchLoader(loader)):
        with torch.no_grad():
            model(batch.x, batch.edge_index)