1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146
|
import argparse
import torch
import torch.nn.functional as F
from tqdm import tqdm
import torch_geometric
import torch_geometric.transforms as T
from torch_geometric.datasets import OGB_MAG
from torch_geometric.loader import NeighborLoader
from torch_geometric.nn import HeteroConv, Linear, SAGEConv
from torch_geometric.utils import trim_to_layer
parser = argparse.ArgumentParser()
parser.add_argument('--use-sparse-tensor', action='store_true')
args = parser.parse_args()
if torch.cuda.is_available():
device = torch.device('cuda')
elif torch_geometric.is_xpu_available():
device = torch.device('xpu')
else:
device = torch.device('cpu')
transforms = [T.ToUndirected(merge=True)]
if args.use_sparse_tensor:
transforms.append(T.ToSparseTensor())
dataset = OGB_MAG(root='../../data', preprocess='metapath2vec',
transform=T.Compose(transforms))
data = dataset[0].to(device, 'x', 'y')
class HierarchicalHeteroGraphSage(torch.nn.Module):
def __init__(self, edge_types, hidden_channels, out_channels, num_layers):
super().__init__()
self.convs = torch.nn.ModuleList()
for _ in range(num_layers):
conv = HeteroConv(
{
edge_type: SAGEConv((-1, -1), hidden_channels)
for edge_type in edge_types
}, aggr='sum')
self.convs.append(conv)
self.lin = Linear(hidden_channels, out_channels)
def forward(self, x_dict, edge_index_dict, num_sampled_edges_dict,
num_sampled_nodes_dict):
for i, conv in enumerate(self.convs):
x_dict, edge_index_dict, _ = trim_to_layer(
layer=i,
num_sampled_nodes_per_hop=num_sampled_nodes_dict,
num_sampled_edges_per_hop=num_sampled_edges_dict,
x=x_dict,
edge_index=edge_index_dict,
)
x_dict = conv(x_dict, edge_index_dict)
x_dict = {key: x.relu() for key, x in x_dict.items()}
return self.lin(x_dict['paper'])
model = HierarchicalHeteroGraphSage(
edge_types=data.edge_types,
hidden_channels=64,
out_channels=dataset.num_classes,
num_layers=2,
).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
kwargs = {'batch_size': 1024, 'num_workers': 0}
train_loader = NeighborLoader(
data,
num_neighbors=[10] * 2,
shuffle=True,
input_nodes=('paper', data['paper'].train_mask),
**kwargs,
)
val_loader = NeighborLoader(
data,
num_neighbors=[10] * 2,
shuffle=False,
input_nodes=('paper', data['paper'].val_mask),
**kwargs,
)
def train():
model.train()
total_examples = total_loss = 0
for batch in tqdm(train_loader):
batch = batch.to(device)
optimizer.zero_grad()
out = model(
batch.x_dict,
batch.adj_t_dict
if args.use_sparse_tensor else batch.edge_index_dict,
num_sampled_nodes_dict=batch.num_sampled_nodes_dict,
num_sampled_edges_dict=batch.num_sampled_edges_dict,
)
batch_size = batch['paper'].batch_size
loss = F.cross_entropy(out[:batch_size], batch['paper'].y[:batch_size])
loss.backward()
optimizer.step()
total_examples += batch_size
total_loss += float(loss) * batch_size
return total_loss / total_examples
@torch.no_grad()
def test(loader):
model.eval()
total_examples = total_correct = 0
for batch in tqdm(loader):
batch = batch.to(device)
out = model(
batch.x_dict,
batch.adj_t_dict
if args.use_sparse_tensor else batch.edge_index_dict,
num_sampled_nodes_dict=batch.num_sampled_nodes_dict,
num_sampled_edges_dict=batch.num_sampled_edges_dict,
)
batch_size = batch['paper'].batch_size
pred = out[:batch_size].argmax(dim=-1)
total_examples += batch_size
total_correct += int((pred == batch['paper'].y[:batch_size]).sum())
return total_correct / total_examples
for epoch in range(1, 6):
loss = train()
val_acc = test(val_loader)
print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Val: {val_acc:.4f}')
|