File: test_data_parallel.py

package info (click to toggle)
pytorch-geometric 2.7.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 14,172 kB
  • sloc: python: 144,911; sh: 247; cpp: 27; makefile: 18; javascript: 16
file content (25 lines) | stat: -rw-r--r-- 859 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import pytest
import torch

from torch_geometric.data import Data
from torch_geometric.nn import DataParallel
from torch_geometric.testing import onlyCUDA


@onlyCUDA
def test_data_parallel_single_gpu():
    with pytest.warns(UserWarning, match="much slower"):
        module = DataParallel(torch.nn.Identity())
    data_list = [Data(x=torch.randn(x, 1)) for x in [2, 3, 10, 4]]
    batches = module.scatter(data_list, device_ids=[0])
    assert len(batches) == 1


@onlyCUDA
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason='No multiple GPUs')
def test_data_parallel_multi_gpu():
    with pytest.warns(UserWarning, match="much slower"):
        module = DataParallel(torch.nn.Identity())
    data_list = [Data(x=torch.randn(x, 1)) for x in [2, 3, 10, 4]]
    batches = module.scatter(data_list, device_ids=[0, 1, 0, 1])
    assert len(batches) == 3