File: gcn.py

package info (click to toggle)
pytorch-geometric 2.6.1-7
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 12,904 kB
  • sloc: python: 127,155; sh: 338; cpp: 27; makefile: 18; javascript: 16
file content (80 lines) | stat: -rw-r--r-- 2,439 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import dgl.function as fn
import torch
import torch.nn.functional as F
from torch.nn import Parameter

from torch_geometric.nn.inits import glorot, zeros


class GCNConv(torch.nn.Module):
    def __init__(self, g, in_channels, out_channels):
        super().__init__()
        self.g = g
        self.weight = Parameter(torch.empty(in_channels, out_channels))
        self.bias = Parameter(torch.empty(out_channels))
        self.reset_parameters()

    def reset_parameters(self):
        glorot(self.weight)
        zeros(self.bias)

    def gcn_msg(self, edge):
        return {'m': edge.src['x'] * edge.src['norm']}

    def gcn_reduce(self, node):
        return {'x': node.mailbox['m'].sum(dim=1) * node.data['norm']}

    def forward(self, x):
        self.g.ndata['x'] = torch.matmul(x, self.weight)
        self.g.update_all(self.gcn_msg, self.gcn_reduce)
        x = self.g.ndata.pop('x')
        x = x + self.bias
        return x


class GCN(torch.nn.Module):
    def __init__(self, g, in_channels, out_channels):
        super().__init__()
        self.conv1 = GCNConv(g, in_channels, 16)
        self.conv2 = GCNConv(g, 16, out_channels)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.dropout(x, training=self.training)
        x = self.conv2(x)
        return F.log_softmax(x, dim=1)


class GCNSPMVConv(torch.nn.Module):
    def __init__(self, g, in_channels, out_channels):
        super().__init__()
        self.g = g
        self.weight = Parameter(torch.empty(in_channels, out_channels))
        self.bias = Parameter(torch.empty(out_channels))
        self.reset_parameters()

    def reset_parameters(self):
        glorot(self.weight)
        zeros(self.bias)

    def forward(self, x):
        x = torch.matmul(x, self.weight)
        self.g.ndata['x'] = x * self.g.ndata['norm']
        self.g.update_all(fn.copy_src(src='x', out='m'),
                          fn.sum(msg='m', out='x'))
        x = self.g.ndata.pop('x') * self.g.ndata['norm']
        x = x + self.bias
        return x


class GCNSPMV(torch.nn.Module):
    def __init__(self, g, in_channels, out_channels):
        super().__init__()
        self.conv1 = GCNSPMVConv(g, in_channels, 16)
        self.conv2 = GCNSPMVConv(g, 16, out_channels)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.dropout(x, training=self.training)
        x = self.conv2(x)
        return F.log_softmax(x, dim=1)