1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124
|
import pytest
import torch
import torch_geometric.typing
from torch_geometric.data import Data, HeteroData
from torch_geometric.transforms import ToSparseTensor
@pytest.mark.parametrize('layout', [None, torch.sparse_coo, torch.sparse_csr])
def test_to_sparse_tensor_basic(layout):
transform = ToSparseTensor(layout=layout)
assert str(transform) == (f'ToSparseTensor(attr=edge_weight, '
f'layout={layout})')
edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])
edge_weight = torch.randn(edge_index.size(1))
edge_attr = torch.randn(edge_index.size(1), 8)
perm = torch.tensor([1, 0, 3, 2])
data = Data(edge_index=edge_index, edge_weight=edge_weight,
edge_attr=edge_attr, num_nodes=3)
data = transform(data)
assert len(data) == 3
assert data.num_nodes == 3
assert torch.equal(data.edge_attr, edge_attr[perm])
assert 'adj_t' in data
if layout is None and torch_geometric.typing.WITH_TORCH_SPARSE:
row, col, value = data.adj_t.coo()
assert row.tolist() == [0, 1, 1, 2]
assert col.tolist() == [1, 0, 2, 1]
assert torch.equal(value, edge_weight[perm])
else:
adj_t = data.adj_t
assert adj_t.layout == layout or torch.sparse_csr
if layout != torch.sparse_coo:
adj_t = adj_t.to_sparse_coo()
assert adj_t.coalesce().indices().tolist() == [
[0, 1, 1, 2],
[1, 0, 2, 1],
]
assert torch.equal(adj_t.coalesce().values(), edge_weight[perm])
def test_to_sparse_tensor_and_keep_edge_index():
edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])
edge_weight = torch.randn(edge_index.size(1))
edge_attr = torch.randn(edge_index.size(1), 8)
perm = torch.tensor([1, 0, 3, 2])
data = Data(edge_index=edge_index, edge_weight=edge_weight,
edge_attr=edge_attr, num_nodes=3)
data = ToSparseTensor(remove_edge_index=False)(data)
assert len(data) == 5
assert torch.equal(data.edge_index, edge_index[:, perm])
assert torch.equal(data.edge_weight, edge_weight[perm])
assert torch.equal(data.edge_attr, edge_attr[perm])
@pytest.mark.parametrize('layout', [None, torch.sparse_coo, torch.sparse_csr])
def test_hetero_to_sparse_tensor(layout):
edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])
data = HeteroData()
data['v'].num_nodes = 3
data['w'].num_nodes = 3
data['v', 'v'].edge_index = edge_index
data['v', 'w'].edge_index = edge_index
data = ToSparseTensor(layout=layout)(data)
if layout is None and torch_geometric.typing.WITH_TORCH_SPARSE:
row, col, value = data['v', 'v'].adj_t.coo()
assert row.tolist() == [0, 1, 1, 2]
assert col.tolist() == [1, 0, 2, 1]
assert value is None
row, col, value = data['v', 'w'].adj_t.coo()
assert row.tolist() == [0, 1, 1, 2]
assert col.tolist() == [1, 0, 2, 1]
assert value is None
else:
adj_t = data['v', 'v'].adj_t
assert adj_t.layout == layout or torch.sparse_csr
if layout != torch.sparse_coo:
adj_t = adj_t.to_sparse_coo()
assert adj_t.coalesce().indices().tolist() == [
[0, 1, 1, 2],
[1, 0, 2, 1],
]
assert adj_t.coalesce().values().tolist() == [1., 1., 1., 1.]
adj_t = data['v', 'w'].adj_t
assert adj_t.layout == layout or torch.sparse_csr
if layout != torch.sparse_coo:
adj_t = adj_t.to_sparse_coo()
assert adj_t.coalesce().indices().tolist() == [
[0, 1, 1, 2],
[1, 0, 2, 1],
]
assert adj_t.coalesce().values().tolist() == [1., 1., 1., 1.]
def test_to_sparse_tensor_num_nodes_equals_num_edges():
x = torch.arange(4)
y = torch.arange(4)
edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])
edge_weight = torch.randn(edge_index.size(1))
edge_attr = torch.randn(edge_index.size(1), 8)
perm = torch.tensor([1, 0, 3, 2])
data = Data(x=x, edge_index=edge_index, edge_weight=edge_weight,
edge_attr=edge_attr, y=y)
data = ToSparseTensor()(data)
assert len(data) == 4
assert torch.equal(data.x, x)
assert torch.equal(data.y, y)
assert torch.equal(data.edge_attr, edge_attr[perm])
|