1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77
|
import torch
import torch.nn.functional as F
from torch.nn import Linear
from torch_geometric.nn import GCNConv, JumpingKnowledge, global_mean_pool
class GCN(torch.nn.Module):
def __init__(self, dataset, num_layers, hidden):
super().__init__()
self.conv1 = GCNConv(dataset.num_features, hidden)
self.convs = torch.nn.ModuleList()
for i in range(num_layers - 1):
self.convs.append(GCNConv(hidden, hidden))
self.lin1 = Linear(hidden, hidden)
self.lin2 = Linear(hidden, dataset.num_classes)
def reset_parameters(self):
self.conv1.reset_parameters()
for conv in self.convs:
conv.reset_parameters()
self.lin1.reset_parameters()
self.lin2.reset_parameters()
def forward(self, data):
x, edge_index, batch = data.x, data.edge_index, data.batch
x = F.relu(self.conv1(x, edge_index))
for conv in self.convs:
x = F.relu(conv(x, edge_index))
x = global_mean_pool(x, batch)
x = F.relu(self.lin1(x))
x = F.dropout(x, p=0.5, training=self.training)
x = self.lin2(x)
return F.log_softmax(x, dim=-1)
def __repr__(self):
return self.__class__.__name__
class GCNWithJK(torch.nn.Module):
def __init__(self, dataset, num_layers, hidden, mode='cat'):
super().__init__()
self.conv1 = GCNConv(dataset.num_features, hidden)
self.convs = torch.nn.ModuleList()
for i in range(num_layers - 1):
self.convs.append(GCNConv(hidden, hidden))
self.jump = JumpingKnowledge(mode)
if mode == 'cat':
self.lin1 = Linear(num_layers * hidden, hidden)
else:
self.lin1 = Linear(hidden, hidden)
self.lin2 = Linear(hidden, dataset.num_classes)
def reset_parameters(self):
self.conv1.reset_parameters()
for conv in self.convs:
conv.reset_parameters()
self.jump.reset_parameters()
self.lin1.reset_parameters()
self.lin2.reset_parameters()
def forward(self, data):
x, edge_index, batch = data.x, data.edge_index, data.batch
x = F.relu(self.conv1(x, edge_index))
xs = [x]
for conv in self.convs:
x = F.relu(conv(x, edge_index))
xs += [x]
x = self.jump(xs)
x = global_mean_pool(x, batch)
x = F.relu(self.lin1(x))
x = F.dropout(x, p=0.5, training=self.training)
x = self.lin2(x)
return F.log_softmax(x, dim=-1)
def __repr__(self):
return self.__class__.__name__
|