File: test_dynamic_batch_sampler.py

package info (click to toggle)
pytorch-geometric 2.6.1-7
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 12,904 kB
  • sloc: python: 127,155; sh: 338; cpp: 27; makefile: 18; javascript: 16
file content (38 lines) | stat: -rw-r--r-- 1,253 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from typing import List

import pytest
import torch

from torch_geometric.data import Data
from torch_geometric.loader import DataLoader, DynamicBatchSampler


def test_dataloader_with_dynamic_batches():
    data_list: List[Data] = []
    for num_nodes in range(100, 110):
        data_list.append(Data(num_nodes=num_nodes))

    torch.manual_seed(12345)
    batch_sampler = DynamicBatchSampler(data_list, 300, shuffle=True)
    loader = DataLoader(data_list, batch_sampler=batch_sampler)

    num_nodes_total = 0
    for data in loader:
        assert data.num_nodes <= 300
        num_nodes_total += data.num_nodes
    assert num_nodes_total == 1045

    # Test skipping
    data_list = [Data(num_nodes=400)] + data_list
    batch_sampler = DynamicBatchSampler(data_list, 300, skip_too_big=True,
                                        num_steps=2)
    loader = DataLoader(data_list, batch_sampler=batch_sampler)

    num_nodes_total = 0
    for data in loader:
        num_nodes_total += data.num_nodes
    assert num_nodes_total == 404

    with pytest.raises(ValueError, match="length of 'DynamicBatchSampler'"):
        len(DynamicBatchSampler(data_list, max_num=300))
    assert len(DynamicBatchSampler(data_list, max_num=300, num_steps=2)) == 2