1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110
|
import numpy as np
import torch
from torch_geometric.loader import NeighborSampler
from torch_geometric.nn.conv import GATConv, SAGEConv
from torch_geometric.testing import onlyOnline, withPackage
from torch_geometric.typing import SparseTensor
from torch_geometric.utils import erdos_renyi_graph
@withPackage('torch_sparse')
def test_neighbor_sampler_basic():
edge_index = erdos_renyi_graph(num_nodes=10, edge_prob=0.5)
adj_t = SparseTensor.from_edge_index(edge_index, sparse_sizes=(10, 10)).t()
E = edge_index.size(1)
loader = NeighborSampler(edge_index, sizes=[2, 4], batch_size=2)
assert str(loader) == 'NeighborSampler(sizes=[2, 4])'
assert len(loader) == 5
for batch_size, n_id, adjs in loader:
assert batch_size == 2
assert all(np.isin(n_id, torch.arange(10)).tolist())
assert n_id.unique().size(0) == n_id.size(0)
for (edge_index, e_id, size) in adjs:
assert int(edge_index[0].max() + 1) <= size[0]
assert int(edge_index[1].max() + 1) <= size[1]
assert all(np.isin(e_id, torch.arange(E)).tolist())
assert e_id.unique().size(0) == e_id.size(0)
assert size[0] >= size[1]
out = loader.sample([1, 2])
assert len(out) == 3
loader = NeighborSampler(adj_t, sizes=[2, 4], batch_size=2)
for batch_size, n_id, adjs in loader:
for (adj_t, e_id, size) in adjs:
assert adj_t.size(0) == size[1]
assert adj_t.size(1) == size[0]
@withPackage('torch_sparse')
def test_neighbor_sampler_invalid_kwargs():
# Ignore `collate_fn` and `dataset` arguments:
edge_index = torch.tensor([[0, 1], [1, 0]])
NeighborSampler(edge_index, sizes=[-1], collate_fn=None, dataset=None)
@onlyOnline
@withPackage('torch_sparse')
def test_neighbor_sampler_on_cora(get_dataset):
dataset = get_dataset(name='Cora')
data = dataset[0]
batch = torch.arange(10)
loader = NeighborSampler(data.edge_index, sizes=[-1, -1, -1],
node_idx=batch, batch_size=10)
class SAGE(torch.nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.convs = torch.nn.ModuleList()
self.convs.append(SAGEConv(in_channels, 16))
self.convs.append(SAGEConv(16, 16))
self.convs.append(SAGEConv(16, out_channels))
def batch(self, x, adjs):
for i, (edge_index, _, size) in enumerate(adjs):
x_target = x[:size[1]] # Target nodes are always placed first.
x = self.convs[i]((x, x_target), edge_index)
return x
def full(self, x, edge_index):
for conv in self.convs:
x = conv(x, edge_index)
return x
model = SAGE(dataset.num_features, dataset.num_classes)
_, n_id, adjs = next(iter(loader))
out1 = model.batch(data.x[n_id], adjs)
out2 = model.full(data.x, data.edge_index)[batch]
assert torch.allclose(out1, out2, atol=1e-7)
class GAT(torch.nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.convs = torch.nn.ModuleList()
self.convs.append(GATConv(in_channels, 16, heads=2))
self.convs.append(GATConv(32, 16, heads=2))
self.convs.append(GATConv(32, out_channels, heads=2, concat=False))
def batch(self, x, adjs):
for i, (edge_index, _, size) in enumerate(adjs):
x_target = x[:size[1]] # Target nodes are always placed first.
x = self.convs[i]((x, x_target), edge_index)
return x
def full(self, x, edge_index):
for conv in self.convs:
x = conv(x, edge_index)
return x
_, n_id, adjs = next(iter(loader))
out1 = model.batch(data.x[n_id], adjs)
out2 = model.full(data.x, data.edge_index)[batch]
assert torch.allclose(out1, out2, atol=1e-7)
|