1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140
|
import torch
from torch_geometric.distributed import LocalFeatureStore
from torch_geometric.testing import onlyDistributedTest
@onlyDistributedTest
def test_local_feature_store_global_id():
store = LocalFeatureStore()
feat = torch.tensor([
[0.0, 0.0, 0.0],
[1.0, 1.0, 1.0],
[2.0, 2.0, 2.0],
[3.0, 3.0, 3.0],
[4.0, 4.0, 4.0],
[5.0, 5.0, 5.0],
[6.0, 6.0, 6.0],
[7.0, 7.0, 7.0],
[8.0, 8.0, 8.0],
])
paper_global_id = torch.tensor([1, 2, 3, 5, 8, 4])
paper_feat = feat[paper_global_id]
store.put_global_id(paper_global_id, group_name='paper')
store.put_tensor(paper_feat, group_name='paper', attr_name='feat')
out = store.get_tensor_from_global_id(group_name='paper', attr_name='feat',
index=torch.tensor([3, 8, 4]))
assert torch.equal(out, feat[torch.tensor([3, 8, 4])])
@onlyDistributedTest
def test_local_feature_store_utils():
store = LocalFeatureStore()
feat = torch.tensor([
[0.0, 0.0, 0.0],
[1.0, 1.0, 1.0],
[2.0, 2.0, 2.0],
[3.0, 3.0, 3.0],
[4.0, 4.0, 4.0],
[5.0, 5.0, 5.0],
[6.0, 6.0, 6.0],
[7.0, 7.0, 7.0],
[8.0, 8.0, 8.0],
])
paper_global_id = torch.tensor([1, 2, 3, 5, 8, 4])
paper_feat = feat[paper_global_id]
store.put_tensor(paper_feat, group_name='paper', attr_name='feat')
assert len(store.get_all_tensor_attrs()) == 1
attr = store.get_all_tensor_attrs()[0]
assert attr.group_name == 'paper'
assert attr.attr_name == 'feat'
assert attr.index is None
assert store.get_tensor_size(attr) == (6, 3)
@onlyDistributedTest
def test_homogeneous_feature_store():
node_id = torch.randperm(6)
x = torch.randn(6, 32)
y = torch.randint(0, 2, (6, ))
edge_id = torch.randperm(12)
edge_attr = torch.randn(12, 16)
store = LocalFeatureStore.from_data(node_id, x, y, edge_id, edge_attr)
assert len(store.get_all_tensor_attrs()) == 3
attrs = store.get_all_tensor_attrs()
assert attrs[0].group_name is None
assert attrs[0].attr_name == 'x'
assert attrs[1].group_name is None
assert attrs[1].attr_name == 'y'
assert attrs[2].group_name == (None, None)
assert attrs[2].attr_name == 'edge_attr'
assert torch.equal(store.get_global_id(group_name=None), node_id)
assert torch.equal(store.get_tensor(group_name=None, attr_name='x'), x)
assert torch.equal(store.get_tensor(group_name=None, attr_name='y'), y)
assert torch.equal(store.get_global_id(group_name=(None, None)), edge_id)
assert torch.equal(
store.get_tensor(group_name=(None, None), attr_name='edge_attr'),
edge_attr,
)
@onlyDistributedTest
def test_heterogeneous_feature_store():
node_type = 'paper'
edge_type = ('paper', 'to', 'paper')
node_id_dict = {node_type: torch.randperm(6)}
x_dict = {node_type: torch.randn(6, 32)}
y_dict = {node_type: torch.randint(0, 2, (6, ))}
edge_id_dict = {edge_type: torch.randperm(12)}
edge_attr_dict = {edge_type: torch.randn(12, 16)}
store = LocalFeatureStore.from_hetero_data(
node_id_dict,
x_dict,
y_dict,
edge_id_dict,
edge_attr_dict,
)
assert len(store.get_all_tensor_attrs()) == 3
attrs = store.get_all_tensor_attrs()
assert attrs[0].group_name == node_type
assert attrs[0].attr_name == 'x'
assert attrs[1].group_name == node_type
assert attrs[1].attr_name == 'y'
assert attrs[2].group_name == edge_type
assert attrs[2].attr_name == 'edge_attr'
assert torch.equal(
store.get_global_id(group_name=node_type),
node_id_dict[node_type],
)
assert torch.equal(
store.get_tensor(group_name=node_type, attr_name='x'),
x_dict[node_type],
)
assert torch.equal(
store.get_tensor(group_name=node_type, attr_name='y'),
y_dict[node_type],
)
assert torch.equal(
store.get_global_id(group_name=edge_type),
edge_id_dict[edge_type],
)
assert torch.equal(
store.get_tensor(group_name=edge_type, attr_name='edge_attr'),
edge_attr_dict[edge_type],
)
|