import os.path as osp
import time

import torch
import torch.nn.functional as F
from torch.nn import BatchNorm1d, LeakyReLU, Linear

from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader
from torch_geometric.nn import DeepGCNLayer, GATConv, MemPooling

path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'TUD')
dataset = TUDataset(path, name="PROTEINS_full", use_node_attr=True)
dataset._data.x = dataset._data.x[:, :-3]  # only use non-binary features.
dataset = dataset.shuffle()

n = (len(dataset)) // 10
test_dataset = dataset[:n]
val_dataset = dataset[n:2 * n]
train_dataset = dataset[2 * n:]

test_loader = DataLoader(test_dataset, batch_size=20)
val_loader = DataLoader(val_dataset, batch_size=20)
train_loader = DataLoader(train_dataset, batch_size=20)


class Net(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, dropout):
        super().__init__()
        self.dropout = dropout

        self.lin = Linear(in_channels, hidden_channels)

        self.convs = torch.nn.ModuleList()
        for i in range(2):
            conv = GATConv(hidden_channels, hidden_channels, dropout=dropout)
            norm = BatchNorm1d(hidden_channels)
            act = LeakyReLU()
            self.convs.append(
                DeepGCNLayer(conv, norm, act, block='res+', dropout=dropout))

        self.mem1 = MemPooling(hidden_channels, 80, heads=5, num_clusters=10)
        self.mem2 = MemPooling(80, out_channels, heads=5, num_clusters=1)

    def forward(self, x, edge_index, batch):
        x = self.lin(x)
        for conv in self.convs:
            x = conv(x, edge_index)

        x, S1 = self.mem1(x, batch)
        x = F.leaky_relu(x)
        x = F.dropout(x, p=self.dropout)
        x, S2 = self.mem2(x)

        return (
            F.log_softmax(x.squeeze(1), dim=-1),
            MemPooling.kl_loss(S1) + MemPooling.kl_loss(S2),
        )


device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = Net(dataset.num_features, 32, dataset.num_classes, dropout=0.1)
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=4e-5)


def train():
    model.train()

    model.mem1.k.requires_grad = False
    model.mem2.k.requires_grad = False
    for data in train_loader:
        optimizer.zero_grad()
        data = data.to(device)
        out = model(data.x, data.edge_index, data.batch)[0]
        loss = F.nll_loss(out, data.y)
        loss.backward()
        optimizer.step()

    kl_loss = 0.
    model.mem1.k.requires_grad = True
    model.mem2.k.requires_grad = True
    optimizer.zero_grad()
    for data in train_loader:
        data = data.to(device)
        kl_loss += model(data.x, data.edge_index, data.batch)[1]
    kl_loss /= len(train_loader.dataset)
    kl_loss.backward()
    optimizer.step()


@torch.no_grad()
def test(loader):
    model.eval()
    total_correct = 0
    for data in loader:
        data = data.to(device)
        out = model(data.x, data.edge_index, data.batch)[0]
        total_correct += int((out.argmax(dim=-1) == data.y).sum())
    return total_correct / len(loader.dataset)


times = []
patience = start_patience = 250
test_acc = best_val_acc = 0.
for epoch in range(1, 2001):
    start = time.time()
    train()
    val_acc = test(val_loader)
    if epoch % 500 == 0:
        optimizer.param_groups[0]['lr'] *= 0.5
    if best_val_acc < val_acc:
        patience = start_patience
        best_val_acc = val_acc
        test_acc = test(test_loader)
    else:
        patience -= 1
    print(f'Epoch {epoch:02d}, Val: {val_acc:.3f}, Test: {test_acc:.3f}')
    if patience <= 0:
        break
    times.append(time.time() - start)
print(f"Median time per epoch: {torch.tensor(times).median():.4f}s")
