import torch
from torchaudio.datasets import LIBRISPEECH


class MapMemoryCache(torch.utils.data.Dataset):
    """
    Wrap a dataset so that, whenever a new item is returned, it is saved to memory.
    """

    def __init__(self, dataset):
        self.dataset = dataset
        self._cache = [None] * len(dataset)

    def __getitem__(self, n):
        if self._cache[n] is not None:
            return self._cache[n]

        item = self.dataset[n]
        self._cache[n] = item

        return item

    def __len__(self):
        return len(self.dataset)


class Processed(torch.utils.data.Dataset):
    def __init__(self, dataset, transforms, encode):
        self.dataset = dataset
        self.transforms = transforms
        self.encode = encode

    def __getitem__(self, key):
        item = self.dataset[key]
        return self.process_datapoint(item)

    def __len__(self):
        return len(self.dataset)

    def process_datapoint(self, item):
        transformed = item[0]
        target = item[2].lower()

        transformed = self.transforms(transformed)
        transformed = transformed[0, ...].transpose(0, -1)

        target = self.encode(target)
        target = torch.tensor(target, dtype=torch.long, device=transformed.device)

        return transformed, target


def split_process_librispeech(
    datasets,
    transforms,
    language_model,
    root,
    folder_in_archive,
):
    def create(tags, cache=True):

        if isinstance(tags, str):
            tags = [tags]
        if isinstance(transforms, list):
            transform_list = transforms
        else:
            transform_list = [transforms]

        data = torch.utils.data.ConcatDataset(
            [
                Processed(
                    LIBRISPEECH(
                        root,
                        tag,
                        folder_in_archive=folder_in_archive,
                        download=False,
                    ),
                    transform,
                    language_model.encode,
                )
                for tag, transform in zip(tags, transform_list)
            ]
        )

        data = MapMemoryCache(data)
        return data

    # For performance, we cache all datasets
    return tuple(create(dataset) for dataset in datasets)


def collate_factory(model_length_function, transforms=None):

    if transforms is None:
        transforms = torch.nn.Sequential()

    def collate_fn(batch):

        tensors = [transforms(b[0]) for b in batch if b]

        tensors_lengths = torch.tensor(
            [model_length_function(t) for t in tensors],
            dtype=torch.long,
            device=tensors[0].device,
        )

        tensors = torch.nn.utils.rnn.pad_sequence(tensors, batch_first=True)
        tensors = tensors.transpose(1, -1)

        targets = [b[1] for b in batch if b]
        target_lengths = torch.tensor(
            [target.shape[0] for target in targets],
            dtype=torch.long,
            device=tensors.device,
        )
        targets = torch.nn.utils.rnn.pad_sequence(targets, batch_first=True)

        return tensors, targets, tensors_lengths, target_lengths

    return collate_fn
