import tarfile
import warnings
from unittest.mock import patch

import torch
from torchaudio._internal import module_utils as _mod_utils
from torchaudio.backend import soundfile_backend
from torchaudio_unittest.backend.common import get_bits_per_sample, get_encoding
from torchaudio_unittest.common_utils import (
    get_wav_data,
    nested_params,
    PytorchTestCase,
    save_wav,
    skipIfNoModule,
    TempDirMixin,
)

from .common import parameterize, skipIfFormatNotSupported

if _mod_utils.is_module_available("soundfile"):
    import soundfile


@skipIfNoModule("soundfile")
class TestInfo(TempDirMixin, PytorchTestCase):
    @parameterize(
        ["float32", "int32", "int16", "uint8"],
        [8000, 16000],
        [1, 2],
    )
    def test_wav(self, dtype, sample_rate, num_channels):
        """`soundfile_backend.info` can check wav file correctly"""
        duration = 1
        path = self.get_temp_path("data.wav")
        data = get_wav_data(dtype, num_channels, normalize=False, num_frames=duration * sample_rate)
        save_wav(path, data, sample_rate)
        info = soundfile_backend.info(path)
        assert info.sample_rate == sample_rate
        assert info.num_frames == sample_rate * duration
        assert info.num_channels == num_channels
        assert info.bits_per_sample == get_bits_per_sample("wav", dtype)
        assert info.encoding == get_encoding("wav", dtype)

    @parameterize([8000, 16000], [1, 2])
    @skipIfFormatNotSupported("FLAC")
    def test_flac(self, sample_rate, num_channels):
        """`soundfile_backend.info` can check flac file correctly"""
        duration = 1
        num_frames = sample_rate * duration
        data = torch.randn(num_frames, num_channels).numpy()
        path = self.get_temp_path("data.flac")
        soundfile.write(path, data, sample_rate)

        info = soundfile_backend.info(path)
        assert info.sample_rate == sample_rate
        assert info.num_frames == num_frames
        assert info.num_channels == num_channels
        assert info.bits_per_sample == 16
        assert info.encoding == "FLAC"

    @parameterize([8000, 16000], [1, 2])
    @skipIfFormatNotSupported("OGG")
    def test_ogg(self, sample_rate, num_channels):
        """`soundfile_backend.info` can check ogg file correctly"""
        duration = 1
        num_frames = sample_rate * duration
        data = torch.randn(num_frames, num_channels).numpy()
        path = self.get_temp_path("data.ogg")
        soundfile.write(path, data, sample_rate)

        info = soundfile_backend.info(path)
        assert info.sample_rate == sample_rate
        assert info.num_frames == sample_rate * duration
        assert info.num_channels == num_channels
        assert info.bits_per_sample == 0
        assert info.encoding == "VORBIS"

    @nested_params(
        [8000, 16000],
        [1, 2],
        [("PCM_24", 24), ("PCM_32", 32)],
    )
    @skipIfFormatNotSupported("NIST")
    def test_sphere(self, sample_rate, num_channels, subtype_and_bit_depth):
        """`soundfile_backend.info` can check sph file correctly"""
        duration = 1
        num_frames = sample_rate * duration
        data = torch.randn(num_frames, num_channels).numpy()
        path = self.get_temp_path("data.nist")
        subtype, bits_per_sample = subtype_and_bit_depth
        soundfile.write(path, data, sample_rate, subtype=subtype)

        info = soundfile_backend.info(path)
        assert info.sample_rate == sample_rate
        assert info.num_frames == sample_rate * duration
        assert info.num_channels == num_channels
        assert info.bits_per_sample == bits_per_sample
        assert info.encoding == "PCM_S"

    def test_unknown_subtype_warning(self):
        """soundfile_backend.info issues a warning when the subtype is unknown

        This will happen if a new subtype is supported in SoundFile: the _SUBTYPE_TO_BITS_PER_SAMPLE
        dict should be updated.
        """

        def _mock_info_func(_):
            class MockSoundFileInfo:
                samplerate = 8000
                frames = 356
                channels = 2
                subtype = "UNSEEN_SUBTYPE"
                format = "UNKNOWN"

            return MockSoundFileInfo()

        with patch("soundfile.info", _mock_info_func):
            with warnings.catch_warnings(record=True) as w:
                info = soundfile_backend.info("foo")
                assert len(w) == 1
                assert "UNSEEN_SUBTYPE subtype is unknown to TorchAudio" in str(w[-1].message)
                assert info.bits_per_sample == 0


@skipIfNoModule("soundfile")
class TestFileObject(TempDirMixin, PytorchTestCase):
    def _test_fileobj(self, ext, subtype, bits_per_sample):
        """Query audio via file-like object works"""
        duration = 2
        sample_rate = 16000
        num_channels = 2
        num_frames = sample_rate * duration
        path = self.get_temp_path(f"test.{ext}")

        data = torch.randn(num_frames, num_channels).numpy()
        soundfile.write(path, data, sample_rate, subtype=subtype)

        with open(path, "rb") as fileobj:
            info = soundfile_backend.info(fileobj)
        assert info.sample_rate == sample_rate
        assert info.num_frames == num_frames
        assert info.num_channels == num_channels
        assert info.bits_per_sample == bits_per_sample
        assert info.encoding == "FLAC" if ext == "flac" else "PCM_S"

    def test_fileobj_wav(self):
        """Loading audio via file-like object works"""
        self._test_fileobj("wav", "PCM_16", 16)

    @skipIfFormatNotSupported("FLAC")
    def test_fileobj_flac(self):
        """Loading audio via file-like object works"""
        self._test_fileobj("flac", "PCM_16", 16)

    def _test_tarobj(self, ext, subtype, bits_per_sample):
        """Query compressed audio via file-like object works"""
        duration = 2
        sample_rate = 16000
        num_channels = 2
        num_frames = sample_rate * duration
        audio_file = f"test.{ext}"
        audio_path = self.get_temp_path(audio_file)
        archive_path = self.get_temp_path("archive.tar.gz")

        data = torch.randn(num_frames, num_channels).numpy()
        soundfile.write(audio_path, data, sample_rate, subtype=subtype)

        with tarfile.TarFile(archive_path, "w") as tarobj:
            tarobj.add(audio_path, arcname=audio_file)
        with tarfile.TarFile(archive_path, "r") as tarobj:
            fileobj = tarobj.extractfile(audio_file)
            info = soundfile_backend.info(fileobj)
        assert info.sample_rate == sample_rate
        assert info.num_frames == num_frames
        assert info.num_channels == num_channels
        assert info.bits_per_sample == bits_per_sample
        assert info.encoding == "FLAC" if ext == "flac" else "PCM_S"

    def test_tarobj_wav(self):
        """Query compressed audio via file-like object works"""
        self._test_tarobj("wav", "PCM_16", 16)

    @skipIfFormatNotSupported("FLAC")
    def test_tarobj_flac(self):
        """Query compressed audio via file-like object works"""
        self._test_tarobj("flac", "PCM_16", 16)
