1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121
|
import copy
import os.path as osp
import time
import torch
import torch.nn.functional as F
from tqdm import tqdm
from torch_geometric.datasets import Reddit
from torch_geometric.loader import NeighborLoader
from torch_geometric.nn import SAGEConv
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Reddit')
dataset = Reddit(path)
# Already send node features/labels to GPU for faster access during sampling:
data = dataset[0].to(device, 'x', 'y')
kwargs = {'batch_size': 1024, 'num_workers': 6, 'persistent_workers': True}
train_loader = NeighborLoader(data, input_nodes=data.train_mask,
num_neighbors=[25, 10], shuffle=True, **kwargs)
subgraph_loader = NeighborLoader(copy.copy(data), input_nodes=None,
num_neighbors=[-1], shuffle=False, **kwargs)
# No need to maintain these features during evaluation:
del subgraph_loader.data.x, subgraph_loader.data.y
# Add global node index information.
subgraph_loader.data.num_nodes = data.num_nodes
subgraph_loader.data.n_id = torch.arange(data.num_nodes)
class SAGE(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels):
super().__init__()
self.convs = torch.nn.ModuleList()
self.convs.append(SAGEConv(in_channels, hidden_channels))
self.convs.append(SAGEConv(hidden_channels, out_channels))
def forward(self, x, edge_index):
for i, conv in enumerate(self.convs):
x = conv(x, edge_index)
if i < len(self.convs) - 1:
x = x.relu_()
x = F.dropout(x, p=0.5, training=self.training)
return x
@torch.no_grad()
def inference(self, x_all, subgraph_loader):
pbar = tqdm(total=len(subgraph_loader.dataset) * len(self.convs))
pbar.set_description('Evaluating')
# Compute representations of nodes layer by layer, using *all*
# available edges. This leads to faster computation in contrast to
# immediately computing the final representations of each batch:
for i, conv in enumerate(self.convs):
xs = []
for batch in subgraph_loader:
x = x_all[batch.n_id.to(x_all.device)].to(device)
x = conv(x, batch.edge_index.to(device))
if i < len(self.convs) - 1:
x = x.relu_()
xs.append(x[:batch.batch_size].cpu())
pbar.update(batch.batch_size)
x_all = torch.cat(xs, dim=0)
pbar.close()
return x_all
model = SAGE(dataset.num_features, 256, dataset.num_classes).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
def train(epoch):
model.train()
pbar = tqdm(total=int(len(train_loader.dataset)))
pbar.set_description(f'Epoch {epoch:02d}')
total_loss = total_correct = total_examples = 0
for batch in train_loader:
optimizer.zero_grad()
y = batch.y[:batch.batch_size]
y_hat = model(batch.x, batch.edge_index.to(device))[:batch.batch_size]
loss = F.cross_entropy(y_hat, y)
loss.backward()
optimizer.step()
total_loss += float(loss) * batch.batch_size
total_correct += int((y_hat.argmax(dim=-1) == y).sum())
total_examples += batch.batch_size
pbar.update(batch.batch_size)
pbar.close()
return total_loss / total_examples, total_correct / total_examples
@torch.no_grad()
def test():
model.eval()
y_hat = model.inference(data.x, subgraph_loader).argmax(dim=-1)
y = data.y.to(y_hat.device)
accs = []
for mask in [data.train_mask, data.val_mask, data.test_mask]:
accs.append(int((y_hat[mask] == y[mask]).sum()) / int(mask.sum()))
return accs
times = []
for epoch in range(1, 11):
start = time.time()
loss, acc = train(epoch)
print(f'Epoch {epoch:02d}, Loss: {loss:.4f}, Approx. Train: {acc:.4f}')
train_acc, val_acc, test_acc = test()
print(f'Epoch: {epoch:02d}, Train: {train_acc:.4f}, Val: {val_acc:.4f}, '
f'Test: {test_acc:.4f}')
times.append(time.time() - start)
print(f"Median time per epoch: {torch.tensor(times).median():.4f}s")
|