import os
import os.path as osp
from typing import Callable, Dict, List, Optional

import torch

from torch_geometric.data import (
    Data,
    InMemoryDataset,
    download_url,
    extract_tar,
)
from torch_geometric.io import fs


class Wikidata5M(InMemoryDataset):
    r"""The Wikidata-5M dataset from the `"KEPLER: A Unified Model for
    Knowledge Embedding and Pre-trained Language Representation"
    <https://arxiv.org/abs/1911.06136>`_ paper,
    containing 4,594,485 entities, 822 relations,
    20,614,279 train triples, 5,163 validation triples, and 5,133 test triples.

    `Wikidata-5M <https://deepgraphlearning.github.io/project/wikidata5m>`_
    is a large-scale knowledge graph dataset with aligned corpus
    extracted form Wikidata.

    Args:
        root (str): Root directory where the dataset should be saved.
        setting (str, optional):
            If :obj:`"transductive"`, loads the transductive dataset.
            If :obj:`"inductive"`, loads the inductive dataset.
            (default: :obj:`"transductive"`)
        transform (callable, optional): A function/transform that takes in an
            :obj:`torch_geometric.data.Data` object and returns a transformed
            version. The data object will be transformed before every access.
            (default: :obj:`None`)
        pre_transform (callable, optional): A function/transform that takes in
            an :obj:`torch_geometric.data.Data` object and returns a
            transformed version. The data object will be transformed before
            being saved to disk. (default: :obj:`None`)
        force_reload (bool, optional): Whether to re-process the dataset.
            (default: :obj:`False`)
    """
    def __init__(
        self,
        root: str,
        setting: str = 'transductive',
        transform: Optional[Callable] = None,
        pre_transform: Optional[Callable] = None,
        force_reload: bool = False,
    ) -> None:
        if setting not in {'transductive', 'inductive'}:
            raise ValueError(f"Invalid 'setting' argument (got '{setting}')")

        self.setting = setting

        self.urls = [
            ('https://www.dropbox.com/s/7jp4ib8zo3i6m10/'
             'wikidata5m_text.txt.gz?dl=1'),
            'https://uni-bielefeld.sciebo.de/s/yuBKzBxsEc9j3hy/download',
        ]
        if self.setting == 'inductive':
            self.urls.append('https://www.dropbox.com/s/csed3cgal3m7rzo/'
                             'wikidata5m_inductive.tar.gz?dl=1')
        else:
            self.urls.append('https://www.dropbox.com/s/6sbhm0rwo4l73jq/'
                             'wikidata5m_transductive.tar.gz?dl=1')

        super().__init__(root, transform, pre_transform,
                         force_reload=force_reload)
        self.load(self.processed_paths[0])

    @property
    def raw_file_names(self) -> List[str]:
        return [
            'wikidata5m_text.txt.gz',
            'download',
            f'wikidata5m_{self.setting}_train.txt',
            f'wikidata5m_{self.setting}_valid.txt',
            f'wikidata5m_{self.setting}_test.txt',
        ]

    @property
    def processed_file_names(self) -> str:
        return f'{self.setting}_data.pt'

    def download(self) -> None:
        for url in self.urls:
            download_url(url, self.raw_dir)
        path = osp.join(self.raw_dir, f'wikidata5m_{self.setting}.tar.gz')
        extract_tar(path, self.raw_dir)
        os.remove(path)

    def process(self) -> None:
        import gzip

        entity_to_id: Dict[str, int] = {}
        with gzip.open(self.raw_paths[0], 'rt') as f:
            for i, line in enumerate(f):
                values = line.strip().split('\t')
                entity_to_id[values[0]] = i

        x = fs.torch_load(self.raw_paths[1])

        edge_indices = []
        edge_types = []
        split_indices = []

        rel_to_id: Dict[str, int] = {}
        for split, path in enumerate(self.raw_paths[2:]):
            with open(path) as f:
                for line in f:
                    head, rel, tail = line[:-1].split('\t')
                    src = entity_to_id[head]
                    dst = entity_to_id[tail]
                    edge_indices.append([src, dst])
                    if rel not in rel_to_id:
                        rel_to_id[rel] = len(rel_to_id)
                    edge_types.append(rel_to_id[rel])
                    split_indices.append(split)

        edge_index = torch.tensor(edge_indices).t().contiguous()
        edge_type = torch.tensor(edge_types)
        split_index = torch.tensor(split_indices)

        data = Data(
            x=x,
            edge_index=edge_index,
            edge_type=edge_type,
            train_mask=split_index == 0,
            val_mask=split_index == 1,
            test_mask=split_index == 2,
        )

        if self.pre_transform is not None:
            data = self.pre_transform(data)

        self.save([data], self.processed_paths[0])
