1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
|
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric.graphgym.models.head # noqa, register module
import torch_geometric.graphgym.register as register
import torch_geometric.nn as pyg_nn
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.register import register_network
@register_network('example')
class ExampleGNN(torch.nn.Module):
def __init__(self, dim_in, dim_out, num_layers=2, model_type='GCN'):
super().__init__()
conv_model = self.build_conv_model(model_type)
self.convs = nn.ModuleList()
self.convs.append(conv_model(dim_in, dim_in))
for _ in range(num_layers - 1):
self.convs.append(conv_model(dim_in, dim_in))
GNNHead = register.head_dict[cfg.dataset.task]
self.post_mp = GNNHead(dim_in=dim_in, dim_out=dim_out)
def build_conv_model(self, model_type):
if model_type == 'GCN':
return pyg_nn.GCNConv
elif model_type == 'GAT':
return pyg_nn.GATConv
elif model_type == "GraphSage":
return pyg_nn.SAGEConv
else:
raise ValueError(f'Model {model_type} unavailable')
def forward(self, batch):
x, edge_index = batch.x, batch.edge_index
for i in range(len(self.convs)):
x = self.convs[i](x, edge_index)
x = F.relu(x)
x = F.dropout(x, p=0.1, training=self.training)
batch.x = x
batch = self.post_mp(batch)
return batch
|