File: test_imbalanced_sampler.py

package info (click to toggle)
pytorch-geometric 2.6.1-7
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 12,904 kB
  • sloc: python: 127,155; sh: 338; cpp: 27; makefile: 18; javascript: 16
file content (116 lines) | stat: -rw-r--r-- 3,650 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
from typing import List

import torch

from torch_geometric.data import Data
from torch_geometric.datasets import FakeDataset, FakeHeteroDataset
from torch_geometric.loader import (
    DataLoader,
    ImbalancedSampler,
    NeighborLoader,
)
from torch_geometric.testing import onlyNeighborSampler


def test_dataloader_with_imbalanced_sampler():
    data_list: List[Data] = []
    for _ in range(10):
        data_list.append(Data(num_nodes=10, y=0))
    for _ in range(90):
        data_list.append(Data(num_nodes=10, y=1))

    torch.manual_seed(12345)
    sampler = ImbalancedSampler(data_list)
    loader = DataLoader(data_list, batch_size=10, sampler=sampler)

    y = torch.cat([batch.y for batch in loader])

    histogram = y.bincount()
    prob = histogram / histogram.sum()

    assert histogram.sum() == len(data_list)
    assert prob.min() > 0.4 and prob.max() < 0.6

    # Test with label tensor as input:
    torch.manual_seed(12345)
    sampler = ImbalancedSampler(torch.tensor([data.y for data in data_list]))
    loader = DataLoader(data_list, batch_size=10, sampler=sampler)

    assert torch.allclose(y, torch.cat([batch.y for batch in loader]))

    # Test with list of data objects as input where each y is a tensor:
    torch.manual_seed(12345)
    for data in data_list:
        data.y = torch.tensor([data.y])
    sampler = ImbalancedSampler(data_list)
    loader = DataLoader(data_list, batch_size=100, sampler=sampler)

    assert torch.allclose(y, torch.cat([batch.y for batch in loader]))


def test_in_memory_dataset_imbalanced_sampler():
    torch.manual_seed(12345)
    dataset = FakeDataset(num_graphs=100, avg_num_nodes=10, avg_degree=0,
                          num_channels=0, num_classes=2)
    sampler = ImbalancedSampler(dataset)
    loader = DataLoader(dataset, batch_size=10, sampler=sampler)

    y = torch.cat([batch.y for batch in loader])
    histogram = y.bincount()
    prob = histogram / histogram.sum()

    assert histogram.sum() == len(dataset)
    assert prob.min() > 0.4 and prob.max() < 0.6


@onlyNeighborSampler
def test_neighbor_loader_with_imbalanced_sampler():
    zeros = torch.zeros(10, dtype=torch.long)
    ones = torch.ones(90, dtype=torch.long)

    y = torch.cat([zeros, ones], dim=0)
    edge_index = torch.empty((2, 0), dtype=torch.long)
    data = Data(edge_index=edge_index, y=y, num_nodes=y.size(0))

    torch.manual_seed(12345)
    sampler = ImbalancedSampler(data)
    loader = NeighborLoader(data, batch_size=10, sampler=sampler,
                            num_neighbors=[-1])

    y = torch.cat([batch.y for batch in loader])

    histogram = y.bincount()
    prob = histogram / histogram.sum()

    assert histogram.sum() == data.num_nodes
    assert prob.min() > 0.4 and prob.max() < 0.6

    # Test with label tensor as input:
    torch.manual_seed(12345)
    sampler = ImbalancedSampler(data.y)
    loader = NeighborLoader(data, batch_size=10, sampler=sampler,
                            num_neighbors=[-1])

    assert torch.allclose(y, torch.cat([batch.y for batch in loader]))


@onlyNeighborSampler
def test_hetero_neighbor_loader_with_imbalanced_sampler():
    torch.manual_seed(12345)
    data = FakeHeteroDataset(num_classes=2)[0]

    loader = NeighborLoader(
        data,
        batch_size=100,
        input_nodes='v0',
        num_neighbors=[-1],
        sampler=ImbalancedSampler(data['v0'].y),
    )

    y = torch.cat([batch['v0'].y[:batch['v0'].batch_size] for batch in loader])

    histogram = y.bincount()
    prob = histogram / histogram.sum()

    assert histogram.sum() == data['v0'].num_nodes
    assert prob.min() > 0.4 and prob.max() < 0.6