1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157
|
import glob
import os
import os.path as osp
from typing import Callable, List, Optional
import torch
from torch_geometric.data import InMemoryDataset, download_url, extract_zip
from torch_geometric.io import fs, read_off, read_txt_array
class SHREC2016(InMemoryDataset):
r"""The SHREC 2016 partial matching dataset from the `"SHREC'16: Partial
Matching of Deformable Shapes"
<http://www.dais.unive.it/~shrec2016/shrec16-partial.pdf>`_ paper.
The reference shape can be referenced via :obj:`dataset.ref`.
.. note::
Data objects hold mesh faces instead of edge indices.
To convert the mesh to a graph, use the
:obj:`torch_geometric.transforms.FaceToEdge` as :obj:`pre_transform`.
To convert the mesh to a point cloud, use the
:obj:`torch_geometric.transforms.SamplePoints` as :obj:`transform` to
sample a fixed number of points on the mesh faces according to their
face area.
Args:
root (str): Root directory where the dataset should be saved.
partiality (str): The partiality of the dataset (one of :obj:`"Holes"`,
:obj:`"Cuts"`).
category (str): The category of the dataset (one of
:obj:`"Cat"`, :obj:`"Centaur"`, :obj:`"David"`, :obj:`"Dog"`,
:obj:`"Horse"`, :obj:`"Michael"`, :obj:`"Victoria"`,
:obj:`"Wolf"`).
train (bool, optional): If :obj:`True`, loads the training dataset,
otherwise the test dataset. (default: :obj:`True`)
transform (callable, optional): A function/transform that takes in an
:obj:`torch_geometric.data.Data` object and returns a transformed
version. The data object will be transformed before every access.
(default: :obj:`None`)
pre_transform (callable, optional): A function/transform that takes in
an :obj:`torch_geometric.data.Data` object and returns a
transformed version. The data object will be transformed before
being saved to disk. (default: :obj:`None`)
pre_filter (callable, optional): A function that takes in an
:obj:`torch_geometric.data.Data` object and returns a boolean
value, indicating whether the data object should be included in the
final dataset. (default: :obj:`None`)
force_reload (bool, optional): Whether to re-process the dataset.
(default: :obj:`False`)
"""
train_url = ('http://www.dais.unive.it/~shrec2016/data/'
'shrec2016_PartialDeformableShapes.zip')
test_url = ('http://www.dais.unive.it/~shrec2016/data/'
'shrec2016_PartialDeformableShapes_TestSet.zip')
categories = [
'cat', 'centaur', 'david', 'dog', 'horse', 'michael', 'victoria',
'wolf'
]
partialities = ['holes', 'cuts']
def __init__(
self,
root: str,
partiality: str,
category: str,
train: bool = True,
transform: Optional[Callable] = None,
pre_transform: Optional[Callable] = None,
pre_filter: Optional[Callable] = None,
force_reload: bool = False,
) -> None:
assert partiality.lower() in self.partialities
self.part = partiality.lower()
assert category.lower() in self.categories
self.cat = category.lower()
super().__init__(root, transform, pre_transform, pre_filter,
force_reload=force_reload)
self.__ref__ = fs.torch_load(self.processed_paths[0])
path = self.processed_paths[1] if train else self.processed_paths[2]
self.load(path)
@property
def ref(self) -> str:
ref = self.__ref__
if self.transform is not None:
ref = self.transform(ref)
return ref
@property
def raw_file_names(self) -> List[str]:
return ['training', 'test']
@property
def processed_file_names(self) -> List[str]:
name = f'{self.part}_{self.cat}.pt'
return [f'{i}_{name}' for i in ['ref', 'training', 'test']]
def download(self) -> None:
path = download_url(self.train_url, self.raw_dir)
extract_zip(path, self.raw_dir)
os.unlink(path)
path = osp.join(self.raw_dir, 'shrec2016_PartialDeformableShapes')
os.rename(path, osp.join(self.raw_dir, 'training'))
path = download_url(self.test_url, self.raw_dir)
extract_zip(path, self.raw_dir)
os.unlink(path)
path = osp.join(self.raw_dir,
'shrec2016_PartialDeformableShapes_TestSet')
os.rename(path, osp.join(self.raw_dir, 'test'))
def process(self) -> None:
ref_data = read_off(
osp.join(self.raw_paths[0], 'null', f'{self.cat}.off'))
train_list = []
name = f'{self.part}_{self.cat}_*.off'
paths = glob.glob(osp.join(self.raw_paths[0], self.part, name))
paths = [path[:-4] for path in paths]
paths = sorted(paths, key=lambda e: (len(e), e))
for path in paths:
data = read_off(f'{path}.off')
y = read_txt_array(f'{path}.baryc_gt')
data.y = y[:, 0].to(torch.long) - 1
data.y_baryc = y[:, 1:]
train_list.append(data)
test_list = []
name = f'{self.part}_{self.cat}_*.off'
paths = glob.glob(osp.join(self.raw_paths[1], self.part, name))
paths = [path[:-4] for path in paths]
paths = sorted(paths, key=lambda e: (len(e), e))
for path in paths:
test_list.append(read_off(f'{path}.off'))
if self.pre_filter is not None:
train_list = [d for d in train_list if self.pre_filter(d)]
test_list = [d for d in test_list if self.pre_filter(d)]
if self.pre_transform is not None:
ref_data = self.pre_transform(ref_data)
train_list = [self.pre_transform(d) for d in train_list]
test_list = [self.pre_transform(d) for d in test_list]
torch.save(ref_data, self.processed_paths[0])
self.save(train_list, self.processed_paths[1])
self.save(test_list, self.processed_paths[2])
def __repr__(self) -> str:
return (f'{self.__class__.__name__}({len(self)}, '
f'partiality={self.part}, category={self.cat})')
|