1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38
|
from typing import List
import pytest
import torch
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader, DynamicBatchSampler
def test_dataloader_with_dynamic_batches():
data_list: List[Data] = []
for num_nodes in range(100, 110):
data_list.append(Data(num_nodes=num_nodes))
torch.manual_seed(12345)
batch_sampler = DynamicBatchSampler(data_list, 300, shuffle=True)
loader = DataLoader(data_list, batch_sampler=batch_sampler)
num_nodes_total = 0
for data in loader:
assert data.num_nodes <= 300
num_nodes_total += data.num_nodes
assert num_nodes_total == 1045
# Test skipping
data_list = [Data(num_nodes=400)] + data_list
batch_sampler = DynamicBatchSampler(data_list, 300, skip_too_big=True,
num_steps=2)
loader = DataLoader(data_list, batch_sampler=batch_sampler)
num_nodes_total = 0
for data in loader:
num_nodes_total += data.num_nodes
assert num_nodes_total == 404
with pytest.raises(ValueError, match="length of 'DynamicBatchSampler'"):
len(DynamicBatchSampler(data_list, max_num=300))
assert len(DynamicBatchSampler(data_list, max_num=300, num_steps=2)) == 2
|