import dgl.function as fn
import torch
import torch.nn.functional as F
from dgl.nn.pytorch import EdgeSoftmax
from torch.nn import Parameter

from torch_geometric.nn.inits import glorot, zeros


class GATConv(torch.nn.Module):
    def __init__(self, g, in_channels, out_channels, heads=1,
                 negative_slope=0.2, dropout=0):
        super().__init__()

        self.g = g
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.heads = heads
        self.negative_slope = negative_slope
        self.dropout = dropout

        self.weight = Parameter(torch.empty(in_channels, heads * out_channels))
        self.att = Parameter(torch.empty(1, heads, 2 * out_channels))
        self.bias = Parameter(torch.empty(heads * out_channels))
        self.reset_parameters()

    def reset_parameters(self):
        glorot(self.weight)
        glorot(self.att)
        zeros(self.bias)

    def gat_msg(self, edge):
        alpha = torch.cat([edge.src['x'], edge.dst['x']], dim=-1)
        alpha = (alpha * self.att).sum(dim=-1)
        alpha = F.leaky_relu(alpha, self.negative_slope)
        return {'m': edge.src['x'], 'a': alpha}

    def gat_reduce(self, node):
        alpha = torch.softmax(node.mailbox['a'], dim=1)
        alpha = F.dropout(alpha, p=self.dropout, training=self.training)
        x = (node.mailbox['m'] * alpha.unsqueeze(-1)).sum(dim=1)
        return {'x': x}

    def forward(self, x):
        x = torch.mm(x, self.weight).view(-1, self.heads, self.out_channels)
        self.g.ndata['x'] = x
        self.g.update_all(self.gat_msg, self.gat_reduce)
        x = self.g.ndata.pop('x')
        x = x.view(-1, self.heads * self.out_channels)
        x = x + self.bias
        return x


class GAT(torch.nn.Module):
    def __init__(self, g, in_channels, out_channels):
        super().__init__()
        self.g = g
        self.conv1 = GATConv(g, in_channels, 8, 8, 0.6, 0.2)
        self.conv2 = GATConv(g, 64, out_channels, 1, 0.6, 0.2)

    def forward(self, x):
        x = F.dropout(x, p=0.6, training=self.training)
        x = F.elu(self.conv1(x))
        x = F.dropout(x, p=0.6, training=self.training)
        x = self.conv2(x)
        return F.log_softmax(x, dim=1)


class GATSPMVConv(torch.nn.Module):
    def __init__(self, g, in_channels, out_channels, heads=1,
                 negative_slope=0.2, dropout=0):
        super().__init__()
        self.g = g
        self.out_channels = out_channels
        self.heads = heads
        self.negative_slope = negative_slope
        self.dropout = dropout
        self.weight = Parameter(torch.empty(in_channels, heads * out_channels))
        self.att_l = Parameter(torch.empty(heads, out_channels, 1))
        self.att_r = Parameter(torch.empty(heads, out_channels, 1))
        self.bias = Parameter(torch.empty(heads * out_channels))
        self.softmax = EdgeSoftmax()
        self.reset_parameters()

    def reset_parameters(self):
        glorot(self.weight)
        glorot(self.att_l)
        glorot(self.att_r)
        zeros(self.bias)

    def forward(self, x):
        x = torch.matmul(x, self.weight)
        x = x.reshape((x.size(0), self.heads, -1))  # NxHxD'
        head_x = x.transpose(0, 1)  # HxNxD'
        a1 = torch.bmm(head_x, self.att_l).transpose(0, 1)  # NxHx1
        a2 = torch.bmm(head_x, self.att_r).transpose(0, 1)  # NxHx1
        self.g.ndata.update({'x': x, 'a1': a1, 'a2': a2})
        self.g.apply_edges(self.edge_attention)
        self.edge_softmax()
        self.g.update_all(fn.src_mul_edge('x', 'a', 'x'), fn.sum('x', 'x'))
        x = self.g.ndata['x'] / self.g.ndata['z']  # NxHxD'
        return x.view(-1, self.heads * self.out_channels)

    def edge_attention(self, edge):
        a = F.leaky_relu(edge.src['a1'] + edge.dst['a2'], self.negative_slope)
        return {'a': a}

    def edge_softmax(self):
        alpha, normalizer = self.softmax(self.g.edata['a'], self.g)
        self.g.ndata['z'] = normalizer
        if self.training and self.dropout > 0:
            alpha = F.dropout(alpha, p=self.dropout, training=True)
        self.g.edata['a'] = alpha


class GATSPMV(torch.nn.Module):
    def __init__(self, g, in_channels, out_channels):
        super().__init__()
        self.g = g
        self.conv1 = GATSPMVConv(g, in_channels, 8, 8, 0.6, 0.2)
        self.conv2 = GATSPMVConv(g, 64, out_channels, 1, 0.6, 0.2)

    def forward(self, x):
        x = F.dropout(x, p=0.6, training=self.training)
        x = F.elu(self.conv1(x))
        x = F.dropout(x, p=0.6, training=self.training)
        x = self.conv2(x)
        return F.log_softmax(x, dim=1)
