1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
|
import torch
from torch import Tensor
from torch_geometric.data import Data
from torch_geometric.loader import CachedLoader, NeighborLoader
from torch_geometric.testing import withDevice, withPackage
@withDevice
@withPackage('pyg_lib')
def test_cached_loader(device):
x = torch.randn(14, 16)
edge_index = torch.tensor([
[2, 3, 4, 5, 7, 7, 10, 11, 12, 13],
[0, 1, 2, 3, 2, 3, 7, 7, 7, 7],
])
loader = NeighborLoader(
Data(x=x, edge_index=edge_index),
num_neighbors=[2],
batch_size=10,
shuffle=False,
)
cached_loader = CachedLoader(loader, device=device)
assert len(cached_loader) == len(loader)
assert len(cached_loader._cache) == 0
cache = []
for i, batch in enumerate(cached_loader):
assert len(cached_loader._cache) == i + 1
assert batch.x.device == device
assert batch.edge_index.device == device
cache.append(batch)
for i, batch in enumerate(cached_loader):
assert batch == cache[i]
cached_loader.clear()
assert len(cached_loader._cache) == 0
@withDevice
@withPackage('pyg_lib')
def test_cached_loader_transform(device):
x = torch.randn(14, 16)
edge_index = torch.tensor([
[2, 3, 4, 5, 7, 7, 10, 11, 12, 13],
[0, 1, 2, 3, 2, 3, 7, 7, 7, 7],
])
loader = NeighborLoader(
Data(x=x, edge_index=edge_index),
num_neighbors=[2],
batch_size=10,
shuffle=False,
)
cached_loader = CachedLoader(
loader,
device=device,
transform=lambda batch: batch.edge_index,
)
assert len(cached_loader) == len(loader)
assert len(cached_loader._cache) == 0
cache = []
for i, batch in enumerate(cached_loader):
assert len(cached_loader._cache) == i + 1
assert isinstance(batch, Tensor)
assert batch.dim() == 2 and batch.size(0) == 2
assert batch.device == device
cache.append(batch)
for i, batch in enumerate(cached_loader):
assert torch.equal(batch, cache[i])
|