import dgl.function as fn
import torch
import torch.nn.functional as F
from torch.nn import Parameter as Param

from torch_geometric.nn.inits import uniform


class RGCNConv(torch.nn.Module):
    def __init__(self, g, in_channels, out_channels, num_relations, num_bases):
        super().__init__()

        self.g = g
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.num_relations = num_relations
        self.num_bases = num_bases

        self.basis = Param(torch.empty(num_bases, in_channels, out_channels))
        self.att = Param(torch.empty(num_relations, num_bases))
        self.root = Param(torch.empty(in_channels, out_channels))
        self.bias = Param(torch.empty(out_channels))

        self.reset_parameters()

    def reset_parameters(self):
        size = self.num_bases * self.in_channels
        uniform(size, self.basis)
        uniform(size, self.att)
        uniform(size, self.root)
        uniform(size, self.bias)

    def rgcn_reduce(self, node):
        return {'x': node.mailbox['m'].sum(dim=1)}

    def forward(self, x):
        self.w = torch.matmul(self.att, self.basis.view(self.num_bases, -1))
        self.w = self.w.view(self.num_relations, self.in_channels,
                             self.out_channels)

        if x is None:

            def msg_func(edge):
                w = self.w.view(-1, self.out_channels)
                index = edge.data['type'] * self.in_channels + edge.src['id']
                m = w.index_select(0, index) * edge.data['norm'].unsqueeze(1)
                return {'m': m}
        else:
            self.g.ndata['x'] = x

            def msg_func(edge):
                w = self.w.index_select(0, edge.data['type'])
                m = torch.bmm(edge.src['x'].unsqueeze(1), w).squeeze()
                m = m * edge.data['norm'].unsqueeze(1)
                return {'m': m}

        self.g.update_all(msg_func, self.rgcn_reduce)
        out = self.g.ndata.pop('x')

        if x is None:
            out = out + self.root
        else:
            out = out + torch.matmul(x, self.root)

        out = out + self.bias
        return out


class RGCN(torch.nn.Module):
    def __init__(self, g, in_channels, out_channels, num_relations):
        super().__init__()
        self.conv1 = RGCNConv(g, in_channels, 16, num_relations, num_bases=30)
        self.conv2 = RGCNConv(g, 16, out_channels, num_relations, num_bases=30)

    def forward(self, x):
        x = F.relu(self.conv1(None))
        x = self.conv2(x)
        return F.log_softmax(x, dim=1)


class RGCNSPMVConv(torch.nn.Module):
    def __init__(self, g, in_channels, out_channels, num_relations, num_bases):
        super().__init__()

        self.g = g
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.num_relations = num_relations
        self.num_bases = num_bases

        self.basis = Param(torch.empty(num_bases, in_channels, out_channels))
        self.att = Param(torch.empty(num_relations, num_bases))
        self.root = Param(torch.empty(in_channels, out_channels))
        self.bias = Param(torch.empty(out_channels))

        self.reset_parameters()

    def reset_parameters(self):
        size = self.num_bases * self.in_channels
        uniform(size, self.basis)
        uniform(size, self.att)
        uniform(size, self.root)
        uniform(size, self.bias)

    def forward(self, x):
        self.w = torch.matmul(self.att, self.basis.view(self.num_bases, -1))
        self.w = self.w.view(self.num_relations, self.in_channels,
                             self.out_channels)

        if x is None:

            def msg_func(edge):
                w = self.w.view(-1, self.out_channels)
                index = edge.data['type'] * self.in_channels + edge.src['id']
                m = w.index_select(0, index) * edge.data['norm'].unsqueeze(1)
                return {'m': m}
        else:
            self.g.ndata['x'] = x

            def msg_func(edge):
                w = self.w.index_select(0, edge.data['type'])
                m = torch.bmm(edge.src['x'].unsqueeze(1), w).squeeze()
                m = m * edge.data['norm'].unsqueeze(1)
                return {'m': m}

        self.g.update_all(msg_func, fn.sum(msg='m', out='x'))
        out = self.g.ndata.pop('x')

        if x is None:
            out = out + self.root
        else:
            out = out + torch.matmul(x, self.root)

        out = out + self.bias
        return out


class RGCNSPMV(torch.nn.Module):
    def __init__(self, g, in_channels, out_channels, num_relations):
        super().__init__()
        self.conv1 = RGCNSPMVConv(g, in_channels, 16, num_relations,
                                  num_bases=30)
        self.conv2 = RGCNSPMVConv(g, 16, out_channels, num_relations,
                                  num_bases=30)

    def forward(self, x):
        x = F.relu(self.conv1(None))
        x = self.conv2(x)
        return F.log_softmax(x, dim=1)
