1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
|
import torch
import torch.nn as nn
from torch.nn import Parameter
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.register import register_layer
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.inits import glorot, zeros
# Note: A registered GNN layer should take 'batch' as input
# and 'batch' as output
# Example 1: Directly define a GraphGym format Conv
# take 'batch' as input and 'batch' as output
@register_layer('exampleconv1')
class ExampleConv1(MessagePassing):
r"""Example GNN layer."""
def __init__(self, in_channels, out_channels, bias=True, **kwargs):
super().__init__(aggr=cfg.gnn.agg, **kwargs)
self.in_channels = in_channels
self.out_channels = out_channels
self.weight = Parameter(torch.empty(in_channels, out_channels))
if bias:
self.bias = Parameter(torch.empty(out_channels))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
glorot(self.weight)
zeros(self.bias)
def forward(self, batch):
x, edge_index = batch.x, batch.edge_index
x = torch.matmul(x, self.weight)
batch.x = self.propagate(edge_index, x=x)
return batch
def message(self, x_j):
return x_j
def update(self, aggr_out):
if self.bias is not None:
aggr_out = aggr_out + self.bias
return aggr_out
# Example 2: First define a PyG format Conv layer
# Then wrap it to become GraphGym format
class ExampleConv2Layer(MessagePassing):
r"""Example GNN layer."""
def __init__(self, in_channels, out_channels, bias=True, **kwargs):
super().__init__(aggr=cfg.gnn.agg, **kwargs)
self.in_channels = in_channels
self.out_channels = out_channels
self.weight = Parameter(torch.empty(in_channels, out_channels))
if bias:
self.bias = Parameter(torch.empty(out_channels))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
glorot(self.weight)
zeros(self.bias)
def forward(self, x, edge_index):
x = torch.matmul(x, self.weight)
return self.propagate(edge_index, x=x)
def message(self, x_j):
return x_j
def update(self, aggr_out):
if self.bias is not None:
aggr_out = aggr_out + self.bias
return aggr_out
@register_layer('exampleconv2')
class ExampleConv2(nn.Module):
def __init__(self, dim_in, dim_out, bias=False, **kwargs):
super().__init__()
self.model = ExampleConv2Layer(dim_in, dim_out, bias=bias)
def forward(self, batch):
batch.x = self.model(batch.x, batch.edge_index)
return batch
|