1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140
|
import torch
from torch_geometric.distributed import LocalGraphStore
from torch_geometric.testing import get_random_edge_index, onlyDistributedTest
@onlyDistributedTest
def test_local_graph_store():
graph_store = LocalGraphStore()
edge_index = get_random_edge_index(100, 100, 300)
edge_id = torch.tensor([1, 2, 3, 5, 8, 4])
graph_store.put_edge_index(
edge_index,
edge_type=None,
layout='coo',
size=(100, 100),
)
graph_store.put_edge_id(
edge_id,
edge_type=None,
layout='coo',
size=(100, 100),
)
assert len(graph_store.get_all_edge_attrs()) == 1
edge_attr = graph_store.get_all_edge_attrs()[0]
assert torch.equal(graph_store.get_edge_index(edge_attr), edge_index)
assert torch.equal(graph_store.get_edge_id(edge_attr), edge_id)
assert not graph_store.is_sorted
graph_store.remove_edge_index(edge_attr)
graph_store.remove_edge_id(edge_attr)
assert len(graph_store.get_all_edge_attrs()) == 0
@onlyDistributedTest
def test_homogeneous_graph_store():
edge_id = torch.randperm(300)
edge_index = get_random_edge_index(100, 100, 300)
edge_index[1] = torch.sort(edge_index[1])[0]
graph_store = LocalGraphStore.from_data(
edge_id,
edge_index,
num_nodes=100,
is_sorted=True,
)
assert len(graph_store.get_all_edge_attrs()) == 1
edge_attr = graph_store.get_all_edge_attrs()[0]
assert edge_attr.edge_type is None
assert edge_attr.layout.value == 'coo'
assert edge_attr.is_sorted
assert edge_attr.size == (100, 100)
assert torch.equal(
graph_store.get_edge_id(edge_type=None, layout='coo'),
edge_id,
)
assert torch.equal(
graph_store.get_edge_index(edge_type=None, layout='coo'),
edge_index,
)
@onlyDistributedTest
def test_heterogeneous_graph_store():
edge_type = ('paper', 'to', 'paper')
edge_id_dict = {edge_type: torch.randperm(300)}
edge_index = get_random_edge_index(100, 100, 300)
edge_index[1] = torch.sort(edge_index[1])[0]
edge_index_dict = {edge_type: edge_index}
graph_store = LocalGraphStore.from_hetero_data(
edge_id_dict,
edge_index_dict,
num_nodes_dict={'paper': 100},
is_sorted=True,
)
assert len(graph_store.get_all_edge_attrs()) == 1
edge_attr = graph_store.get_all_edge_attrs()[0]
assert edge_attr.edge_type == edge_type
assert edge_attr.layout.value == 'coo'
assert edge_attr.is_sorted
assert edge_attr.size == (100, 100)
assert torch.equal(
graph_store.get_edge_id(edge_type, layout='coo'),
edge_id_dict[edge_type],
)
assert torch.equal(
graph_store.get_edge_index(edge_type, layout='coo'),
edge_index_dict[edge_type],
)
@onlyDistributedTest
def test_sorted_graph_store():
edge_index_sorted = torch.tensor([[1, 7, 5, 6, 1], [0, 0, 1, 1, 2]])
edge_id_sorted = torch.tensor([0, 1, 2, 3, 4])
edge_index = torch.tensor([[1, 5, 7, 1, 6], [0, 1, 0, 2, 1]])
edge_id = torch.tensor([0, 2, 1, 4, 3])
graph_store = LocalGraphStore.from_data(
edge_id,
edge_index,
num_nodes=8,
is_sorted=False,
)
assert torch.equal(
graph_store.get_edge_index(edge_type=None, layout='coo'),
edge_index_sorted,
)
assert torch.equal(
graph_store.get_edge_id(edge_type=None, layout='coo'),
edge_id_sorted,
)
edge_type = ('paper', 'to', 'paper')
edge_index_dict = {edge_type: edge_index}
edge_id_dict = {edge_type: edge_id}
graph_store = LocalGraphStore.from_hetero_data(
edge_id_dict,
edge_index_dict,
num_nodes_dict={'paper': 8},
is_sorted=False,
)
assert torch.equal(
graph_store.get_edge_index(edge_type, layout='coo'),
edge_index_sorted,
)
assert torch.equal(
graph_store.get_edge_id(edge_type, layout='coo'),
edge_id_sorted,
)
|