import os
import os.path as osp
import pickle
from typing import Callable, Dict, List, Optional

import torch
from tqdm import tqdm

from torch_geometric.data import (
    Data,
    InMemoryDataset,
    download_url,
    extract_zip,
)
from torch_geometric.io import fs


class LRGBDataset(InMemoryDataset):
    r"""The `"Long Range Graph Benchmark (LRGB)"
    <https://arxiv.org/abs/2206.08164>`_
    datasets which is a collection of 5 graph learning datasets with tasks
    that are based on long-range dependencies in graphs. See the original
    `source code <https://github.com/vijaydwivedi75/lrgb>`_ for more details
    on the individual datasets.

    +------------------------+-------------------+----------------------+
    | Dataset                | Domain            | Task                 |
    +========================+===================+======================+
    | :obj:`PascalVOC-SP`    | Computer Vision   | Node Classification  |
    +------------------------+-------------------+----------------------+
    | :obj:`COCO-SP`         | Computer Vision   | Node Classification  |
    +------------------------+-------------------+----------------------+
    | :obj:`PCQM-Contact`    | Quantum Chemistry | Link Prediction      |
    +------------------------+-------------------+----------------------+
    | :obj:`Peptides-func`   | Chemistry         | Graph Classification |
    +------------------------+-------------------+----------------------+
    | :obj:`Peptides-struct` | Chemistry         | Graph Regression     |
    +------------------------+-------------------+----------------------+

    Args:
        root (str): Root directory where the dataset should be saved.
        name (str): The name of the dataset (one of :obj:`"PascalVOC-SP"`,
            :obj:`"COCO-SP"`, :obj:`"PCQM-Contact"`, :obj:`"Peptides-func"`,
            :obj:`"Peptides-struct"`)
        split (str, optional): If :obj:`"train"`, loads the training dataset.
            If :obj:`"val"`, loads the validation dataset.
            If :obj:`"test"`, loads the test dataset.
            (default: :obj:`"train"`)
        transform (callable, optional): A function/transform that takes in an
            :obj:`torch_geometric.data.Data` object and returns a transformed
            version. The data object will be transformed before every access.
            (default: :obj:`None`)
        pre_transform (callable, optional): A function/transform that takes in
            an :obj:`torch_geometric.data.Data` object and returns a
            transformed version. The data object will be transformed before
            being saved to disk. (default: :obj:`None`)
        pre_filter (callable, optional): A function that takes in an
            :obj:`torch_geometric.data.Data` object and returns a boolean
            value, indicating whether the data object should be included in the
            final dataset. (default: :obj:`None`)
        force_reload (bool, optional): Whether to re-process the dataset.
            (default: :obj:`False`)

    **STATS:**

    .. list-table::
        :widths: 15 10 10 10 10
        :header-rows: 1

        * - Name
          - #graphs
          - #nodes
          - #edges
          - #classes
        * - PascalVOC-SP
          - 11,355
          - ~479.40
          - ~2,710.48
          - 21
        * - COCO-SP
          - 123,286
          - ~476.88
          - ~2,693.67
          - 81
        * - PCQM-Contact
          - 529,434
          - ~30.14
          - ~61.09
          - 1
        * - Peptides-func
          - 15,535
          - ~150.94
          - ~307.30
          - 10
        * - Peptides-struct
          - 15,535
          - ~150.94
          - ~307.30
          - 11
    """
    names = [
        'pascalvoc-sp', 'coco-sp', 'pcqm-contact', 'peptides-func',
        'peptides-struct'
    ]

    urls = {
        'pascalvoc-sp':
        'https://www.dropbox.com/s/8x722ai272wqwl4/pascalvocsp.zip?dl=1',
        'coco-sp':
        'https://www.dropbox.com/s/r6ihg1f4pmyjjy0/cocosp.zip?dl=1',
        'pcqm-contact':
        'https://www.dropbox.com/s/qdag867u6h6i60y/pcqmcontact.zip?dl=1',
        'peptides-func':
        'https://www.dropbox.com/s/ycsq37q8sxs1ou8/peptidesfunc.zip?dl=1',
        'peptides-struct':
        'https://www.dropbox.com/s/zgv4z8fcpmknhs8/peptidesstruct.zip?dl=1'
    }

    dwnld_file_name = {
        'pascalvoc-sp': 'voc_superpixels_edge_wt_region_boundary',
        'coco-sp': 'coco_superpixels_edge_wt_region_boundary',
        'pcqm-contact': 'pcqmcontact',
        'peptides-func': 'peptidesfunc',
        'peptides-struct': 'peptidesstruct'
    }

    def __init__(
        self,
        root: str,
        name: str,
        split: str = "train",
        transform: Optional[Callable] = None,
        pre_transform: Optional[Callable] = None,
        pre_filter: Optional[Callable] = None,
        force_reload: bool = False,
    ) -> None:
        self.name = name.lower()
        assert self.name in self.names
        assert split in ['train', 'val', 'test']

        super().__init__(root, transform, pre_transform, pre_filter,
                         force_reload=force_reload)
        path = osp.join(self.processed_dir, f'{split}.pt')
        self.load(path)

    @property
    def raw_dir(self) -> str:
        return osp.join(self.root, self.name, 'raw')

    @property
    def processed_dir(self) -> str:
        return osp.join(self.root, self.name, 'processed')

    @property
    def raw_file_names(self) -> List[str]:
        if self.name.split('-')[1] == 'sp':
            return ['train.pickle', 'val.pickle', 'test.pickle']
        else:
            return ['train.pt', 'val.pt', 'test.pt']

    @property
    def processed_file_names(self) -> List[str]:
        return ['train.pt', 'val.pt', 'test.pt']

    def download(self) -> None:
        fs.rm(self.raw_dir)
        path = download_url(self.urls[self.name], self.root)
        extract_zip(path, self.root)
        os.rename(osp.join(self.root, self.dwnld_file_name[self.name]),
                  self.raw_dir)
        os.unlink(path)

    def process(self) -> None:
        if self.name == 'pcqm-contact':
            # PCQM-Contact
            self.process_pcqm_contact()
        else:
            if self.name == 'coco-sp':
                # Label remapping for coco-sp.
                # See self.label_remap_coco() func
                label_map = self.label_remap_coco()

            for split in ['train', 'val', 'test']:
                if self.name.split('-')[1] == 'sp':
                    # PascalVOC-SP and COCO-SP
                    with open(osp.join(self.raw_dir, f'{split}.pickle'),
                              'rb') as f:
                        graphs = pickle.load(f)
                elif self.name.split('-')[0] == 'peptides':
                    # Peptides-func and Peptides-struct
                    graphs = fs.torch_load(
                        osp.join(self.raw_dir, f'{split}.pt'))

                data_list = []
                for graph in tqdm(graphs, desc=f'Processing {split} dataset'):
                    if self.name.split('-')[1] == 'sp':
                        """
                        PascalVOC-SP and COCO-SP
                        Each `graph` is a tuple (x, edge_attr, edge_index, y)
                            Shape of x : [num_nodes, 14]
                            Shape of edge_attr : [num_edges, 2]
                            Shape of edge_index : [2, num_edges]
                            Shape of y : [num_nodes]
                        """
                        x = graph[0].to(torch.float)
                        edge_attr = graph[1].to(torch.float)
                        edge_index = graph[2]
                        y = torch.LongTensor(graph[3])
                    elif self.name.split('-')[0] == 'peptides':
                        """
                        Peptides-func and Peptides-struct
                        Each `graph` is a tuple (x, edge_attr, edge_index, y)
                            Shape of x : [num_nodes, 9]
                            Shape of edge_attr : [num_edges, 3]
                            Shape of edge_index : [2, num_edges]
                            Shape of y : [1, 10] for Peptides-func,  or
                                         [1, 11] for Peptides-struct
                        """
                        x = graph[0]
                        edge_attr = graph[1]
                        edge_index = graph[2]
                        y = graph[3]

                    if self.name == 'coco-sp':
                        for i, label in enumerate(y):
                            y[i] = label_map[label.item()]

                    data = Data(x=x, edge_index=edge_index,
                                edge_attr=edge_attr, y=y)

                    if self.pre_filter is not None and not self.pre_filter(
                            data):
                        continue

                    if self.pre_transform is not None:
                        data = self.pre_transform(data)

                    data_list.append(data)

                path = osp.join(self.processed_dir, f'{split}.pt')
                self.save(data_list, path)

    def label_remap_coco(self) -> Dict[int, int]:
        # Util function for name 'COCO-SP'
        # to remap the labels as the original label idxs are not contiguous
        original_label_idx = [
            0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19,
            20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34, 35, 36, 37, 38, 39,
            40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57,
            58, 59, 60, 61, 62, 63, 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78,
            79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90
        ]

        label_map = {}
        for i, key in enumerate(original_label_idx):
            label_map[key] = i

        return label_map

    def process_pcqm_contact(self) -> None:
        for split in ['train', 'val', 'test']:
            graphs = fs.torch_load(osp.join(self.raw_dir, f'{split}.pt'))

            data_list = []
            for graph in tqdm(graphs, desc=f'Processing {split} dataset'):
                """
                PCQM-Contact
                Each `graph` is a tuple (x, edge_attr, edge_index,
                                        edge_label_index, edge_label)
                    Shape of x : [num_nodes, 9]
                    Shape of edge_attr : [num_edges, 3]
                    Shape of edge_index : [2, num_edges]
                    Shape of edge_label_index: [2, num_labeled_edges]
                    Shape of edge_label : [num_labeled_edges]

                    where,
                    num_labeled_edges are negative edges and link pred labels,
                    https://github.com/vijaydwivedi75/lrgb/blob/main/graphgps/loader/dataset/pcqm4mv2_contact.py#L192
                """
                x = graph[0]
                edge_attr = graph[1]
                edge_index = graph[2]
                edge_label_index = graph[3]
                edge_label = graph[4]

                data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr,
                            edge_label_index=edge_label_index,
                            edge_label=edge_label)

                if self.pre_filter is not None and not self.pre_filter(data):
                    continue

                if self.pre_transform is not None:
                    data = self.pre_transform(data)

                data_list.append(data)

            self.save(data_list, osp.join(self.processed_dir, f'{split}.pt'))
