File: gat.py

package info (click to toggle)
pytorch-geometric 2.6.1-7
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 12,904 kB
  • sloc: python: 127,155; sh: 338; cpp: 27; makefile: 18; javascript: 16
file content (128 lines) | stat: -rw-r--r-- 4,635 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import dgl.function as fn
import torch
import torch.nn.functional as F
from dgl.nn.pytorch import EdgeSoftmax
from torch.nn import Parameter

from torch_geometric.nn.inits import glorot, zeros


class GATConv(torch.nn.Module):
    def __init__(self, g, in_channels, out_channels, heads=1,
                 negative_slope=0.2, dropout=0):
        super().__init__()

        self.g = g
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.heads = heads
        self.negative_slope = negative_slope
        self.dropout = dropout

        self.weight = Parameter(torch.empty(in_channels, heads * out_channels))
        self.att = Parameter(torch.empty(1, heads, 2 * out_channels))
        self.bias = Parameter(torch.empty(heads * out_channels))
        self.reset_parameters()

    def reset_parameters(self):
        glorot(self.weight)
        glorot(self.att)
        zeros(self.bias)

    def gat_msg(self, edge):
        alpha = torch.cat([edge.src['x'], edge.dst['x']], dim=-1)
        alpha = (alpha * self.att).sum(dim=-1)
        alpha = F.leaky_relu(alpha, self.negative_slope)
        return {'m': edge.src['x'], 'a': alpha}

    def gat_reduce(self, node):
        alpha = torch.softmax(node.mailbox['a'], dim=1)
        alpha = F.dropout(alpha, p=self.dropout, training=self.training)
        x = (node.mailbox['m'] * alpha.unsqueeze(-1)).sum(dim=1)
        return {'x': x}

    def forward(self, x):
        x = torch.mm(x, self.weight).view(-1, self.heads, self.out_channels)
        self.g.ndata['x'] = x
        self.g.update_all(self.gat_msg, self.gat_reduce)
        x = self.g.ndata.pop('x')
        x = x.view(-1, self.heads * self.out_channels)
        x = x + self.bias
        return x


class GAT(torch.nn.Module):
    def __init__(self, g, in_channels, out_channels):
        super().__init__()
        self.g = g
        self.conv1 = GATConv(g, in_channels, 8, 8, 0.6, 0.2)
        self.conv2 = GATConv(g, 64, out_channels, 1, 0.6, 0.2)

    def forward(self, x):
        x = F.dropout(x, p=0.6, training=self.training)
        x = F.elu(self.conv1(x))
        x = F.dropout(x, p=0.6, training=self.training)
        x = self.conv2(x)
        return F.log_softmax(x, dim=1)


class GATSPMVConv(torch.nn.Module):
    def __init__(self, g, in_channels, out_channels, heads=1,
                 negative_slope=0.2, dropout=0):
        super().__init__()
        self.g = g
        self.out_channels = out_channels
        self.heads = heads
        self.negative_slope = negative_slope
        self.dropout = dropout
        self.weight = Parameter(torch.empty(in_channels, heads * out_channels))
        self.att_l = Parameter(torch.empty(heads, out_channels, 1))
        self.att_r = Parameter(torch.empty(heads, out_channels, 1))
        self.bias = Parameter(torch.empty(heads * out_channels))
        self.softmax = EdgeSoftmax()
        self.reset_parameters()

    def reset_parameters(self):
        glorot(self.weight)
        glorot(self.att_l)
        glorot(self.att_r)
        zeros(self.bias)

    def forward(self, x):
        x = torch.matmul(x, self.weight)
        x = x.reshape((x.size(0), self.heads, -1))  # NxHxD'
        head_x = x.transpose(0, 1)  # HxNxD'
        a1 = torch.bmm(head_x, self.att_l).transpose(0, 1)  # NxHx1
        a2 = torch.bmm(head_x, self.att_r).transpose(0, 1)  # NxHx1
        self.g.ndata.update({'x': x, 'a1': a1, 'a2': a2})
        self.g.apply_edges(self.edge_attention)
        self.edge_softmax()
        self.g.update_all(fn.src_mul_edge('x', 'a', 'x'), fn.sum('x', 'x'))
        x = self.g.ndata['x'] / self.g.ndata['z']  # NxHxD'
        return x.view(-1, self.heads * self.out_channels)

    def edge_attention(self, edge):
        a = F.leaky_relu(edge.src['a1'] + edge.dst['a2'], self.negative_slope)
        return {'a': a}

    def edge_softmax(self):
        alpha, normalizer = self.softmax(self.g.edata['a'], self.g)
        self.g.ndata['z'] = normalizer
        if self.training and self.dropout > 0:
            alpha = F.dropout(alpha, p=self.dropout, training=True)
        self.g.edata['a'] = alpha


class GATSPMV(torch.nn.Module):
    def __init__(self, g, in_channels, out_channels):
        super().__init__()
        self.g = g
        self.conv1 = GATSPMVConv(g, in_channels, 8, 8, 0.6, 0.2)
        self.conv2 = GATSPMVConv(g, 64, out_channels, 1, 0.6, 0.2)

    def forward(self, x):
        x = F.dropout(x, p=0.6, training=self.training)
        x = F.elu(self.conv1(x))
        x = F.dropout(x, p=0.6, training=self.training)
        x = self.conv2(x)
        return F.log_softmax(x, dim=1)