1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128
|
import dgl.function as fn
import torch
import torch.nn.functional as F
from dgl.nn.pytorch import EdgeSoftmax
from torch.nn import Parameter
from torch_geometric.nn.inits import glorot, zeros
class GATConv(torch.nn.Module):
def __init__(self, g, in_channels, out_channels, heads=1,
negative_slope=0.2, dropout=0):
super().__init__()
self.g = g
self.in_channels = in_channels
self.out_channels = out_channels
self.heads = heads
self.negative_slope = negative_slope
self.dropout = dropout
self.weight = Parameter(torch.empty(in_channels, heads * out_channels))
self.att = Parameter(torch.empty(1, heads, 2 * out_channels))
self.bias = Parameter(torch.empty(heads * out_channels))
self.reset_parameters()
def reset_parameters(self):
glorot(self.weight)
glorot(self.att)
zeros(self.bias)
def gat_msg(self, edge):
alpha = torch.cat([edge.src['x'], edge.dst['x']], dim=-1)
alpha = (alpha * self.att).sum(dim=-1)
alpha = F.leaky_relu(alpha, self.negative_slope)
return {'m': edge.src['x'], 'a': alpha}
def gat_reduce(self, node):
alpha = torch.softmax(node.mailbox['a'], dim=1)
alpha = F.dropout(alpha, p=self.dropout, training=self.training)
x = (node.mailbox['m'] * alpha.unsqueeze(-1)).sum(dim=1)
return {'x': x}
def forward(self, x):
x = torch.mm(x, self.weight).view(-1, self.heads, self.out_channels)
self.g.ndata['x'] = x
self.g.update_all(self.gat_msg, self.gat_reduce)
x = self.g.ndata.pop('x')
x = x.view(-1, self.heads * self.out_channels)
x = x + self.bias
return x
class GAT(torch.nn.Module):
def __init__(self, g, in_channels, out_channels):
super().__init__()
self.g = g
self.conv1 = GATConv(g, in_channels, 8, 8, 0.6, 0.2)
self.conv2 = GATConv(g, 64, out_channels, 1, 0.6, 0.2)
def forward(self, x):
x = F.dropout(x, p=0.6, training=self.training)
x = F.elu(self.conv1(x))
x = F.dropout(x, p=0.6, training=self.training)
x = self.conv2(x)
return F.log_softmax(x, dim=1)
class GATSPMVConv(torch.nn.Module):
def __init__(self, g, in_channels, out_channels, heads=1,
negative_slope=0.2, dropout=0):
super().__init__()
self.g = g
self.out_channels = out_channels
self.heads = heads
self.negative_slope = negative_slope
self.dropout = dropout
self.weight = Parameter(torch.empty(in_channels, heads * out_channels))
self.att_l = Parameter(torch.empty(heads, out_channels, 1))
self.att_r = Parameter(torch.empty(heads, out_channels, 1))
self.bias = Parameter(torch.empty(heads * out_channels))
self.softmax = EdgeSoftmax()
self.reset_parameters()
def reset_parameters(self):
glorot(self.weight)
glorot(self.att_l)
glorot(self.att_r)
zeros(self.bias)
def forward(self, x):
x = torch.matmul(x, self.weight)
x = x.reshape((x.size(0), self.heads, -1)) # NxHxD'
head_x = x.transpose(0, 1) # HxNxD'
a1 = torch.bmm(head_x, self.att_l).transpose(0, 1) # NxHx1
a2 = torch.bmm(head_x, self.att_r).transpose(0, 1) # NxHx1
self.g.ndata.update({'x': x, 'a1': a1, 'a2': a2})
self.g.apply_edges(self.edge_attention)
self.edge_softmax()
self.g.update_all(fn.src_mul_edge('x', 'a', 'x'), fn.sum('x', 'x'))
x = self.g.ndata['x'] / self.g.ndata['z'] # NxHxD'
return x.view(-1, self.heads * self.out_channels)
def edge_attention(self, edge):
a = F.leaky_relu(edge.src['a1'] + edge.dst['a2'], self.negative_slope)
return {'a': a}
def edge_softmax(self):
alpha, normalizer = self.softmax(self.g.edata['a'], self.g)
self.g.ndata['z'] = normalizer
if self.training and self.dropout > 0:
alpha = F.dropout(alpha, p=self.dropout, training=True)
self.g.edata['a'] = alpha
class GATSPMV(torch.nn.Module):
def __init__(self, g, in_channels, out_channels):
super().__init__()
self.g = g
self.conv1 = GATSPMVConv(g, in_channels, 8, 8, 0.6, 0.2)
self.conv2 = GATSPMVConv(g, 64, out_channels, 1, 0.6, 0.2)
def forward(self, x):
x = F.dropout(x, p=0.6, training=self.training)
x = F.elu(self.conv1(x))
x = F.dropout(x, p=0.6, training=self.training)
x = self.conv2(x)
return F.log_softmax(x, dim=1)
|