1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124
|
import torch
from torch_geometric.nn import GCNConv, Linear
from torch_geometric.testing import withDevice, withPackage
from torch_geometric.utils import (
bipartite_subgraph,
get_num_hops,
index_to_mask,
k_hop_subgraph,
subgraph,
)
def test_get_num_hops():
class GNN(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = GCNConv(3, 16, normalize=False)
self.conv2 = GCNConv(16, 16, normalize=False)
self.lin = Linear(16, 2)
def forward(self, x, edge_index):
x = torch.F.relu(self.conv1(x, edge_index))
x = self.conv2(x, edge_index)
return self.lin(x)
assert get_num_hops(GNN()) == 2
def test_subgraph():
edge_index = torch.tensor([
[0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6],
[1, 0, 2, 1, 3, 2, 4, 3, 5, 4, 6, 5],
])
edge_attr = torch.tensor(
[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0])
idx = torch.tensor([3, 4, 5])
mask = index_to_mask(idx, 7)
indices = idx.tolist()
for subset in [idx, mask, indices]:
out = subgraph(subset, edge_index, edge_attr, return_edge_mask=True)
assert out[0].tolist() == [[3, 4, 4, 5], [4, 3, 5, 4]]
assert out[1].tolist() == [7.0, 8.0, 9.0, 10.0]
assert out[2].tolist() == [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0]
out = subgraph(subset, edge_index, edge_attr, relabel_nodes=True)
assert out[0].tolist() == [[0, 1, 1, 2], [1, 0, 2, 1]]
assert out[1].tolist() == [7, 8, 9, 10]
@withDevice
@withPackage('pandas')
def test_subgraph_large_index(device):
subset = torch.tensor([50_000_000], device=device)
edge_index = torch.tensor([[50_000_000], [50_000_000]], device=device)
edge_index, _ = subgraph(subset, edge_index, relabel_nodes=True)
assert edge_index.tolist() == [[0], [0]]
def test_bipartite_subgraph():
edge_index = torch.tensor([[0, 5, 2, 3, 3, 4, 4, 3, 5, 5, 6],
[0, 0, 3, 2, 0, 0, 2, 1, 2, 3, 1]])
edge_attr = torch.tensor(
[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0])
idx = (torch.tensor([2, 3, 5]), torch.tensor([2, 3]))
mask = (index_to_mask(idx[0], 7), index_to_mask(idx[1], 4))
indices = (idx[0].tolist(), idx[1].tolist())
mixed = (mask[0], idx[1])
for subset in [idx, mask, indices, mixed]:
out = bipartite_subgraph(subset, edge_index, edge_attr,
return_edge_mask=True)
assert out[0].tolist() == [[2, 3, 5, 5], [3, 2, 2, 3]]
assert out[1].tolist() == [3.0, 4.0, 9.0, 10.0]
assert out[2].tolist() == [0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0]
out = bipartite_subgraph(subset, edge_index, edge_attr,
relabel_nodes=True)
assert out[0].tolist() == [[0, 1, 2, 2], [1, 0, 0, 1]]
assert out[1].tolist() == [3.0, 4.0, 9.0, 10.0]
@withDevice
@withPackage('pandas')
def test_bipartite_subgraph_large_index(device):
subset = torch.tensor([50_000_000], device=device)
edge_index = torch.tensor([[50_000_000], [50_000_000]], device=device)
edge_index, _ = bipartite_subgraph(
(subset, subset),
edge_index,
relabel_nodes=True,
)
assert edge_index.tolist() == [[0], [0]]
def test_k_hop_subgraph():
edge_index = torch.tensor([
[0, 1, 2, 3, 4, 5],
[2, 2, 4, 4, 6, 6],
])
subset, edge_index, mapping, edge_mask = k_hop_subgraph(
6, 2, edge_index, relabel_nodes=True)
assert subset.tolist() == [2, 3, 4, 5, 6]
assert edge_index.tolist() == [[0, 1, 2, 3], [2, 2, 4, 4]]
assert mapping.tolist() == [4]
assert edge_mask.tolist() == [False, False, True, True, True, True]
edge_index = torch.tensor([
[1, 2, 4, 5],
[0, 1, 5, 6],
])
subset, edge_index, mapping, edge_mask = k_hop_subgraph([0, 6], 2,
edge_index,
relabel_nodes=True)
assert subset.tolist() == [0, 1, 2, 4, 5, 6]
assert edge_index.tolist() == [[1, 2, 3, 4], [0, 1, 4, 5]]
assert mapping.tolist() == [0, 5]
assert edge_mask.tolist() == [True, True, True, True]
|