1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109
|
# In this example, you will find data loading implementations using PyTorch
# DataPipes (https://pytorch.org/data/) across various tasks:
# (1) molecular graph data loading pipe
# (2) mesh/point cloud data loading pipe
# In particular, we make use of PyG's built-in DataPipes, e.g., for batching
# multiple PyG data objects together or for converting SMILES strings into
# molecular graph representations. We also showcase how to write your own
# DataPipe (i.e. for loading and parsing mesh data into PyG data objects).
import argparse
import os.path as osp
import time
import torch
from torchdata.datapipes.iter import FileLister, FileOpener, IterDataPipe
from torch_geometric.data import Data, download_url, extract_zip
def molecule_datapipe() -> IterDataPipe:
# Download HIV dataset from MoleculeNet:
url = 'https://deepchemdata.s3-us-west-1.amazonaws.com/datasets'
root_dir = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data')
path = download_url(f'{url}/HIV.csv', root_dir)
datapipe = FileOpener([path])
datapipe = datapipe.parse_csv_as_dict()
datapipe = datapipe.parse_smiles(target_key='HIV_active')
datapipe = datapipe.in_memory_cache() # Cache graph instances in-memory.
return datapipe
@torch.utils.data.functional_datapipe('read_mesh')
class MeshOpener(IterDataPipe):
# A custom DataPipe to load and parse mesh data into PyG data objects.
def __init__(self, dp: IterDataPipe):
super().__init__()
self.dp = dp
def __iter__(self):
import meshio
for path in self.dp:
category = osp.basename(path).split('_')[0]
mesh = meshio.read(path)
pos = torch.from_numpy(mesh.points).to(torch.float)
face = torch.from_numpy(mesh.cells[0].data).t().contiguous()
yield Data(pos=pos, face=face, category=category)
def mesh_datapipe() -> IterDataPipe:
# Download ModelNet10 dataset from Princeton:
url = 'http://vision.princeton.edu/projects/2014/3DShapeNets'
root_dir = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data')
path = download_url(f'{url}/ModelNet10.zip', root_dir)
root_dir = osp.join(root_dir, 'ModelNet10')
if not osp.exists(root_dir):
extract_zip(path, root_dir)
def is_train(path: str) -> bool:
return 'train' in path
datapipe = FileLister([root_dir], masks='*.off', recursive=True)
datapipe = datapipe.filter(is_train)
datapipe = datapipe.read_mesh()
datapipe = datapipe.in_memory_cache() # Cache graph instances in-memory.
datapipe = datapipe.sample_points(1024) # Use PyG transforms from here.
datapipe = datapipe.knn_graph(k=8)
return datapipe
DATAPIPES = {
'molecule': molecule_datapipe,
'mesh': mesh_datapipe,
}
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--task', default='molecule', choices=DATAPIPES.keys())
args = parser.parse_args()
datapipe = DATAPIPES[args.task]()
print('Example output:')
print(next(iter(datapipe)))
# Shuffling + Batching support:
datapipe = datapipe.shuffle()
datapipe = datapipe.batch_graphs(batch_size=32)
# The first epoch will take longer than the remaining ones...
print('Iterating over all data...')
t = time.perf_counter()
for batch in datapipe:
pass
print(f'Done! [{time.perf_counter() - t:.2f}s]')
print('Iterating over all data a second time...')
t = time.perf_counter()
for batch in datapipe:
pass
print(f'Done! [{time.perf_counter() - t:.2f}s]')
|