1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61
|
import torch
from torch_geometric.data import Data, Dataset, InMemoryDataset
class MyData(Data):
def __init__(self, x=None, edge_index=None, arg=None):
super().__init__(x=x, edge_index=edge_index, arg=arg)
def random(self):
return torch.randn(list(self.x.size()) + list(self.arg.size()))
class MyInMemoryDataset(InMemoryDataset):
def __init__(self):
super().__init__('/tmp/MyInMemoryDataset')
x = torch.randn(4, 5)
edge_index = torch.tensor([[0, 0, 0], [1, 2, 3]])
arg = torch.randn(4, 3)
data_list = [MyData(x, edge_index, arg) for _ in range(10)]
self.data, self.slices = self.collate(data_list)
def _download(self):
pass
def _process(self):
pass
class MyDataset(Dataset):
def __init__(self):
super().__init__('/tmp/MyDataset')
def _download(self):
pass
def _process(self):
pass
def len(self):
return 10
def get(self, idx):
x = torch.randn(4, 5)
edge_index = torch.tensor([[0, 0, 0], [1, 2, 3]])
arg = torch.randn(4, 3)
return MyData(x, edge_index, arg)
def test_inherit():
dataset = MyDataset()
assert len(dataset) == 10
data = dataset[0]
assert data.random().size() == (4, 5, 4, 3)
dataset = MyInMemoryDataset()
assert len(dataset) == 10
data = dataset[0]
assert data.random().size() == (4, 5, 4, 3)
|