File: test_temporal_dataloader.py

package info (click to toggle)
pytorch-geometric 2.6.1-7
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 12,904 kB
  • sloc: python: 127,155; sh: 338; cpp: 27; makefile: 18; javascript: 16
file content (28 lines) | stat: -rw-r--r-- 934 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import pytest
import torch

from torch_geometric.data import TemporalData
from torch_geometric.loader import TemporalDataLoader


@pytest.mark.parametrize('batch_size,drop_last', [(4, True), (2, False)])
def test_temporal_dataloader(batch_size, drop_last):
    src = dst = t = torch.arange(10)
    msg = torch.randn(10, 16)

    data = TemporalData(src=src, dst=dst, t=t, msg=msg)

    loader = TemporalDataLoader(
        data,
        batch_size=batch_size,
        drop_last=drop_last,
    )
    assert len(loader) == 10 // batch_size

    for i, batch in enumerate(loader):
        assert len(batch) == batch_size
        arange = range(len(batch) * i, len(batch) * i + len(batch))
        assert batch.src.tolist() == data.src[arange].tolist()
        assert batch.dst.tolist() == data.dst[arange].tolist()
        assert batch.t.tolist() == data.t[arange].tolist()
        assert batch.msg.tolist() == data.msg[arange].tolist()