1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163
|
import os
import os.path as osp
from itertools import product
from typing import Callable, List, Optional
import numpy as np
import torch
from torch_geometric.data import (
HeteroData,
InMemoryDataset,
download_url,
extract_zip,
)
class DBLP(InMemoryDataset):
r"""A subset of the DBLP computer science bibliography website, as
collected in the `"MAGNN: Metapath Aggregated Graph Neural Network for
Heterogeneous Graph Embedding" <https://arxiv.org/abs/2002.01680>`_ paper.
DBLP is a heterogeneous graph containing four types of entities - authors
(4,057 nodes), papers (14,328 nodes), terms (7,723 nodes), and conferences
(20 nodes).
The authors are divided into four research areas (database, data mining,
artificial intelligence, information retrieval).
Each author is described by a bag-of-words representation of their paper
keywords.
Args:
root (str): Root directory where the dataset should be saved.
transform (callable, optional): A function/transform that takes in an
:obj:`torch_geometric.data.HeteroData` object and returns a
transformed version. The data object will be transformed before
every access. (default: :obj:`None`)
pre_transform (callable, optional): A function/transform that takes in
an :obj:`torch_geometric.data.HeteroData` object and returns a
transformed version. The data object will be transformed before
being saved to disk. (default: :obj:`None`)
force_reload (bool, optional): Whether to re-process the dataset.
(default: :obj:`False`)
**STATS:**
.. list-table::
:widths: 20 10 10 10
:header-rows: 1
* - Node/Edge Type
- #nodes/#edges
- #features
- #classes
* - Author
- 4,057
- 334
- 4
* - Paper
- 14,328
- 4,231
-
* - Term
- 7,723
- 50
-
* - Conference
- 20
- 0
-
* - Author-Paper
- 196,425
-
-
* - Paper-Term
- 85,810
-
-
* - Conference-Paper
- 14,328
-
-
"""
url = 'https://www.dropbox.com/s/yh4grpeks87ugr2/DBLP_processed.zip?dl=1'
def __init__(
self,
root: str,
transform: Optional[Callable] = None,
pre_transform: Optional[Callable] = None,
force_reload: bool = False,
) -> None:
super().__init__(root, transform, pre_transform,
force_reload=force_reload)
self.load(self.processed_paths[0], data_cls=HeteroData)
@property
def raw_file_names(self) -> List[str]:
return [
'adjM.npz', 'features_0.npz', 'features_1.npz', 'features_2.npy',
'labels.npy', 'node_types.npy', 'train_val_test_idx.npz'
]
@property
def processed_file_names(self) -> str:
return 'data.pt'
def download(self) -> None:
path = download_url(self.url, self.raw_dir)
extract_zip(path, self.raw_dir)
os.remove(path)
def process(self) -> None:
import scipy.sparse as sp
data = HeteroData()
node_types = ['author', 'paper', 'term', 'conference']
for i, node_type in enumerate(node_types[:2]):
x = sp.load_npz(osp.join(self.raw_dir, f'features_{i}.npz'))
data[node_type].x = torch.from_numpy(x.todense()).to(torch.float)
x = np.load(osp.join(self.raw_dir, 'features_2.npy'))
data['term'].x = torch.from_numpy(x).to(torch.float)
node_type_idx = np.load(osp.join(self.raw_dir, 'node_types.npy'))
node_type_idx = torch.from_numpy(node_type_idx).to(torch.long)
data['conference'].num_nodes = int((node_type_idx == 3).sum())
y = np.load(osp.join(self.raw_dir, 'labels.npy'))
data['author'].y = torch.from_numpy(y).to(torch.long)
split = np.load(osp.join(self.raw_dir, 'train_val_test_idx.npz'))
for name in ['train', 'val', 'test']:
idx = split[f'{name}_idx']
idx = torch.from_numpy(idx).to(torch.long)
mask = torch.zeros(data['author'].num_nodes, dtype=torch.bool)
mask[idx] = True
data['author'][f'{name}_mask'] = mask
s = {}
N_a = data['author'].num_nodes
N_p = data['paper'].num_nodes
N_t = data['term'].num_nodes
N_c = data['conference'].num_nodes
s['author'] = (0, N_a)
s['paper'] = (N_a, N_a + N_p)
s['term'] = (N_a + N_p, N_a + N_p + N_t)
s['conference'] = (N_a + N_p + N_t, N_a + N_p + N_t + N_c)
A = sp.load_npz(osp.join(self.raw_dir, 'adjM.npz'))
for src, dst in product(node_types, node_types):
A_sub = A[s[src][0]:s[src][1], s[dst][0]:s[dst][1]].tocoo()
if A_sub.nnz > 0:
row = torch.from_numpy(A_sub.row).to(torch.long)
col = torch.from_numpy(A_sub.col).to(torch.long)
data[src, dst].edge_index = torch.stack([row, col], dim=0)
if self.pre_transform is not None:
data = self.pre_transform(data)
self.save([data], self.processed_paths[0])
def __repr__(self) -> str:
return f'{self.__class__.__name__}()'
|