import os
import platform
import sys
from concurrent.futures import ProcessPoolExecutor
from typing import List, Tuple
from unittest import skipIf

import numpy as np
import torch
import torchaudio
from torchaudio_unittest.common_utils import get_whitenoise, PytorchTestCase, save_wav, skipIfNoSox, TempDirMixin


class RandomPerturbationFile(torch.utils.data.Dataset):
    """Given flist, apply random speed perturbation"""

    def __init__(self, flist: List[str], sample_rate: int):
        super().__init__()
        self.flist = flist
        self.sample_rate = sample_rate
        self.rng = None

    def __getitem__(self, index):
        speed = self.rng.uniform(0.5, 2.0)
        effects = [
            ["gain", "-n", "-10"],
            ["speed", f"{speed:.5f}"],  # duration of data is 0.5 ~ 2.0 seconds.
            ["rate", f"{self.sample_rate}"],
            ["pad", "0", "1.5"],  # add 1.5 seconds silence at the end
            ["trim", "0", "2"],  # get the first 2 seconds
        ]
        data, _ = torchaudio.sox_effects.apply_effects_file(self.flist[index], effects)
        return data

    def __len__(self):
        return len(self.flist)


class RandomPerturbationTensor(torch.utils.data.Dataset):
    """Apply speed purturbation to (synthetic) Tensor data"""

    def __init__(self, signals: List[Tuple[torch.Tensor, int]], sample_rate: int):
        super().__init__()
        self.signals = signals
        self.sample_rate = sample_rate
        self.rng = None

    def __getitem__(self, index):
        speed = self.rng.uniform(0.5, 2.0)
        effects = [
            ["gain", "-n", "-10"],
            ["speed", f"{speed:.5f}"],  # duration of data is 0.5 ~ 2.0 seconds.
            ["rate", f"{self.sample_rate}"],
            ["pad", "0", "1.5"],  # add 1.5 seconds silence at the end
            ["trim", "0", "2"],  # get the first 2 seconds
        ]
        tensor, sample_rate = self.signals[index]
        data, _ = torchaudio.sox_effects.apply_effects_tensor(tensor, sample_rate, effects)
        return data

    def __len__(self):
        return len(self.signals)


def init_random_seed(worker_id):
    dataset = torch.utils.data.get_worker_info().dataset
    dataset.rng = np.random.RandomState(worker_id)


@skipIfNoSox
@skipIf(
    platform.system() == "Darwin" and sys.version_info.major == 3 and sys.version_info.minor in [6, 7],
    "This test is known to get stuck for macOS with Python < 3.8. "
    "See https://github.com/pytorch/pytorch/issues/46409",
)
class TestSoxEffectsDataset(TempDirMixin, PytorchTestCase):
    """Test `apply_effects_file` in multi-process dataloader setting"""

    def _generate_dataset(self, num_samples=128):
        flist = []
        for i in range(num_samples):
            sample_rate = np.random.choice([8000, 16000, 44100])
            dtype = np.random.choice(["float32", "int32", "int16", "uint8"])
            data = get_whitenoise(n_channels=2, sample_rate=sample_rate, duration=1, dtype=dtype)
            path = self.get_temp_path(f"{i:03d}_{dtype}_{sample_rate}.wav")
            save_wav(path, data, sample_rate)
            flist.append(path)
        return flist

    def test_apply_effects_file(self):
        sample_rate = 12000
        flist = self._generate_dataset()
        dataset = RandomPerturbationFile(flist, sample_rate)
        loader = torch.utils.data.DataLoader(
            dataset,
            batch_size=32,
            num_workers=16,
            worker_init_fn=init_random_seed,
        )
        for batch in loader:
            assert batch.shape == (32, 2, 2 * sample_rate)

    def _generate_signals(self, num_samples=128):
        signals = []
        for _ in range(num_samples):
            sample_rate = np.random.choice([8000, 16000, 44100])
            data = get_whitenoise(n_channels=2, sample_rate=sample_rate, duration=1, dtype="float32")
            signals.append((data, sample_rate))
        return signals

    def test_apply_effects_tensor(self):
        sample_rate = 12000
        signals = self._generate_signals()
        dataset = RandomPerturbationTensor(signals, sample_rate)
        loader = torch.utils.data.DataLoader(
            dataset,
            batch_size=32,
            num_workers=16,
            worker_init_fn=init_random_seed,
        )
        for batch in loader:
            assert batch.shape == (32, 2, 2 * sample_rate)


def speed(path):
    wav, sample_rate = torchaudio.backend.sox_io_backend.load(path)
    effects = [
        ["speed", "1.03756523535464655"],
        ["rate", f"{sample_rate}"],
    ]
    return torchaudio.sox_effects.apply_effects_tensor(wav, sample_rate, effects)[0]


@skipIfNoSox
class TestProcessPoolExecutor(TempDirMixin, PytorchTestCase):
    backend = "sox_io"

    def setUp(self):
        sample_rate = 16000
        self.flist = []
        for i in range(10):
            path = self.get_temp_path(f"{i}.wav")
            data = get_whitenoise(n_channels=1, sample_rate=sample_rate, duration=1, dtype="float")
            save_wav(path, data, sample_rate)
            self.flist.append(path)

    @skipIf(os.environ.get("CI") == "true", "This test now hangs in CI")
    def test_executor(self):
        """Test that apply_effects_tensor with speed + rate does not crush

        https://github.com/pytorch/audio/issues/1021
        """
        executor = ProcessPoolExecutor(1)
        futures = [executor.submit(speed, path) for path in self.flist]
        for future in futures:
            future.result()
