1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80
|
import dgl.function as fn
import torch
import torch.nn.functional as F
from torch.nn import Parameter
from torch_geometric.nn.inits import glorot, zeros
class GCNConv(torch.nn.Module):
def __init__(self, g, in_channels, out_channels):
super().__init__()
self.g = g
self.weight = Parameter(torch.empty(in_channels, out_channels))
self.bias = Parameter(torch.empty(out_channels))
self.reset_parameters()
def reset_parameters(self):
glorot(self.weight)
zeros(self.bias)
def gcn_msg(self, edge):
return {'m': edge.src['x'] * edge.src['norm']}
def gcn_reduce(self, node):
return {'x': node.mailbox['m'].sum(dim=1) * node.data['norm']}
def forward(self, x):
self.g.ndata['x'] = torch.matmul(x, self.weight)
self.g.update_all(self.gcn_msg, self.gcn_reduce)
x = self.g.ndata.pop('x')
x = x + self.bias
return x
class GCN(torch.nn.Module):
def __init__(self, g, in_channels, out_channels):
super().__init__()
self.conv1 = GCNConv(g, in_channels, 16)
self.conv2 = GCNConv(g, 16, out_channels)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.dropout(x, training=self.training)
x = self.conv2(x)
return F.log_softmax(x, dim=1)
class GCNSPMVConv(torch.nn.Module):
def __init__(self, g, in_channels, out_channels):
super().__init__()
self.g = g
self.weight = Parameter(torch.empty(in_channels, out_channels))
self.bias = Parameter(torch.empty(out_channels))
self.reset_parameters()
def reset_parameters(self):
glorot(self.weight)
zeros(self.bias)
def forward(self, x):
x = torch.matmul(x, self.weight)
self.g.ndata['x'] = x * self.g.ndata['norm']
self.g.update_all(fn.copy_src(src='x', out='m'),
fn.sum(msg='m', out='x'))
x = self.g.ndata.pop('x') * self.g.ndata['norm']
x = x + self.bias
return x
class GCNSPMV(torch.nn.Module):
def __init__(self, g, in_channels, out_channels):
super().__init__()
self.conv1 = GCNSPMVConv(g, in_channels, 16)
self.conv2 = GCNSPMVConv(g, 16, out_channels)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.dropout(x, training=self.training)
x = self.conv2(x)
return F.log_softmax(x, dim=1)
|