1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85
|
from typing import Callable, Optional
import torch
from torch import Tensor
from torch_geometric.data import HeteroData, InMemoryDataset
from torch_geometric.typing import TensorFrame, torch_frame
from torch_geometric.utils import coalesce as coalesce_fn
def get_random_edge_index(
num_src_nodes: int,
num_dst_nodes: int,
num_edges: int,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
coalesce: bool = False,
) -> Tensor:
row = torch.randint(num_src_nodes, (num_edges, ), dtype=dtype,
device=device)
col = torch.randint(num_dst_nodes, (num_edges, ), dtype=dtype,
device=device)
edge_index = torch.stack([row, col], dim=0)
if coalesce:
edge_index = coalesce_fn(edge_index)
return edge_index
def get_random_tensor_frame(
num_rows: int,
device: Optional[torch.device] = None,
) -> TensorFrame:
feat_dict = {
torch_frame.categorical:
torch.randint(0, 3, size=(num_rows, 3), device=device),
torch_frame.numerical:
torch.randn(size=(num_rows, 2), device=device),
}
col_names_dict = {
torch_frame.categorical: ['a', 'b', 'c'],
torch_frame.numerical: ['x', 'y'],
}
y = torch.randn(num_rows, device=device)
return torch_frame.TensorFrame(
feat_dict=feat_dict,
col_names_dict=col_names_dict,
y=y,
)
class FakeHeteroDataset(InMemoryDataset):
def __init__(self, transform: Optional[Callable] = None):
super().__init__(transform=transform)
data = HeteroData()
num_papers = 100
num_authors = 10
data['paper'].x = torch.randn(num_papers, 16)
data['author'].x = torch.randn(num_authors, 8)
edge_index = get_random_edge_index(
num_src_nodes=num_papers,
num_dst_nodes=num_authors,
num_edges=300,
)
data['paper', 'author'].edge_index = edge_index
data['author', 'paper'].edge_index = edge_index.flip([0])
data['paper'].y = torch.randint(0, 4, (num_papers, ))
perm = torch.randperm(num_papers)
data['paper'].train_mask = torch.zeros(num_papers, dtype=torch.bool)
data['paper'].train_mask[perm[0:60]] = True
data['paper'].val_mask = torch.zeros(num_papers, dtype=torch.bool)
data['paper'].val_mask[perm[60:80]] = True
data['paper'].test_mask = torch.zeros(num_papers, dtype=torch.bool)
data['paper'].test_mask[perm[80:100]] = True
self.data, self.slices = self.collate([data])
|