1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138
|
import os
import os.path as osp
from typing import Callable, Dict, List, Optional
import torch
from torch_geometric.data import (
Data,
InMemoryDataset,
download_url,
extract_tar,
)
from torch_geometric.io import fs
class Wikidata5M(InMemoryDataset):
r"""The Wikidata-5M dataset from the `"KEPLER: A Unified Model for
Knowledge Embedding and Pre-trained Language Representation"
<https://arxiv.org/abs/1911.06136>`_ paper,
containing 4,594,485 entities, 822 relations,
20,614,279 train triples, 5,163 validation triples, and 5,133 test triples.
`Wikidata-5M <https://deepgraphlearning.github.io/project/wikidata5m>`_
is a large-scale knowledge graph dataset with aligned corpus
extracted form Wikidata.
Args:
root (str): Root directory where the dataset should be saved.
setting (str, optional):
If :obj:`"transductive"`, loads the transductive dataset.
If :obj:`"inductive"`, loads the inductive dataset.
(default: :obj:`"transductive"`)
transform (callable, optional): A function/transform that takes in an
:obj:`torch_geometric.data.Data` object and returns a transformed
version. The data object will be transformed before every access.
(default: :obj:`None`)
pre_transform (callable, optional): A function/transform that takes in
an :obj:`torch_geometric.data.Data` object and returns a
transformed version. The data object will be transformed before
being saved to disk. (default: :obj:`None`)
force_reload (bool, optional): Whether to re-process the dataset.
(default: :obj:`False`)
"""
def __init__(
self,
root: str,
setting: str = 'transductive',
transform: Optional[Callable] = None,
pre_transform: Optional[Callable] = None,
force_reload: bool = False,
) -> None:
if setting not in {'transductive', 'inductive'}:
raise ValueError(f"Invalid 'setting' argument (got '{setting}')")
self.setting = setting
self.urls = [
('https://www.dropbox.com/s/7jp4ib8zo3i6m10/'
'wikidata5m_text.txt.gz?dl=1'),
'https://uni-bielefeld.sciebo.de/s/yuBKzBxsEc9j3hy/download',
]
if self.setting == 'inductive':
self.urls.append('https://www.dropbox.com/s/csed3cgal3m7rzo/'
'wikidata5m_inductive.tar.gz?dl=1')
else:
self.urls.append('https://www.dropbox.com/s/6sbhm0rwo4l73jq/'
'wikidata5m_transductive.tar.gz?dl=1')
super().__init__(root, transform, pre_transform,
force_reload=force_reload)
self.load(self.processed_paths[0])
@property
def raw_file_names(self) -> List[str]:
return [
'wikidata5m_text.txt.gz',
'download',
f'wikidata5m_{self.setting}_train.txt',
f'wikidata5m_{self.setting}_valid.txt',
f'wikidata5m_{self.setting}_test.txt',
]
@property
def processed_file_names(self) -> str:
return f'{self.setting}_data.pt'
def download(self) -> None:
for url in self.urls:
download_url(url, self.raw_dir)
path = osp.join(self.raw_dir, f'wikidata5m_{self.setting}.tar.gz')
extract_tar(path, self.raw_dir)
os.remove(path)
def process(self) -> None:
import gzip
entity_to_id: Dict[str, int] = {}
with gzip.open(self.raw_paths[0], 'rt') as f:
for i, line in enumerate(f):
values = line.strip().split('\t')
entity_to_id[values[0]] = i
x = fs.torch_load(self.raw_paths[1])
edge_indices = []
edge_types = []
split_indices = []
rel_to_id: Dict[str, int] = {}
for split, path in enumerate(self.raw_paths[2:]):
with open(path) as f:
for line in f:
head, rel, tail = line[:-1].split('\t')
src = entity_to_id[head]
dst = entity_to_id[tail]
edge_indices.append([src, dst])
if rel not in rel_to_id:
rel_to_id[rel] = len(rel_to_id)
edge_types.append(rel_to_id[rel])
split_indices.append(split)
edge_index = torch.tensor(edge_indices).t().contiguous()
edge_type = torch.tensor(edge_types)
split_index = torch.tensor(split_indices)
data = Data(
x=x,
edge_index=edge_index,
edge_type=edge_type,
train_mask=split_index == 0,
val_mask=split_index == 1,
test_mask=split_index == 2,
)
if self.pre_transform is not None:
data = self.pre_transform(data)
self.save([data], self.processed_paths[0])
|